use crate::types::Message;
#[derive(Debug, Clone, Copy)]
pub struct ContextLimits {
pub max_context_tokens: usize,
}
impl ContextLimits {
pub fn new(max_context_tokens: usize) -> Self {
Self { max_context_tokens }
}
}
#[derive(Debug, Clone)]
pub struct ContextUsage {
pub context_tokens: usize,
pub total_messages: usize,
pub context_messages: usize,
pub max_context_tokens: usize,
pub usage_percentage: f32,
}
pub type TokenEstimator<'a> = &'a dyn Fn(&[Message]) -> usize;
pub trait ConversationManager: Send + Sync {
fn add_message(&mut self, message: Message);
fn messages_for_context(
&self,
limits: ContextLimits,
estimate_tokens: TokenEstimator<'_>,
) -> Vec<Message>;
fn all_messages(&self) -> &[Message];
fn hydrate(&mut self, messages: Vec<Message>);
fn clear(&mut self);
fn context_usage(
&self,
limits: ContextLimits,
estimate_tokens: TokenEstimator<'_>,
) -> ContextUsage {
let context_messages = self.messages_for_context(limits, estimate_tokens);
let context_tokens = estimate_tokens(&context_messages);
let max_context_tokens = limits.max_context_tokens;
ContextUsage {
context_tokens,
total_messages: self.all_messages().len(),
context_messages: context_messages.len(),
max_context_tokens,
usage_percentage: if max_context_tokens > 0 {
context_tokens as f32 / max_context_tokens as f32
} else {
0.0
},
}
}
}
#[derive(Debug, Clone)]
pub struct SlidingWindowConversationManager {
messages: Vec<Message>,
system_prompt_reserve: f32,
response_reserve: f32,
}
impl Default for SlidingWindowConversationManager {
fn default() -> Self {
Self::new()
}
}
impl SlidingWindowConversationManager {
pub fn new() -> Self {
Self {
messages: Vec::new(),
system_prompt_reserve: 0.10,
response_reserve: 0.20,
}
}
pub fn with_reserve(system_prompt_reserve: f32, response_reserve: f32) -> Self {
Self {
messages: Vec::new(),
system_prompt_reserve: system_prompt_reserve.clamp(0.0, 0.5),
response_reserve: response_reserve.clamp(0.0, 0.5),
}
}
fn available_tokens(&self, limits: ContextLimits) -> usize {
let max = limits.max_context_tokens;
let reserved = (max as f32 * (self.system_prompt_reserve + self.response_reserve)) as usize;
max.saturating_sub(reserved)
}
}
impl ConversationManager for SlidingWindowConversationManager {
fn add_message(&mut self, message: Message) {
self.messages.push(message);
}
fn messages_for_context(
&self,
limits: ContextLimits,
estimate_tokens: TokenEstimator<'_>,
) -> Vec<Message> {
let available = self.available_tokens(limits);
let mut result = Vec::new();
let mut total_tokens = 0;
for message in self.messages.iter().rev() {
let msg_tokens = estimate_tokens(std::slice::from_ref(message));
if total_tokens + msg_tokens <= available {
result.push(message.clone());
total_tokens += msg_tokens;
} else {
break;
}
}
result.reverse();
result
}
fn all_messages(&self) -> &[Message] {
&self.messages
}
fn hydrate(&mut self, messages: Vec<Message>) {
self.messages = messages;
}
fn clear(&mut self) {
self.messages.clear();
}
}
#[derive(Debug, Clone)]
pub struct SimpleConversationManager {
messages: Vec<Message>,
max_messages: usize,
}
impl SimpleConversationManager {
pub fn new(max_messages: usize) -> Self {
Self {
messages: Vec::new(),
max_messages,
}
}
}
impl ConversationManager for SimpleConversationManager {
fn add_message(&mut self, message: Message) {
self.messages.push(message);
}
fn messages_for_context(
&self,
_limits: ContextLimits,
_estimate_tokens: TokenEstimator<'_>,
) -> Vec<Message> {
let start = self.messages.len().saturating_sub(self.max_messages);
self.messages[start..].to_vec()
}
fn all_messages(&self) -> &[Message] {
&self.messages
}
fn hydrate(&mut self, messages: Vec<Message>) {
self.messages = messages;
}
fn clear(&mut self) {
self.messages.clear();
}
}
#[derive(Debug, Clone, Default)]
pub struct NoOpConversationManager {
messages: Vec<Message>,
}
impl NoOpConversationManager {
pub fn new() -> Self {
Self {
messages: Vec::new(),
}
}
}
impl ConversationManager for NoOpConversationManager {
fn add_message(&mut self, message: Message) {
self.messages.push(message);
}
fn messages_for_context(
&self,
_limits: ContextLimits,
_estimate_tokens: TokenEstimator<'_>,
) -> Vec<Message> {
self.messages.clone()
}
fn all_messages(&self) -> &[Message] {
&self.messages
}
fn hydrate(&mut self, messages: Vec<Message>) {
self.messages = messages;
}
fn clear(&mut self) {
self.messages.clear();
}
}
pub type BoxedConversationManager = Box<dyn ConversationManager>;
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{ContentBlock, Role};
fn make_message(text: &str) -> Message {
Message {
role: Role::User,
content: vec![ContentBlock::Text(text.to_string())],
}
}
fn estimate_tokens(messages: &[Message]) -> usize {
messages.iter().map(|m| m.text().len() + 4).sum()
}
#[test]
fn test_sliding_window_basic() {
let mut manager = SlidingWindowConversationManager::new();
let limits = ContextLimits::new(1000);
manager.add_message(make_message("Hello"));
manager.add_message(make_message("World"));
let context = manager.messages_for_context(limits, &estimate_tokens);
assert_eq!(context.len(), 2);
}
#[test]
fn test_sliding_window_truncates() {
let mut manager = SlidingWindowConversationManager::with_reserve(0.0, 0.0);
let limits = ContextLimits::new(50);
manager.add_message(make_message("This is a long message one"));
manager.add_message(make_message("This is a long message two"));
manager.add_message(make_message("Short"));
let context = manager.messages_for_context(limits, &estimate_tokens);
assert!(context.len() < 3);
assert_eq!(context.last().unwrap().text(), "Short");
}
#[test]
fn test_sliding_window_hydrate() {
let mut manager = SlidingWindowConversationManager::new();
let messages = vec![
make_message("One"),
make_message("Two"),
make_message("Three"),
];
manager.hydrate(messages);
assert_eq!(manager.all_messages().len(), 3);
}
#[test]
fn test_simple_manager_limits() {
let mut manager = SimpleConversationManager::new(2);
let limits = ContextLimits::new(10000);
manager.add_message(make_message("One"));
manager.add_message(make_message("Two"));
manager.add_message(make_message("Three"));
manager.add_message(make_message("Four"));
assert_eq!(manager.all_messages().len(), 4);
let context = manager.messages_for_context(limits, &estimate_tokens);
assert_eq!(context.len(), 2);
assert_eq!(context[0].text(), "Three");
assert_eq!(context[1].text(), "Four");
}
#[test]
fn test_noop_manager() {
let mut manager = NoOpConversationManager::new();
let limits = ContextLimits::new(10000);
manager.add_message(make_message("One"));
manager.add_message(make_message("Two"));
manager.add_message(make_message("Three"));
let context = manager.messages_for_context(limits, &estimate_tokens);
assert_eq!(context.len(), 3);
}
#[test]
fn test_context_usage() {
let mut manager = SlidingWindowConversationManager::new();
let limits = ContextLimits::new(1000);
manager.add_message(make_message("Hello"));
manager.add_message(make_message("World"));
let usage = manager.context_usage(limits, &estimate_tokens);
assert_eq!(usage.total_messages, 2);
assert_eq!(usage.context_messages, 2);
assert!(usage.usage_percentage > 0.0);
assert!(usage.usage_percentage < 1.0);
}
#[test]
fn test_clear() {
let mut manager = SlidingWindowConversationManager::new();
manager.add_message(make_message("Hello"));
manager.add_message(make_message("World"));
assert_eq!(manager.all_messages().len(), 2);
manager.clear();
assert_eq!(manager.all_messages().len(), 0);
}
}