use std::sync::Arc;
use tracing::Instrument as _;
pub use zeph_config::memory::TieredRetrievalConfig;
use zeph_llm::any::AnyProvider;
use crate::embedding_store::SearchFilter;
use crate::error::MemoryError;
use crate::router::{HeuristicRouter, HybridRouter, MemoryRoute, MemoryRouter};
use crate::semantic::RecalledMessage;
use crate::semantic::SemanticMemory;
use crate::types::ConversationId;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum IntentClass {
ProfileLookup,
TargetedRetrieval,
DeepReasoning,
}
impl IntentClass {
fn from_route(route: MemoryRoute) -> Self {
match route {
MemoryRoute::Keyword | MemoryRoute::Episodic => Self::ProfileLookup,
MemoryRoute::Semantic | MemoryRoute::Hybrid => Self::TargetedRetrieval,
MemoryRoute::Graph => Self::DeepReasoning,
}
}
fn top_k(self) -> usize {
match self {
Self::ProfileLookup => 3,
Self::TargetedRetrieval => 10,
Self::DeepReasoning => 20,
}
}
fn escalate(self) -> Option<Self> {
match self {
Self::ProfileLookup => Some(Self::TargetedRetrieval),
Self::TargetedRetrieval => Some(Self::DeepReasoning),
Self::DeepReasoning => None,
}
}
}
impl std::fmt::Display for IntentClass {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ProfileLookup => f.write_str("ProfileLookup"),
Self::TargetedRetrieval => f.write_str("TargetedRetrieval"),
Self::DeepReasoning => f.write_str("DeepReasoning"),
}
}
}
#[derive(Debug)]
pub struct TieredRetrievalResult {
pub messages: Vec<RecalledMessage>,
pub intent: IntentClass,
pub tokens_used: usize,
pub tier_escalated: bool,
}
#[tracing::instrument(name = "memory.tiered.retrieve", skip_all, fields(intent = tracing::field::Empty))]
pub async fn recall_tiered(
memory: &SemanticMemory,
query: &str,
conversation_id: Option<ConversationId>,
classifier: Option<&Arc<AnyProvider>>,
validator: Option<&Arc<AnyProvider>>,
config: &TieredRetrievalConfig,
remaining_budget: Option<usize>,
) -> Result<TieredRetrievalResult, MemoryError> {
let effective_budget =
remaining_budget.map_or(config.token_budget, |rb| rb.min(config.token_budget));
let initial_intent = if let Some(classifier_provider) = classifier {
let hybrid = HybridRouter::new(
Arc::clone(classifier_provider),
MemoryRoute::Hybrid,
0.7,
);
let decision = if let Ok(d) = tokio::time::timeout(
std::time::Duration::from_secs(config.classifier_timeout_secs),
hybrid.classify_async(query),
)
.await
{
d
} else {
tracing::warn!("tiered: classifier LLM timed out, falling back to heuristic");
HeuristicRouter.route_with_confidence(query)
};
IntentClass::from_route(decision.route)
} else {
let decision = HeuristicRouter.route_with_confidence(query);
IntentClass::from_route(decision.route)
};
tracing::debug!(intent = %initial_intent, query_len = query.len(), "tiered: classified intent");
escalation_loop(
memory,
query,
conversation_id,
initial_intent,
validator,
config,
effective_budget,
)
.await
}
async fn escalation_loop(
memory: &SemanticMemory,
query: &str,
conversation_id: Option<ConversationId>,
initial_intent: IntentClass,
validator: Option<&Arc<AnyProvider>>,
config: &TieredRetrievalConfig,
effective_budget: usize,
) -> Result<TieredRetrievalResult, MemoryError> {
let mut intent = initial_intent;
let mut escalations: u8 = 0;
let mut tier_escalated = false;
loop {
let candidates = retrieve_tier(memory, query, conversation_id, intent)
.instrument(tracing::debug_span!("memory.tiered.retrieve_tier", tier = %intent))
.await?;
let (messages, tokens_used) = {
let _span = tracing::debug_span!("memory.tiered.assemble").entered();
assemble_within_budget(candidates, effective_budget)
};
if config.validation_enabled
&& escalations < config.max_escalations
&& let Some(validator_provider) = validator
&& let Some(next_tier) = intent.escalate()
{
let sufficient = validate_evidence(
validator_provider,
query,
&messages,
config.validation_threshold,
config.validator_timeout_secs,
)
.instrument(tracing::debug_span!("memory.tiered.validate"))
.await;
if !sufficient {
tracing::debug!(
current_tier = %intent,
next_tier = %next_tier,
escalations,
"tiered: evidence insufficient, escalating tier"
);
intent = next_tier;
escalations += 1;
tier_escalated = true;
continue;
}
}
return Ok(TieredRetrievalResult {
messages,
intent,
tokens_used,
tier_escalated,
});
}
}
async fn retrieve_tier(
memory: &SemanticMemory,
query: &str,
conversation_id: Option<ConversationId>,
intent: IntentClass,
) -> Result<Vec<RecalledMessage>, MemoryError> {
let top_k = intent.top_k();
let heuristic = HeuristicRouter;
let filter = conversation_id.map(|cid| SearchFilter {
conversation_id: Some(cid),
role: None,
category: None,
});
memory.recall_routed(query, top_k, filter, &heuristic).await
}
fn assemble_within_budget(
candidates: Vec<RecalledMessage>,
budget: usize,
) -> (Vec<RecalledMessage>, usize) {
let mut retained = Vec::with_capacity(candidates.len());
let mut total_tokens: usize = 0;
for msg in candidates {
let msg_tokens = zeph_common::text::estimate_tokens(&msg.message.content);
if total_tokens.saturating_add(msg_tokens) > budget {
break;
}
total_tokens += msg_tokens;
retained.push(msg);
}
(retained, total_tokens)
}
async fn validate_evidence(
provider: &Arc<AnyProvider>,
query: &str,
messages: &[RecalledMessage],
threshold: f32,
timeout_secs: u64,
) -> bool {
use zeph_llm::provider::{LlmProvider as _, Message, MessageMetadata, Role};
if messages.is_empty() {
return false;
}
let evidence_snippet = messages
.iter()
.take(5)
.map(|m| {
zeph_common::sanitize::strip_control_chars_preserve_whitespace(&m.message.content)
.chars()
.take(200)
.collect::<String>()
})
.collect::<Vec<_>>()
.join("\n---\n");
let system = "You are an evidence quality judge. \
Given a query and evidence snippets, decide if the evidence is sufficient to answer the query. \
Respond ONLY with a JSON object: {\"sufficient\": true|false, \"confidence\": 0.0-1.0}";
let sanitized_query = zeph_common::sanitize::strip_control_chars_preserve_whitespace(query);
let user = format!(
"<query>{}</query>\n<evidence>{}</evidence>",
sanitized_query.chars().take(500).collect::<String>(),
evidence_snippet
);
let msgs = vec![
Message {
role: Role::System,
content: system.to_owned(),
parts: vec![],
metadata: MessageMetadata::default(),
},
Message {
role: Role::User,
content: user,
parts: vec![],
metadata: MessageMetadata::default(),
},
];
match tokio::time::timeout(
std::time::Duration::from_secs(timeout_secs),
provider.chat(&msgs),
)
.await
{
Ok(Ok(raw)) => parse_validation_response(&raw, threshold),
Ok(Err(e)) => {
tracing::warn!(error = %e, "tiered: validator LLM call failed, treating as sufficient");
true
}
Err(_) => {
tracing::warn!("tiered: validator LLM call timed out, treating as sufficient");
true
}
}
}
fn parse_validation_response(raw: &str, threshold: f32) -> bool {
let json_str = raw
.find('{')
.and_then(|s| raw[s..].rfind('}').map(|e| &raw[s..=s + e]))
.unwrap_or("");
if let Ok(v) = serde_json::from_str::<serde_json::Value>(json_str) {
let sufficient = v
.get("sufficient")
.and_then(serde_json::Value::as_bool)
.unwrap_or(true);
#[allow(clippy::cast_possible_truncation)]
let confidence = v
.get("confidence")
.and_then(serde_json::Value::as_f64)
.map_or(1.0, |c| c.clamp(0.0, 1.0) as f32);
return sufficient && confidence >= threshold;
}
tracing::debug!("tiered: could not parse validator response, treating as sufficient");
true
}
#[cfg(test)]
mod tests {
use super::*;
use crate::router::MemoryRoute;
use crate::semantic::RecalledMessage;
use zeph_llm::provider::{Message, MessageMetadata, Role};
fn make_message(content: &str) -> RecalledMessage {
RecalledMessage {
message: Message {
role: Role::User,
content: content.to_owned(),
parts: vec![],
metadata: MessageMetadata::default(),
},
score: 1.0,
}
}
#[test]
fn intent_class_from_route_mapping() {
assert_eq!(
IntentClass::from_route(MemoryRoute::Keyword),
IntentClass::ProfileLookup
);
assert_eq!(
IntentClass::from_route(MemoryRoute::Episodic),
IntentClass::ProfileLookup
);
assert_eq!(
IntentClass::from_route(MemoryRoute::Semantic),
IntentClass::TargetedRetrieval
);
assert_eq!(
IntentClass::from_route(MemoryRoute::Hybrid),
IntentClass::TargetedRetrieval
);
assert_eq!(
IntentClass::from_route(MemoryRoute::Graph),
IntentClass::DeepReasoning
);
}
#[test]
fn intent_class_top_k() {
assert_eq!(IntentClass::ProfileLookup.top_k(), 3);
assert_eq!(IntentClass::TargetedRetrieval.top_k(), 10);
assert_eq!(IntentClass::DeepReasoning.top_k(), 20);
}
#[test]
fn intent_class_escalate_chain() {
assert_eq!(
IntentClass::ProfileLookup.escalate(),
Some(IntentClass::TargetedRetrieval)
);
assert_eq!(
IntentClass::TargetedRetrieval.escalate(),
Some(IntentClass::DeepReasoning)
);
assert_eq!(IntentClass::DeepReasoning.escalate(), None);
}
#[test]
fn assemble_within_budget_empty_input() {
let (retained, tokens) = assemble_within_budget(vec![], 4096);
assert!(retained.is_empty());
assert_eq!(tokens, 0);
}
#[test]
fn assemble_within_budget_zero_budget_returns_nothing() {
let candidates = vec![make_message("hello"), make_message("world")];
let (retained, tokens) = assemble_within_budget(candidates, 0);
assert!(retained.is_empty(), "budget=0 must retain no messages");
assert_eq!(tokens, 0);
}
#[test]
fn assemble_within_budget_truncates_at_limit() {
let msg = "a ".repeat(400);
let candidates = vec![make_message(&msg), make_message(&msg)];
let (retained, tokens) = assemble_within_budget(candidates, 250);
assert_eq!(
retained.len(),
1,
"tight budget must keep only first message"
);
assert_eq!(tokens, 200);
}
#[test]
fn parse_validation_response_missing_fields_defaults_to_sufficient() {
let raw = "{}";
assert!(
parse_validation_response(raw, 0.6),
"missing fields must default to sufficient"
);
}
#[test]
fn tiered_retrieval_config_defaults() {
let cfg = TieredRetrievalConfig::default();
assert!(!cfg.enabled);
assert_eq!(cfg.token_budget, 4096);
assert!(!cfg.validation_enabled);
assert_eq!(cfg.max_escalations, 1);
assert_eq!(cfg.classifier_timeout_secs, 5);
assert_eq!(cfg.validator_timeout_secs, 5);
}
#[test]
fn tiered_retrieval_config_timeout_fields_propagate() {
let cfg = TieredRetrievalConfig {
classifier_timeout_secs: 10,
validator_timeout_secs: 15,
..TieredRetrievalConfig::default()
};
assert_eq!(cfg.classifier_timeout_secs, 10);
assert_eq!(cfg.validator_timeout_secs, 15);
let classifier_dur = std::time::Duration::from_secs(cfg.classifier_timeout_secs);
let validator_dur = std::time::Duration::from_secs(cfg.validator_timeout_secs);
assert_eq!(classifier_dur.as_secs(), 10);
assert_eq!(validator_dur.as_secs(), 15);
}
#[test]
fn parse_validation_response_sufficient() {
let raw = r#"{"sufficient": true, "confidence": 0.9}"#;
assert!(parse_validation_response(raw, 0.6));
}
#[test]
fn parse_validation_response_insufficient() {
let raw = r#"{"sufficient": false, "confidence": 0.4}"#;
assert!(!parse_validation_response(raw, 0.6));
}
#[test]
fn parse_validation_response_low_confidence() {
let raw = r#"{"sufficient": true, "confidence": 0.3}"#;
assert!(!parse_validation_response(raw, 0.6));
}
#[test]
fn parse_validation_response_malformed_json_treats_as_sufficient() {
let raw = "not json at all";
assert!(parse_validation_response(raw, 0.6));
}
#[test]
fn intent_class_display() {
assert_eq!(IntentClass::ProfileLookup.to_string(), "ProfileLookup");
assert_eq!(
IntentClass::TargetedRetrieval.to_string(),
"TargetedRetrieval"
);
assert_eq!(IntentClass::DeepReasoning.to_string(), "DeepReasoning");
}
#[tokio::test]
async fn recall_tiered_no_classifier_uses_heuristic_router() {
let memory = crate::testing::mock_semantic_memory()
.await
.expect("mock_semantic_memory");
let config = TieredRetrievalConfig {
enabled: true,
validation_enabled: false,
..TieredRetrievalConfig::default()
};
let result = recall_tiered(&memory, "what is my name", None, None, None, &config, None)
.await
.expect("recall_tiered must not fail");
assert!(
!result.tier_escalated,
"no escalation when validation is off"
);
assert!(result.tokens_used <= config.token_budget);
}
#[tokio::test]
async fn recall_tiered_with_classifier_uses_hybrid_router() {
use zeph_llm::mock::MockProvider;
let memory = crate::testing::mock_semantic_memory()
.await
.expect("mock_semantic_memory");
let route_json = r#"{"route": "Semantic", "confidence": 0.9}"#.to_owned();
let mut mock = MockProvider::with_responses(vec![route_json]);
mock.supports_embeddings = true;
mock.embedding = vec![0.1_f32; 384];
let classifier = Arc::new(AnyProvider::Mock(mock));
let config = TieredRetrievalConfig {
enabled: true,
validation_enabled: false,
..TieredRetrievalConfig::default()
};
let result = recall_tiered(
&memory,
"semantic query about the user",
None,
Some(&classifier),
None,
&config,
None,
)
.await
.expect("recall_tiered with classifier must not fail");
assert!(!result.tier_escalated);
assert!(result.tokens_used <= config.token_budget);
}
#[tokio::test]
async fn recall_tiered_escalates_when_evidence_insufficient() {
use zeph_llm::mock::MockProvider;
let memory = crate::testing::mock_semantic_memory()
.await
.expect("mock_semantic_memory");
let insufficient = r#"{"sufficient": false, "confidence": 0.1}"#.to_owned();
let sufficient = r#"{"sufficient": true, "confidence": 0.95}"#.to_owned();
let mut validator_mock = MockProvider::with_responses(vec![insufficient, sufficient]);
validator_mock.supports_embeddings = true;
let validator = Arc::new(AnyProvider::Mock(validator_mock));
let config = TieredRetrievalConfig {
enabled: true,
validation_enabled: true,
validation_threshold: 0.6,
max_escalations: 2,
..TieredRetrievalConfig::default()
};
let result = recall_tiered(
&memory,
"deep query",
None,
None,
Some(&validator),
&config,
None,
)
.await
.expect("escalation path must not fail");
assert!(
result.tier_escalated,
"must set tier_escalated when validator triggers escalation"
);
}
#[tokio::test]
async fn validate_evidence_timeout_is_fail_open() {
use zeph_llm::mock::MockProvider;
let memory = crate::testing::mock_semantic_memory()
.await
.expect("mock_semantic_memory");
let conv_id = memory
.sqlite()
.create_conversation()
.await
.expect("create_conversation");
memory
.remember(conv_id, "user", "some evidence content", None)
.await
.expect("remember");
let slow_mock = MockProvider::default().with_delay(6_000);
let validator = Arc::new(AnyProvider::Mock(slow_mock));
let config = TieredRetrievalConfig {
enabled: true,
validation_enabled: true,
validation_threshold: 0.6,
max_escalations: 1,
validator_timeout_secs: 5,
..TieredRetrievalConfig::default()
};
let result = recall_tiered(
&memory,
"evidence",
None,
None,
Some(&validator),
&config,
None,
)
.await
.expect("timeout path must not propagate as error");
assert!(
!result.tier_escalated,
"validator timeout must be treated as sufficient (fail-open)"
);
}
#[tokio::test]
async fn validate_evidence_llm_error_is_fail_open() {
use zeph_llm::mock::MockProvider;
let memory = crate::testing::mock_semantic_memory()
.await
.expect("mock_semantic_memory");
let conv_id = memory
.sqlite()
.create_conversation()
.await
.expect("create_conversation");
memory
.remember(conv_id, "user", "some evidence content", None)
.await
.expect("remember");
let failing_mock = MockProvider::failing();
let validator = Arc::new(AnyProvider::Mock(failing_mock));
let config = TieredRetrievalConfig {
enabled: true,
validation_enabled: true,
validation_threshold: 0.6,
max_escalations: 1,
..TieredRetrievalConfig::default()
};
let result = recall_tiered(
&memory,
"evidence",
None,
None,
Some(&validator),
&config,
None,
)
.await
.expect("LLM error path must not propagate as retrieval error");
assert!(
!result.tier_escalated,
"validator LLM error must be treated as sufficient (fail-open)"
);
}
#[tokio::test]
async fn recall_tiered_with_conversation_id_filter() {
let memory = crate::testing::mock_semantic_memory()
.await
.expect("mock_semantic_memory");
let conv_id = ConversationId(42);
let config = TieredRetrievalConfig {
enabled: true,
validation_enabled: false,
..TieredRetrievalConfig::default()
};
let result = recall_tiered(
&memory,
"what did we discuss",
Some(conv_id),
None,
None,
&config,
None,
)
.await
.expect("conversation-scoped recall must not fail");
assert!(result.messages.is_empty());
assert_eq!(result.tokens_used, 0);
assert!(!result.tier_escalated);
}
}