use std::sync::Arc;
use std::time::Duration;
use zeph_llm::any::AnyProvider;
use zeph_llm::provider::LlmProvider as _;
use crate::graph::GraphStore;
#[derive(Debug, Clone)]
pub struct QualityGateConfig {
pub enabled: bool,
pub threshold: f32,
pub recent_window: usize,
pub contradiction_grace_seconds: u64,
pub information_value_weight: f32,
pub reference_completeness_weight: f32,
pub contradiction_weight: f32,
pub rejection_rate_alarm_ratio: f32,
pub llm_timeout_ms: u64,
pub llm_weight: f32,
pub reference_check_lang_en: bool,
}
impl Default for QualityGateConfig {
fn default() -> Self {
Self {
enabled: false,
threshold: 0.55,
recent_window: 32,
contradiction_grace_seconds: 300,
information_value_weight: 0.4,
reference_completeness_weight: 0.3,
contradiction_weight: 0.3,
rejection_rate_alarm_ratio: 0.35,
llm_timeout_ms: 500,
llm_weight: 0.5,
reference_check_lang_en: true,
}
}
}
#[derive(Debug, Clone)]
pub struct QualityScore {
pub information_value: f32,
pub reference_completeness: f32,
pub contradiction_risk: f32,
pub combined: f32,
pub final_score: f32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize)]
#[serde(rename_all = "snake_case")]
pub enum QualityRejectionReason {
Redundant,
IncompleteReference,
Contradiction,
LlmLowConfidence,
}
impl QualityRejectionReason {
#[must_use]
pub fn label(self) -> &'static str {
match self {
Self::Redundant => "redundant",
Self::IncompleteReference => "incomplete_reference",
Self::Contradiction => "contradiction",
Self::LlmLowConfidence => "llm_low_confidence",
}
}
}
struct RollingRateTracker {
window: std::collections::VecDeque<bool>,
capacity: usize,
reject_count: usize,
}
impl RollingRateTracker {
fn new(capacity: usize) -> Self {
Self {
window: std::collections::VecDeque::with_capacity(capacity + 1),
capacity,
reject_count: 0,
}
}
fn push(&mut self, rejected: bool) {
if self.window.len() >= self.capacity
&& let Some(evicted) = self.window.pop_front()
&& evicted
{
self.reject_count = self.reject_count.saturating_sub(1);
}
self.window.push_back(rejected);
if rejected {
self.reject_count += 1;
}
}
#[allow(clippy::cast_precision_loss)]
fn rate(&self) -> f32 {
if self.window.is_empty() {
return 0.0;
}
self.reject_count as f32 / self.window.len() as f32
}
}
pub struct QualityGate {
config: Arc<QualityGateConfig>,
llm_provider: Option<Arc<AnyProvider>>,
graph_store: Option<Arc<GraphStore>>,
rejection_counts: std::sync::Mutex<std::collections::HashMap<QualityRejectionReason, u64>>,
rate_tracker: std::sync::Mutex<RollingRateTracker>,
}
impl QualityGate {
#[must_use]
pub fn new(config: QualityGateConfig) -> Self {
Self {
config: Arc::new(config),
llm_provider: None,
graph_store: None,
rejection_counts: std::sync::Mutex::new(std::collections::HashMap::new()),
rate_tracker: std::sync::Mutex::new(RollingRateTracker::new(100)),
}
}
#[must_use]
pub fn with_llm_provider(mut self, provider: AnyProvider) -> Self {
self.llm_provider = Some(Arc::new(provider));
self
}
#[must_use]
pub fn with_graph_store(mut self, store: Arc<GraphStore>) -> Self {
self.graph_store = Some(store);
self
}
#[must_use]
pub fn config(&self) -> &QualityGateConfig {
&self.config
}
#[must_use]
pub fn rejection_counts(&self) -> std::collections::HashMap<QualityRejectionReason, u64> {
self.rejection_counts
.lock()
.map(|g| g.clone())
.unwrap_or_default()
}
#[tracing::instrument(name = "memory.quality_gate.evaluate", skip_all)]
pub async fn evaluate(
&self,
content: &str,
embed_provider: &AnyProvider,
recent_embeddings: &[Vec<f32>],
) -> Option<QualityRejectionReason> {
if !self.config.enabled {
return None;
}
let info_val = compute_information_value(content, embed_provider, recent_embeddings).await;
let ref_comp = if self.config.reference_check_lang_en {
compute_reference_completeness(content)
} else {
1.0
};
let contradiction_risk =
compute_contradiction_risk(content, self.graph_store.as_deref(), &self.config).await;
let w_v = self.config.information_value_weight;
let w_c = self.config.reference_completeness_weight;
let w_k = self.config.contradiction_weight;
let rule_score = w_v * info_val + w_c * ref_comp + w_k * (1.0 - contradiction_risk);
let final_score = if let Some(ref llm) = self.llm_provider {
let llm_score = call_llm_scorer(content, llm, self.config.llm_timeout_ms).await;
let lw = self.config.llm_weight;
(1.0 - lw) * rule_score + lw * llm_score
} else {
rule_score
};
let rejected = final_score < self.config.threshold;
if let Ok(mut tracker) = self.rate_tracker.lock() {
tracker.push(rejected);
let rate = tracker.rate();
if rate > self.config.rejection_rate_alarm_ratio {
tracing::warn!(
rate = %format!("{:.2}", rate),
window_size = self.config.recent_window,
threshold = self.config.rejection_rate_alarm_ratio,
"quality_gate: high rejection rate alarm"
);
}
}
if !rejected {
return None;
}
let reason = if info_val < 0.1 {
QualityRejectionReason::Redundant
} else if ref_comp < 0.5 && self.config.reference_check_lang_en {
QualityRejectionReason::IncompleteReference
} else if contradiction_risk >= 1.0 {
QualityRejectionReason::Contradiction
} else {
QualityRejectionReason::LlmLowConfidence
};
if let Ok(mut counts) = self.rejection_counts.lock() {
*counts.entry(reason).or_insert(0) += 1;
}
tracing::debug!(
reason = reason.label(),
final_score,
info_val,
ref_comp,
contradiction_risk,
"quality_gate: rejected write"
);
Some(reason)
}
}
async fn compute_information_value(
content: &str,
provider: &AnyProvider,
recent_embeddings: &[Vec<f32>],
) -> f32 {
if recent_embeddings.is_empty() {
return 1.0;
}
if !provider.supports_embeddings() {
return 1.0;
}
let candidate = match provider.embed(content).await {
Ok(v) => v,
Err(e) => {
tracing::debug!(error = %e, "quality_gate: embed failed, treating info_val = 1.0 (fail-open)");
return 1.0;
}
};
let max_sim = recent_embeddings
.iter()
.map(|r| zeph_common::math::cosine_similarity(&candidate, r))
.fold(0.0f32, f32::max);
(1.0 - max_sim).max(0.0)
}
#[must_use]
pub fn compute_reference_completeness(content: &str) -> f32 {
const PRONOUNS: &[&str] = &[
" he ", " she ", " they ", " it ", " him ", " her ", " them ",
];
const DEICTIC_TIME: &[&str] = &[
"yesterday",
"tomorrow",
"last week",
"next week",
"last month",
"next month",
"last year",
"next year",
];
const DATE_ANCHORS: &[&str] = &[
"january",
"february",
"march",
"april",
"may",
"june",
"july",
"august",
"september",
"october",
"november",
"december",
"jan ",
"feb ",
"mar ",
"apr ",
"jun ",
"jul ",
"aug ",
"sep ",
"oct ",
"nov ",
"dec ",
];
let lower = content.to_lowercase();
let padded = format!(" {lower} ");
let pronoun_count = PRONOUNS.iter().filter(|&&p| padded.contains(p)).count();
let has_year_anchor = has_4digit_year_anchor(&lower);
let has_date_anchor = has_year_anchor || DATE_ANCHORS.iter().any(|&a| lower.contains(a));
let deictic_count = if has_date_anchor {
0
} else {
DEICTIC_TIME.iter().filter(|&&t| lower.contains(t)).count()
};
let total_issues = pronoun_count + deictic_count;
if total_issues == 0 {
return 1.0;
}
let word_count = content.split_ascii_whitespace().count().max(1);
#[allow(clippy::cast_precision_loss)]
let ratio = total_issues as f32 / word_count as f32;
(1.0 - ratio * 2.0).clamp(0.0, 1.0)
}
fn has_4digit_year_anchor(text: &str) -> bool {
let bytes = text.as_bytes();
let len = bytes.len();
if len < 4 {
return false;
}
let mut i = 0usize;
while i + 3 < len {
let c0 = bytes[i];
let c1 = bytes[i + 1];
if ((c0 == b'1' && c1 == b'9') || (c0 == b'2' && c1 == b'0'))
&& bytes[i + 2].is_ascii_digit()
&& bytes[i + 3].is_ascii_digit()
{
let left_ok = i == 0 || !bytes[i - 1].is_ascii_digit();
let right_ok = i + 4 >= len || !bytes[i + 4].is_ascii_digit();
if left_ok && right_ok {
return true;
}
}
i += 1;
}
false
}
async fn compute_contradiction_risk(
content: &str,
graph: Option<&GraphStore>,
config: &QualityGateConfig,
) -> f32 {
let Some(store) = graph else {
return 0.0;
};
let content_lower = content.to_lowercase();
let subject_query = extract_subject_tokens(&content_lower);
if subject_query.is_empty() {
return 0.0;
}
let Ok(entities) = store.find_entities_fuzzy(&subject_query, 1).await else {
return 0.0;
};
let Some(subject_entity) = entities.into_iter().next() else {
return 0.0;
};
let canonical_predicate = extract_predicate_token(&content_lower);
let Ok(edges) = store.edges_for_entity(subject_entity.id).await else {
return 0.0;
};
let relevant_edges: Vec<_> = edges
.iter()
.filter(|e| {
e.source_entity_id == subject_entity.id
&& canonical_predicate
.as_ref()
.is_none_or(|p| e.relation == *p)
})
.collect();
if relevant_edges.is_empty() {
return 0.0;
}
let now_secs = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_or(0, |d| d.as_secs());
let has_old_conflict = relevant_edges.iter().any(|edge| {
let edge_ts = chrono::DateTime::parse_from_rfc3339(&edge.created_at)
.map_or(0u64, |dt| u64::try_from(dt.timestamp()).unwrap_or(0));
now_secs.saturating_sub(edge_ts) > config.contradiction_grace_seconds
});
if has_old_conflict { 1.0 } else { 0.5 }
}
fn extract_subject_tokens(content_lower: &str) -> String {
const VERB_MARKERS: &[&str] = &["is", "was", "are", "were", "has", "have", "had", "will"];
let tokens: Vec<&str> = content_lower.split_ascii_whitespace().collect();
let end = tokens
.iter()
.position(|t| VERB_MARKERS.contains(t))
.unwrap_or(2.min(tokens.len()));
let subject_tokens = &tokens[..end.min(3)];
subject_tokens.join(" ")
}
fn extract_predicate_token(content_lower: &str) -> Option<String> {
const VERB_MARKERS: &[&str] = &["is", "was", "are", "were", "has", "have", "had", "will"];
content_lower
.split_ascii_whitespace()
.find(|t| VERB_MARKERS.contains(t))
.map(str::to_owned)
}
async fn call_llm_scorer(content: &str, provider: &AnyProvider, timeout_ms: u64) -> f32 {
use zeph_llm::provider::{Message, MessageMetadata, Role};
let system = "You are a memory quality judge. Rate the quality of the following message \
for long-term storage on a scale of 0.0 to 1.0. Consider: information density, \
completeness of references, factual clarity. \
Respond with ONLY a JSON object: \
{\"information_value\": 0.0-1.0, \"reference_completeness\": 0.0-1.0, \
\"contradiction_risk\": 0.0-1.0}";
let user = format!(
"Message: {}\n\nQuality JSON:",
content.chars().take(500).collect::<String>()
);
let messages = vec![
Message {
role: Role::System,
content: system.to_owned(),
parts: vec![],
metadata: MessageMetadata::default(),
},
Message {
role: Role::User,
content: user,
parts: vec![],
metadata: MessageMetadata::default(),
},
];
let timeout = Duration::from_millis(timeout_ms);
let result = match tokio::time::timeout(timeout, provider.chat(&messages)).await {
Ok(Ok(r)) => r,
Ok(Err(e)) => {
tracing::debug!(error = %e, "quality_gate: LLM scorer failed, using 0.5");
return 0.5;
}
Err(_) => {
tracing::debug!("quality_gate: LLM scorer timed out, using 0.5");
return 0.5;
}
};
parse_llm_score(&result)
}
fn parse_llm_score(response: &str) -> f32 {
let start = response.find('{');
let end = response.rfind('}');
let (Some(s), Some(e)) = (start, end) else {
return 0.5;
};
let json_str = &response[s..=e];
let Ok(val) = serde_json::from_str::<serde_json::Value>(json_str) else {
return 0.5;
};
#[allow(clippy::cast_possible_truncation)]
let iv = val["information_value"].as_f64().unwrap_or(0.5) as f32;
#[allow(clippy::cast_possible_truncation)]
let rc = val["reference_completeness"].as_f64().unwrap_or(0.5) as f32;
#[allow(clippy::cast_possible_truncation)]
let cr = val["contradiction_risk"].as_f64().unwrap_or(0.0) as f32;
let score =
0.4 * iv.clamp(0.0, 1.0) + 0.3 * rc.clamp(0.0, 1.0) + 0.3 * (1.0 - cr.clamp(0.0, 1.0));
score.clamp(0.0, 1.0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn reference_completeness_clean_text() {
let score = compute_reference_completeness("The Rust compiler enforces memory safety.");
assert!((score - 1.0).abs() < 0.01, "clean text should score 1.0");
}
#[test]
fn reference_completeness_pronoun_heavy() {
let score = compute_reference_completeness("yeah he said they confirmed it");
assert!(
score < 0.5,
"pronoun-heavy message should score below 0.5, got {score}"
);
}
#[test]
fn reference_completeness_deictic_without_anchor() {
let score = compute_reference_completeness("We agreed yesterday to postpone");
assert!(
score < 1.0,
"deictic time without anchor should penalize, got {score}"
);
}
#[test]
fn reference_completeness_deictic_with_anchor() {
let score = compute_reference_completeness("We agreed yesterday (2026-04-18) to postpone");
assert!(
score >= 0.9,
"deictic with anchor '20' should not penalize, got {score}"
);
}
#[test]
fn rejection_reason_labels() {
assert_eq!(QualityRejectionReason::Redundant.label(), "redundant");
assert_eq!(
QualityRejectionReason::IncompleteReference.label(),
"incomplete_reference"
);
assert_eq!(
QualityRejectionReason::Contradiction.label(),
"contradiction"
);
assert_eq!(
QualityRejectionReason::LlmLowConfidence.label(),
"llm_low_confidence"
);
}
#[test]
fn rolling_rate_tracker_basic() {
let mut tracker = RollingRateTracker::new(4);
tracker.push(true);
tracker.push(true);
tracker.push(false);
tracker.push(false);
let rate = tracker.rate();
assert!((rate - 0.5).abs() < 0.01, "rate should be 0.5, got {rate}");
}
#[test]
fn rolling_rate_tracker_evicts_oldest() {
let mut tracker = RollingRateTracker::new(3);
tracker.push(true); tracker.push(false);
tracker.push(false);
tracker.push(false); let rate = tracker.rate();
assert!(
rate < 0.01,
"evicted rejection should not count, rate={rate}"
);
}
#[test]
fn parse_llm_score_valid_json() {
let json = r#"{"information_value": 0.8, "reference_completeness": 0.9, "contradiction_risk": 0.1}"#;
let score = parse_llm_score(json);
assert!(
score > 0.7,
"high-quality JSON should yield high score, got {score}"
);
}
#[test]
fn parse_llm_score_malformed_returns_neutral() {
let score = parse_llm_score("not json");
assert!(
(score - 0.5).abs() < 0.01,
"malformed JSON should return 0.5"
);
}
fn mock_provider() -> zeph_llm::any::AnyProvider {
zeph_llm::any::AnyProvider::Mock(zeph_llm::mock::MockProvider::default())
}
#[tokio::test]
async fn gate_disabled_always_passes() {
let config = QualityGateConfig {
enabled: false,
..QualityGateConfig::default()
};
let gate = QualityGate::new(config);
let provider = mock_provider();
let result = gate.evaluate("yeah he confirmed it", &provider, &[]).await;
assert!(result.is_none(), "disabled gate must always pass");
}
#[tokio::test]
async fn gate_admits_novel_clean_content() {
let config = QualityGateConfig {
enabled: true,
threshold: 0.3, ..QualityGateConfig::default()
};
let gate = QualityGate::new(config);
let provider = mock_provider();
let result = gate
.evaluate(
"The Rust compiler enforces memory safety through the borrow checker.",
&provider,
&[],
)
.await;
assert!(result.is_none(), "clean novel content should be admitted");
}
#[tokio::test]
async fn gate_rejects_pronoun_only_at_low_threshold() {
let config = QualityGateConfig {
enabled: true,
threshold: 0.75, reference_completeness_weight: 0.9,
information_value_weight: 0.05,
contradiction_weight: 0.05,
..QualityGateConfig::default()
};
let gate = QualityGate::new(config);
let provider = mock_provider();
let result = gate
.evaluate("yeah he confirmed it they said so", &provider, &[])
.await;
assert!(
result == Some(QualityRejectionReason::IncompleteReference),
"pronoun-heavy message should be rejected as IncompleteReference, got {result:?}"
);
}
#[test]
fn quality_gate_counts_rejections() {
let config = QualityGateConfig {
enabled: true,
threshold: 0.99, ..QualityGateConfig::default()
};
let gate = QualityGate::new(config);
if let Ok(mut counts) = gate.rejection_counts.lock() {
*counts.entry(QualityRejectionReason::Redundant).or_insert(0) += 1;
}
let counts = gate.rejection_counts();
assert_eq!(counts.get(&QualityRejectionReason::Redundant), Some(&1));
}
#[tokio::test]
async fn gate_fail_open_on_embed_error() {
let config = QualityGateConfig {
enabled: true,
threshold: 0.5,
..QualityGateConfig::default()
};
let gate = QualityGate::new(config);
let provider = zeph_llm::any::AnyProvider::Mock(
zeph_llm::mock::MockProvider::default().with_embed_invalid_input(),
);
let result = gate
.evaluate(
"Alice confirmed the meeting at 3pm.",
&provider,
&[], )
.await;
assert!(
result.is_none(),
"embed error must be treated as fail-open (admitted), got {result:?}"
);
}
#[tokio::test]
async fn gate_rejects_redundant_with_populated_embeddings() {
let config = QualityGateConfig {
enabled: true,
threshold: 0.5,
information_value_weight: 0.9,
reference_completeness_weight: 0.05,
contradiction_weight: 0.05,
..QualityGateConfig::default()
};
let gate = QualityGate::new(config);
let fixed_embedding = vec![0.1_f32; 384];
let provider = zeph_llm::any::AnyProvider::Mock(
zeph_llm::mock::MockProvider::default().with_embedding(fixed_embedding.clone()),
);
let result = gate
.evaluate(
"The Rust compiler enforces memory safety through the borrow checker.",
&provider,
&[fixed_embedding],
)
.await;
assert_eq!(
result,
Some(QualityRejectionReason::Redundant),
"identical recent embedding must trigger Redundant rejection"
);
}
#[tokio::test]
async fn gate_llm_timeout_falls_back_to_rule_score() {
let config = QualityGateConfig {
enabled: true,
threshold: 0.3, llm_timeout_ms: 50, llm_weight: 0.5,
..QualityGateConfig::default()
};
let gate = QualityGate::new(config);
let slow_provider = zeph_llm::any::AnyProvider::Mock(
zeph_llm::mock::MockProvider::default().with_delay(600),
);
let gate = gate.with_llm_provider(slow_provider);
let embed_provider = mock_provider();
let result = gate
.evaluate(
"The release is scheduled for next Friday.",
&embed_provider,
&[],
)
.await;
assert!(
result.is_none(),
"LLM timeout must fall back to rule score and admit clean content, got {result:?}"
);
}
}