use crate::core::sm::config::{SmInferenceConfig, SmRoundsConfig};
use crate::core::sm::providers::{ChatMessage, LlmProvider, SmLlmError};
use super::compaction::{estimate_tokens, fold_rounds, resummarise};
use super::model::{Round, SmConversation, ToolTrace};
use super::persist::{ConversationStore, ConversationStoreError};
const ROLE_SYSTEM: &str = "system";
const ROLE_USER: &str = "user";
const ROLE_ASSISTANT: &str = "assistant";
#[derive(Debug, thiserror::Error)]
pub enum SmContextError {
#[error("context compaction failed: {0}")]
Compaction(#[from] SmLlmError),
#[error("context persistence failed: {0}")]
Persist(#[from] ConversationStoreError),
}
pub struct SmContextEngine {
conv_id: String,
conversation: SmConversation,
window: usize,
token_budget: usize,
compressed_max_tokens: usize,
store: ConversationStore,
}
impl SmContextEngine {
pub fn open(
conv_id: impl Into<String>,
data_root: impl Into<std::path::PathBuf>,
inference: &SmInferenceConfig,
rounds: &SmRoundsConfig,
) -> Result<Self, SmContextError> {
let conv_id = conv_id.into();
let store = ConversationStore::new(data_root);
let conversation = store.load(&conv_id)?;
let mut engine = Self {
conv_id,
conversation,
window: rounds.window as usize,
token_budget: inference.context_token_budget as usize,
compressed_max_tokens: inference.compressed_context_max_tokens as usize,
store,
};
engine.recompute_estimate();
Ok(engine)
}
pub fn conversation(&self) -> &SmConversation {
&self.conversation
}
pub fn conv_id(&self) -> &str {
&self.conv_id
}
pub async fn record(
&mut self,
provider: &dyn LlmProvider,
compaction_model: &str,
user: impl Into<String>,
assistant: impl Into<String>,
ts: chrono::DateTime<chrono::Utc>,
tool_calls: Vec<ToolTrace>,
) -> Result<usize, SmContextError> {
let round = Round::new(user, assistant, ts, tool_calls);
self.conversation.recent_rounds.push_back(round);
self.conversation.total_rounds += 1;
self.recompute_estimate();
let mut evicted = 0usize;
while self.should_compact() && self.conversation.recent_rounds.len() > 1 {
self.compact_once(provider, compaction_model).await?;
evicted += 1;
}
self.converge_within_budget(provider, compaction_model)
.await?;
self.store.save(&self.conv_id, &self.conversation)?;
Ok(evicted)
}
pub fn record_without_compaction(
&mut self,
user: impl Into<String>,
assistant: impl Into<String>,
ts: chrono::DateTime<chrono::Utc>,
tool_calls: Vec<ToolTrace>,
) -> Result<(), SmContextError> {
let round = Round::new(user, assistant, ts, tool_calls);
self.conversation.recent_rounds.push_back(round);
self.conversation.total_rounds += 1;
self.recompute_estimate();
self.store.save(&self.conv_id, &self.conversation)?;
Ok(())
}
async fn converge_within_budget(
&mut self,
provider: &dyn LlmProvider,
compaction_model: &str,
) -> Result<(), SmContextError> {
while self.should_compact() {
let before = self.conversation.token_estimate;
if !self.conversation.compressed_context.trim().is_empty() {
let resp = resummarise(
provider,
compaction_model,
&self.conversation.compressed_context,
)
.await?;
self.conversation.compressed_context = resp.text;
self.recompute_estimate();
}
if self.should_compact() && self.conversation.recent_rounds.len() == 1 {
self.compact_once(provider, compaction_model).await?;
}
if self.conversation.token_estimate >= before
&& self.conversation.recent_rounds.is_empty()
{
tracing::debug!(
conv_id = %self.conv_id,
token_estimate = self.conversation.token_estimate,
token_budget = self.token_budget,
"context convergence stalled; persisting best-effort over-budget context"
);
break;
}
}
Ok(())
}
fn should_compact(&self) -> bool {
self.conversation.recent_rounds.len() > self.window
|| self.conversation.token_estimate > self.token_budget
}
async fn compact_once(
&mut self,
provider: &dyn LlmProvider,
compaction_model: &str,
) -> Result<(), SmContextError> {
let Some(oldest) = self.conversation.recent_rounds.pop_front() else {
return Ok(());
};
let evicted = [oldest];
let resp = fold_rounds(
provider,
compaction_model,
&self.conversation.compressed_context,
&evicted,
)
.await?;
self.conversation.compressed_context = resp.text;
self.recompute_estimate();
if self.compressed_max_tokens > 0
&& estimate_tokens(self.conversation.compressed_context.len())
> self.compressed_max_tokens
{
let resp = resummarise(
provider,
compaction_model,
&self.conversation.compressed_context,
)
.await?;
self.conversation.compressed_context = resp.text;
self.recompute_estimate();
}
Ok(())
}
fn recompute_estimate(&mut self) {
let chars = self.conversation.compressed_context.len()
+ self
.conversation
.recent_rounds
.iter()
.map(Round::char_len)
.sum::<usize>();
self.conversation.token_estimate = estimate_tokens(chars);
}
pub fn assemble_working_prompt(
&self,
system_prompt: &str,
memory_recall: Option<&str>,
message: &str,
) -> Vec<ChatMessage> {
let mut msgs: Vec<ChatMessage> = Vec::new();
let mut sections: Vec<String> = Vec::new();
if !system_prompt.trim().is_empty() {
sections.push(system_prompt.to_string());
}
if !self.conversation.compressed_context.trim().is_empty() {
sections.push(format!(
"Earlier in this conversation: {}",
self.conversation.compressed_context
));
}
if let Some(recall) = memory_recall
&& !recall.trim().is_empty()
{
sections.push(format!("Relevant memory: {recall}"));
}
if !sections.is_empty() {
msgs.push(ChatMessage {
role: ROLE_SYSTEM.to_string(),
content: sections.join("\n\n"),
});
}
for round in &self.conversation.recent_rounds {
msgs.push(ChatMessage {
role: ROLE_USER.to_string(),
content: round.user.clone(),
});
msgs.push(ChatMessage {
role: ROLE_ASSISTANT.to_string(),
content: round.assistant.clone(),
});
}
msgs.push(ChatMessage {
role: ROLE_USER.to_string(),
content: message.to_string(),
});
msgs
}
}
#[cfg(test)]
#[path = "engine_tests.rs"]
mod tests;