mixtape_core/
conversation.rs1use crate::types::Message;
21
22#[derive(Debug, Clone, Copy)]
27pub struct ContextLimits {
28 pub max_context_tokens: usize,
30}
31
32impl ContextLimits {
33 pub fn new(max_context_tokens: usize) -> Self {
35 Self { max_context_tokens }
36 }
37}
38
39#[derive(Debug, Clone)]
41pub struct ContextUsage {
42 pub context_tokens: usize,
44 pub total_messages: usize,
46 pub context_messages: usize,
48 pub max_context_tokens: usize,
50 pub usage_percentage: f32,
52}
53
54pub type TokenEstimator<'a> = &'a dyn Fn(&[Message]) -> usize;
58
59pub trait ConversationManager: Send + Sync {
65 fn add_message(&mut self, message: Message);
67
68 fn messages_for_context(
77 &self,
78 limits: ContextLimits,
79 estimate_tokens: TokenEstimator<'_>,
80 ) -> Vec<Message>;
81
82 fn all_messages(&self) -> &[Message];
84
85 fn hydrate(&mut self, messages: Vec<Message>);
87
88 fn clear(&mut self);
90
91 fn context_usage(
93 &self,
94 limits: ContextLimits,
95 estimate_tokens: TokenEstimator<'_>,
96 ) -> ContextUsage {
97 let context_messages = self.messages_for_context(limits, estimate_tokens);
98 let context_tokens = estimate_tokens(&context_messages);
99 let max_context_tokens = limits.max_context_tokens;
100
101 ContextUsage {
102 context_tokens,
103 total_messages: self.all_messages().len(),
104 context_messages: context_messages.len(),
105 max_context_tokens,
106 usage_percentage: if max_context_tokens > 0 {
107 context_tokens as f32 / max_context_tokens as f32
108 } else {
109 0.0
110 },
111 }
112 }
113}
114
115#[derive(Debug, Clone)]
135pub struct SlidingWindowConversationManager {
136 messages: Vec<Message>,
137 system_prompt_reserve: f32,
139 response_reserve: f32,
141}
142
143impl Default for SlidingWindowConversationManager {
144 fn default() -> Self {
145 Self::new()
146 }
147}
148
149impl SlidingWindowConversationManager {
150 pub fn new() -> Self {
156 Self {
157 messages: Vec::new(),
158 system_prompt_reserve: 0.10,
159 response_reserve: 0.20,
160 }
161 }
162
163 pub fn with_reserve(system_prompt_reserve: f32, response_reserve: f32) -> Self {
169 Self {
170 messages: Vec::new(),
171 system_prompt_reserve: system_prompt_reserve.clamp(0.0, 0.5),
172 response_reserve: response_reserve.clamp(0.0, 0.5),
173 }
174 }
175
176 fn available_tokens(&self, limits: ContextLimits) -> usize {
178 let max = limits.max_context_tokens;
179 let reserved = (max as f32 * (self.system_prompt_reserve + self.response_reserve)) as usize;
180 max.saturating_sub(reserved)
181 }
182}
183
184impl ConversationManager for SlidingWindowConversationManager {
185 fn add_message(&mut self, message: Message) {
186 self.messages.push(message);
187 }
188
189 fn messages_for_context(
190 &self,
191 limits: ContextLimits,
192 estimate_tokens: TokenEstimator<'_>,
193 ) -> Vec<Message> {
194 let available = self.available_tokens(limits);
195
196 let mut result = Vec::new();
198 let mut total_tokens = 0;
199
200 for message in self.messages.iter().rev() {
201 let msg_tokens = estimate_tokens(std::slice::from_ref(message));
202
203 if total_tokens + msg_tokens <= available {
204 result.push(message.clone());
205 total_tokens += msg_tokens;
206 } else {
207 break;
209 }
210 }
211
212 result.reverse();
214 result
215 }
216
217 fn all_messages(&self) -> &[Message] {
218 &self.messages
219 }
220
221 fn hydrate(&mut self, messages: Vec<Message>) {
222 self.messages = messages;
223 }
224
225 fn clear(&mut self) {
226 self.messages.clear();
227 }
228}
229
230#[derive(Debug, Clone)]
246pub struct SimpleConversationManager {
247 messages: Vec<Message>,
248 max_messages: usize,
249}
250
251impl SimpleConversationManager {
252 pub fn new(max_messages: usize) -> Self {
254 Self {
255 messages: Vec::new(),
256 max_messages,
257 }
258 }
259}
260
261impl ConversationManager for SimpleConversationManager {
262 fn add_message(&mut self, message: Message) {
263 self.messages.push(message);
264 }
265
266 fn messages_for_context(
267 &self,
268 _limits: ContextLimits,
269 _estimate_tokens: TokenEstimator<'_>,
270 ) -> Vec<Message> {
271 let start = self.messages.len().saturating_sub(self.max_messages);
272 self.messages[start..].to_vec()
273 }
274
275 fn all_messages(&self) -> &[Message] {
276 &self.messages
277 }
278
279 fn hydrate(&mut self, messages: Vec<Message>) {
280 self.messages = messages;
281 }
282
283 fn clear(&mut self) {
284 self.messages.clear();
285 }
286}
287
288#[derive(Debug, Clone, Default)]
303pub struct NoOpConversationManager {
304 messages: Vec<Message>,
305}
306
307impl NoOpConversationManager {
308 pub fn new() -> Self {
310 Self {
311 messages: Vec::new(),
312 }
313 }
314}
315
316impl ConversationManager for NoOpConversationManager {
317 fn add_message(&mut self, message: Message) {
318 self.messages.push(message);
319 }
320
321 fn messages_for_context(
322 &self,
323 _limits: ContextLimits,
324 _estimate_tokens: TokenEstimator<'_>,
325 ) -> Vec<Message> {
326 self.messages.clone()
327 }
328
329 fn all_messages(&self) -> &[Message] {
330 &self.messages
331 }
332
333 fn hydrate(&mut self, messages: Vec<Message>) {
334 self.messages = messages;
335 }
336
337 fn clear(&mut self) {
338 self.messages.clear();
339 }
340}
341
342pub type BoxedConversationManager = Box<dyn ConversationManager>;
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348 use crate::types::{ContentBlock, Role};
349
350 fn make_message(text: &str) -> Message {
351 Message {
352 role: Role::User,
353 content: vec![ContentBlock::Text(text.to_string())],
354 }
355 }
356
357 fn estimate_tokens(messages: &[Message]) -> usize {
359 messages.iter().map(|m| m.text().len() + 4).sum()
360 }
361
362 #[test]
363 fn test_sliding_window_basic() {
364 let mut manager = SlidingWindowConversationManager::new();
365 let limits = ContextLimits::new(1000);
366
367 manager.add_message(make_message("Hello"));
368 manager.add_message(make_message("World"));
369
370 let context = manager.messages_for_context(limits, &estimate_tokens);
371 assert_eq!(context.len(), 2);
372 }
373
374 #[test]
375 fn test_sliding_window_truncates() {
376 let mut manager = SlidingWindowConversationManager::with_reserve(0.0, 0.0);
377 let limits = ContextLimits::new(50);
379
380 manager.add_message(make_message("This is a long message one"));
382 manager.add_message(make_message("This is a long message two"));
383 manager.add_message(make_message("Short"));
384
385 let context = manager.messages_for_context(limits, &estimate_tokens);
386 assert!(context.len() < 3);
388 assert_eq!(context.last().unwrap().text(), "Short");
390 }
391
392 #[test]
393 fn test_sliding_window_hydrate() {
394 let mut manager = SlidingWindowConversationManager::new();
395
396 let messages = vec![
397 make_message("One"),
398 make_message("Two"),
399 make_message("Three"),
400 ];
401
402 manager.hydrate(messages);
403 assert_eq!(manager.all_messages().len(), 3);
404 }
405
406 #[test]
407 fn test_simple_manager_limits() {
408 let mut manager = SimpleConversationManager::new(2);
409 let limits = ContextLimits::new(10000);
410
411 manager.add_message(make_message("One"));
412 manager.add_message(make_message("Two"));
413 manager.add_message(make_message("Three"));
414 manager.add_message(make_message("Four"));
415
416 assert_eq!(manager.all_messages().len(), 4);
418
419 let context = manager.messages_for_context(limits, &estimate_tokens);
421 assert_eq!(context.len(), 2);
422 assert_eq!(context[0].text(), "Three");
423 assert_eq!(context[1].text(), "Four");
424 }
425
426 #[test]
427 fn test_noop_manager() {
428 let mut manager = NoOpConversationManager::new();
429 let limits = ContextLimits::new(10000);
430
431 manager.add_message(make_message("One"));
432 manager.add_message(make_message("Two"));
433 manager.add_message(make_message("Three"));
434
435 let context = manager.messages_for_context(limits, &estimate_tokens);
436 assert_eq!(context.len(), 3);
437 }
438
439 #[test]
440 fn test_context_usage() {
441 let mut manager = SlidingWindowConversationManager::new();
442 let limits = ContextLimits::new(1000);
443
444 manager.add_message(make_message("Hello"));
445 manager.add_message(make_message("World"));
446
447 let usage = manager.context_usage(limits, &estimate_tokens);
448 assert_eq!(usage.total_messages, 2);
449 assert_eq!(usage.context_messages, 2);
450 assert!(usage.usage_percentage > 0.0);
451 assert!(usage.usage_percentage < 1.0);
452 }
453
454 #[test]
455 fn test_clear() {
456 let mut manager = SlidingWindowConversationManager::new();
457
458 manager.add_message(make_message("Hello"));
459 manager.add_message(make_message("World"));
460 assert_eq!(manager.all_messages().len(), 2);
461
462 manager.clear();
463 assert_eq!(manager.all_messages().len(), 0);
464 }
465}