1use crate::error::{MagiError, ProviderError};
6use crate::provider::{CompletionConfig, LlmProvider};
7use crate::schema::{AgentName, Mode};
8use std::collections::BTreeMap;
9use std::path::Path;
10use std::sync::Arc;
11
12use crate::prompts;
13
14const ALL_MODES: [Mode; 3] = [Mode::CodeReview, Mode::Design, Mode::Analysis];
16
17pub struct Agent {
23 name: AgentName,
24 mode: Mode,
25 system_prompt: String,
26 provider: Arc<dyn LlmProvider>,
27}
28
29impl Agent {
30 pub fn new(name: AgentName, mode: Mode, provider: Arc<dyn LlmProvider>) -> Self {
39 let prompt = match name {
40 AgentName::Melchior => prompts::melchior::prompt_for_mode(&mode),
41 AgentName::Balthasar => prompts::balthasar::prompt_for_mode(&mode),
42 AgentName::Caspar => prompts::caspar::prompt_for_mode(&mode),
43 };
44 Self {
45 name,
46 mode,
47 system_prompt: prompt.to_string(),
48 provider,
49 }
50 }
51
52 pub fn with_custom_prompt(
60 name: AgentName,
61 mode: Mode,
62 provider: Arc<dyn LlmProvider>,
63 prompt: String,
64 ) -> Self {
65 Self {
66 name,
67 mode,
68 system_prompt: prompt,
69 provider,
70 }
71 }
72
73 pub fn from_file(
86 name: AgentName,
87 mode: Mode,
88 provider: Arc<dyn LlmProvider>,
89 path: &Path,
90 ) -> Result<Self, MagiError> {
91 let prompt = std::fs::read_to_string(path)?;
92 Ok(Self {
93 name,
94 mode,
95 system_prompt: prompt,
96 provider,
97 })
98 }
99
100 pub async fn execute(
112 &self,
113 user_prompt: &str,
114 config: &CompletionConfig,
115 ) -> Result<String, ProviderError> {
116 self.provider
117 .complete(&self.system_prompt, user_prompt, config)
118 .await
119 }
120
121 pub fn name(&self) -> AgentName {
123 self.name
124 }
125
126 pub fn mode(&self) -> Mode {
128 self.mode
129 }
130
131 pub fn system_prompt(&self) -> &str {
133 &self.system_prompt
134 }
135
136 pub fn provider_name(&self) -> &str {
138 self.provider.name()
139 }
140
141 pub fn provider_model(&self) -> &str {
143 self.provider.model()
144 }
145
146 pub fn display_name(&self) -> &str {
148 self.name.display_name()
149 }
150
151 pub fn title(&self) -> &str {
153 self.name.title()
154 }
155}
156
157pub struct AgentFactory {
163 default_provider: Arc<dyn LlmProvider>,
164 agent_providers: BTreeMap<AgentName, Arc<dyn LlmProvider>>,
165 custom_prompts: BTreeMap<(AgentName, Mode), String>,
166}
167
168impl AgentFactory {
169 pub fn new(default_provider: Arc<dyn LlmProvider>) -> Self {
174 Self {
175 default_provider,
176 agent_providers: BTreeMap::new(),
177 custom_prompts: BTreeMap::new(),
178 }
179 }
180
181 pub fn with_provider(mut self, name: AgentName, provider: Arc<dyn LlmProvider>) -> Self {
187 self.agent_providers.insert(name, provider);
188 self
189 }
190
191 pub fn with_custom_prompt(mut self, name: AgentName, prompt: String) -> Self {
197 for mode in ALL_MODES {
198 self.custom_prompts.insert((name, mode), prompt.clone());
199 }
200 self
201 }
202
203 pub fn from_directory(mut self, dir: &Path) -> Result<Self, MagiError> {
212 std::fs::read_dir(dir)?;
214
215 let agents = ["melchior", "balthasar", "caspar"];
216 let modes = ["code_review", "design", "analysis"];
217
218 for agent_str in &agents {
219 for mode_str in &modes {
220 let filename = format!("{agent_str}_{mode_str}.md");
221 let path = dir.join(&filename);
222 if path.exists() {
223 let content = std::fs::read_to_string(&path)?;
224 let agent_name = match *agent_str {
225 "melchior" => AgentName::Melchior,
226 "balthasar" => AgentName::Balthasar,
227 "caspar" => AgentName::Caspar,
228 _ => unreachable!(),
229 };
230 let mode = match *mode_str {
231 "code_review" => Mode::CodeReview,
232 "design" => Mode::Design,
233 "analysis" => Mode::Analysis,
234 _ => unreachable!(),
235 };
236 self.custom_prompts.insert((agent_name, mode), content);
237 }
238 }
239 }
240
241 Ok(self)
242 }
243
244 pub fn create_agents(&self, mode: Mode) -> Vec<Agent> {
253 let names = [AgentName::Melchior, AgentName::Balthasar, AgentName::Caspar];
254
255 names
256 .iter()
257 .map(|&name| {
258 let provider = self
259 .agent_providers
260 .get(&name)
261 .cloned()
262 .unwrap_or_else(|| self.default_provider.clone());
263
264 if let Some(prompt) = self.custom_prompts.get(&(name, mode)) {
265 Agent::with_custom_prompt(name, mode, provider, prompt.clone())
266 } else {
267 Agent::new(name, mode, provider)
268 }
269 })
270 .collect()
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277 use crate::schema::*;
278 use std::sync::Arc;
279 use std::sync::atomic::{AtomicUsize, Ordering};
280
281 struct MockProvider {
283 name: String,
284 model: String,
285 response: String,
286 call_count: AtomicUsize,
287 }
288
289 impl MockProvider {
290 fn new(name: &str, model: &str, response: &str) -> Self {
291 Self {
292 name: name.to_string(),
293 model: model.to_string(),
294 response: response.to_string(),
295 call_count: AtomicUsize::new(0),
296 }
297 }
298
299 fn calls(&self) -> usize {
300 self.call_count.load(Ordering::SeqCst)
301 }
302 }
303
304 #[async_trait::async_trait]
305 impl LlmProvider for MockProvider {
306 async fn complete(
307 &self,
308 _system_prompt: &str,
309 _user_prompt: &str,
310 _config: &CompletionConfig,
311 ) -> Result<String, ProviderError> {
312 self.call_count.fetch_add(1, Ordering::SeqCst);
313 Ok(self.response.clone())
314 }
315
316 fn name(&self) -> &str {
317 &self.name
318 }
319
320 fn model(&self) -> &str {
321 &self.model
322 }
323 }
324
325 #[tokio::test]
329 async fn test_each_agent_uses_its_own_provider() {
330 let p1 = Arc::new(MockProvider::new("p1", "m1", "r1"));
331 let p2 = Arc::new(MockProvider::new("p2", "m2", "r2"));
332 let p3 = Arc::new(MockProvider::new("p3", "m3", "r3"));
333
334 let factory = AgentFactory::new(p1.clone() as Arc<dyn LlmProvider>)
335 .with_provider(AgentName::Melchior, p1.clone() as Arc<dyn LlmProvider>)
336 .with_provider(AgentName::Balthasar, p2.clone() as Arc<dyn LlmProvider>)
337 .with_provider(AgentName::Caspar, p3.clone() as Arc<dyn LlmProvider>);
338
339 let agents = factory.create_agents(Mode::CodeReview);
340 let config = CompletionConfig::default();
341
342 for agent in &agents {
343 let _ = agent.execute("test input", &config).await;
344 }
345
346 assert_eq!(p1.calls(), 1, "p1 should receive exactly 1 call");
347 assert_eq!(p2.calls(), 1, "p2 should receive exactly 1 call");
348 assert_eq!(p3.calls(), 1, "p3 should receive exactly 1 call");
349 }
350
351 #[tokio::test]
355 async fn test_factory_default_and_override_providers() {
356 let default = Arc::new(MockProvider::new("default", "m1", "r1"));
357 let caspar_override = Arc::new(MockProvider::new("caspar-special", "m2", "r2"));
358
359 let factory = AgentFactory::new(default.clone() as Arc<dyn LlmProvider>).with_provider(
360 AgentName::Caspar,
361 caspar_override.clone() as Arc<dyn LlmProvider>,
362 );
363
364 let agents = factory.create_agents(Mode::CodeReview);
365
366 let melchior = agents
367 .iter()
368 .find(|a| a.name() == AgentName::Melchior)
369 .unwrap();
370 let balthasar = agents
371 .iter()
372 .find(|a| a.name() == AgentName::Balthasar)
373 .unwrap();
374 let caspar = agents
375 .iter()
376 .find(|a| a.name() == AgentName::Caspar)
377 .unwrap();
378
379 assert_eq!(melchior.provider_name(), "default");
380 assert_eq!(balthasar.provider_name(), "default");
381 assert_eq!(caspar.provider_name(), "caspar-special");
382 }
383
384 #[test]
388 fn test_different_modes_produce_distinct_prompts() {
389 let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
390
391 let cr = Agent::new(AgentName::Melchior, Mode::CodeReview, provider.clone());
392 let design = Agent::new(AgentName::Melchior, Mode::Design, provider.clone());
393 let analysis = Agent::new(AgentName::Melchior, Mode::Analysis, provider.clone());
394
395 assert_ne!(cr.system_prompt(), design.system_prompt());
396 assert_ne!(cr.system_prompt(), analysis.system_prompt());
397 assert_ne!(design.system_prompt(), analysis.system_prompt());
398 }
399
400 #[test]
404 fn test_from_directory_returns_io_error_for_nonexistent_path() {
405 let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
406 let factory = AgentFactory::new(provider);
407 let result = factory.from_directory(Path::new("/nonexistent/path"));
408 assert!(matches!(result, Err(MagiError::Io(_))));
409 }
410
411 #[test]
415 fn test_agent_new_generates_system_prompt() {
416 let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
417 let agent = Agent::new(AgentName::Melchior, Mode::CodeReview, provider);
418 assert!(!agent.system_prompt().is_empty());
419 }
420
421 #[test]
423 fn test_agent_with_custom_prompt_uses_provided_prompt() {
424 let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
425 let agent = Agent::with_custom_prompt(
426 AgentName::Melchior,
427 Mode::CodeReview,
428 provider,
429 "Custom prompt".to_string(),
430 );
431 assert_eq!(agent.system_prompt(), "Custom prompt");
432 }
433
434 #[tokio::test]
436 async fn test_agent_execute_delegates_to_provider() {
437 let provider = Arc::new(MockProvider::new("mock", "m1", "response text"));
438 let provider_arc = provider.clone() as Arc<dyn LlmProvider>;
439 let agent = Agent::new(AgentName::Melchior, Mode::CodeReview, provider_arc);
440 let config = CompletionConfig::default();
441
442 let result = agent.execute("user input", &config).await;
443 assert_eq!(result.unwrap(), "response text");
444 assert_eq!(provider.calls(), 1);
445 }
446
447 #[test]
449 fn test_agent_accessors() {
450 let provider = Arc::new(MockProvider::new("test-provider", "test-model", "r"));
451 let provider_arc = provider.clone() as Arc<dyn LlmProvider>;
452 let agent = Agent::new(AgentName::Balthasar, Mode::Design, provider_arc);
453
454 assert_eq!(agent.name(), AgentName::Balthasar);
455 assert_eq!(agent.mode(), Mode::Design);
456 assert_eq!(agent.provider_name(), "test-provider");
457 assert_eq!(agent.provider_model(), "test-model");
458 assert_eq!(agent.display_name(), "Balthasar");
459 assert_eq!(agent.title(), "Pragmatist");
460 }
461
462 #[test]
466 fn test_agent_factory_creates_three_agents() {
467 let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
468 let factory = AgentFactory::new(provider);
469 let agents = factory.create_agents(Mode::CodeReview);
470
471 assert_eq!(agents.len(), 3);
472
473 let names: Vec<AgentName> = agents.iter().map(|a| a.name()).collect();
474 assert!(names.contains(&AgentName::Melchior));
475 assert!(names.contains(&AgentName::Balthasar));
476 assert!(names.contains(&AgentName::Caspar));
477 }
478
479 #[test]
481 fn test_agent_factory_creates_agents_in_order() {
482 let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
483 let factory = AgentFactory::new(provider);
484 let agents = factory.create_agents(Mode::CodeReview);
485
486 assert_eq!(agents[0].name(), AgentName::Melchior);
487 assert_eq!(agents[1].name(), AgentName::Balthasar);
488 assert_eq!(agents[2].name(), AgentName::Caspar);
489 }
490
491 #[test]
493 fn test_agent_factory_with_provider_overrides_specific_agent() {
494 let default = Arc::new(MockProvider::new("default", "m1", "r1")) as Arc<dyn LlmProvider>;
495 let override_p =
496 Arc::new(MockProvider::new("override", "m2", "r2")) as Arc<dyn LlmProvider>;
497
498 let factory = AgentFactory::new(default).with_provider(AgentName::Caspar, override_p);
499 let agents = factory.create_agents(Mode::CodeReview);
500
501 let caspar = agents
502 .iter()
503 .find(|a| a.name() == AgentName::Caspar)
504 .unwrap();
505 assert_eq!(caspar.provider_name(), "override");
506
507 let melchior = agents
508 .iter()
509 .find(|a| a.name() == AgentName::Melchior)
510 .unwrap();
511 assert_eq!(melchior.provider_name(), "default");
512 }
513
514 #[test]
516 fn test_agent_factory_with_custom_prompt_overrides_prompt() {
517 let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
518
519 let factory = AgentFactory::new(provider)
520 .with_custom_prompt(AgentName::Melchior, "My custom prompt".to_string());
521 let agents = factory.create_agents(Mode::CodeReview);
522
523 let melchior = agents
524 .iter()
525 .find(|a| a.name() == AgentName::Melchior)
526 .unwrap();
527 assert_eq!(melchior.system_prompt(), "My custom prompt");
528
529 let balthasar = agents
530 .iter()
531 .find(|a| a.name() == AgentName::Balthasar)
532 .unwrap();
533 assert_ne!(balthasar.system_prompt(), "My custom prompt");
534 assert!(!balthasar.system_prompt().is_empty());
535 }
536
537 #[test]
539 fn test_agent_factory_creates_three_agents_for_all_modes() {
540 let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
541 let factory = AgentFactory::new(provider);
542
543 for mode in [Mode::CodeReview, Mode::Design, Mode::Analysis] {
544 let agents = factory.create_agents(mode);
545 assert_eq!(agents.len(), 3, "Expected 3 agents for mode {mode}");
546 }
547 }
548
549 #[test]
551 fn test_default_prompts_contain_json_and_english_constraints() {
552 let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
553
554 for name in [AgentName::Melchior, AgentName::Balthasar, AgentName::Caspar] {
555 for mode in [Mode::CodeReview, Mode::Design, Mode::Analysis] {
556 let agent = Agent::new(name, mode, provider.clone());
557 let prompt = agent.system_prompt();
558 assert!(
559 prompt.contains("JSON"),
560 "{name:?}/{mode:?} prompt should mention JSON"
561 );
562 assert!(
563 prompt.contains("English"),
564 "{name:?}/{mode:?} prompt should mention English"
565 );
566 }
567 }
568 }
569
570 #[test]
572 fn test_from_file_returns_io_error_for_nonexistent_path() {
573 let provider = Arc::new(MockProvider::new("mock", "m1", "r1")) as Arc<dyn LlmProvider>;
574 let result = Agent::from_file(
575 AgentName::Melchior,
576 Mode::CodeReview,
577 provider,
578 Path::new("/nonexistent/prompt.md"),
579 );
580 assert!(matches!(result, Err(MagiError::Io(_))));
581 }
582}