Skip to main content

magi_core/
agent.rs

1// Author: Julian Bolivar
2// Version: 1.0.0
3// Date: 2026-04-05
4
5use crate::error::{MagiError, ProviderError};
6use crate::provider::{CompletionConfig, LlmProvider};
7use crate::schema::{AgentName, Mode};
8use std::collections::BTreeMap;
9use std::path::Path;
10use std::sync::Arc;
11
12use crate::prompts;
13
14/// All analysis modes in iteration order.
15const ALL_MODES: [Mode; 3] = [Mode::CodeReview, Mode::Design, Mode::Analysis];
16
17/// An autonomous MAGI agent with its own identity, system prompt, and LLM provider.
18///
19/// Each agent combines an [`AgentName`] identity, a [`Mode`]-specific system prompt,
20/// and an [`LlmProvider`] backend. The agent delegates LLM communication to its
21/// provider via [`execute`](Agent::execute).
22pub struct Agent {
23    name: AgentName,
24    mode: Mode,
25    system_prompt: String,
26    provider: Arc<dyn LlmProvider>,
27}
28
29impl Agent {
30    /// Creates an agent with an auto-generated system prompt for the given name and mode.
31    ///
32    /// The prompt is selected from compiled-in markdown files via `include_str!`.
33    ///
34    /// # Parameters
35    /// - `name`: Which MAGI agent (Melchior, Balthasar, Caspar).
36    /// - `mode`: Analysis mode (CodeReview, Design, Analysis).
37    /// - `provider`: The LLM backend for this agent.
38    pub fn new(name: AgentName, mode: Mode, provider: Arc<dyn LlmProvider>) -> Self {
39        let prompt = match name {
40            AgentName::Melchior => prompts::melchior::prompt_for_mode(&mode),
41            AgentName::Balthasar => prompts::balthasar::prompt_for_mode(&mode),
42            AgentName::Caspar => prompts::caspar::prompt_for_mode(&mode),
43        };
44        Self {
45            name,
46            mode,
47            system_prompt: prompt.to_string(),
48            provider,
49        }
50    }
51
52    /// Creates an agent with a custom system prompt, bypassing the compiled-in defaults.
53    ///
54    /// # Parameters
55    /// - `name`: Which MAGI agent.
56    /// - `mode`: Analysis mode.
57    /// - `provider`: The LLM backend.
58    /// - `prompt`: Custom system prompt string.
59    pub fn with_custom_prompt(
60        name: AgentName,
61        mode: Mode,
62        provider: Arc<dyn LlmProvider>,
63        prompt: String,
64    ) -> Self {
65        Self {
66            name,
67            mode,
68            system_prompt: prompt,
69            provider,
70        }
71    }
72
73    /// Creates an agent by loading the system prompt from a filesystem path.
74    ///
75    /// Returns [`MagiError::Io`] if the file cannot be read.
76    ///
77    /// # Parameters
78    /// - `name`: Which MAGI agent.
79    /// - `mode`: Analysis mode.
80    /// - `provider`: The LLM backend.
81    /// - `path`: Path to the prompt file.
82    ///
83    /// # Errors
84    /// Returns `MagiError::Io` if the file does not exist or cannot be read.
85    pub fn from_file(
86        name: AgentName,
87        mode: Mode,
88        provider: Arc<dyn LlmProvider>,
89        path: &Path,
90    ) -> Result<Self, MagiError> {
91        let prompt = std::fs::read_to_string(path)?;
92        Ok(Self {
93            name,
94            mode,
95            system_prompt: prompt,
96            provider,
97        })
98    }
99
100    /// Executes the agent by sending the user prompt to the LLM provider.
101    ///
102    /// Delegates to [`LlmProvider::complete`] with this agent's system prompt.
103    /// Returns the raw LLM response string — parsing is the orchestrator's responsibility.
104    ///
105    /// # Parameters
106    /// - `user_prompt`: The user's input content.
107    /// - `config`: Completion parameters (max_tokens, temperature).
108    ///
109    /// # Errors
110    /// Returns `ProviderError` on LLM communication failure.
111    pub async fn execute(
112        &self,
113        user_prompt: &str,
114        config: &CompletionConfig,
115    ) -> Result<String, ProviderError> {
116        self.provider
117            .complete(&self.system_prompt, user_prompt, config)
118            .await
119    }
120
121    /// Returns the agent's name.
122    pub fn name(&self) -> AgentName {
123        self.name
124    }
125
126    /// Returns the agent's analysis mode.
127    pub fn mode(&self) -> Mode {
128        self.mode
129    }
130
131    /// Returns the agent's system prompt.
132    pub fn system_prompt(&self) -> &str {
133        &self.system_prompt
134    }
135
136    /// Returns the provider's name (e.g., "claude", "openai").
137    pub fn provider_name(&self) -> &str {
138        self.provider.name()
139    }
140
141    /// Returns the provider's model identifier.
142    pub fn provider_model(&self) -> &str {
143        self.provider.model()
144    }
145
146    /// Returns the agent's display name (e.g., "Melchior").
147    pub fn display_name(&self) -> &str {
148        self.name.display_name()
149    }
150
151    /// Returns the agent's analytical role title (e.g., "Scientist").
152    pub fn title(&self) -> &str {
153        self.name.title()
154    }
155}
156
157/// Factory for creating sets of three MAGI agents with provider and prompt overrides.
158///
159/// Supports a default provider shared by all agents, per-agent provider overrides,
160/// and custom prompt overrides. Always creates agents in order:
161/// `[Melchior, Balthasar, Caspar]`.
162pub struct AgentFactory {
163    default_provider: Arc<dyn LlmProvider>,
164    agent_providers: BTreeMap<AgentName, Arc<dyn LlmProvider>>,
165    custom_prompts: BTreeMap<(AgentName, Mode), String>,
166}
167
168impl AgentFactory {
169    /// Creates a factory with a default provider shared by all three agents.
170    ///
171    /// # Parameters
172    /// - `default_provider`: The LLM provider used for agents without a specific override.
173    pub fn new(default_provider: Arc<dyn LlmProvider>) -> Self {
174        Self {
175            default_provider,
176            agent_providers: BTreeMap::new(),
177            custom_prompts: BTreeMap::new(),
178        }
179    }
180
181    /// Registers a provider override for a specific agent.
182    ///
183    /// # Parameters
184    /// - `name`: Which agent to override.
185    /// - `provider`: The provider to use for this agent.
186    pub fn with_provider(mut self, name: AgentName, provider: Arc<dyn LlmProvider>) -> Self {
187        self.agent_providers.insert(name, provider);
188        self
189    }
190
191    /// Registers a custom prompt override for a specific agent across all modes.
192    ///
193    /// # Parameters
194    /// - `name`: Which agent to override.
195    /// - `prompt`: The custom system prompt applied to every analysis mode.
196    pub fn with_custom_prompt(mut self, name: AgentName, prompt: String) -> Self {
197        for mode in ALL_MODES {
198            self.custom_prompts.insert((name, mode), prompt.clone());
199        }
200        self
201    }
202
203    /// Loads custom prompts from a directory of markdown files.
204    ///
205    /// Expected filenames: `{agent}_{mode}.md` (e.g., `melchior_code_review.md`).
206    /// Only loads files that exist; missing files use the default compiled-in prompts.
207    /// Returns [`MagiError::Io`] if the directory itself does not exist.
208    ///
209    /// # Errors
210    /// Returns `MagiError::Io` if the directory does not exist or cannot be read.
211    pub fn from_directory(mut self, dir: &Path) -> Result<Self, MagiError> {
212        // Verify the directory exists
213        std::fs::read_dir(dir)?;
214
215        let agents = ["melchior", "balthasar", "caspar"];
216        let modes = ["code_review", "design", "analysis"];
217
218        for agent_str in &agents {
219            for mode_str in &modes {
220                let filename = format!("{agent_str}_{mode_str}.md");
221                let path = dir.join(&filename);
222                if path.exists() {
223                    let content = std::fs::read_to_string(&path)?;
224                    let agent_name = match *agent_str {
225                        "melchior" => AgentName::Melchior,
226                        "balthasar" => AgentName::Balthasar,
227                        "caspar" => AgentName::Caspar,
228                        _ => unreachable!(),
229                    };
230                    let mode = match *mode_str {
231                        "code_review" => Mode::CodeReview,
232                        "design" => Mode::Design,
233                        "analysis" => Mode::Analysis,
234                        _ => unreachable!(),
235                    };
236                    self.custom_prompts.insert((agent_name, mode), content);
237                }
238            }
239        }
240
241        Ok(self)
242    }
243
244    /// Creates exactly three agents for the given mode.
245    ///
246    /// Returns agents in fixed order: `[Melchior, Balthasar, Caspar]`.
247    /// Each agent uses its specific provider override or the default provider,
248    /// and its custom prompt override or the compiled-in default prompt.
249    ///
250    /// # Parameters
251    /// - `mode`: The analysis mode for all three agents.
252    pub fn create_agents(&self, mode: Mode) -> Vec<Agent> {
253        let names = [AgentName::Melchior, AgentName::Balthasar, AgentName::Caspar];
254
255        names
256            .iter()
257            .map(|&name| {
258                let provider = self
259                    .agent_providers
260                    .get(&name)
261                    .cloned()
262                    .unwrap_or_else(|| self.default_provider.clone());
263
264                if let Some(prompt) = self.custom_prompts.get(&(name, mode)) {
265                    Agent::with_custom_prompt(name, mode, provider, prompt.clone())
266                } else {
267                    Agent::new(name, mode, provider)
268                }
269            })
270            .collect()
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277    use crate::schema::*;
278    use std::sync::Arc;
279    use std::sync::atomic::{AtomicUsize, Ordering};
280
281    /// Mock LlmProvider that tracks call count and returns a configurable response.
282    struct MockProvider {
283        name: String,
284        model: String,
285        response: String,
286        call_count: AtomicUsize,
287    }
288
289    impl MockProvider {
290        fn new(name: &str, model: &str, response: &str) -> Self {
291            Self {
292                name: name.to_string(),
293                model: model.to_string(),
294                response: response.to_string(),
295                call_count: AtomicUsize::new(0),
296            }
297        }
298
299        fn calls(&self) -> usize {
300            self.call_count.load(Ordering::SeqCst)
301        }
302    }
303
304    #[async_trait::async_trait]
305    impl LlmProvider for MockProvider {
306        async fn complete(
307            &self,
308            _system_prompt: &str,
309            _user_prompt: &str,
310            _config: &CompletionConfig,
311        ) -> Result<String, ProviderError> {
312            self.call_count.fetch_add(1, Ordering::SeqCst);
313            Ok(self.response.clone())
314        }
315
316        fn name(&self) -> &str {
317            &self.name
318        }
319
320        fn model(&self) -> &str {
321            &self.model
322        }
323    }
324
325    // -- BDD Scenario 26: agents with different providers --
326
327    /// Each agent uses its own provider (verify mock receives exactly 1 call).
328    #[tokio::test]
329    async fn test_each_agent_uses_its_own_provider() {
330        let p1 = Arc::new(MockProvider::new("p1", "m1", "r1"));
331        let p2 = Arc::new(MockProvider::new("p2", "m2", "r2"));
332        let p3 = Arc::new(MockProvider::new("p3", "m3", "r3"));
333
334        let factory = AgentFactory::new(p1.clone() as Arc<dyn LlmProvider>)
335            .with_provider(AgentName::Melchior, p1.clone() as Arc<dyn LlmProvider>)
336            .with_provider(AgentName::Balthasar, p2.clone() as Arc<dyn LlmProvider>)
337            .with_provider(AgentName::Caspar, p3.clone() as Arc<dyn LlmProvider>);
338
339        let agents = factory.create_agents(Mode::CodeReview);
340        let config = CompletionConfig::default();
341
342        for agent in &agents {
343            let _ = agent.execute("test input", &config).await;
344        }
345
346        assert_eq!(p1.calls(), 1, "p1 should receive exactly 1 call");
347        assert_eq!(p2.calls(), 1, "p2 should receive exactly 1 call");
348        assert_eq!(p3.calls(), 1, "p3 should receive exactly 1 call");
349    }
350
351    // -- BDD Scenario 27: factory with default and override --
352
353    /// Factory uses default provider for unoverridden agents, override for Caspar.
354    #[tokio::test]
355    async fn test_factory_default_and_override_providers() {
356        let default = Arc::new(MockProvider::new("default", "m1", "r1"));
357        let caspar_override = Arc::new(MockProvider::new("caspar-special", "m2", "r2"));
358
359        let factory = AgentFactory::new(default.clone() as Arc<dyn LlmProvider>).with_provider(
360            AgentName::Caspar,
361            caspar_override.clone() as Arc<dyn LlmProvider>,
362        );
363
364        let agents = factory.create_agents(Mode::CodeReview);
365
366        let melchior = agents
367            .iter()
368            .find(|a| a.name() == AgentName::Melchior)
369            .unwrap();
370        let balthasar = agents
371            .iter()
372            .find(|a| a.name() == AgentName::Balthasar)
373            .unwrap();
374        let caspar = agents
375            .iter()
376            .find(|a| a.name() == AgentName::Caspar)
377            .unwrap();
378
379        assert_eq!(melchior.provider_name(), "default");
380        assert_eq!(balthasar.provider_name(), "default");
381        assert_eq!(caspar.provider_name(), "caspar-special");
382    }
383
384    // -- BDD Scenario 30: modes generate different prompts --
385
386    /// CodeReview, Design, Analysis produce distinct system prompts per agent.
387    #[test]
388    fn test_different_modes_produce_distinct_prompts() {
389        let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
390
391        let cr = Agent::new(AgentName::Melchior, Mode::CodeReview, provider.clone());
392        let design = Agent::new(AgentName::Melchior, Mode::Design, provider.clone());
393        let analysis = Agent::new(AgentName::Melchior, Mode::Analysis, provider.clone());
394
395        assert_ne!(cr.system_prompt(), design.system_prompt());
396        assert_ne!(cr.system_prompt(), analysis.system_prompt());
397        assert_ne!(design.system_prompt(), analysis.system_prompt());
398    }
399
400    // -- BDD Scenario 31: from_directory with nonexistent path --
401
402    /// from_directory returns MagiError::Io for nonexistent directory.
403    #[test]
404    fn test_from_directory_returns_io_error_for_nonexistent_path() {
405        let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
406        let factory = AgentFactory::new(provider);
407        let result = factory.from_directory(Path::new("/nonexistent/path"));
408        assert!(matches!(result, Err(MagiError::Io(_))));
409    }
410
411    // -- Agent construction and accessors --
412
413    /// Agent::new generates system prompt from include_str! prompts.
414    #[test]
415    fn test_agent_new_generates_system_prompt() {
416        let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
417        let agent = Agent::new(AgentName::Melchior, Mode::CodeReview, provider);
418        assert!(!agent.system_prompt().is_empty());
419    }
420
421    /// Agent::with_custom_prompt uses provided prompt.
422    #[test]
423    fn test_agent_with_custom_prompt_uses_provided_prompt() {
424        let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
425        let agent = Agent::with_custom_prompt(
426            AgentName::Melchior,
427            Mode::CodeReview,
428            provider,
429            "Custom prompt".to_string(),
430        );
431        assert_eq!(agent.system_prompt(), "Custom prompt");
432    }
433
434    /// Agent::execute delegates to provider.complete with system prompt.
435    #[tokio::test]
436    async fn test_agent_execute_delegates_to_provider() {
437        let provider = Arc::new(MockProvider::new("mock", "m1", "response text"));
438        let provider_arc = provider.clone() as Arc<dyn LlmProvider>;
439        let agent = Agent::new(AgentName::Melchior, Mode::CodeReview, provider_arc);
440        let config = CompletionConfig::default();
441
442        let result = agent.execute("user input", &config).await;
443        assert_eq!(result.unwrap(), "response text");
444        assert_eq!(provider.calls(), 1);
445    }
446
447    /// Agent accessors return correct values.
448    #[test]
449    fn test_agent_accessors() {
450        let provider = Arc::new(MockProvider::new("test-provider", "test-model", "r"));
451        let provider_arc = provider.clone() as Arc<dyn LlmProvider>;
452        let agent = Agent::new(AgentName::Balthasar, Mode::Design, provider_arc);
453
454        assert_eq!(agent.name(), AgentName::Balthasar);
455        assert_eq!(agent.mode(), Mode::Design);
456        assert_eq!(agent.provider_name(), "test-provider");
457        assert_eq!(agent.provider_model(), "test-model");
458        assert_eq!(agent.display_name(), "Balthasar");
459        assert_eq!(agent.title(), "Pragmatist");
460    }
461
462    // -- AgentFactory tests --
463
464    /// AgentFactory::new creates 3 agents sharing default provider.
465    #[test]
466    fn test_agent_factory_creates_three_agents() {
467        let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
468        let factory = AgentFactory::new(provider);
469        let agents = factory.create_agents(Mode::CodeReview);
470
471        assert_eq!(agents.len(), 3);
472
473        let names: Vec<AgentName> = agents.iter().map(|a| a.name()).collect();
474        assert!(names.contains(&AgentName::Melchior));
475        assert!(names.contains(&AgentName::Balthasar));
476        assert!(names.contains(&AgentName::Caspar));
477    }
478
479    /// AgentFactory::create_agents returns agents in order [Melchior, Balthasar, Caspar].
480    #[test]
481    fn test_agent_factory_creates_agents_in_order() {
482        let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
483        let factory = AgentFactory::new(provider);
484        let agents = factory.create_agents(Mode::CodeReview);
485
486        assert_eq!(agents[0].name(), AgentName::Melchior);
487        assert_eq!(agents[1].name(), AgentName::Balthasar);
488        assert_eq!(agents[2].name(), AgentName::Caspar);
489    }
490
491    /// AgentFactory::with_provider overrides provider for specific agent.
492    #[test]
493    fn test_agent_factory_with_provider_overrides_specific_agent() {
494        let default = Arc::new(MockProvider::new("default", "m1", "r1")) as Arc<dyn LlmProvider>;
495        let override_p =
496            Arc::new(MockProvider::new("override", "m2", "r2")) as Arc<dyn LlmProvider>;
497
498        let factory = AgentFactory::new(default).with_provider(AgentName::Caspar, override_p);
499        let agents = factory.create_agents(Mode::CodeReview);
500
501        let caspar = agents
502            .iter()
503            .find(|a| a.name() == AgentName::Caspar)
504            .unwrap();
505        assert_eq!(caspar.provider_name(), "override");
506
507        let melchior = agents
508            .iter()
509            .find(|a| a.name() == AgentName::Melchior)
510            .unwrap();
511        assert_eq!(melchior.provider_name(), "default");
512    }
513
514    /// AgentFactory::with_custom_prompt overrides prompt for specific agent.
515    #[test]
516    fn test_agent_factory_with_custom_prompt_overrides_prompt() {
517        let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
518
519        let factory = AgentFactory::new(provider)
520            .with_custom_prompt(AgentName::Melchior, "My custom prompt".to_string());
521        let agents = factory.create_agents(Mode::CodeReview);
522
523        let melchior = agents
524            .iter()
525            .find(|a| a.name() == AgentName::Melchior)
526            .unwrap();
527        assert_eq!(melchior.system_prompt(), "My custom prompt");
528
529        let balthasar = agents
530            .iter()
531            .find(|a| a.name() == AgentName::Balthasar)
532            .unwrap();
533        assert_ne!(balthasar.system_prompt(), "My custom prompt");
534        assert!(!balthasar.system_prompt().is_empty());
535    }
536
537    /// AgentFactory::create_agents returns exactly 3 agents for all modes.
538    #[test]
539    fn test_agent_factory_creates_three_agents_for_all_modes() {
540        let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
541        let factory = AgentFactory::new(provider);
542
543        for mode in [Mode::CodeReview, Mode::Design, Mode::Analysis] {
544            let agents = factory.create_agents(mode);
545            assert_eq!(agents.len(), 3, "Expected 3 agents for mode {mode}");
546        }
547    }
548
549    /// Default prompts contain JSON schema instructions and English constraint.
550    #[test]
551    fn test_default_prompts_contain_json_and_english_constraints() {
552        let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
553
554        for name in [AgentName::Melchior, AgentName::Balthasar, AgentName::Caspar] {
555            for mode in [Mode::CodeReview, Mode::Design, Mode::Analysis] {
556                let agent = Agent::new(name, mode, provider.clone());
557                let prompt = agent.system_prompt();
558                assert!(
559                    prompt.contains("JSON"),
560                    "{name:?}/{mode:?} prompt should mention JSON"
561                );
562                assert!(
563                    prompt.contains("English"),
564                    "{name:?}/{mode:?} prompt should mention English"
565                );
566            }
567        }
568    }
569
570    /// from_file with nonexistent path returns MagiError::Io.
571    #[test]
572    fn test_from_file_returns_io_error_for_nonexistent_path() {
573        let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
574        let result = Agent::from_file(
575            AgentName::Melchior,
576            Mode::CodeReview,
577            provider,
578            Path::new("/nonexistent/prompt.md"),
579        );
580        assert!(matches!(result, Err(MagiError::Io(_))));
581    }
582}