use crate::recursive::validate::{Score, Validate};
use regex::Regex;
use smallvec::SmallVec;
use std::sync::OnceLock;
#[inline]
pub fn checks() -> Checks {
Checks::new()
}
#[derive(Clone)]
pub enum CheckKind {
Require(String),
Forbid(String),
Regex(String, OnceLock<Regex>),
MinLen(usize),
MaxLen(usize),
MaxErrors(usize),
Predicate(fn(&str) -> bool, &'static str),
}
#[derive(Clone)]
pub struct Check {
name: &'static str,
kind: CheckKind,
weight: f64,
feedback: String,
}
#[derive(Clone)]
pub struct Checks {
checks: SmallVec<[Check; 8]>,
total_weight: f64,
}
impl Default for Checks {
fn default() -> Self {
Self::new()
}
}
impl Checks {
pub fn new() -> Self {
Self {
checks: SmallVec::new(),
total_weight: 0.0,
}
}
fn add_check(mut self, check: Check) -> Self {
self.total_weight += check.weight;
self.checks.push(check);
self
}
pub fn require(self, pattern: impl Into<String>) -> Self {
let pattern = pattern.into();
let feedback = format!("Missing required: '{}'", pattern);
self.add_check(Check {
name: "require",
kind: CheckKind::Require(pattern),
weight: 1.0,
feedback,
})
}
pub fn require_weighted(self, pattern: impl Into<String>, weight: f64) -> Self {
let pattern = pattern.into();
let feedback = format!("Missing required: '{}'", pattern);
self.add_check(Check {
name: "require",
kind: CheckKind::Require(pattern),
weight,
feedback,
})
}
pub fn require_all<I, S>(mut self, patterns: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
for pattern in patterns {
self = self.require(pattern);
}
self
}
pub fn require_all_weighted<I, S>(mut self, patterns: I, weight: f64) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
for pattern in patterns {
self = self.require_weighted(pattern, weight);
}
self
}
pub fn forbid(self, pattern: impl Into<String>) -> Self {
let pattern = pattern.into();
let feedback = format!("Must not contain: '{}'", pattern);
self.add_check(Check {
name: "forbid",
kind: CheckKind::Forbid(pattern),
weight: 1.0,
feedback,
})
}
pub fn forbid_weighted(self, pattern: impl Into<String>, weight: f64) -> Self {
let pattern = pattern.into();
let feedback = format!("Must not contain: '{}'", pattern);
self.add_check(Check {
name: "forbid",
kind: CheckKind::Forbid(pattern),
weight,
feedback,
})
}
pub fn forbid_all<I, S>(mut self, patterns: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
for pattern in patterns {
self = self.forbid(pattern);
}
self
}
pub fn forbid_all_weighted<I, S>(mut self, patterns: I, weight: f64) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
for pattern in patterns {
self = self.forbid_weighted(pattern, weight);
}
self
}
pub fn regex(self, pattern: impl Into<String>) -> Self {
let pattern = pattern.into();
let feedback = format!("Regex not matched: '{}'", pattern);
self.add_check(Check {
name: "regex",
kind: CheckKind::Regex(pattern, OnceLock::new()),
weight: 1.0,
feedback,
})
}
pub fn regex_weighted(self, pattern: impl Into<String>, weight: f64) -> Self {
let pattern = pattern.into();
let feedback = format!("Regex not matched: '{}'", pattern);
self.add_check(Check {
name: "regex",
kind: CheckKind::Regex(pattern, OnceLock::new()),
weight,
feedback,
})
}
pub fn regex_all<I, S>(mut self, patterns: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
for pattern in patterns {
self = self.regex(pattern);
}
self
}
pub fn regex_all_weighted<I, S>(mut self, patterns: I, weight: f64) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
for pattern in patterns {
self = self.regex_weighted(pattern, weight);
}
self
}
pub fn min_len(self, n: usize) -> Self {
self.add_check(Check {
name: "min_len",
kind: CheckKind::MinLen(n),
weight: 1.0,
feedback: format!("Text too short (min {} chars)", n),
})
}
pub fn min_len_weighted(self, n: usize, weight: f64) -> Self {
self.add_check(Check {
name: "min_len",
kind: CheckKind::MinLen(n),
weight,
feedback: format!("Text too short (min {} chars)", n),
})
}
pub fn max_len(self, n: usize) -> Self {
self.add_check(Check {
name: "max_len",
kind: CheckKind::MaxLen(n),
weight: 1.0,
feedback: format!("Text too long (max {} chars)", n),
})
}
pub fn max_len_weighted(self, n: usize, weight: f64) -> Self {
self.add_check(Check {
name: "max_len",
kind: CheckKind::MaxLen(n),
weight,
feedback: format!("Text too long (max {} chars)", n),
})
}
pub fn max_errors(self, n: usize) -> Self {
self.add_check(Check {
name: "max_errors",
kind: CheckKind::MaxErrors(n),
weight: 1.0,
feedback: format!("Too many error lines (max {})", n),
})
}
pub fn pred(self, name: &'static str, f: fn(&str) -> bool) -> Self {
self.add_check(Check {
name,
kind: CheckKind::Predicate(f, name),
weight: 1.0,
feedback: format!("Predicate failed: '{}'", name),
})
}
pub fn pred_weighted(self, name: &'static str, f: fn(&str) -> bool, weight: f64) -> Self {
self.add_check(Check {
name,
kind: CheckKind::Predicate(f, name),
weight,
feedback: format!("Predicate failed: '{}'", name),
})
}
pub fn weight(mut self, w: f64) -> Self {
if let Some(check) = self.checks.last_mut() {
self.total_weight -= check.weight;
check.weight = w;
self.total_weight += w;
}
self
}
pub fn feedback(mut self, msg: impl Into<String>) -> Self {
if let Some(check) = self.checks.last_mut() {
check.feedback = msg.into();
}
self
}
pub fn require_if(self, cond: bool, pattern: impl Into<String>) -> Self {
if cond {
self.require(pattern)
} else {
self
}
}
pub fn forbid_if(self, cond: bool, pattern: impl Into<String>) -> Self {
if cond {
self.forbid(pattern)
} else {
self
}
}
pub fn python(self) -> Self {
self.require("def ")
.forbid("SyntaxError")
.forbid("IndentationError")
}
pub fn rust(self) -> Self {
self.require("fn ")
.forbid("todo!()")
.forbid("unimplemented!()")
}
pub fn json(self) -> Self {
self.pred("valid_json", |s| {
serde_json::from_str::<serde_json::Value>(s).is_ok()
})
}
pub fn yaml(self) -> Self {
self.forbid("\t").pred("yaml_structure", |s| {
s.lines()
.any(|l| l.contains(": ") || l.trim_start().starts_with("- "))
})
}
pub fn sql(self) -> Self {
self.pred("sql_keyword", |s| {
let upper = s.to_uppercase();
upper.contains("SELECT")
|| upper.contains("INSERT")
|| upper.contains("UPDATE")
|| upper.contains("DELETE")
|| upper.contains("CREATE")
|| upper.contains("ALTER")
})
}
fn evaluate_check(check: &Check, text: &str) -> bool {
match &check.kind {
CheckKind::Require(pattern) => text.contains(pattern.as_str()),
CheckKind::Forbid(pattern) => !text.contains(pattern.as_str()),
CheckKind::Regex(pattern, compiled) => {
let regex = compiled.get_or_init(|| {
Regex::new(pattern).unwrap_or_else(|_| Regex::new("^$").unwrap())
});
regex.is_match(text)
}
CheckKind::MinLen(n) => text.len() >= *n,
CheckKind::MaxLen(n) => text.len() <= *n,
CheckKind::MaxErrors(n) => {
let count = text
.lines()
.filter(|line| {
line.contains("error") || line.contains("Error") || line.contains("ERROR")
})
.count();
count <= *n
}
CheckKind::Predicate(f, _) => f(text),
}
}
}
impl Validate for Checks {
fn validate(&self, text: &str) -> Score<'static> {
if self.checks.is_empty() {
return Score::pass();
}
let mut weighted_sum = 0.0;
let mut failed_checks = Vec::new();
let mut breakdown = SmallVec::new();
for check in &self.checks {
let passed = Self::evaluate_check(check, text);
let check_score = if passed { 1.0 } else { 0.0 };
weighted_sum += check_score * check.weight;
breakdown.push((check.name, check_score));
if !passed {
failed_checks.push(check.feedback.as_str());
}
}
let final_score = if self.total_weight > 0.0 {
weighted_sum / self.total_weight
} else {
1.0
};
if failed_checks.is_empty() {
Score::pass().with_breakdown(breakdown)
} else {
let feedback = failed_checks.join("; ");
Score::with_feedback(final_score, feedback).with_breakdown(breakdown)
}
}
fn name(&self) -> &'static str {
"checks"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_checks() {
let v = checks();
let score = v.validate("anything");
assert!(score.is_perfect());
}
#[test]
fn test_require() {
let v = checks().require("fn ");
assert!(v.validate("fn main() {}").is_perfect());
assert!(!v.validate("let x = 1").is_perfect());
}
#[test]
fn test_forbid() {
let v = checks().forbid(".unwrap()");
assert!(v.validate("let x = 1").is_perfect());
assert!(!v.validate("x.unwrap()").is_perfect());
}
#[test]
fn test_min_len() {
let v = checks().min_len(10);
assert!(v.validate("hello world").is_perfect()); assert!(!v.validate("hello").is_perfect()); }
#[test]
fn test_max_len() {
let v = checks().max_len(10);
assert!(v.validate("hello").is_perfect()); assert!(!v.validate("hello world").is_perfect()); }
#[test]
fn test_regex() {
let v = checks().regex(r"fn\s+\w+");
assert!(v.validate("fn main() {}").is_perfect());
assert!(!v.validate("let x = 1").is_perfect());
}
#[test]
fn test_max_errors() {
let v = checks().max_errors(1);
assert!(v.validate("line1\nline2").is_perfect());
assert!(v.validate("error: one\nline2").is_perfect()); assert!(!v.validate("error: one\nerror: two").is_perfect()); }
#[test]
fn test_pred() {
let v = checks().pred("has_return", |s| s.contains("return"));
assert!(v.validate("return 42;").is_perfect());
assert!(!v.validate("let x = 42;").is_perfect());
}
#[test]
fn test_weighted_checks() {
let v = checks()
.require_weighted("fn ", 0.5)
.require_weighted("->", 0.5);
let score = v.validate("fn foo() -> i32 {}");
assert!((score.value - 1.0).abs() < f64::EPSILON);
let score = v.validate("fn foo() {}");
assert!((score.value - 0.5).abs() < f64::EPSILON);
let score = v.validate("let x = 1");
assert!((score.value - 0.0).abs() < f64::EPSILON);
}
#[test]
fn test_weight_modifier() {
let v = checks()
.require("fn ")
.weight(2.0)
.require("->")
.weight(1.0);
let score = v.validate("fn foo() {}");
assert!((score.value - 2.0 / 3.0).abs() < 0.01);
}
#[test]
fn test_breakdown() {
let v = checks().require("fn ").forbid(".unwrap()");
let score = v.validate("fn foo() { x.unwrap() }");
assert!(score.breakdown.is_some());
let breakdown = score.breakdown.unwrap();
assert_eq!(breakdown.len(), 2);
assert_eq!(breakdown[0], ("require", 1.0)); assert_eq!(breakdown[1], ("forbid", 0.0)); }
#[test]
fn test_combined_checks() {
let v = checks()
.require("fn ")
.require("->")
.forbid(".unwrap()")
.forbid("panic!")
.min_len(20);
let code = "fn parse(s: &str) -> Option<i32> { s.parse().ok() }";
let score = v.validate(code);
assert!(score.is_perfect());
let code = "fn parse(s: &str) -> i32 { s.parse().unwrap() }";
let score = v.validate(code);
assert!(score.value < 1.0);
}
#[test]
fn test_feedback_modifier() {
let v = checks().require("fn ").feedback("Missing function keyword");
let score = v.validate("let x = 1");
assert_eq!(score.feedback_str(), Some("Missing function keyword"));
}
#[test]
fn test_require_all() {
let v = checks().require_all(["fn ", "-> i32", "pub"]);
let score = v.validate("pub fn add(a: i32, b: i32) -> i32 { a + b }");
assert!(score.is_perfect());
let score = v.validate("fn add(a: i32, b: i32) -> i32 { a + b }");
assert!(!score.is_perfect()); }
#[test]
fn test_forbid_all() {
let v = checks().forbid_all([".unwrap()", "panic!", "todo!"]);
let score = v.validate("fn safe() -> i32 { 42 }");
assert!(score.is_perfect());
let score = v.validate("fn bad() { panic!(\"oh no\") }");
assert!(!score.is_perfect());
}
#[test]
fn test_regex_all() {
let v = checks().regex_all([r"fn \w+", r"-> \w+"]);
let score = v.validate("fn add(a: i32) -> i32 { a + 1 }");
assert!(score.is_perfect());
let score = v.validate("let x = 42;");
assert!(!score.is_perfect());
}
#[test]
fn test_batch_with_vec() {
let patterns: Vec<String> = vec!["fn ".to_string(), "pub ".to_string()];
let v = checks().require_all(patterns);
let score = v.validate("pub fn test() {}");
assert!(score.is_perfect());
}
#[test]
fn test_batch_empty_iterator() {
let v = checks().require_all(Vec::<String>::new()).require("fn ");
let score = v.validate("fn test() {}");
assert!(score.is_perfect());
}
#[test]
fn test_batch_mixed_with_individual() {
let v = checks()
.require_all(["fn ", "-> i32"])
.forbid_all([".unwrap()", "panic!"])
.min_len(10);
let score = v.validate("fn add(a: i32, b: i32) -> i32 { a + b }");
assert!(score.is_perfect());
}
#[test]
fn test_batch_weighted() {
let v = checks()
.require_all_weighted(["fn ", "pub "], 0.5)
.forbid_all_weighted(["unsafe", "unwrap"], 2.0);
let score = v.validate("pub fn safe() -> i32 { 42 }");
assert!(score.is_perfect());
}
}