1use regex::Regex;
13
14use super::voting::ResponseMetadata;
15
16#[derive(Clone, Debug)]
18pub struct RedFlagConfig {
19 pub max_response_tokens: u32,
21 pub require_exact_format: bool,
23 pub flag_self_correction: bool,
25 pub confusion_patterns: Vec<String>,
27 pub min_response_length: u32,
29 pub max_empty_line_ratio: f32,
31}
32
33impl Default for RedFlagConfig {
34 fn default() -> Self {
35 Self::strict()
36 }
37}
38
39impl RedFlagConfig {
40 pub fn strict() -> Self {
42 Self {
43 max_response_tokens: 750,
44 require_exact_format: true,
45 flag_self_correction: true,
46 confusion_patterns: vec![
47 "Wait,".to_string(),
48 "Actually,".to_string(),
49 "Let me reconsider".to_string(),
50 "I made a mistake".to_string(),
51 "On second thought".to_string(),
52 "Hmm,".to_string(),
53 "I think I".to_string(),
54 "Let me correct".to_string(),
55 "Sorry, I meant".to_string(),
56 "That's not right".to_string(),
57 ],
58 min_response_length: 1,
59 max_empty_line_ratio: 0.5,
60 }
61 }
62
63 pub fn relaxed() -> Self {
65 Self {
66 max_response_tokens: 1500,
67 require_exact_format: false,
68 flag_self_correction: false,
69 confusion_patterns: vec![],
70 min_response_length: 0,
71 max_empty_line_ratio: 0.8,
72 }
73 }
74
75 pub fn builder() -> RedFlagConfigBuilder {
77 RedFlagConfigBuilder::default()
78 }
79}
80
81#[derive(Default)]
83pub struct RedFlagConfigBuilder {
84 config: RedFlagConfig,
85}
86
87impl RedFlagConfigBuilder {
88 pub fn max_response_tokens(mut self, tokens: u32) -> Self {
90 self.config.max_response_tokens = tokens;
91 self
92 }
93
94 pub fn require_exact_format(mut self, require: bool) -> Self {
96 self.config.require_exact_format = require;
97 self
98 }
99
100 pub fn flag_self_correction(mut self, flag: bool) -> Self {
102 self.config.flag_self_correction = flag;
103 self
104 }
105
106 pub fn add_confusion_pattern(mut self, pattern: impl Into<String>) -> Self {
108 self.config.confusion_patterns.push(pattern.into());
109 self
110 }
111
112 pub fn confusion_patterns(mut self, patterns: Vec<String>) -> Self {
114 self.config.confusion_patterns = patterns;
115 self
116 }
117
118 pub fn min_response_length(mut self, length: u32) -> Self {
120 self.config.min_response_length = length;
121 self
122 }
123
124 pub fn max_empty_line_ratio(mut self, ratio: f32) -> Self {
126 self.config.max_empty_line_ratio = ratio;
127 self
128 }
129
130 pub fn build(self) -> RedFlagConfig {
132 self.config
133 }
134}
135
136#[derive(Clone, Debug)]
138pub enum RedFlagResult {
139 Valid,
141 Flagged {
143 reason: RedFlagReason,
145 severity: f32,
147 },
148}
149
150impl RedFlagResult {
151 pub fn is_valid(&self) -> bool {
153 matches!(self, RedFlagResult::Valid)
154 }
155
156 pub fn is_flagged(&self) -> bool {
158 matches!(self, RedFlagResult::Flagged { .. })
159 }
160}
161
162#[derive(Clone, Debug)]
164pub enum RedFlagReason {
165 ResponseTooLong {
167 tokens: u32,
169 limit: u32,
171 },
172 ResponseTooShort {
174 length: u32,
176 minimum: u32,
178 },
179 InvalidFormat {
181 expected: String,
183 got: String,
185 },
186 SelfCorrectionDetected {
188 pattern: String,
190 },
191 ConfusedReasoning {
193 pattern: String,
195 },
196 ParseError {
198 message: String,
200 },
201 EmptyResponse,
203 TooManyEmptyLines {
205 ratio: f32,
207 max: f32,
209 },
210 InvalidJson {
212 message: String,
214 },
215 MissingField {
217 field: String,
219 },
220 Truncated {
222 reason: String,
224 },
225}
226
227impl std::fmt::Display for RedFlagReason {
228 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
229 match self {
230 RedFlagReason::ResponseTooLong { tokens, limit } => {
231 write!(f, "Response too long: {} tokens > {} limit", tokens, limit)
232 }
233 RedFlagReason::ResponseTooShort { length, minimum } => {
234 write!(
235 f,
236 "Response too short: {} chars < {} minimum",
237 length, minimum
238 )
239 }
240 RedFlagReason::InvalidFormat { expected, got } => {
241 write!(f, "Invalid format: expected {}, got {}", expected, got)
242 }
243 RedFlagReason::SelfCorrectionDetected { pattern } => {
244 write!(f, "Self-correction detected: '{}'", pattern)
245 }
246 RedFlagReason::ConfusedReasoning { pattern } => {
247 write!(f, "Confused reasoning: '{}'", pattern)
248 }
249 RedFlagReason::ParseError { message } => {
250 write!(f, "Parse error: {}", message)
251 }
252 RedFlagReason::EmptyResponse => write!(f, "Empty response"),
253 RedFlagReason::TooManyEmptyLines { ratio, max } => {
254 write!(
255 f,
256 "Too many empty lines: {:.1}% > {:.1}% max",
257 ratio * 100.0,
258 max * 100.0
259 )
260 }
261 RedFlagReason::InvalidJson { message } => {
262 write!(f, "Invalid JSON: {}", message)
263 }
264 RedFlagReason::MissingField { field } => {
265 write!(f, "Missing required field: {}", field)
266 }
267 RedFlagReason::Truncated { reason } => {
268 write!(f, "Response truncated: {}", reason)
269 }
270 }
271 }
272}
273
274pub trait RedFlagValidator: Send + Sync {
276 fn validate(&self, response: &str, metadata: &ResponseMetadata) -> RedFlagResult;
278}
279
280pub struct StandardRedFlagValidator {
282 config: RedFlagConfig,
283 expected_format: Option<OutputFormat>,
284 confusion_regexes: Vec<Regex>,
285}
286
287impl StandardRedFlagValidator {
288 pub fn new(config: RedFlagConfig, expected_format: Option<OutputFormat>) -> Self {
290 let confusion_regexes = config
292 .confusion_patterns
293 .iter()
294 .filter_map(|p| {
295 Regex::new(®ex::escape(p)).ok()
297 })
298 .collect();
299
300 Self {
301 config,
302 expected_format,
303 confusion_regexes,
304 }
305 }
306
307 pub fn strict() -> Self {
309 Self::new(RedFlagConfig::strict(), None)
310 }
311
312 pub fn with_format(format: OutputFormat) -> Self {
314 Self::new(RedFlagConfig::strict(), Some(format))
315 }
316
317 pub fn set_expected_format(&mut self, format: Option<OutputFormat>) {
319 self.expected_format = format;
320 }
321
322 fn check_length(&self, response: &str, metadata: &ResponseMetadata) -> Option<RedFlagResult> {
324 if response.trim().is_empty() {
326 return Some(RedFlagResult::Flagged {
327 reason: RedFlagReason::EmptyResponse,
328 severity: 1.0,
329 });
330 }
331
332 if (response.len() as u32) < self.config.min_response_length {
334 return Some(RedFlagResult::Flagged {
335 reason: RedFlagReason::ResponseTooShort {
336 length: response.len() as u32,
337 minimum: self.config.min_response_length,
338 },
339 severity: 0.9,
340 });
341 }
342
343 if metadata.token_count > self.config.max_response_tokens {
345 return Some(RedFlagResult::Flagged {
346 reason: RedFlagReason::ResponseTooLong {
347 tokens: metadata.token_count,
348 limit: self.config.max_response_tokens,
349 },
350 severity: 0.8,
351 });
352 }
353
354 None
355 }
356
357 fn check_self_correction(&self, response: &str) -> Option<RedFlagResult> {
359 if !self.config.flag_self_correction {
360 return None;
361 }
362
363 for (regex, pattern) in self
364 .confusion_regexes
365 .iter()
366 .zip(&self.config.confusion_patterns)
367 {
368 if regex.is_match(response) {
369 return Some(RedFlagResult::Flagged {
370 reason: RedFlagReason::SelfCorrectionDetected {
371 pattern: pattern.clone(),
372 },
373 severity: 0.7,
374 });
375 }
376 }
377
378 None
379 }
380
381 fn check_format(&self, response: &str) -> Option<RedFlagResult> {
383 if !self.config.require_exact_format {
384 return None;
385 }
386
387 if let Some(ref format) = self.expected_format
388 && !format.matches(response)
389 {
390 return Some(RedFlagResult::Flagged {
391 reason: RedFlagReason::InvalidFormat {
392 expected: format.description(),
393 got: self.extract_format_sample(response),
394 },
395 severity: 0.9,
396 });
397 }
398
399 None
400 }
401
402 fn check_truncation(&self, metadata: &ResponseMetadata) -> Option<RedFlagResult> {
404 if let Some(ref reason) = metadata.finish_reason {
405 let reason_lower = reason.to_lowercase();
406 if reason_lower.contains("length") || reason_lower.contains("max_tokens") {
407 return Some(RedFlagResult::Flagged {
408 reason: RedFlagReason::Truncated {
409 reason: reason.clone(),
410 },
411 severity: 0.85,
412 });
413 }
414 }
415 None
416 }
417
418 fn check_empty_lines(&self, response: &str) -> Option<RedFlagResult> {
420 let lines: Vec<&str> = response.lines().collect();
421 if lines.is_empty() {
422 return None;
423 }
424
425 let empty_count = lines.iter().filter(|l| l.trim().is_empty()).count();
426 let ratio = empty_count as f32 / lines.len() as f32;
427
428 if ratio > self.config.max_empty_line_ratio {
429 return Some(RedFlagResult::Flagged {
430 reason: RedFlagReason::TooManyEmptyLines {
431 ratio,
432 max: self.config.max_empty_line_ratio,
433 },
434 severity: 0.6,
435 });
436 }
437
438 None
439 }
440
441 fn extract_format_sample(&self, response: &str) -> String {
443 let trimmed = response.trim();
444 if trimmed.len() <= 50 {
445 trimmed.to_string()
446 } else {
447 format!("{}...", &trimmed[..50])
448 }
449 }
450}
451
452impl RedFlagValidator for StandardRedFlagValidator {
453 fn validate(&self, response: &str, metadata: &ResponseMetadata) -> RedFlagResult {
454 if let Some(result) = self.check_length(response, metadata) {
458 return result;
459 }
460
461 if let Some(result) = self.check_truncation(metadata) {
463 return result;
464 }
465
466 if let Some(result) = self.check_format(response) {
468 return result;
469 }
470
471 if let Some(result) = self.check_self_correction(response) {
473 return result;
474 }
475
476 if let Some(result) = self.check_empty_lines(response) {
478 return result;
479 }
480
481 RedFlagResult::Valid
482 }
483}
484
485#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
487pub enum OutputFormat {
488 Exact(String),
490 Pattern(String),
492 Json,
494 JsonWithFields(Vec<String>),
496 Markers {
498 start: String,
500 end: String,
502 },
503 OneOf(Vec<String>),
505 Custom {
507 description: String,
509 validator_id: String,
511 },
512}
513
514impl OutputFormat {
515 pub fn matches(&self, response: &str) -> bool {
517 let trimmed = response.trim();
518 match self {
519 OutputFormat::Exact(s) => trimmed == s.trim(),
520 OutputFormat::Pattern(pattern) => Regex::new(pattern)
521 .map(|re| re.is_match(trimmed))
522 .unwrap_or(false),
523 OutputFormat::Json => serde_json::from_str::<serde_json::Value>(trimmed).is_ok(),
524 OutputFormat::JsonWithFields(fields) => {
525 if let Ok(value) = serde_json::from_str::<serde_json::Value>(trimmed)
526 && let Some(obj) = value.as_object()
527 {
528 return fields.iter().all(|f| obj.contains_key(f));
529 }
530 false
531 }
532 OutputFormat::Markers { start, end } => {
533 trimmed.contains(start) && trimmed.contains(end)
534 }
535 OutputFormat::OneOf(options) => options.iter().any(|o| trimmed == o.trim()),
536 OutputFormat::Custom { .. } => {
537 true
540 }
541 }
542 }
543
544 pub fn description(&self) -> String {
546 match self {
547 OutputFormat::Exact(s) => format!("exact: '{}'", s),
548 OutputFormat::Pattern(p) => format!("pattern: {}", p),
549 OutputFormat::Json => "valid JSON".to_string(),
550 OutputFormat::JsonWithFields(fields) => {
551 format!("JSON with fields: {}", fields.join(", "))
552 }
553 OutputFormat::Markers { start, end } => format!("markers: {}...{}", start, end),
554 OutputFormat::OneOf(options) => format!("one of: {}", options.join(", ")),
555 OutputFormat::Custom { description, .. } => description.clone(),
556 }
557 }
558}
559
560pub struct AcceptAllValidator;
562
563impl RedFlagValidator for AcceptAllValidator {
564 fn validate(&self, _response: &str, _metadata: &ResponseMetadata) -> RedFlagResult {
565 RedFlagResult::Valid
566 }
567}
568
569pub struct CompositeValidator {
571 validators: Vec<Box<dyn RedFlagValidator>>,
572}
573
574impl CompositeValidator {
575 pub fn new() -> Self {
577 Self {
578 validators: Vec::new(),
579 }
580 }
581
582 pub fn with_validator(mut self, validator: Box<dyn RedFlagValidator>) -> Self {
584 self.validators.push(validator);
585 self
586 }
587}
588
589impl Default for CompositeValidator {
590 fn default() -> Self {
591 Self::new()
592 }
593}
594
595impl RedFlagValidator for CompositeValidator {
596 fn validate(&self, response: &str, metadata: &ResponseMetadata) -> RedFlagResult {
597 for validator in &self.validators {
598 let result = validator.validate(response, metadata);
599 if result.is_flagged() {
600 return result;
601 }
602 }
603 RedFlagResult::Valid
604 }
605}
606
607#[cfg(test)]
608mod tests {
609 use super::*;
610
611 fn make_metadata(tokens: u32) -> ResponseMetadata {
612 ResponseMetadata {
613 token_count: tokens,
614 response_time_ms: 100,
615 format_valid: true,
616 finish_reason: None,
617 model: None,
618 }
619 }
620
621 #[test]
622 fn test_valid_response() {
623 let validator = StandardRedFlagValidator::strict();
624 let result = validator.validate("This is a valid response.", &make_metadata(50));
625 assert!(result.is_valid());
626 }
627
628 #[test]
629 fn test_empty_response() {
630 let validator = StandardRedFlagValidator::strict();
631 let result = validator.validate("", &make_metadata(0));
632 assert!(result.is_flagged());
633 if let RedFlagResult::Flagged { reason, .. } = result {
634 assert!(matches!(reason, RedFlagReason::EmptyResponse));
635 }
636 }
637
638 #[test]
639 fn test_response_too_long() {
640 let validator = StandardRedFlagValidator::strict();
641 let result = validator.validate("Some response", &make_metadata(800)); assert!(result.is_flagged());
643 if let RedFlagResult::Flagged { reason, .. } = result {
644 assert!(matches!(reason, RedFlagReason::ResponseTooLong { .. }));
645 }
646 }
647
648 #[test]
649 fn test_self_correction_detected() {
650 let validator = StandardRedFlagValidator::strict();
651 let result = validator.validate(
652 "Wait, I think I made an error. Let me reconsider.",
653 &make_metadata(50),
654 );
655 assert!(result.is_flagged());
656 if let RedFlagResult::Flagged { reason, .. } = result {
657 assert!(matches!(
658 reason,
659 RedFlagReason::SelfCorrectionDetected { .. }
660 ));
661 }
662 }
663
664 #[test]
665 fn test_confused_reasoning() {
666 let validator = StandardRedFlagValidator::strict();
667 let result = validator.validate(
668 "Actually, that's not right. On second thought...",
669 &make_metadata(50),
670 );
671 assert!(result.is_flagged());
672 }
673
674 #[test]
675 fn test_format_validation_exact() {
676 let validator =
677 StandardRedFlagValidator::with_format(OutputFormat::Exact("hello".to_string()));
678
679 assert!(validator.validate("hello", &make_metadata(10)).is_valid());
680 assert!(
681 validator
682 .validate(" hello ", &make_metadata(10))
683 .is_valid()
684 ); assert!(validator.validate("world", &make_metadata(10)).is_flagged());
686 }
687
688 #[test]
689 fn test_format_validation_json() {
690 let validator = StandardRedFlagValidator::with_format(OutputFormat::Json);
691
692 assert!(
693 validator
694 .validate(r#"{"key": "value"}"#, &make_metadata(20))
695 .is_valid()
696 );
697 assert!(
698 validator
699 .validate("not json", &make_metadata(10))
700 .is_flagged()
701 );
702 }
703
704 #[test]
705 fn test_format_validation_json_with_fields() {
706 let validator = StandardRedFlagValidator::with_format(OutputFormat::JsonWithFields(vec![
707 "name".to_string(),
708 "value".to_string(),
709 ]));
710
711 assert!(
712 validator
713 .validate(r#"{"name": "test", "value": 42}"#, &make_metadata(30))
714 .is_valid()
715 );
716 assert!(
717 validator
718 .validate(r#"{"name": "test"}"#, &make_metadata(20))
719 .is_flagged()
720 ); }
722
723 #[test]
724 fn test_format_validation_markers() {
725 let validator = StandardRedFlagValidator::with_format(OutputFormat::Markers {
726 start: "```".to_string(),
727 end: "```".to_string(),
728 });
729
730 assert!(
731 validator
732 .validate("```code here```", &make_metadata(20))
733 .is_valid()
734 );
735 assert!(
736 validator
737 .validate("no markers", &make_metadata(10))
738 .is_flagged()
739 );
740 }
741
742 #[test]
743 fn test_format_validation_one_of() {
744 let validator = StandardRedFlagValidator::with_format(OutputFormat::OneOf(vec![
745 "yes".to_string(),
746 "no".to_string(),
747 "maybe".to_string(),
748 ]));
749
750 assert!(validator.validate("yes", &make_metadata(5)).is_valid());
751 assert!(validator.validate("no", &make_metadata(5)).is_valid());
752 assert!(
753 validator
754 .validate("perhaps", &make_metadata(10))
755 .is_flagged()
756 );
757 }
758
759 #[test]
760 fn test_truncation_detection() {
761 let validator = StandardRedFlagValidator::strict();
762 let mut metadata = make_metadata(50);
763 metadata.finish_reason = Some("length".to_string());
764
765 let result = validator.validate("Truncated response", &metadata);
766 assert!(result.is_flagged());
767 if let RedFlagResult::Flagged { reason, .. } = result {
768 assert!(matches!(reason, RedFlagReason::Truncated { .. }));
769 }
770 }
771
772 #[test]
773 fn test_relaxed_config() {
774 let config = RedFlagConfig::relaxed();
775 let validator = StandardRedFlagValidator::new(config, None);
776
777 let result = validator.validate("Wait, let me reconsider this.", &make_metadata(50));
779 assert!(result.is_valid());
780 }
781
782 #[test]
783 fn test_config_builder() {
784 let config = RedFlagConfig::builder()
785 .max_response_tokens(500)
786 .flag_self_correction(false)
787 .add_confusion_pattern("Oops")
788 .build();
789
790 assert_eq!(config.max_response_tokens, 500);
791 assert!(!config.flag_self_correction);
792 assert!(config.confusion_patterns.contains(&"Oops".to_string()));
793 }
794
795 #[test]
796 fn test_accept_all_validator() {
797 let validator = AcceptAllValidator;
798
799 assert!(validator.validate("", &make_metadata(0)).is_valid());
800 assert!(
801 validator
802 .validate("anything", &make_metadata(10000))
803 .is_valid()
804 );
805 }
806
807 #[test]
808 fn test_composite_validator() {
809 let validator =
810 CompositeValidator::new().with_validator(Box::new(StandardRedFlagValidator::strict()));
811
812 assert!(validator.validate("valid", &make_metadata(10)).is_valid());
813 assert!(validator.validate("", &make_metadata(0)).is_flagged());
814 }
815
816 #[test]
817 fn test_red_flag_reason_display() {
818 let reason = RedFlagReason::ResponseTooLong {
819 tokens: 800,
820 limit: 750,
821 };
822 assert_eq!(
823 reason.to_string(),
824 "Response too long: 800 tokens > 750 limit"
825 );
826
827 let reason = RedFlagReason::SelfCorrectionDetected {
828 pattern: "Wait,".to_string(),
829 };
830 assert!(reason.to_string().contains("Wait,"));
831 }
832
833 #[test]
834 fn test_empty_line_ratio() {
835 let config = RedFlagConfig::builder()
836 .max_empty_line_ratio(0.3)
837 .flag_self_correction(false)
838 .build();
839 let validator = StandardRedFlagValidator::new(config, None);
840
841 let response = "line1\n\n\n\nline2";
843 let result = validator.validate(response, &make_metadata(10));
844 assert!(result.is_flagged());
845 }
846}