synthclaw 0.1.3

Lightweight synthetic data generation library/CLI.
Documentation
use super::{ValidationResult, Validator};
use crate::generation::GenerationResult;
use regex::Regex;
use std::collections::HashMap;

pub struct MinLength(pub usize);
pub struct MaxLength(pub usize);

impl Validator for MinLength {
    fn validate(&self, r: &GenerationResult) -> ValidationResult {
        if r.content.trim().len() >= self.0 {
            ValidationResult::valid()
        } else {
            ValidationResult::invalid(format!(
                "too short: {} < {}",
                r.content.trim().len(),
                self.0
            ))
        }
    }
}

impl Validator for MaxLength {
    fn validate(&self, r: &GenerationResult) -> ValidationResult {
        if r.content.len() <= self.0 {
            ValidationResult::valid()
        } else {
            ValidationResult::invalid(format!("too long: {} > {}", r.content.len(), self.0))
        }
    }
}

pub struct Json;
pub struct JsonSchema {
    required: Vec<String>,
}

impl JsonSchema {
    pub fn require(fields: &[&str]) -> Self {
        Self {
            required: fields.iter().map(|s| s.to_string()).collect(),
        }
    }
}

impl Validator for Json {
    fn validate(&self, r: &GenerationResult) -> ValidationResult {
        match serde_json::from_str::<serde_json::Value>(&extract_json(&r.content)) {
            Ok(_) => ValidationResult::valid(),
            Err(e) => ValidationResult::invalid(format!("invalid json: {}", e)),
        }
    }
}

impl Validator for JsonSchema {
    fn validate(&self, r: &GenerationResult) -> ValidationResult {
        let v: serde_json::Value = match serde_json::from_str(&extract_json(&r.content)) {
            Ok(v) => v,
            Err(e) => return ValidationResult::invalid(format!("invalid json: {}", e)),
        };

        let obj = match v.as_object() {
            Some(o) => o,
            None => return ValidationResult::invalid("expected json object"),
        };

        let missing: Vec<_> = self
            .required
            .iter()
            .filter(|f| !obj.contains_key(*f))
            .collect();
        if missing.is_empty() {
            ValidationResult::valid()
        } else {
            ValidationResult::invalid(format!("missing fields: {:?}", missing))
        }
    }
}

pub struct Blocklist(Vec<(Regex, &'static str)>);

impl Blocklist {
    pub fn llm_artifacts() -> Self {
        let patterns = [
            (r"(?i)^(sure|certainly|of course)[,!]?\s", "filler"),
            (r"(?i)^here('s| is)", "here is"),
            (r"(?i)^I('d| would) be happy to", "politeness"),
            (r"(?i)as an AI", "ai mention"),
            (r"(?i)I cannot|I can't|I'm unable", "refusal"),
        ];
        Self(
            patterns
                .iter()
                .filter_map(|(p, r)| Some((Regex::new(p).ok()?, *r)))
                .collect(),
        )
    }
}

impl Validator for Blocklist {
    fn validate(&self, r: &GenerationResult) -> ValidationResult {
        for (re, reason) in &self.0 {
            if re.is_match(&r.content) {
                return ValidationResult::invalid(format!("blocked: {}", reason));
            }
        }
        ValidationResult::valid()
    }
}

pub struct Repetition {
    pub max_ratio: f32,
    pub ngram_size: usize,
}

impl Default for Repetition {
    fn default() -> Self {
        Self {
            max_ratio: 0.5,
            ngram_size: 3,
        }
    }
}

impl Validator for Repetition {
    fn validate(&self, r: &GenerationResult) -> ValidationResult {
        let words: Vec<_> = r.content.split_whitespace().collect();
        if words.len() < self.ngram_size * 2 {
            return ValidationResult::valid();
        }

        let mut counts: HashMap<String, usize> = HashMap::new();
        for w in words.windows(self.ngram_size) {
            *counts.entry(w.join(" ").to_lowercase()).or_default() += 1;
        }

        let total = words.len() - self.ngram_size + 1;
        let repeated: usize = counts.values().filter(|&&c| c > 1).map(|c| c - 1).sum();
        let ratio = repeated as f32 / total as f32;

        if ratio <= self.max_ratio {
            ValidationResult::valid()
        } else {
            ValidationResult::invalid(format!(
                "repetitive: {:.0}% > {:.0}%",
                ratio * 100.0,
                self.max_ratio * 100.0
            ))
        }
    }
}

pub struct Custom<F>(pub F);

impl<F: Fn(&GenerationResult) -> ValidationResult + Send + Sync> Validator for Custom<F> {
    fn validate(&self, r: &GenerationResult) -> ValidationResult {
        self.0(r)
    }
}

fn extract_json(content: &str) -> String {
    let s = content.trim();
    if let Some(start) = s.find("```json") {
        if let Some(end) = s[start + 7..].find("```") {
            return s[start + 7..start + 7 + end].trim().to_string();
        }
    }
    if let Some(start) = s.find("```") {
        if let Some(end) = s[start + 3..].find("```") {
            let inner = s[start + 3..start + 3 + end].trim();
            return inner.lines().skip(1).collect::<Vec<_>>().join("\n");
        }
    }
    s.to_string()
}

#[cfg(test)]
mod tests {
    use super::*;

    fn r(s: &str) -> GenerationResult {
        GenerationResult {
            content: s.to_string(),
            source_index: None,
            category: None,
            input_tokens: 0,
            output_tokens: 0,
        }
    }

    #[test]
    fn test_length() {
        assert!(!MinLength(10).validate(&r("short")).is_valid);
        assert!(MinLength(5).validate(&r("hello")).is_valid);
        assert!(MaxLength(10).validate(&r("short")).is_valid);
        assert!(!MaxLength(5).validate(&r("too long")).is_valid);
    }

    #[test]
    fn test_json() {
        assert!(Json.validate(&r(r#"{"a":1}"#)).is_valid);
        assert!(!Json.validate(&r("not json")).is_valid);
        assert!(Json.validate(&r("```json\n{\"a\":1}\n```")).is_valid);
    }

    #[test]
    fn test_schema() {
        let v = JsonSchema::require(&["a", "b"]);
        assert!(v.validate(&r(r#"{"a":1,"b":2}"#)).is_valid);
        assert!(!v.validate(&r(r#"{"a":1}"#)).is_valid);
    }

    #[test]
    fn test_blocklist() {
        let v = Blocklist::llm_artifacts();
        assert!(!v.validate(&r("Sure! Here you go")).is_valid);
        assert!(!v.validate(&r("As an AI, I")).is_valid);
        assert!(v.validate(&r("Normal text")).is_valid);
    }

    #[test]
    fn test_repetition() {
        let v = Repetition {
            max_ratio: 0.3,
            ngram_size: 2,
        };
        assert!(!v.validate(&r("the cat the cat the cat the cat")).is_valid);
        assert!(v.validate(&r("the quick brown fox jumps")).is_valid);
    }
}