use std::collections::HashMap;
use std::sync::Arc;
use serde_json::Value;
use smos_domain::chat::{ToolArguments, ToolCall};
use smos_domain::config::{HeatConfig, RetrievalConfig};
use smos_domain::{MemoryKey, SessionId};
use crate::errors::UseCaseError;
use crate::helpers::person_router::{
PersonEntry, ProviderEntry, inject_persona_into_messages, load_persona_at, route_request,
};
use crate::helpers::session_marker;
use crate::ports::{
Clock, EmbeddingProvider, FactRepository, IdGenerator, LlmUpstream, RerankProvider,
SessionRepository,
};
use crate::types::{ChatRequest, ChatResponse, enrichment_messages_from_json};
use crate::use_cases::enrich_request::EnrichRequest;
pub struct HandleChatCompletion<FR, SR, EP, RP, LU, C, IG> {
pub facts: FR,
pub sessions: SR,
pub embedder: EP,
pub reranker: RP,
pub upstream: LU,
pub clock: C,
pub id_generator: IG,
pub retrieval_cfg: Arc<RetrievalConfig>,
pub heat_cfg: Arc<HeatConfig>,
pub persons: Arc<HashMap<String, PersonEntry>>,
pub providers: Arc<Vec<ProviderEntry>>,
}
impl<FR, SR, EP, RP, LU, C, IG> HandleChatCompletion<FR, SR, EP, RP, LU, C, IG>
where
FR: FactRepository,
SR: SessionRepository,
EP: EmbeddingProvider,
RP: RerankProvider,
LU: LlmUpstream,
C: Clock,
IG: IdGenerator,
{
pub async fn execute(
&self,
mut request: ChatRequest,
) -> Result<(ChatResponse, SessionId, MemoryKey), UseCaseError> {
let route = route_request(&request.model, &self.persons, &self.providers)?;
let memory_key = route.memory_key;
request.model = route.upstream_model;
let provider_name = route.provider_name;
if let Some(persona_path) = route.persona_path
&& let Some(persona_content) = load_persona_at(&persona_path)
{
inject_persona_into_messages(&mut request.messages, &persona_content);
}
let typed_projection = enrichment_messages_from_json(&request.messages);
let session_id = session_marker::detect_from_typed_messages(&typed_projection)
.unwrap_or_else(|| self.id_generator.new_session_id());
let enriched_messages = self
.enrich(
std::mem::take(&mut request.messages),
&memory_key,
&session_id,
)
.await?;
request.messages = enriched_messages;
let response = self.upstream.complete(&provider_name, request).await?;
Ok((response, session_id, memory_key))
}
async fn enrich(
&self,
messages: Vec<Value>,
memory_key: &MemoryKey,
session_id: &SessionId,
) -> Result<Vec<Value>, UseCaseError> {
let enrich = EnrichRequest {
facts: &self.facts,
sessions: &self.sessions,
embedder: &self.embedder,
reranker: &self.reranker,
clock: &self.clock,
retrieval_cfg: &self.retrieval_cfg,
heat_cfg: &self.heat_cfg,
};
enrich.execute(messages, memory_key, session_id).await
}
}
pub fn extract_response_payload(value: &Value) -> (String, Vec<ToolCall>) {
let content = value
.pointer("/choices/0/message/content")
.and_then(Value::as_str)
.unwrap_or("")
.to_string();
let tool_calls = value
.pointer("/choices/0/message/tool_calls")
.and_then(Value::as_array)
.map(|arr| arr.iter().filter_map(parse_openai_tool_call).collect())
.unwrap_or_default();
(content, tool_calls)
}
fn parse_openai_tool_call(v: &Value) -> Option<ToolCall> {
let function = v.get("function")?;
let name = function.get("name")?.as_str()?.to_string();
let arguments = match function.get("arguments") {
Some(Value::String(raw)) => raw.clone(),
Some(other) => serde_json::to_string(other).unwrap_or_else(|_| "null".to_string()),
None => "null".to_string(),
};
Some(ToolCall {
name,
arguments: ToolArguments::from_json(arguments),
})
}