synthclaw 0.1.3

Lightweight synthetic data generation library/CLI.
Documentation
mod dedup;
mod validators;

pub use dedup::*;
pub use validators::*;

use crate::config::ValidationConfig;
use crate::generation::GenerationResult;

#[derive(Debug, Clone, Default)]
pub struct ValidationResult {
    pub is_valid: bool,
    pub errors: Vec<String>,
}

impl ValidationResult {
    pub fn valid() -> Self {
        Self {
            is_valid: true,
            errors: vec![],
        }
    }

    pub fn invalid(error: impl Into<String>) -> Self {
        Self {
            is_valid: false,
            errors: vec![error.into()],
        }
    }

    pub fn merge(&mut self, other: Self) {
        if !other.is_valid {
            self.is_valid = false;
            self.errors.extend(other.errors);
        }
    }
}

pub trait Validator: Send + Sync {
    fn validate(&self, result: &GenerationResult) -> ValidationResult;
}

#[derive(Default)]
pub struct ValidationPipeline {
    validators: Vec<Box<dyn Validator>>,
}

impl ValidationPipeline {
    pub fn new() -> Self {
        Self::default()
    }

    pub fn from_config(config: &ValidationConfig) -> Self {
        let mut p = Self::new();
        if let Some(n) = config.min_length {
            p = p.add(MinLength(n));
        }
        if let Some(n) = config.max_length {
            p = p.add(MaxLength(n));
        }
        if config.json {
            p = p.add(Json);
        }
        if let Some(fields) = &config.json_schema {
            let fields: Vec<&str> = fields.iter().map(|s| s.as_str()).collect();
            p = p.add(JsonSchema::require(&fields));
        }
        if config.blocklist {
            p = p.add(Blocklist::llm_artifacts());
        }
        if config.repetition {
            p = p.add(Repetition::default());
        }
        p
    }

    pub fn add<V: Validator + 'static>(mut self, v: V) -> Self {
        self.validators.push(Box::new(v));
        self
    }

    pub fn validate(&self, result: &GenerationResult) -> ValidationResult {
        self.validators
            .iter()
            .fold(ValidationResult::valid(), |mut acc, v| {
                acc.merge(v.validate(result));
                acc
            })
    }

    pub fn filter(
        &self,
        results: Vec<GenerationResult>,
    ) -> (
        Vec<GenerationResult>,
        Vec<(GenerationResult, ValidationResult)>,
    ) {
        let (mut valid, mut invalid) = (vec![], vec![]);
        for r in results {
            let v = self.validate(&r);
            if v.is_valid {
                valid.push(r);
            } else {
                invalid.push((r, v));
            }
        }
        (valid, invalid)
    }
}

#[derive(Debug, Clone, Default)]
pub struct ValidationStats {
    pub total: usize,
    pub passed: usize,
    pub failed: usize,
    pub duplicates_removed: usize,
}

pub struct ValidatedResults {
    pub results: Vec<GenerationResult>,
    pub stats: ValidationStats,
    pub rejected: Vec<(GenerationResult, ValidationResult)>,
}

pub fn validate_and_dedupe(
    results: Vec<GenerationResult>,
    pipeline: &ValidationPipeline,
    dedup: Option<&Deduplicator>,
) -> ValidatedResults {
    let total = results.len();
    let (valid, rejected) = pipeline.filter(results);
    let failed = rejected.len();

    let (results, duplicates_removed) = match dedup {
        Some(d) => {
            let before = valid.len();
            let deduped = d.dedupe(valid);
            let removed = before - deduped.len();
            (deduped, removed)
        }
        None => (valid, 0),
    };

    ValidatedResults {
        stats: ValidationStats {
            total,
            passed: results.len(),
            failed,
            duplicates_removed,
        },
        results,
        rejected,
    }
}