use crate::llm::{LlmProvider, Message};
use anyhow::Result;
use serde::{Deserialize, Serialize};
use tracing::{info, warn};
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum ContextStrategy {
Truncate,
#[default]
Summarize,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(default)]
pub struct ContextConfig {
pub strategy: ContextStrategy,
pub budget_chars: usize,
pub threshold: f64,
pub summary_max_tokens: u32,
}
impl Default for ContextConfig {
fn default() -> Self {
ContextConfig {
strategy: ContextStrategy::default(),
budget_chars: 400_000,
threshold: 0.80,
summary_max_tokens: 800,
}
}
}
pub struct ContextManager {
pub config: ContextConfig,
pub summarize_count: u32,
}
impl ContextManager {
pub fn new(config: ContextConfig) -> Self {
ContextManager {
config,
summarize_count: 0,
}
}
pub async fn manage(
&mut self,
messages: &mut Vec<Message>,
llm: &dyn LlmProvider,
) -> Result<()> {
let total_chars: usize = messages
.iter()
.map(|m| m.text_content().map(|s| s.len()).unwrap_or(0))
.sum();
let trigger_chars = (self.config.budget_chars as f64 * self.config.threshold) as usize;
if total_chars <= trigger_chars {
return Ok(());
}
info!(
total_chars,
budget_chars = self.config.budget_chars,
"Context window approaching limit — applying {:?} strategy",
self.config.strategy,
);
match self.config.strategy {
ContextStrategy::Truncate => {
truncate_messages(messages, self.config.budget_chars);
}
ContextStrategy::Summarize => {
match self.try_summarize(messages, llm).await {
Ok(()) => {
self.summarize_count += 1;
info!(
summarize_count = self.summarize_count,
"Context summarised successfully"
);
}
Err(e) => {
warn!(
error = %e,
"Context summarisation failed — falling back to truncation"
);
truncate_messages(messages, self.config.budget_chars);
}
}
}
}
Ok(())
}
pub async fn try_summarize(
&self,
messages: &mut Vec<Message>,
llm: &dyn LlmProvider,
) -> Result<()> {
if messages.len() < 4 {
truncate_messages(messages, self.config.budget_chars);
return Ok(());
}
let non_system_count = messages.len().saturating_sub(1); let summarise_count = (non_system_count / 2).max(2);
let summarise_end = 1 + summarise_count;
let context_text: String = messages[1..summarise_end]
.iter()
.filter_map(|m| {
m.text_content().map(|text| {
let role_label = match m.role {
crate::llm::Role::User => "User",
crate::llm::Role::Assistant => "Assistant",
crate::llm::Role::System => "System",
crate::llm::Role::Tool => "Tool",
};
format!("[{role_label}]: {text}")
})
})
.collect::<Vec<_>>()
.join("\n\n");
let summarisation_prompt = format!(
"The following is an excerpt from a coding session conversation. \
Summarise the key context, decisions, file changes, tool calls, \
and current state in 500 words or less. Be factual and concise — \
this summary will be injected back into the conversation so the \
agent can continue working without losing important context.\n\n\
--- Conversation excerpt ---\n\n{context_text}\n\n--- End excerpt ---"
);
let summary_request = vec![Message::user(&summarisation_prompt)];
let summary_response = llm.chat_completion(&summary_request, &[]).await?;
let summary_text = summary_response
.content
.unwrap_or_else(|| "[Context summary unavailable]".to_string());
let summary_message = Message::user(format!(
"[Context Summary — condensed from {} messages]\n\n{}",
summarise_count, summary_text
));
let rest = messages.split_off(summarise_end); messages.truncate(1); messages.push(summary_message);
messages.extend(rest);
Ok(())
}
}
pub fn truncate_messages(messages: &mut Vec<Message>, budget_chars: usize) {
if messages.is_empty() {
return;
}
let total_chars: usize = messages
.iter()
.map(|m| m.text_content().map(|s| s.len()).unwrap_or(0))
.sum();
if total_chars <= budget_chars {
return;
}
let mut keep = vec![];
if !messages.is_empty() {
keep.push(messages[0].clone());
}
if messages.len() > 1 {
keep.push(messages[1].clone());
}
let mut chars = keep
.iter()
.map(|m| m.text_content().map(|s| s.len()).unwrap_or(0))
.sum::<usize>();
let mut tail = Vec::new();
for m in messages.iter().skip(2).rev() {
let mc = m.text_content().map(|s| s.len()).unwrap_or(0);
if chars + mc > budget_chars {
break;
}
tail.push(m.clone());
chars += mc;
}
tail.reverse();
let truncated_count = messages.len() - (keep.len() + tail.len());
if truncated_count > 0 {
warn!("Truncating {} messages for context window", truncated_count);
keep.push(Message::user(format!(
"[... {} messages truncated for context window ...]",
truncated_count
)));
}
keep.extend(tail);
*messages = keep;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::llm::Message;
fn make_messages(n: usize) -> Vec<Message> {
let mut msgs = vec![Message::system("You are a helpful assistant.")];
for i in 0..n {
let big = "x".repeat(50_000); if i % 2 == 0 {
msgs.push(Message::user(format!("user message {}: {}", i, big)));
} else {
msgs.push(Message::assistant(
Some(format!("assistant reply {}: {}", i, big)),
None,
));
}
}
msgs
}
#[test]
fn test_truncate_noop_when_under_budget() {
let mut msgs = vec![
Message::system("system"),
Message::user("hello"),
Message::assistant(Some("hi".to_string()), None),
];
let original_len = msgs.len();
truncate_messages(&mut msgs, 400_000);
assert_eq!(
msgs.len(),
original_len,
"should not truncate when under budget"
);
}
#[test]
fn test_truncate_keeps_system_and_recent() {
let mut msgs = make_messages(9); let system_content = msgs[0].text_content().unwrap();
truncate_messages(&mut msgs, 400_000);
assert_eq!(
msgs[0].text_content().unwrap(),
system_content,
"system prompt must be preserved"
);
let total: usize = msgs
.iter()
.map(|m| m.text_content().map(|s| s.len()).unwrap_or(0))
.sum();
assert!(
total <= 400_000,
"total chars after truncation should be ≤ budget, got {total}"
);
}
#[test]
fn test_truncate_inserts_marker() {
let mut msgs = make_messages(9);
truncate_messages(&mut msgs, 400_000);
let has_marker = msgs.iter().any(|m| {
m.text_content()
.map(|t| t.contains("messages truncated"))
.unwrap_or(false)
});
assert!(has_marker, "truncated marker message should be present");
}
#[test]
fn test_context_config_defaults() {
let cfg = ContextConfig::default();
assert_eq!(cfg.budget_chars, 400_000);
assert!((cfg.threshold - 0.80).abs() < 1e-9);
assert_eq!(cfg.strategy, ContextStrategy::Summarize);
assert_eq!(cfg.summary_max_tokens, 800);
}
#[tokio::test]
async fn test_context_manager_truncate_noop_under_threshold() {
let cfg = ContextConfig {
strategy: ContextStrategy::Truncate,
budget_chars: 400_000,
threshold: 0.80,
summary_max_tokens: 800,
};
let mut mgr = ContextManager::new(cfg);
let mut msgs = vec![Message::system("system"), Message::user("hello")];
let len_before = msgs.len();
struct NullProvider;
#[async_trait::async_trait]
impl crate::llm::LlmProvider for NullProvider {
async fn chat_completion(
&self,
_messages: &[crate::llm::Message],
_tools: &[crate::llm::ToolDefinition],
) -> Result<crate::llm::LlmResponse> {
panic!("should not be called");
}
}
mgr.manage(&mut msgs, &NullProvider).await.unwrap();
assert_eq!(msgs.len(), len_before, "no-op when under threshold");
}
#[tokio::test]
async fn test_context_manager_truncate_strategy() {
let cfg = ContextConfig {
strategy: ContextStrategy::Truncate,
budget_chars: 100_000,
threshold: 0.80,
summary_max_tokens: 800,
};
let mut mgr = ContextManager::new(cfg);
let mut msgs = vec![
Message::system("system"),
Message::user("x".repeat(50_000)),
Message::assistant(Some("y".repeat(50_000)), None),
Message::user("z".repeat(50_000)),
];
struct NullProvider;
#[async_trait::async_trait]
impl crate::llm::LlmProvider for NullProvider {
async fn chat_completion(
&self,
_: &[crate::llm::Message],
_: &[crate::llm::ToolDefinition],
) -> Result<crate::llm::LlmResponse> {
panic!("truncate should not call LLM");
}
}
mgr.manage(&mut msgs, &NullProvider).await.unwrap();
let total: usize = msgs
.iter()
.map(|m| m.text_content().map(|s| s.len()).unwrap_or(0))
.sum();
assert!(
total <= 100_000,
"should fit within budget after truncation, got {total}"
);
}
}