use anyhow::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompactionConfig {
pub max_tokens: usize,
pub threshold_ratio: f64,
pub strategy: CompactionStrategy,
pub preserve_system_messages: bool,
pub preserve_recent_count: usize,
pub compression_level: i32,
}
impl Default for CompactionConfig {
fn default() -> Self {
Self {
max_tokens: 200_000, threshold_ratio: 0.8,
strategy: CompactionStrategy::SmartSummarize,
preserve_system_messages: true,
preserve_recent_count: 5,
compression_level: 3,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CompactionStrategy {
Truncate,
Summarize,
SmartSummarize,
SlidingWindow,
Hybrid,
}
impl std::fmt::Display for CompactionStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CompactionStrategy::Truncate => write!(f, "truncate"),
CompactionStrategy::Summarize => write!(f, "summarize"),
CompactionStrategy::SmartSummarize => write!(f, "smart_summarize"),
CompactionStrategy::SlidingWindow => write!(f, "sliding_window"),
CompactionStrategy::Hybrid => write!(f, "hybrid"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompactionResult {
pub compacted: bool,
pub original_tokens: usize,
pub final_tokens: usize,
pub messages_removed: usize,
pub messages_summarized: usize,
pub bytes_saved: usize,
pub compression_ratio: f64,
pub strategy_used: CompactionStrategy,
pub summary: Option<String>,
}
impl CompactionResult {
pub fn no_compaction(token_count: usize) -> Self {
Self {
compacted: false,
original_tokens: token_count,
final_tokens: token_count,
messages_removed: 0,
messages_summarized: 0,
bytes_saved: 0,
compression_ratio: 1.0,
strategy_used: CompactionStrategy::Truncate,
summary: None,
}
}
pub fn savings_percentage(&self) -> f64 {
if self.original_tokens == 0 {
return 0.0;
}
(1.0 - (self.final_tokens as f64 / self.original_tokens as f64)) * 100.0
}
}
#[async_trait]
pub trait ContextCompactor: Send + Sync {
async fn compact(&mut self, config: &CompactionConfig) -> Result<CompactionResult>;
async fn needs_compaction(&self, threshold_tokens: usize) -> bool;
async fn get_token_count(&self) -> Result<usize>;
async fn get_context_size(&self) -> Result<usize>;
async fn clear_context(&mut self) -> Result<()>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextMessage {
pub role: String,
pub content: String,
pub token_count: usize,
pub timestamp: chrono::DateTime<chrono::Utc>,
pub importance: f64,
pub preserve: bool,
}
impl ContextMessage {
pub fn new(role: impl Into<String>, content: impl Into<String>) -> Self {
let content_str: String = content.into();
let token_count = content_str.len() / 4;
Self {
role: role.into(),
content: content_str,
token_count,
timestamp: chrono::Utc::now(),
importance: 0.5,
preserve: false,
}
}
pub fn with_importance(mut self, importance: f64) -> Self {
self.importance = importance.clamp(0.0, 1.0);
self
}
pub fn preserve(mut self) -> Self {
self.preserve = true;
self
}
}
#[derive(Debug, Clone, Default)]
pub struct ContextHistory {
messages: Vec<ContextMessage>,
total_tokens: usize,
}
impl ContextHistory {
pub fn new() -> Self {
Self::default()
}
pub fn add_message(&mut self, message: ContextMessage) {
self.total_tokens += message.token_count;
self.messages.push(message);
}
pub fn token_count(&self) -> usize {
self.total_tokens
}
pub fn message_count(&self) -> usize {
self.messages.len()
}
pub fn messages(&self) -> &[ContextMessage] {
&self.messages
}
pub fn clear(&mut self) {
self.messages.clear();
self.total_tokens = 0;
}
pub fn apply_truncation(&mut self, target_tokens: usize, preserve_recent: usize) -> usize {
if self.total_tokens <= target_tokens {
return 0;
}
let mut removed = 0;
let preserve_from = self.messages.len().saturating_sub(preserve_recent);
while self.total_tokens > target_tokens && !self.messages.is_empty() {
let remove_idx = self.messages.iter().enumerate().find_map(|(i, m)| {
if i < preserve_from && !m.preserve {
Some(i)
} else {
None
}
});
match remove_idx {
Some(idx) => {
let msg = self.messages.remove(idx);
self.total_tokens = self.total_tokens.saturating_sub(msg.token_count);
removed += 1;
}
None => break, }
}
removed
}
pub fn apply_sliding_window(&mut self, window_size: usize) -> usize {
if self.messages.len() <= window_size {
return 0;
}
let to_remove = self.messages.len() - window_size;
let (preserved, removable): (Vec<_>, Vec<_>) =
self.messages.drain(..to_remove).partition(|m| m.preserve);
let preserved_tokens: usize = preserved.iter().map(|m| m.token_count).sum();
let removed_tokens: usize = removable.iter().map(|m| m.token_count).sum();
for msg in preserved.into_iter().rev() {
self.messages.insert(0, msg);
}
self.total_tokens = self.total_tokens.saturating_sub(removed_tokens);
self.total_tokens += preserved_tokens;
removable.len()
}
pub fn score_importance(&mut self) {
let message_count = self.messages.len();
for (i, msg) in self.messages.iter_mut().enumerate() {
let mut score = 0.5;
if msg.role == "system" {
score += 0.3;
}
let recency = i as f64 / message_count.max(1) as f64;
score += recency * 0.2;
if msg.content.len() > 500 {
score += 0.1;
}
if msg.content.contains("```") {
score += 0.1;
}
if msg.content.to_lowercase().contains("error") {
score += 0.1;
}
msg.importance = score.clamp(0.0, 1.0);
}
}
}
pub fn compress_context(data: &[u8], level: i32) -> Result<Vec<u8>> {
let _ = level;
Ok(data.to_vec())
}
pub fn decompress_context(data: &[u8]) -> Result<Vec<u8>> {
Ok(data.to_vec())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_message_creation() {
let msg = ContextMessage::new("user", "Hello, world!")
.with_importance(0.8)
.preserve();
assert_eq!(msg.role, "user");
assert_eq!(msg.importance, 0.8);
assert!(msg.preserve);
}
#[test]
fn test_context_history() {
let mut history = ContextHistory::new();
history.add_message(ContextMessage::new("user", "Hello"));
history.add_message(ContextMessage::new("assistant", "Hi there!"));
assert_eq!(history.message_count(), 2);
assert!(history.token_count() > 0);
}
#[test]
fn test_truncation() {
let mut history = ContextHistory::new();
for i in 0..10 {
let content = format!("This is a longer message number {} with more content", i);
history.add_message(ContextMessage::new("user", content));
}
let original_count = history.message_count();
let original_tokens = history.token_count();
let removed = history.apply_truncation(50, 3);
assert!(
removed > 0,
"Expected to remove messages, original_tokens={}, target=50",
original_tokens
);
assert!(history.message_count() < original_count);
}
#[test]
fn test_sliding_window() {
let mut history = ContextHistory::new();
for i in 0..10 {
history.add_message(ContextMessage::new("user", format!("Message {}", i)));
}
let removed = history.apply_sliding_window(5);
assert_eq!(removed, 5);
assert_eq!(history.message_count(), 5);
}
#[test]
fn test_compaction_result_savings() {
let result = CompactionResult {
compacted: true,
original_tokens: 100_000,
final_tokens: 7_000,
messages_removed: 50,
messages_summarized: 10,
bytes_saved: 500_000,
compression_ratio: 0.07,
strategy_used: CompactionStrategy::SmartSummarize,
summary: Some("Summary of conversation".to_string()),
};
assert_eq!(result.savings_percentage(), 93.0);
}
}