use std::time::Duration;
use serde::Deserialize;
use zeph_llm::any::AnyProvider;
use super::parser::{ChatJsonError, chat_json};
use super::prompts::{CHECKER_SYSTEM, checker_user};
use super::types::{Assertion, AssertionVerdict, StageOutcome};
#[derive(Debug, Deserialize)]
struct CheckerOutput {
verdicts: Vec<AssertionVerdict>,
}
pub async fn run_checker(
provider: &AnyProvider,
assertions: &[Assertion],
evidence: &str,
user_query: &str,
per_call_timeout: Duration,
) -> (Vec<AssertionVerdict>, u64, StageOutcome, u32) {
let assertions_json = match serde_json::to_string(assertions) {
Ok(j) => j,
Err(e) => {
return (
vec![],
0,
StageOutcome::LlmError {
msg: format!("assertion serialization failed: {e}"),
},
0,
);
}
};
let user = checker_user(user_query, evidence, &assertions_json);
match chat_json::<CheckerOutput>(provider, CHECKER_SYSTEM, &user, per_call_timeout).await {
Ok((out, tokens, attempt)) => {
let retries = attempt.saturating_sub(1);
(out.verdicts, tokens, StageOutcome::Ok, retries)
}
Err(ChatJsonError::Llm(e)) => (vec![], 0, StageOutcome::LlmError { msg: e.to_string() }, 0),
Err(ChatJsonError::Timeout(ms)) => (vec![], 0, StageOutcome::Timeout { ms }, 0),
Err(ChatJsonError::Parse(raw)) => (
vec![],
0,
StageOutcome::ParseError { raw_truncated: raw },
1,
),
}
}