use std::collections::HashSet;
use serde_json::Value;
use smos_domain::config::{HeatConfig, RetrievalConfig};
use smos_domain::{FactId, FactStatus, Heat, MemoryKey, SessionId, Timestamp};
use crate::errors::{ProviderError, UseCaseError};
use crate::helpers::memory_block::{self, MemoryBlockEntry};
use crate::helpers::request_enricher;
use crate::helpers::retrieval_planner::{self, RetrievalHit};
use crate::helpers::topic_extractor;
use crate::ports::{Clock, EmbeddingProvider, FactRepository, RerankProvider, SessionRepository};
use crate::types::{SearchHit, enrichment_messages_from_json};
pub struct EnrichRequest<'a, FR, SR, EP, RP, C> {
pub facts: &'a FR,
pub sessions: &'a SR,
pub embedder: &'a EP,
pub reranker: &'a RP,
pub clock: &'a C,
pub retrieval_cfg: &'a RetrievalConfig,
pub heat_cfg: &'a HeatConfig,
}
impl<'a, FR, SR, EP, RP, C> EnrichRequest<'a, FR, SR, EP, RP, C>
where
FR: FactRepository,
SR: SessionRepository,
EP: EmbeddingProvider,
RP: RerankProvider,
C: Clock,
{
pub async fn execute(
&self,
messages: Vec<Value>,
memory_key: &MemoryKey,
session_id: &SessionId,
) -> Result<Vec<Value>, UseCaseError> {
let typed_projection = enrichment_messages_from_json(&messages);
let topic = topic_extractor::extract_from_messages(&typed_projection);
let trimmed_len = topic.trim().chars().count();
if trimmed_len < self.retrieval_cfg.min_topic_chars {
tracing::debug!(
chars = trimmed_len,
"enrichment skipped: topic below min_topic_chars"
);
return Ok(messages);
}
let embedding = match self.embedder.embed(&topic).await {
Ok(Some(v)) => v,
Ok(None) => {
tracing::warn!("embedder returned None; skipping enrichment (fail-open)");
return Ok(messages);
}
Err(e) => {
tracing::warn!(error = %e, "embedder error; skipping enrichment (fail-open)");
return Ok(messages);
}
};
let hits = match self
.facts
.search_similar(embedding, memory_key, self.retrieval_cfg.top_k_initial)
.await
{
Ok(h) => h,
Err(e) => {
tracing::warn!(error = %e, "vector search failed; skipping enrichment (fail-open)");
return Ok(messages);
}
};
if hits.is_empty() {
tracing::info!(memory_key = %memory_key, "no vector hits; skipping enrichment");
return Ok(messages);
}
let now = self.clock.now();
let survivors = prefilter(hits, self.retrieval_cfg, self.heat_cfg, now);
if survivors.is_empty() {
return Ok(messages);
}
self.boost_heat(&survivors, memory_key, now).await;
let ranked_facts = self.rerank_survivors(&topic, &survivors).await?;
if ranked_facts.is_empty() {
return Err(UseCaseError::Provider(ProviderError::InvalidResponse(
"reranker returned no usable results".to_string(),
)));
}
let new_facts = self
.dedup_against_session(&ranked_facts, session_id, memory_key)
.await;
if new_facts.is_empty() {
return Ok(messages);
}
let block = build_memory_block(&new_facts, session_id, memory_key);
let messages_value = Value::Array(messages);
let enriched = request_enricher::inject_value(&messages_value, &block);
match enriched {
Value::Array(arr) => Ok(arr),
other => Ok(vec![other]),
}
}
async fn boost_heat(&self, survivors: &[RetrievalHit], memory_key: &MemoryKey, now: Timestamp) {
let ids: Vec<FactId> = survivors.iter().map(|h| h.id.clone()).collect();
if let Err(e) = self
.facts
.update_heat_batch(&ids, memory_key, Heat::MAX, now)
.await
{
tracing::warn!(error = %e, "heat boost failed (best-effort); continuing");
}
}
async fn rerank_survivors(
&self,
topic: &str,
survivors: &[RetrievalHit],
) -> Result<Vec<RetrievalHit>, ProviderError> {
let documents: Vec<String> = survivors.iter().map(|s| s.document.clone()).collect();
let ranked = self
.reranker
.rerank(topic, &documents, self.retrieval_cfg.top_k_final)
.await
.map_err(|e| {
tracing::error!(error = %e, "reranker unavailable; request will fail with 503");
e
})?;
if ranked.is_empty() {
tracing::error!("reranker returned empty results; request will fail with 503");
return Err(ProviderError::InvalidResponse(
"reranker returned empty results".to_string(),
));
}
Ok(ranked
.into_iter()
.filter_map(|r| survivors.get(r.index).cloned())
.collect())
}
async fn dedup_against_session(
&self,
ranked_facts: &[RetrievalHit],
session_id: &SessionId,
memory_key: &MemoryKey,
) -> Vec<RetrievalHit> {
let candidate_ids: Vec<FactId> = ranked_facts.iter().map(|f| f.id.clone()).collect();
let new_ids: HashSet<FactId> = match self
.sessions
.dedup_and_mark(session_id, memory_key, &candidate_ids)
.await
{
Ok(ids) => ids.into_iter().collect(),
Err(e) => {
tracing::warn!(error = %e, "dedup_and_mark failed; skipping injection (fail-open)");
return Vec::new();
}
};
if new_ids.is_empty() {
return Vec::new();
}
ranked_facts
.iter()
.filter(|f| new_ids.contains(&f.id))
.cloned()
.collect()
}
}
fn prefilter(
hits: Vec<SearchHit>,
retrieval_cfg: &RetrievalConfig,
heat_cfg: &HeatConfig,
now: Timestamp,
) -> Vec<RetrievalHit> {
let retrieval_hits: Vec<RetrievalHit> = hits.into_iter().filter_map(hit_to_retrieval).collect();
retrieval_planner::prefilter_and_heat(&retrieval_hits, retrieval_cfg, heat_cfg, now)
}
fn hit_to_retrieval(hit: SearchHit) -> Option<RetrievalHit> {
let status = match parse_fact_status(&hit.metadata.status) {
Some(s) => s,
None => {
tracing::warn!(status = %hit.metadata.status, "unparseable status; dropping hit");
return None;
}
};
let confidence = match smos_domain::Confidence::new(hit.metadata.confidence) {
Ok(c) => c,
Err(e) => {
tracing::warn!(error = %e, "out-of-range confidence; dropping hit");
return None;
}
};
let heat = match Heat::new(hit.metadata.heat_base) {
Ok(h) => h,
Err(e) => {
tracing::warn!(error = %e, "out-of-range heat_base; dropping hit");
return None;
}
};
let last_access_at = match Timestamp::from_unix_secs(hit.metadata.last_access_at as i64) {
Ok(ts) => ts,
Err(e) => {
tracing::warn!(error = %e, "out-of-range last_access_at; dropping hit");
return None;
}
};
let valid_until = hit
.metadata
.valid_until
.as_deref()
.and_then(parse_iso_timestamp);
Some(RetrievalHit {
id: hit.id,
document: hit.document,
memory_key: hit.memory_key,
status,
confidence,
valid_until,
heat_base: heat,
last_access_at,
})
}
fn parse_fact_status(s: &str) -> Option<FactStatus> {
[
FactStatus::Pending,
FactStatus::Accepted,
FactStatus::Rejected,
]
.into_iter()
.find(|candidate| s == candidate.as_str())
}
fn parse_iso_timestamp(s: &str) -> Option<Timestamp> {
use time::OffsetDateTime;
let odt = OffsetDateTime::parse(s, &time::format_description::well_known::Rfc3339).ok()?;
Timestamp::from_unix_secs(odt.unix_timestamp()).ok()
}
fn build_memory_block(
facts: &[RetrievalHit],
session_id: &SessionId,
memory_key: &MemoryKey,
) -> String {
let entries: Vec<MemoryBlockEntry<'_>> = facts
.iter()
.map(|f| MemoryBlockEntry {
id: &f.id,
document: f.document.as_str(),
})
.collect();
memory_block::build(entries, session_id, memory_key)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{ChatMessageDto, EnrichmentMessages, MessageContent};
use smos_domain::{MemoryKey, SessionId};
fn user_msg(content: &str) -> ChatMessageDto {
ChatMessageDto {
role: "user".into(),
content: MessageContent::Text(content.into()),
tool_calls: None,
}
}
#[test]
fn extract_topic_from_string_content() {
let msgs: EnrichmentMessages = vec![user_msg("hello world")];
assert_eq!(topic_extractor::extract_from_messages(&msgs), "hello world");
}
#[test]
fn extract_topic_returns_empty_when_no_messages() {
let msgs: EnrichmentMessages = Vec::new();
assert_eq!(topic_extractor::extract_from_messages(&msgs), "");
}
#[test]
fn extract_topic_returns_empty_when_missing_content() {
let msg = ChatMessageDto {
role: "user".into(),
content: MessageContent::Text(String::new()),
tool_calls: None,
};
let msgs: EnrichmentMessages = vec![msg];
assert_eq!(topic_extractor::extract_from_messages(&msgs), "");
}
#[test]
fn extract_topic_flattens_multipart() {
let msg = ChatMessageDto {
role: "user".into(),
content: MessageContent::Multipart(vec![
crate::types::ContentPart {
kind: "text".into(),
text: "alpha".into(),
},
crate::types::ContentPart {
kind: "image_url".into(),
text: String::new(),
},
crate::types::ContentPart {
kind: "text".into(),
text: "beta".into(),
},
]),
tool_calls: None,
};
let msgs: EnrichmentMessages = vec![msg];
assert_eq!(topic_extractor::extract_from_messages(&msgs), "alpha beta");
}
#[test]
fn parse_fact_status_recognises_canonical_tokens() {
assert_eq!(parse_fact_status("pending"), Some(FactStatus::Pending));
assert_eq!(parse_fact_status("accepted"), Some(FactStatus::Accepted));
assert_eq!(parse_fact_status("rejected"), Some(FactStatus::Rejected));
}
#[test]
fn parse_fact_status_rejects_unknown_tokens() {
assert_eq!(parse_fact_status("invalid"), None);
assert_eq!(parse_fact_status(""), None);
}
#[test]
fn parse_fact_status_is_case_sensitive() {
assert_eq!(parse_fact_status("Accepted"), None);
assert_eq!(parse_fact_status("ACCEPTED"), None);
}
#[test]
fn parse_iso_timestamp_accepts_rfc3339_utc() {
let ts = parse_iso_timestamp("2025-06-18T12:00:00Z").expect("valid rfc3339");
assert_eq!(ts.as_unix_secs(), 1_750_248_000);
}
#[test]
fn parse_iso_timestamp_accepts_offset_form() {
let ts = parse_iso_timestamp("2025-06-18T12:00:00+00:00").expect("valid offset");
assert_eq!(ts.as_unix_secs(), 1_750_248_000);
}
#[test]
fn parse_iso_timestamp_rejects_malformed_strings() {
assert_eq!(parse_iso_timestamp("not a date"), None);
assert_eq!(parse_iso_timestamp(""), None);
assert_eq!(parse_iso_timestamp("2025-06-18"), None);
}
fn sample_hit(
status: &str,
confidence: f32,
heat_base: f32,
last_access_at: f32,
valid_until: Option<&str>,
) -> SearchHit {
SearchHit {
id: FactId::from_raw("fact_0123456789abcdef").expect("fact id"),
document: "doc".into(),
memory_key: MemoryKey::from_raw("origa").expect("memory key"),
metadata: crate::types::SearchHitMetadata {
status: status.into(),
confidence,
valid_until: valid_until.map(str::to_string),
heat_base,
last_access_at,
distance: Some(0.1),
},
}
}
#[test]
fn hit_to_retrieval_maps_well_formed_hit() {
let hit = sample_hit("accepted", 0.85, 1.0, 1_700_000_000.0, None);
let r = hit_to_retrieval(hit).expect("mapped");
assert_eq!(r.status, FactStatus::Accepted);
assert!((r.confidence.value() - 0.85).abs() < 1e-6);
assert!((r.heat_base.value() - 1.0).abs() < 1e-6);
assert_eq!(r.last_access_at.as_unix_secs(), 1_700_000_000);
assert!(r.valid_until.is_none());
}
#[test]
fn hit_to_retrieval_carries_valid_until_tombstone() {
let hit = sample_hit(
"accepted",
0.9,
0.5,
1_700_000_000.0,
Some("2025-12-31T00:00:00Z"),
);
let r = hit_to_retrieval(hit).expect("mapped");
assert!(r.valid_until.is_some());
}
#[test]
fn hit_to_retrieval_drops_hit_with_unknown_status() {
let hit = sample_hit("weird", 0.9, 1.0, 1_700_000_000.0, None);
assert!(hit_to_retrieval(hit).is_none());
}
#[test]
fn hit_to_retrieval_drops_hit_with_out_of_range_confidence() {
let hit = sample_hit("accepted", 1.5, 1.0, 1_700_000_000.0, None);
assert!(hit_to_retrieval(hit).is_none());
}
#[test]
fn hit_to_retrieval_drops_hit_with_out_of_range_heat() {
let hit = sample_hit("accepted", 0.9, 2.0, 1_700_000_000.0, None);
assert!(hit_to_retrieval(hit).is_none());
}
#[test]
fn hit_to_retrieval_drops_hit_with_out_of_range_last_access_at() {
let hit = sample_hit("accepted", 0.9, 1.0, f32::INFINITY, None);
assert!(hit_to_retrieval(hit).is_none());
}
#[test]
fn hit_to_retrieval_treats_malformed_valid_until_as_none() {
let hit = sample_hit("accepted", 0.9, 1.0, 1_700_000_000.0, Some("not-a-date"));
let r = hit_to_retrieval(hit).expect("mapped despite malformed valid_until");
assert!(r.valid_until.is_none());
}
#[test]
fn build_memory_block_includes_session_and_fact_lines() {
let session = SessionId::from_raw("sess_0123456789ab").expect("session");
let key = MemoryKey::from_raw("origa").expect("key");
let facts = vec![RetrievalHit {
id: FactId::from_raw("fact_0123456789abcdef").expect("fact"),
document: "hello world".into(),
memory_key: key.clone(),
status: FactStatus::Accepted,
confidence: smos_domain::Confidence::new(0.9).unwrap(),
valid_until: None,
heat_base: Heat::MAX,
last_access_at: Timestamp::from_unix_secs(1_700_000_000).unwrap(),
}];
let block = build_memory_block(&facts, &session, &key);
assert!(block.contains("<smos-memory"));
assert!(block.contains("hello world"));
}
#[test]
fn prefilter_returns_empty_for_empty_input() {
let cfg = RetrievalConfig::default();
let heat = HeatConfig::default();
let now = Timestamp::from_unix_secs(1_700_000_000).unwrap();
assert!(prefilter(Vec::new(), &cfg, &heat, now).is_empty());
}
}