use std::time::Duration;
use zeph_llm::any::AnyProvider;
use zeph_llm::provider::{LlmProvider as _, Message, Role};
use crate::error::MemoryError;
use crate::graph::types::Edge;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ConflictStrategy {
Recency,
Confidence,
Llm,
}
pub const SUPERSEDE_DEPTH_CAP: usize = 64;
pub struct ConflictResult {
pub winner: Edge,
pub alternatives: Vec<Edge>,
}
pub struct ConflictResolver {
strategy: ConflictStrategy,
timeout: Duration,
llm_budget: std::sync::atomic::AtomicI32,
retain_alternatives: bool,
llm_provider: Option<AnyProvider>,
}
impl ConflictResolver {
#[must_use]
pub fn new(
strategy: ConflictStrategy,
timeout_ms: u64,
llm_budget_per_turn: usize,
retain_alternatives: bool,
) -> Self {
let budget = i32::try_from(llm_budget_per_turn).unwrap_or(i32::MAX);
Self {
strategy,
timeout: Duration::from_millis(timeout_ms),
llm_budget: std::sync::atomic::AtomicI32::new(budget),
retain_alternatives,
llm_provider: None,
}
}
#[must_use]
pub fn with_llm_provider(mut self, provider: AnyProvider) -> Self {
self.llm_provider = Some(provider);
self
}
pub fn reset_turn_budget(&self, budget: usize) {
let budget_i32 = i32::try_from(budget).unwrap_or(i32::MAX);
self.llm_budget
.store(budget_i32, std::sync::atomic::Ordering::Relaxed);
}
pub async fn resolve(
&self,
mut candidates: Vec<Edge>,
metrics: &ApexMetrics,
) -> Result<ConflictResult, MemoryError> {
tracing::debug!(target: "memory.graph.apex.conflict_resolve", candidates = candidates.len());
if candidates.is_empty() {
return Err(MemoryError::InvalidInput(
"conflict resolver called with empty candidate list".into(),
));
}
if candidates.len() == 1 {
return Ok(ConflictResult {
winner: candidates.remove(0),
alternatives: Vec::new(),
});
}
let effective_strategy = self.effective_strategy();
let winner_idx = match effective_strategy {
ConflictStrategy::Recency => recency_winner(&candidates),
ConflictStrategy::Confidence => confidence_winner(&candidates),
ConflictStrategy::Llm => self.llm_winner(&candidates, metrics).await,
};
metrics
.conflicts_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let winner = candidates.remove(winner_idx);
let alternatives = if self.retain_alternatives {
candidates
} else {
Vec::new()
};
Ok(ConflictResult {
winner,
alternatives,
})
}
fn effective_strategy(&self) -> ConflictStrategy {
if self.strategy == ConflictStrategy::Llm {
let remaining = self.llm_budget.load(std::sync::atomic::Ordering::Relaxed);
if remaining <= 0 {
return ConflictStrategy::Recency;
}
}
self.strategy.clone()
}
async fn llm_winner(&self, candidates: &[Edge], metrics: &ApexMetrics) -> usize {
tracing::debug!(target: "memory.graph.apex.conflict_llm", candidates = candidates.len());
self.llm_budget
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
let Some(provider) = &self.llm_provider else {
return recency_winner(candidates);
};
let prompt = build_conflict_prompt(candidates);
let messages = [
Message::from_legacy(
Role::System,
"You are a knowledge graph conflict resolver. Given a list of conflicting \
edge facts indexed from 0, respond with only the index of the most \
authoritative or recent fact. Output a single integer and nothing else.",
),
Message::from_legacy(Role::User, prompt),
];
let timeout = self.timeout;
match tokio::time::timeout(timeout, provider.chat(&messages)).await {
Ok(Ok(response)) => {
let trimmed = response.trim();
if let Ok(idx) = trimmed.parse::<usize>()
&& idx < candidates.len()
{
return idx;
}
tracing::warn!(
raw = %trimmed,
"apex_mem: LLM conflict resolver returned unparseable index, falling back to recency"
);
recency_winner(candidates)
}
Ok(Err(e)) => {
tracing::warn!(error = %e,
"apex_mem: LLM conflict resolver call failed, falling back to recency");
recency_winner(candidates)
}
Err(_) => {
metrics
.llm_timeouts_total
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
tracing::warn!(
"apex_mem: LLM conflict resolver timed out after {}ms, falling back to recency",
timeout.as_millis()
);
recency_winner(candidates)
}
}
}
}
fn build_conflict_prompt(candidates: &[Edge]) -> String {
let mut lines = String::from("Conflicting facts for the same predicate:\n");
for (i, edge) in candidates.iter().enumerate() {
use std::fmt::Write as _;
let _ = writeln!(lines, "{i}: [{}] {}", edge.valid_from, edge.fact);
}
lines.push_str(
"\nWhich index (0-based) is the most authoritative? Respond with only the integer.",
);
lines
}
fn recency_winner(candidates: &[Edge]) -> usize {
candidates
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.valid_from.cmp(&b.valid_from))
.map_or(0, |(i, _)| i)
}
fn confidence_winner(candidates: &[Edge]) -> usize {
candidates
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| {
a.confidence
.partial_cmp(&b.confidence)
.unwrap_or(std::cmp::Ordering::Equal)
})
.map_or(0, |(i, _)| i)
}
#[derive(Debug, Default)]
pub struct ApexMetrics {
pub supersedes_total: std::sync::atomic::AtomicU64,
pub conflicts_total: std::sync::atomic::AtomicU64,
pub llm_timeouts_total: std::sync::atomic::AtomicU64,
pub unmapped_predicates_total: std::sync::atomic::AtomicU64,
}
impl ApexMetrics {
#[must_use]
pub fn snapshot(&self) -> Vec<(&'static str, u64)> {
vec![
(
"apex_mem_supersedes_total",
self.supersedes_total
.load(std::sync::atomic::Ordering::Relaxed),
),
(
"apex_mem_conflicts_total",
self.conflicts_total
.load(std::sync::atomic::Ordering::Relaxed),
),
(
"apex_mem_llm_timeouts_total",
self.llm_timeouts_total
.load(std::sync::atomic::Ordering::Relaxed),
),
(
"apex_mem_unmapped_predicates_total",
self.unmapped_predicates_total
.load(std::sync::atomic::Ordering::Relaxed),
),
]
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_edge(id: i64, valid_from: &str, confidence: f32) -> Edge {
Edge {
id,
source_entity_id: 1,
target_entity_id: 2,
relation: "works_at".into(),
canonical_relation: "works_at".into(),
fact: "fact".into(),
confidence,
valid_from: valid_from.to_string(),
valid_to: None,
created_at: valid_from.to_string(),
expired_at: None,
source_message_id: None,
qdrant_point_id: None,
edge_type: crate::graph::types::EdgeType::Semantic,
retrieval_count: 0,
last_retrieved_at: None,
superseded_by: None,
supersedes: None,
}
}
#[tokio::test]
async fn recency_strategy_picks_newest() {
let metrics = ApexMetrics::default();
let resolver = ConflictResolver::new(ConflictStrategy::Recency, 500, 3, false);
let candidates = vec![
make_edge(1, "2026-01-01 00:00:00", 0.9),
make_edge(2, "2026-06-01 00:00:00", 0.5),
make_edge(3, "2026-03-01 00:00:00", 0.7),
];
let result = resolver.resolve(candidates, &metrics).await.unwrap();
assert_eq!(result.winner.id, 2, "newest valid_from wins");
}
#[tokio::test]
async fn confidence_strategy_picks_highest() {
let metrics = ApexMetrics::default();
let resolver = ConflictResolver::new(ConflictStrategy::Confidence, 500, 3, false);
let candidates = vec![
make_edge(1, "2026-01-01 00:00:00", 0.9),
make_edge(2, "2026-06-01 00:00:00", 0.5),
make_edge(3, "2026-03-01 00:00:00", 0.7),
];
let result = resolver.resolve(candidates, &metrics).await.unwrap();
assert_eq!(result.winner.id, 1);
}
#[tokio::test]
async fn single_candidate_passes_through() {
let metrics = ApexMetrics::default();
let resolver = ConflictResolver::new(ConflictStrategy::Recency, 500, 3, false);
let candidates = vec![make_edge(42, "2026-01-01 00:00:00", 0.8)];
let result = resolver.resolve(candidates, &metrics).await.unwrap();
assert_eq!(result.winner.id, 42);
assert!(result.alternatives.is_empty());
}
#[tokio::test]
async fn retain_alternatives_when_enabled() {
let metrics = ApexMetrics::default();
let resolver = ConflictResolver::new(ConflictStrategy::Recency, 500, 3, true);
let candidates = vec![
make_edge(1, "2026-01-01 00:00:00", 0.9),
make_edge(2, "2026-06-01 00:00:00", 0.5),
];
let result = resolver.resolve(candidates, &metrics).await.unwrap();
assert_eq!(result.winner.id, 2);
assert_eq!(result.alternatives.len(), 1);
assert_eq!(result.alternatives[0].id, 1);
}
#[tokio::test]
async fn budget_exhaustion_falls_back_to_recency() {
let metrics = ApexMetrics::default();
let resolver = ConflictResolver::new(ConflictStrategy::Llm, 500, 0, false);
let candidates = vec![
make_edge(1, "2026-01-01 00:00:00", 0.9),
make_edge(2, "2026-06-01 00:00:00", 0.5),
];
let result = resolver.resolve(candidates, &metrics).await.unwrap();
assert_eq!(result.winner.id, 2);
}
#[test]
fn metrics_snapshot_has_four_entries() {
let m = ApexMetrics::default();
assert_eq!(m.snapshot().len(), 4);
}
}