1use crate::error::LlmError;
2use crate::types::{Message, Role};
3use tiktoken_rs::cl100k_base;
4
5pub struct ModelHandoff {
6 tokenizer: tiktoken_rs::CoreBPE,
7}
8
9impl Default for ModelHandoff {
10 fn default() -> Self {
11 Self::new()
12 }
13}
14
15impl ModelHandoff {
16 pub fn new() -> Self {
17 Self {
18 tokenizer: cl100k_base().expect("Failed to load tokenizer"),
19 }
20 }
21
22 pub fn count_tokens(&self, text: &str) -> usize {
23 self.tokenizer.encode_with_special_tokens(text).len()
24 }
25
26 pub fn count_message_tokens(&self, message: &Message) -> usize {
27 let mut total = message
28 .content
29 .as_ref()
30 .map(|c| self.count_tokens(&c.to_text()))
31 .unwrap_or(0);
32
33 total += 4;
35
36 if let Some(tool_calls) = &message.tool_calls {
38 for call in tool_calls {
39 total += self.count_tokens(&call.id);
40 total += self.count_tokens(&call.function.name);
41 total += self.count_tokens(&call.function.arguments);
42 }
43 }
44
45 total
46 }
47
48 pub fn count_total_tokens(&self, messages: &[Message]) -> usize {
49 messages.iter().map(|m| self.count_message_tokens(m)).sum()
50 }
51
52 pub fn compact_messages(&self, messages: &[Message], target_tokens: usize) -> Vec<Message> {
53 let system_msg = messages.iter().find(|m| matches!(m.role, Role::System));
55
56 let non_system: Vec<_> = messages
58 .iter()
59 .filter(|m| !matches!(m.role, Role::System))
60 .cloned()
61 .collect();
62
63 let mut compacted = Vec::new();
64
65 if let Some(sys) = system_msg {
67 compacted.push(sys.clone());
68 }
69
70 let system_tokens = compacted
72 .iter()
73 .map(|m| self.count_message_tokens(m))
74 .sum::<usize>();
75
76 let safety_buffer = (target_tokens / 5).max(100);
78 let remaining_budget = target_tokens.saturating_sub(system_tokens + safety_buffer);
79
80 let mut selected = Vec::new();
82 let mut current_tokens = 0;
83
84 for msg in non_system.iter().rev() {
85 let msg_tokens = self.count_message_tokens(msg);
86
87 if current_tokens + msg_tokens <= remaining_budget {
88 current_tokens += msg_tokens;
89 selected.push(msg.clone());
90 } else {
91 break;
92 }
93 }
94
95 selected.reverse();
97 compacted.extend(selected);
98
99 compacted
100 }
101
102 pub fn handoff_to_model(
103 &self,
104 _from_model: &str,
105 to_model: &str,
106 messages: &[Message],
107 ) -> Result<Vec<Message>, LlmError> {
108 let target_tokens = match to_model {
110 "claude-3-5-sonnet-20241022" => 200000,
111 "claude-3-5-haiku-20241022" => 200000,
112 "claude-3-opus-20240229" => 200000,
113 "claude-3-sonnet-20240229" => 200000,
114 "claude-3-haiku-20240307" => 200000,
115 _ => 200000, };
117
118 let current_tokens = self.count_total_tokens(messages);
119
120 if current_tokens > target_tokens * 9 / 10 {
121 Ok(self.compact_messages(messages, target_tokens))
122 } else {
123 Ok(messages.to_vec())
124 }
125 }
126
127 pub fn find_cut_point(&self, messages: &[Message], keep_recent_tokens: usize) -> Option<usize> {
128 if messages.is_empty() {
129 return None;
130 }
131
132 let non_system: Vec<_> = messages
133 .iter()
134 .enumerate()
135 .filter(|(_, m)| !matches!(m.role, Role::System))
136 .collect();
137
138 if non_system.is_empty() {
139 return None;
140 }
141
142 let mut accumulated = 0;
143 for (idx, msg) in non_system.iter().rev() {
144 accumulated += self.count_message_tokens(msg);
145
146 if accumulated >= keep_recent_tokens {
147 let cut_idx = self.find_valid_cut_point(&non_system, *idx);
148 return Some(cut_idx);
149 }
150 }
151
152 Some(0)
153 }
154
155 fn find_valid_cut_point(&self, non_system: &[(usize, &Message)], min_idx: usize) -> usize {
156 for (idx, msg) in non_system.iter() {
157 if *idx >= min_idx && matches!(msg.role, Role::User) {
158 return *idx;
159 }
160 }
161
162 min_idx
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169 use crate::types::FunctionCall;
170 use crate::types::ToolCall;
171
172 #[test]
173 fn test_count_tokens_simple() {
174 let handoff = ModelHandoff::new();
175 let tokens = handoff.count_tokens("Hello, world!");
176 assert!(tokens > 0);
177 assert!(tokens < 10);
178 }
179
180 #[test]
181 fn test_count_message_tokens() {
182 let handoff = ModelHandoff::new();
183 let msg = Message {
184 role: Role::User,
185 content: Some(crate::MessageContent::text("Hello, world!")),
186 tool_calls: None,
187 tool_call_id: None,
188 cache_control: None,
189 };
190 let tokens = handoff.count_message_tokens(&msg);
191 assert!(tokens > 4); }
193
194 #[test]
195 fn test_count_message_tokens_with_tool_calls() {
196 let handoff = ModelHandoff::new();
197 let msg = Message {
198 role: Role::Assistant,
199 content: Some(crate::MessageContent::text("")),
200 tool_calls: Some(vec![ToolCall {
201 id: "call_123".to_string(),
202 tool_type: "function".to_string(),
203 function: FunctionCall {
204 name: "test_tool".to_string(),
205 arguments: serde_json::json!({"arg": "value"}).to_string(),
206 },
207 }]),
208 tool_call_id: None,
209 cache_control: None,
210 };
211 let tokens = handoff.count_message_tokens(&msg);
212 assert!(tokens > 10);
213 }
214
215 #[test]
216 fn test_count_total_tokens() {
217 let handoff = ModelHandoff::new();
218 let messages = vec![
219 Message {
220 role: Role::User,
221 content: Some(crate::MessageContent::text("Hello")),
222 tool_calls: None,
223 tool_call_id: None,
224 cache_control: None,
225 },
226 Message {
227 role: Role::Assistant,
228 content: Some(crate::MessageContent::text("Hi there!")),
229 tool_calls: None,
230 tool_call_id: None,
231 cache_control: None,
232 },
233 ];
234 let total = handoff.count_total_tokens(&messages);
235 assert!(total > 0);
236 }
237
238 #[test]
239 fn test_compact_messages_preserves_system() {
240 let handoff = ModelHandoff::new();
241 let messages = vec![
242 Message {
243 role: Role::System,
244 content: Some(crate::MessageContent::text("You are a helpful assistant.")),
245 tool_calls: None,
246 tool_call_id: None,
247 cache_control: None,
248 },
249 Message {
250 role: Role::User,
251 content: Some(crate::MessageContent::text("Hello")),
252 tool_calls: None,
253 tool_call_id: None,
254 cache_control: None,
255 },
256 ];
257 let compacted = handoff.compact_messages(&messages, 500);
258 assert!(!compacted.is_empty());
259 if compacted.len() > 1 {
260 assert!(matches!(compacted[0].role, Role::System));
261 }
262 }
263
264 #[test]
265 fn test_compact_messages_keeps_recent() {
266 let handoff = ModelHandoff::new();
267 let mut messages = vec![Message {
268 role: Role::System,
269 content: Some(crate::MessageContent::text("System")),
270 tool_calls: None,
271 tool_call_id: None,
272 cache_control: None,
273 }];
274
275 for i in 0..100 {
277 messages.push(Message {
278 role: if i % 2 == 0 {
279 Role::User
280 } else {
281 Role::Assistant
282 },
283 content: Some(crate::MessageContent::text(format!("Message {}", i))),
284 tool_calls: None,
285 tool_call_id: None,
286 cache_control: None,
287 });
288 }
289
290 let compacted = handoff.compact_messages(&messages, 500);
292
293 assert!(compacted.len() < messages.len());
295 assert!(matches!(compacted[0].role, Role::System));
296
297 assert_eq!(
299 compacted.last().unwrap().content,
300 Some(crate::MessageContent::text("Message 99"))
301 );
302 }
303
304 #[test]
305 fn test_handoff_to_model_no_compaction_needed() {
306 let handoff = ModelHandoff::new();
307 let messages = vec![Message {
308 role: Role::User,
309 content: Some(crate::MessageContent::text("Hello")),
310 tool_calls: None,
311 tool_call_id: None,
312 cache_control: None,
313 }];
314
315 let result = handoff.handoff_to_model(
316 "claude-3-5-sonnet-20241022",
317 "claude-3-5-haiku-20241022",
318 &messages,
319 );
320
321 assert!(result.is_ok());
322 let handoff_messages = result.unwrap();
323 assert_eq!(handoff_messages.len(), messages.len());
324 }
325
326 #[test]
327 fn test_handoff_to_model_compacts_when_needed() {
328 let handoff = ModelHandoff::new();
329 let mut messages = vec![Message {
330 role: Role::System,
331 content: Some(crate::MessageContent::text("System")),
332 tool_calls: None,
333 tool_call_id: None,
334 cache_control: None,
335 }];
336
337 for i in 0..5000 {
339 messages.push(Message {
340 role: if i % 2 == 0 {
341 Role::User
342 } else {
343 Role::Assistant
344 },
345 content: Some(crate::MessageContent::text(format!(
346 "This is message number {}. It contains significantly more content to ensure we exceed the context window limit. Each message should be approximately 50-60 tokens in length when encoded with the cl100k_base tokenizer. This allows us to test the compaction functionality effectively. ",
347 i
348 ))),
349 tool_calls: None,
350 tool_call_id: None,
351 cache_control: None,
352 });
353 }
354
355 let result = handoff.handoff_to_model(
356 "claude-3-5-sonnet-20241022",
357 "claude-3-5-haiku-20241022",
358 &messages,
359 );
360
361 assert!(result.is_ok());
362 let handoff_messages = result.unwrap();
363
364 assert!(handoff_messages.len() < messages.len());
366 assert!(matches!(handoff_messages[0].role, Role::System));
367 }
368
369 #[test]
370 fn test_token_count_accuracy_within_5_percent() {
371 let handoff = ModelHandoff::new();
372 let text = "The quick brown fox jumps over the lazy dog. ";
373
374 let counted = handoff.count_tokens(text);
375
376 let expected = 11;
377 let tolerance = (expected as f64 * 0.10) as i32;
378
379 assert!(
380 (counted as i32 - expected).abs() <= tolerance,
381 "Token count {} not within {}% of expected {}",
382 counted,
383 10,
384 expected
385 );
386 }
387
388 #[test]
389 fn test_find_cut_point_basic() {
390 let handoff = ModelHandoff::new();
391
392 let messages: Vec<Message> = (0..10)
393 .map(|i| Message {
394 role: if i % 2 == 0 {
395 Role::User
396 } else {
397 Role::Assistant
398 },
399 content: Some(crate::MessageContent::text(format!(
400 "Message {} with some content to make it longer",
401 i
402 ))),
403 tool_calls: None,
404 tool_call_id: None,
405 cache_control: None,
406 })
407 .collect();
408
409 let cut = handoff.find_cut_point(&messages, 50);
410 assert!(cut.is_some());
411 let cut_idx = cut.unwrap();
412 assert!(cut_idx > 0);
413 assert!(cut_idx < messages.len());
414 }
415
416 #[test]
417 fn test_find_cut_point_empty_messages() {
418 let handoff = ModelHandoff::new();
419 let messages: Vec<Message> = vec![];
420
421 let cut = handoff.find_cut_point(&messages, 100);
422 assert!(cut.is_none());
423 }
424
425 #[test]
426 fn test_find_cut_point_all_fit() {
427 let handoff = ModelHandoff::new();
428
429 let messages = vec![
430 Message {
431 role: Role::User,
432 content: Some(crate::MessageContent::text("Short")),
433 tool_calls: None,
434 tool_call_id: None,
435 cache_control: None,
436 },
437 Message {
438 role: Role::Assistant,
439 content: Some(crate::MessageContent::text("Hi")),
440 tool_calls: None,
441 tool_call_id: None,
442 cache_control: None,
443 },
444 ];
445
446 let cut = handoff.find_cut_point(&messages, 1000);
447 assert_eq!(cut, Some(0));
448 }
449
450 #[test]
451 fn test_find_cut_point_prefers_user_message() {
452 let handoff = ModelHandoff::new();
453
454 let mut messages = vec![];
455 for _ in 0..5 {
456 messages.push(Message {
457 role: Role::User,
458 content: Some(crate::MessageContent::text(
459 "This is a user message with enough content",
460 )),
461 tool_calls: None,
462 tool_call_id: None,
463 cache_control: None,
464 });
465 messages.push(Message {
466 role: Role::Assistant,
467 content: Some(crate::MessageContent::text("Assistant reply")),
468 tool_calls: None,
469 tool_call_id: None,
470 cache_control: None,
471 });
472 }
473
474 let cut = handoff.find_cut_point(&messages, 30).unwrap();
475 assert!(matches!(messages[cut].role, Role::User));
476 }
477}