use std::time::Duration;
use smos_domain::chat::ToolCall;
use smos_domain::config::{ConfidenceConfig, ExtractionConfig};
use smos_domain::{MemoryKey, SessionId};
use crate::errors::{ProviderError, UseCaseError};
use crate::helpers::noise_filter;
use crate::ports::{
Clock, Delay, EmbeddingProvider, FactRepository, LlmExtractor, SessionRepository,
};
pub const MIN_INPUT_CHARS: usize = 15;
const EXTRACTION_ATTEMPTS: u32 = 3;
const BACKOFF: [Duration; 2] = [Duration::from_secs(1), Duration::from_secs(2)];
pub struct ExtractFactsFromResponse<'a, FR, SR, EP, LE, C, D> {
pub facts: &'a FR,
pub sessions: &'a SR,
pub embedder: &'a EP,
pub extractor: &'a LE,
pub clock: &'a C,
pub delay: &'a D,
pub confidence_cfg: &'a ConfidenceConfig,
pub extraction_cfg: &'a ExtractionConfig,
pub enable_response_extraction: bool,
}
impl<'a, FR, SR, EP, LE, C, D> ExtractFactsFromResponse<'a, FR, SR, EP, LE, C, D>
where
FR: FactRepository,
SR: SessionRepository,
EP: EmbeddingProvider,
LE: LlmExtractor,
C: Clock,
D: Delay,
{
pub async fn execute(
&self,
content: &str,
tool_calls: &[ToolCall],
memory_key: &MemoryKey,
session_id: &SessionId,
) -> Result<usize, UseCaseError> {
if !self.enable_response_extraction {
return Ok(0);
}
let mut input = noise_filter::clean(content);
input.push_str(&format_tool_calls(tool_calls));
if input.trim().chars().count() < MIN_INPUT_CHARS {
tracing::debug!(
len = input.len(),
"extraction skipped: input below MIN_INPUT_CHARS"
);
return Ok(0);
}
let raw_facts = self.extract_with_retries(&input, tool_calls).await?;
if raw_facts.is_empty() {
return Ok(0);
}
let new_ids = self
.persist_facts(&raw_facts, memory_key, session_id)
.await?;
if !new_ids.is_empty() {
self.sessions.add_pending(session_id, &new_ids).await?;
}
Ok(new_ids.len())
}
async fn extract_with_retries(
&self,
input: &str,
tool_calls: &[ToolCall],
) -> Result<Vec<String>, UseCaseError> {
for attempt in 0..EXTRACTION_ATTEMPTS {
match self.extractor.extract_facts(input, tool_calls).await {
Ok(facts) if !facts.is_empty() => return Ok(facts),
Ok(_) => self.maybe_sleep(attempt).await,
Err(ProviderError::Unavailable(msg)) => {
tracing::warn!(error = %msg, "extractor unavailable; skipping (graceful)");
return Ok(Vec::new());
}
Err(e) if attempt + 1 < EXTRACTION_ATTEMPTS => {
tracing::warn!(attempt = attempt + 1, error = %e, "extraction failed; retrying");
self.maybe_sleep(attempt).await;
}
Err(e) => return Err(e.into()),
}
}
Ok(Vec::new())
}
async fn maybe_sleep(&self, attempt: u32) {
if let Some(delay) = BACKOFF.get(attempt as usize) {
self.delay.delay(*delay).await;
}
}
}
pub fn format_tool_calls(tool_calls: &[ToolCall]) -> String {
if tool_calls.is_empty() {
return String::new();
}
let mut out = String::from("\n\nTool calls:");
for call in tool_calls {
out.push_str(&format!("\n- {}({})", call.name, call.arguments));
}
out
}
pub mod dedup;
pub mod persist;
#[cfg(test)]
mod tests;