reasonkit/m2/
protocol_generator.rs

1//! M2 Protocol Generator
2//!
3//! Responsible for generating `InterleavedProtocol` definitions based on task classification
4//! and high-level requirements.
5
6use crate::error::Error;
7use crate::m2::types::*;
8use uuid::Uuid;
9
10#[derive(Debug, Clone, Default)]
11pub struct ProtocolGenerator;
12
13impl ProtocolGenerator {
14    pub fn new() -> Self {
15        Self
16    }
17
18    /// Generates a full execution protocol based on the classified task.
19    pub fn generate_protocol(
20        &self,
21        classification: &TaskClassification,
22        name: Option<String>,
23    ) -> Result<InterleavedProtocol, Error> {
24        let protocol_id = Uuid::new_v4().to_string();
25        let name = name.unwrap_or_else(|| format!("m2-protocol-{}", protocol_id));
26
27        // Determine phases based on complexity and task type
28        let phases = self.determine_phases(classification);
29
30        // Determine global constraints
31        let constraints = self.determine_constraints(classification);
32
33        // Determine optimizations
34        let optimizations = self.determine_optimizations(classification);
35
36        Ok(InterleavedProtocol {
37            id: protocol_id,
38            name,
39            version: "1.0.0".to_string(),
40            description: format!("Generated protocol for {:?} task", classification.task_type),
41            phases,
42            constraints,
43            m2_optimizations: optimizations,
44            framework_compatibility: vec![],
45            language_support: vec![],
46        })
47    }
48
49    fn determine_phases(&self, classification: &TaskClassification) -> Vec<InterleavedPhase> {
50        match classification.complexity_level {
51            ComplexityLevel::Simple => vec![
52                self.create_phase("reasoning", 1, 0.7),
53                self.create_phase("verification", 1, 0.8),
54            ],
55            ComplexityLevel::Moderate => vec![
56                self.create_phase("analysis", 2, 0.75),
57                self.create_phase("synthesis", 1, 0.8),
58                self.create_phase("verification", 1, 0.85),
59            ],
60            ComplexityLevel::Complex => vec![
61                self.create_phase("decomposition", 1, 0.8),
62                self.create_phase("parallel_analysis", 3, 0.8),
63                self.create_phase("integration", 1, 0.85),
64                self.create_phase("final_validation", 2, 0.9),
65            ],
66        }
67    }
68
69    fn create_phase(&self, name: &str, branches: u32, confidence: f64) -> InterleavedPhase {
70        InterleavedPhase {
71            name: name.to_string(),
72            parallel_branches: branches,
73            required_confidence: confidence,
74            validation_methods: vec![ValidationMethod::SelfCheck],
75            synthesis_methods: vec![SynthesisMethod::WeightedAverage],
76            constraints: CompositeConstraints {
77                time_budget_ms: 5000,
78                token_budget: 4000,
79                dependencies: vec![],
80            },
81        }
82    }
83
84    fn determine_constraints(&self, classification: &TaskClassification) -> CompositeConstraints {
85        let (time, tokens) = match classification.expected_output_size {
86            OutputSize::Small => (10_000, 2000),
87            OutputSize::Medium => (30_000, 8000),
88            OutputSize::Large => (60_000, 32000),
89        };
90
91        CompositeConstraints {
92            time_budget_ms: time,
93            token_budget: tokens,
94            dependencies: vec![],
95        }
96    }
97
98    fn determine_optimizations(&self, _classification: &TaskClassification) -> M2Optimizations {
99        M2Optimizations {
100            target_parameters: 200_000_000_000,
101            context_optimization: ContextOptimization {
102                method: "auto".to_string(),
103                compression_ratio: 0.8,
104            },
105            output_optimization: OutputOptimization {
106                format: "markdown".to_string(),
107                template: "standard".to_string(),
108                max_output_length: 16000,
109                streaming_enabled: true,
110                compression_enabled: false,
111            },
112            cost_optimization: CostOptimization {
113                strategy: "balanced".to_string(),
114                max_budget: 5.0,
115                target_cost_reduction: 0.5,
116                target_latency_reduction: 0.2,
117                parallel_processing_enabled: true,
118                caching_enabled: true,
119            },
120        }
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127
128    // ============================================================================
129    // HELPER FUNCTIONS
130    // ============================================================================
131
132    /// Creates a TaskClassification with specified parameters for testing
133    fn create_classification(
134        task_type: TaskType,
135        complexity: ComplexityLevel,
136        domain: TaskDomain,
137        output_size: OutputSize,
138    ) -> TaskClassification {
139        TaskClassification {
140            task_type,
141            complexity_level: complexity,
142            domain,
143            expected_output_size: output_size,
144            time_constraints: TimeConstraints::default(),
145            quality_requirements: QualityRequirements::default(),
146        }
147    }
148
149    /// Creates a simple classification for quick tests
150    fn simple_classification() -> TaskClassification {
151        create_classification(
152            TaskType::General,
153            ComplexityLevel::Simple,
154            TaskDomain::General,
155            OutputSize::Small,
156        )
157    }
158
159    /// Creates a moderate classification for quick tests
160    fn moderate_classification() -> TaskClassification {
161        create_classification(
162            TaskType::Documentation,
163            ComplexityLevel::Moderate,
164            TaskDomain::General,
165            OutputSize::Medium,
166        )
167    }
168
169    /// Creates a complex classification for quick tests
170    fn complex_classification() -> TaskClassification {
171        create_classification(
172            TaskType::CodeAnalysis,
173            ComplexityLevel::Complex,
174            TaskDomain::SystemProgramming,
175            OutputSize::Large,
176        )
177    }
178
179    // ============================================================================
180    // PROTOCOL STRUCTURE GENERATION TESTS
181    // ============================================================================
182
183    #[test]
184    fn test_simple_protocol_generation() {
185        let generator = ProtocolGenerator::new();
186        let classification = simple_classification();
187
188        let protocol = generator
189            .generate_protocol(&classification, Some("test-proto".into()))
190            .unwrap();
191
192        assert_eq!(protocol.name, "test-proto");
193        assert_eq!(protocol.phases.len(), 2);
194        assert_eq!(protocol.phases[0].name, "reasoning");
195        assert_eq!(protocol.phases[1].name, "verification");
196    }
197
198    #[test]
199    fn test_complex_protocol_generation() {
200        let generator = ProtocolGenerator::new();
201        let classification = complex_classification();
202
203        let protocol = generator.generate_protocol(&classification, None).unwrap();
204
205        assert_eq!(protocol.phases.len(), 4);
206        assert_eq!(protocol.phases[1].parallel_branches, 3); // parallel_analysis
207        assert!(protocol.constraints.token_budget >= 30000);
208    }
209
210    #[test]
211    fn test_moderate_protocol_generation() {
212        let generator = ProtocolGenerator::new();
213        let classification = moderate_classification();
214
215        let protocol = generator.generate_protocol(&classification, None).unwrap();
216
217        assert_eq!(protocol.phases.len(), 3);
218        assert_eq!(protocol.phases[0].name, "analysis");
219        assert_eq!(protocol.phases[1].name, "synthesis");
220        assert_eq!(protocol.phases[2].name, "verification");
221    }
222
223    #[test]
224    fn test_protocol_has_unique_id() {
225        let generator = ProtocolGenerator::new();
226        let classification = simple_classification();
227
228        let protocol1 = generator.generate_protocol(&classification, None).unwrap();
229        let protocol2 = generator.generate_protocol(&classification, None).unwrap();
230
231        assert_ne!(protocol1.id, protocol2.id);
232        assert!(!protocol1.id.is_empty());
233        assert!(!protocol2.id.is_empty());
234    }
235
236    #[test]
237    fn test_protocol_id_is_valid_uuid() {
238        let generator = ProtocolGenerator::new();
239        let classification = simple_classification();
240
241        let protocol = generator.generate_protocol(&classification, None).unwrap();
242
243        // Attempt to parse the ID as a UUID - should succeed
244        let parsed = Uuid::parse_str(&protocol.id);
245        assert!(parsed.is_ok(), "Protocol ID should be a valid UUID");
246    }
247
248    #[test]
249    fn test_protocol_version_is_set() {
250        let generator = ProtocolGenerator::new();
251        let classification = simple_classification();
252
253        let protocol = generator.generate_protocol(&classification, None).unwrap();
254
255        assert_eq!(protocol.version, "1.0.0");
256    }
257
258    #[test]
259    fn test_protocol_description_includes_task_type() {
260        let generator = ProtocolGenerator::new();
261
262        // Test with General task type
263        let general_classification = simple_classification();
264        let protocol = generator
265            .generate_protocol(&general_classification, None)
266            .unwrap();
267        assert!(protocol.description.contains("General"));
268
269        // Test with CodeAnalysis task type
270        let code_classification = create_classification(
271            TaskType::CodeAnalysis,
272            ComplexityLevel::Simple,
273            TaskDomain::General,
274            OutputSize::Small,
275        );
276        let protocol = generator
277            .generate_protocol(&code_classification, None)
278            .unwrap();
279        assert!(protocol.description.contains("CodeAnalysis"));
280    }
281
282    #[test]
283    fn test_auto_generated_name_contains_protocol_id() {
284        let generator = ProtocolGenerator::new();
285        let classification = simple_classification();
286
287        let protocol = generator.generate_protocol(&classification, None).unwrap();
288
289        assert!(protocol.name.starts_with("m2-protocol-"));
290        assert!(protocol.name.contains(&protocol.id));
291    }
292
293    #[test]
294    fn test_custom_name_is_preserved() {
295        let generator = ProtocolGenerator::new();
296        let classification = simple_classification();
297
298        let custom_name = "my-custom-protocol-name";
299        let protocol = generator
300            .generate_protocol(&classification, Some(custom_name.to_string()))
301            .unwrap();
302
303        assert_eq!(protocol.name, custom_name);
304    }
305
306    // ============================================================================
307    // PHASE STRUCTURE TESTS
308    // ============================================================================
309
310    #[test]
311    fn test_simple_phases_have_correct_confidence() {
312        let generator = ProtocolGenerator::new();
313        let classification = simple_classification();
314
315        let protocol = generator.generate_protocol(&classification, None).unwrap();
316
317        assert_eq!(protocol.phases[0].required_confidence, 0.7);
318        assert_eq!(protocol.phases[1].required_confidence, 0.8);
319    }
320
321    #[test]
322    fn test_moderate_phases_have_correct_confidence() {
323        let generator = ProtocolGenerator::new();
324        let classification = moderate_classification();
325
326        let protocol = generator.generate_protocol(&classification, None).unwrap();
327
328        assert_eq!(protocol.phases[0].required_confidence, 0.75); // analysis
329        assert_eq!(protocol.phases[1].required_confidence, 0.8); // synthesis
330        assert_eq!(protocol.phases[2].required_confidence, 0.85); // verification
331    }
332
333    #[test]
334    fn test_complex_phases_have_correct_confidence() {
335        let generator = ProtocolGenerator::new();
336        let classification = complex_classification();
337
338        let protocol = generator.generate_protocol(&classification, None).unwrap();
339
340        assert_eq!(protocol.phases[0].required_confidence, 0.8); // decomposition
341        assert_eq!(protocol.phases[1].required_confidence, 0.8); // parallel_analysis
342        assert_eq!(protocol.phases[2].required_confidence, 0.85); // integration
343        assert_eq!(protocol.phases[3].required_confidence, 0.9); // final_validation
344    }
345
346    #[test]
347    fn test_simple_phases_have_single_branch() {
348        let generator = ProtocolGenerator::new();
349        let classification = simple_classification();
350
351        let protocol = generator.generate_protocol(&classification, None).unwrap();
352
353        for phase in &protocol.phases {
354            assert_eq!(phase.parallel_branches, 1);
355        }
356    }
357
358    #[test]
359    fn test_moderate_analysis_has_two_branches() {
360        let generator = ProtocolGenerator::new();
361        let classification = moderate_classification();
362
363        let protocol = generator.generate_protocol(&classification, None).unwrap();
364
365        assert_eq!(protocol.phases[0].parallel_branches, 2); // analysis
366        assert_eq!(protocol.phases[1].parallel_branches, 1); // synthesis
367        assert_eq!(protocol.phases[2].parallel_branches, 1); // verification
368    }
369
370    #[test]
371    fn test_complex_parallel_analysis_has_three_branches() {
372        let generator = ProtocolGenerator::new();
373        let classification = complex_classification();
374
375        let protocol = generator.generate_protocol(&classification, None).unwrap();
376
377        assert_eq!(protocol.phases[0].parallel_branches, 1); // decomposition
378        assert_eq!(protocol.phases[1].parallel_branches, 3); // parallel_analysis
379        assert_eq!(protocol.phases[2].parallel_branches, 1); // integration
380        assert_eq!(protocol.phases[3].parallel_branches, 2); // final_validation
381    }
382
383    #[test]
384    fn test_all_phases_have_validation_methods() {
385        let generator = ProtocolGenerator::new();
386        let classification = complex_classification();
387
388        let protocol = generator.generate_protocol(&classification, None).unwrap();
389
390        for phase in &protocol.phases {
391            assert!(!phase.validation_methods.is_empty());
392            assert!(phase
393                .validation_methods
394                .contains(&ValidationMethod::SelfCheck));
395        }
396    }
397
398    #[test]
399    fn test_all_phases_have_synthesis_methods() {
400        let generator = ProtocolGenerator::new();
401        let classification = complex_classification();
402
403        let protocol = generator.generate_protocol(&classification, None).unwrap();
404
405        for phase in &protocol.phases {
406            assert!(!phase.synthesis_methods.is_empty());
407            assert!(phase
408                .synthesis_methods
409                .contains(&SynthesisMethod::WeightedAverage));
410        }
411    }
412
413    #[test]
414    fn test_phase_constraints_have_default_values() {
415        let generator = ProtocolGenerator::new();
416        let classification = simple_classification();
417
418        let protocol = generator.generate_protocol(&classification, None).unwrap();
419
420        for phase in &protocol.phases {
421            assert_eq!(phase.constraints.time_budget_ms, 5000);
422            assert_eq!(phase.constraints.token_budget, 4000);
423            assert!(phase.constraints.dependencies.is_empty());
424        }
425    }
426
427    // ============================================================================
428    // CONSTRAINT TESTS BY OUTPUT SIZE
429    // ============================================================================
430
431    #[test]
432    fn test_small_output_constraints() {
433        let generator = ProtocolGenerator::new();
434        let classification = create_classification(
435            TaskType::General,
436            ComplexityLevel::Simple,
437            TaskDomain::General,
438            OutputSize::Small,
439        );
440
441        let protocol = generator.generate_protocol(&classification, None).unwrap();
442
443        assert_eq!(protocol.constraints.time_budget_ms, 10_000);
444        assert_eq!(protocol.constraints.token_budget, 2000);
445    }
446
447    #[test]
448    fn test_medium_output_constraints() {
449        let generator = ProtocolGenerator::new();
450        let classification = create_classification(
451            TaskType::General,
452            ComplexityLevel::Simple,
453            TaskDomain::General,
454            OutputSize::Medium,
455        );
456
457        let protocol = generator.generate_protocol(&classification, None).unwrap();
458
459        assert_eq!(protocol.constraints.time_budget_ms, 30_000);
460        assert_eq!(protocol.constraints.token_budget, 8000);
461    }
462
463    #[test]
464    fn test_large_output_constraints() {
465        let generator = ProtocolGenerator::new();
466        let classification = create_classification(
467            TaskType::General,
468            ComplexityLevel::Simple,
469            TaskDomain::General,
470            OutputSize::Large,
471        );
472
473        let protocol = generator.generate_protocol(&classification, None).unwrap();
474
475        assert_eq!(protocol.constraints.time_budget_ms, 60_000);
476        assert_eq!(protocol.constraints.token_budget, 32000);
477    }
478
479    #[test]
480    fn test_constraints_have_empty_dependencies() {
481        let generator = ProtocolGenerator::new();
482        let classification = complex_classification();
483
484        let protocol = generator.generate_protocol(&classification, None).unwrap();
485
486        assert!(protocol.constraints.dependencies.is_empty());
487    }
488
489    // ============================================================================
490    // OPTIMIZATION TESTS
491    // ============================================================================
492
493    #[test]
494    fn test_optimizations_target_parameters() {
495        let generator = ProtocolGenerator::new();
496        let classification = simple_classification();
497
498        let protocol = generator.generate_protocol(&classification, None).unwrap();
499
500        assert_eq!(protocol.m2_optimizations.target_parameters, 200_000_000_000);
501    }
502
503    #[test]
504    fn test_context_optimization_defaults() {
505        let generator = ProtocolGenerator::new();
506        let classification = simple_classification();
507
508        let protocol = generator.generate_protocol(&classification, None).unwrap();
509
510        assert_eq!(
511            protocol.m2_optimizations.context_optimization.method,
512            "auto"
513        );
514        assert_eq!(
515            protocol
516                .m2_optimizations
517                .context_optimization
518                .compression_ratio,
519            0.8
520        );
521    }
522
523    #[test]
524    fn test_output_optimization_defaults() {
525        let generator = ProtocolGenerator::new();
526        let classification = simple_classification();
527
528        let protocol = generator.generate_protocol(&classification, None).unwrap();
529
530        let output_opt = &protocol.m2_optimizations.output_optimization;
531        assert_eq!(output_opt.format, "markdown");
532        assert_eq!(output_opt.template, "standard");
533        assert_eq!(output_opt.max_output_length, 16000);
534        assert!(output_opt.streaming_enabled);
535        assert!(!output_opt.compression_enabled);
536    }
537
538    #[test]
539    fn test_cost_optimization_defaults() {
540        let generator = ProtocolGenerator::new();
541        let classification = simple_classification();
542
543        let protocol = generator.generate_protocol(&classification, None).unwrap();
544
545        let cost_opt = &protocol.m2_optimizations.cost_optimization;
546        assert_eq!(cost_opt.strategy, "balanced");
547        assert_eq!(cost_opt.max_budget, 5.0);
548        assert_eq!(cost_opt.target_cost_reduction, 0.5);
549        assert_eq!(cost_opt.target_latency_reduction, 0.2);
550        assert!(cost_opt.parallel_processing_enabled);
551        assert!(cost_opt.caching_enabled);
552    }
553
554    // ============================================================================
555    // SERIALIZATION / DESERIALIZATION TESTS
556    // ============================================================================
557
558    #[test]
559    fn test_protocol_serialization_roundtrip() {
560        let generator = ProtocolGenerator::new();
561        let classification = complex_classification();
562
563        let protocol = generator.generate_protocol(&classification, None).unwrap();
564
565        // Serialize to JSON
566        let json = serde_json::to_string(&protocol).expect("Serialization should succeed");
567
568        // Deserialize back
569        let deserialized: InterleavedProtocol =
570            serde_json::from_str(&json).expect("Deserialization should succeed");
571
572        // Verify key fields match
573        assert_eq!(protocol.id, deserialized.id);
574        assert_eq!(protocol.name, deserialized.name);
575        assert_eq!(protocol.version, deserialized.version);
576        assert_eq!(protocol.description, deserialized.description);
577        assert_eq!(protocol.phases.len(), deserialized.phases.len());
578    }
579
580    #[test]
581    fn test_protocol_pretty_json_serialization() {
582        let generator = ProtocolGenerator::new();
583        let classification = simple_classification();
584
585        let protocol = generator.generate_protocol(&classification, None).unwrap();
586
587        // Pretty print should succeed
588        let pretty_json =
589            serde_json::to_string_pretty(&protocol).expect("Pretty serialization should succeed");
590
591        // Should contain expected structure
592        assert!(pretty_json.contains("\"phases\""));
593        assert!(pretty_json.contains("\"constraints\""));
594        assert!(pretty_json.contains("\"m2_optimizations\""));
595    }
596
597    #[test]
598    fn test_phase_serialization_roundtrip() {
599        let generator = ProtocolGenerator::new();
600        let classification = simple_classification();
601
602        let protocol = generator.generate_protocol(&classification, None).unwrap();
603        let original_phase = &protocol.phases[0];
604
605        let json = serde_json::to_string(original_phase).expect("Phase serialization should work");
606        let deserialized: InterleavedPhase =
607            serde_json::from_str(&json).expect("Phase deserialization should work");
608
609        assert_eq!(original_phase.name, deserialized.name);
610        assert_eq!(
611            original_phase.parallel_branches,
612            deserialized.parallel_branches
613        );
614        assert_eq!(
615            original_phase.required_confidence,
616            deserialized.required_confidence
617        );
618    }
619
620    #[test]
621    fn test_constraints_serialization_roundtrip() {
622        let generator = ProtocolGenerator::new();
623        let classification = complex_classification();
624
625        let protocol = generator.generate_protocol(&classification, None).unwrap();
626        let original = &protocol.constraints;
627
628        let json = serde_json::to_string(original).expect("Constraints serialization should work");
629        let deserialized: CompositeConstraints =
630            serde_json::from_str(&json).expect("Constraints deserialization should work");
631
632        assert_eq!(original.time_budget_ms, deserialized.time_budget_ms);
633        assert_eq!(original.token_budget, deserialized.token_budget);
634        assert_eq!(original.dependencies.len(), deserialized.dependencies.len());
635    }
636
637    #[test]
638    fn test_optimizations_serialization_roundtrip() {
639        let generator = ProtocolGenerator::new();
640        let classification = simple_classification();
641
642        let protocol = generator.generate_protocol(&classification, None).unwrap();
643        let original = &protocol.m2_optimizations;
644
645        let json =
646            serde_json::to_string(original).expect("Optimizations serialization should work");
647        let deserialized: M2Optimizations =
648            serde_json::from_str(&json).expect("Optimizations deserialization should work");
649
650        assert_eq!(original.target_parameters, deserialized.target_parameters);
651        assert_eq!(
652            original.context_optimization.method,
653            deserialized.context_optimization.method
654        );
655        assert_eq!(
656            original.output_optimization.format,
657            deserialized.output_optimization.format
658        );
659        assert_eq!(
660            original.cost_optimization.strategy,
661            deserialized.cost_optimization.strategy
662        );
663    }
664
665    // ============================================================================
666    // VARIOUS PROTOCOL TYPES / TASK TYPES TESTS
667    // ============================================================================
668
669    #[test]
670    fn test_code_analysis_protocol() {
671        let generator = ProtocolGenerator::new();
672        let classification = create_classification(
673            TaskType::CodeAnalysis,
674            ComplexityLevel::Complex,
675            TaskDomain::SystemProgramming,
676            OutputSize::Large,
677        );
678
679        let protocol = generator
680            .generate_protocol(&classification, Some("code-analysis".into()))
681            .unwrap();
682
683        assert!(protocol.description.contains("CodeAnalysis"));
684        assert_eq!(protocol.phases.len(), 4); // Complex has 4 phases
685    }
686
687    #[test]
688    fn test_bug_finding_protocol() {
689        let generator = ProtocolGenerator::new();
690        let classification = create_classification(
691            TaskType::BugFinding,
692            ComplexityLevel::Moderate,
693            TaskDomain::SystemProgramming,
694            OutputSize::Medium,
695        );
696
697        let protocol = generator
698            .generate_protocol(&classification, Some("bug-finding".into()))
699            .unwrap();
700
701        assert!(protocol.description.contains("BugFinding"));
702        assert_eq!(protocol.phases.len(), 3); // Moderate has 3 phases
703    }
704
705    #[test]
706    fn test_documentation_protocol() {
707        let generator = ProtocolGenerator::new();
708        let classification = create_classification(
709            TaskType::Documentation,
710            ComplexityLevel::Simple,
711            TaskDomain::General,
712            OutputSize::Medium,
713        );
714
715        let protocol = generator
716            .generate_protocol(&classification, Some("docs".into()))
717            .unwrap();
718
719        assert!(protocol.description.contains("Documentation"));
720        assert_eq!(protocol.phases.len(), 2); // Simple has 2 phases
721    }
722
723    #[test]
724    fn test_architecture_protocol() {
725        let generator = ProtocolGenerator::new();
726        let classification = create_classification(
727            TaskType::Architecture,
728            ComplexityLevel::Complex,
729            TaskDomain::General,
730            OutputSize::Large,
731        );
732
733        let protocol = generator
734            .generate_protocol(&classification, Some("architecture".into()))
735            .unwrap();
736
737        assert!(protocol.description.contains("Architecture"));
738        assert_eq!(protocol.phases.len(), 4); // Complex has 4 phases
739    }
740
741    #[test]
742    fn test_general_protocol() {
743        let generator = ProtocolGenerator::new();
744        let classification = create_classification(
745            TaskType::General,
746            ComplexityLevel::Moderate,
747            TaskDomain::General,
748            OutputSize::Small,
749        );
750
751        let protocol = generator.generate_protocol(&classification, None).unwrap();
752
753        assert!(protocol.description.contains("General"));
754    }
755
756    // ============================================================================
757    // DOMAIN VARIATION TESTS
758    // ============================================================================
759
760    #[test]
761    fn test_system_programming_domain() {
762        let generator = ProtocolGenerator::new();
763        let classification = create_classification(
764            TaskType::CodeAnalysis,
765            ComplexityLevel::Complex,
766            TaskDomain::SystemProgramming,
767            OutputSize::Large,
768        );
769
770        let protocol = generator.generate_protocol(&classification, None).unwrap();
771
772        // Protocol should be generated regardless of domain
773        assert!(!protocol.id.is_empty());
774        assert_eq!(protocol.phases.len(), 4);
775    }
776
777    #[test]
778    fn test_web_domain() {
779        let generator = ProtocolGenerator::new();
780        let classification = create_classification(
781            TaskType::CodeAnalysis,
782            ComplexityLevel::Moderate,
783            TaskDomain::Web,
784            OutputSize::Medium,
785        );
786
787        let protocol = generator.generate_protocol(&classification, None).unwrap();
788
789        assert!(!protocol.id.is_empty());
790        assert_eq!(protocol.phases.len(), 3);
791    }
792
793    #[test]
794    fn test_data_domain() {
795        let generator = ProtocolGenerator::new();
796        let classification = create_classification(
797            TaskType::General,
798            ComplexityLevel::Simple,
799            TaskDomain::Data,
800            OutputSize::Small,
801        );
802
803        let protocol = generator.generate_protocol(&classification, None).unwrap();
804
805        assert!(!protocol.id.is_empty());
806        assert_eq!(protocol.phases.len(), 2);
807    }
808
809    // ============================================================================
810    // EDGE CASE TESTS
811    // ============================================================================
812
813    #[test]
814    fn test_empty_name_option() {
815        let generator = ProtocolGenerator::new();
816        let classification = simple_classification();
817
818        // None should generate auto name
819        let protocol = generator.generate_protocol(&classification, None).unwrap();
820        assert!(protocol.name.starts_with("m2-protocol-"));
821
822        // Empty string should be preserved (if explicitly provided)
823        let protocol = generator
824            .generate_protocol(&classification, Some(String::new()))
825            .unwrap();
826        assert!(protocol.name.is_empty());
827    }
828
829    #[test]
830    fn test_special_characters_in_name() {
831        let generator = ProtocolGenerator::new();
832        let classification = simple_classification();
833
834        let special_name = "protocol-with-special_chars.v1.2.3!@#$%";
835        let protocol = generator
836            .generate_protocol(&classification, Some(special_name.to_string()))
837            .unwrap();
838
839        assert_eq!(protocol.name, special_name);
840    }
841
842    #[test]
843    fn test_unicode_in_name() {
844        let generator = ProtocolGenerator::new();
845        let classification = simple_classification();
846
847        let unicode_name = "protocol-unicode";
848        let protocol = generator
849            .generate_protocol(&classification, Some(unicode_name.to_string()))
850            .unwrap();
851
852        assert_eq!(protocol.name, unicode_name);
853    }
854
855    #[test]
856    fn test_generator_is_stateless() {
857        let generator = ProtocolGenerator::new();
858        let classification1 = simple_classification();
859        let classification2 = complex_classification();
860
861        // Generate multiple protocols
862        let p1 = generator.generate_protocol(&classification1, None).unwrap();
863        let p2 = generator.generate_protocol(&classification2, None).unwrap();
864        let p3 = generator.generate_protocol(&classification1, None).unwrap();
865
866        // All should be independent
867        assert_ne!(p1.id, p2.id);
868        assert_ne!(p1.id, p3.id);
869        assert_ne!(p2.id, p3.id);
870
871        // Same classification should produce same structure
872        assert_eq!(p1.phases.len(), p3.phases.len());
873        assert_ne!(p1.phases.len(), p2.phases.len());
874    }
875
876    #[test]
877    fn test_generator_clone() {
878        let generator1 = ProtocolGenerator::new();
879        let generator2 = generator1.clone();
880
881        let classification = simple_classification();
882
883        let p1 = generator1.generate_protocol(&classification, None).unwrap();
884        let p2 = generator2.generate_protocol(&classification, None).unwrap();
885
886        // Both should produce valid protocols
887        assert!(!p1.id.is_empty());
888        assert!(!p2.id.is_empty());
889        assert_eq!(p1.phases.len(), p2.phases.len());
890    }
891
892    #[test]
893    fn test_generator_default() {
894        let generator = ProtocolGenerator;
895        let classification = simple_classification();
896
897        let protocol = generator.generate_protocol(&classification, None).unwrap();
898
899        assert!(!protocol.id.is_empty());
900        assert!(!protocol.name.is_empty());
901    }
902
903    // ============================================================================
904    // VALIDATION METHOD AND SYNTHESIS METHOD TESTS
905    // ============================================================================
906
907    #[test]
908    fn test_validation_methods_are_consistent() {
909        let generator = ProtocolGenerator::new();
910
911        // Test across all complexity levels
912        for complexity in [
913            ComplexityLevel::Simple,
914            ComplexityLevel::Moderate,
915            ComplexityLevel::Complex,
916        ] {
917            let classification = create_classification(
918                TaskType::General,
919                complexity,
920                TaskDomain::General,
921                OutputSize::Medium,
922            );
923
924            let protocol = generator.generate_protocol(&classification, None).unwrap();
925
926            for phase in &protocol.phases {
927                assert_eq!(phase.validation_methods.len(), 1);
928                assert_eq!(phase.validation_methods[0], ValidationMethod::SelfCheck);
929            }
930        }
931    }
932
933    #[test]
934    fn test_synthesis_methods_are_consistent() {
935        let generator = ProtocolGenerator::new();
936
937        // Test across all complexity levels
938        for complexity in [
939            ComplexityLevel::Simple,
940            ComplexityLevel::Moderate,
941            ComplexityLevel::Complex,
942        ] {
943            let classification = create_classification(
944                TaskType::General,
945                complexity,
946                TaskDomain::General,
947                OutputSize::Medium,
948            );
949
950            let protocol = generator.generate_protocol(&classification, None).unwrap();
951
952            for phase in &protocol.phases {
953                assert_eq!(phase.synthesis_methods.len(), 1);
954                assert_eq!(phase.synthesis_methods[0], SynthesisMethod::WeightedAverage);
955            }
956        }
957    }
958
959    // ============================================================================
960    // PROTOCOL SCHEMA VALIDATION TESTS
961    // ============================================================================
962
963    #[test]
964    fn test_protocol_has_all_required_fields() {
965        let generator = ProtocolGenerator::new();
966        let classification = complex_classification();
967
968        let protocol = generator.generate_protocol(&classification, None).unwrap();
969
970        // Verify all required fields are present and not empty/default
971        assert!(!protocol.id.is_empty());
972        assert!(!protocol.name.is_empty());
973        assert!(!protocol.version.is_empty());
974        assert!(!protocol.description.is_empty());
975        assert!(!protocol.phases.is_empty());
976    }
977
978    #[test]
979    fn test_phases_have_all_required_fields() {
980        let generator = ProtocolGenerator::new();
981        let classification = complex_classification();
982
983        let protocol = generator.generate_protocol(&classification, None).unwrap();
984
985        for phase in &protocol.phases {
986            assert!(!phase.name.is_empty());
987            assert!(phase.parallel_branches >= 1);
988            assert!(phase.required_confidence > 0.0 && phase.required_confidence <= 1.0);
989            assert!(!phase.validation_methods.is_empty());
990            assert!(!phase.synthesis_methods.is_empty());
991        }
992    }
993
994    #[test]
995    fn test_confidence_values_are_valid() {
996        let generator = ProtocolGenerator::new();
997
998        // Test all complexity levels
999        for complexity in [
1000            ComplexityLevel::Simple,
1001            ComplexityLevel::Moderate,
1002            ComplexityLevel::Complex,
1003        ] {
1004            let classification = create_classification(
1005                TaskType::General,
1006                complexity,
1007                TaskDomain::General,
1008                OutputSize::Medium,
1009            );
1010
1011            let protocol = generator.generate_protocol(&classification, None).unwrap();
1012
1013            for phase in &protocol.phases {
1014                assert!(
1015                    phase.required_confidence >= 0.0,
1016                    "Confidence should be non-negative"
1017                );
1018                assert!(
1019                    phase.required_confidence <= 1.0,
1020                    "Confidence should not exceed 1.0"
1021                );
1022            }
1023        }
1024    }
1025
1026    #[test]
1027    fn test_token_budgets_are_positive() {
1028        let generator = ProtocolGenerator::new();
1029
1030        for output_size in [OutputSize::Small, OutputSize::Medium, OutputSize::Large] {
1031            let classification = create_classification(
1032                TaskType::General,
1033                ComplexityLevel::Simple,
1034                TaskDomain::General,
1035                output_size,
1036            );
1037
1038            let protocol = generator.generate_protocol(&classification, None).unwrap();
1039
1040            assert!(
1041                protocol.constraints.token_budget > 0,
1042                "Token budget should be positive"
1043            );
1044        }
1045    }
1046
1047    #[test]
1048    fn test_time_budgets_are_positive() {
1049        let generator = ProtocolGenerator::new();
1050
1051        for output_size in [OutputSize::Small, OutputSize::Medium, OutputSize::Large] {
1052            let classification = create_classification(
1053                TaskType::General,
1054                ComplexityLevel::Simple,
1055                TaskDomain::General,
1056                output_size,
1057            );
1058
1059            let protocol = generator.generate_protocol(&classification, None).unwrap();
1060
1061            assert!(
1062                protocol.constraints.time_budget_ms > 0,
1063                "Time budget should be positive"
1064            );
1065        }
1066    }
1067
1068    // ============================================================================
1069    // FRAMEWORK AND LANGUAGE SUPPORT TESTS
1070    // ============================================================================
1071
1072    #[test]
1073    fn test_framework_compatibility_is_empty_by_default() {
1074        let generator = ProtocolGenerator::new();
1075        let classification = simple_classification();
1076
1077        let protocol = generator.generate_protocol(&classification, None).unwrap();
1078
1079        assert!(protocol.framework_compatibility.is_empty());
1080    }
1081
1082    #[test]
1083    fn test_language_support_is_empty_by_default() {
1084        let generator = ProtocolGenerator::new();
1085        let classification = simple_classification();
1086
1087        let protocol = generator.generate_protocol(&classification, None).unwrap();
1088
1089        assert!(protocol.language_support.is_empty());
1090    }
1091
1092    // ============================================================================
1093    // TASK CLASSIFICATION FROM USE CASE TESTS
1094    // ============================================================================
1095
1096    #[test]
1097    fn test_use_case_code_analysis_classification() {
1098        let classification = TaskClassification::from(UseCase::CodeAnalysis);
1099
1100        assert_eq!(classification.task_type, TaskType::CodeAnalysis);
1101        assert_eq!(classification.complexity_level, ComplexityLevel::Complex);
1102        assert_eq!(classification.domain, TaskDomain::SystemProgramming);
1103        assert_eq!(classification.expected_output_size, OutputSize::Large);
1104    }
1105
1106    #[test]
1107    fn test_use_case_bug_finding_classification() {
1108        let classification = TaskClassification::from(UseCase::BugFinding);
1109
1110        assert_eq!(classification.task_type, TaskType::BugFinding);
1111        assert_eq!(classification.complexity_level, ComplexityLevel::Moderate);
1112        assert_eq!(classification.domain, TaskDomain::SystemProgramming);
1113        assert_eq!(classification.expected_output_size, OutputSize::Medium);
1114    }
1115
1116    #[test]
1117    fn test_use_case_documentation_classification() {
1118        let classification = TaskClassification::from(UseCase::Documentation);
1119
1120        assert_eq!(classification.task_type, TaskType::Documentation);
1121        assert_eq!(classification.complexity_level, ComplexityLevel::Moderate);
1122        assert_eq!(classification.domain, TaskDomain::General);
1123        assert_eq!(classification.expected_output_size, OutputSize::Medium);
1124    }
1125
1126    #[test]
1127    fn test_use_case_architecture_classification() {
1128        let classification = TaskClassification::from(UseCase::Architecture);
1129
1130        assert_eq!(classification.task_type, TaskType::Architecture);
1131        assert_eq!(classification.complexity_level, ComplexityLevel::Complex);
1132        assert_eq!(classification.domain, TaskDomain::General);
1133        assert_eq!(classification.expected_output_size, OutputSize::Large);
1134    }
1135
1136    #[test]
1137    fn test_use_case_general_classification() {
1138        let classification = TaskClassification::from(UseCase::General);
1139
1140        assert_eq!(classification.task_type, TaskType::General);
1141        assert_eq!(classification.complexity_level, ComplexityLevel::Moderate);
1142        assert_eq!(classification.domain, TaskDomain::General);
1143        assert_eq!(classification.expected_output_size, OutputSize::Medium);
1144    }
1145
1146    // ============================================================================
1147    // INTEGRATION TESTS - COMPLETE WORKFLOW
1148    // ============================================================================
1149
1150    #[test]
1151    fn test_complete_protocol_generation_workflow() {
1152        let generator = ProtocolGenerator::new();
1153
1154        // Generate from UseCase
1155        let use_case = UseCase::CodeAnalysis;
1156        let classification = TaskClassification::from(use_case);
1157
1158        let protocol = generator
1159            .generate_protocol(&classification, Some("integration-test".into()))
1160            .unwrap();
1161
1162        // Verify complete protocol structure
1163        assert_eq!(protocol.name, "integration-test");
1164        assert!(!protocol.id.is_empty());
1165        assert_eq!(protocol.version, "1.0.0");
1166
1167        // Complex task should have 4 phases
1168        assert_eq!(protocol.phases.len(), 4);
1169
1170        // Large output should have high token budget
1171        assert_eq!(protocol.constraints.token_budget, 32000);
1172
1173        // Serialize and verify it's valid JSON
1174        let json = serde_json::to_string(&protocol).unwrap();
1175        assert!(!json.is_empty());
1176
1177        // Deserialize and verify roundtrip
1178        let deserialized: InterleavedProtocol = serde_json::from_str(&json).unwrap();
1179        assert_eq!(protocol.id, deserialized.id);
1180    }
1181
1182    #[test]
1183    fn test_multiple_protocols_for_all_use_cases() {
1184        let generator = ProtocolGenerator::new();
1185
1186        let use_cases = vec![
1187            UseCase::CodeAnalysis,
1188            UseCase::BugFinding,
1189            UseCase::Documentation,
1190            UseCase::Architecture,
1191            UseCase::General,
1192        ];
1193
1194        for use_case in use_cases {
1195            let classification = TaskClassification::from(use_case);
1196            let protocol = generator.generate_protocol(&classification, None);
1197
1198            assert!(
1199                protocol.is_ok(),
1200                "Protocol generation should succeed for {:?}",
1201                use_case
1202            );
1203
1204            let protocol = protocol.unwrap();
1205            assert!(!protocol.id.is_empty());
1206            assert!(!protocol.phases.is_empty());
1207        }
1208    }
1209}