use regex::Regex;
use super::voting::ResponseMetadata;
#[derive(Clone, Debug)]
pub struct RedFlagConfig {
pub max_response_tokens: u32,
pub require_exact_format: bool,
pub flag_self_correction: bool,
pub confusion_patterns: Vec<String>,
pub min_response_length: u32,
pub max_empty_line_ratio: f32,
}
impl Default for RedFlagConfig {
fn default() -> Self {
Self::strict()
}
}
impl RedFlagConfig {
pub fn strict() -> Self {
Self {
max_response_tokens: 750,
require_exact_format: true,
flag_self_correction: true,
confusion_patterns: vec![
"Wait,".to_string(),
"Actually,".to_string(),
"Let me reconsider".to_string(),
"I made a mistake".to_string(),
"On second thought".to_string(),
"Hmm,".to_string(),
"I think I".to_string(),
"Let me correct".to_string(),
"Sorry, I meant".to_string(),
"That's not right".to_string(),
],
min_response_length: 1,
max_empty_line_ratio: 0.5,
}
}
pub fn relaxed() -> Self {
Self {
max_response_tokens: 1500,
require_exact_format: false,
flag_self_correction: false,
confusion_patterns: vec![],
min_response_length: 0,
max_empty_line_ratio: 0.8,
}
}
pub fn builder() -> RedFlagConfigBuilder {
RedFlagConfigBuilder::default()
}
}
#[derive(Default)]
pub struct RedFlagConfigBuilder {
config: RedFlagConfig,
}
impl RedFlagConfigBuilder {
pub fn max_response_tokens(mut self, tokens: u32) -> Self {
self.config.max_response_tokens = tokens;
self
}
pub fn require_exact_format(mut self, require: bool) -> Self {
self.config.require_exact_format = require;
self
}
pub fn flag_self_correction(mut self, flag: bool) -> Self {
self.config.flag_self_correction = flag;
self
}
pub fn add_confusion_pattern(mut self, pattern: impl Into<String>) -> Self {
self.config.confusion_patterns.push(pattern.into());
self
}
pub fn confusion_patterns(mut self, patterns: Vec<String>) -> Self {
self.config.confusion_patterns = patterns;
self
}
pub fn min_response_length(mut self, length: u32) -> Self {
self.config.min_response_length = length;
self
}
pub fn max_empty_line_ratio(mut self, ratio: f32) -> Self {
self.config.max_empty_line_ratio = ratio;
self
}
pub fn build(self) -> RedFlagConfig {
self.config
}
}
#[derive(Clone, Debug)]
pub enum RedFlagResult {
Valid,
Flagged {
reason: RedFlagReason,
severity: f32,
},
}
impl RedFlagResult {
pub fn is_valid(&self) -> bool {
matches!(self, RedFlagResult::Valid)
}
pub fn is_flagged(&self) -> bool {
matches!(self, RedFlagResult::Flagged { .. })
}
}
#[derive(Clone, Debug)]
pub enum RedFlagReason {
ResponseTooLong {
tokens: u32,
limit: u32,
},
ResponseTooShort {
length: u32,
minimum: u32,
},
InvalidFormat {
expected: String,
got: String,
},
SelfCorrectionDetected {
pattern: String,
},
ConfusedReasoning {
pattern: String,
},
ParseError {
message: String,
},
EmptyResponse,
TooManyEmptyLines {
ratio: f32,
max: f32,
},
InvalidJson {
message: String,
},
MissingField {
field: String,
},
Truncated {
reason: String,
},
}
impl std::fmt::Display for RedFlagReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RedFlagReason::ResponseTooLong { tokens, limit } => {
write!(f, "Response too long: {} tokens > {} limit", tokens, limit)
}
RedFlagReason::ResponseTooShort { length, minimum } => {
write!(
f,
"Response too short: {} chars < {} minimum",
length, minimum
)
}
RedFlagReason::InvalidFormat { expected, got } => {
write!(f, "Invalid format: expected {}, got {}", expected, got)
}
RedFlagReason::SelfCorrectionDetected { pattern } => {
write!(f, "Self-correction detected: '{}'", pattern)
}
RedFlagReason::ConfusedReasoning { pattern } => {
write!(f, "Confused reasoning: '{}'", pattern)
}
RedFlagReason::ParseError { message } => {
write!(f, "Parse error: {}", message)
}
RedFlagReason::EmptyResponse => write!(f, "Empty response"),
RedFlagReason::TooManyEmptyLines { ratio, max } => {
write!(
f,
"Too many empty lines: {:.1}% > {:.1}% max",
ratio * 100.0,
max * 100.0
)
}
RedFlagReason::InvalidJson { message } => {
write!(f, "Invalid JSON: {}", message)
}
RedFlagReason::MissingField { field } => {
write!(f, "Missing required field: {}", field)
}
RedFlagReason::Truncated { reason } => {
write!(f, "Response truncated: {}", reason)
}
}
}
}
pub trait RedFlagValidator: Send + Sync {
fn validate(&self, response: &str, metadata: &ResponseMetadata) -> RedFlagResult;
}
pub struct StandardRedFlagValidator {
config: RedFlagConfig,
expected_format: Option<OutputFormat>,
confusion_regexes: Vec<Regex>,
}
impl StandardRedFlagValidator {
pub fn new(config: RedFlagConfig, expected_format: Option<OutputFormat>) -> Self {
let confusion_regexes = config
.confusion_patterns
.iter()
.filter_map(|p| {
Regex::new(®ex::escape(p)).ok()
})
.collect();
Self {
config,
expected_format,
confusion_regexes,
}
}
pub fn strict() -> Self {
Self::new(RedFlagConfig::strict(), None)
}
pub fn with_format(format: OutputFormat) -> Self {
Self::new(RedFlagConfig::strict(), Some(format))
}
pub fn set_expected_format(&mut self, format: Option<OutputFormat>) {
self.expected_format = format;
}
fn check_length(&self, response: &str, metadata: &ResponseMetadata) -> Option<RedFlagResult> {
if response.trim().is_empty() {
return Some(RedFlagResult::Flagged {
reason: RedFlagReason::EmptyResponse,
severity: 1.0,
});
}
if (response.len() as u32) < self.config.min_response_length {
return Some(RedFlagResult::Flagged {
reason: RedFlagReason::ResponseTooShort {
length: response.len() as u32,
minimum: self.config.min_response_length,
},
severity: 0.9,
});
}
if metadata.token_count > self.config.max_response_tokens {
return Some(RedFlagResult::Flagged {
reason: RedFlagReason::ResponseTooLong {
tokens: metadata.token_count,
limit: self.config.max_response_tokens,
},
severity: 0.8,
});
}
None
}
fn check_self_correction(&self, response: &str) -> Option<RedFlagResult> {
if !self.config.flag_self_correction {
return None;
}
for (regex, pattern) in self
.confusion_regexes
.iter()
.zip(&self.config.confusion_patterns)
{
if regex.is_match(response) {
return Some(RedFlagResult::Flagged {
reason: RedFlagReason::SelfCorrectionDetected {
pattern: pattern.clone(),
},
severity: 0.7,
});
}
}
None
}
fn check_format(&self, response: &str) -> Option<RedFlagResult> {
if !self.config.require_exact_format {
return None;
}
if let Some(ref format) = self.expected_format
&& !format.matches(response)
{
return Some(RedFlagResult::Flagged {
reason: RedFlagReason::InvalidFormat {
expected: format.description(),
got: self.extract_format_sample(response),
},
severity: 0.9,
});
}
None
}
fn check_truncation(&self, metadata: &ResponseMetadata) -> Option<RedFlagResult> {
if let Some(ref reason) = metadata.finish_reason {
let reason_lower = reason.to_lowercase();
if reason_lower.contains("length") || reason_lower.contains("max_tokens") {
return Some(RedFlagResult::Flagged {
reason: RedFlagReason::Truncated {
reason: reason.clone(),
},
severity: 0.85,
});
}
}
None
}
fn check_empty_lines(&self, response: &str) -> Option<RedFlagResult> {
let lines: Vec<&str> = response.lines().collect();
if lines.is_empty() {
return None;
}
let empty_count = lines.iter().filter(|l| l.trim().is_empty()).count();
let ratio = empty_count as f32 / lines.len() as f32;
if ratio > self.config.max_empty_line_ratio {
return Some(RedFlagResult::Flagged {
reason: RedFlagReason::TooManyEmptyLines {
ratio,
max: self.config.max_empty_line_ratio,
},
severity: 0.6,
});
}
None
}
fn extract_format_sample(&self, response: &str) -> String {
let trimmed = response.trim();
if trimmed.len() <= 50 {
trimmed.to_string()
} else {
format!("{}...", &trimmed[..50])
}
}
}
impl RedFlagValidator for StandardRedFlagValidator {
fn validate(&self, response: &str, metadata: &ResponseMetadata) -> RedFlagResult {
if let Some(result) = self.check_length(response, metadata) {
return result;
}
if let Some(result) = self.check_truncation(metadata) {
return result;
}
if let Some(result) = self.check_format(response) {
return result;
}
if let Some(result) = self.check_self_correction(response) {
return result;
}
if let Some(result) = self.check_empty_lines(response) {
return result;
}
RedFlagResult::Valid
}
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub enum OutputFormat {
Exact(String),
Pattern(String),
Json,
JsonWithFields(Vec<String>),
Markers {
start: String,
end: String,
},
OneOf(Vec<String>),
Custom {
description: String,
validator_id: String,
},
}
impl OutputFormat {
pub fn matches(&self, response: &str) -> bool {
let trimmed = response.trim();
match self {
OutputFormat::Exact(s) => trimmed == s.trim(),
OutputFormat::Pattern(pattern) => Regex::new(pattern)
.map(|re| re.is_match(trimmed))
.unwrap_or(false),
OutputFormat::Json => serde_json::from_str::<serde_json::Value>(trimmed).is_ok(),
OutputFormat::JsonWithFields(fields) => {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(trimmed)
&& let Some(obj) = value.as_object()
{
return fields.iter().all(|f| obj.contains_key(f));
}
false
}
OutputFormat::Markers { start, end } => {
trimmed.contains(start) && trimmed.contains(end)
}
OutputFormat::OneOf(options) => options.iter().any(|o| trimmed == o.trim()),
OutputFormat::Custom { .. } => {
true
}
}
}
pub fn description(&self) -> String {
match self {
OutputFormat::Exact(s) => format!("exact: '{}'", s),
OutputFormat::Pattern(p) => format!("pattern: {}", p),
OutputFormat::Json => "valid JSON".to_string(),
OutputFormat::JsonWithFields(fields) => {
format!("JSON with fields: {}", fields.join(", "))
}
OutputFormat::Markers { start, end } => format!("markers: {}...{}", start, end),
OutputFormat::OneOf(options) => format!("one of: {}", options.join(", ")),
OutputFormat::Custom { description, .. } => description.clone(),
}
}
}
pub struct AcceptAllValidator;
impl RedFlagValidator for AcceptAllValidator {
fn validate(&self, _response: &str, _metadata: &ResponseMetadata) -> RedFlagResult {
RedFlagResult::Valid
}
}
pub struct CompositeValidator {
validators: Vec<Box<dyn RedFlagValidator>>,
}
impl CompositeValidator {
pub fn new() -> Self {
Self {
validators: Vec::new(),
}
}
pub fn with_validator(mut self, validator: Box<dyn RedFlagValidator>) -> Self {
self.validators.push(validator);
self
}
}
impl Default for CompositeValidator {
fn default() -> Self {
Self::new()
}
}
impl RedFlagValidator for CompositeValidator {
fn validate(&self, response: &str, metadata: &ResponseMetadata) -> RedFlagResult {
for validator in &self.validators {
let result = validator.validate(response, metadata);
if result.is_flagged() {
return result;
}
}
RedFlagResult::Valid
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_metadata(tokens: u32) -> ResponseMetadata {
ResponseMetadata {
token_count: tokens,
response_time_ms: 100,
format_valid: true,
finish_reason: None,
model: None,
}
}
#[test]
fn test_valid_response() {
let validator = StandardRedFlagValidator::strict();
let result = validator.validate("This is a valid response.", &make_metadata(50));
assert!(result.is_valid());
}
#[test]
fn test_empty_response() {
let validator = StandardRedFlagValidator::strict();
let result = validator.validate("", &make_metadata(0));
assert!(result.is_flagged());
if let RedFlagResult::Flagged { reason, .. } = result {
assert!(matches!(reason, RedFlagReason::EmptyResponse));
}
}
#[test]
fn test_response_too_long() {
let validator = StandardRedFlagValidator::strict();
let result = validator.validate("Some response", &make_metadata(800)); assert!(result.is_flagged());
if let RedFlagResult::Flagged { reason, .. } = result {
assert!(matches!(reason, RedFlagReason::ResponseTooLong { .. }));
}
}
#[test]
fn test_self_correction_detected() {
let validator = StandardRedFlagValidator::strict();
let result = validator.validate(
"Wait, I think I made an error. Let me reconsider.",
&make_metadata(50),
);
assert!(result.is_flagged());
if let RedFlagResult::Flagged { reason, .. } = result {
assert!(matches!(
reason,
RedFlagReason::SelfCorrectionDetected { .. }
));
}
}
#[test]
fn test_confused_reasoning() {
let validator = StandardRedFlagValidator::strict();
let result = validator.validate(
"Actually, that's not right. On second thought...",
&make_metadata(50),
);
assert!(result.is_flagged());
}
#[test]
fn test_format_validation_exact() {
let validator =
StandardRedFlagValidator::with_format(OutputFormat::Exact("hello".to_string()));
assert!(validator.validate("hello", &make_metadata(10)).is_valid());
assert!(
validator
.validate(" hello ", &make_metadata(10))
.is_valid()
); assert!(validator.validate("world", &make_metadata(10)).is_flagged());
}
#[test]
fn test_format_validation_json() {
let validator = StandardRedFlagValidator::with_format(OutputFormat::Json);
assert!(
validator
.validate(r#"{"key": "value"}"#, &make_metadata(20))
.is_valid()
);
assert!(
validator
.validate("not json", &make_metadata(10))
.is_flagged()
);
}
#[test]
fn test_format_validation_json_with_fields() {
let validator = StandardRedFlagValidator::with_format(OutputFormat::JsonWithFields(vec![
"name".to_string(),
"value".to_string(),
]));
assert!(
validator
.validate(r#"{"name": "test", "value": 42}"#, &make_metadata(30))
.is_valid()
);
assert!(
validator
.validate(r#"{"name": "test"}"#, &make_metadata(20))
.is_flagged()
); }
#[test]
fn test_format_validation_markers() {
let validator = StandardRedFlagValidator::with_format(OutputFormat::Markers {
start: "```".to_string(),
end: "```".to_string(),
});
assert!(
validator
.validate("```code here```", &make_metadata(20))
.is_valid()
);
assert!(
validator
.validate("no markers", &make_metadata(10))
.is_flagged()
);
}
#[test]
fn test_format_validation_one_of() {
let validator = StandardRedFlagValidator::with_format(OutputFormat::OneOf(vec![
"yes".to_string(),
"no".to_string(),
"maybe".to_string(),
]));
assert!(validator.validate("yes", &make_metadata(5)).is_valid());
assert!(validator.validate("no", &make_metadata(5)).is_valid());
assert!(
validator
.validate("perhaps", &make_metadata(10))
.is_flagged()
);
}
#[test]
fn test_truncation_detection() {
let validator = StandardRedFlagValidator::strict();
let mut metadata = make_metadata(50);
metadata.finish_reason = Some("length".to_string());
let result = validator.validate("Truncated response", &metadata);
assert!(result.is_flagged());
if let RedFlagResult::Flagged { reason, .. } = result {
assert!(matches!(reason, RedFlagReason::Truncated { .. }));
}
}
#[test]
fn test_relaxed_config() {
let config = RedFlagConfig::relaxed();
let validator = StandardRedFlagValidator::new(config, None);
let result = validator.validate("Wait, let me reconsider this.", &make_metadata(50));
assert!(result.is_valid());
}
#[test]
fn test_config_builder() {
let config = RedFlagConfig::builder()
.max_response_tokens(500)
.flag_self_correction(false)
.add_confusion_pattern("Oops")
.build();
assert_eq!(config.max_response_tokens, 500);
assert!(!config.flag_self_correction);
assert!(config.confusion_patterns.contains(&"Oops".to_string()));
}
#[test]
fn test_accept_all_validator() {
let validator = AcceptAllValidator;
assert!(validator.validate("", &make_metadata(0)).is_valid());
assert!(
validator
.validate("anything", &make_metadata(10000))
.is_valid()
);
}
#[test]
fn test_composite_validator() {
let validator =
CompositeValidator::new().with_validator(Box::new(StandardRedFlagValidator::strict()));
assert!(validator.validate("valid", &make_metadata(10)).is_valid());
assert!(validator.validate("", &make_metadata(0)).is_flagged());
}
#[test]
fn test_red_flag_reason_display() {
let reason = RedFlagReason::ResponseTooLong {
tokens: 800,
limit: 750,
};
assert_eq!(
reason.to_string(),
"Response too long: 800 tokens > 750 limit"
);
let reason = RedFlagReason::SelfCorrectionDetected {
pattern: "Wait,".to_string(),
};
assert!(reason.to_string().contains("Wait,"));
}
#[test]
fn test_empty_line_ratio() {
let config = RedFlagConfig::builder()
.max_empty_line_ratio(0.3)
.flag_self_correction(false)
.build();
let validator = StandardRedFlagValidator::new(config, None);
let response = "line1\n\n\n\nline2";
let result = validator.validate(response, &make_metadata(10));
assert!(result.is_flagged());
}
}