Skip to main content

limit_llm/
handoff.rs

1use crate::error::LlmError;
2use crate::types::{Message, Role};
3use tiktoken_rs::cl100k_base;
4
5pub struct ModelHandoff {
6    tokenizer: tiktoken_rs::CoreBPE,
7}
8
9impl Default for ModelHandoff {
10    fn default() -> Self {
11        Self::new()
12    }
13}
14
15impl ModelHandoff {
16    pub fn new() -> Self {
17        Self {
18            tokenizer: cl100k_base().expect("Failed to load tokenizer"),
19        }
20    }
21
22    pub fn count_tokens(&self, text: &str) -> usize {
23        self.tokenizer.encode_with_special_tokens(text).len()
24    }
25
26    pub fn count_message_tokens(&self, message: &Message) -> usize {
27        let mut total = message
28            .content
29            .as_ref()
30            .map(|c| self.count_tokens(&c.to_text()))
31            .unwrap_or(0);
32
33        // Add role overhead (4 tokens for message format)
34        total += 4;
35
36        // Add tool_calls if present
37        if let Some(tool_calls) = &message.tool_calls {
38            for call in tool_calls {
39                total += self.count_tokens(&call.id);
40                total += self.count_tokens(&call.function.name);
41                total += self.count_tokens(&call.function.arguments);
42            }
43        }
44
45        total
46    }
47
48    pub fn count_total_tokens(&self, messages: &[Message]) -> usize {
49        messages.iter().map(|m| self.count_message_tokens(m)).sum()
50    }
51
52    pub fn compact_messages(&self, messages: &[Message], target_tokens: usize) -> Vec<Message> {
53        // Always keep system message if present
54        let system_msg = messages.iter().find(|m| matches!(m.role, Role::System));
55
56        // Count tokens without system message
57        let non_system: Vec<_> = messages
58            .iter()
59            .filter(|m| !matches!(m.role, Role::System))
60            .cloned()
61            .collect();
62
63        let mut compacted = Vec::new();
64
65        // Add system message first if exists
66        if let Some(sys) = system_msg {
67            compacted.push(sys.clone());
68        }
69
70        // Calculate target for non-system messages
71        let system_tokens = compacted
72            .iter()
73            .map(|m| self.count_message_tokens(m))
74            .sum::<usize>();
75
76        // Reserve 20% of budget for safety, minimum 100 tokens
77        let safety_buffer = (target_tokens / 5).max(100);
78        let remaining_budget = target_tokens.saturating_sub(system_tokens + safety_buffer);
79
80        // Keep last N messages that fit within remaining budget
81        let mut selected = Vec::new();
82        let mut current_tokens = 0;
83
84        for msg in non_system.iter().rev() {
85            let msg_tokens = self.count_message_tokens(msg);
86
87            if current_tokens + msg_tokens <= remaining_budget {
88                current_tokens += msg_tokens;
89                selected.push(msg.clone());
90            } else {
91                break;
92            }
93        }
94
95        // Reverse to maintain original order
96        selected.reverse();
97        compacted.extend(selected);
98
99        compacted
100    }
101
102    pub fn handoff_to_model(
103        &self,
104        _from_model: &str,
105        to_model: &str,
106        messages: &[Message],
107    ) -> Result<Vec<Message>, LlmError> {
108        // Define context windows for different models (approximate)
109        let target_tokens = match to_model {
110            "claude-3-5-sonnet-20241022" => 200000,
111            "claude-3-5-haiku-20241022" => 200000,
112            "claude-3-opus-20240229" => 200000,
113            "claude-3-sonnet-20240229" => 200000,
114            "claude-3-haiku-20240307" => 200000,
115            _ => 200000, // Default to 200K context window
116        };
117
118        let current_tokens = self.count_total_tokens(messages);
119
120        if current_tokens > target_tokens * 9 / 10 {
121            Ok(self.compact_messages(messages, target_tokens))
122        } else {
123            Ok(messages.to_vec())
124        }
125    }
126
127    pub fn find_cut_point(&self, messages: &[Message], keep_recent_tokens: usize) -> Option<usize> {
128        if messages.is_empty() {
129            return None;
130        }
131
132        let non_system: Vec<_> = messages
133            .iter()
134            .enumerate()
135            .filter(|(_, m)| !matches!(m.role, Role::System))
136            .collect();
137
138        if non_system.is_empty() {
139            return None;
140        }
141
142        let mut accumulated = 0;
143        for (idx, msg) in non_system.iter().rev() {
144            accumulated += self.count_message_tokens(msg);
145
146            if accumulated >= keep_recent_tokens {
147                let cut_idx = self.find_valid_cut_point(&non_system, *idx);
148                return Some(cut_idx);
149            }
150        }
151
152        Some(0)
153    }
154
155    fn find_valid_cut_point(&self, non_system: &[(usize, &Message)], min_idx: usize) -> usize {
156        for (idx, msg) in non_system.iter() {
157            if *idx >= min_idx && matches!(msg.role, Role::User) {
158                return *idx;
159            }
160        }
161
162        min_idx
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169    use crate::types::FunctionCall;
170    use crate::types::ToolCall;
171
172    #[test]
173    fn test_count_tokens_simple() {
174        let handoff = ModelHandoff::new();
175        let tokens = handoff.count_tokens("Hello, world!");
176        assert!(tokens > 0);
177        assert!(tokens < 10);
178    }
179
180    #[test]
181    fn test_count_message_tokens() {
182        let handoff = ModelHandoff::new();
183        let msg = Message {
184            role: Role::User,
185            content: Some(crate::MessageContent::text("Hello, world!")),
186            tool_calls: None,
187            tool_call_id: None,
188            cache_control: None,
189        };
190        let tokens = handoff.count_message_tokens(&msg);
191        assert!(tokens > 4); // Content tokens + role overhead
192    }
193
194    #[test]
195    fn test_count_message_tokens_with_tool_calls() {
196        let handoff = ModelHandoff::new();
197        let msg = Message {
198            role: Role::Assistant,
199            content: Some(crate::MessageContent::text("")),
200            tool_calls: Some(vec![ToolCall {
201                id: "call_123".to_string(),
202                tool_type: "function".to_string(),
203                function: FunctionCall {
204                    name: "test_tool".to_string(),
205                    arguments: serde_json::json!({"arg": "value"}).to_string(),
206                },
207            }]),
208            tool_call_id: None,
209            cache_control: None,
210        };
211        let tokens = handoff.count_message_tokens(&msg);
212        assert!(tokens > 10);
213    }
214
215    #[test]
216    fn test_count_total_tokens() {
217        let handoff = ModelHandoff::new();
218        let messages = vec![
219            Message {
220                role: Role::User,
221                content: Some(crate::MessageContent::text("Hello")),
222                tool_calls: None,
223                tool_call_id: None,
224                cache_control: None,
225            },
226            Message {
227                role: Role::Assistant,
228                content: Some(crate::MessageContent::text("Hi there!")),
229                tool_calls: None,
230                tool_call_id: None,
231                cache_control: None,
232            },
233        ];
234        let total = handoff.count_total_tokens(&messages);
235        assert!(total > 0);
236    }
237
238    #[test]
239    fn test_compact_messages_preserves_system() {
240        let handoff = ModelHandoff::new();
241        let messages = vec![
242            Message {
243                role: Role::System,
244                content: Some(crate::MessageContent::text("You are a helpful assistant.")),
245                tool_calls: None,
246                tool_call_id: None,
247                cache_control: None,
248            },
249            Message {
250                role: Role::User,
251                content: Some(crate::MessageContent::text("Hello")),
252                tool_calls: None,
253                tool_call_id: None,
254                cache_control: None,
255            },
256        ];
257        let compacted = handoff.compact_messages(&messages, 500);
258        assert!(!compacted.is_empty());
259        if compacted.len() > 1 {
260            assert!(matches!(compacted[0].role, Role::System));
261        }
262    }
263
264    #[test]
265    fn test_compact_messages_keeps_recent() {
266        let handoff = ModelHandoff::new();
267        let mut messages = vec![Message {
268            role: Role::System,
269            content: Some(crate::MessageContent::text("System")),
270            tool_calls: None,
271            tool_call_id: None,
272            cache_control: None,
273        }];
274
275        // Add 100 messages
276        for i in 0..100 {
277            messages.push(Message {
278                role: if i % 2 == 0 {
279                    Role::User
280                } else {
281                    Role::Assistant
282                },
283                content: Some(crate::MessageContent::text(format!("Message {}", i))),
284                tool_calls: None,
285                tool_call_id: None,
286                cache_control: None,
287            });
288        }
289
290        // Compact to small budget
291        let compacted = handoff.compact_messages(&messages, 500);
292
293        // Should have system + recent messages
294        assert!(compacted.len() < messages.len());
295        assert!(matches!(compacted[0].role, Role::System));
296
297        // Last message should be preserved
298        assert_eq!(
299            compacted.last().unwrap().content,
300            Some(crate::MessageContent::text("Message 99"))
301        );
302    }
303
304    #[test]
305    fn test_handoff_to_model_no_compaction_needed() {
306        let handoff = ModelHandoff::new();
307        let messages = vec![Message {
308            role: Role::User,
309            content: Some(crate::MessageContent::text("Hello")),
310            tool_calls: None,
311            tool_call_id: None,
312            cache_control: None,
313        }];
314
315        let result = handoff.handoff_to_model(
316            "claude-3-5-sonnet-20241022",
317            "claude-3-5-haiku-20241022",
318            &messages,
319        );
320
321        assert!(result.is_ok());
322        let handoff_messages = result.unwrap();
323        assert_eq!(handoff_messages.len(), messages.len());
324    }
325
326    #[test]
327    fn test_handoff_to_model_compacts_when_needed() {
328        let handoff = ModelHandoff::new();
329        let mut messages = vec![Message {
330            role: Role::System,
331            content: Some(crate::MessageContent::text("System")),
332            tool_calls: None,
333            tool_call_id: None,
334            cache_control: None,
335        }];
336
337        // Create 5000 messages with substantial content to exceed 200K context
338        for i in 0..5000 {
339            messages.push(Message {
340                role: if i % 2 == 0 {
341                    Role::User
342                } else {
343                    Role::Assistant
344                },
345                content: Some(crate::MessageContent::text(format!(
346                    "This is message number {}. It contains significantly more content to ensure we exceed the context window limit. Each message should be approximately 50-60 tokens in length when encoded with the cl100k_base tokenizer. This allows us to test the compaction functionality effectively. ",
347                    i
348                ))),
349                tool_calls: None,
350                tool_call_id: None,
351            cache_control: None,
352            });
353        }
354
355        let result = handoff.handoff_to_model(
356            "claude-3-5-sonnet-20241022",
357            "claude-3-5-haiku-20241022",
358            &messages,
359        );
360
361        assert!(result.is_ok());
362        let handoff_messages = result.unwrap();
363
364        // Should have compacted
365        assert!(handoff_messages.len() < messages.len());
366        assert!(matches!(handoff_messages[0].role, Role::System));
367    }
368
369    #[test]
370    fn test_token_count_accuracy_within_5_percent() {
371        let handoff = ModelHandoff::new();
372        let text = "The quick brown fox jumps over the lazy dog. ";
373
374        let counted = handoff.count_tokens(text);
375
376        let expected = 11;
377        let tolerance = (expected as f64 * 0.10) as i32;
378
379        assert!(
380            (counted as i32 - expected).abs() <= tolerance,
381            "Token count {} not within {}% of expected {}",
382            counted,
383            10,
384            expected
385        );
386    }
387
388    #[test]
389    fn test_find_cut_point_basic() {
390        let handoff = ModelHandoff::new();
391
392        let messages: Vec<Message> = (0..10)
393            .map(|i| Message {
394                role: if i % 2 == 0 {
395                    Role::User
396                } else {
397                    Role::Assistant
398                },
399                content: Some(crate::MessageContent::text(format!(
400                    "Message {} with some content to make it longer",
401                    i
402                ))),
403                tool_calls: None,
404                tool_call_id: None,
405                cache_control: None,
406            })
407            .collect();
408
409        let cut = handoff.find_cut_point(&messages, 50);
410        assert!(cut.is_some());
411        let cut_idx = cut.unwrap();
412        assert!(cut_idx > 0);
413        assert!(cut_idx < messages.len());
414    }
415
416    #[test]
417    fn test_find_cut_point_empty_messages() {
418        let handoff = ModelHandoff::new();
419        let messages: Vec<Message> = vec![];
420
421        let cut = handoff.find_cut_point(&messages, 100);
422        assert!(cut.is_none());
423    }
424
425    #[test]
426    fn test_find_cut_point_all_fit() {
427        let handoff = ModelHandoff::new();
428
429        let messages = vec![
430            Message {
431                role: Role::User,
432                content: Some(crate::MessageContent::text("Short")),
433                tool_calls: None,
434                tool_call_id: None,
435                cache_control: None,
436            },
437            Message {
438                role: Role::Assistant,
439                content: Some(crate::MessageContent::text("Hi")),
440                tool_calls: None,
441                tool_call_id: None,
442                cache_control: None,
443            },
444        ];
445
446        let cut = handoff.find_cut_point(&messages, 1000);
447        assert_eq!(cut, Some(0));
448    }
449
450    #[test]
451    fn test_find_cut_point_prefers_user_message() {
452        let handoff = ModelHandoff::new();
453
454        let mut messages = vec![];
455        for _ in 0..5 {
456            messages.push(Message {
457                role: Role::User,
458                content: Some(crate::MessageContent::text(
459                    "This is a user message with enough content",
460                )),
461                tool_calls: None,
462                tool_call_id: None,
463                cache_control: None,
464            });
465            messages.push(Message {
466                role: Role::Assistant,
467                content: Some(crate::MessageContent::text("Assistant reply")),
468                tool_calls: None,
469                tool_call_id: None,
470                cache_control: None,
471            });
472        }
473
474        let cut = handoff.find_cut_point(&messages, 30).unwrap();
475        assert!(matches!(messages[cut].role, Role::User));
476    }
477}