use std::time::Duration;
use serde::{Deserialize, Serialize};
use zeph_llm::provider::{LlmProvider, Message, Role};
use zeph_sanitizer::{ContentSanitizer, ContentSource, ContentSourceKind};
use super::error::OrchestrationError;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "kind", content = "value", rename_all = "snake_case")]
pub enum VerifyPredicate {
Natural(String),
Expression(String),
}
impl VerifyPredicate {
pub fn as_natural(&self) -> Result<&str, OrchestrationError> {
match self {
VerifyPredicate::Natural(s) => Ok(s.as_str()),
VerifyPredicate::Expression(s) => Err(OrchestrationError::PredicateNotSupported(
format!("Expression predicate '{s}' is not supported in v1; use Natural"),
)),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PredicateOutcome {
pub passed: bool,
pub confidence: f32,
pub reason: String,
}
pub struct PredicateEvaluator<P: LlmProvider> {
provider: P,
sanitizer: ContentSanitizer,
timeout: Duration,
}
impl<P: LlmProvider> PredicateEvaluator<P> {
pub fn new(provider: P, sanitizer: ContentSanitizer, timeout_secs: u64) -> Self {
Self {
provider,
sanitizer,
timeout: Duration::from_secs(timeout_secs),
}
}
pub async fn evaluate(
&self,
predicate: &VerifyPredicate,
output: &str,
prior_failure_reason: Option<&str>,
) -> PredicateOutcome {
let criterion = match predicate.as_natural() {
Ok(s) => s,
Err(e) => {
tracing::warn!(error = %e, "unsupported predicate variant, skipping evaluation (fail-open)");
return PredicateOutcome {
passed: true,
confidence: 0.0,
reason: format!("predicate not evaluated: {e}"),
};
}
};
let prior_note = prior_failure_reason
.map(|r| {
let truncated: String = r.chars().take(256).collect();
format!(
"\n\n<prior_failure_reason>{truncated}</prior_failure_reason>\n\
Note: a previous evaluation failed with this reason. Take it into account."
)
})
.unwrap_or_default();
let system = format!(
"You are a strict output verifier. Evaluate whether the task output satisfies \
the given criterion. Respond with a JSON object: \
{{\"passed\": true/false, \"confidence\": 0.0-1.0, \"reason\": \"...\"}}\n\
Criterion: {criterion}{prior_note}"
);
let source = ContentSource::new(ContentSourceKind::ToolResult)
.with_identifier("predicate-evaluator-input");
let sanitized = self.sanitizer.sanitize(output, source);
let user = format!("Task output:\n\n{}", sanitized.body);
let messages = vec![
Message::from_legacy(Role::System, system),
Message::from_legacy(Role::User, user),
];
match tokio::time::timeout(
self.timeout,
self.provider.chat_typed::<EvalResponse>(&messages),
)
.await
{
Ok(Ok(resp)) => {
let outcome = PredicateOutcome {
passed: resp.passed,
confidence: resp.confidence.clamp(0.0, 1.0),
reason: resp.reason,
};
if outcome.passed && outcome.confidence < 0.5 {
tracing::warn!(
confidence = outcome.confidence,
reason = %outcome.reason,
"weak predicate pass (confidence < 0.5)"
);
}
outcome
}
Ok(Err(e)) => {
tracing::warn!(
error = %e,
"predicate evaluation LLM call failed, returning fail-open outcome"
);
PredicateOutcome {
passed: true,
confidence: 0.0,
reason: format!("evaluation failed: {e}"),
}
}
Err(_elapsed) => {
tracing::warn!(
timeout_secs = self.timeout.as_secs(),
"predicate evaluation timed out, returning fail-open outcome"
);
PredicateOutcome {
passed: true,
confidence: 0.0,
reason: "evaluation timed out".to_string(),
}
}
}
}
}
#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
struct EvalResponse {
passed: bool,
confidence: f32,
reason: String,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn natural_predicate_as_natural() {
let pred = VerifyPredicate::Natural("must contain JSON".to_string());
assert_eq!(pred.as_natural().unwrap(), "must contain JSON");
}
#[test]
fn expression_predicate_returns_error() {
let pred = VerifyPredicate::Expression("len(output) > 0".to_string());
assert!(pred.as_natural().is_err());
}
#[test]
fn predicate_outcome_serde_roundtrip() {
let o = PredicateOutcome {
passed: true,
confidence: 0.85,
reason: "looks good".to_string(),
};
let json = serde_json::to_string(&o).expect("serialize");
let restored: PredicateOutcome = serde_json::from_str(&json).expect("deserialize");
assert_eq!(restored.passed, o.passed);
assert!((restored.confidence - o.confidence).abs() < f32::EPSILON);
assert_eq!(restored.reason, o.reason);
}
#[test]
fn verify_predicate_serde_roundtrip_natural() {
let pred = VerifyPredicate::Natural("criterion".to_string());
let json = serde_json::to_string(&pred).expect("serialize");
let restored: VerifyPredicate = serde_json::from_str(&json).expect("deserialize");
assert_eq!(pred, restored);
}
#[test]
fn task_node_missing_predicate_fields_deserialize_as_none() {
let json = r#"{
"id": 0,
"title": "t",
"description": "d",
"agent_hint": null,
"status": "pending",
"depends_on": [],
"result": null,
"assigned_agent": null,
"retry_count": 0,
"failure_strategy": null,
"max_retries": null
}"#;
let val: serde_json::Value = serde_json::from_str(json).expect("parse");
assert!(val.get("verify_predicate").is_none());
assert!(val.get("predicate_outcome").is_none());
}
}