1use crate::error::{MagiError, ProviderError};
6use crate::provider::{CompletionConfig, LlmProvider};
7use crate::schema::{AgentName, Mode};
8use std::collections::BTreeMap;
9use std::path::Path;
10use std::sync::Arc;
11
12const ALL_MODES: [Mode; 3] = [Mode::CodeReview, Mode::Design, Mode::Analysis];
14
15pub struct Agent {
21 name: AgentName,
22 system_prompt: String,
23 provider: Arc<dyn LlmProvider>,
24}
25
26impl Agent {
27 pub fn new(name: AgentName, provider: Arc<dyn LlmProvider>) -> Self {
37 let prompt = crate::prompts::embedded_prompt_for(name);
38 Self {
39 name,
40 system_prompt: prompt.to_string(),
41 provider,
42 }
43 }
44
45 pub fn with_custom_prompt(
52 name: AgentName,
53 provider: Arc<dyn LlmProvider>,
54 prompt: String,
55 ) -> Self {
56 Self {
57 name,
58 system_prompt: prompt,
59 provider,
60 }
61 }
62
63 pub fn from_file(
75 name: AgentName,
76 provider: Arc<dyn LlmProvider>,
77 path: &Path,
78 ) -> Result<Self, MagiError> {
79 let prompt = std::fs::read_to_string(path)?;
80 Ok(Self {
81 name,
82 system_prompt: prompt,
83 provider,
84 })
85 }
86
87 pub async fn execute(
99 &self,
100 user_prompt: &str,
101 config: &CompletionConfig,
102 ) -> Result<String, ProviderError> {
103 self.provider
104 .complete(&self.system_prompt, user_prompt, config)
105 .await
106 }
107
108 pub fn name(&self) -> AgentName {
110 self.name
111 }
112
113 pub fn system_prompt(&self) -> &str {
115 &self.system_prompt
116 }
117
118 pub fn provider_name(&self) -> &str {
120 self.provider.name()
121 }
122
123 pub fn provider_model(&self) -> &str {
125 self.provider.model()
126 }
127
128 pub fn display_name(&self) -> &str {
130 self.name.display_name()
131 }
132
133 pub fn title(&self) -> &str {
135 self.name.title()
136 }
137}
138
139pub struct AgentFactory {
145 default_provider: Arc<dyn LlmProvider>,
146 agent_providers: BTreeMap<AgentName, Arc<dyn LlmProvider>>,
147 custom_prompts: BTreeMap<(AgentName, Mode), String>,
148}
149
150impl AgentFactory {
151 pub fn new(default_provider: Arc<dyn LlmProvider>) -> Self {
156 Self {
157 default_provider,
158 agent_providers: BTreeMap::new(),
159 custom_prompts: BTreeMap::new(),
160 }
161 }
162
163 pub fn with_provider(mut self, name: AgentName, provider: Arc<dyn LlmProvider>) -> Self {
169 self.agent_providers.insert(name, provider);
170 self
171 }
172
173 pub fn with_custom_prompt(mut self, name: AgentName, prompt: String) -> Self {
179 for mode in ALL_MODES {
180 self.custom_prompts.insert((name, mode), prompt.clone());
181 }
182 self
183 }
184
185 pub fn from_directory(mut self, dir: &Path) -> Result<Self, MagiError> {
194 std::fs::read_dir(dir)?;
196
197 let agents = ["melchior", "balthasar", "caspar"];
198 let modes = ["code_review", "design", "analysis"];
199
200 for agent_str in &agents {
201 for mode_str in &modes {
202 let filename = format!("{agent_str}_{mode_str}.md");
203 let path = dir.join(&filename);
204 if path.exists() {
205 let content = std::fs::read_to_string(&path)?;
206 let agent_name = match *agent_str {
207 "melchior" => AgentName::Melchior,
208 "balthasar" => AgentName::Balthasar,
209 "caspar" => AgentName::Caspar,
210 _ => unreachable!(),
211 };
212 let mode = match *mode_str {
213 "code_review" => Mode::CodeReview,
214 "design" => Mode::Design,
215 "analysis" => Mode::Analysis,
216 _ => unreachable!(),
217 };
218 self.custom_prompts.insert((agent_name, mode), content);
219 }
220 }
221 }
222
223 Ok(self)
224 }
225
226 pub(crate) fn create_agents_with_prompts(
239 &self,
240 mode: Mode,
241 overrides: &std::collections::BTreeMap<(AgentName, Option<Mode>), String>,
242 ) -> Vec<Agent> {
243 let names = [AgentName::Melchior, AgentName::Balthasar, AgentName::Caspar];
244 names
245 .iter()
246 .map(|&name| {
247 let provider = self
248 .agent_providers
249 .get(&name)
250 .cloned()
251 .unwrap_or_else(|| self.default_provider.clone());
252 let prompt = crate::prompts::lookup_prompt(name, mode, overrides).to_string();
253 Agent::with_custom_prompt(name, provider, prompt)
254 })
255 .collect()
256 }
257
258 #[deprecated(
275 since = "0.3.0",
276 note = "create_agents does NOT apply overrides set via \
277 MagiBuilder::with_custom_prompt_for_mode / with_custom_prompt_all_modes. \
278 If you need overrides, call create_agents_with_prompts(mode, &overrides) directly, \
279 or (preferred) use MagiBuilder::build() which wires overrides automatically. \
280 See docs/migration-v0.3.md §3 for the correct upgrade path."
281 )]
282 pub fn create_agents(&self, mode: Mode) -> Vec<Agent> {
283 let names = [AgentName::Melchior, AgentName::Balthasar, AgentName::Caspar];
284
285 names
286 .iter()
287 .map(|&name| {
288 let provider = self
289 .agent_providers
290 .get(&name)
291 .cloned()
292 .unwrap_or_else(|| self.default_provider.clone());
293
294 if let Some(prompt) = self.custom_prompts.get(&(name, mode)) {
295 Agent::with_custom_prompt(name, provider, prompt.clone())
296 } else {
297 Agent::new(name, provider)
298 }
299 })
300 .collect()
301 }
302
303 pub(crate) fn custom_prompts(&self) -> &BTreeMap<(AgentName, Mode), String> {
310 &self.custom_prompts
311 }
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317 use crate::schema::*;
318 use std::sync::Arc;
319 use std::sync::atomic::{AtomicUsize, Ordering};
320
321 struct MockProvider {
323 name: String,
324 model: String,
325 response: String,
326 call_count: AtomicUsize,
327 }
328
329 impl MockProvider {
330 fn new(name: &str, model: &str, response: &str) -> Self {
331 Self {
332 name: name.to_string(),
333 model: model.to_string(),
334 response: response.to_string(),
335 call_count: AtomicUsize::new(0),
336 }
337 }
338
339 fn calls(&self) -> usize {
340 self.call_count.load(Ordering::SeqCst)
341 }
342 }
343
344 impl Default for MockProvider {
345 fn default() -> Self {
346 Self::new("mock", "model", "response")
347 }
348 }
349
350 #[async_trait::async_trait]
351 impl LlmProvider for MockProvider {
352 async fn complete(
353 &self,
354 _system_prompt: &str,
355 _user_prompt: &str,
356 _config: &CompletionConfig,
357 ) -> Result<String, ProviderError> {
358 self.call_count.fetch_add(1, Ordering::SeqCst);
359 Ok(self.response.clone())
360 }
361
362 fn name(&self) -> &str {
363 &self.name
364 }
365
366 fn model(&self) -> &str {
367 &self.model
368 }
369 }
370
371 #[allow(deprecated)]
375 #[tokio::test]
376 async fn test_each_agent_uses_its_own_provider() {
377 let p1 = Arc::new(MockProvider::new("p1", "m1", "r1"));
378 let p2 = Arc::new(MockProvider::new("p2", "m2", "r2"));
379 let p3 = Arc::new(MockProvider::new("p3", "m3", "r3"));
380
381 let factory = AgentFactory::new(p1.clone() as Arc<dyn LlmProvider>)
382 .with_provider(AgentName::Melchior, p1.clone() as Arc<dyn LlmProvider>)
383 .with_provider(AgentName::Balthasar, p2.clone() as Arc<dyn LlmProvider>)
384 .with_provider(AgentName::Caspar, p3.clone() as Arc<dyn LlmProvider>);
385
386 let agents = factory.create_agents(Mode::CodeReview);
387 let config = CompletionConfig::default();
388
389 for agent in &agents {
390 let _ = agent.execute("test input", &config).await;
391 }
392
393 assert_eq!(p1.calls(), 1, "p1 should receive exactly 1 call");
394 assert_eq!(p2.calls(), 1, "p2 should receive exactly 1 call");
395 assert_eq!(p3.calls(), 1, "p3 should receive exactly 1 call");
396 }
397
398 #[allow(deprecated)]
402 #[tokio::test]
403 async fn test_factory_default_and_override_providers() {
404 let default = Arc::new(MockProvider::new("default", "m1", "r1"));
405 let caspar_override = Arc::new(MockProvider::new("caspar-special", "m2", "r2"));
406
407 let factory = AgentFactory::new(default.clone() as Arc<dyn LlmProvider>).with_provider(
408 AgentName::Caspar,
409 caspar_override.clone() as Arc<dyn LlmProvider>,
410 );
411
412 let agents = factory.create_agents(Mode::CodeReview);
413
414 let melchior = agents
415 .iter()
416 .find(|a| a.name() == AgentName::Melchior)
417 .unwrap();
418 let balthasar = agents
419 .iter()
420 .find(|a| a.name() == AgentName::Balthasar)
421 .unwrap();
422 let caspar = agents
423 .iter()
424 .find(|a| a.name() == AgentName::Caspar)
425 .unwrap();
426
427 assert_eq!(melchior.provider_name(), "default");
428 assert_eq!(balthasar.provider_name(), "default");
429 assert_eq!(caspar.provider_name(), "caspar-special");
430 }
431
432 #[test]
436 fn test_different_agent_identities_produce_distinct_prompts() {
437 let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
438
439 let melchior = Agent::new(AgentName::Melchior, provider.clone());
440 let balthasar = Agent::new(AgentName::Balthasar, provider.clone());
441 let caspar = Agent::new(AgentName::Caspar, provider.clone());
442
443 assert_ne!(melchior.system_prompt(), balthasar.system_prompt());
444 assert_ne!(melchior.system_prompt(), caspar.system_prompt());
445 assert_ne!(balthasar.system_prompt(), caspar.system_prompt());
446 }
447
448 #[test]
452 fn test_from_directory_returns_io_error_for_nonexistent_path() {
453 let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
454 let factory = AgentFactory::new(provider);
455 let result = factory.from_directory(Path::new("/nonexistent/path"));
456 assert!(matches!(result, Err(MagiError::Io(_))));
457 }
458
459 #[test]
463 fn test_agent_new_generates_system_prompt() {
464 let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
465 let agent = Agent::new(AgentName::Melchior, provider);
466 assert!(!agent.system_prompt().is_empty());
467 }
468
469 #[test]
471 fn test_agent_with_custom_prompt_uses_provided_prompt() {
472 let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
473 let agent =
474 Agent::with_custom_prompt(AgentName::Melchior, provider, "Custom prompt".to_string());
475 assert_eq!(agent.system_prompt(), "Custom prompt");
476 }
477
478 #[tokio::test]
480 async fn test_agent_execute_delegates_to_provider() {
481 let provider = Arc::new(MockProvider::new("mock", "m1", "response text"));
482 let provider_arc = provider.clone() as Arc<dyn LlmProvider>;
483 let agent = Agent::new(AgentName::Melchior, provider_arc);
484 let config = CompletionConfig::default();
485
486 let result = agent.execute("user input", &config).await;
487 assert_eq!(result.unwrap(), "response text");
488 assert_eq!(provider.calls(), 1);
489 }
490
491 #[test]
493 fn test_agent_accessors() {
494 let provider = Arc::new(MockProvider::new("test-provider", "test-model", "r"));
495 let provider_arc = provider.clone() as Arc<dyn LlmProvider>;
496 let agent = Agent::new(AgentName::Balthasar, provider_arc);
497
498 assert_eq!(agent.name(), AgentName::Balthasar);
499 assert_eq!(agent.provider_name(), "test-provider");
500 assert_eq!(agent.provider_model(), "test-model");
501 assert_eq!(agent.display_name(), "Balthasar");
502 assert_eq!(agent.title(), "Pragmatist");
503 }
504
505 #[allow(deprecated)]
509 #[test]
510 fn test_agent_factory_creates_three_agents() {
511 let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
512 let factory = AgentFactory::new(provider);
513 let agents = factory.create_agents(Mode::CodeReview);
514
515 assert_eq!(agents.len(), 3);
516
517 let names: Vec<AgentName> = agents.iter().map(|a| a.name()).collect();
518 assert!(names.contains(&AgentName::Melchior));
519 assert!(names.contains(&AgentName::Balthasar));
520 assert!(names.contains(&AgentName::Caspar));
521 }
522
523 #[allow(deprecated)]
525 #[test]
526 fn test_agent_factory_creates_agents_in_order() {
527 let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
528 let factory = AgentFactory::new(provider);
529 let agents = factory.create_agents(Mode::CodeReview);
530
531 assert_eq!(agents[0].name(), AgentName::Melchior);
532 assert_eq!(agents[1].name(), AgentName::Balthasar);
533 assert_eq!(agents[2].name(), AgentName::Caspar);
534 }
535
536 #[allow(deprecated)]
538 #[test]
539 fn test_agent_factory_with_provider_overrides_specific_agent() {
540 let default = Arc::new(MockProvider::new("default", "m1", "r1")) as Arc<dyn LlmProvider>;
541 let override_p =
542 Arc::new(MockProvider::new("override", "m2", "r2")) as Arc<dyn LlmProvider>;
543
544 let factory = AgentFactory::new(default).with_provider(AgentName::Caspar, override_p);
545 let agents = factory.create_agents(Mode::CodeReview);
546
547 let caspar = agents
548 .iter()
549 .find(|a| a.name() == AgentName::Caspar)
550 .unwrap();
551 assert_eq!(caspar.provider_name(), "override");
552
553 let melchior = agents
554 .iter()
555 .find(|a| a.name() == AgentName::Melchior)
556 .unwrap();
557 assert_eq!(melchior.provider_name(), "default");
558 }
559
560 #[allow(deprecated)]
562 #[test]
563 fn test_agent_factory_with_custom_prompt_overrides_prompt() {
564 let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
565
566 let factory = AgentFactory::new(provider)
567 .with_custom_prompt(AgentName::Melchior, "My custom prompt".to_string());
568 let agents = factory.create_agents(Mode::CodeReview);
569
570 let melchior = agents
571 .iter()
572 .find(|a| a.name() == AgentName::Melchior)
573 .unwrap();
574 assert_eq!(melchior.system_prompt(), "My custom prompt");
575
576 let balthasar = agents
577 .iter()
578 .find(|a| a.name() == AgentName::Balthasar)
579 .unwrap();
580 assert_ne!(balthasar.system_prompt(), "My custom prompt");
581 assert!(!balthasar.system_prompt().is_empty());
582 }
583
584 #[allow(deprecated)]
586 #[test]
587 fn test_agent_factory_creates_three_agents_for_all_modes() {
588 let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
589 let factory = AgentFactory::new(provider);
590
591 for mode in [Mode::CodeReview, Mode::Design, Mode::Analysis] {
592 let agents = factory.create_agents(mode);
593 assert_eq!(agents.len(), 3, "Expected 3 agents for mode {mode}");
594 }
595 }
596
597 #[test]
599 fn test_default_prompts_contain_json_and_english_constraints() {
600 let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
601
602 for name in [AgentName::Melchior, AgentName::Balthasar, AgentName::Caspar] {
603 let agent = Agent::new(name, provider.clone());
604 let prompt = agent.system_prompt();
605 assert!(
606 prompt.contains("JSON"),
607 "{name:?} prompt should mention JSON"
608 );
609 assert!(
610 prompt.contains("English"),
611 "{name:?} prompt should mention English"
612 );
613 }
614 }
615
616 #[test]
618 fn test_from_file_returns_io_error_for_nonexistent_path() {
619 let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
620 let result = Agent::from_file(
621 AgentName::Melchior,
622 provider,
623 Path::new("/nonexistent/prompt.md"),
624 );
625 assert!(matches!(result, Err(MagiError::Io(_))));
626 }
627
628 #[test]
630 fn test_agent_new_no_longer_requires_mode_parameter() {
631 let provider: Arc<dyn LlmProvider> = Arc::new(MockProvider::default());
632 let _agent = Agent::new(AgentName::Melchior, provider);
633 }
634}