1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum TruncationStrategy {
34 TruncateLeft,
36 TruncateRight,
38 SlidingWindow,
41 Summarize,
43}
44
45#[derive(Debug)]
51pub struct ContextError(String);
52
53impl std::fmt::Display for ContextError {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 write!(f, "ContextError: {}", self.0)
56 }
57}
58
59impl std::error::Error for ContextError {}
60
61pub struct ContextWindow {
71 pub max_tokens: usize,
73 pub system_tokens: Vec<u32>,
75 pub conversation: Vec<u32>,
77 pub strategy: TruncationStrategy,
79}
80
81impl ContextWindow {
82 pub fn new(max_tokens: usize, strategy: TruncationStrategy) -> Self {
84 Self {
85 max_tokens,
86 system_tokens: Vec::new(),
87 conversation: Vec::new(),
88 strategy,
89 }
90 }
91
92 pub fn set_system_prompt(&mut self, tokens: Vec<u32>) -> Result<(), ContextError> {
96 if tokens.len() > self.max_tokens {
97 return Err(ContextError(format!(
98 "system prompt ({} tokens) exceeds max_tokens ({})",
99 tokens.len(),
100 self.max_tokens
101 )));
102 }
103 self.system_tokens = tokens;
104 self.truncate_to_fit();
106 Ok(())
107 }
108
109 pub fn append(&mut self, tokens: &[u32]) -> usize {
116 self.conversation.extend_from_slice(tokens);
117 let removed = self.truncate_to_fit();
118 tokens.len().saturating_sub(removed)
120 }
121
122 pub fn truncate_to_fit(&mut self) -> usize {
127 let capacity_for_conv = self.max_tokens.saturating_sub(self.system_tokens.len());
128 if self.conversation.len() <= capacity_for_conv {
129 return 0;
130 }
131 let excess = self.conversation.len() - capacity_for_conv;
132
133 match self.strategy {
134 TruncationStrategy::TruncateLeft
135 | TruncationStrategy::SlidingWindow
136 | TruncationStrategy::Summarize => {
137 self.conversation.drain(0..excess);
139 }
140 TruncationStrategy::TruncateRight => {
141 let new_len = self.conversation.len() - excess;
143 self.conversation.truncate(new_len);
144 }
145 }
146
147 excess
148 }
149
150 pub fn tokens(&self) -> Vec<u32> {
154 let mut result = Vec::with_capacity(self.system_tokens.len() + self.conversation.len());
155 result.extend_from_slice(&self.system_tokens);
156 result.extend_from_slice(&self.conversation);
157 result
158 }
159
160 pub fn len(&self) -> usize {
162 self.system_tokens.len() + self.conversation.len()
163 }
164
165 pub fn is_empty(&self) -> bool {
167 self.system_tokens.is_empty() && self.conversation.is_empty()
168 }
169
170 pub fn remaining_capacity(&self) -> usize {
172 self.max_tokens.saturating_sub(self.len())
173 }
174
175 pub fn is_at_limit(&self) -> bool {
177 self.len() >= self.max_tokens
178 }
179
180 pub fn clear_conversation(&mut self) {
182 self.conversation.clear();
183 }
184
185 pub fn utilization(&self) -> f32 {
189 if self.max_tokens == 0 {
190 return 0.0;
191 }
192 self.len() as f32 / self.max_tokens as f32
193 }
194}
195
196pub struct ConversationTurn {
202 pub role: String,
204 pub content: String,
206 pub token_ids: Vec<u32>,
208}
209
210pub struct ConversationContext {
220 window: ContextWindow,
221 turns: Vec<ConversationTurn>,
222}
223
224impl ConversationContext {
225 pub fn new(max_tokens: usize) -> Self {
227 Self {
228 window: ContextWindow::new(max_tokens, TruncationStrategy::TruncateLeft),
229 turns: Vec::new(),
230 }
231 }
232
233 pub fn add_turn(&mut self, role: &str, content: &str, token_ids: Vec<u32>) {
237 self.window.append(&token_ids);
238 self.turns.push(ConversationTurn {
239 role: role.to_string(),
240 content: content.to_string(),
241 token_ids,
242 });
243 }
244
245 pub fn build_tokens(&self) -> Vec<u32> {
250 self.window.tokens()
251 }
252
253 pub fn turn_count(&self) -> usize {
255 self.turns.len()
256 }
257
258 pub fn total_tokens(&self) -> usize {
260 self.window.len()
261 }
262
263 pub fn is_full(&self) -> bool {
265 self.window.is_at_limit()
266 }
267
268 pub fn clear(&mut self) {
270 self.turns.clear();
271 self.window.clear_conversation();
272 }
273
274 pub fn last_turn(&self) -> Option<&ConversationTurn> {
276 self.turns.last()
277 }
278
279 pub fn utilization(&self) -> f32 {
281 self.window.utilization()
282 }
283}
284
285#[cfg(test)]
290mod tests {
291 use super::*;
292
293 #[test]
294 fn test_context_window_append_within_limit() {
295 let mut window = ContextWindow::new(100, TruncationStrategy::TruncateLeft);
296 let appended = window.append(&[1, 2, 3, 4, 5]);
297 assert!(appended > 0, "should append tokens when within limit");
298 assert_eq!(window.conversation.len(), 5);
299 assert_eq!(window.len(), 5);
300 }
301
302 #[test]
303 fn test_context_window_truncate_left() {
304 let mut window = ContextWindow::new(5, TruncationStrategy::TruncateLeft);
305 window.append(&[1, 2, 3, 4, 5]);
307 assert_eq!(window.conversation.len(), 5);
308
309 window.append(&[6, 7]);
311 assert_eq!(
312 window.conversation.len(),
313 5,
314 "should still be at max after truncation"
315 );
316 let last = *window.conversation.last().expect("must have tokens");
318 assert_eq!(last, 7, "newest token should be 7");
319 assert!(
321 !window.conversation.contains(&1),
322 "token 1 should have been truncated"
323 );
324 }
325
326 #[test]
327 fn test_context_window_truncate_right() {
328 let mut window = ContextWindow::new(5, TruncationStrategy::TruncateRight);
329 window.append(&[1, 2, 3, 4, 5]);
330 window.append(&[6, 7]);
331 assert_eq!(window.conversation.len(), 5);
333 assert_eq!(
334 window.conversation[0], 1,
335 "token 1 should be preserved with TruncateRight"
336 );
337 assert!(
338 !window.conversation.contains(&6),
339 "token 6 should have been truncated"
340 );
341 }
342
343 #[test]
344 fn test_context_window_system_prompt_preserved() {
345 let mut window = ContextWindow::new(10, TruncationStrategy::TruncateLeft);
346 window
347 .set_system_prompt(vec![100, 200, 300])
348 .expect("system prompt should fit");
349
350 window.append(&[1, 2, 3, 4, 5, 6, 7]);
352 assert_eq!(window.len(), 10);
353
354 window.append(&[8, 9]);
356 let tokens = window.tokens();
357 assert_eq!(tokens.len(), 10);
358 assert_eq!(tokens[0], 100, "system token 0 must be preserved");
359 assert_eq!(tokens[1], 200, "system token 1 must be preserved");
360 assert_eq!(tokens[2], 300, "system token 2 must be preserved");
361 }
362
363 #[test]
364 fn test_context_window_remaining_capacity() {
365 let mut window = ContextWindow::new(20, TruncationStrategy::TruncateLeft);
366 assert_eq!(window.remaining_capacity(), 20);
367 window.append(&[1, 2, 3]);
368 assert_eq!(window.remaining_capacity(), 17);
369 window.set_system_prompt(vec![10, 20]).expect("fits");
370 assert_eq!(window.remaining_capacity(), 15);
372 }
373
374 #[test]
375 fn test_context_window_system_prompt_too_large() {
376 let mut window = ContextWindow::new(5, TruncationStrategy::TruncateLeft);
377 let result = window.set_system_prompt(vec![1, 2, 3, 4, 5, 6]);
378 assert!(
379 result.is_err(),
380 "system prompt larger than max_tokens should error"
381 );
382 }
383
384 #[test]
385 fn test_conversation_context_add_turn() {
386 let mut ctx = ConversationContext::new(200);
387 ctx.add_turn("user", "Hello!", vec![10, 20, 30]);
388 ctx.add_turn("assistant", "Hi there!", vec![40, 50, 60, 70]);
389
390 assert_eq!(ctx.turn_count(), 2);
391 assert_eq!(ctx.total_tokens(), 7, "3 + 4 = 7 tokens total");
392
393 let last = ctx.last_turn().expect("must have a last turn");
394 assert_eq!(last.role, "assistant");
395 assert_eq!(last.content, "Hi there!");
396 }
397
398 #[test]
399 fn test_conversation_context_build_tokens() {
400 let mut ctx = ConversationContext::new(100);
401 ctx.add_turn("user", "A", vec![1, 2]);
402 ctx.add_turn("assistant", "B", vec![3, 4, 5]);
403
404 let tokens = ctx.build_tokens();
405 assert_eq!(
406 tokens,
407 vec![1, 2, 3, 4, 5],
408 "tokens should be in turn order"
409 );
410 }
411
412 #[test]
413 fn test_context_utilization() {
414 let mut window = ContextWindow::new(100, TruncationStrategy::TruncateLeft);
415 assert!(
416 (window.utilization() - 0.0).abs() < f32::EPSILON,
417 "empty window has 0.0 utilization"
418 );
419 window.append(&(0u32..50).collect::<Vec<_>>());
420 assert!(
421 (window.utilization() - 0.5).abs() < f32::EPSILON,
422 "50/100 = 0.5 utilization"
423 );
424 window.append(&(0u32..50).collect::<Vec<_>>());
425 assert!(
426 (window.utilization() - 1.0).abs() < f32::EPSILON,
427 "full window = 1.0 utilization"
428 );
429 }
430}