use std::sync::Arc;
use std::time::{Duration, Instant};
use zeph_llm::any::AnyProvider;
use super::checker::run_checker;
use super::config::{QualityConfig, TriggerPolicy};
use super::proposer::run_proposer;
use super::types::{AssertionVerdict, SelfCheckReport, SkipReason, StageOutcome, VerdictStatus};
#[derive(Debug, Default)]
pub struct RetrievedContext<'a> {
pub recall: Vec<&'a str>,
pub graph_facts: Vec<&'a str>,
pub cross_session: Vec<&'a str>,
pub summaries: Vec<&'a str>,
}
impl RetrievedContext<'_> {
#[must_use]
pub fn is_empty(&self) -> bool {
self.recall.is_empty()
&& self.graph_facts.is_empty()
&& self.cross_session.is_empty()
&& self.summaries.is_empty()
}
#[must_use]
pub fn joined(&self, sep: &str) -> String {
let parts: Vec<&str> = self
.recall
.iter()
.chain(&self.graph_facts)
.chain(&self.cross_session)
.chain(&self.summaries)
.copied()
.collect();
parts.join(sep)
}
}
pub struct SelfCheckPipeline {
pub(crate) cfg: QualityConfig,
proposer: AnyProvider,
checker: AnyProvider,
}
impl SelfCheckPipeline {
pub(crate) fn cfg_ref(&self) -> &QualityConfig {
&self.cfg
}
}
impl SelfCheckPipeline {
pub fn build(config: &QualityConfig, main_provider: &AnyProvider) -> Result<Arc<Self>, String> {
config.validate().map_err(|e| e.to_string())?;
let proposer = main_provider.clone();
let checker = if config.cache_disabled_for_checker {
main_provider.with_prompt_cache_disabled()
} else {
main_provider.clone()
};
Ok(Arc::new(Self {
cfg: config.clone(),
proposer,
checker,
}))
}
pub async fn run(
&self,
response: &str,
retrieved_context: RetrievedContext<'_>,
user_query: &str,
turn_id: u64,
) -> SelfCheckReport {
let started = Instant::now();
let per_call = Duration::from_millis(self.cfg.per_call_timeout_ms);
if self.cfg.trigger == TriggerPolicy::HasRetrieval && retrieved_context.is_empty() {
return SelfCheckReport {
turn_id,
assertions: vec![],
verdicts: vec![],
flagged_ids: vec![],
latency_ms: u64::try_from(started.elapsed().as_millis()).unwrap_or(u64::MAX),
proposer_tokens: 0,
checker_tokens: 0,
proposer_outcome: StageOutcome::Skipped(SkipReason::NoRetrievedContext),
checker_outcome: StageOutcome::Skipped(SkipReason::NoRetrievedContext),
parse_retries: 0,
};
}
let (assertions, p_tokens, p_outcome, p_retries) =
run_proposer(&self.proposer, response, self.cfg.max_assertions, per_call).await;
if assertions.is_empty() {
return SelfCheckReport {
turn_id,
assertions,
verdicts: vec![],
flagged_ids: vec![],
latency_ms: u64::try_from(started.elapsed().as_millis()).unwrap_or(u64::MAX),
proposer_tokens: p_tokens,
checker_tokens: 0,
proposer_outcome: p_outcome,
checker_outcome: StageOutcome::Skipped(SkipReason::NoAssistantText),
parse_retries: p_retries,
};
}
let evidence = retrieved_context.joined("\n\n");
let (verdicts, c_tokens, c_outcome, c_retries) =
run_checker(&self.checker, &assertions, &evidence, user_query, per_call).await;
let flagged_ids = self.compute_flagged(&verdicts);
SelfCheckReport {
turn_id,
assertions,
verdicts,
flagged_ids,
latency_ms: u64::try_from(started.elapsed().as_millis()).unwrap_or(u64::MAX),
proposer_tokens: p_tokens,
checker_tokens: c_tokens,
proposer_outcome: p_outcome,
checker_outcome: c_outcome,
parse_retries: p_retries + c_retries,
}
}
fn compute_flagged(&self, verdicts: &[AssertionVerdict]) -> Vec<u32> {
verdicts
.iter()
.filter(|v| {
v.status == VerdictStatus::Contradicted
|| (v.status != VerdictStatus::Irrelevant && v.evidence < self.cfg.min_evidence)
})
.map(|v| v.id)
.collect()
}
}