1use std::collections::HashMap;
32use std::future::Future;
33use std::pin::Pin;
34use std::sync::Arc;
35use std::time::Instant;
36
37use async_trait::async_trait;
38use parking_lot::RwLock;
39use serde::{Deserialize, Serialize};
40use tracing::{debug, warn};
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct DimensionScore {
45 pub dimension: String,
47 pub score: f32,
49 pub explanation: String,
51 pub feedback: Option<String>,
53}
54
55impl DimensionScore {
56 pub fn new(dimension: impl Into<String>, score: f32) -> Self {
57 Self {
58 dimension: dimension.into(),
59 score: score.clamp(0.0, 1.0),
60 explanation: String::new(),
61 feedback: None,
62 }
63 }
64
65 pub fn with_explanation(mut self, explanation: impl Into<String>) -> Self {
66 self.explanation = explanation.into();
67 self
68 }
69
70 pub fn with_feedback(mut self, feedback: impl Into<String>) -> Self {
71 self.feedback = Some(feedback.into());
72 self
73 }
74
75 pub fn passed(&self, threshold: f32) -> bool {
76 self.score >= threshold
77 }
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct Judgment {
83 pub passed: bool,
85 pub overall_score: f32,
87 pub dimension_scores: Vec<DimensionScore>,
89 pub summary: String,
91 pub corrections: Vec<String>,
93 pub metadata: HashMap<String, String>,
95}
96
97impl Judgment {
98 pub fn passed(score: f32) -> Self {
99 Self {
100 passed: true,
101 overall_score: score.clamp(0.0, 1.0),
102 dimension_scores: Vec::new(),
103 summary: "Quality check passed".to_string(),
104 corrections: Vec::new(),
105 metadata: HashMap::new(),
106 }
107 }
108
109 pub fn failed(score: f32, summary: impl Into<String>) -> Self {
110 Self {
111 passed: false,
112 overall_score: score.clamp(0.0, 1.0),
113 dimension_scores: Vec::new(),
114 summary: summary.into(),
115 corrections: Vec::new(),
116 metadata: HashMap::new(),
117 }
118 }
119
120 pub fn with_dimension(mut self, dimension: DimensionScore) -> Self {
121 self.dimension_scores.push(dimension);
122 self
123 }
124
125 pub fn with_correction(mut self, correction: impl Into<String>) -> Self {
126 self.corrections.push(correction.into());
127 self
128 }
129
130 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
131 self.metadata.insert(key.into(), value.into());
132 self
133 }
134
135 pub fn feedback_for_retry(&self) -> String {
137 let mut feedback = vec![format!(
138 "Previous attempt scored {:.1}%",
139 self.overall_score * 100.0
140 )];
141 feedback.push(format!("Feedback: {}", self.summary));
142
143 if !self.corrections.is_empty() {
144 feedback.push("Corrections needed:".to_string());
145 for (i, correction) in self.corrections.iter().enumerate() {
146 feedback.push(format!(" {}. {}", i + 1, correction));
147 }
148 }
149
150 for dim in &self.dimension_scores {
151 if let Some(ref fb) = dim.feedback {
152 feedback.push(format!("- {}: {}", dim.dimension, fb));
153 }
154 }
155
156 feedback.join("\n")
157 }
158}
159
160#[derive(Debug, Clone)]
162pub struct JudgmentContext {
163 pub input: String,
165 pub output: String,
167 pub iteration: u32,
169 pub previous_judgments: Vec<Judgment>,
171 pub metadata: HashMap<String, String>,
173}
174
175impl JudgmentContext {
176 pub fn new(input: impl Into<String>, output: impl Into<String>) -> Self {
177 Self {
178 input: input.into(),
179 output: output.into(),
180 iteration: 1,
181 previous_judgments: Vec::new(),
182 metadata: HashMap::new(),
183 }
184 }
185
186 pub fn with_iteration(mut self, iteration: u32) -> Self {
187 self.iteration = iteration;
188 self
189 }
190
191 pub fn with_previous(mut self, judgments: Vec<Judgment>) -> Self {
192 self.previous_judgments = judgments;
193 self
194 }
195
196 pub fn is_improving(&self) -> bool {
197 if self.previous_judgments.len() < 2 {
198 return true;
199 }
200 let last = &self.previous_judgments[self.previous_judgments.len() - 1];
201 let prev = &self.previous_judgments[self.previous_judgments.len() - 2];
202 last.overall_score > prev.overall_score
203 }
204}
205
206#[async_trait]
208pub trait Judge: Send + Sync {
209 fn name(&self) -> &str;
211
212 async fn evaluate(&self, context: &JudgmentContext) -> Judgment;
214
215 fn weight(&self) -> f32 {
217 1.0
218 }
219
220 fn is_critical(&self) -> bool {
222 false
223 }
224}
225
226pub type BoxedJudge = Arc<dyn Judge>;
228
229#[derive(Debug, thiserror::Error)]
231pub enum SelfCorrectError {
232 #[error("Max iterations ({0}) exceeded without passing quality threshold")]
233 MaxIterationsExceeded(u32),
234
235 #[error("Critical judge '{0}' failed")]
236 CriticalJudgeFailed(String),
237
238 #[error("Execution failed: {0}")]
239 ExecutionFailed(String),
240
241 #[error("No improvement after {0} iterations")]
242 NoImprovement(u32),
243}
244
245#[derive(Debug, Clone)]
247pub struct SelfCorrectConfig {
248 pub max_iterations: u32,
250 pub quality_threshold: f32,
252 pub stop_on_plateau: Option<u32>,
254 pub include_feedback: bool,
256}
257
258impl Default for SelfCorrectConfig {
259 fn default() -> Self {
260 Self {
261 max_iterations: 3,
262 quality_threshold: 0.8,
263 stop_on_plateau: Some(2),
264 include_feedback: true,
265 }
266 }
267}
268
269#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct SelfCorrectResult {
272 pub output: String,
274 pub passed: bool,
276 pub final_score: f32,
278 pub iterations: u32,
280 pub judgment_history: Vec<Judgment>,
282 pub duration_ms: u64,
284}
285
286impl SelfCorrectResult {
287 pub fn improvement(&self) -> Option<f32> {
288 if self.judgment_history.len() < 2 {
289 return None;
290 }
291 let first = self.judgment_history.first()?.overall_score;
292 let last = self.judgment_history.last()?.overall_score;
293 Some(last - first)
294 }
295}
296
297#[derive(Debug, Clone, Default, Serialize, Deserialize)]
299pub struct SelfCorrectStats {
300 pub total_executions: u64,
301 pub successful_executions: u64,
302 pub failed_executions: u64,
303 pub total_iterations: u64,
304 pub average_iterations: f64,
305 pub average_final_score: f64,
306}
307
308pub struct SelfCorrectingWorkflow {
310 judges: Vec<BoxedJudge>,
311 config: SelfCorrectConfig,
312 stats: Arc<RwLock<SelfCorrectStats>>,
313}
314
315impl SelfCorrectingWorkflow {
316 pub fn new() -> Self {
317 Self {
318 judges: Vec::new(),
319 config: SelfCorrectConfig::default(),
320 stats: Arc::new(RwLock::new(SelfCorrectStats::default())),
321 }
322 }
323
324 pub fn with_config(mut self, config: SelfCorrectConfig) -> Self {
325 self.config = config;
326 self
327 }
328
329 pub fn add_judge<J: Judge + 'static>(mut self, judge: J) -> Self {
330 self.judges.push(Arc::new(judge));
331 self
332 }
333
334 pub fn add_judge_boxed(mut self, judge: BoxedJudge) -> Self {
335 self.judges.push(judge);
336 self
337 }
338
339 pub fn max_iterations(mut self, max: u32) -> Self {
340 self.config.max_iterations = max;
341 self
342 }
343
344 pub fn quality_threshold(mut self, threshold: f32) -> Self {
345 self.config.quality_threshold = threshold.clamp(0.0, 1.0);
346 self
347 }
348
349 pub async fn execute<F, Fut>(
351 &self,
352 input: impl Into<String>,
353 generator: F,
354 ) -> Result<SelfCorrectResult, SelfCorrectError>
355 where
356 F: Fn(String) -> Fut,
357 Fut: Future<Output = Result<String, String>>,
358 {
359 let input = input.into();
360 let start = Instant::now();
361 let mut judgment_history = Vec::new();
362 let mut current_prompt = input.clone();
363 let mut best_output = String::new();
364 let mut best_score = 0.0f32;
365 let mut plateau_count = 0u32;
366
367 for iteration in 1..=self.config.max_iterations {
368 debug!(iteration, "Starting self-correct iteration");
369
370 let output = generator(current_prompt.clone())
372 .await
373 .map_err(SelfCorrectError::ExecutionFailed)?;
374
375 let context = JudgmentContext::new(&input, &output)
377 .with_iteration(iteration)
378 .with_previous(judgment_history.clone());
379
380 let judgment = self.evaluate_all(&context).await;
381
382 if judgment.overall_score > best_score {
384 best_score = judgment.overall_score;
385 best_output = output.clone();
386 plateau_count = 0;
387 } else {
388 plateau_count += 1;
389 }
390
391 judgment_history.push(judgment.clone());
392
393 if judgment.passed {
395 return Ok(self.create_result(
396 output,
397 true,
398 judgment.overall_score,
399 iteration,
400 judgment_history,
401 start.elapsed().as_millis() as u64,
402 ));
403 }
404
405 for judge in &self.judges {
407 if judge.is_critical() {
408 let judge_result = judge.evaluate(&context).await;
409 if !judge_result.passed {
410 return Err(SelfCorrectError::CriticalJudgeFailed(
411 judge.name().to_string(),
412 ));
413 }
414 }
415 }
416
417 if let Some(plateau_limit) = self.config.stop_on_plateau {
419 if plateau_count >= plateau_limit {
420 warn!(iteration, plateau_count, "No improvement, stopping early");
421 break;
422 }
423 }
424
425 if self.config.include_feedback && iteration < self.config.max_iterations {
427 current_prompt = format!(
428 "{}\n\n--- Previous Attempt Feedback ---\n{}",
429 input,
430 judgment.feedback_for_retry()
431 );
432 }
433 }
434
435 Ok(self.create_result(
437 best_output,
438 best_score >= self.config.quality_threshold,
439 best_score,
440 self.config.max_iterations,
441 judgment_history,
442 start.elapsed().as_millis() as u64,
443 ))
444 }
445
446 async fn evaluate_all(&self, context: &JudgmentContext) -> Judgment {
447 if self.judges.is_empty() {
448 return Judgment::passed(1.0);
449 }
450
451 let mut total_score = 0.0f32;
452 let mut total_weight = 0.0f32;
453 let mut all_dimensions = Vec::new();
454 let mut all_corrections = Vec::new();
455 let mut summaries = Vec::new();
456
457 for judge in &self.judges {
458 let result = judge.evaluate(context).await;
459 let weight = judge.weight();
460
461 total_score += result.overall_score * weight;
462 total_weight += weight;
463
464 all_dimensions.extend(result.dimension_scores);
465 all_corrections.extend(result.corrections);
466
467 if !result.summary.is_empty() {
468 summaries.push(format!("{}: {}", judge.name(), result.summary));
469 }
470 }
471
472 let overall_score = if total_weight > 0.0 {
473 total_score / total_weight
474 } else {
475 0.0
476 };
477
478 let passed = overall_score >= self.config.quality_threshold;
479
480 Judgment {
481 passed,
482 overall_score,
483 dimension_scores: all_dimensions,
484 summary: summaries.join("; "),
485 corrections: all_corrections,
486 metadata: HashMap::new(),
487 }
488 }
489
490 fn create_result(
491 &self,
492 output: String,
493 passed: bool,
494 score: f32,
495 iterations: u32,
496 history: Vec<Judgment>,
497 duration_ms: u64,
498 ) -> SelfCorrectResult {
499 {
501 let mut stats = self.stats.write();
502 stats.total_executions += 1;
503 if passed {
504 stats.successful_executions += 1;
505 } else {
506 stats.failed_executions += 1;
507 }
508 stats.total_iterations += iterations as u64;
509 let total = stats.total_executions as f64;
510 stats.average_iterations =
511 (stats.average_iterations * (total - 1.0) + iterations as f64) / total;
512 stats.average_final_score =
513 (stats.average_final_score * (total - 1.0) + score as f64) / total;
514 }
515
516 SelfCorrectResult {
517 output,
518 passed,
519 final_score: score,
520 iterations,
521 judgment_history: history,
522 duration_ms,
523 }
524 }
525
526 pub fn stats(&self) -> SelfCorrectStats {
527 self.stats.read().clone()
528 }
529}
530
531impl Default for SelfCorrectingWorkflow {
532 fn default() -> Self {
533 Self::new()
534 }
535}
536
537pub struct LengthJudge {
543 min_length: Option<usize>,
544 max_length: Option<usize>,
545}
546
547impl LengthJudge {
548 pub fn new() -> Self {
549 Self {
550 min_length: None,
551 max_length: None,
552 }
553 }
554
555 pub fn min(mut self, len: usize) -> Self {
556 self.min_length = Some(len);
557 self
558 }
559
560 pub fn max(mut self, len: usize) -> Self {
561 self.max_length = Some(len);
562 self
563 }
564
565 pub fn range(mut self, min: usize, max: usize) -> Self {
566 self.min_length = Some(min);
567 self.max_length = Some(max);
568 self
569 }
570}
571
572impl Default for LengthJudge {
573 fn default() -> Self {
574 Self::new()
575 }
576}
577
578#[async_trait]
579impl Judge for LengthJudge {
580 fn name(&self) -> &str {
581 "length_judge"
582 }
583
584 async fn evaluate(&self, context: &JudgmentContext) -> Judgment {
585 let len = context.output.len();
586 let mut score = 1.0f32;
587 let mut feedback = Vec::new();
588
589 if let Some(min) = self.min_length {
590 if len < min {
591 score *= len as f32 / min as f32;
592 feedback.push(format!("Output too short ({} chars, minimum {})", len, min));
593 }
594 }
595
596 if let Some(max) = self.max_length {
597 if len > max {
598 score *= max as f32 / len as f32;
599 feedback.push(format!("Output too long ({} chars, maximum {})", len, max));
600 }
601 }
602
603 if feedback.is_empty() {
604 Judgment::passed(score)
605 } else {
606 Judgment::failed(score, feedback.join("; "))
607 }
608 }
609}
610
611pub struct KeywordJudge {
613 required: Vec<String>,
614 forbidden: Vec<String>,
615}
616
617impl KeywordJudge {
618 pub fn new() -> Self {
619 Self {
620 required: Vec::new(),
621 forbidden: Vec::new(),
622 }
623 }
624
625 pub fn require(mut self, keyword: impl Into<String>) -> Self {
626 self.required.push(keyword.into());
627 self
628 }
629
630 pub fn forbid(mut self, keyword: impl Into<String>) -> Self {
631 self.forbidden.push(keyword.into());
632 self
633 }
634}
635
636impl Default for KeywordJudge {
637 fn default() -> Self {
638 Self::new()
639 }
640}
641
642#[async_trait]
643impl Judge for KeywordJudge {
644 fn name(&self) -> &str {
645 "keyword_judge"
646 }
647
648 async fn evaluate(&self, context: &JudgmentContext) -> Judgment {
649 let output_lower = context.output.to_lowercase();
650 let mut missing = Vec::new();
651 let mut found_forbidden = Vec::new();
652
653 for keyword in &self.required {
654 if !output_lower.contains(&keyword.to_lowercase()) {
655 missing.push(keyword.clone());
656 }
657 }
658
659 for keyword in &self.forbidden {
660 if output_lower.contains(&keyword.to_lowercase()) {
661 found_forbidden.push(keyword.clone());
662 }
663 }
664
665 let required_score = if self.required.is_empty() {
666 1.0
667 } else {
668 (self.required.len() - missing.len()) as f32 / self.required.len() as f32
669 };
670
671 let forbidden_score = if self.forbidden.is_empty() {
672 1.0
673 } else if found_forbidden.is_empty() {
674 1.0
675 } else {
676 0.0
677 };
678
679 let score = required_score * 0.7 + forbidden_score * 0.3;
680
681 let mut judgment = if missing.is_empty() && found_forbidden.is_empty() {
682 Judgment::passed(score)
683 } else {
684 let mut summary = Vec::new();
685 if !missing.is_empty() {
686 summary.push(format!("Missing keywords: {}", missing.join(", ")));
687 }
688 if !found_forbidden.is_empty() {
689 summary.push(format!(
690 "Forbidden keywords found: {}",
691 found_forbidden.join(", ")
692 ));
693 }
694 Judgment::failed(score, summary.join("; "))
695 };
696
697 for keyword in &missing {
698 judgment = judgment.with_correction(format!("Include '{}' in the response", keyword));
699 }
700
701 for keyword in &found_forbidden {
702 judgment = judgment.with_correction(format!("Remove '{}' from the response", keyword));
703 }
704
705 judgment
706 }
707
708 fn is_critical(&self) -> bool {
709 !self.forbidden.is_empty() }
711}
712
713pub struct PatternJudge {
715 name: String,
716 required_patterns: Vec<(regex::Regex, String)>,
717 forbidden_patterns: Vec<(regex::Regex, String)>,
718}
719
720impl PatternJudge {
721 pub fn new(name: impl Into<String>) -> Self {
722 Self {
723 name: name.into(),
724 required_patterns: Vec::new(),
725 forbidden_patterns: Vec::new(),
726 }
727 }
728
729 pub fn require(
730 mut self,
731 pattern: &str,
732 description: impl Into<String>,
733 ) -> Result<Self, regex::Error> {
734 self.required_patterns
735 .push((regex::Regex::new(pattern)?, description.into()));
736 Ok(self)
737 }
738
739 pub fn forbid(
740 mut self,
741 pattern: &str,
742 description: impl Into<String>,
743 ) -> Result<Self, regex::Error> {
744 self.forbidden_patterns
745 .push((regex::Regex::new(pattern)?, description.into()));
746 Ok(self)
747 }
748}
749
750#[async_trait]
751impl Judge for PatternJudge {
752 fn name(&self) -> &str {
753 &self.name
754 }
755
756 async fn evaluate(&self, context: &JudgmentContext) -> Judgment {
757 let mut missing = Vec::new();
758 let mut found_forbidden = Vec::new();
759
760 for (pattern, desc) in &self.required_patterns {
761 if !pattern.is_match(&context.output) {
762 missing.push(desc.clone());
763 }
764 }
765
766 for (pattern, desc) in &self.forbidden_patterns {
767 if pattern.is_match(&context.output) {
768 found_forbidden.push(desc.clone());
769 }
770 }
771
772 let total = self.required_patterns.len() + self.forbidden_patterns.len();
773 let failed = missing.len() + found_forbidden.len();
774 let score = if total == 0 {
775 1.0
776 } else {
777 (total - failed) as f32 / total as f32
778 };
779
780 if missing.is_empty() && found_forbidden.is_empty() {
781 Judgment::passed(score)
782 } else {
783 let mut summary = Vec::new();
784 if !missing.is_empty() {
785 summary.push(format!("Missing: {}", missing.join(", ")));
786 }
787 if !found_forbidden.is_empty() {
788 summary.push(format!("Found forbidden: {}", found_forbidden.join(", ")));
789 }
790 Judgment::failed(score, summary.join("; "))
791 }
792 }
793}
794
795pub struct FnJudge<F>
797where
798 F: Fn(&JudgmentContext) -> Pin<Box<dyn Future<Output = Judgment> + Send>> + Send + Sync,
799{
800 name: String,
801 evaluate_fn: F,
802 weight: f32,
803 critical: bool,
804}
805
806impl<F> FnJudge<F>
807where
808 F: Fn(&JudgmentContext) -> Pin<Box<dyn Future<Output = Judgment> + Send>> + Send + Sync,
809{
810 pub fn new(name: impl Into<String>, evaluate_fn: F) -> Self {
811 Self {
812 name: name.into(),
813 evaluate_fn,
814 weight: 1.0,
815 critical: false,
816 }
817 }
818
819 pub fn with_weight(mut self, weight: f32) -> Self {
820 self.weight = weight;
821 self
822 }
823
824 pub fn critical(mut self) -> Self {
825 self.critical = true;
826 self
827 }
828}
829
830#[async_trait]
831impl<F> Judge for FnJudge<F>
832where
833 F: Fn(&JudgmentContext) -> Pin<Box<dyn Future<Output = Judgment> + Send>> + Send + Sync,
834{
835 fn name(&self) -> &str {
836 &self.name
837 }
838
839 async fn evaluate(&self, context: &JudgmentContext) -> Judgment {
840 (self.evaluate_fn)(context).await
841 }
842
843 fn weight(&self) -> f32 {
844 self.weight
845 }
846
847 fn is_critical(&self) -> bool {
848 self.critical
849 }
850}
851
852#[cfg(test)]
853mod tests {
854 use super::*;
855
856 #[tokio::test]
857 async fn test_length_judge() {
858 let judge = LengthJudge::new().range(10, 100);
859
860 let ctx = JudgmentContext::new("input", "short");
862 let result = judge.evaluate(&ctx).await;
863 assert!(!result.passed);
864
865 let ctx = JudgmentContext::new("input", "This is a properly sized response.");
867 let result = judge.evaluate(&ctx).await;
868 assert!(result.passed);
869 }
870
871 #[tokio::test]
872 async fn test_keyword_judge() {
873 let judge = KeywordJudge::new()
874 .require("rust")
875 .require("programming")
876 .forbid("python");
877
878 let ctx = JudgmentContext::new("input", "Rust is a great programming language.");
880 let result = judge.evaluate(&ctx).await;
881 assert!(result.passed);
882
883 let ctx = JudgmentContext::new("input", "Rust is great.");
885 let result = judge.evaluate(&ctx).await;
886 assert!(!result.passed);
887 assert!(result.summary.contains("programming"));
888
889 let ctx = JudgmentContext::new("input", "Rust is better than Python for programming.");
891 let result = judge.evaluate(&ctx).await;
892 assert!(!result.passed);
893 }
894
895 #[tokio::test]
896 async fn test_self_correcting_workflow() {
897 use std::sync::atomic::{AtomicU32, Ordering};
898
899 let workflow = SelfCorrectingWorkflow::new()
900 .add_judge(LengthJudge::new().min(20))
901 .quality_threshold(0.8);
902
903 let attempt = AtomicU32::new(0);
905 let result = workflow
906 .execute("Write something", |_prompt| {
907 let current = attempt.fetch_add(1, Ordering::SeqCst);
908 async move {
909 if current == 0 {
910 Ok("Too short".to_string())
911 } else {
912 Ok(
913 "This is a much longer response that should pass the length check."
914 .to_string(),
915 )
916 }
917 }
918 })
919 .await
920 .unwrap();
921
922 assert!(result.passed);
923 assert!(result.iterations <= 2);
924 }
925
926 #[tokio::test]
927 async fn test_workflow_max_iterations() {
928 let workflow = SelfCorrectingWorkflow::new()
929 .add_judge(LengthJudge::new().min(1000)) .max_iterations(2)
931 .quality_threshold(0.9);
932
933 let result = workflow
934 .execute("input", |_| async { Ok("short".to_string()) })
935 .await
936 .unwrap();
937
938 assert!(!result.passed);
939 assert_eq!(result.iterations, 2);
940 }
941
942 #[tokio::test]
943 async fn test_judgment_feedback() {
944 let judgment = Judgment::failed(0.5, "Quality issues found")
945 .with_correction("Fix the formatting")
946 .with_correction("Add more details");
947
948 let feedback = judgment.feedback_for_retry();
949 assert!(feedback.contains("50.0%"));
950 assert!(feedback.contains("Fix the formatting"));
951 assert!(feedback.contains("Add more details"));
952 }
953
954 #[tokio::test]
955 async fn test_dimension_score() {
956 let dim = DimensionScore::new("accuracy", 0.8)
957 .with_explanation("Good accuracy overall")
958 .with_feedback("Could improve citation quality");
959
960 assert!(dim.passed(0.7));
961 assert!(!dim.passed(0.9));
962 assert_eq!(
963 dim.feedback.as_deref(),
964 Some("Could improve citation quality")
965 );
966 }
967
968 #[tokio::test]
969 async fn test_pattern_judge() {
970 let judge = PatternJudge::new("format_check")
971 .require(r"\d{4}-\d{2}-\d{2}", "date format YYYY-MM-DD")
972 .unwrap();
973
974 let ctx = JudgmentContext::new("input", "The date is 2024-01-15");
975 let result = judge.evaluate(&ctx).await;
976 assert!(result.passed);
977
978 let ctx = JudgmentContext::new("input", "The date is January 15");
979 let result = judge.evaluate(&ctx).await;
980 assert!(!result.passed);
981 }
982
983 #[tokio::test]
984 async fn test_fn_judge() {
985 let judge = FnJudge::new("custom", |ctx| {
986 let has_greeting = ctx.output.to_lowercase().contains("hello");
987 Box::pin(async move {
988 if has_greeting {
989 Judgment::passed(1.0)
990 } else {
991 Judgment::failed(0.0, "Missing greeting")
992 }
993 })
994 });
995
996 let ctx = JudgmentContext::new("input", "Hello, world!");
997 let result = judge.evaluate(&ctx).await;
998 assert!(result.passed);
999 }
1000
1001 #[tokio::test]
1002 async fn test_workflow_stats() {
1003 let workflow = SelfCorrectingWorkflow::new()
1004 .add_judge(LengthJudge::new().min(5))
1005 .quality_threshold(0.8);
1006
1007 workflow
1008 .execute("test", |_| async { Ok("Hello World".to_string()) })
1009 .await
1010 .unwrap();
1011
1012 workflow
1013 .execute("test", |_| async {
1014 Ok("Another test response".to_string())
1015 })
1016 .await
1017 .unwrap();
1018
1019 let stats = workflow.stats();
1020 assert_eq!(stats.total_executions, 2);
1021 assert_eq!(stats.successful_executions, 2);
1022 }
1023
1024 #[tokio::test]
1025 async fn test_judgment_context_improving() {
1026 let j1 = Judgment::failed(0.3, "Poor");
1027 let j2 = Judgment::failed(0.5, "Better");
1028 let j3 = Judgment::passed(0.8);
1029
1030 let ctx = JudgmentContext::new("input", "output").with_previous(vec![j1, j2, j3]);
1031
1032 assert!(ctx.is_improving());
1033 }
1034
1035 #[tokio::test]
1036 async fn test_result_improvement() {
1037 let result = SelfCorrectResult {
1038 output: "final".to_string(),
1039 passed: true,
1040 final_score: 0.9,
1041 iterations: 3,
1042 judgment_history: vec![
1043 Judgment::failed(0.3, ""),
1044 Judgment::failed(0.6, ""),
1045 Judgment::passed(0.9),
1046 ],
1047 duration_ms: 100,
1048 };
1049
1050 let improvement = result.improvement().unwrap();
1051 assert!((improvement - 0.6).abs() < 0.01);
1052 }
1053}