use std::sync::Arc;
use smos_domain::chat::ToolCall;
use smos_domain::config::{ConfidenceConfig, ExtractionConfig};
use smos_domain::{MemoryKey, SessionId};
use crate::errors::UseCaseError;
use crate::ports::{
Clock, Delay, EmbeddingProvider, FactRepository, LlmExtractor, SessionRepository,
};
use crate::use_cases::extract_facts_from_response::ExtractFactsFromResponse;
#[derive(Debug, Clone, PartialEq)]
pub struct AssistantTurn {
pub message_id: String,
pub agent: String,
pub content: String,
pub tool_calls: Vec<ToolCall>,
}
#[derive(Debug, Clone, Default)]
pub struct ImportStats {
pub session_id: String,
pub turns_processed: usize,
pub turns_skipped: usize,
pub facts_extracted: usize,
}
pub struct ImportOpencodeSession<FR, SR, EP, LE, C, D> {
pub facts: FR,
pub sessions: SR,
pub embedder: EP,
pub extractor: LE,
pub clock: C,
pub delay: D,
pub confidence_cfg: Arc<ConfidenceConfig>,
pub extraction_cfg: Arc<ExtractionConfig>,
pub enable_response_extraction: bool,
pub min_chars: usize,
}
impl<FR, SR, EP, LE, C, D> ImportOpencodeSession<FR, SR, EP, LE, C, D>
where
FR: FactRepository,
SR: SessionRepository,
EP: EmbeddingProvider,
LE: LlmExtractor,
C: Clock,
D: Delay,
{
pub async fn execute(
&self,
turns: Vec<AssistantTurn>,
memory_key: &MemoryKey,
session_id: &SessionId,
agent_filter: Option<&[String]>,
) -> Result<ImportStats, UseCaseError> {
let mut stats = ImportStats {
session_id: session_id.as_str().to_string(),
..Default::default()
};
self.sessions.get_or_create(session_id, memory_key).await?;
for turn in &turns {
if self.should_skip(turn, agent_filter) {
stats.turns_skipped += 1;
continue;
}
stats.turns_processed += 1;
let new_count = self.extract_turn(turn, memory_key, session_id).await?;
stats.facts_extracted += new_count;
}
tracing::info!(
session = %session_id,
memory_key = %memory_key,
processed = stats.turns_processed,
skipped = stats.turns_skipped,
new_facts = stats.facts_extracted,
"import complete"
);
Ok(stats)
}
fn should_skip(&self, turn: &AssistantTurn, agent_filter: Option<&[String]>) -> bool {
if let Some(filter) = agent_filter
&& !filter.iter().any(|a| a == &turn.agent)
{
return true;
}
let too_short = turn.content.chars().count() < self.min_chars;
too_short && turn.tool_calls.is_empty()
}
async fn extract_turn(
&self,
turn: &AssistantTurn,
memory_key: &MemoryKey,
session_id: &SessionId,
) -> Result<usize, UseCaseError> {
let extractor = ExtractFactsFromResponse {
facts: &self.facts,
sessions: &self.sessions,
embedder: &self.embedder,
extractor: &self.extractor,
clock: &self.clock,
delay: &self.delay,
confidence_cfg: &self.confidence_cfg,
extraction_cfg: &self.extraction_cfg,
enable_response_extraction: self.enable_response_extraction,
};
extractor
.execute(&turn.content, &turn.tool_calls, memory_key, session_id)
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::testkit::{
ConstantEmbedder, FixedClock, InMemoryFacts, NoOpDelay, ScriptedExtractor,
};
use smos_domain::{Fact, FactId, NewPendingRequest, SessionState, Timestamp};
use std::sync::Mutex;
use std::time::Duration;
#[derive(Default, Clone)]
struct RecordingSessions {
created: std::sync::Arc<Mutex<bool>>,
}
impl SessionRepository for RecordingSessions {
async fn get_or_create(
&self,
id: &SessionId,
_m: &MemoryKey,
) -> Result<SessionState, crate::errors::RepoError> {
*self.created.lock().unwrap() = true;
Ok(SessionState::new(
id.clone(),
MemoryKey::from_raw("proj").unwrap(),
Timestamp::from_unix_secs(1_700_000_000).unwrap(),
))
}
async fn add_pending(
&self,
_i: &SessionId,
_ids: &[FactId],
) -> Result<(), crate::errors::RepoError> {
Ok(())
}
async fn collect_expired(
&self,
_t: Duration,
) -> Result<Vec<(SessionId, SessionState)>, crate::errors::RepoError> {
Ok(Vec::new())
}
async fn snapshot_all(
&self,
) -> Result<Vec<(SessionId, SessionState)>, crate::errors::RepoError> {
Ok(Vec::new())
}
async fn remove_pending_owned(
&self,
_i: &SessionId,
_o: &[FactId],
) -> Result<(), crate::errors::RepoError> {
Ok(())
}
async fn clear_session(&self, _i: &SessionId) -> Result<(), crate::errors::RepoError> {
Ok(())
}
async fn dedup_and_mark(
&self,
_i: &SessionId,
_m: &MemoryKey,
_c: &[FactId],
) -> Result<Vec<FactId>, crate::errors::RepoError> {
Ok(Vec::new())
}
async fn save(
&self,
_i: &SessionId,
_s: &SessionState,
) -> Result<(), crate::errors::RepoError> {
Ok(())
}
}
fn mk() -> MemoryKey {
MemoryKey::from_raw("proj").unwrap()
}
fn sid(tag: u8) -> SessionId {
SessionId::from_raw(&format!("sess_{:012x}", tag as u64)).unwrap()
}
struct Fix {
facts: InMemoryFacts,
sessions: RecordingSessions,
embedder: ConstantEmbedder,
clock: FixedClock,
cfg: ConfidenceConfig,
extraction_cfg: ExtractionConfig,
}
impl Fix {
fn new() -> Self {
Self {
facts: InMemoryFacts::default(),
sessions: RecordingSessions::default(),
embedder: ConstantEmbedder(vec![0.1, 0.2, 0.3]),
clock: FixedClock(Timestamp::from_unix_secs(1_700_000_000).unwrap()),
cfg: ConfidenceConfig::default(),
extraction_cfg: ExtractionConfig::default(),
}
}
fn build(
&self,
extractor: ScriptedExtractor,
min_chars: usize,
) -> ImportOpencodeSession<
InMemoryFacts,
RecordingSessions,
ConstantEmbedder,
ScriptedExtractor,
FixedClock,
NoOpDelay,
> {
ImportOpencodeSession {
facts: self.facts.clone(),
sessions: self.sessions.clone(),
embedder: ConstantEmbedder(self.embedder.0.clone()),
extractor,
clock: FixedClock(self.clock.0),
delay: NoOpDelay,
confidence_cfg: Arc::new(self.cfg.clone()),
extraction_cfg: Arc::new(self.extraction_cfg.clone()),
enable_response_extraction: true,
min_chars,
}
}
}
fn turn(agent: &str, content: &str) -> AssistantTurn {
AssistantTurn {
message_id: format!("msg_{agent}"),
agent: agent.to_string(),
content: content.to_string(),
tool_calls: Vec::new(),
}
}
#[tokio::test]
async fn execute_imports_each_turn_and_counts_new_facts() {
let fix = Fix::new();
let extractor = ScriptedExtractor::new(vec![
Ok(vec!["fact one".to_string()]),
Ok(vec!["fact two".to_string()]),
]);
let import = fix.build(extractor, 15);
let turns = vec![
turn("head-of-development", "TTL=10 prevents refresh loop"),
turn("head-of-development", "Auth uses JWT for tokens"),
];
let stats = import.execute(turns, &mk(), &sid(1), None).await.unwrap();
assert_eq!(stats.turns_processed, 2);
assert_eq!(stats.turns_skipped, 0);
assert_eq!(stats.facts_extracted, 2);
}
#[tokio::test]
async fn execute_skips_turns_below_min_chars_without_tool_calls() {
let fix = Fix::new();
let extractor = ScriptedExtractor::new(vec![Ok(vec!["real fact".to_string()])]);
let import = fix.build(extractor, 15);
let turns = vec![
turn("a", "ok"), turn("a", "TTL=10 prevents refresh loop"),
];
let stats = import.execute(turns, &mk(), &sid(1), None).await.unwrap();
assert_eq!(stats.turns_processed, 1);
assert_eq!(stats.turns_skipped, 1);
assert_eq!(stats.facts_extracted, 1);
}
#[tokio::test]
async fn execute_keeps_short_turn_when_it_has_tool_calls() {
let fix = Fix::new();
let extractor = ScriptedExtractor::new(vec![Ok(vec!["from tool".to_string()])]);
let import = fix.build(extractor, 15);
let mut short_with_tool = turn("a", "ok");
short_with_tool.tool_calls.push(ToolCall {
name: "read_file".into(),
arguments: smos_domain::chat::ToolArguments::from_json(r#"{"path":"auth.rs"}"#),
});
let stats = import
.execute(vec![short_with_tool], &mk(), &sid(1), None)
.await
.unwrap();
assert_eq!(stats.turns_processed, 1);
assert_eq!(stats.turns_skipped, 0);
assert_eq!(stats.facts_extracted, 1);
}
#[tokio::test]
async fn execute_applies_agent_filter() {
let fix = Fix::new();
let extractor = ScriptedExtractor::new(vec![
Ok(vec!["hod fact".to_string()]),
Ok(vec!["hod fact 2".to_string()]),
]);
let import = fix.build(extractor, 15);
let turns = vec![
turn("head-of-development", "TTL=10 prevents refresh loop"),
turn("dreaming", "Internal analysis content here"),
turn("head-of-development", "Auth uses JWT for tokens"),
];
let filter = vec!["head-of-development".to_string()];
let stats = import
.execute(turns, &mk(), &sid(1), Some(&filter))
.await
.unwrap();
assert_eq!(stats.turns_processed, 2);
assert_eq!(stats.turns_skipped, 1);
assert_eq!(stats.facts_extracted, 2);
}
#[tokio::test]
async fn execute_ensures_session_row_exists_before_first_turn() {
let fix = Fix::new();
let extractor = ScriptedExtractor::new(vec![]);
let import = fix.build(extractor, 15);
let _ = import.execute(vec![], &mk(), &sid(7), None).await.unwrap();
assert!(
*fix.sessions.created.lock().unwrap(),
"get_or_create must run even for an empty turn list"
);
}
#[tokio::test]
async fn execute_with_extraction_disabled_returns_zero_facts() {
let fix = Fix::new();
let extractor = ScriptedExtractor::new(vec![Ok(vec!["should not be stored".to_string()])]);
let mut import = fix.build(extractor, 15);
import.enable_response_extraction = false;
let stats = import
.execute(
vec![turn("a", "TTL=10 prevents refresh loop")],
&mk(),
&sid(1),
None,
)
.await
.unwrap();
assert_eq!(stats.turns_processed, 1);
assert_eq!(stats.facts_extracted, 0);
assert!(fix.facts.is_empty());
}
#[tokio::test]
async fn execute_confirms_existing_fact_instead_of_counting_it_new() {
let fix = Fix::new();
let seeded_content = "shared fact content here";
let first = Fact::new_pending(NewPendingRequest {
content: seeded_content,
memory_key: mk(),
session: sid(1),
embedding: smos_domain::Embedding::new(vec![1.0]).unwrap(),
extracted_at: Timestamp::from_unix_secs(1_700_000_000).unwrap(),
base_confidence: ConfidenceConfig::default().base,
})
.unwrap();
let fid = first.id().clone();
fix.facts.seed(first);
let extractor = ScriptedExtractor::new(vec![Ok(vec![seeded_content.to_string()])]);
let import = fix.build(extractor, 15);
let stats = import
.execute(vec![turn("a", seeded_content)], &mk(), &sid(2), None)
.await
.unwrap();
assert_eq!(stats.facts_extracted, 0, "confirmation is not a new fact");
let confirmed = fix.facts.get_clone(&fid).expect("fact present");
assert_eq!(
confirmed.source_sessions().distinct_count(),
2,
"provenance grew to two sessions"
);
}
}