use std::collections::VecDeque;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Role {
System,
User,
Assistant,
Tool,
}
impl Role {
pub fn default_importance(&self) -> f64 {
match self {
Role::System => 1.0,
Role::User => 0.8,
Role::Tool => 0.6,
Role::Assistant => 0.5,
}
}
}
#[derive(Debug, Clone)]
pub struct ContextMessage {
pub role: Role,
pub content: String,
pub token_count: usize,
pub turn: usize,
pub pinned: bool,
}
impl ContextMessage {
pub fn new(role: Role, content: impl Into<String>, turn: usize) -> Self {
let content = content.into();
let token_count = estimate_tokens(&content);
Self {
role,
content,
token_count,
turn,
pinned: false,
}
}
pub fn pinned(role: Role, content: impl Into<String>, turn: usize) -> Self {
let mut msg = Self::new(role, content, turn);
msg.pinned = true;
msg
}
pub fn importance(&self, max_turn: usize) -> f64 {
let role_weight = self.role.default_importance();
let recency = if max_turn == 0 {
1.0
} else {
0.5 + 0.5 * (self.turn as f64 / max_turn as f64)
};
role_weight * recency
}
}
#[derive(Debug, Clone)]
pub struct WindowConfig {
pub max_tokens: usize,
pub reserved_for_response: usize,
pub truncation_strategy: TruncationStrategy,
}
impl Default for WindowConfig {
fn default() -> Self {
Self {
max_tokens: 4096,
reserved_for_response: 512,
truncation_strategy: TruncationStrategy::SlidingWindow,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TruncationStrategy {
SlidingWindow,
PriorityBased,
Compress,
}
#[derive(Debug, Clone)]
pub struct ContextSummary {
pub text: String,
pub token_count: usize,
pub messages_compressed: usize,
pub original_tokens: usize,
}
#[derive(Debug, Clone)]
pub struct WindowState {
pub total_tokens: usize,
pub available_tokens: usize,
pub message_count: usize,
pub is_overflowing: bool,
pub fill_ratio: f64,
}
pub struct ContextWindow {
config: WindowConfig,
messages: VecDeque<ContextMessage>,
summaries: Vec<ContextSummary>,
next_turn: usize,
}
impl ContextWindow {
pub fn new(config: WindowConfig) -> Self {
Self {
config,
messages: VecDeque::new(),
summaries: Vec::new(),
next_turn: 0,
}
}
pub fn add_message(&mut self, role: Role, content: impl Into<String>) {
let msg = ContextMessage::new(role, content, self.next_turn);
self.next_turn += 1;
self.messages.push_back(msg);
if self.is_overflowing() {
self.apply_truncation();
}
}
pub fn add_pinned(&mut self, role: Role, content: impl Into<String>) {
let msg = ContextMessage::pinned(role, content, self.next_turn);
self.next_turn += 1;
self.messages.push_back(msg);
}
pub fn total_tokens(&self) -> usize {
let msg_tokens: usize = self.messages.iter().map(|m| m.token_count).sum();
let summary_tokens: usize = self.summaries.iter().map(|s| s.token_count).sum();
msg_tokens + summary_tokens
}
pub fn available_tokens(&self) -> usize {
let usable = self
.config
.max_tokens
.saturating_sub(self.config.reserved_for_response);
usable.saturating_sub(self.total_tokens())
}
pub fn is_overflowing(&self) -> bool {
let usable = self
.config
.max_tokens
.saturating_sub(self.config.reserved_for_response);
self.total_tokens() > usable
}
pub fn fill_ratio(&self) -> f64 {
let usable = self
.config
.max_tokens
.saturating_sub(self.config.reserved_for_response) as f64;
if usable == 0.0 {
return 1.0;
}
(self.total_tokens() as f64 / usable).min(1.0)
}
pub fn state(&self) -> WindowState {
WindowState {
total_tokens: self.total_tokens(),
available_tokens: self.available_tokens(),
message_count: self.messages.len(),
is_overflowing: self.is_overflowing(),
fill_ratio: self.fill_ratio(),
}
}
pub fn message_count(&self) -> usize {
self.messages.len()
}
pub fn messages(&self) -> impl Iterator<Item = &ContextMessage> {
self.messages.iter()
}
pub fn summaries(&self) -> &[ContextSummary] {
&self.summaries
}
pub fn build_context(&self) -> String {
let mut parts: Vec<String> = Vec::new();
for summary in &self.summaries {
parts.push(format!("[Summary] {}", summary.text));
}
for msg in &self.messages {
let role_str = match msg.role {
Role::System => "system",
Role::User => "user",
Role::Assistant => "assistant",
Role::Tool => "tool",
};
parts.push(format!("[{}] {}", role_str, msg.content));
}
parts.join("\n")
}
pub fn truncate(&mut self) {
self.apply_truncation();
}
pub fn clear(&mut self) {
self.messages.clear();
self.summaries.clear();
self.next_turn = 0;
}
pub fn config(&self) -> &WindowConfig {
&self.config
}
fn apply_truncation(&mut self) {
match self.config.truncation_strategy {
TruncationStrategy::SlidingWindow => self.truncate_sliding(),
TruncationStrategy::PriorityBased => self.truncate_priority(),
TruncationStrategy::Compress => self.truncate_compress(),
}
}
fn truncate_sliding(&mut self) {
while self.is_overflowing() && !self.messages.is_empty() {
let idx = self.messages.iter().position(|m| !m.pinned);
match idx {
Some(i) => {
self.messages.remove(i);
}
None => break, }
}
}
fn truncate_priority(&mut self) {
let max_turn = self.next_turn.saturating_sub(1);
while self.is_overflowing() && !self.messages.is_empty() {
let mut min_idx: Option<usize> = None;
let mut min_importance = f64::MAX;
for (i, msg) in self.messages.iter().enumerate() {
if msg.pinned {
continue;
}
let imp = msg.importance(max_turn);
if imp < min_importance {
min_importance = imp;
min_idx = Some(i);
}
}
match min_idx {
Some(i) => {
self.messages.remove(i);
}
None => break,
}
}
}
fn truncate_compress(&mut self) {
let total_non_pinned = self.messages.iter().filter(|m| !m.pinned).count();
if total_non_pinned <= 1 {
self.truncate_sliding();
return;
}
let to_compress = total_non_pinned / 2;
let mut compressed_texts: Vec<String> = Vec::new();
let mut compressed_tokens: usize = 0;
let mut compressed_count: usize = 0;
let mut indices_to_remove: Vec<usize> = Vec::new();
for (i, msg) in self.messages.iter().enumerate() {
if compressed_count >= to_compress {
break;
}
if !msg.pinned {
compressed_texts.push(msg.content.clone());
compressed_tokens += msg.token_count;
compressed_count += 1;
indices_to_remove.push(i);
}
}
for &i in indices_to_remove.iter().rev() {
self.messages.remove(i);
}
if !compressed_texts.is_empty() {
let summary_text = format!(
"Previous conversation ({} messages): {}",
compressed_count,
compressed_texts.join(" | ")
);
let summary_token_count = estimate_tokens(&summary_text);
self.summaries.push(ContextSummary {
text: summary_text,
token_count: summary_token_count,
messages_compressed: compressed_count,
original_tokens: compressed_tokens,
});
}
}
}
pub fn estimate_tokens(text: &str) -> usize {
let chars = text.len();
if chars == 0 {
0
} else {
(chars / 4).max(1)
}
}
pub fn estimate_tokens_batch(texts: &[&str]) -> Vec<usize> {
texts.iter().map(|t| estimate_tokens(t)).collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn default_window() -> ContextWindow {
ContextWindow::new(WindowConfig::default())
}
fn small_window() -> ContextWindow {
ContextWindow::new(WindowConfig {
max_tokens: 100,
reserved_for_response: 20,
truncation_strategy: TruncationStrategy::SlidingWindow,
})
}
fn priority_window() -> ContextWindow {
ContextWindow::new(WindowConfig {
max_tokens: 100,
reserved_for_response: 20,
truncation_strategy: TruncationStrategy::PriorityBased,
})
}
fn compress_window() -> ContextWindow {
ContextWindow::new(WindowConfig {
max_tokens: 100,
reserved_for_response: 20,
truncation_strategy: TruncationStrategy::Compress,
})
}
#[test]
fn test_estimate_tokens_empty() {
assert_eq!(estimate_tokens(""), 0);
}
#[test]
fn test_estimate_tokens_short() {
assert_eq!(estimate_tokens("Hi"), 1);
}
#[test]
fn test_estimate_tokens_longer() {
let text = "a]".repeat(10); assert_eq!(estimate_tokens(&text), 5);
}
#[test]
fn test_estimate_tokens_batch_length() {
let result = estimate_tokens_batch(&["hello", "world", "foo"]);
assert_eq!(result.len(), 3);
}
#[test]
fn test_system_highest_importance() {
assert!(Role::System.default_importance() > Role::User.default_importance());
assert!(Role::System.default_importance() > Role::Assistant.default_importance());
}
#[test]
fn test_user_higher_than_assistant() {
assert!(Role::User.default_importance() > Role::Assistant.default_importance());
}
#[test]
fn test_message_auto_token_count() {
let msg = ContextMessage::new(Role::User, "Hello world, this is a test!", 0);
assert!(msg.token_count > 0);
}
#[test]
fn test_pinned_message_flag() {
let msg = ContextMessage::pinned(Role::System, "You are a helpful assistant.", 0);
assert!(msg.pinned);
}
#[test]
fn test_message_importance_increases_with_turn() {
let early = ContextMessage::new(Role::User, "early", 0);
let late = ContextMessage::new(Role::User, "late", 10);
assert!(late.importance(10) > early.importance(10));
}
#[test]
fn test_message_importance_system_higher_than_assistant() {
let sys = ContextMessage::new(Role::System, "sys", 5);
let asst = ContextMessage::new(Role::Assistant, "asst", 5);
assert!(sys.importance(10) > asst.importance(10));
}
#[test]
fn test_empty_window_state() {
let w = default_window();
assert_eq!(w.total_tokens(), 0);
assert_eq!(w.message_count(), 0);
assert!(!w.is_overflowing());
}
#[test]
fn test_add_message_increments_count() {
let mut w = default_window();
w.add_message(Role::User, "Hello");
assert_eq!(w.message_count(), 1);
}
#[test]
fn test_add_message_increases_tokens() {
let mut w = default_window();
w.add_message(Role::User, "Hello there");
assert!(w.total_tokens() > 0);
}
#[test]
fn test_clear_resets_window() {
let mut w = default_window();
w.add_message(Role::User, "test");
w.clear();
assert_eq!(w.total_tokens(), 0);
assert_eq!(w.message_count(), 0);
}
#[test]
fn test_overflow_detected() {
let mut w = small_window(); for _ in 0..50 {
w.add_message(
Role::User,
"This is a somewhat long message for testing overflow detection.",
);
}
assert!(!w.is_overflowing());
}
#[test]
fn test_available_tokens_decreases() {
let mut w = small_window();
let before = w.available_tokens();
w.add_message(Role::User, "Some content here");
let after = w.available_tokens();
assert!(after < before);
}
#[test]
fn test_sliding_window_removes_oldest() {
let mut w = small_window();
w.add_message(Role::User, "First message with some content");
w.add_message(Role::User, "Second message with some content");
for i in 0..30 {
w.add_message(Role::User, format!("Message number {} with content", i));
}
assert!(!w.is_overflowing());
assert!(w.message_count() < 32);
}
#[test]
fn test_sliding_window_preserves_pinned() {
let mut w = small_window();
w.add_pinned(Role::System, "Pinned system prompt that must stay");
for i in 0..30 {
w.add_message(Role::User, format!("Overflow message {}", i));
}
let has_pinned = w.messages().any(|m| m.pinned);
assert!(has_pinned, "pinned messages should survive truncation");
}
#[test]
fn test_priority_truncation_removes_low_importance() {
let mut w = priority_window();
w.add_message(
Role::Assistant,
"Low importance assistant response from early turn",
);
w.add_message(Role::System, "High importance system message");
for i in 0..30 {
w.add_message(Role::User, format!("User message {}", i));
}
assert!(!w.is_overflowing());
}
#[test]
fn test_priority_preserves_pinned() {
let mut w = priority_window();
w.add_pinned(Role::System, "Pinned instruction");
for i in 0..30 {
w.add_message(Role::User, format!("Filler message {}", i));
}
let has_pinned = w.messages().any(|m| m.pinned);
assert!(has_pinned);
}
#[test]
fn test_compress_creates_summary() {
let mut w = compress_window();
for i in 0..30 {
w.add_message(Role::User, format!("Message about topic {}", i));
}
assert!(
!w.summaries().is_empty() || w.message_count() < 30,
"compression should either create summaries or reduce messages"
);
}
#[test]
fn test_compress_summary_has_metadata() {
let mut w = compress_window();
for i in 0..30 {
w.add_message(Role::User, format!("Message about topic {}", i));
}
if let Some(summary) = w.summaries().first() {
assert!(summary.messages_compressed > 0);
assert!(summary.original_tokens > 0);
assert!(summary.token_count > 0);
}
}
#[test]
fn test_build_context_includes_messages() {
let mut w = default_window();
w.add_message(Role::User, "Hello AI");
let ctx = w.build_context();
assert!(ctx.contains("Hello AI"));
assert!(ctx.contains("[user]"));
}
#[test]
fn test_build_context_includes_summaries() {
let mut w = compress_window();
for i in 0..30 {
w.add_message(Role::User, format!("Message {}", i));
}
let ctx = w.build_context();
if !w.summaries().is_empty() {
assert!(ctx.contains("[Summary]"));
}
}
#[test]
fn test_build_context_empty_window() {
let w = default_window();
let ctx = w.build_context();
assert!(ctx.is_empty());
}
#[test]
fn test_fill_ratio_empty_is_zero() {
let w = default_window();
assert!((w.fill_ratio() - 0.0).abs() < 1e-10);
}
#[test]
fn test_fill_ratio_increases_with_messages() {
let mut w = default_window();
let before = w.fill_ratio();
w.add_message(Role::User, "Some content to fill the window");
let after = w.fill_ratio();
assert!(after > before);
}
#[test]
fn test_fill_ratio_capped_at_one() {
let mut w = small_window();
for _ in 0..100 {
w.add_pinned(Role::User, "lots of pinned content that cannot be removed");
}
assert!(w.fill_ratio() <= 1.0);
}
#[test]
fn test_state_snapshot_fields() {
let mut w = default_window();
w.add_message(Role::User, "test");
let state = w.state();
assert!(state.total_tokens > 0);
assert_eq!(state.message_count, 1);
assert!(!state.is_overflowing);
}
#[test]
fn test_config_accessor() {
let w = default_window();
assert_eq!(w.config().max_tokens, 4096);
assert_eq!(w.config().reserved_for_response, 512);
}
#[test]
fn test_window_config_default() {
let config = WindowConfig::default();
assert_eq!(config.max_tokens, 4096);
assert_eq!(config.reserved_for_response, 512);
assert_eq!(
config.truncation_strategy,
TruncationStrategy::SlidingWindow
);
}
#[test]
fn test_multi_turn_conversation() {
let mut w = default_window();
for i in 0..5 {
w.add_message(Role::User, format!("User turn {}", i));
w.add_message(Role::Assistant, format!("Assistant turn {}", i));
}
assert_eq!(w.message_count(), 10);
}
#[test]
fn test_turn_numbers_are_sequential() {
let mut w = default_window();
w.add_message(Role::User, "First");
w.add_message(Role::Assistant, "Second");
w.add_message(Role::User, "Third");
let turns: Vec<usize> = w.messages().map(|m| m.turn).collect();
assert_eq!(turns, vec![0, 1, 2]);
}
#[test]
fn test_zero_max_tokens_always_overflows() {
let mut w = ContextWindow::new(WindowConfig {
max_tokens: 0,
reserved_for_response: 0,
truncation_strategy: TruncationStrategy::SlidingWindow,
});
w.add_message(Role::User, "Hello");
assert_eq!(
w.message_count(),
0,
"should truncate everything with 0 budget"
);
}
#[test]
fn test_reserved_larger_than_max() {
let w = ContextWindow::new(WindowConfig {
max_tokens: 10,
reserved_for_response: 20,
truncation_strategy: TruncationStrategy::SlidingWindow,
});
assert_eq!(w.available_tokens(), 0);
}
}