Skip to main content

batuta/serve/
context.rs

1//! Context Window Management
2//!
3//! Automatic token counting and context truncation.
4//! Prevents silent failures when prompts exceed model context limits.
5
6use crate::serve::templates::ChatMessage;
7use serde::{Deserialize, Serialize};
8
9// ============================================================================
10// SERVE-CTX-001: Context Configuration
11// ============================================================================
12
13/// Known model context window sizes
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15pub struct ContextWindow {
16    /// Maximum context size in tokens
17    pub max_tokens: usize,
18    /// Reserved tokens for output
19    pub output_reserve: usize,
20}
21
22impl ContextWindow {
23    /// Create a new context window configuration
24    #[must_use]
25    pub const fn new(max_tokens: usize, output_reserve: usize) -> Self {
26        Self { max_tokens, output_reserve }
27    }
28
29    /// Available tokens for input after reserving output space
30    #[must_use]
31    pub const fn available_input(&self) -> usize {
32        self.max_tokens.saturating_sub(self.output_reserve)
33    }
34
35    /// Model name patterns mapped to (max_tokens, output_reserve).
36    /// Order matters: more specific patterns must come first.
37    /// Each entry uses ALL semantics: every pattern in the slice must match.
38    const MODEL_WINDOWS: &[(&[&str], usize, usize)] = &[
39        (&["gpt-4-turbo"], 128_000, 4096),
40        (&["gpt-4o"], 128_000, 4096),
41        (&["gpt-4-32k"], 32_768, 4096),
42        (&["gpt-4"], 8_192, 2048),
43        (&["gpt-3.5-turbo-16k"], 16_384, 4096),
44        (&["gpt-3.5"], 4_096, 1024),
45        (&["claude-3"], 200_000, 4096),
46        (&["claude-2"], 200_000, 4096),
47        (&["claude"], 100_000, 4096),
48        (&["llama-3"], 8_192, 2048),
49        (&["llama-2", "32k"], 32_768, 4096),
50        (&["llama"], 4_096, 1024),
51        (&["mixtral"], 32_768, 4096),
52        (&["mistral"], 8_192, 2048),
53    ];
54
55    /// Get context window for known model
56    #[must_use]
57    pub fn for_model(model: &str) -> Self {
58        let lower = model.to_lowercase();
59        Self::MODEL_WINDOWS
60            .iter()
61            .find(|(pats, _, _)| pats.iter().all(|p| lower.contains(p)))
62            .map_or_else(Self::default, |&(_, max, reserve)| Self::new(max, reserve))
63    }
64}
65
66impl Default for ContextWindow {
67    fn default() -> Self {
68        Self::new(4_096, 1024)
69    }
70}
71
72// ============================================================================
73// SERVE-CTX-002: Token Estimation
74// ============================================================================
75
76/// Simple token estimator (approximation without full tokenizer)
77///
78/// Uses heuristic: ~4 characters per token for English text.
79/// For accurate counts, use a proper tokenizer.
80pub struct TokenEstimator {
81    /// Characters per token (default: 4.0)
82    chars_per_token: f64,
83}
84
85impl TokenEstimator {
86    /// Create with default settings
87    #[must_use]
88    pub fn new() -> Self {
89        Self { chars_per_token: 4.0 }
90    }
91
92    /// Create with custom chars-per-token ratio
93    #[must_use]
94    pub fn with_ratio(chars_per_token: f64) -> Self {
95        Self { chars_per_token }
96    }
97
98    /// Estimate token count for a string
99    #[must_use]
100    pub fn estimate(&self, text: &str) -> usize {
101        if self.chars_per_token <= 0.0 {
102            return text.len();
103        }
104        (text.len() as f64 / self.chars_per_token).ceil() as usize
105    }
106
107    /// Estimate tokens for chat messages
108    #[must_use]
109    pub fn estimate_messages(&self, messages: &[ChatMessage]) -> usize {
110        let mut total = 0;
111        for msg in messages {
112            // Role tokens (approximately 3-4 tokens per message for formatting)
113            total += 4;
114            total += self.estimate(&msg.content);
115        }
116        total
117    }
118}
119
120impl Default for TokenEstimator {
121    fn default() -> Self {
122        Self::new()
123    }
124}
125
126// ============================================================================
127// SERVE-CTX-003: Context Manager
128// ============================================================================
129
130/// Truncation strategy when context is exceeded
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
132pub enum TruncationStrategy {
133    /// Remove oldest messages first (sliding window)
134    #[default]
135    SlidingWindow,
136    /// Remove from the middle, keep first and last
137    MiddleOut,
138    /// Fail with error instead of truncating
139    Error,
140}
141
142/// Context management configuration
143#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct ContextConfig {
145    /// Context window settings
146    pub window: ContextWindow,
147    /// Truncation strategy
148    pub strategy: TruncationStrategy,
149    /// Always preserve system message
150    pub preserve_system: bool,
151    /// Minimum messages to keep
152    pub min_messages: usize,
153}
154
155impl Default for ContextConfig {
156    fn default() -> Self {
157        Self {
158            window: ContextWindow::default(),
159            strategy: TruncationStrategy::SlidingWindow,
160            preserve_system: true,
161            min_messages: 2,
162        }
163    }
164}
165
166impl ContextConfig {
167    /// Create config for a specific model
168    #[must_use]
169    pub fn for_model(model: &str) -> Self {
170        Self { window: ContextWindow::for_model(model), ..Default::default() }
171    }
172}
173
174/// Context manager for handling token limits
175pub struct ContextManager {
176    config: ContextConfig,
177    estimator: TokenEstimator,
178}
179
180impl ContextManager {
181    /// Create a new context manager
182    #[must_use]
183    pub fn new(config: ContextConfig) -> Self {
184        Self { config, estimator: TokenEstimator::new() }
185    }
186
187    /// Create for a specific model
188    #[must_use]
189    pub fn for_model(model: &str) -> Self {
190        Self::new(ContextConfig::for_model(model))
191    }
192
193    /// Check if messages fit within context window
194    #[must_use]
195    pub fn fits(&self, messages: &[ChatMessage]) -> bool {
196        let tokens = self.estimator.estimate_messages(messages);
197        tokens <= self.config.window.available_input()
198    }
199
200    /// Get estimated token count for messages
201    #[must_use]
202    pub fn estimate_tokens(&self, messages: &[ChatMessage]) -> usize {
203        self.estimator.estimate_messages(messages)
204    }
205
206    /// Get available token budget
207    #[must_use]
208    pub fn available_tokens(&self) -> usize {
209        self.config.window.available_input()
210    }
211
212    /// Truncate messages to fit within context window
213    ///
214    /// Returns truncated messages or error if strategy is `Error` and truncation needed.
215    pub fn truncate(&self, messages: &[ChatMessage]) -> Result<Vec<ChatMessage>, ContextError> {
216        let available = self.config.window.available_input();
217        let current = self.estimator.estimate_messages(messages);
218
219        if current <= available {
220            return Ok(messages.to_vec());
221        }
222
223        match self.config.strategy {
224            TruncationStrategy::Error => {
225                Err(ContextError::ExceedsLimit { tokens: current, limit: available })
226            }
227            TruncationStrategy::SlidingWindow => {
228                Ok(self.truncate_sliding_window(messages, available))
229            }
230            TruncationStrategy::MiddleOut => Ok(self.truncate_middle_out(messages, available)),
231        }
232    }
233
234    fn truncate_sliding_window(
235        &self,
236        messages: &[ChatMessage],
237        available: usize,
238    ) -> Vec<ChatMessage> {
239        let mut result = Vec::new();
240        let mut tokens_used = 0;
241
242        // Extract system message if preserving
243        let (system_msg, other_msgs): (Vec<_>, Vec<_>) = if self.config.preserve_system {
244            messages.iter().partition(|m| matches!(m.role, crate::serve::templates::Role::System))
245        } else {
246            (vec![], messages.iter().collect())
247        };
248
249        // Add system message first
250        for msg in &system_msg {
251            let msg_tokens = self.estimator.estimate(&msg.content) + 4;
252            if tokens_used + msg_tokens <= available {
253                result.push((*msg).clone());
254                tokens_used += msg_tokens;
255            }
256        }
257
258        // Add messages from the end (most recent first)
259        let mut recent_msgs: Vec<ChatMessage> = Vec::new();
260        for msg in other_msgs.into_iter().rev() {
261            let msg_tokens = self.estimator.estimate(&msg.content) + 4;
262            if tokens_used + msg_tokens <= available {
263                recent_msgs.push(msg.clone());
264                tokens_used += msg_tokens;
265            } else if recent_msgs.len() >= self.config.min_messages {
266                break;
267            }
268        }
269
270        // Reverse to restore chronological order
271        recent_msgs.reverse();
272        result.extend(recent_msgs);
273
274        result
275    }
276
277    fn truncate_middle_out(&self, messages: &[ChatMessage], available: usize) -> Vec<ChatMessage> {
278        if messages.len() <= 2 {
279            return messages.to_vec();
280        }
281
282        let mut result = Vec::new();
283        let mut tokens_used = 0;
284
285        // Always keep first message (often system)
286        let first = &messages[0];
287        let first_tokens = self.estimator.estimate(&first.content) + 4;
288        result.push(first.clone());
289        tokens_used += first_tokens;
290
291        // Always keep last message
292        let last = &messages[messages.len() - 1];
293        let last_tokens = self.estimator.estimate(&last.content) + 4;
294        tokens_used += last_tokens;
295
296        // Add messages from the end, working backwards
297        let middle = &messages[1..messages.len() - 1];
298        let mut kept_from_end: Vec<ChatMessage> = Vec::new();
299
300        for msg in middle.iter().rev() {
301            let msg_tokens = self.estimator.estimate(&msg.content) + 4;
302            if tokens_used + msg_tokens <= available {
303                kept_from_end.push(msg.clone());
304                tokens_used += msg_tokens;
305            } else {
306                break;
307            }
308        }
309
310        // Reverse and add
311        kept_from_end.reverse();
312        result.extend(kept_from_end);
313        result.push(last.clone());
314
315        result
316    }
317}
318
319impl Default for ContextManager {
320    fn default() -> Self {
321        Self::new(ContextConfig::default())
322    }
323}
324
325/// Context management errors
326#[derive(Debug, Clone, PartialEq, Eq)]
327pub enum ContextError {
328    /// Context window exceeded and strategy is Error
329    ExceedsLimit { tokens: usize, limit: usize },
330}
331
332impl std::fmt::Display for ContextError {
333    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
334        match self {
335            Self::ExceedsLimit { tokens, limit } => {
336                write!(f, "Context exceeds limit: {} tokens, max {} tokens", tokens, limit)
337            }
338        }
339    }
340}
341
342impl std::error::Error for ContextError {}
343
344// ============================================================================
345// Tests
346// ============================================================================
347
348#[cfg(test)]
349#[allow(non_snake_case)]
350#[path = "context_tests.rs"]
351mod tests;
352
353#[cfg(test)]
354#[path = "context_contract_tests.rs"]
355mod contract_tests;