Skip to main content

cortexai_agents/
self_correct.rs

1//! # Self-Correcting Workflows
2//!
3//! LLM-as-Judge pattern for automatic quality assessment and correction.
4//!
5//! Inspired by Swarm's self-correcting workflows and AutoGen patterns.
6//!
7//! ## Features
8//!
9//! - **Quality Assessment**: Judge output quality using LLM or rules
10//! - **Auto-Correction**: Automatically retry with feedback on failure
11//! - **Validation Rules**: Custom validation criteria
12//! - **Scoring**: Multi-dimensional quality scoring
13//!
14//! ## Example
15//!
16//! ```rust,ignore
17//! use cortex::self_correct::{SelfCorrectingWorkflow, Judge, QualityCriteria};
18//!
19//! let workflow = SelfCorrectingWorkflow::new()
20//!     .add_judge(CodeQualityJudge::new())
21//!     .add_judge(FactualAccuracyJudge::new())
22//!     .max_iterations(3)
23//!     .quality_threshold(0.8);
24//!
25//! let result = workflow.execute(|ctx| async {
26//!     // Generate response
27//!     generate_response(&ctx.prompt).await
28//! }).await?;
29//! ```
30
31use std::collections::HashMap;
32use std::future::Future;
33use std::pin::Pin;
34use std::sync::Arc;
35use std::time::Instant;
36
37use async_trait::async_trait;
38use parking_lot::RwLock;
39use serde::{Deserialize, Serialize};
40use tracing::{debug, warn};
41
42/// Quality score for a dimension
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct DimensionScore {
45    /// Dimension name (e.g., "accuracy", "completeness", "clarity")
46    pub dimension: String,
47    /// Score from 0.0 to 1.0
48    pub score: f32,
49    /// Explanation for the score
50    pub explanation: String,
51    /// Specific feedback for improvement
52    pub feedback: Option<String>,
53}
54
55impl DimensionScore {
56    pub fn new(dimension: impl Into<String>, score: f32) -> Self {
57        Self {
58            dimension: dimension.into(),
59            score: score.clamp(0.0, 1.0),
60            explanation: String::new(),
61            feedback: None,
62        }
63    }
64
65    pub fn with_explanation(mut self, explanation: impl Into<String>) -> Self {
66        self.explanation = explanation.into();
67        self
68    }
69
70    pub fn with_feedback(mut self, feedback: impl Into<String>) -> Self {
71        self.feedback = Some(feedback.into());
72        self
73    }
74
75    pub fn passed(&self, threshold: f32) -> bool {
76        self.score >= threshold
77    }
78}
79
80/// Overall judgment result
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct Judgment {
83    /// Whether the output passed quality checks
84    pub passed: bool,
85    /// Overall quality score (0.0 to 1.0)
86    pub overall_score: f32,
87    /// Individual dimension scores
88    pub dimension_scores: Vec<DimensionScore>,
89    /// Summary feedback
90    pub summary: String,
91    /// Specific corrections needed
92    pub corrections: Vec<String>,
93    /// Metadata from the judge
94    pub metadata: HashMap<String, String>,
95}
96
97impl Judgment {
98    pub fn passed(score: f32) -> Self {
99        Self {
100            passed: true,
101            overall_score: score.clamp(0.0, 1.0),
102            dimension_scores: Vec::new(),
103            summary: "Quality check passed".to_string(),
104            corrections: Vec::new(),
105            metadata: HashMap::new(),
106        }
107    }
108
109    pub fn failed(score: f32, summary: impl Into<String>) -> Self {
110        Self {
111            passed: false,
112            overall_score: score.clamp(0.0, 1.0),
113            dimension_scores: Vec::new(),
114            summary: summary.into(),
115            corrections: Vec::new(),
116            metadata: HashMap::new(),
117        }
118    }
119
120    pub fn with_dimension(mut self, dimension: DimensionScore) -> Self {
121        self.dimension_scores.push(dimension);
122        self
123    }
124
125    pub fn with_correction(mut self, correction: impl Into<String>) -> Self {
126        self.corrections.push(correction.into());
127        self
128    }
129
130    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
131        self.metadata.insert(key.into(), value.into());
132        self
133    }
134
135    /// Get feedback string for retry prompt
136    pub fn feedback_for_retry(&self) -> String {
137        let mut feedback = vec![format!(
138            "Previous attempt scored {:.1}%",
139            self.overall_score * 100.0
140        )];
141        feedback.push(format!("Feedback: {}", self.summary));
142
143        if !self.corrections.is_empty() {
144            feedback.push("Corrections needed:".to_string());
145            for (i, correction) in self.corrections.iter().enumerate() {
146                feedback.push(format!("  {}. {}", i + 1, correction));
147            }
148        }
149
150        for dim in &self.dimension_scores {
151            if let Some(ref fb) = dim.feedback {
152                feedback.push(format!("- {}: {}", dim.dimension, fb));
153            }
154        }
155
156        feedback.join("\n")
157    }
158}
159
160/// Context for judgment
161#[derive(Debug, Clone)]
162pub struct JudgmentContext {
163    /// Original input/prompt
164    pub input: String,
165    /// Generated output to judge
166    pub output: String,
167    /// Iteration number (1-based)
168    pub iteration: u32,
169    /// Previous judgments (for comparison)
170    pub previous_judgments: Vec<Judgment>,
171    /// Additional context
172    pub metadata: HashMap<String, String>,
173}
174
175impl JudgmentContext {
176    pub fn new(input: impl Into<String>, output: impl Into<String>) -> Self {
177        Self {
178            input: input.into(),
179            output: output.into(),
180            iteration: 1,
181            previous_judgments: Vec::new(),
182            metadata: HashMap::new(),
183        }
184    }
185
186    pub fn with_iteration(mut self, iteration: u32) -> Self {
187        self.iteration = iteration;
188        self
189    }
190
191    pub fn with_previous(mut self, judgments: Vec<Judgment>) -> Self {
192        self.previous_judgments = judgments;
193        self
194    }
195
196    pub fn is_improving(&self) -> bool {
197        if self.previous_judgments.len() < 2 {
198            return true;
199        }
200        let last = &self.previous_judgments[self.previous_judgments.len() - 1];
201        let prev = &self.previous_judgments[self.previous_judgments.len() - 2];
202        last.overall_score > prev.overall_score
203    }
204}
205
206/// Trait for implementing judges
207#[async_trait]
208pub trait Judge: Send + Sync {
209    /// Unique name of this judge
210    fn name(&self) -> &str;
211
212    /// Evaluate the output and return a judgment
213    async fn evaluate(&self, context: &JudgmentContext) -> Judgment;
214
215    /// Weight of this judge in overall scoring (default 1.0)
216    fn weight(&self) -> f32 {
217        1.0
218    }
219
220    /// Whether this judge is critical (must pass)
221    fn is_critical(&self) -> bool {
222        false
223    }
224}
225
226/// Boxed judge for type erasure
227pub type BoxedJudge = Arc<dyn Judge>;
228
229/// Error type for self-correcting workflows
230#[derive(Debug, thiserror::Error)]
231pub enum SelfCorrectError {
232    #[error("Max iterations ({0}) exceeded without passing quality threshold")]
233    MaxIterationsExceeded(u32),
234
235    #[error("Critical judge '{0}' failed")]
236    CriticalJudgeFailed(String),
237
238    #[error("Execution failed: {0}")]
239    ExecutionFailed(String),
240
241    #[error("No improvement after {0} iterations")]
242    NoImprovement(u32),
243}
244
245/// Configuration for self-correcting workflow
246#[derive(Debug, Clone)]
247pub struct SelfCorrectConfig {
248    /// Maximum iterations before giving up
249    pub max_iterations: u32,
250    /// Quality threshold to pass (0.0 to 1.0)
251    pub quality_threshold: f32,
252    /// Stop if no improvement for N iterations
253    pub stop_on_plateau: Option<u32>,
254    /// Whether to include previous feedback in retry prompt
255    pub include_feedback: bool,
256}
257
258impl Default for SelfCorrectConfig {
259    fn default() -> Self {
260        Self {
261            max_iterations: 3,
262            quality_threshold: 0.8,
263            stop_on_plateau: Some(2),
264            include_feedback: true,
265        }
266    }
267}
268
269/// Result of a self-correcting workflow
270#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct SelfCorrectResult {
272    /// Final output
273    pub output: String,
274    /// Whether quality threshold was met
275    pub passed: bool,
276    /// Final quality score
277    pub final_score: f32,
278    /// Number of iterations taken
279    pub iterations: u32,
280    /// History of judgments
281    pub judgment_history: Vec<Judgment>,
282    /// Total time taken
283    pub duration_ms: u64,
284}
285
286impl SelfCorrectResult {
287    pub fn improvement(&self) -> Option<f32> {
288        if self.judgment_history.len() < 2 {
289            return None;
290        }
291        let first = self.judgment_history.first()?.overall_score;
292        let last = self.judgment_history.last()?.overall_score;
293        Some(last - first)
294    }
295}
296
297/// Statistics for workflow execution
298#[derive(Debug, Clone, Default, Serialize, Deserialize)]
299pub struct SelfCorrectStats {
300    pub total_executions: u64,
301    pub successful_executions: u64,
302    pub failed_executions: u64,
303    pub total_iterations: u64,
304    pub average_iterations: f64,
305    pub average_final_score: f64,
306}
307
308/// A self-correcting workflow with judges
309pub struct SelfCorrectingWorkflow {
310    judges: Vec<BoxedJudge>,
311    config: SelfCorrectConfig,
312    stats: Arc<RwLock<SelfCorrectStats>>,
313}
314
315impl SelfCorrectingWorkflow {
316    pub fn new() -> Self {
317        Self {
318            judges: Vec::new(),
319            config: SelfCorrectConfig::default(),
320            stats: Arc::new(RwLock::new(SelfCorrectStats::default())),
321        }
322    }
323
324    pub fn with_config(mut self, config: SelfCorrectConfig) -> Self {
325        self.config = config;
326        self
327    }
328
329    pub fn add_judge<J: Judge + 'static>(mut self, judge: J) -> Self {
330        self.judges.push(Arc::new(judge));
331        self
332    }
333
334    pub fn add_judge_boxed(mut self, judge: BoxedJudge) -> Self {
335        self.judges.push(judge);
336        self
337    }
338
339    pub fn max_iterations(mut self, max: u32) -> Self {
340        self.config.max_iterations = max;
341        self
342    }
343
344    pub fn quality_threshold(mut self, threshold: f32) -> Self {
345        self.config.quality_threshold = threshold.clamp(0.0, 1.0);
346        self
347    }
348
349    /// Execute with a generator function
350    pub async fn execute<F, Fut>(
351        &self,
352        input: impl Into<String>,
353        generator: F,
354    ) -> Result<SelfCorrectResult, SelfCorrectError>
355    where
356        F: Fn(String) -> Fut,
357        Fut: Future<Output = Result<String, String>>,
358    {
359        let input = input.into();
360        let start = Instant::now();
361        let mut judgment_history = Vec::new();
362        let mut current_prompt = input.clone();
363        let mut best_output = String::new();
364        let mut best_score = 0.0f32;
365        let mut plateau_count = 0u32;
366
367        for iteration in 1..=self.config.max_iterations {
368            debug!(iteration, "Starting self-correct iteration");
369
370            // Generate output
371            let output = generator(current_prompt.clone())
372                .await
373                .map_err(SelfCorrectError::ExecutionFailed)?;
374
375            // Evaluate with all judges
376            let context = JudgmentContext::new(&input, &output)
377                .with_iteration(iteration)
378                .with_previous(judgment_history.clone());
379
380            let judgment = self.evaluate_all(&context).await;
381
382            // Track best
383            if judgment.overall_score > best_score {
384                best_score = judgment.overall_score;
385                best_output = output.clone();
386                plateau_count = 0;
387            } else {
388                plateau_count += 1;
389            }
390
391            judgment_history.push(judgment.clone());
392
393            // Check if passed
394            if judgment.passed {
395                return Ok(self.create_result(
396                    output,
397                    true,
398                    judgment.overall_score,
399                    iteration,
400                    judgment_history,
401                    start.elapsed().as_millis() as u64,
402                ));
403            }
404
405            // Check for critical judge failure
406            for judge in &self.judges {
407                if judge.is_critical() {
408                    let judge_result = judge.evaluate(&context).await;
409                    if !judge_result.passed {
410                        return Err(SelfCorrectError::CriticalJudgeFailed(
411                            judge.name().to_string(),
412                        ));
413                    }
414                }
415            }
416
417            // Check for plateau
418            if let Some(plateau_limit) = self.config.stop_on_plateau {
419                if plateau_count >= plateau_limit {
420                    warn!(iteration, plateau_count, "No improvement, stopping early");
421                    break;
422                }
423            }
424
425            // Prepare retry prompt with feedback
426            if self.config.include_feedback && iteration < self.config.max_iterations {
427                current_prompt = format!(
428                    "{}\n\n--- Previous Attempt Feedback ---\n{}",
429                    input,
430                    judgment.feedback_for_retry()
431                );
432            }
433        }
434
435        // Return best result even if didn't pass
436        Ok(self.create_result(
437            best_output,
438            best_score >= self.config.quality_threshold,
439            best_score,
440            self.config.max_iterations,
441            judgment_history,
442            start.elapsed().as_millis() as u64,
443        ))
444    }
445
446    async fn evaluate_all(&self, context: &JudgmentContext) -> Judgment {
447        if self.judges.is_empty() {
448            return Judgment::passed(1.0);
449        }
450
451        let mut total_score = 0.0f32;
452        let mut total_weight = 0.0f32;
453        let mut all_dimensions = Vec::new();
454        let mut all_corrections = Vec::new();
455        let mut summaries = Vec::new();
456
457        for judge in &self.judges {
458            let result = judge.evaluate(context).await;
459            let weight = judge.weight();
460
461            total_score += result.overall_score * weight;
462            total_weight += weight;
463
464            all_dimensions.extend(result.dimension_scores);
465            all_corrections.extend(result.corrections);
466
467            if !result.summary.is_empty() {
468                summaries.push(format!("{}: {}", judge.name(), result.summary));
469            }
470        }
471
472        let overall_score = if total_weight > 0.0 {
473            total_score / total_weight
474        } else {
475            0.0
476        };
477
478        let passed = overall_score >= self.config.quality_threshold;
479
480        Judgment {
481            passed,
482            overall_score,
483            dimension_scores: all_dimensions,
484            summary: summaries.join("; "),
485            corrections: all_corrections,
486            metadata: HashMap::new(),
487        }
488    }
489
490    fn create_result(
491        &self,
492        output: String,
493        passed: bool,
494        score: f32,
495        iterations: u32,
496        history: Vec<Judgment>,
497        duration_ms: u64,
498    ) -> SelfCorrectResult {
499        // Update stats
500        {
501            let mut stats = self.stats.write();
502            stats.total_executions += 1;
503            if passed {
504                stats.successful_executions += 1;
505            } else {
506                stats.failed_executions += 1;
507            }
508            stats.total_iterations += iterations as u64;
509            let total = stats.total_executions as f64;
510            stats.average_iterations =
511                (stats.average_iterations * (total - 1.0) + iterations as f64) / total;
512            stats.average_final_score =
513                (stats.average_final_score * (total - 1.0) + score as f64) / total;
514        }
515
516        SelfCorrectResult {
517            output,
518            passed,
519            final_score: score,
520            iterations,
521            judgment_history: history,
522            duration_ms,
523        }
524    }
525
526    pub fn stats(&self) -> SelfCorrectStats {
527        self.stats.read().clone()
528    }
529}
530
531impl Default for SelfCorrectingWorkflow {
532    fn default() -> Self {
533        Self::new()
534    }
535}
536
537// ============================================================================
538// Built-in Judges
539// ============================================================================
540
541/// Simple length-based judge
542pub struct LengthJudge {
543    min_length: Option<usize>,
544    max_length: Option<usize>,
545}
546
547impl LengthJudge {
548    pub fn new() -> Self {
549        Self {
550            min_length: None,
551            max_length: None,
552        }
553    }
554
555    pub fn min(mut self, len: usize) -> Self {
556        self.min_length = Some(len);
557        self
558    }
559
560    pub fn max(mut self, len: usize) -> Self {
561        self.max_length = Some(len);
562        self
563    }
564
565    pub fn range(mut self, min: usize, max: usize) -> Self {
566        self.min_length = Some(min);
567        self.max_length = Some(max);
568        self
569    }
570}
571
572impl Default for LengthJudge {
573    fn default() -> Self {
574        Self::new()
575    }
576}
577
578#[async_trait]
579impl Judge for LengthJudge {
580    fn name(&self) -> &str {
581        "length_judge"
582    }
583
584    async fn evaluate(&self, context: &JudgmentContext) -> Judgment {
585        let len = context.output.len();
586        let mut score = 1.0f32;
587        let mut feedback = Vec::new();
588
589        if let Some(min) = self.min_length {
590            if len < min {
591                score *= len as f32 / min as f32;
592                feedback.push(format!("Output too short ({} chars, minimum {})", len, min));
593            }
594        }
595
596        if let Some(max) = self.max_length {
597            if len > max {
598                score *= max as f32 / len as f32;
599                feedback.push(format!("Output too long ({} chars, maximum {})", len, max));
600            }
601        }
602
603        if feedback.is_empty() {
604            Judgment::passed(score)
605        } else {
606            Judgment::failed(score, feedback.join("; "))
607        }
608    }
609}
610
611/// Keyword presence judge
612pub struct KeywordJudge {
613    required: Vec<String>,
614    forbidden: Vec<String>,
615}
616
617impl KeywordJudge {
618    pub fn new() -> Self {
619        Self {
620            required: Vec::new(),
621            forbidden: Vec::new(),
622        }
623    }
624
625    pub fn require(mut self, keyword: impl Into<String>) -> Self {
626        self.required.push(keyword.into());
627        self
628    }
629
630    pub fn forbid(mut self, keyword: impl Into<String>) -> Self {
631        self.forbidden.push(keyword.into());
632        self
633    }
634}
635
636impl Default for KeywordJudge {
637    fn default() -> Self {
638        Self::new()
639    }
640}
641
642#[async_trait]
643impl Judge for KeywordJudge {
644    fn name(&self) -> &str {
645        "keyword_judge"
646    }
647
648    async fn evaluate(&self, context: &JudgmentContext) -> Judgment {
649        let output_lower = context.output.to_lowercase();
650        let mut missing = Vec::new();
651        let mut found_forbidden = Vec::new();
652
653        for keyword in &self.required {
654            if !output_lower.contains(&keyword.to_lowercase()) {
655                missing.push(keyword.clone());
656            }
657        }
658
659        for keyword in &self.forbidden {
660            if output_lower.contains(&keyword.to_lowercase()) {
661                found_forbidden.push(keyword.clone());
662            }
663        }
664
665        let required_score = if self.required.is_empty() {
666            1.0
667        } else {
668            (self.required.len() - missing.len()) as f32 / self.required.len() as f32
669        };
670
671        let forbidden_score = if self.forbidden.is_empty() {
672            1.0
673        } else if found_forbidden.is_empty() {
674            1.0
675        } else {
676            0.0
677        };
678
679        let score = required_score * 0.7 + forbidden_score * 0.3;
680
681        let mut judgment = if missing.is_empty() && found_forbidden.is_empty() {
682            Judgment::passed(score)
683        } else {
684            let mut summary = Vec::new();
685            if !missing.is_empty() {
686                summary.push(format!("Missing keywords: {}", missing.join(", ")));
687            }
688            if !found_forbidden.is_empty() {
689                summary.push(format!(
690                    "Forbidden keywords found: {}",
691                    found_forbidden.join(", ")
692                ));
693            }
694            Judgment::failed(score, summary.join("; "))
695        };
696
697        for keyword in &missing {
698            judgment = judgment.with_correction(format!("Include '{}' in the response", keyword));
699        }
700
701        for keyword in &found_forbidden {
702            judgment = judgment.with_correction(format!("Remove '{}' from the response", keyword));
703        }
704
705        judgment
706    }
707
708    fn is_critical(&self) -> bool {
709        !self.forbidden.is_empty() // Critical if we have forbidden keywords
710    }
711}
712
713/// Regex pattern judge
714pub struct PatternJudge {
715    name: String,
716    required_patterns: Vec<(regex::Regex, String)>,
717    forbidden_patterns: Vec<(regex::Regex, String)>,
718}
719
720impl PatternJudge {
721    pub fn new(name: impl Into<String>) -> Self {
722        Self {
723            name: name.into(),
724            required_patterns: Vec::new(),
725            forbidden_patterns: Vec::new(),
726        }
727    }
728
729    pub fn require(
730        mut self,
731        pattern: &str,
732        description: impl Into<String>,
733    ) -> Result<Self, regex::Error> {
734        self.required_patterns
735            .push((regex::Regex::new(pattern)?, description.into()));
736        Ok(self)
737    }
738
739    pub fn forbid(
740        mut self,
741        pattern: &str,
742        description: impl Into<String>,
743    ) -> Result<Self, regex::Error> {
744        self.forbidden_patterns
745            .push((regex::Regex::new(pattern)?, description.into()));
746        Ok(self)
747    }
748}
749
750#[async_trait]
751impl Judge for PatternJudge {
752    fn name(&self) -> &str {
753        &self.name
754    }
755
756    async fn evaluate(&self, context: &JudgmentContext) -> Judgment {
757        let mut missing = Vec::new();
758        let mut found_forbidden = Vec::new();
759
760        for (pattern, desc) in &self.required_patterns {
761            if !pattern.is_match(&context.output) {
762                missing.push(desc.clone());
763            }
764        }
765
766        for (pattern, desc) in &self.forbidden_patterns {
767            if pattern.is_match(&context.output) {
768                found_forbidden.push(desc.clone());
769            }
770        }
771
772        let total = self.required_patterns.len() + self.forbidden_patterns.len();
773        let failed = missing.len() + found_forbidden.len();
774        let score = if total == 0 {
775            1.0
776        } else {
777            (total - failed) as f32 / total as f32
778        };
779
780        if missing.is_empty() && found_forbidden.is_empty() {
781            Judgment::passed(score)
782        } else {
783            let mut summary = Vec::new();
784            if !missing.is_empty() {
785                summary.push(format!("Missing: {}", missing.join(", ")));
786            }
787            if !found_forbidden.is_empty() {
788                summary.push(format!("Found forbidden: {}", found_forbidden.join(", ")));
789            }
790            Judgment::failed(score, summary.join("; "))
791        }
792    }
793}
794
795/// Function-based judge
796pub struct FnJudge<F>
797where
798    F: Fn(&JudgmentContext) -> Pin<Box<dyn Future<Output = Judgment> + Send>> + Send + Sync,
799{
800    name: String,
801    evaluate_fn: F,
802    weight: f32,
803    critical: bool,
804}
805
806impl<F> FnJudge<F>
807where
808    F: Fn(&JudgmentContext) -> Pin<Box<dyn Future<Output = Judgment> + Send>> + Send + Sync,
809{
810    pub fn new(name: impl Into<String>, evaluate_fn: F) -> Self {
811        Self {
812            name: name.into(),
813            evaluate_fn,
814            weight: 1.0,
815            critical: false,
816        }
817    }
818
819    pub fn with_weight(mut self, weight: f32) -> Self {
820        self.weight = weight;
821        self
822    }
823
824    pub fn critical(mut self) -> Self {
825        self.critical = true;
826        self
827    }
828}
829
830#[async_trait]
831impl<F> Judge for FnJudge<F>
832where
833    F: Fn(&JudgmentContext) -> Pin<Box<dyn Future<Output = Judgment> + Send>> + Send + Sync,
834{
835    fn name(&self) -> &str {
836        &self.name
837    }
838
839    async fn evaluate(&self, context: &JudgmentContext) -> Judgment {
840        (self.evaluate_fn)(context).await
841    }
842
843    fn weight(&self) -> f32 {
844        self.weight
845    }
846
847    fn is_critical(&self) -> bool {
848        self.critical
849    }
850}
851
852#[cfg(test)]
853mod tests {
854    use super::*;
855
856    #[tokio::test]
857    async fn test_length_judge() {
858        let judge = LengthJudge::new().range(10, 100);
859
860        // Too short
861        let ctx = JudgmentContext::new("input", "short");
862        let result = judge.evaluate(&ctx).await;
863        assert!(!result.passed);
864
865        // Just right
866        let ctx = JudgmentContext::new("input", "This is a properly sized response.");
867        let result = judge.evaluate(&ctx).await;
868        assert!(result.passed);
869    }
870
871    #[tokio::test]
872    async fn test_keyword_judge() {
873        let judge = KeywordJudge::new()
874            .require("rust")
875            .require("programming")
876            .forbid("python");
877
878        // Has required, no forbidden
879        let ctx = JudgmentContext::new("input", "Rust is a great programming language.");
880        let result = judge.evaluate(&ctx).await;
881        assert!(result.passed);
882
883        // Missing required
884        let ctx = JudgmentContext::new("input", "Rust is great.");
885        let result = judge.evaluate(&ctx).await;
886        assert!(!result.passed);
887        assert!(result.summary.contains("programming"));
888
889        // Has forbidden
890        let ctx = JudgmentContext::new("input", "Rust is better than Python for programming.");
891        let result = judge.evaluate(&ctx).await;
892        assert!(!result.passed);
893    }
894
895    #[tokio::test]
896    async fn test_self_correcting_workflow() {
897        use std::sync::atomic::{AtomicU32, Ordering};
898
899        let workflow = SelfCorrectingWorkflow::new()
900            .add_judge(LengthJudge::new().min(20))
901            .quality_threshold(0.8);
902
903        // This generator improves each iteration
904        let attempt = AtomicU32::new(0);
905        let result = workflow
906            .execute("Write something", |_prompt| {
907                let current = attempt.fetch_add(1, Ordering::SeqCst);
908                async move {
909                    if current == 0 {
910                        Ok("Too short".to_string())
911                    } else {
912                        Ok(
913                            "This is a much longer response that should pass the length check."
914                                .to_string(),
915                        )
916                    }
917                }
918            })
919            .await
920            .unwrap();
921
922        assert!(result.passed);
923        assert!(result.iterations <= 2);
924    }
925
926    #[tokio::test]
927    async fn test_workflow_max_iterations() {
928        let workflow = SelfCorrectingWorkflow::new()
929            .add_judge(LengthJudge::new().min(1000)) // Impossible to satisfy
930            .max_iterations(2)
931            .quality_threshold(0.9);
932
933        let result = workflow
934            .execute("input", |_| async { Ok("short".to_string()) })
935            .await
936            .unwrap();
937
938        assert!(!result.passed);
939        assert_eq!(result.iterations, 2);
940    }
941
942    #[tokio::test]
943    async fn test_judgment_feedback() {
944        let judgment = Judgment::failed(0.5, "Quality issues found")
945            .with_correction("Fix the formatting")
946            .with_correction("Add more details");
947
948        let feedback = judgment.feedback_for_retry();
949        assert!(feedback.contains("50.0%"));
950        assert!(feedback.contains("Fix the formatting"));
951        assert!(feedback.contains("Add more details"));
952    }
953
954    #[tokio::test]
955    async fn test_dimension_score() {
956        let dim = DimensionScore::new("accuracy", 0.8)
957            .with_explanation("Good accuracy overall")
958            .with_feedback("Could improve citation quality");
959
960        assert!(dim.passed(0.7));
961        assert!(!dim.passed(0.9));
962        assert_eq!(
963            dim.feedback.as_deref(),
964            Some("Could improve citation quality")
965        );
966    }
967
968    #[tokio::test]
969    async fn test_pattern_judge() {
970        let judge = PatternJudge::new("format_check")
971            .require(r"\d{4}-\d{2}-\d{2}", "date format YYYY-MM-DD")
972            .unwrap();
973
974        let ctx = JudgmentContext::new("input", "The date is 2024-01-15");
975        let result = judge.evaluate(&ctx).await;
976        assert!(result.passed);
977
978        let ctx = JudgmentContext::new("input", "The date is January 15");
979        let result = judge.evaluate(&ctx).await;
980        assert!(!result.passed);
981    }
982
983    #[tokio::test]
984    async fn test_fn_judge() {
985        let judge = FnJudge::new("custom", |ctx| {
986            let has_greeting = ctx.output.to_lowercase().contains("hello");
987            Box::pin(async move {
988                if has_greeting {
989                    Judgment::passed(1.0)
990                } else {
991                    Judgment::failed(0.0, "Missing greeting")
992                }
993            })
994        });
995
996        let ctx = JudgmentContext::new("input", "Hello, world!");
997        let result = judge.evaluate(&ctx).await;
998        assert!(result.passed);
999    }
1000
1001    #[tokio::test]
1002    async fn test_workflow_stats() {
1003        let workflow = SelfCorrectingWorkflow::new()
1004            .add_judge(LengthJudge::new().min(5))
1005            .quality_threshold(0.8);
1006
1007        workflow
1008            .execute("test", |_| async { Ok("Hello World".to_string()) })
1009            .await
1010            .unwrap();
1011
1012        workflow
1013            .execute("test", |_| async {
1014                Ok("Another test response".to_string())
1015            })
1016            .await
1017            .unwrap();
1018
1019        let stats = workflow.stats();
1020        assert_eq!(stats.total_executions, 2);
1021        assert_eq!(stats.successful_executions, 2);
1022    }
1023
1024    #[tokio::test]
1025    async fn test_judgment_context_improving() {
1026        let j1 = Judgment::failed(0.3, "Poor");
1027        let j2 = Judgment::failed(0.5, "Better");
1028        let j3 = Judgment::passed(0.8);
1029
1030        let ctx = JudgmentContext::new("input", "output").with_previous(vec![j1, j2, j3]);
1031
1032        assert!(ctx.is_improving());
1033    }
1034
1035    #[tokio::test]
1036    async fn test_result_improvement() {
1037        let result = SelfCorrectResult {
1038            output: "final".to_string(),
1039            passed: true,
1040            final_score: 0.9,
1041            iterations: 3,
1042            judgment_history: vec![
1043                Judgment::failed(0.3, ""),
1044                Judgment::failed(0.6, ""),
1045                Judgment::passed(0.9),
1046            ],
1047            duration_ms: 100,
1048        };
1049
1050        let improvement = result.improvement().unwrap();
1051        assert!((improvement - 0.6).abs() < 0.01);
1052    }
1053}