use crate::recursive::llm::Llm;
use crate::recursive::validate::{Score, Validate};
use smallvec::SmallVec;
pub fn semantic<L: Llm>(llm: &L) -> SemanticBuilder<'_, L> {
SemanticBuilder::new(llm)
}
pub struct SemanticBuilder<'a, L: Llm> {
llm: &'a L,
criteria: SmallVec<[&'a str; 8]>,
threshold: f64,
system_prompt: Option<&'a str>,
}
impl<'a, L: Llm> SemanticBuilder<'a, L> {
pub fn new(llm: &'a L) -> Self {
Self {
llm,
criteria: SmallVec::new(),
threshold: 0.7,
system_prompt: None,
}
}
pub fn criterion(mut self, criterion: &'a str) -> Self {
self.criteria.push(criterion);
self
}
pub fn threshold(mut self, threshold: f64) -> Self {
self.threshold = threshold.clamp(0.0, 1.0);
self
}
pub fn system_prompt(mut self, prompt: &'a str) -> Self {
self.system_prompt = Some(prompt);
self
}
pub fn build(self) -> SemanticValidator<'a, L> {
SemanticValidator {
llm: self.llm,
criteria: self.criteria,
threshold: self.threshold,
system_prompt: self.system_prompt,
}
}
}
pub struct SemanticValidator<'a, L: Llm> {
llm: &'a L,
criteria: SmallVec<[&'a str; 8]>,
threshold: f64,
#[allow(dead_code)]
system_prompt: Option<&'a str>,
}
impl<'a, L: Llm> SemanticValidator<'a, L> {
fn build_judge_prompt(&self, text: &str) -> String {
let criteria_list = self
.criteria
.iter()
.enumerate()
.map(|(i, c)| format!("{}. {}", i + 1, c))
.collect::<Vec<_>>()
.join("\n");
format!(
r#"You are evaluating code/text against specific criteria.
Rate each criterion from 0.0 to 1.0.
TEXT TO EVALUATE:
```
{text}
```
CRITERIA:
{criteria}
Respond ONLY with a JSON object in this exact format:
{{"scores": [{{"criterion": "criterion text", "score": 0.0-1.0, "reason": "brief explanation"}}], "overall": 0.0-1.0, "confidence": 0.0-1.0}}
Important:
- "overall" is the weighted average of all criterion scores
- "confidence" indicates how certain you are about your assessment (1.0 = very certain, 0.5 = uncertain)
- Be strict but fair in your evaluation"#,
text = text,
criteria = criteria_list
)
}
fn parse_judgment(&self, response: &str) -> Score<'static> {
let json_str = self.extract_json(response);
let (overall, confidence) = self.parse_scores(json_str);
let feedback = if overall >= self.threshold {
format!("Semantic validation passed with score {:.2}", overall)
} else {
format!(
"Semantic validation failed: score {:.2} < threshold {:.2}",
overall, self.threshold
)
};
Score::with_feedback(overall, feedback).with_confidence(confidence)
}
fn extract_json<'b>(&self, response: &'b str) -> &'b str {
if let Some(start) = response.find("```json") {
let after_marker = &response[start + 7..];
if let Some(end) = after_marker.find("```") {
return after_marker[..end].trim();
}
}
if let Some(start) = response.find('{') {
if let Some(end) = response.rfind('}') {
return &response[start..=end];
}
}
response
}
fn parse_scores(&self, json_str: &str) -> (f64, f64) {
let overall = self.extract_number(json_str, "overall").unwrap_or(0.5);
let confidence = self.extract_number(json_str, "confidence").unwrap_or(0.5);
(overall.clamp(0.0, 1.0), confidence.clamp(0.0, 1.0))
}
fn extract_number(&self, json: &str, key: &str) -> Option<f64> {
let pattern = format!("\"{}\"", key);
let start = json.find(&pattern)?;
let after_key = &json[start + pattern.len()..];
let colon_pos = after_key.find(':')?;
let after_colon = &after_key[colon_pos + 1..];
let trimmed = after_colon.trim_start();
let end = trimmed
.find(|c: char| !c.is_ascii_digit() && c != '.')
.unwrap_or(trimmed.len());
trimmed[..end].parse().ok()
}
}
impl<'a, L: Llm> Validate for SemanticValidator<'a, L> {
fn validate(&self, text: &str) -> Score<'static> {
let prompt = self.build_judge_prompt(text);
let response = std::thread::scope(|s| {
s.spawn(|| crate::recursive::shared::block_on(self.llm.generate(&prompt, "", None)))
.join()
.unwrap_or_else(|_| {
Err(crate::error::Error::module(
"Semantic validation thread panicked",
))
})
});
match response {
Ok(output) => self.parse_judgment(&output.text),
Err(e) => Score::with_feedback(0.5, format!("Semantic validation error: {}", e))
.with_confidence(0.0),
}
}
fn name(&self) -> &'static str {
"semantic"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::recursive::llm::MockLlm;
#[test]
fn test_semantic_builder() {
let llm = MockLlm::new(|_, _| "test".to_string());
let builder = semantic(&llm)
.criterion("Is idiomatic")
.criterion("Has error handling")
.threshold(0.8);
assert_eq!(builder.criteria.len(), 2);
assert!((builder.threshold - 0.8).abs() < f64::EPSILON);
}
#[test]
fn test_build_judge_prompt() {
let llm = MockLlm::new(|_, _| "test".to_string());
let validator = semantic(&llm)
.criterion("Code quality")
.criterion("Readability")
.build();
let prompt = validator.build_judge_prompt("fn main() {}");
assert!(prompt.contains("Code quality"));
assert!(prompt.contains("Readability"));
assert!(prompt.contains("fn main() {}"));
}
#[test]
fn test_parse_judgment_success() {
let llm = MockLlm::new(|_, _| "test".to_string());
let validator = semantic(&llm).criterion("Test").build();
let response = r#"{"scores": [], "overall": 0.85, "confidence": 0.9}"#;
let score = validator.parse_judgment(response);
assert!((score.value - 0.85).abs() < 0.01);
assert!((score.confidence - 0.9).abs() < 0.01);
}
#[test]
fn test_parse_judgment_with_code_block() {
let llm = MockLlm::new(|_, _| "test".to_string());
let validator = semantic(&llm).criterion("Test").build();
let response = r#"Here's my evaluation:
```json
{"scores": [], "overall": 0.75, "confidence": 0.8}
```"#;
let score = validator.parse_judgment(response);
assert!((score.value - 0.75).abs() < 0.01);
assert!((score.confidence - 0.8).abs() < 0.01);
}
#[test]
fn test_semantic_validator_validate() {
let llm = MockLlm::new(|_, _| {
r#"{"scores": [{"criterion": "Test", "score": 0.9, "reason": "Good"}], "overall": 0.9, "confidence": 0.95}"#.to_string()
});
let validator = semantic(&llm)
.criterion("Test criterion")
.threshold(0.8)
.build();
let score = validator.validate("fn main() {}");
assert!((score.value - 0.9).abs() < 0.01);
assert!((score.confidence - 0.95).abs() < 0.01);
assert!(score.passes(0.8));
}
#[test]
fn test_semantic_validator_below_threshold() {
let llm = MockLlm::new(|_, _| r#"{"overall": 0.5, "confidence": 0.8}"#.to_string());
let validator = semantic(&llm).criterion("Quality").threshold(0.7).build();
let score = validator.validate("bad code");
assert!((score.value - 0.5).abs() < 0.01);
assert!(!score.passes(0.7));
assert!(score.feedback_str().unwrap().contains("failed"));
}
#[test]
fn test_extract_number() {
let llm = MockLlm::new(|_, _| "test".to_string());
let validator = semantic(&llm).criterion("Test").build();
let json = r#"{"overall": 0.85, "confidence": 0.9}"#;
assert!((validator.extract_number(json, "overall").unwrap() - 0.85).abs() < 0.01);
assert!((validator.extract_number(json, "confidence").unwrap() - 0.9).abs() < 0.01);
assert!(validator.extract_number(json, "missing").is_none());
}
}