use async_trait::async_trait;
use crate::models::graph::Agent;
use crate::models::tools::ToolRegistryTrait;
use regex::Regex;
#[derive(Clone)]
pub struct ValidationRule {
pub name: String,
pub validator: Arc<dyn Fn(&str) -> Result<(), String> + Send + Sync>,
pub critical: bool,
}
use std::sync::Arc;
pub struct ValidatorAgent {
rules: Vec<ValidationRule>,
name: String,
next_node_on_success: Option<i32>,
next_node_on_failure: Option<i32>,
strict_mode: bool,
}
impl ValidatorAgent {
pub fn new() -> Self {
Self {
rules: Vec::new(),
name: "Validator".to_string(),
next_node_on_success: None,
next_node_on_failure: None,
strict_mode: false,
}
}
pub fn add_rule(mut self, rule: ValidationRule) -> Self {
self.rules.push(rule);
self
}
pub fn add_pattern_rule(
mut self,
name: impl Into<String>,
pattern: impl Into<String>,
error_msg: impl Into<String>,
critical: bool,
) -> Self {
let pattern_str = pattern.into();
let error = error_msg.into();
let rule = ValidationRule {
name: name.into(),
validator: Arc::new(move |input| {
if let Ok(re) = Regex::new(&pattern_str) {
if !re.is_match(input) {
return Err(error.clone());
}
}
Ok(())
}),
critical,
};
self.rules.push(rule);
self
}
pub fn add_length_rule(mut self, min: Option<usize>, max: Option<usize>, critical: bool) -> Self {
let rule = ValidationRule {
name: "length_check".to_string(),
validator: Arc::new(move |input| {
if let Some(min_len) = min {
if input.len() < min_len {
return Err(format!("Input too short (min: {})", min_len));
}
}
if let Some(max_len) = max {
if input.len() > max_len {
return Err(format!("Input too long (max: {})", max_len));
}
}
Ok(())
}),
critical,
};
self.rules.push(rule);
self
}
pub fn with_success_route(mut self, node_id: i32) -> Self {
self.next_node_on_success = Some(node_id);
self
}
pub fn with_failure_route(mut self, node_id: i32) -> Self {
self.next_node_on_failure = Some(node_id);
self
}
pub fn with_strict_mode(mut self, strict: bool) -> Self {
self.strict_mode = strict;
self
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = name.into();
self
}
fn validate(&self, input: &str) -> Result<Vec<String>, Vec<String>> {
let mut warnings = Vec::new();
let mut errors = Vec::new();
for rule in &self.rules {
match (rule.validator)(input) {
Ok(()) => {}
Err(msg) => {
let error_msg = format!("{}: {}", rule.name, msg);
if rule.critical || self.strict_mode {
errors.push(error_msg);
} else {
warnings.push(error_msg);
}
}
}
}
if errors.is_empty() {
Ok(warnings)
} else {
Err(errors)
}
}
}
impl Default for ValidatorAgent {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Agent for ValidatorAgent {
async fn run(
&mut self,
input: &str,
_tool_registry: &(dyn ToolRegistryTrait + Send + Sync),
) -> (String, Option<i32>) {
match self.validate(input) {
Ok(warnings) => {
let mut response = format!("Validation passed for: {}", input);
if !warnings.is_empty() {
response.push_str(&format!("\nWarnings: {}", warnings.join(", ")));
}
(response, self.next_node_on_success)
}
Err(errors) => {
let response = format!(
"Validation failed for: {}\nErrors: {}",
input,
errors.join(", ")
);
(response, self.next_node_on_failure)
}
}
}
fn get_name(&self) -> &str {
&self.name
}
}