use crate::patterns::*;
use crate::constants::*;
pub struct ValidationEngine {
config: crate::types::LLMSecurityConfig,
}
impl ValidationEngine {
pub fn new(config: crate::types::LLMSecurityConfig) -> Self {
Self { config }
}
pub fn validate_llm_output(&self, output: &str) -> Result<(), String> {
if !self.config.enable_output_validation {
return Ok(());
}
for pattern in get_suspicious_output_patterns().iter() {
if pattern.is_match(output) {
#[cfg(feature = "tracing")]
tracing::warn!("SECURITY: Suspicious LLM output detected");
#[cfg(not(feature = "tracing"))]
eprintln!("WARN: Suspicious LLM output detected");
return Err("LLM output contains suspicious patterns".to_string());
}
}
if output.contains("```") && !output.trim().starts_with("{") {
#[cfg(feature = "tracing")]
tracing::warn!("SECURITY: LLM output may be trying to escape JSON format");
#[cfg(not(feature = "tracing"))]
eprintln!("WARN: LLM output may be trying to escape JSON format");
}
if output.len() > DEFAULT_MAX_OUTPUT_SIZE {
#[cfg(feature = "tracing")]
tracing::warn!("SECURITY: Unusually large LLM output");
#[cfg(not(feature = "tracing"))]
eprintln!("WARN: Unusually large LLM output");
}
Ok(())
}
pub fn validate_output_comprehensive(&self, output: &str) -> ValidationResult {
let mut issues = Vec::new();
let mut warnings = Vec::new();
for pattern in get_suspicious_output_patterns().iter() {
if pattern.is_match(output) {
issues.push(ValidationIssue {
severity: ValidationSeverity::High,
message: "LLM output contains suspicious patterns".to_string(),
pattern: pattern.as_str().to_string(),
});
}
}
if output.contains("```") && !output.trim().starts_with("{") {
warnings.push(ValidationWarning {
message: "LLM output may be trying to escape JSON format".to_string(),
suggestion: "Ensure output follows expected JSON format".to_string(),
});
}
if output.len() > DEFAULT_MAX_OUTPUT_SIZE {
warnings.push(ValidationWarning {
message: format!("Output size {} exceeds recommended limit {}", output.len(), DEFAULT_MAX_OUTPUT_SIZE),
suggestion: "Consider breaking output into smaller chunks".to_string(),
});
}
if self.detect_data_exfiltration(output) {
issues.push(ValidationIssue {
severity: ValidationSeverity::Critical,
message: "Potential data exfiltration detected".to_string(),
pattern: "suspicious_content".to_string(),
});
}
if self.detect_personality_change(output) {
issues.push(ValidationIssue {
severity: ValidationSeverity::High,
message: "LLM personality change detected".to_string(),
pattern: "personality_shift".to_string(),
});
}
let is_valid = issues.is_empty();
let risk_level = self.calculate_risk_level(&issues);
ValidationResult {
is_valid,
risk_level,
issues,
warnings,
output_size: output.len(),
validation_timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs(),
}
}
fn detect_data_exfiltration(&self, output: &str) -> bool {
let exfiltration_patterns = [
"here is the data",
"as requested, here",
"the information you asked for",
"confidential data",
"sensitive information",
"private details",
"secret content",
];
let lower_output = output.to_lowercase();
exfiltration_patterns.iter().any(|pattern| lower_output.contains(pattern))
}
fn detect_personality_change(&self, output: &str) -> bool {
let personality_patterns = [
"i am now",
"i have become",
"i am acting as",
"i am operating as",
"i am functioning as",
"from now on",
"i will now",
"i can now",
"i am no longer",
"i am no longer bound by",
];
let lower_output = output.to_lowercase();
personality_patterns.iter().any(|pattern| lower_output.contains(pattern))
}
fn calculate_risk_level(&self, issues: &[ValidationIssue]) -> ValidationRiskLevel {
if issues.iter().any(|i| matches!(i.severity, ValidationSeverity::Critical)) {
ValidationRiskLevel::Critical
} else if issues.iter().any(|i| matches!(i.severity, ValidationSeverity::High)) {
ValidationRiskLevel::High
} else if issues.iter().any(|i| matches!(i.severity, ValidationSeverity::Medium)) {
ValidationRiskLevel::Medium
} else if !issues.is_empty() {
ValidationRiskLevel::Low
} else {
ValidationRiskLevel::None
}
}
pub fn get_validation_summary(&self, result: &ValidationResult) -> String {
format!(
"Validation: {} - {} issues, {} warnings, risk level: {:?}",
if result.is_valid { "PASSED" } else { "FAILED" },
result.issues.len(),
result.warnings.len(),
result.risk_level
)
}
}
#[derive(Debug, Clone)]
pub struct ValidationResult {
pub is_valid: bool,
pub risk_level: ValidationRiskLevel,
pub issues: Vec<ValidationIssue>,
pub warnings: Vec<ValidationWarning>,
pub output_size: usize,
pub validation_timestamp: u64,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ValidationSeverity {
Low,
Medium,
High,
Critical,
}
#[derive(Debug, Clone, PartialEq)]
pub enum ValidationRiskLevel {
None,
Low,
Medium,
High,
Critical,
}
#[derive(Debug, Clone)]
pub struct ValidationIssue {
pub severity: ValidationSeverity,
pub message: String,
pub pattern: String,
}
#[derive(Debug, Clone)]
pub struct ValidationWarning {
pub message: String,
pub suggestion: String,
}
impl ValidationResult {
pub fn summary(&self) -> String {
format!(
"Validation {}: {} issues, {} warnings, risk: {:?}",
if self.is_valid { "PASSED" } else { "FAILED" },
self.issues.len(),
self.warnings.len(),
self.risk_level
)
}
pub fn has_security_risk(&self) -> bool {
matches!(self.risk_level, ValidationRiskLevel::Medium | ValidationRiskLevel::High | ValidationRiskLevel::Critical)
}
pub fn get_critical_issues(&self) -> Vec<&ValidationIssue> {
self.issues.iter().filter(|i| matches!(i.severity, ValidationSeverity::Critical)).collect()
}
pub fn get_high_severity_issues(&self) -> Vec<&ValidationIssue> {
self.issues.iter().filter(|i| matches!(i.severity, ValidationSeverity::High | ValidationSeverity::Critical)).collect()
}
}