pub mod compressor;
pub use echo_core::compression::{CompressionInput, CompressionOutput, ContextCompressor};
use crate::compression::compressor::SlidingWindowCompressor;
use echo_core::error::Result;
use echo_core::llm::types::{Message, MessageContent};
use echo_core::tokenizer::{HeuristicTokenizer, Tokenizer};
use std::sync::Arc;
struct ProtectedMessage {
message: Message,
compressible_after: usize,
protected_after: usize,
}
pub struct ForceCompressStats {
pub before_count: usize,
pub after_count: usize,
pub evicted: usize,
pub before_tokens: usize,
pub after_tokens: usize,
}
pub struct PrepareResult {
pub messages: Vec<Message>,
pub compressed: Option<ForceCompressStats>,
}
pub struct ContextManager {
messages: Vec<Message>,
compressor: Option<Box<dyn ContextCompressor>>,
token_limit: usize,
tokenizer: Arc<dyn Tokenizer>,
protected_markers: Vec<String>,
max_messages: usize,
}
impl ContextManager {
pub fn builder(token_limit: usize) -> ContextManagerBuilder {
ContextManagerBuilder {
token_limit,
compressor: None,
initial_messages: Vec::new(),
tokenizer: None,
max_messages: None,
}
}
pub fn push(&mut self, message: Message) {
self.messages.push(message);
if self.messages.len() > self.max_messages {
self.apply_hard_cap();
}
}
fn apply_hard_cap(&mut self) {
let target = self.max_messages;
if self.messages.len() <= target {
return;
}
let mut protected_indices: Vec<usize> = Vec::new();
for (i, msg) in self.messages.iter().enumerate() {
if self.is_protected(msg) {
protected_indices.push(i);
}
}
let first_non_system = self
.messages
.iter()
.position(|m| m.role != "system")
.unwrap_or(0);
let excess = self.messages.len() - target;
let mut to_remove = Vec::new();
let mut removed = 0;
for i in first_non_system..self.messages.len() {
if removed >= excess {
break;
}
if protected_indices.contains(&i) {
continue;
}
to_remove.push(i);
removed += 1;
}
if to_remove.is_empty() {
return;
}
tracing::warn!(
total = self.messages.len(),
cap = target,
evicted = to_remove.len(),
"Message count exceeded hard cap, applying sliding window degradation (preserving protected messages)"
);
for &i in to_remove.iter().rev() {
self.messages.remove(i);
}
}
pub fn push_many(&mut self, messages: impl IntoIterator<Item = Message>) {
self.messages.extend(messages);
}
pub fn messages(&self) -> &[Message] {
&self.messages
}
pub fn set_messages(&mut self, messages: Vec<Message>) {
self.messages = messages;
}
pub fn token_estimate(&self) -> usize {
Self::estimate_tokens(&self.messages, &*self.tokenizer)
}
pub fn tokenizer(&self) -> &dyn Tokenizer {
&*self.tokenizer
}
pub fn set_tokenizer(&mut self, tokenizer: Arc<dyn Tokenizer>) {
self.tokenizer = tokenizer;
}
pub fn clear(&mut self) {
self.messages.clear();
}
pub fn add_protected_marker(&mut self, marker: String) {
if !self.protected_markers.contains(&marker) {
self.protected_markers.push(marker);
}
}
fn is_protected(&self, message: &Message) -> bool {
if self.protected_markers.is_empty() {
return false;
}
if let Some(content) = message.content.as_text() {
self.protected_markers.iter().any(|m| content.contains(m))
} else {
false
}
}
fn split_protected(&self, messages: Vec<Message>) -> (Vec<Message>, Vec<ProtectedMessage>) {
let mut compressible = Vec::new();
let mut protected: Vec<(usize, Message)> = Vec::new();
let mut compressible_seen = 0usize;
for msg in messages {
if self.is_protected(&msg) {
protected.push((compressible_seen, msg));
} else {
compressible.push(msg);
compressible_seen += 1;
}
}
let total_compressible = compressible.len();
let total_protected = protected.len();
let protected = protected
.into_iter()
.enumerate()
.map(|(idx, (compressible_before, message))| ProtectedMessage {
message,
compressible_after: total_compressible.saturating_sub(compressible_before),
protected_after: total_protected.saturating_sub(idx + 1),
})
.collect();
(compressible, protected)
}
fn merge_protected(compressed: Vec<Message>, protected: Vec<ProtectedMessage>) -> Vec<Message> {
if protected.is_empty() {
return compressed;
}
let mut result = compressed;
for protected_msg in protected.into_iter().rev() {
let trailing_slots = protected_msg.compressible_after + protected_msg.protected_after;
let insert_at = result.len().saturating_sub(trailing_slots);
result.insert(insert_at, protected_msg.message);
}
result
}
pub fn set_compressor(&mut self, compressor: impl ContextCompressor + 'static) {
self.compressor = Some(Box::new(compressor));
}
pub fn remove_compressor(&mut self) {
self.compressor = None;
}
pub fn has_compressor(&self) -> bool {
self.compressor.is_some()
}
pub async fn force_compress(&mut self, fallback_window: usize) -> Result<ForceCompressStats> {
let before_count = self.messages.len();
let before_tokens = self.token_estimate();
let (compressible, protected) = self.split_protected(self.messages.clone());
let output = if let Some(compressor) = &self.compressor {
let input = CompressionInput {
messages: compressible,
token_limit: self.token_limit,
current_query: None,
};
compressor.compress(input).await?
} else {
SlidingWindowCompressor::new(fallback_window)
.compress(CompressionInput {
messages: compressible,
token_limit: self.token_limit,
current_query: None,
})
.await?
};
let evicted = output.evicted.len();
self.messages = Self::merge_protected(output.messages, protected);
Ok(ForceCompressStats {
before_count,
after_count: self.messages.len(),
evicted,
before_tokens,
after_tokens: self.token_estimate(),
})
}
pub async fn force_compress_with(
&mut self,
compressor: &dyn ContextCompressor,
) -> Result<ForceCompressStats> {
let before_count = self.messages.len();
let before_tokens = self.token_estimate();
let (compressible, protected) = self.split_protected(self.messages.clone());
let output = compressor
.compress(CompressionInput {
messages: compressible,
token_limit: self.token_limit,
current_query: None,
})
.await?;
let evicted = output.evicted.len();
self.messages = Self::merge_protected(output.messages, protected);
Ok(ForceCompressStats {
before_count,
after_count: self.messages.len(),
evicted,
before_tokens,
after_tokens: self.token_estimate(),
})
}
pub fn update_system(&mut self, new_system_prompt: String) {
if let Some(msg) = self.messages.iter_mut().find(|m| m.role == "system") {
msg.content = MessageContent::Text(new_system_prompt);
} else {
self.messages.insert(0, Message::system(new_system_prompt));
}
}
pub async fn prepare(&mut self, current_query: Option<&str>) -> Result<PrepareResult> {
let compressed = if let Some(compressor) = &self.compressor
&& Self::estimate_tokens(&self.messages, &*self.tokenizer) > self.token_limit
{
let before_count = self.messages.len();
let before_tokens = self.token_estimate();
let (compressible, protected) = self.split_protected(self.messages.clone());
let output = compressor
.compress(CompressionInput {
messages: compressible,
token_limit: self.token_limit,
current_query: current_query.map(String::from),
})
.await?;
let evicted = output.evicted.len();
self.messages = Self::merge_protected(output.messages, protected);
Some(ForceCompressStats {
before_count,
after_count: self.messages.len(),
evicted,
before_tokens,
after_tokens: self.token_estimate(),
})
} else {
None
};
Ok(PrepareResult {
messages: self.messages.clone(),
compressed,
})
}
fn estimate_tokens(messages: &[Message], tokenizer: &dyn Tokenizer) -> usize {
messages
.iter()
.filter_map(|m| m.content.as_text())
.map(|c| tokenizer.count_tokens(&c))
.sum()
}
}
pub struct ContextManagerBuilder {
token_limit: usize,
compressor: Option<Box<dyn ContextCompressor>>,
initial_messages: Vec<Message>,
tokenizer: Option<Arc<dyn Tokenizer>>,
max_messages: Option<usize>,
}
impl ContextManagerBuilder {
pub fn compressor(mut self, c: impl ContextCompressor + 'static) -> Self {
self.compressor = Some(Box::new(c));
self
}
pub fn with_system(mut self, system_prompt: String) -> Self {
self.initial_messages.push(Message::system(system_prompt));
self
}
pub fn tokenizer(mut self, tokenizer: Arc<dyn Tokenizer>) -> Self {
self.tokenizer = Some(tokenizer);
self
}
pub fn max_messages(mut self, max: usize) -> Self {
self.max_messages = Some(max);
self
}
pub fn build(self) -> ContextManager {
ContextManager {
messages: self.initial_messages,
compressor: self.compressor,
token_limit: self.token_limit,
tokenizer: self
.tokenizer
.unwrap_or_else(|| Arc::new(HeuristicTokenizer)),
protected_markers: Vec::new(),
max_messages: self.max_messages.unwrap_or(200),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::compression::compressor::SlidingWindowCompressor;
use echo_core::error::Result;
#[tokio::test]
async fn test_sliding_window_compressor() -> Result<()> {
println!("=== Example 1: Sliding window compression ===");
let mut ctx = ContextManager::builder(200)
.compressor(SlidingWindowCompressor::new(4))
.build();
ctx.push(Message::system("You are an assistant.".to_string()));
for i in 1..=6 {
ctx.push(Message::user(format!("用户消息 {}", i)));
ctx.push(Message::assistant(format!("助手回复 {}", i)));
}
println!("压缩前消息数:{}", ctx.messages().len());
let result = ctx.prepare(None).await?;
let messages = result.messages;
println!("压缩后消息数:{}", messages.len());
for m in &messages {
println!(" [{}] {}", m.role, m.content.as_text_ref().unwrap_or(""));
}
Ok(())
}
#[tokio::test]
async fn test_protected_messages_keep_relative_position_after_compression() -> Result<()> {
let mut ctx = ContextManager::builder(10)
.compressor(SlidingWindowCompressor::new(2))
.build();
ctx.add_protected_marker("<skill>".to_string());
ctx.push(Message::system("system".to_string()));
ctx.push(Message::user("old user".to_string()));
ctx.push(Message::assistant("old assistant".to_string()));
ctx.push(Message::user("<skill> protected".to_string()));
ctx.push(Message::assistant("recent assistant".to_string()));
ctx.push(Message::user("latest user".to_string()));
let messages = ctx.force_compress(2).await?;
assert!(messages.after_count >= 3);
let rendered: Vec<(String, String)> = ctx
.messages()
.iter()
.map(|m| {
(
m.role.clone(),
m.content.as_text_ref().unwrap_or("").to_string(),
)
})
.collect();
assert_eq!(
rendered,
vec![
("system".to_string(), "system".to_string()),
("user".to_string(), "<skill> protected".to_string()),
("assistant".to_string(), "recent assistant".to_string()),
("user".to_string(), "latest user".to_string()),
]
);
Ok(())
}
}