1use crate::serve::templates::ChatMessage;
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15pub struct ContextWindow {
16 pub max_tokens: usize,
18 pub output_reserve: usize,
20}
21
22impl ContextWindow {
23 #[must_use]
25 pub const fn new(max_tokens: usize, output_reserve: usize) -> Self {
26 Self { max_tokens, output_reserve }
27 }
28
29 #[must_use]
31 pub const fn available_input(&self) -> usize {
32 self.max_tokens.saturating_sub(self.output_reserve)
33 }
34
35 const MODEL_WINDOWS: &[(&[&str], usize, usize)] = &[
39 (&["gpt-4-turbo"], 128_000, 4096),
40 (&["gpt-4o"], 128_000, 4096),
41 (&["gpt-4-32k"], 32_768, 4096),
42 (&["gpt-4"], 8_192, 2048),
43 (&["gpt-3.5-turbo-16k"], 16_384, 4096),
44 (&["gpt-3.5"], 4_096, 1024),
45 (&["claude-3"], 200_000, 4096),
46 (&["claude-2"], 200_000, 4096),
47 (&["claude"], 100_000, 4096),
48 (&["llama-3"], 8_192, 2048),
49 (&["llama-2", "32k"], 32_768, 4096),
50 (&["llama"], 4_096, 1024),
51 (&["mixtral"], 32_768, 4096),
52 (&["mistral"], 8_192, 2048),
53 ];
54
55 #[must_use]
57 pub fn for_model(model: &str) -> Self {
58 let lower = model.to_lowercase();
59 Self::MODEL_WINDOWS
60 .iter()
61 .find(|(pats, _, _)| pats.iter().all(|p| lower.contains(p)))
62 .map_or_else(Self::default, |&(_, max, reserve)| Self::new(max, reserve))
63 }
64}
65
66impl Default for ContextWindow {
67 fn default() -> Self {
68 Self::new(4_096, 1024)
69 }
70}
71
72pub struct TokenEstimator {
81 chars_per_token: f64,
83}
84
85impl TokenEstimator {
86 #[must_use]
88 pub fn new() -> Self {
89 Self { chars_per_token: 4.0 }
90 }
91
92 #[must_use]
94 pub fn with_ratio(chars_per_token: f64) -> Self {
95 Self { chars_per_token }
96 }
97
98 #[must_use]
100 pub fn estimate(&self, text: &str) -> usize {
101 if self.chars_per_token <= 0.0 {
102 return text.len();
103 }
104 (text.len() as f64 / self.chars_per_token).ceil() as usize
105 }
106
107 #[must_use]
109 pub fn estimate_messages(&self, messages: &[ChatMessage]) -> usize {
110 let mut total = 0;
111 for msg in messages {
112 total += 4;
114 total += self.estimate(&msg.content);
115 }
116 total
117 }
118}
119
120impl Default for TokenEstimator {
121 fn default() -> Self {
122 Self::new()
123 }
124}
125
126#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
132pub enum TruncationStrategy {
133 #[default]
135 SlidingWindow,
136 MiddleOut,
138 Error,
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct ContextConfig {
145 pub window: ContextWindow,
147 pub strategy: TruncationStrategy,
149 pub preserve_system: bool,
151 pub min_messages: usize,
153}
154
155impl Default for ContextConfig {
156 fn default() -> Self {
157 Self {
158 window: ContextWindow::default(),
159 strategy: TruncationStrategy::SlidingWindow,
160 preserve_system: true,
161 min_messages: 2,
162 }
163 }
164}
165
166impl ContextConfig {
167 #[must_use]
169 pub fn for_model(model: &str) -> Self {
170 Self { window: ContextWindow::for_model(model), ..Default::default() }
171 }
172}
173
174pub struct ContextManager {
176 config: ContextConfig,
177 estimator: TokenEstimator,
178}
179
180impl ContextManager {
181 #[must_use]
183 pub fn new(config: ContextConfig) -> Self {
184 Self { config, estimator: TokenEstimator::new() }
185 }
186
187 #[must_use]
189 pub fn for_model(model: &str) -> Self {
190 Self::new(ContextConfig::for_model(model))
191 }
192
193 #[must_use]
195 pub fn fits(&self, messages: &[ChatMessage]) -> bool {
196 let tokens = self.estimator.estimate_messages(messages);
197 tokens <= self.config.window.available_input()
198 }
199
200 #[must_use]
202 pub fn estimate_tokens(&self, messages: &[ChatMessage]) -> usize {
203 self.estimator.estimate_messages(messages)
204 }
205
206 #[must_use]
208 pub fn available_tokens(&self) -> usize {
209 self.config.window.available_input()
210 }
211
212 pub fn truncate(&self, messages: &[ChatMessage]) -> Result<Vec<ChatMessage>, ContextError> {
216 let available = self.config.window.available_input();
217 let current = self.estimator.estimate_messages(messages);
218
219 if current <= available {
220 return Ok(messages.to_vec());
221 }
222
223 match self.config.strategy {
224 TruncationStrategy::Error => {
225 Err(ContextError::ExceedsLimit { tokens: current, limit: available })
226 }
227 TruncationStrategy::SlidingWindow => {
228 Ok(self.truncate_sliding_window(messages, available))
229 }
230 TruncationStrategy::MiddleOut => Ok(self.truncate_middle_out(messages, available)),
231 }
232 }
233
234 fn truncate_sliding_window(
235 &self,
236 messages: &[ChatMessage],
237 available: usize,
238 ) -> Vec<ChatMessage> {
239 let mut result = Vec::new();
240 let mut tokens_used = 0;
241
242 let (system_msg, other_msgs): (Vec<_>, Vec<_>) = if self.config.preserve_system {
244 messages.iter().partition(|m| matches!(m.role, crate::serve::templates::Role::System))
245 } else {
246 (vec![], messages.iter().collect())
247 };
248
249 for msg in &system_msg {
251 let msg_tokens = self.estimator.estimate(&msg.content) + 4;
252 if tokens_used + msg_tokens <= available {
253 result.push((*msg).clone());
254 tokens_used += msg_tokens;
255 }
256 }
257
258 let mut recent_msgs: Vec<ChatMessage> = Vec::new();
260 for msg in other_msgs.into_iter().rev() {
261 let msg_tokens = self.estimator.estimate(&msg.content) + 4;
262 if tokens_used + msg_tokens <= available {
263 recent_msgs.push(msg.clone());
264 tokens_used += msg_tokens;
265 } else if recent_msgs.len() >= self.config.min_messages {
266 break;
267 }
268 }
269
270 recent_msgs.reverse();
272 result.extend(recent_msgs);
273
274 result
275 }
276
277 fn truncate_middle_out(&self, messages: &[ChatMessage], available: usize) -> Vec<ChatMessage> {
278 if messages.len() <= 2 {
279 return messages.to_vec();
280 }
281
282 let mut result = Vec::new();
283 let mut tokens_used = 0;
284
285 let first = &messages[0];
287 let first_tokens = self.estimator.estimate(&first.content) + 4;
288 result.push(first.clone());
289 tokens_used += first_tokens;
290
291 let last = &messages[messages.len() - 1];
293 let last_tokens = self.estimator.estimate(&last.content) + 4;
294 tokens_used += last_tokens;
295
296 let middle = &messages[1..messages.len() - 1];
298 let mut kept_from_end: Vec<ChatMessage> = Vec::new();
299
300 for msg in middle.iter().rev() {
301 let msg_tokens = self.estimator.estimate(&msg.content) + 4;
302 if tokens_used + msg_tokens <= available {
303 kept_from_end.push(msg.clone());
304 tokens_used += msg_tokens;
305 } else {
306 break;
307 }
308 }
309
310 kept_from_end.reverse();
312 result.extend(kept_from_end);
313 result.push(last.clone());
314
315 result
316 }
317}
318
319impl Default for ContextManager {
320 fn default() -> Self {
321 Self::new(ContextConfig::default())
322 }
323}
324
325#[derive(Debug, Clone, PartialEq, Eq)]
327pub enum ContextError {
328 ExceedsLimit { tokens: usize, limit: usize },
330}
331
332impl std::fmt::Display for ContextError {
333 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
334 match self {
335 Self::ExceedsLimit { tokens, limit } => {
336 write!(f, "Context exceeds limit: {} tokens, max {} tokens", tokens, limit)
337 }
338 }
339 }
340}
341
342impl std::error::Error for ContextError {}
343
344#[cfg(test)]
349#[allow(non_snake_case)]
350#[path = "context_tests.rs"]
351mod tests;
352
353#[cfg(test)]
354#[path = "context_contract_tests.rs"]
355mod contract_tests;