Skip to main content

oxibonsai_runtime/
context_manager.rs

1//! Context window management for multi-turn inference.
2//!
3//! Manages the token budget across conversation turns, supporting multiple
4//! truncation strategies when context exceeds the model's maximum sequence length.
5//!
6//! ## Truncation Strategies
7//!
8//! - [`TruncationStrategy::TruncateLeft`] — drops the oldest conversation tokens (default).
9//!   System prompt tokens are always preserved.
10//! - [`TruncationStrategy::TruncateRight`] — drops the newest conversation tokens.
11//! - [`TruncationStrategy::SlidingWindow`] — keeps system prompt + most recent tokens.
12//!   Equivalent to `TruncateLeft` in this implementation.
13//! - [`TruncationStrategy::Summarize`] — placeholder; falls back to `TruncateLeft`.
14//!
15//! ## Usage
16//!
17//! ```rust
18//! use oxibonsai_runtime::context_manager::{ContextWindow, TruncationStrategy};
19//!
20//! let mut window = ContextWindow::new(2048, TruncationStrategy::TruncateLeft);
21//! window.set_system_prompt(vec![1, 2, 3]).expect("system prompt fits");
22//! window.append(&[10, 20, 30]);
23//! let tokens = window.tokens();
24//! assert!(tokens.len() <= 2048);
25//! ```
26
27// ──────────────────────────────────────────────────────────────────
28// Truncation strategy
29// ──────────────────────────────────────────────────────────────────
30
31/// Strategy for handling context that exceeds the maximum token budget.
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum TruncationStrategy {
34    /// Drop oldest conversation tokens (default). System prompt is never removed.
35    TruncateLeft,
36    /// Drop newest conversation tokens. System prompt is never removed.
37    TruncateRight,
38    /// Keep system prompt plus the most recent conversation tokens.
39    /// In practice identical to `TruncateLeft` for system-prompt-first layouts.
40    SlidingWindow,
41    /// Placeholder for future LLM-based summarisation. Falls back to `TruncateLeft`.
42    Summarize,
43}
44
45// ──────────────────────────────────────────────────────────────────
46// Context error
47// ──────────────────────────────────────────────────────────────────
48
49/// Error type for context window operations.
50#[derive(Debug)]
51pub struct ContextError(String);
52
53impl std::fmt::Display for ContextError {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        write!(f, "ContextError: {}", self.0)
56    }
57}
58
59impl std::error::Error for ContextError {}
60
61// ──────────────────────────────────────────────────────────────────
62// ContextWindow
63// ──────────────────────────────────────────────────────────────────
64
65/// A fixed-capacity token window with configurable truncation.
66///
67/// The window stores a protected *system prompt* (never truncated) and
68/// a mutable *conversation* segment. Together they must fit within
69/// `max_tokens`.
70pub struct ContextWindow {
71    /// Maximum total token count (system + conversation).
72    pub max_tokens: usize,
73    /// Tokens belonging to the system prompt — never truncated.
74    pub system_tokens: Vec<u32>,
75    /// Accumulated conversation tokens.
76    pub conversation: Vec<u32>,
77    /// How to truncate when the window is full.
78    pub strategy: TruncationStrategy,
79}
80
81impl ContextWindow {
82    /// Create a new empty context window.
83    pub fn new(max_tokens: usize, strategy: TruncationStrategy) -> Self {
84        Self {
85            max_tokens,
86            system_tokens: Vec::new(),
87            conversation: Vec::new(),
88            strategy,
89        }
90    }
91
92    /// Set the system prompt tokens.
93    ///
94    /// Returns an error if the system prompt alone exceeds `max_tokens`.
95    pub fn set_system_prompt(&mut self, tokens: Vec<u32>) -> Result<(), ContextError> {
96        if tokens.len() > self.max_tokens {
97            return Err(ContextError(format!(
98                "system prompt ({} tokens) exceeds max_tokens ({})",
99                tokens.len(),
100                self.max_tokens
101            )));
102        }
103        self.system_tokens = tokens;
104        // Truncate conversation if system prompt now leaves no room
105        self.truncate_to_fit();
106        Ok(())
107    }
108
109    /// Append tokens to the conversation.
110    ///
111    /// If the new tokens cause the window to overflow, truncation is applied
112    /// *before* appending as much of the new tokens as will fit.
113    ///
114    /// Returns the number of tokens actually appended.
115    pub fn append(&mut self, tokens: &[u32]) -> usize {
116        self.conversation.extend_from_slice(tokens);
117        let removed = self.truncate_to_fit();
118        // How many of the newly added tokens survived truncation
119        tokens.len().saturating_sub(removed)
120    }
121
122    /// Truncate the conversation segment to make the total fit within `max_tokens`.
123    ///
124    /// Applies the configured [`TruncationStrategy`].
125    /// Returns the number of tokens removed from the conversation.
126    pub fn truncate_to_fit(&mut self) -> usize {
127        let capacity_for_conv = self.max_tokens.saturating_sub(self.system_tokens.len());
128        if self.conversation.len() <= capacity_for_conv {
129            return 0;
130        }
131        let excess = self.conversation.len() - capacity_for_conv;
132
133        match self.strategy {
134            TruncationStrategy::TruncateLeft
135            | TruncationStrategy::SlidingWindow
136            | TruncationStrategy::Summarize => {
137                // Remove from the front (oldest tokens)
138                self.conversation.drain(0..excess);
139            }
140            TruncationStrategy::TruncateRight => {
141                // Remove from the back (newest tokens)
142                let new_len = self.conversation.len() - excess;
143                self.conversation.truncate(new_len);
144            }
145        }
146
147        excess
148    }
149
150    /// Concatenate system tokens and conversation tokens into a single flat vector.
151    ///
152    /// The result is always within `max_tokens`.
153    pub fn tokens(&self) -> Vec<u32> {
154        let mut result = Vec::with_capacity(self.system_tokens.len() + self.conversation.len());
155        result.extend_from_slice(&self.system_tokens);
156        result.extend_from_slice(&self.conversation);
157        result
158    }
159
160    /// Total token count (system + conversation).
161    pub fn len(&self) -> usize {
162        self.system_tokens.len() + self.conversation.len()
163    }
164
165    /// Returns `true` if both system and conversation are empty.
166    pub fn is_empty(&self) -> bool {
167        self.system_tokens.is_empty() && self.conversation.is_empty()
168    }
169
170    /// Number of additional tokens that can be appended before truncation.
171    pub fn remaining_capacity(&self) -> usize {
172        self.max_tokens.saturating_sub(self.len())
173    }
174
175    /// Returns `true` if the window is at or beyond its maximum capacity.
176    pub fn is_at_limit(&self) -> bool {
177        self.len() >= self.max_tokens
178    }
179
180    /// Clear all conversation tokens (system prompt is preserved).
181    pub fn clear_conversation(&mut self) {
182        self.conversation.clear();
183    }
184
185    /// Fraction of `max_tokens` currently in use: `len / max_tokens`.
186    ///
187    /// Returns 0.0 if `max_tokens` is zero.
188    pub fn utilization(&self) -> f32 {
189        if self.max_tokens == 0 {
190            return 0.0;
191        }
192        self.len() as f32 / self.max_tokens as f32
193    }
194}
195
196// ──────────────────────────────────────────────────────────────────
197// ConversationTurn
198// ──────────────────────────────────────────────────────────────────
199
200/// A single turn in a multi-turn conversation.
201pub struct ConversationTurn {
202    /// Role identifier (e.g., `"user"`, `"assistant"`, `"system"`).
203    pub role: String,
204    /// Raw text content of this turn.
205    pub content: String,
206    /// Pre-tokenised representation of `content`.
207    pub token_ids: Vec<u32>,
208}
209
210// ──────────────────────────────────────────────────────────────────
211// ConversationContext
212// ──────────────────────────────────────────────────────────────────
213
214/// A multi-turn conversation with automatic context window management.
215///
216/// Each added turn is stored with its role, content, and token ids.
217/// `build_tokens()` concatenates all turn token ids in order,
218/// respecting the underlying [`ContextWindow`]'s token budget.
219pub struct ConversationContext {
220    window: ContextWindow,
221    turns: Vec<ConversationTurn>,
222}
223
224impl ConversationContext {
225    /// Create a new conversation context with the given maximum token budget.
226    pub fn new(max_tokens: usize) -> Self {
227        Self {
228            window: ContextWindow::new(max_tokens, TruncationStrategy::TruncateLeft),
229            turns: Vec::new(),
230        }
231    }
232
233    /// Add a conversation turn.
234    ///
235    /// The turn's token ids are appended to the context window.
236    pub fn add_turn(&mut self, role: &str, content: &str, token_ids: Vec<u32>) {
237        self.window.append(&token_ids);
238        self.turns.push(ConversationTurn {
239            role: role.to_string(),
240            content: content.to_string(),
241            token_ids,
242        });
243    }
244
245    /// Build a flat token sequence from all turns, respecting the window budget.
246    ///
247    /// Concatenates token ids in turn order. The result is always within
248    /// `max_tokens` after truncation.
249    pub fn build_tokens(&self) -> Vec<u32> {
250        self.window.tokens()
251    }
252
253    /// Number of turns added to this conversation.
254    pub fn turn_count(&self) -> usize {
255        self.turns.len()
256    }
257
258    /// Total token count across the current context window (after truncation).
259    pub fn total_tokens(&self) -> usize {
260        self.window.len()
261    }
262
263    /// Returns `true` if the context window is at its maximum capacity.
264    pub fn is_full(&self) -> bool {
265        self.window.is_at_limit()
266    }
267
268    /// Clear all turns and reset the context window.
269    pub fn clear(&mut self) {
270        self.turns.clear();
271        self.window.clear_conversation();
272    }
273
274    /// Reference to the most recently added turn, if any.
275    pub fn last_turn(&self) -> Option<&ConversationTurn> {
276        self.turns.last()
277    }
278
279    /// Utilisation of the token budget: `total_tokens / max_tokens`.
280    pub fn utilization(&self) -> f32 {
281        self.window.utilization()
282    }
283}
284
285// ──────────────────────────────────────────────────────────────────
286// Tests
287// ──────────────────────────────────────────────────────────────────
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292
293    #[test]
294    fn test_context_window_append_within_limit() {
295        let mut window = ContextWindow::new(100, TruncationStrategy::TruncateLeft);
296        let appended = window.append(&[1, 2, 3, 4, 5]);
297        assert!(appended > 0, "should append tokens when within limit");
298        assert_eq!(window.conversation.len(), 5);
299        assert_eq!(window.len(), 5);
300    }
301
302    #[test]
303    fn test_context_window_truncate_left() {
304        let mut window = ContextWindow::new(5, TruncationStrategy::TruncateLeft);
305        // Fill to capacity
306        window.append(&[1, 2, 3, 4, 5]);
307        assert_eq!(window.conversation.len(), 5);
308
309        // Append more — oldest should be dropped
310        window.append(&[6, 7]);
311        assert_eq!(
312            window.conversation.len(),
313            5,
314            "should still be at max after truncation"
315        );
316        // The newest tokens (6, 7) should be at the end
317        let last = *window.conversation.last().expect("must have tokens");
318        assert_eq!(last, 7, "newest token should be 7");
319        // The oldest tokens (1, 2) should be gone
320        assert!(
321            !window.conversation.contains(&1),
322            "token 1 should have been truncated"
323        );
324    }
325
326    #[test]
327    fn test_context_window_truncate_right() {
328        let mut window = ContextWindow::new(5, TruncationStrategy::TruncateRight);
329        window.append(&[1, 2, 3, 4, 5]);
330        window.append(&[6, 7]);
331        // Newest tokens (6, 7) should be dropped, oldest retained
332        assert_eq!(window.conversation.len(), 5);
333        assert_eq!(
334            window.conversation[0], 1,
335            "token 1 should be preserved with TruncateRight"
336        );
337        assert!(
338            !window.conversation.contains(&6),
339            "token 6 should have been truncated"
340        );
341    }
342
343    #[test]
344    fn test_context_window_system_prompt_preserved() {
345        let mut window = ContextWindow::new(10, TruncationStrategy::TruncateLeft);
346        window
347            .set_system_prompt(vec![100, 200, 300])
348            .expect("system prompt should fit");
349
350        // Fill remaining capacity (7 slots)
351        window.append(&[1, 2, 3, 4, 5, 6, 7]);
352        assert_eq!(window.len(), 10);
353
354        // Add more — system tokens must survive
355        window.append(&[8, 9]);
356        let tokens = window.tokens();
357        assert_eq!(tokens.len(), 10);
358        assert_eq!(tokens[0], 100, "system token 0 must be preserved");
359        assert_eq!(tokens[1], 200, "system token 1 must be preserved");
360        assert_eq!(tokens[2], 300, "system token 2 must be preserved");
361    }
362
363    #[test]
364    fn test_context_window_remaining_capacity() {
365        let mut window = ContextWindow::new(20, TruncationStrategy::TruncateLeft);
366        assert_eq!(window.remaining_capacity(), 20);
367        window.append(&[1, 2, 3]);
368        assert_eq!(window.remaining_capacity(), 17);
369        window.set_system_prompt(vec![10, 20]).expect("fits");
370        // system (2) + conversation (3) = 5; remaining = 15
371        assert_eq!(window.remaining_capacity(), 15);
372    }
373
374    #[test]
375    fn test_context_window_system_prompt_too_large() {
376        let mut window = ContextWindow::new(5, TruncationStrategy::TruncateLeft);
377        let result = window.set_system_prompt(vec![1, 2, 3, 4, 5, 6]);
378        assert!(
379            result.is_err(),
380            "system prompt larger than max_tokens should error"
381        );
382    }
383
384    #[test]
385    fn test_conversation_context_add_turn() {
386        let mut ctx = ConversationContext::new(200);
387        ctx.add_turn("user", "Hello!", vec![10, 20, 30]);
388        ctx.add_turn("assistant", "Hi there!", vec![40, 50, 60, 70]);
389
390        assert_eq!(ctx.turn_count(), 2);
391        assert_eq!(ctx.total_tokens(), 7, "3 + 4 = 7 tokens total");
392
393        let last = ctx.last_turn().expect("must have a last turn");
394        assert_eq!(last.role, "assistant");
395        assert_eq!(last.content, "Hi there!");
396    }
397
398    #[test]
399    fn test_conversation_context_build_tokens() {
400        let mut ctx = ConversationContext::new(100);
401        ctx.add_turn("user", "A", vec![1, 2]);
402        ctx.add_turn("assistant", "B", vec![3, 4, 5]);
403
404        let tokens = ctx.build_tokens();
405        assert_eq!(
406            tokens,
407            vec![1, 2, 3, 4, 5],
408            "tokens should be in turn order"
409        );
410    }
411
412    #[test]
413    fn test_context_utilization() {
414        let mut window = ContextWindow::new(100, TruncationStrategy::TruncateLeft);
415        assert!(
416            (window.utilization() - 0.0).abs() < f32::EPSILON,
417            "empty window has 0.0 utilization"
418        );
419        window.append(&(0u32..50).collect::<Vec<_>>());
420        assert!(
421            (window.utilization() - 0.5).abs() < f32::EPSILON,
422            "50/100 = 0.5 utilization"
423        );
424        window.append(&(0u32..50).collect::<Vec<_>>());
425        assert!(
426            (window.utilization() - 1.0).abs() < f32::EPSILON,
427            "full window = 1.0 utilization"
428        );
429    }
430}