1use crate::error::LlmError;
2use crate::types::{Message, Role};
3use tiktoken_rs::cl100k_base;
4
5pub struct ModelHandoff {
6 tokenizer: tiktoken_rs::CoreBPE,
7}
8
9impl Default for ModelHandoff {
10 fn default() -> Self {
11 Self::new()
12 }
13}
14
15impl ModelHandoff {
16 pub fn new() -> Self {
17 Self {
18 tokenizer: cl100k_base().expect("Failed to load tokenizer"),
19 }
20 }
21
22 pub fn count_tokens(&self, text: &str) -> usize {
23 self.tokenizer.encode_with_special_tokens(text).len()
24 }
25
26 pub fn count_message_tokens(&self, message: &Message) -> usize {
27 let mut total = message
28 .content
29 .as_ref()
30 .map(|c| self.count_tokens(c))
31 .unwrap_or(0);
32
33 total += 4;
35
36 if let Some(tool_calls) = &message.tool_calls {
38 for call in tool_calls {
39 total += self.count_tokens(&call.id);
40 total += self.count_tokens(&call.function.name);
41 total += self.count_tokens(&call.function.arguments);
42 }
43 }
44
45 total
46 }
47
48 pub fn count_total_tokens(&self, messages: &[Message]) -> usize {
49 messages.iter().map(|m| self.count_message_tokens(m)).sum()
50 }
51
52 pub fn compact_messages(&self, messages: &[Message], target_tokens: usize) -> Vec<Message> {
53 let system_msg = messages.iter().find(|m| matches!(m.role, Role::System));
55
56 let non_system: Vec<_> = messages
58 .iter()
59 .filter(|m| !matches!(m.role, Role::System))
60 .cloned()
61 .collect();
62
63 let mut compacted = Vec::new();
64
65 if let Some(sys) = system_msg {
67 compacted.push(sys.clone());
68 }
69
70 let system_tokens = compacted
72 .iter()
73 .map(|m| self.count_message_tokens(m))
74 .sum::<usize>();
75
76 let safety_buffer = (target_tokens / 5).max(100);
78 let remaining_budget = target_tokens.saturating_sub(system_tokens + safety_buffer);
79
80 let mut selected = Vec::new();
82 let mut current_tokens = 0;
83
84 for msg in non_system.iter().rev() {
85 let msg_tokens = self.count_message_tokens(msg);
86
87 if current_tokens + msg_tokens <= remaining_budget {
88 current_tokens += msg_tokens;
89 selected.push(msg.clone());
90 } else {
91 break;
92 }
93 }
94
95 selected.reverse();
97 compacted.extend(selected);
98
99 compacted
100 }
101
102 pub fn handoff_to_model(
103 &self,
104 _from_model: &str,
105 to_model: &str,
106 messages: &[Message],
107 ) -> Result<Vec<Message>, LlmError> {
108 let target_tokens = match to_model {
110 "claude-3-5-sonnet-20241022" => 200000,
111 "claude-3-5-haiku-20241022" => 200000,
112 "claude-3-opus-20240229" => 200000,
113 "claude-3-sonnet-20240229" => 200000,
114 "claude-3-haiku-20240307" => 200000,
115 _ => 200000, };
117
118 let current_tokens = self.count_total_tokens(messages);
119
120 if current_tokens > target_tokens * 9 / 10 {
121 Ok(self.compact_messages(messages, target_tokens))
122 } else {
123 Ok(messages.to_vec())
124 }
125 }
126
127 pub fn find_cut_point(&self, messages: &[Message], keep_recent_tokens: usize) -> Option<usize> {
128 if messages.is_empty() {
129 return None;
130 }
131
132 let non_system: Vec<_> = messages
133 .iter()
134 .enumerate()
135 .filter(|(_, m)| !matches!(m.role, Role::System))
136 .collect();
137
138 if non_system.is_empty() {
139 return None;
140 }
141
142 let mut accumulated = 0;
143 for (idx, msg) in non_system.iter().rev() {
144 accumulated += self.count_message_tokens(msg);
145
146 if accumulated >= keep_recent_tokens {
147 let cut_idx = self.find_valid_cut_point(&non_system, *idx);
148 return Some(cut_idx);
149 }
150 }
151
152 Some(0)
153 }
154
155 fn find_valid_cut_point(&self, non_system: &[(usize, &Message)], min_idx: usize) -> usize {
156 for (idx, msg) in non_system.iter() {
157 if *idx >= min_idx && matches!(msg.role, Role::User) {
158 return *idx;
159 }
160 }
161
162 min_idx
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169 use crate::types::FunctionCall;
170 use crate::types::ToolCall;
171
172 #[test]
173 fn test_count_tokens_simple() {
174 let handoff = ModelHandoff::new();
175 let tokens = handoff.count_tokens("Hello, world!");
176 assert!(tokens > 0);
177 assert!(tokens < 10);
178 }
179
180 #[test]
181 fn test_count_message_tokens() {
182 let handoff = ModelHandoff::new();
183 let msg = Message {
184 role: Role::User,
185 content: Some("Hello, world!".to_string()),
186 tool_calls: None,
187 tool_call_id: None,
188 cache_control: None,
189 };
190 let tokens = handoff.count_message_tokens(&msg);
191 assert!(tokens > 4); }
193
194 #[test]
195 fn test_count_message_tokens_with_tool_calls() {
196 let handoff = ModelHandoff::new();
197 let msg = Message {
198 role: Role::Assistant,
199 content: Some("".to_string()),
200 tool_calls: Some(vec![ToolCall {
201 id: "call_123".to_string(),
202 tool_type: "function".to_string(),
203 function: FunctionCall {
204 name: "test_tool".to_string(),
205 arguments: serde_json::json!({"arg": "value"}).to_string(),
206 },
207 }]),
208 tool_call_id: None,
209 cache_control: None,
210 };
211 let tokens = handoff.count_message_tokens(&msg);
212 assert!(tokens > 10);
213 }
214
215 #[test]
216 fn test_count_total_tokens() {
217 let handoff = ModelHandoff::new();
218 let messages = vec![
219 Message {
220 role: Role::User,
221 content: Some("Hello".to_string()),
222 tool_calls: None,
223 tool_call_id: None,
224 cache_control: None,
225 },
226 Message {
227 role: Role::Assistant,
228 content: Some("Hi there!".to_string()),
229 tool_calls: None,
230 tool_call_id: None,
231 cache_control: None,
232 },
233 ];
234 let total = handoff.count_total_tokens(&messages);
235 assert!(total > 0);
236 }
237
238 #[test]
239 fn test_compact_messages_preserves_system() {
240 let handoff = ModelHandoff::new();
241 let messages = vec![
242 Message {
243 role: Role::System,
244 content: Some("You are a helpful assistant.".to_string()),
245 tool_calls: None,
246 tool_call_id: None,
247 cache_control: None,
248 },
249 Message {
250 role: Role::User,
251 content: Some("Hello".to_string()),
252 tool_calls: None,
253 tool_call_id: None,
254 cache_control: None,
255 },
256 ];
257 let compacted = handoff.compact_messages(&messages, 500);
258 assert!(!compacted.is_empty());
259 if compacted.len() > 1 {
260 assert!(matches!(compacted[0].role, Role::System));
261 }
262 }
263
264 #[test]
265 fn test_compact_messages_keeps_recent() {
266 let handoff = ModelHandoff::new();
267 let mut messages = vec![Message {
268 role: Role::System,
269 content: Some("System".to_string()),
270 tool_calls: None,
271 tool_call_id: None,
272 cache_control: None,
273 }];
274
275 for i in 0..100 {
277 messages.push(Message {
278 role: if i % 2 == 0 {
279 Role::User
280 } else {
281 Role::Assistant
282 },
283 content: Some(format!("Message {}", i)),
284 tool_calls: None,
285 tool_call_id: None,
286 cache_control: None,
287 });
288 }
289
290 let compacted = handoff.compact_messages(&messages, 500);
292
293 assert!(compacted.len() < messages.len());
295 assert!(matches!(compacted[0].role, Role::System));
296
297 assert_eq!(
299 compacted.last().unwrap().content,
300 Some("Message 99".to_string())
301 );
302 }
303
304 #[test]
305 fn test_handoff_to_model_no_compaction_needed() {
306 let handoff = ModelHandoff::new();
307 let messages = vec![Message {
308 role: Role::User,
309 content: Some("Hello".to_string()),
310 tool_calls: None,
311 tool_call_id: None,
312 cache_control: None,
313 }];
314
315 let result = handoff.handoff_to_model(
316 "claude-3-5-sonnet-20241022",
317 "claude-3-5-haiku-20241022",
318 &messages,
319 );
320
321 assert!(result.is_ok());
322 let handoff_messages = result.unwrap();
323 assert_eq!(handoff_messages.len(), messages.len());
324 }
325
326 #[test]
327 fn test_handoff_to_model_compacts_when_needed() {
328 let handoff = ModelHandoff::new();
329 let mut messages = vec![Message {
330 role: Role::System,
331 content: Some("System".to_string()),
332 tool_calls: None,
333 tool_call_id: None,
334 cache_control: None,
335 }];
336
337 for i in 0..5000 {
339 messages.push(Message {
340 role: if i % 2 == 0 {
341 Role::User
342 } else {
343 Role::Assistant
344 },
345 content: Some(format!(
346 "This is message number {}. It contains significantly more content to ensure we exceed the context window limit. Each message should be approximately 50-60 tokens in length when encoded with the cl100k_base tokenizer. This allows us to test the compaction functionality effectively. ",
347 i
348 )),
349 tool_calls: None,
350 tool_call_id: None,
351 cache_control: None,
352 });
353 }
354
355 let result = handoff.handoff_to_model(
356 "claude-3-5-sonnet-20241022",
357 "claude-3-5-haiku-20241022",
358 &messages,
359 );
360
361 assert!(result.is_ok());
362 let handoff_messages = result.unwrap();
363
364 assert!(handoff_messages.len() < messages.len());
366 assert!(matches!(handoff_messages[0].role, Role::System));
367 }
368
369 #[test]
370 fn test_token_count_accuracy_within_5_percent() {
371 let handoff = ModelHandoff::new();
372 let text = "The quick brown fox jumps over the lazy dog. ";
373
374 let counted = handoff.count_tokens(text);
375
376 let expected = 11;
377 let tolerance = (expected as f64 * 0.10) as i32;
378
379 assert!(
380 (counted as i32 - expected).abs() <= tolerance,
381 "Token count {} not within {}% of expected {}",
382 counted,
383 10,
384 expected
385 );
386 }
387
388 #[test]
389 fn test_find_cut_point_basic() {
390 let handoff = ModelHandoff::new();
391
392 let messages: Vec<Message> = (0..10)
393 .map(|i| Message {
394 role: if i % 2 == 0 {
395 Role::User
396 } else {
397 Role::Assistant
398 },
399 content: Some(format!("Message {} with some content to make it longer", i)),
400 tool_calls: None,
401 tool_call_id: None,
402 cache_control: None,
403 })
404 .collect();
405
406 let cut = handoff.find_cut_point(&messages, 50);
407 assert!(cut.is_some());
408 let cut_idx = cut.unwrap();
409 assert!(cut_idx > 0);
410 assert!(cut_idx < messages.len());
411 }
412
413 #[test]
414 fn test_find_cut_point_empty_messages() {
415 let handoff = ModelHandoff::new();
416 let messages: Vec<Message> = vec![];
417
418 let cut = handoff.find_cut_point(&messages, 100);
419 assert!(cut.is_none());
420 }
421
422 #[test]
423 fn test_find_cut_point_all_fit() {
424 let handoff = ModelHandoff::new();
425
426 let messages = vec![
427 Message {
428 role: Role::User,
429 content: Some("Short".to_string()),
430 tool_calls: None,
431 tool_call_id: None,
432 cache_control: None,
433 },
434 Message {
435 role: Role::Assistant,
436 content: Some("Hi".to_string()),
437 tool_calls: None,
438 tool_call_id: None,
439 cache_control: None,
440 },
441 ];
442
443 let cut = handoff.find_cut_point(&messages, 1000);
444 assert_eq!(cut, Some(0));
445 }
446
447 #[test]
448 fn test_find_cut_point_prefers_user_message() {
449 let handoff = ModelHandoff::new();
450
451 let mut messages = vec![];
452 for _ in 0..5 {
453 messages.push(Message {
454 role: Role::User,
455 content: Some("This is a user message with enough content".to_string()),
456 tool_calls: None,
457 tool_call_id: None,
458 cache_control: None,
459 });
460 messages.push(Message {
461 role: Role::Assistant,
462 content: Some("Assistant reply".to_string()),
463 tool_calls: None,
464 tool_call_id: None,
465 cache_control: None,
466 });
467 }
468
469 let cut = handoff.find_cut_point(&messages, 30).unwrap();
470 assert!(matches!(messages[cut].role, Role::User));
471 }
472}