use crate::session::Message;
#[derive(Debug, Clone, PartialEq)]
pub enum CompactionStrategy {
None,
Summarize { keep_recent: usize },
Truncate { keep_recent: usize },
}
pub struct ContextMonitor {
context_limit: usize,
threshold: f64,
}
impl ContextMonitor {
pub fn new(context_limit: usize, threshold: f64) -> Self {
Self {
context_limit,
threshold,
}
}
pub fn estimate_tokens(messages: &[Message]) -> usize {
messages
.iter()
.map(|msg| {
let word_count = msg.content.split_whitespace().count();
(word_count as f64 * 1.3 + 4.0) as usize
})
.sum()
}
pub fn needs_compaction(&self, messages: &[Message]) -> bool {
let estimated = Self::estimate_tokens(messages);
estimated as f64 > self.threshold * self.context_limit as f64
}
pub fn suggest_strategy(&self, messages: &[Message]) -> CompactionStrategy {
let estimated = Self::estimate_tokens(messages);
let ratio = estimated as f64 / self.context_limit as f64;
if ratio <= self.threshold {
CompactionStrategy::None
} else if ratio > 0.95 {
CompactionStrategy::Truncate { keep_recent: 3 }
} else if ratio > 0.85 {
CompactionStrategy::Summarize { keep_recent: 5 }
} else {
CompactionStrategy::Summarize { keep_recent: 8 }
}
}
}
impl Default for ContextMonitor {
fn default() -> Self {
Self {
context_limit: 100_000,
threshold: 0.80,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_message(content: &str) -> Message {
Message::user(content)
}
#[test]
fn test_estimate_tokens_empty_messages() {
let messages: Vec<Message> = vec![];
assert_eq!(ContextMonitor::estimate_tokens(&messages), 0);
}
#[test]
fn test_estimate_tokens_single_message() {
let messages = vec![make_message("Hello world")];
assert_eq!(ContextMonitor::estimate_tokens(&messages), 6);
}
#[test]
fn test_estimate_tokens_multiple_messages() {
let messages = vec![
make_message("Hello world"), make_message("How are you today"), ];
assert_eq!(ContextMonitor::estimate_tokens(&messages), 6 + 9);
}
#[test]
fn test_estimate_tokens_empty_content() {
let messages = vec![make_message("")];
assert_eq!(ContextMonitor::estimate_tokens(&messages), 4);
}
#[test]
fn test_estimate_tokens_long_content() {
let messages = vec![make_message(
"one two three four five six seven eight nine ten",
)];
assert_eq!(ContextMonitor::estimate_tokens(&messages), 17);
}
#[test]
fn test_word_count_with_extra_whitespace() {
let messages = vec![make_message("hello world")];
assert_eq!(ContextMonitor::estimate_tokens(&messages), 6);
}
#[test]
fn test_needs_compaction_below_threshold() {
let monitor = ContextMonitor::new(1000, 0.80);
let messages = vec![make_message("Hello")];
assert!(!monitor.needs_compaction(&messages));
}
#[test]
fn test_needs_compaction_above_threshold() {
let monitor = ContextMonitor::new(100, 0.80);
let messages: Vec<Message> = (0..5)
.map(|_| make_message("one two three four five six seven eight nine ten"))
.collect();
assert!(monitor.needs_compaction(&messages));
}
#[test]
fn test_strategy_below_threshold() {
let monitor = ContextMonitor::new(100_000, 0.80);
let messages = vec![make_message("Hello world")];
assert_eq!(
monitor.suggest_strategy(&messages),
CompactionStrategy::None
);
}
#[test]
fn test_strategy_above_threshold_below_85() {
let monitor = ContextMonitor::new(100, 0.80);
let mut messages: Vec<Message> = (0..4)
.map(|_| make_message("one two three four five six seven eight nine ten"))
.collect();
messages.push(make_message("one two three four five six seven eight"));
assert_eq!(ContextMonitor::estimate_tokens(&messages), 82);
assert_eq!(
monitor.suggest_strategy(&messages),
CompactionStrategy::Summarize { keep_recent: 8 }
);
}
#[test]
fn test_strategy_above_85() {
let monitor = ContextMonitor::new(100, 0.80);
let mut messages: Vec<Message> = (0..5)
.map(|_| make_message("one two three four five six seven eight nine ten"))
.collect();
messages.push(make_message("a b c d"));
assert_eq!(ContextMonitor::estimate_tokens(&messages), 94);
assert_eq!(
monitor.suggest_strategy(&messages),
CompactionStrategy::Summarize { keep_recent: 5 }
);
}
#[test]
fn test_strategy_above_95() {
let monitor = ContextMonitor::new(100, 0.80);
let messages: Vec<Message> = (0..6)
.map(|_| make_message("one two three four five six seven eight nine ten"))
.collect();
assert_eq!(ContextMonitor::estimate_tokens(&messages), 102);
assert_eq!(
monitor.suggest_strategy(&messages),
CompactionStrategy::Truncate { keep_recent: 3 }
);
}
#[test]
fn test_empty_message_list_strategy() {
let monitor = ContextMonitor::new(100_000, 0.80);
assert_eq!(monitor.suggest_strategy(&[]), CompactionStrategy::None);
assert!(!monitor.needs_compaction(&[]));
}
#[test]
fn test_single_message_no_compaction() {
let monitor = ContextMonitor::new(100_000, 0.80);
let messages = vec![make_message("Just one message here")];
assert!(!monitor.needs_compaction(&messages));
assert_eq!(
monitor.suggest_strategy(&messages),
CompactionStrategy::None
);
}
#[test]
fn test_custom_threshold() {
let monitor = ContextMonitor::new(100, 0.10);
let messages = vec![make_message("Hello world")];
assert!(!monitor.needs_compaction(&messages));
let messages = vec![make_message("Hello world"), make_message("Hello world")];
assert!(monitor.needs_compaction(&messages));
}
#[test]
fn test_default_values() {
let monitor = ContextMonitor::default();
let messages = vec![make_message("Hello"), make_message("World")];
assert!(!monitor.needs_compaction(&messages));
assert_eq!(
monitor.suggest_strategy(&messages),
CompactionStrategy::None
);
}
}