#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TruncationStrategy {
TruncateLeft,
TruncateRight,
SlidingWindow,
Summarize,
}
#[derive(Debug)]
pub struct ContextError(String);
impl std::fmt::Display for ContextError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "ContextError: {}", self.0)
}
}
impl std::error::Error for ContextError {}
pub struct ContextWindow {
pub max_tokens: usize,
pub system_tokens: Vec<u32>,
pub conversation: Vec<u32>,
pub strategy: TruncationStrategy,
}
impl ContextWindow {
pub fn new(max_tokens: usize, strategy: TruncationStrategy) -> Self {
Self {
max_tokens,
system_tokens: Vec::new(),
conversation: Vec::new(),
strategy,
}
}
pub fn set_system_prompt(&mut self, tokens: Vec<u32>) -> Result<(), ContextError> {
if tokens.len() > self.max_tokens {
return Err(ContextError(format!(
"system prompt ({} tokens) exceeds max_tokens ({})",
tokens.len(),
self.max_tokens
)));
}
self.system_tokens = tokens;
self.truncate_to_fit();
Ok(())
}
pub fn append(&mut self, tokens: &[u32]) -> usize {
self.conversation.extend_from_slice(tokens);
let removed = self.truncate_to_fit();
tokens.len().saturating_sub(removed)
}
pub fn truncate_to_fit(&mut self) -> usize {
let capacity_for_conv = self.max_tokens.saturating_sub(self.system_tokens.len());
if self.conversation.len() <= capacity_for_conv {
return 0;
}
let excess = self.conversation.len() - capacity_for_conv;
match self.strategy {
TruncationStrategy::TruncateLeft
| TruncationStrategy::SlidingWindow
| TruncationStrategy::Summarize => {
self.conversation.drain(0..excess);
}
TruncationStrategy::TruncateRight => {
let new_len = self.conversation.len() - excess;
self.conversation.truncate(new_len);
}
}
excess
}
pub fn tokens(&self) -> Vec<u32> {
let mut result = Vec::with_capacity(self.system_tokens.len() + self.conversation.len());
result.extend_from_slice(&self.system_tokens);
result.extend_from_slice(&self.conversation);
result
}
pub fn len(&self) -> usize {
self.system_tokens.len() + self.conversation.len()
}
pub fn is_empty(&self) -> bool {
self.system_tokens.is_empty() && self.conversation.is_empty()
}
pub fn remaining_capacity(&self) -> usize {
self.max_tokens.saturating_sub(self.len())
}
pub fn is_at_limit(&self) -> bool {
self.len() >= self.max_tokens
}
pub fn clear_conversation(&mut self) {
self.conversation.clear();
}
pub fn utilization(&self) -> f32 {
if self.max_tokens == 0 {
return 0.0;
}
self.len() as f32 / self.max_tokens as f32
}
}
pub struct ConversationTurn {
pub role: String,
pub content: String,
pub token_ids: Vec<u32>,
}
pub struct ConversationContext {
window: ContextWindow,
turns: Vec<ConversationTurn>,
}
impl ConversationContext {
pub fn new(max_tokens: usize) -> Self {
Self {
window: ContextWindow::new(max_tokens, TruncationStrategy::TruncateLeft),
turns: Vec::new(),
}
}
pub fn add_turn(&mut self, role: &str, content: &str, token_ids: Vec<u32>) {
self.window.append(&token_ids);
self.turns.push(ConversationTurn {
role: role.to_string(),
content: content.to_string(),
token_ids,
});
}
pub fn build_tokens(&self) -> Vec<u32> {
self.window.tokens()
}
pub fn turn_count(&self) -> usize {
self.turns.len()
}
pub fn total_tokens(&self) -> usize {
self.window.len()
}
pub fn is_full(&self) -> bool {
self.window.is_at_limit()
}
pub fn clear(&mut self) {
self.turns.clear();
self.window.clear_conversation();
}
pub fn last_turn(&self) -> Option<&ConversationTurn> {
self.turns.last()
}
pub fn utilization(&self) -> f32 {
self.window.utilization()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_window_append_within_limit() {
let mut window = ContextWindow::new(100, TruncationStrategy::TruncateLeft);
let appended = window.append(&[1, 2, 3, 4, 5]);
assert!(appended > 0, "should append tokens when within limit");
assert_eq!(window.conversation.len(), 5);
assert_eq!(window.len(), 5);
}
#[test]
fn test_context_window_truncate_left() {
let mut window = ContextWindow::new(5, TruncationStrategy::TruncateLeft);
window.append(&[1, 2, 3, 4, 5]);
assert_eq!(window.conversation.len(), 5);
window.append(&[6, 7]);
assert_eq!(
window.conversation.len(),
5,
"should still be at max after truncation"
);
let last = *window.conversation.last().expect("must have tokens");
assert_eq!(last, 7, "newest token should be 7");
assert!(
!window.conversation.contains(&1),
"token 1 should have been truncated"
);
}
#[test]
fn test_context_window_truncate_right() {
let mut window = ContextWindow::new(5, TruncationStrategy::TruncateRight);
window.append(&[1, 2, 3, 4, 5]);
window.append(&[6, 7]);
assert_eq!(window.conversation.len(), 5);
assert_eq!(
window.conversation[0], 1,
"token 1 should be preserved with TruncateRight"
);
assert!(
!window.conversation.contains(&6),
"token 6 should have been truncated"
);
}
#[test]
fn test_context_window_system_prompt_preserved() {
let mut window = ContextWindow::new(10, TruncationStrategy::TruncateLeft);
window
.set_system_prompt(vec![100, 200, 300])
.expect("system prompt should fit");
window.append(&[1, 2, 3, 4, 5, 6, 7]);
assert_eq!(window.len(), 10);
window.append(&[8, 9]);
let tokens = window.tokens();
assert_eq!(tokens.len(), 10);
assert_eq!(tokens[0], 100, "system token 0 must be preserved");
assert_eq!(tokens[1], 200, "system token 1 must be preserved");
assert_eq!(tokens[2], 300, "system token 2 must be preserved");
}
#[test]
fn test_context_window_remaining_capacity() {
let mut window = ContextWindow::new(20, TruncationStrategy::TruncateLeft);
assert_eq!(window.remaining_capacity(), 20);
window.append(&[1, 2, 3]);
assert_eq!(window.remaining_capacity(), 17);
window.set_system_prompt(vec![10, 20]).expect("fits");
assert_eq!(window.remaining_capacity(), 15);
}
#[test]
fn test_context_window_system_prompt_too_large() {
let mut window = ContextWindow::new(5, TruncationStrategy::TruncateLeft);
let result = window.set_system_prompt(vec![1, 2, 3, 4, 5, 6]);
assert!(
result.is_err(),
"system prompt larger than max_tokens should error"
);
}
#[test]
fn test_conversation_context_add_turn() {
let mut ctx = ConversationContext::new(200);
ctx.add_turn("user", "Hello!", vec![10, 20, 30]);
ctx.add_turn("assistant", "Hi there!", vec![40, 50, 60, 70]);
assert_eq!(ctx.turn_count(), 2);
assert_eq!(ctx.total_tokens(), 7, "3 + 4 = 7 tokens total");
let last = ctx.last_turn().expect("must have a last turn");
assert_eq!(last.role, "assistant");
assert_eq!(last.content, "Hi there!");
}
#[test]
fn test_conversation_context_build_tokens() {
let mut ctx = ConversationContext::new(100);
ctx.add_turn("user", "A", vec![1, 2]);
ctx.add_turn("assistant", "B", vec![3, 4, 5]);
let tokens = ctx.build_tokens();
assert_eq!(
tokens,
vec![1, 2, 3, 4, 5],
"tokens should be in turn order"
);
}
#[test]
fn test_context_utilization() {
let mut window = ContextWindow::new(100, TruncationStrategy::TruncateLeft);
assert!(
(window.utilization() - 0.0).abs() < f32::EPSILON,
"empty window has 0.0 utilization"
);
window.append(&(0u32..50).collect::<Vec<_>>());
assert!(
(window.utilization() - 0.5).abs() < f32::EPSILON,
"50/100 = 0.5 utilization"
);
window.append(&(0u32..50).collect::<Vec<_>>());
assert!(
(window.utilization() - 1.0).abs() < f32::EPSILON,
"full window = 1.0 utilization"
);
}
}