1use std::collections::BTreeMap;
6use std::path::PathBuf;
7use std::sync::Arc;
8use std::time::Duration;
9
10use std::sync::Mutex;
11
12use crate::agent::{Agent, AgentFactory};
13use crate::consensus::{ConsensusConfig, ConsensusEngine};
14use crate::error::{MagiError, ProviderError};
15use crate::provider::{CompletionConfig, LlmProvider};
16use crate::reporting::{MagiReport, ReportConfig, ReportFormatter};
17use crate::schema::{AgentName, AgentOutput, Mode};
18use crate::user_prompt::{FastrandSource, RngLike, build_user_prompt};
19use crate::validate::{ValidationLimits, Validator};
20use tokio::task::AbortHandle;
21
22pub const DEFAULT_MAX_INPUT_LEN: usize = 4 * 1024 * 1024;
31
32#[non_exhaustive]
36#[derive(Debug, Clone)]
37pub struct MagiConfig {
38 pub timeout: Duration,
40 pub max_input_len: usize,
61 pub completion: CompletionConfig,
63}
64
65impl Default for MagiConfig {
66 fn default() -> Self {
67 Self {
68 timeout: Duration::from_secs(300),
69 max_input_len: DEFAULT_MAX_INPUT_LEN,
70 completion: CompletionConfig::default(),
71 }
72 }
73}
74
75pub struct MagiBuilder {
94 default_provider: Arc<dyn LlmProvider>,
95 agent_providers: BTreeMap<AgentName, Arc<dyn LlmProvider>>,
96 overrides: BTreeMap<(AgentName, Option<Mode>), String>,
97 prompts_dir: Option<PathBuf>,
98 config: MagiConfig,
99 validation_limits: ValidationLimits,
100 consensus_config: ConsensusConfig,
101 report_config: ReportConfig,
102 rng_source: Option<Box<dyn RngLike + Send>>,
103}
104
105impl MagiBuilder {
106 pub fn new(default_provider: Arc<dyn LlmProvider>) -> Self {
111 Self {
112 default_provider,
113 agent_providers: BTreeMap::new(),
114 overrides: BTreeMap::new(),
115 prompts_dir: None,
116 config: MagiConfig::default(),
117 validation_limits: ValidationLimits::default(),
118 consensus_config: ConsensusConfig::default(),
119 report_config: ReportConfig::default(),
120 rng_source: None,
121 }
122 }
123
124 pub fn with_provider(mut self, name: AgentName, provider: Arc<dyn LlmProvider>) -> Self {
130 self.agent_providers.insert(name, provider);
131 self
132 }
133
134 pub fn with_custom_prompt_for_mode(
144 mut self,
145 agent: AgentName,
146 mode: Mode,
147 prompt: String,
148 ) -> Self {
149 self.overrides.insert((agent, Some(mode)), prompt);
150 self
151 }
152
153 pub fn with_custom_prompt_all_modes(mut self, agent: AgentName, prompt: String) -> Self {
162 self.overrides.insert((agent, None), prompt);
163 self
164 }
165
166 pub(crate) fn with_rng_source(mut self, rng: Box<dyn RngLike + Send>) -> Self {
175 self.rng_source = Some(rng);
176 self
177 }
178
179 #[deprecated(since = "0.3.0", note = "use `with_custom_prompt_for_mode`")]
190 pub fn with_custom_prompt(self, agent: AgentName, mode: Mode, prompt: String) -> Self {
191 self.with_custom_prompt_for_mode(agent, mode, prompt)
192 }
193
194 pub fn with_prompts_dir(mut self, dir: PathBuf) -> Self {
199 self.prompts_dir = Some(dir);
200 self
201 }
202
203 pub fn with_timeout(mut self, timeout: Duration) -> Self {
208 self.config.timeout = timeout;
209 self
210 }
211
212 pub fn with_max_input_len(mut self, max: usize) -> Self {
217 self.config.max_input_len = max;
218 self
219 }
220
221 pub fn with_completion_config(mut self, config: CompletionConfig) -> Self {
226 self.config.completion = config;
227 self
228 }
229
230 pub fn with_validation_limits(mut self, limits: ValidationLimits) -> Self {
235 self.validation_limits = limits;
236 self
237 }
238
239 pub fn with_consensus_config(mut self, config: ConsensusConfig) -> Self {
244 self.consensus_config = config;
245 self
246 }
247
248 pub fn with_report_config(mut self, config: ReportConfig) -> Self {
253 self.report_config = config;
254 self
255 }
256
257 pub fn build(self) -> Result<Magi, MagiError> {
264 let mut factory = AgentFactory::new(self.default_provider);
265 for (name, provider) in self.agent_providers {
266 factory = factory.with_provider(name, provider);
267 }
268 let mut overrides = self.overrides;
269 if let Some(dir) = self.prompts_dir {
270 factory = factory.from_directory(&dir)?;
271 for ((agent, mode), prompt) in factory.custom_prompts() {
276 overrides
277 .entry((*agent, Some(*mode)))
278 .or_insert_with(|| prompt.clone());
279 }
280 }
281
282 let rng_source = self
283 .rng_source
284 .unwrap_or_else(|| Box::new(FastrandSource) as Box<dyn RngLike + Send>);
285
286 Ok(Magi {
287 config: self.config,
288 agent_factory: factory,
289 validator: Validator::with_limits(self.validation_limits),
290 consensus_engine: ConsensusEngine::new(self.consensus_config),
291 formatter: ReportFormatter::with_config(self.report_config)
292 .map_err(|e| MagiError::Validation(e.to_string()))?,
293 overrides,
294 rng_source: Arc::new(Mutex::new(rng_source)),
295 })
296 }
297}
298
299struct AbortGuard(Vec<AbortHandle>);
305
306impl Drop for AbortGuard {
307 fn drop(&mut self) {
308 for handle in &self.0 {
309 handle.abort();
310 }
311 }
312}
313
314pub struct Magi {
331 config: MagiConfig,
332 agent_factory: AgentFactory,
333 validator: Validator,
334 consensus_engine: ConsensusEngine,
335 formatter: ReportFormatter,
336 overrides: BTreeMap<(AgentName, Option<Mode>), String>,
337 rng_source: Arc<Mutex<Box<dyn RngLike + Send>>>,
338}
339
340impl Magi {
341 pub fn new(provider: Arc<dyn LlmProvider>) -> Self {
349 MagiBuilder::new(provider).build().expect(
351 "Magi::new uses all defaults and cannot fail; \
352 this is an internal invariant violation",
353 )
354 }
355
356 pub fn builder(provider: Arc<dyn LlmProvider>) -> MagiBuilder {
361 MagiBuilder::new(provider)
362 }
363
364 pub async fn analyze(&self, mode: &Mode, content: &str) -> Result<MagiReport, MagiError> {
389 if content.len() > self.config.max_input_len {
391 return Err(MagiError::InputTooLarge {
392 size: content.len(),
393 max: self.config.max_input_len,
394 });
395 }
396
397 let agents = self
401 .agent_factory
402 .create_agents_with_prompts(*mode, &self.overrides);
403
404 let prompt = {
407 let mut rng = self
408 .rng_source
409 .lock()
410 .unwrap_or_else(|poisoned| poisoned.into_inner());
411 build_user_prompt(*mode, content, &mut **rng)?
412 };
413
414 let agent_results = self.launch_agents(agents, &prompt).await;
416
417 let (successful, failed_agents) = self.process_results(agent_results)?;
419
420 let consensus = self.consensus_engine.determine(&successful)?;
422
423 let banner = self.formatter.format_banner(&successful, &consensus);
425 let report = self.formatter.format_report(&successful, &consensus);
426
427 let degraded = successful.len() < 3;
429 Ok(MagiReport {
430 agents: successful,
431 consensus,
432 banner,
433 report,
434 degraded,
435 failed_agents,
436 })
437 }
438
439 async fn launch_agents(
450 &self,
451 agents: Vec<Agent>,
452 prompt: &str,
453 ) -> Vec<(AgentName, Result<String, MagiError>)> {
454 let timeout = self.config.timeout;
455 let completion = self.config.completion.clone();
456 let mut handles = Vec::new();
457 let mut abort_handles = Vec::new();
458
459 for agent in agents {
460 let name = agent.name();
461 let user_prompt = prompt.to_string();
462 let config = completion.clone();
463
464 let handle = tokio::spawn(async move {
465 let result =
466 tokio::time::timeout(timeout, agent.execute(&user_prompt, &config)).await;
467 match result {
468 Ok(Ok(response)) => Ok(response),
469 Ok(Err(provider_err)) => Err(MagiError::Provider(provider_err)),
470 Err(_elapsed) => Err(MagiError::Provider(ProviderError::Timeout {
471 message: format!("agent timed out after {timeout:?}"),
472 })),
473 }
474 });
475 abort_handles.push(handle.abort_handle());
476 handles.push((name, handle));
477 }
478
479 let _guard = AbortGuard(abort_handles);
482
483 let mut results = Vec::new();
484 for (name, handle) in handles {
485 match handle.await {
486 Ok(result) => results.push((name, result)),
487 Err(join_err) => results.push((
488 name,
489 Err(MagiError::Provider(ProviderError::Process {
490 exit_code: None,
491 stderr: format!("agent task panicked: {join_err}"),
492 })),
493 )),
494 }
495 }
496
497 results
498 }
499
500 fn process_results(
506 &self,
507 results: Vec<(AgentName, Result<String, MagiError>)>,
508 ) -> Result<(Vec<AgentOutput>, BTreeMap<AgentName, String>), MagiError> {
509 let mut successful = Vec::new();
510 let mut failed_agents = BTreeMap::new();
511
512 for (name, result) in results {
513 match result {
514 Ok(raw) => match parse_agent_response(&raw) {
515 Ok(mut output) => match self.validator.validate_mut(&mut output) {
516 Ok(()) => successful.push(output),
517 Err(e) => {
518 failed_agents.insert(name, format!("validation: {e}"));
519 }
520 },
521 Err(e) => {
522 failed_agents.insert(name, format!("parse: {e}"));
523 }
524 },
525 Err(e) => {
526 failed_agents.insert(name, e.to_string());
527 }
528 }
529 }
530
531 let min_agents = self.consensus_engine.min_agents();
532 if successful.len() < min_agents {
533 return Err(MagiError::InsufficientAgents {
534 succeeded: successful.len(),
535 required: min_agents,
536 });
537 }
538
539 Ok((successful, failed_agents))
540 }
541
542 #[cfg(test)]
547 pub(crate) fn overrides(&self) -> &BTreeMap<(AgentName, Option<Mode>), String> {
548 &self.overrides
549 }
550}
551
552fn parse_agent_response(raw: &str) -> Result<AgentOutput, MagiError> {
565 let trimmed = raw.trim();
566
567 let stripped = if trimmed.starts_with("```") {
569 let without_opening = if let Some(rest) = trimmed.strip_prefix("```json") {
570 rest
571 } else {
572 trimmed.strip_prefix("```").unwrap_or(trimmed)
573 };
574 without_opening
575 .strip_suffix("```")
576 .unwrap_or(without_opening)
577 .trim()
578 } else {
579 trimmed
580 };
581
582 if let Ok(output) = serde_json::from_str::<AgentOutput>(stripped) {
585 return Ok(output);
586 }
587
588 for (start, _) in stripped.match_indices('{') {
593 let candidate = &stripped[start..];
594 if let Ok(output) = serde_json::from_str::<AgentOutput>(candidate) {
595 return Ok(output);
596 }
597 }
598
599 Err(MagiError::Deserialization(
600 "no valid JSON object found in agent response".to_string(),
601 ))
602}
603
604#[cfg(test)]
605mod tests {
606 use super::*;
607 use crate::prompts::lookup_prompt;
608 use crate::schema::*;
609 use std::sync::Arc;
610 use std::sync::atomic::{AtomicUsize, Ordering};
611 use std::time::Duration;
612
613 fn mock_agent_json(agent: &str, verdict: &str, confidence: f64) -> String {
615 format!(
616 r#"{{
617 "agent": "{agent}",
618 "verdict": "{verdict}",
619 "confidence": {confidence},
620 "summary": "Summary from {agent}",
621 "reasoning": "Reasoning from {agent}",
622 "findings": [],
623 "recommendation": "Recommendation from {agent}"
624 }}"#
625 )
626 }
627
628 struct MockProvider {
632 name: String,
633 model: String,
634 responses: Vec<Result<String, ProviderError>>,
635 call_count: AtomicUsize,
636 }
637
638 impl MockProvider {
639 fn success(name: &str, model: &str, responses: Vec<String>) -> Self {
640 Self {
641 name: name.to_string(),
642 model: model.to_string(),
643 responses: responses.into_iter().map(Ok).collect(),
644 call_count: AtomicUsize::new(0),
645 }
646 }
647
648 fn mixed(name: &str, model: &str, responses: Vec<Result<String, ProviderError>>) -> Self {
649 Self {
650 name: name.to_string(),
651 model: model.to_string(),
652 responses,
653 call_count: AtomicUsize::new(0),
654 }
655 }
656
657 fn calls(&self) -> usize {
658 self.call_count.load(Ordering::SeqCst)
659 }
660 }
661
662 #[async_trait::async_trait]
663 impl LlmProvider for MockProvider {
664 async fn complete(
665 &self,
666 _system_prompt: &str,
667 _user_prompt: &str,
668 _config: &CompletionConfig,
669 ) -> Result<String, ProviderError> {
670 let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
671 let idx = idx % self.responses.len();
672 self.responses[idx].clone()
673 }
674
675 fn name(&self) -> &str {
676 &self.name
677 }
678
679 fn model(&self) -> &str {
680 &self.model
681 }
682 }
683
684 #[tokio::test]
688 async fn test_analyze_unanimous_approve_returns_complete_report() {
689 let responses = vec![
690 mock_agent_json("melchior", "approve", 0.9),
691 mock_agent_json("balthasar", "approve", 0.85),
692 mock_agent_json("caspar", "approve", 0.95),
693 ];
694 let provider = Arc::new(MockProvider::success("mock", "test-model", responses));
695 let magi = Magi::new(provider as Arc<dyn LlmProvider>);
696
697 let result = magi.analyze(&Mode::CodeReview, "fn main() {}").await;
698 let report = result.expect("analyze should succeed");
699
700 assert_eq!(report.agents.len(), 3);
701 assert!(!report.degraded);
702 assert!(report.failed_agents.is_empty());
703 assert_eq!(report.consensus.consensus_verdict, Verdict::Approve);
704 assert!(!report.banner.is_empty());
705 assert!(!report.report.is_empty());
706 }
707
708 #[tokio::test]
712 async fn test_analyze_one_agent_timeout_degrades_gracefully() {
713 let responses = vec![
714 Ok(mock_agent_json("melchior", "approve", 0.9)),
715 Ok(mock_agent_json("balthasar", "approve", 0.85)),
716 Err(ProviderError::Timeout {
717 message: "exceeded timeout".to_string(),
718 }),
719 ];
720 let provider = Arc::new(MockProvider::mixed("mock", "test-model", responses));
721 let magi = Magi::new(provider as Arc<dyn LlmProvider>);
722
723 let result = magi.analyze(&Mode::CodeReview, "fn main() {}").await;
724 let report = result.expect("analyze should succeed with degradation");
725
726 assert!(report.degraded);
727 assert_eq!(report.failed_agents.len(), 1);
728 assert_eq!(report.agents.len(), 2);
729 }
730
731 #[tokio::test]
735 async fn test_analyze_one_agent_bad_json_degrades_gracefully() {
736 let responses = vec![
737 Ok(mock_agent_json("melchior", "approve", 0.9)),
738 Ok(mock_agent_json("balthasar", "approve", 0.85)),
739 Ok("not valid json at all".to_string()),
740 ];
741 let provider = Arc::new(MockProvider::mixed("mock", "test-model", responses));
742 let magi = Magi::new(provider as Arc<dyn LlmProvider>);
743
744 let result = magi.analyze(&Mode::CodeReview, "fn main() {}").await;
745 let report = result.expect("analyze should succeed with degradation");
746
747 assert!(report.degraded);
748 }
749
750 #[tokio::test]
754 async fn test_analyze_two_agents_fail_returns_insufficient_agents() {
755 let responses = vec![
756 Ok(mock_agent_json("melchior", "approve", 0.9)),
757 Err(ProviderError::Timeout {
758 message: "timeout".to_string(),
759 }),
760 Err(ProviderError::Network {
761 message: "connection refused".to_string(),
762 }),
763 ];
764 let provider = Arc::new(MockProvider::mixed("mock", "test-model", responses));
765 let magi = Magi::new(provider as Arc<dyn LlmProvider>);
766
767 let result = magi.analyze(&Mode::CodeReview, "fn main() {}").await;
768
769 match result {
770 Err(MagiError::InsufficientAgents {
771 succeeded,
772 required,
773 }) => {
774 assert_eq!(succeeded, 1);
775 assert_eq!(required, 2);
776 }
777 other => panic!("Expected InsufficientAgents, got: {other:?}"),
778 }
779 }
780
781 #[tokio::test]
785 async fn test_analyze_all_agents_fail_returns_insufficient_agents() {
786 let responses = vec![
787 Err(ProviderError::Timeout {
788 message: "timeout".to_string(),
789 }),
790 Err(ProviderError::Network {
791 message: "network".to_string(),
792 }),
793 Err(ProviderError::Auth {
794 message: "auth".to_string(),
795 }),
796 ];
797 let provider = Arc::new(MockProvider::mixed("mock", "test-model", responses));
798 let magi = Magi::new(provider as Arc<dyn LlmProvider>);
799
800 let result = magi.analyze(&Mode::CodeReview, "fn main() {}").await;
801
802 match result {
803 Err(MagiError::InsufficientAgents {
804 succeeded,
805 required,
806 }) => {
807 assert_eq!(succeeded, 0);
808 assert_eq!(required, 2);
809 }
810 other => panic!("Expected InsufficientAgents, got: {other:?}"),
811 }
812 }
813
814 #[tokio::test]
818 async fn test_analyze_plain_text_response_treated_as_failure() {
819 let responses = vec![
820 Ok(mock_agent_json("melchior", "approve", 0.9)),
821 Ok(mock_agent_json("balthasar", "approve", 0.85)),
822 Ok("I think the code is good".to_string()),
823 ];
824 let provider = Arc::new(MockProvider::mixed("mock", "test-model", responses));
825 let magi = Magi::new(provider as Arc<dyn LlmProvider>);
826
827 let result = magi.analyze(&Mode::CodeReview, "fn main() {}").await;
828 let report = result.expect("should succeed with degradation");
829
830 assert!(report.degraded);
831 assert_eq!(report.agents.len(), 2);
832 }
833
834 #[tokio::test]
838 async fn test_magi_new_creates_with_defaults() {
839 let responses = vec![
840 mock_agent_json("melchior", "approve", 0.9),
841 mock_agent_json("balthasar", "approve", 0.85),
842 mock_agent_json("caspar", "approve", 0.95),
843 ];
844 let provider = Arc::new(MockProvider::success(
845 "test-provider",
846 "test-model",
847 responses,
848 ));
849 let magi = Magi::new(provider as Arc<dyn LlmProvider>);
850
851 let result = magi.analyze(&Mode::CodeReview, "test content").await;
852 let report = result.expect("should succeed");
853
854 assert_eq!(report.agents.len(), 3);
856 }
857
858 #[tokio::test]
862 async fn test_builder_with_mixed_providers_and_custom_config() {
863 let default_responses = vec![
864 mock_agent_json("melchior", "approve", 0.9),
865 mock_agent_json("balthasar", "approve", 0.85),
866 ];
867 let caspar_responses = vec![mock_agent_json("caspar", "reject", 0.8)];
868
869 let default_provider = Arc::new(MockProvider::success(
870 "default-provider",
871 "model-a",
872 default_responses,
873 ));
874 let caspar_provider = Arc::new(MockProvider::success(
875 "caspar-provider",
876 "model-b",
877 caspar_responses,
878 ));
879
880 let magi = MagiBuilder::new(default_provider.clone() as Arc<dyn LlmProvider>)
881 .with_provider(
882 AgentName::Caspar,
883 caspar_provider.clone() as Arc<dyn LlmProvider>,
884 )
885 .with_timeout(Duration::from_secs(60))
886 .build()
887 .expect("build should succeed");
888
889 let result = magi.analyze(&Mode::CodeReview, "test content").await;
890 let report = result.expect("should succeed");
891
892 assert_eq!(report.agents.len(), 3);
893 assert!(caspar_provider.calls() > 0);
895 }
896
897 #[tokio::test]
901 async fn test_analyze_input_too_large_rejects_without_launching_agents() {
902 let responses = vec![mock_agent_json("melchior", "approve", 0.9)];
903 let provider = Arc::new(MockProvider::success("mock", "test-model", responses));
904
905 let magi = MagiBuilder::new(provider.clone() as Arc<dyn LlmProvider>)
906 .with_max_input_len(100)
907 .build()
908 .expect("build should succeed");
909
910 let content = "x".repeat(200);
911 let result = magi.analyze(&Mode::CodeReview, &content).await;
912
913 match result {
914 Err(MagiError::InputTooLarge { size, max }) => {
915 assert_eq!(size, 200);
916 assert_eq!(max, 100);
917 }
918 other => panic!("Expected InputTooLarge, got: {other:?}"),
919 }
920
921 assert_eq!(provider.calls(), 0, "No agents should have been launched");
923 }
924
925 #[test]
929 fn test_magi_config_default_values() {
930 let config = MagiConfig::default();
931 assert_eq!(config.timeout, Duration::from_secs(300));
932 assert_eq!(config.max_input_len, 4 * 1024 * 1024);
933 }
934
935 #[tokio::test]
937 async fn test_builder_with_max_input_len_overrides_default() {
938 let responses = vec![mock_agent_json("melchior", "approve", 0.9)];
939 let provider =
940 Arc::new(MockProvider::success("mock", "model", responses)) as Arc<dyn LlmProvider>;
941
942 let magi = MagiBuilder::new(provider.clone())
943 .with_max_input_len(512)
944 .build()
945 .expect("build should succeed");
946
947 let too_large = "x".repeat(513);
948 let result = magi.analyze(&Mode::CodeReview, &too_large).await;
949 match result {
950 Err(MagiError::InputTooLarge { size, max }) => {
951 assert_eq!(size, 513);
952 assert_eq!(max, 512);
953 }
954 other => panic!("Expected InputTooLarge, got: {other:?}"),
955 }
956 }
957
958 #[test]
962 fn test_parse_agent_response_strips_code_fences() {
963 let json = mock_agent_json("melchior", "approve", 0.9);
964 let raw = format!("```json\n{json}\n```");
965
966 let result = parse_agent_response(&raw);
967 let output = result.expect("should parse successfully");
968 assert_eq!(output.agent, AgentName::Melchior);
969 assert_eq!(output.verdict, Verdict::Approve);
970 }
971
972 #[test]
974 fn test_parse_agent_response_extracts_json_from_preamble() {
975 let json = mock_agent_json("melchior", "approve", 0.9);
976 let raw = format!("Here is my analysis:\n{json}");
977
978 let result = parse_agent_response(&raw);
979 assert!(result.is_ok(), "should find JSON in preamble text");
980 }
981
982 #[test]
984 fn test_parse_agent_response_fails_on_invalid_input() {
985 let result = parse_agent_response("no json here");
986 assert!(result.is_err(), "should fail on invalid input");
987 }
988
989 #[test]
993 fn test_magi_builder_build_returns_result() {
994 let responses = vec![mock_agent_json("melchior", "approve", 0.9)];
995 let provider =
996 Arc::new(MockProvider::success("mock", "model", responses)) as Arc<dyn LlmProvider>;
997
998 let magi = MagiBuilder::new(provider).build();
999 assert!(magi.is_ok());
1000 }
1001
1002 #[derive(Clone)]
1012 struct CapturingMockProvider {
1013 captured: Arc<std::sync::Mutex<Vec<(String, String)>>>,
1015 routing: Arc<std::collections::HashMap<String, AgentName>>, }
1019
1020 impl CapturingMockProvider {
1021 fn for_default_prompts(captured: Arc<std::sync::Mutex<Vec<(String, String)>>>) -> Self {
1024 let mut routing = std::collections::HashMap::new();
1025 routing.insert(
1026 crate::prompts::melchior_prompt().to_string(),
1027 AgentName::Melchior,
1028 );
1029 routing.insert(
1030 crate::prompts::balthasar_prompt().to_string(),
1031 AgentName::Balthasar,
1032 );
1033 routing.insert(
1034 crate::prompts::caspar_prompt().to_string(),
1035 AgentName::Caspar,
1036 );
1037 Self {
1038 captured,
1039 routing: Arc::new(routing),
1040 }
1041 }
1042
1043 fn with_routing(
1047 captured: Arc<std::sync::Mutex<Vec<(String, String)>>>,
1048 mappings: Vec<(&'static str, AgentName)>,
1049 ) -> Self {
1050 let mut routing = std::collections::HashMap::new();
1051 routing.insert(
1053 crate::prompts::melchior_prompt().to_string(),
1054 AgentName::Melchior,
1055 );
1056 routing.insert(
1057 crate::prompts::balthasar_prompt().to_string(),
1058 AgentName::Balthasar,
1059 );
1060 routing.insert(
1061 crate::prompts::caspar_prompt().to_string(),
1062 AgentName::Caspar,
1063 );
1064 for (custom, name) in mappings {
1065 routing.insert(custom.to_string(), name);
1066 }
1067 Self {
1068 captured,
1069 routing: Arc::new(routing),
1070 }
1071 }
1072
1073 #[allow(dead_code)]
1076 fn for_prompt_capture(captured: Arc<std::sync::Mutex<Vec<(String, String)>>>) -> Self {
1077 Self::for_default_prompts(captured)
1078 }
1079 }
1080
1081 #[async_trait::async_trait]
1082 impl LlmProvider for CapturingMockProvider {
1083 async fn complete(
1084 &self,
1085 system_prompt: &str,
1086 user_prompt: &str,
1087 _config: &CompletionConfig,
1088 ) -> Result<String, ProviderError> {
1089 self.captured
1090 .lock()
1091 .unwrap()
1092 .push((system_prompt.to_string(), user_prompt.to_string()));
1093 let agent = self
1094 .routing
1095 .get(system_prompt)
1096 .copied()
1097 .unwrap_or(AgentName::Melchior);
1098 let agent_str = match agent {
1099 AgentName::Melchior => "melchior",
1100 AgentName::Balthasar => "balthasar",
1101 AgentName::Caspar => "caspar",
1102 };
1103 Ok(mock_agent_json(agent_str, "approve", 0.9))
1104 }
1105
1106 fn name(&self) -> &str {
1107 "capturing-mock"
1108 }
1109
1110 fn model(&self) -> &str {
1111 "test-model"
1112 }
1113 }
1114
1115 #[test]
1117 fn test_with_custom_prompt_for_mode_stores_with_some_key() {
1118 let provider: Arc<dyn LlmProvider> = Arc::new(MockProvider::success(
1119 "mock",
1120 "model",
1121 vec![mock_agent_json("melchior", "approve", 0.9)],
1122 ));
1123 let magi = MagiBuilder::new(provider)
1124 .with_custom_prompt_for_mode(AgentName::Melchior, Mode::CodeReview, "X".into())
1125 .build()
1126 .expect("build should succeed");
1127 assert_eq!(
1128 magi.overrides()
1129 .get(&(AgentName::Melchior, Some(Mode::CodeReview))),
1130 Some(&"X".to_string())
1131 );
1132 }
1133
1134 #[test]
1136 fn test_with_custom_prompt_all_modes_stores_with_none_key() {
1137 let provider: Arc<dyn LlmProvider> = Arc::new(MockProvider::success(
1138 "mock",
1139 "model",
1140 vec![mock_agent_json("melchior", "approve", 0.9)],
1141 ));
1142 let magi = MagiBuilder::new(provider)
1143 .with_custom_prompt_all_modes(AgentName::Balthasar, "Y".into())
1144 .build()
1145 .expect("build should succeed");
1146 assert_eq!(
1147 magi.overrides().get(&(AgentName::Balthasar, None)),
1148 Some(&"Y".to_string())
1149 );
1150 }
1151
1152 #[test]
1154 fn test_legacy_with_custom_prompt_delegates_to_for_mode() {
1155 let provider: Arc<dyn LlmProvider> = Arc::new(MockProvider::success(
1156 "mock",
1157 "model",
1158 vec![mock_agent_json("melchior", "approve", 0.9)],
1159 ));
1160 #[allow(deprecated)]
1161 let magi = MagiBuilder::new(provider)
1162 .with_custom_prompt(AgentName::Caspar, Mode::Design, "Z".into())
1163 .build()
1164 .expect("build should succeed");
1165 assert_eq!(
1166 magi.overrides()
1167 .get(&(AgentName::Caspar, Some(Mode::Design))),
1168 Some(&"Z".to_string())
1169 );
1170 }
1171
1172 #[test]
1177 fn test_lookup_prompt_prefers_mode_specific_override() {
1178 let mut overrides = BTreeMap::new();
1179 overrides.insert(
1180 (AgentName::Melchior, Some(Mode::CodeReview)),
1181 "SPECIFIC".to_string(),
1182 );
1183 overrides.insert((AgentName::Melchior, None), "GENERIC".to_string());
1184 assert_eq!(
1185 lookup_prompt(AgentName::Melchior, Mode::CodeReview, &overrides),
1186 "SPECIFIC"
1187 );
1188 }
1189
1190 #[test]
1192 fn test_lookup_prompt_falls_back_to_mode_agnostic_when_mode_specific_missing() {
1193 let mut overrides = BTreeMap::new();
1194 overrides.insert((AgentName::Melchior, None), "GENERIC".to_string());
1195 assert_eq!(
1196 lookup_prompt(AgentName::Melchior, Mode::CodeReview, &overrides),
1197 "GENERIC"
1198 );
1199 }
1200
1201 #[test]
1203 fn test_lookup_prompt_falls_back_to_embedded_default_when_no_override() {
1204 let overrides: BTreeMap<(AgentName, Option<Mode>), String> = BTreeMap::new();
1205 let result = lookup_prompt(AgentName::Caspar, Mode::Analysis, &overrides);
1206 assert_eq!(result, crate::prompts::caspar_prompt());
1207 }
1208
1209 #[test]
1211 fn test_lookup_prompt_returns_correct_embedded_default_per_agent() {
1212 let overrides: BTreeMap<(AgentName, Option<Mode>), String> = BTreeMap::new();
1213 assert_eq!(
1214 lookup_prompt(AgentName::Melchior, Mode::CodeReview, &overrides),
1215 crate::prompts::melchior_prompt()
1216 );
1217 assert_eq!(
1218 lookup_prompt(AgentName::Balthasar, Mode::Design, &overrides),
1219 crate::prompts::balthasar_prompt()
1220 );
1221 assert_eq!(
1222 lookup_prompt(AgentName::Caspar, Mode::Analysis, &overrides),
1223 crate::prompts::caspar_prompt()
1224 );
1225 }
1226
1227 #[tokio::test]
1229 async fn test_with_rng_source_injects_nonce_observable_in_user_prompt() {
1230 let captured: Arc<std::sync::Mutex<Vec<(String, String)>>> =
1233 Arc::new(std::sync::Mutex::new(Vec::new()));
1234 let provider = Arc::new(CapturingMockProvider::for_default_prompts(captured.clone()));
1235 let nonce_val: u128 = 0x1234_5678_9abc_def0_fedc_ba98_7654_3210;
1236 let expected_nonce_hex = format!("{nonce_val:032x}");
1237
1238 let rng = Box::new(crate::user_prompt::FixedRng::new(vec![nonce_val]))
1240 as Box<dyn crate::user_prompt::RngLike + Send>;
1241 let magi = MagiBuilder::new(provider as Arc<dyn LlmProvider>)
1242 .with_rng_source(rng)
1243 .build()
1244 .expect("build should succeed");
1245 let _ = magi.analyze(&Mode::Analysis, "hello").await.unwrap();
1246
1247 let calls = captured.lock().unwrap();
1248 assert!(
1249 !calls.is_empty(),
1250 "mock should have received at least one call"
1251 );
1252 let (_, user_prompt) = &calls[0];
1253 assert!(
1254 user_prompt.contains(&expected_nonce_hex),
1255 "user_prompt should contain the fixed nonce {expected_nonce_hex}"
1256 );
1257 }
1258
1259 #[tokio::test]
1265 async fn test_analyze_applies_mode_agnostic_override_to_melchior() {
1266 let captured = Arc::new(std::sync::Mutex::new(Vec::new()));
1267 let provider = Arc::new(CapturingMockProvider::with_routing(
1268 captured.clone(),
1269 vec![("CUSTOM MEL", AgentName::Melchior)],
1270 ));
1271 let magi = MagiBuilder::new(provider as Arc<dyn LlmProvider>)
1272 .with_custom_prompt_all_modes(AgentName::Melchior, "CUSTOM MEL".into())
1273 .build()
1274 .expect("build should succeed");
1275 let _ = magi.analyze(&Mode::Design, "x").await.unwrap();
1276 let calls = captured.lock().unwrap();
1277 assert!(
1278 calls.iter().any(|(sys, _)| sys == "CUSTOM MEL"),
1279 "Melchior should have received the mode-agnostic custom prompt"
1280 );
1281 }
1282
1283 #[tokio::test]
1287 async fn test_analyze_per_mode_override_supersedes_all_modes() {
1288 let captured = Arc::new(std::sync::Mutex::new(Vec::new()));
1289 let provider = Arc::new(CapturingMockProvider::with_routing(
1290 captured.clone(),
1291 vec![
1292 ("GENERIC MEL", AgentName::Melchior),
1293 ("SPECIFIC MEL", AgentName::Melchior),
1294 ],
1295 ));
1296 let magi = MagiBuilder::new(provider as Arc<dyn LlmProvider>)
1297 .with_custom_prompt_all_modes(AgentName::Melchior, "GENERIC MEL".into())
1298 .with_custom_prompt_for_mode(AgentName::Melchior, Mode::Design, "SPECIFIC MEL".into())
1299 .build()
1300 .expect("build should succeed");
1301 let _ = magi.analyze(&Mode::Design, "x").await.unwrap();
1302 let calls = captured.lock().unwrap();
1303 assert!(
1304 calls.iter().any(|(sys, _)| sys == "SPECIFIC MEL"),
1305 "mode-specific prompt should have been used for Mode::Design"
1306 );
1307 assert!(
1308 !calls.iter().any(|(sys, _)| sys == "GENERIC MEL"),
1309 "mode-agnostic prompt must NOT be used when a mode-specific one is present"
1310 );
1311 }
1312
1313 #[tokio::test]
1317 async fn test_analyze_nonce_collision_returns_invalid_input() {
1318 let captured = Arc::new(std::sync::Mutex::new(Vec::new()));
1319 let provider = Arc::new(CapturingMockProvider::for_default_prompts(captured));
1320 let fixed_nonce_val: u128 = 0x1234_5678_9012_3456_7890_1234_5678_9012;
1321 let fixed_nonce_hex = format!("{fixed_nonce_val:032x}");
1322 let colliding_content = fixed_nonce_hex.clone();
1324
1325 let magi = MagiBuilder::new(provider as Arc<dyn LlmProvider>)
1326 .with_rng_source(Box::new(crate::user_prompt::FixedRng::new(vec![
1327 fixed_nonce_val,
1328 ])))
1329 .build()
1330 .expect("build should succeed");
1331
1332 let result = magi.analyze(&Mode::Analysis, &colliding_content).await;
1333 assert!(
1334 matches!(result, Err(MagiError::InvalidInput { .. })),
1335 "nonce collision must yield MagiError::InvalidInput, got: {result:?}"
1336 );
1337 }
1338
1339 #[tokio::test]
1343 #[allow(deprecated)]
1344 async fn test_legacy_with_custom_prompt_shim_roundtrip() {
1345 let captured_legacy = Arc::new(std::sync::Mutex::new(Vec::new()));
1346 let captured_new = Arc::new(std::sync::Mutex::new(Vec::new()));
1347
1348 let provider_legacy = Arc::new(CapturingMockProvider::with_routing(
1349 captured_legacy.clone(),
1350 vec![("SHIM PROMPT", AgentName::Caspar)],
1351 ));
1352 let provider_new = Arc::new(CapturingMockProvider::with_routing(
1353 captured_new.clone(),
1354 vec![("SHIM PROMPT", AgentName::Caspar)],
1355 ));
1356
1357 let magi_legacy = MagiBuilder::new(provider_legacy as Arc<dyn LlmProvider>)
1358 .with_custom_prompt(AgentName::Caspar, Mode::CodeReview, "SHIM PROMPT".into())
1359 .build()
1360 .expect("legacy build should succeed");
1361
1362 let magi_new = MagiBuilder::new(provider_new as Arc<dyn LlmProvider>)
1363 .with_custom_prompt_for_mode(AgentName::Caspar, Mode::CodeReview, "SHIM PROMPT".into())
1364 .build()
1365 .expect("new build should succeed");
1366
1367 let _ = magi_legacy
1368 .analyze(&Mode::CodeReview, "test")
1369 .await
1370 .unwrap();
1371 let _ = magi_new.analyze(&Mode::CodeReview, "test").await.unwrap();
1372
1373 let legacy_calls = captured_legacy.lock().unwrap();
1374 let new_calls = captured_new.lock().unwrap();
1375
1376 assert!(
1378 legacy_calls.iter().any(|(sys, _)| sys == "SHIM PROMPT"),
1379 "legacy shim must forward the custom prompt to Caspar"
1380 );
1381 assert!(
1382 new_calls.iter().any(|(sys, _)| sys == "SHIM PROMPT"),
1383 "new API must forward the custom prompt to Caspar"
1384 );
1385 }
1386
1387 #[tokio::test]
1393 async fn test_analyze_respects_prompts_dir_loaded_files() {
1394 struct TmpDir(std::path::PathBuf);
1396 impl Drop for TmpDir {
1397 fn drop(&mut self) {
1398 let _ = std::fs::remove_dir_all(&self.0);
1399 }
1400 }
1401
1402 let uniq = std::time::SystemTime::now()
1404 .duration_since(std::time::SystemTime::UNIX_EPOCH)
1405 .unwrap_or_default()
1406 .as_nanos();
1407 let tmp = TmpDir(std::env::temp_dir().join(format!(
1408 "magi_v03_test_{}_{}",
1409 std::process::id(),
1410 uniq
1411 )));
1412 std::fs::create_dir_all(&tmp.0).unwrap();
1413
1414 std::fs::write(
1416 tmp.0.join("melchior_code_review.md"),
1417 "CUSTOM FROM FILESYSTEM",
1418 )
1419 .unwrap();
1420
1421 let captured: Arc<std::sync::Mutex<Vec<(String, String)>>> =
1422 Arc::new(std::sync::Mutex::new(Vec::new()));
1423 let provider = Arc::new(CapturingMockProvider::with_routing(
1424 captured.clone(),
1425 vec![("CUSTOM FROM FILESYSTEM", AgentName::Melchior)],
1426 ));
1427 let magi = MagiBuilder::new(provider as Arc<dyn LlmProvider>)
1428 .with_prompts_dir(tmp.0.clone())
1429 .build()
1430 .expect("build should succeed");
1431 let _ = magi.analyze(&Mode::CodeReview, "x").await.unwrap();
1432
1433 let calls = captured.lock().unwrap();
1434 assert!(
1435 calls.iter().any(|(sys, _)| sys == "CUSTOM FROM FILESYSTEM"),
1436 "with_prompts_dir file-based prompt should reach Melchior"
1437 );
1438 }
1440
1441 #[tokio::test]
1448 async fn test_analyze_shares_same_nonce_across_all_three_agents() {
1449 let captured: Arc<std::sync::Mutex<Vec<(String, String)>>> =
1450 Arc::new(std::sync::Mutex::new(Vec::new()));
1451 let provider = Arc::new(CapturingMockProvider::for_default_prompts(captured.clone()));
1452 let fixed: u128 = 0xabcd_ef01_2345_6789_0000_0000_0000_0001;
1453 let expected_nonce = format!("{fixed:032x}");
1454 let magi = MagiBuilder::new(provider as Arc<dyn LlmProvider>)
1455 .with_rng_source(Box::new(crate::user_prompt::FixedRng::new(vec![fixed])))
1456 .build()
1457 .expect("build should succeed");
1458 let _ = magi.analyze(&Mode::Analysis, "hello").await.unwrap();
1459 let calls = captured.lock().unwrap();
1460 assert_eq!(calls.len(), 3, "expected 3 agent calls per analyze");
1461 for (idx, (_, up)) in calls.iter().enumerate() {
1462 assert!(
1463 up.contains(&expected_nonce),
1464 "call {idx} user_prompt missing expected nonce"
1465 );
1466 }
1467 }
1468}