use std::sync::Arc;
use async_trait::async_trait;
use crate::traits::context_manager::ContextManager;
use crate::traits::provider::Provider;
use crate::types::agent_state::AgentState;
use crate::types::completion::{CompletionRequest, ResponseContent};
use crate::types::message::{Message, MessageRole};
fn estimate_tokens(messages: &[Message]) -> usize {
messages.iter().map(|m| m.content.len() / 4 + 1).sum()
}
pub struct RuleBasedCompressor {
threshold: f64,
recent_count: usize,
}
impl RuleBasedCompressor {
#[must_use]
pub fn new(threshold: f64, recent_count: usize) -> Self {
Self {
threshold: threshold.clamp(0.0, 1.0),
recent_count,
}
}
fn score_message(msg: &Message, is_recent: bool) -> f64 {
if msg.role == MessageRole::System {
return f64::INFINITY; }
if is_recent {
return 0.9;
}
if msg.tool_call_id.is_some() || msg.role == MessageRole::Tool {
return 0.7;
}
0.3
}
}
impl Default for RuleBasedCompressor {
fn default() -> Self {
Self::new(0.85, 3)
}
}
#[async_trait]
impl ContextManager for RuleBasedCompressor {
#[allow(
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::cast_precision_loss
)]
async fn prepare(
&self,
messages: &mut Vec<Message>,
context_window: usize,
state: &mut AgentState,
) {
let max_tokens = (context_window as f64 * self.threshold) as usize;
if estimate_tokens(messages) <= max_tokens {
return;
}
let total_non_system = messages
.iter()
.filter(|m| m.role != MessageRole::System)
.count();
let recent_start = total_non_system.saturating_sub(self.recent_count);
let mut scored: Vec<(usize, f64, usize)> = Vec::new(); let mut non_system_idx = 0usize;
for (i, msg) in messages.iter().enumerate() {
if msg.role == MessageRole::System {
continue;
}
let is_recent = non_system_idx >= recent_start;
let tokens = msg.content.len() / 4 + 1;
scored.push((i, Self::score_message(msg, is_recent), tokens));
non_system_idx += 1;
}
scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let mut projected_tokens = estimate_tokens(messages);
let mut remove_set: Vec<usize> = Vec::new();
for &(idx, score, tokens) in &scored {
if projected_tokens <= max_tokens {
break;
}
if score.is_infinite() {
continue; }
remove_set.push(idx);
projected_tokens = projected_tokens.saturating_sub(tokens);
}
if !remove_set.is_empty() {
remove_set.sort_unstable();
for &idx in remove_set.iter().rev() {
messages.remove(idx);
}
state.last_output_truncated = true;
}
}
}
pub struct LlmCompressor {
provider: Arc<dyn Provider>,
summary_prompt: String,
threshold: f64,
keep_recent: usize,
}
impl LlmCompressor {
const DEFAULT_PROMPT: &str = "Summarize the following conversation messages \
into a concise paragraph. Preserve key facts, decisions, and context. \
Omit greetings and filler.";
#[must_use]
pub fn new(provider: Arc<dyn Provider>) -> Self {
Self {
provider,
summary_prompt: Self::DEFAULT_PROMPT.to_string(),
threshold: 0.80,
keep_recent: 4,
}
}
#[must_use]
pub fn with_prompt(mut self, prompt: impl Into<String>) -> Self {
self.summary_prompt = prompt.into();
self
}
#[must_use]
pub fn with_threshold(mut self, threshold: f64) -> Self {
self.threshold = threshold.clamp(0.0, 1.0);
self
}
#[must_use]
pub fn with_keep_recent(mut self, count: usize) -> Self {
self.keep_recent = count;
self
}
}
#[async_trait]
impl ContextManager for LlmCompressor {
#[allow(
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::cast_precision_loss
)]
async fn prepare(
&self,
messages: &mut Vec<Message>,
context_window: usize,
state: &mut AgentState,
) {
let max_tokens = (context_window as f64 * self.threshold) as usize;
if estimate_tokens(messages) <= max_tokens {
return;
}
let system_msgs: Vec<Message> = messages
.iter()
.filter(|m| m.role == MessageRole::System)
.cloned()
.collect();
let non_system: Vec<Message> = messages
.iter()
.filter(|m| m.role != MessageRole::System)
.cloned()
.collect();
if non_system.len() <= self.keep_recent {
return; }
let split_at = non_system.len() - self.keep_recent;
let old_messages = &non_system[..split_at];
let recent_messages = &non_system[split_at..];
let old_text: String = old_messages
.iter()
.map(|m| format!("{:?}: {}", m.role, m.content))
.collect::<Vec<_>>()
.join("\n");
let req = CompletionRequest {
model: self.provider.model_info().name.clone(),
messages: vec![
Message {
role: MessageRole::System,
content: self.summary_prompt.clone(),
tool_call_id: None,
},
Message {
role: MessageRole::User,
content: old_text,
tool_call_id: None,
},
],
tools: vec![],
max_tokens: Some(500),
temperature: Some(0.3),
response_format: None,
stream: false,
};
let summary_text = match self.provider.complete(req).await {
Ok(response) => match response.content {
ResponseContent::Text(text) => text,
ResponseContent::ToolCalls(_) => {
tracing::warn!("LlmCompressor: provider returned tool calls instead of text");
Self::fallback_summary(old_messages)
}
},
Err(e) => {
tracing::warn!("LlmCompressor: summarization failed ({e}), using fallback");
Self::fallback_summary(old_messages)
}
};
let summary_msg = Message {
role: MessageRole::Assistant,
content: format!("[Context Summary] {summary_text}"),
tool_call_id: None,
};
messages.clear();
messages.extend(system_msgs);
messages.push(summary_msg);
messages.extend(recent_messages.iter().cloned());
state.last_output_truncated = true;
}
}
impl LlmCompressor {
fn fallback_summary(old_messages: &[Message]) -> String {
format!(
"{} earlier messages were removed to save context space.",
old_messages.len()
)
}
}
pub struct TieredCompressor {
recent_count: usize,
rule_compressor: RuleBasedCompressor,
llm_compressor: Option<LlmCompressor>,
}
impl TieredCompressor {
#[must_use]
pub fn new(recent_count: usize) -> Self {
Self {
recent_count,
rule_compressor: RuleBasedCompressor::new(0.85, recent_count),
llm_compressor: None,
}
}
#[must_use]
pub fn with_llm(mut self, provider: Arc<dyn Provider>) -> Self {
self.llm_compressor =
Some(LlmCompressor::new(provider).with_keep_recent(self.recent_count));
self
}
}
#[async_trait]
impl ContextManager for TieredCompressor {
async fn prepare(
&self,
messages: &mut Vec<Message>,
context_window: usize,
state: &mut AgentState,
) {
if let Some(llm) = &self.llm_compressor {
llm.prepare(messages, context_window, state).await;
}
self.rule_compressor
.prepare(messages, context_window, state)
.await;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::completion::{CompletionResponse, ResponseContent, Usage};
use crate::types::model_info::{ModelInfo, ModelTier};
use crate::types::stream::CompletionStream;
fn msg(role: MessageRole, content: &str) -> Message {
Message {
role,
content: content.to_string(),
tool_call_id: None,
}
}
fn tool_msg(content: &str) -> Message {
Message {
role: MessageRole::Tool,
content: content.to_string(),
tool_call_id: Some("call_1".to_string()),
}
}
fn default_state() -> AgentState {
AgentState::new(ModelTier::Medium, 128_000)
}
#[tokio::test]
async fn test_rule_compressor_no_pruning_under_threshold() {
let comp = RuleBasedCompressor::default();
let mut msgs = vec![
msg(MessageRole::System, "system"),
msg(MessageRole::User, "hello"),
];
let mut state = default_state();
comp.prepare(&mut msgs, 100_000, &mut state).await;
assert_eq!(msgs.len(), 2);
assert!(!state.last_output_truncated);
}
#[tokio::test]
async fn test_rule_compressor_removes_lowest_scored() {
let comp = RuleBasedCompressor::new(0.85, 1);
let mut msgs = vec![
msg(MessageRole::System, "system"),
msg(MessageRole::User, &"old1 ".repeat(500)), msg(MessageRole::Assistant, &"old2 ".repeat(500)), tool_msg(&"tool ".repeat(500)), msg(MessageRole::User, &"recent ".repeat(500)), ];
let mut state = default_state();
comp.prepare(&mut msgs, 800, &mut state).await;
assert_eq!(msgs[0].role, MessageRole::System);
assert!(msgs.len() < 5, "should have removed some messages");
assert!(state.last_output_truncated);
}
#[tokio::test]
async fn test_rule_compressor_never_removes_system() {
let comp = RuleBasedCompressor::new(0.5, 0);
let mut msgs = vec![
msg(MessageRole::System, &"sys ".repeat(1000)),
msg(MessageRole::User, "tiny"),
];
let mut state = default_state();
comp.prepare(&mut msgs, 100, &mut state).await;
assert!(msgs.iter().any(|m| m.role == MessageRole::System));
}
#[tokio::test]
async fn test_rule_compressor_updates_state() {
let comp = RuleBasedCompressor::new(0.5, 0);
let mut msgs = vec![
msg(MessageRole::System, "sys"),
msg(MessageRole::User, &"x".repeat(4000)),
msg(MessageRole::Assistant, &"y".repeat(4000)),
];
let mut state = default_state();
comp.prepare(&mut msgs, 1000, &mut state).await;
assert!(state.last_output_truncated);
}
struct MockSummarizer {
info: ModelInfo,
}
impl MockSummarizer {
fn new() -> Self {
Self {
info: ModelInfo::new(
"mock-summarizer",
ModelTier::Small,
4096,
false,
false,
false,
),
}
}
}
#[async_trait]
impl Provider for MockSummarizer {
async fn complete(&self, _req: CompletionRequest) -> crate::Result<CompletionResponse> {
Ok(CompletionResponse {
content: ResponseContent::Text(
"User asked about Rust. Assistant explained traits.".to_string(),
),
usage: Usage {
prompt_tokens: 50,
completion_tokens: 20,
total_tokens: 70,
},
})
}
async fn stream(&self, _req: CompletionRequest) -> crate::Result<CompletionStream> {
unimplemented!()
}
fn model_info(&self) -> &ModelInfo {
&self.info
}
}
struct FailingSummarizer {
info: ModelInfo,
}
impl FailingSummarizer {
fn new() -> Self {
Self {
info: ModelInfo::new("failing", ModelTier::Small, 4096, false, false, false),
}
}
}
#[async_trait]
impl Provider for FailingSummarizer {
async fn complete(&self, _req: CompletionRequest) -> crate::Result<CompletionResponse> {
Err(crate::Error::Provider {
message: "network error".into(),
status_code: None,
})
}
async fn stream(&self, _req: CompletionRequest) -> crate::Result<CompletionStream> {
unimplemented!()
}
fn model_info(&self) -> &ModelInfo {
&self.info
}
}
#[tokio::test]
async fn test_llm_compressor_summarizes_old_messages() {
let provider: Arc<dyn Provider> = Arc::new(MockSummarizer::new());
let comp = LlmCompressor::new(provider).with_keep_recent(2);
let mut msgs = vec![
msg(MessageRole::System, "You are helpful"),
msg(MessageRole::User, &"old question ".repeat(500)),
msg(MessageRole::Assistant, &"old answer ".repeat(500)),
msg(MessageRole::User, "recent question"),
msg(MessageRole::Assistant, "recent answer"),
];
let mut state = default_state();
comp.prepare(&mut msgs, 800, &mut state).await;
assert_eq!(msgs[0].role, MessageRole::System);
assert!(
msgs[1].content.contains("[Context Summary]"),
"should have summary: {}",
msgs[1].content
);
assert_eq!(msgs.len(), 4); assert!(state.last_output_truncated);
}
#[tokio::test]
async fn test_llm_compressor_no_compression_under_threshold() {
let provider: Arc<dyn Provider> = Arc::new(MockSummarizer::new());
let comp = LlmCompressor::new(provider).with_keep_recent(2);
let mut msgs = vec![
msg(MessageRole::System, "sys"),
msg(MessageRole::User, "hi"),
msg(MessageRole::Assistant, "hello"),
];
let mut state = default_state();
comp.prepare(&mut msgs, 100_000, &mut state).await;
assert_eq!(msgs.len(), 3);
assert!(!state.last_output_truncated);
}
#[tokio::test]
async fn test_llm_compressor_custom_prompt() {
let provider: Arc<dyn Provider> = Arc::new(MockSummarizer::new());
let comp = LlmCompressor::new(provider)
.with_prompt("Custom prompt")
.with_keep_recent(1);
let mut msgs = vec![
msg(MessageRole::System, "sys"),
msg(MessageRole::User, &"old ".repeat(2000)),
msg(MessageRole::User, "recent"),
];
let mut state = default_state();
comp.prepare(&mut msgs, 500, &mut state).await;
assert!(msgs[1].content.contains("[Context Summary]"));
}
#[tokio::test]
async fn test_llm_compressor_fallback_on_failure() {
let provider: Arc<dyn Provider> = Arc::new(FailingSummarizer::new());
let comp = LlmCompressor::new(provider).with_keep_recent(1);
let mut msgs = vec![
msg(MessageRole::System, "sys"),
msg(MessageRole::User, &"old ".repeat(2000)),
msg(MessageRole::Assistant, &"old ".repeat(2000)),
msg(MessageRole::User, "recent"),
];
let mut state = default_state();
comp.prepare(&mut msgs, 500, &mut state).await;
assert!(msgs[1].content.contains("[Context Summary]"));
assert!(msgs[1].content.contains("removed to save context"));
assert!(state.last_output_truncated);
}
#[tokio::test]
async fn test_tiered_compressor_rule_only() {
let comp = TieredCompressor::new(2);
let mut msgs = vec![
msg(MessageRole::System, "sys"),
msg(MessageRole::User, &"old1 ".repeat(500)),
msg(MessageRole::Assistant, &"old2 ".repeat(500)),
msg(MessageRole::User, &"recent1 ".repeat(500)),
msg(MessageRole::Assistant, &"recent2 ".repeat(500)),
];
let mut state = default_state();
comp.prepare(&mut msgs, 1500, &mut state).await;
assert_eq!(msgs[0].role, MessageRole::System);
assert!(msgs.len() < 5);
assert!(state.last_output_truncated);
}
#[tokio::test]
async fn test_tiered_compressor_with_llm() {
let provider: Arc<dyn Provider> = Arc::new(MockSummarizer::new());
let comp = TieredCompressor::new(2).with_llm(provider);
let mut msgs = vec![
msg(MessageRole::System, "sys"),
msg(MessageRole::User, &"old ".repeat(1000)),
msg(MessageRole::Assistant, &"old ".repeat(1000)),
msg(MessageRole::User, "recent1"),
msg(MessageRole::Assistant, "recent2"),
];
let mut state = default_state();
comp.prepare(&mut msgs, 800, &mut state).await;
assert_eq!(msgs[0].role, MessageRole::System);
assert!(
msgs.iter().any(|m| m.content.contains("[Context Summary]")),
"should have LLM summary"
);
assert!(state.last_output_truncated);
}
#[tokio::test]
async fn test_rule_compressor_50_messages_within_budget() {
let comp = RuleBasedCompressor::new(0.85, 5);
let mut msgs = vec![msg(MessageRole::System, "You are a helpful assistant")];
for i in 0..50 {
msgs.push(msg(
if i % 2 == 0 {
MessageRole::User
} else {
MessageRole::Assistant
},
&format!("Message number {i}: {}", "content ".repeat(100)),
));
}
let mut state = default_state();
let window = 2000;
comp.prepare(&mut msgs, window, &mut state).await;
let tokens: usize = msgs.iter().map(|m| m.content.len() / 4 + 1).sum();
let max = (window as f64 * 0.85) as usize;
assert!(tokens <= max, "should be within budget: {tokens} <= {max}");
assert_eq!(msgs[0].role, MessageRole::System);
assert!(state.last_output_truncated);
}
}