use std::collections::HashSet;
use serde_json::Value;
use smos_domain::config::{HeatConfig, RetrievalConfig};
use smos_domain::{FactId, Heat, MemoryKey, SessionId, Timestamp};
use crate::errors::{ProviderError, UseCaseError};
use crate::helpers::memory_block::{self, MemoryBlockEntry};
use crate::helpers::request_enricher;
use crate::helpers::retrieval_pipeline;
use crate::helpers::retrieval_planner::{self, RetrievalHit};
use crate::helpers::topic_extractor;
use crate::ports::{Clock, EmbeddingProvider, FactRepository, RerankProvider, SessionRepository};
use crate::types::{EnrichmentMessages, SearchHit, enrichment_messages_from_json};
pub struct EnrichRequest<'a, FR, SR, EP, RP, C> {
pub facts: &'a FR,
pub sessions: &'a SR,
pub embedder: &'a EP,
pub reranker: &'a RP,
pub clock: &'a C,
pub retrieval_cfg: &'a RetrievalConfig,
pub heat_cfg: &'a HeatConfig,
}
impl<'a, FR, SR, EP, RP, C> EnrichRequest<'a, FR, SR, EP, RP, C>
where
FR: FactRepository,
SR: SessionRepository,
EP: EmbeddingProvider,
RP: RerankProvider,
C: Clock,
{
pub async fn execute(
&self,
messages: Vec<Value>,
memory_key: &MemoryKey,
session_id: &SessionId,
) -> Result<Vec<Value>, UseCaseError> {
let typed_projection = enrichment_messages_from_json(&messages);
let Some((topic, survivors)) = self
.retrieve_survivors(&typed_projection, memory_key)
.await?
else {
return Ok(messages);
};
let new_facts = self
.rerank_and_dedup(&topic, &survivors, session_id, memory_key)
.await?;
if new_facts.is_empty() {
return Ok(messages);
}
let block = build_memory_block(&new_facts, session_id, memory_key);
let messages_value = Value::Array(messages);
let enriched = request_enricher::inject_value(&messages_value, &block);
match enriched {
Value::Array(arr) => Ok(arr),
other => Ok(vec![other]),
}
}
async fn retrieve_survivors(
&self,
typed_projection: &EnrichmentMessages,
memory_key: &MemoryKey,
) -> Result<Option<(String, Vec<RetrievalHit>)>, UseCaseError> {
let topic = topic_extractor::extract_from_messages(typed_projection);
let trimmed_len = topic.trim().chars().count();
if trimmed_len < self.retrieval_cfg.min_topic_chars {
tracing::debug!(
chars = trimmed_len,
"enrichment skipped: topic below min_topic_chars"
);
return Ok(None);
}
let embedding = match self.embedder.embed(&topic).await {
Ok(Some(v)) => v,
Ok(None) => {
tracing::warn!("embedder returned None; skipping enrichment (fail-open)");
return Ok(None);
}
Err(e) => {
tracing::warn!(error = %e, "embedder error; skipping enrichment (fail-open)");
return Ok(None);
}
};
let hits = match self
.facts
.search_similar(embedding, memory_key, self.retrieval_cfg.top_k_initial)
.await
{
Ok(h) => h,
Err(e) => {
tracing::warn!(error = %e, "vector search failed; skipping enrichment (fail-open)");
return Ok(None);
}
};
if hits.is_empty() {
tracing::info!(memory_key = %memory_key, "no vector hits; skipping enrichment");
return Ok(None);
}
let now = self.clock.now();
let survivors = prefilter(hits, self.retrieval_cfg, self.heat_cfg, now);
if survivors.is_empty() {
return Ok(None);
}
self.boost_heat(&survivors, memory_key, now).await;
Ok(Some((topic, survivors)))
}
async fn rerank_and_dedup(
&self,
topic: &str,
survivors: &[RetrievalHit],
session_id: &SessionId,
memory_key: &MemoryKey,
) -> Result<Vec<RetrievalHit>, UseCaseError> {
let ranked_facts = self.rerank_survivors(topic, survivors).await?;
if ranked_facts.is_empty() {
return Err(UseCaseError::Provider(ProviderError::InvalidResponse(
"reranker returned no usable results".to_string(),
)));
}
Ok(self
.dedup_against_session(&ranked_facts, session_id, memory_key)
.await)
}
async fn boost_heat(&self, survivors: &[RetrievalHit], memory_key: &MemoryKey, now: Timestamp) {
let ids: Vec<FactId> = survivors.iter().map(|h| h.id.clone()).collect();
if let Err(e) = self
.facts
.update_heat_batch(&ids, memory_key, Heat::MAX, now)
.await
{
tracing::warn!(error = %e, "heat boost failed (best-effort); continuing");
}
}
async fn rerank_survivors(
&self,
topic: &str,
survivors: &[RetrievalHit],
) -> Result<Vec<RetrievalHit>, ProviderError> {
let ranked = retrieval_pipeline::rerank_hits(
topic,
survivors,
self.reranker,
self.retrieval_cfg.top_k_final,
)
.await?;
Ok(ranked.into_iter().map(|r| r.hit).collect())
}
async fn dedup_against_session(
&self,
ranked_facts: &[RetrievalHit],
session_id: &SessionId,
memory_key: &MemoryKey,
) -> Vec<RetrievalHit> {
let candidate_ids: Vec<FactId> = ranked_facts.iter().map(|f| f.id.clone()).collect();
let new_ids: HashSet<FactId> = match self
.sessions
.dedup_and_mark(session_id, memory_key, &candidate_ids)
.await
{
Ok(ids) => ids.into_iter().collect(),
Err(e) => {
tracing::warn!(error = %e, "dedup_and_mark failed; skipping injection (fail-open)");
return Vec::new();
}
};
if new_ids.is_empty() {
return Vec::new();
}
ranked_facts
.iter()
.filter(|f| new_ids.contains(&f.id))
.cloned()
.collect()
}
}
fn prefilter(
hits: Vec<SearchHit>,
retrieval_cfg: &RetrievalConfig,
heat_cfg: &HeatConfig,
now: Timestamp,
) -> Vec<RetrievalHit> {
let retrieval_hits: Vec<RetrievalHit> = hits
.into_iter()
.filter_map(retrieval_pipeline::hit_to_retrieval)
.collect();
retrieval_planner::prefilter_and_heat(&retrieval_hits, retrieval_cfg, heat_cfg, now)
}
fn build_memory_block(
facts: &[RetrievalHit],
session_id: &SessionId,
memory_key: &MemoryKey,
) -> String {
let entries: Vec<MemoryBlockEntry<'_>> = facts
.iter()
.map(|f| MemoryBlockEntry {
id: &f.id,
document: f.document.as_str(),
})
.collect();
memory_block::build(entries, session_id, memory_key)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{ChatMessageDto, EnrichmentMessages, MessageContent};
use smos_domain::{FactStatus, MemoryKey, SessionId};
fn user_msg(content: &str) -> ChatMessageDto {
ChatMessageDto {
role: "user".into(),
content: MessageContent::Text(content.into()),
tool_calls: None,
}
}
#[test]
fn extract_topic_from_string_content() {
let msgs: EnrichmentMessages = vec![user_msg("hello world")];
assert_eq!(topic_extractor::extract_from_messages(&msgs), "hello world");
}
#[test]
fn extract_topic_returns_empty_when_no_messages() {
let msgs: EnrichmentMessages = Vec::new();
assert_eq!(topic_extractor::extract_from_messages(&msgs), "");
}
#[test]
fn extract_topic_returns_empty_when_missing_content() {
let msg = ChatMessageDto {
role: "user".into(),
content: MessageContent::Text(String::new()),
tool_calls: None,
};
let msgs: EnrichmentMessages = vec![msg];
assert_eq!(topic_extractor::extract_from_messages(&msgs), "");
}
#[test]
fn extract_topic_flattens_multipart() {
let msg = ChatMessageDto {
role: "user".into(),
content: MessageContent::Multipart(vec![
crate::types::ContentPart {
kind: "text".into(),
text: "alpha".into(),
},
crate::types::ContentPart {
kind: "image_url".into(),
text: String::new(),
},
crate::types::ContentPart {
kind: "text".into(),
text: "beta".into(),
},
]),
tool_calls: None,
};
let msgs: EnrichmentMessages = vec![msg];
assert_eq!(topic_extractor::extract_from_messages(&msgs), "alpha beta");
}
#[test]
fn build_memory_block_includes_session_and_fact_lines() {
let session = SessionId::from_raw("sess_0123456789ab").expect("session");
let key = MemoryKey::from_raw("origa").expect("key");
let facts = vec![RetrievalHit {
id: FactId::from_raw("fact_0123456789abcdef").expect("fact"),
document: "hello world".into(),
memory_key: key.clone(),
status: FactStatus::Accepted,
confidence: smos_domain::Confidence::new(0.9).unwrap(),
valid_until: None,
heat_base: Heat::MAX,
last_access_at: Timestamp::from_unix_secs(1_700_000_000).unwrap(),
}];
let block = build_memory_block(&facts, &session, &key);
assert!(block.contains("<smos-memory"));
assert!(block.contains("hello world"));
}
#[test]
fn prefilter_returns_empty_for_empty_input() {
let cfg = RetrievalConfig::default();
let heat = HeatConfig::default();
let now = Timestamp::from_unix_secs(1_700_000_000).unwrap();
assert!(prefilter(Vec::new(), &cfg, &heat, now).is_empty());
}
}
#[cfg(test)]
mod execute_tests {
use super::*;
use crate::ports::{EmbeddingProvider, RerankProvider};
use crate::testkit::{ConstantEmbedder, FixedClock, InMemoryFacts, InMemorySessions};
use crate::types::{RerankResult, SearchHit, SearchHitMetadata};
use serde_json::json;
struct NoneEmbedder;
impl EmbeddingProvider for NoneEmbedder {
async fn embed(&self, _text: &str) -> Result<Option<Vec<f32>>, ProviderError> {
Ok(None)
}
}
struct EmptyReranker;
impl RerankProvider for EmptyReranker {
async fn rerank(
&self,
_query: &str,
_documents: &[String],
_top_k: usize,
) -> Result<Vec<RerankResult>, ProviderError> {
Ok(Vec::new())
}
}
fn now_ts() -> Timestamp {
Timestamp::from_unix_secs(1_700_000_000).unwrap()
}
fn key() -> MemoryKey {
MemoryKey::from_raw("origa").unwrap()
}
fn sid() -> SessionId {
SessionId::from_raw("sess_0123456789ab").unwrap()
}
fn survivable_hit() -> SearchHit {
SearchHit {
id: FactId::from_content("an accepted fact about rust ownership"),
document: "an accepted fact about rust ownership".to_string(),
memory_key: key(),
metadata: SearchHitMetadata {
status: "accepted".into(),
confidence: 0.85,
valid_until: None,
heat_base: 1.0,
last_access_at: 1_700_000_000.0,
distance: Some(0.1),
created_at: None,
conflicts_with: Vec::new(),
},
}
}
#[tokio::test]
async fn enrich_skips_when_topic_below_min_chars() {
let facts = InMemoryFacts::default();
let sessions = InMemorySessions::default();
let embedder = ConstantEmbedder(vec![0.1, 0.2, 0.3]);
let reranker = EmptyReranker;
let clock = FixedClock(now_ts());
let retrieval = RetrievalConfig::default();
let heat = HeatConfig::default();
let uc = EnrichRequest {
facts: &facts,
sessions: &sessions,
embedder: &embedder,
reranker: &reranker,
clock: &clock,
retrieval_cfg: &retrieval,
heat_cfg: &heat,
};
let original = vec![json!({"role": "user", "content": "ok"})];
let out = uc
.execute(original.clone(), &key(), &sid())
.await
.expect("ok");
assert_eq!(
out, original,
"topic below min_topic_chars returns the messages unchanged (fail-open)"
);
}
#[tokio::test]
async fn enrich_fail_opens_when_embedder_returns_none() {
let facts = InMemoryFacts::default();
let sessions = InMemorySessions::default();
let embedder = NoneEmbedder;
let reranker = EmptyReranker;
let clock = FixedClock(now_ts());
let retrieval = RetrievalConfig::default();
let heat = HeatConfig::default();
let uc = EnrichRequest {
facts: &facts,
sessions: &sessions,
embedder: &embedder,
reranker: &reranker,
clock: &clock,
retrieval_cfg: &retrieval,
heat_cfg: &heat,
};
let original =
vec![json!({"role": "user", "content": "explain rust ownership and borrowing"})];
let out = uc
.execute(original.clone(), &key(), &sid())
.await
.expect("ok");
assert_eq!(
out, original,
"embedder None must fail-open to the original messages (no <smos-memory> block)"
);
}
#[tokio::test]
async fn enrich_fail_closes_with_provider_err_when_reranker_returns_empty() {
let facts = InMemoryFacts::default();
facts.script_search_hits(vec![survivable_hit()]);
let sessions = InMemorySessions::default();
let embedder = ConstantEmbedder(vec![0.1, 0.2, 0.3]);
let reranker = EmptyReranker;
let clock = FixedClock(now_ts());
let retrieval = RetrievalConfig::default();
let heat = HeatConfig::default();
let uc = EnrichRequest {
facts: &facts,
sessions: &sessions,
embedder: &embedder,
reranker: &reranker,
clock: &clock,
retrieval_cfg: &retrieval,
heat_cfg: &heat,
};
let original =
vec![json!({"role": "user", "content": "explain rust ownership and borrowing"})];
let result = uc.execute(original, &key(), &sid()).await;
assert!(
matches!(result, Err(UseCaseError::Provider(_))),
"an empty rerank result must fail-closed as UseCaseError::Provider (HTTP 503)"
);
}
}