Skip to main content

limit_llm/
handoff.rs

1use crate::error::LlmError;
2use crate::types::{Message, Role};
3use tiktoken_rs::cl100k_base;
4
5pub struct ModelHandoff {
6    tokenizer: tiktoken_rs::CoreBPE,
7}
8
9impl Default for ModelHandoff {
10    fn default() -> Self {
11        Self::new()
12    }
13}
14
15impl ModelHandoff {
16    pub fn new() -> Self {
17        Self {
18            tokenizer: cl100k_base().expect("Failed to load tokenizer"),
19        }
20    }
21
22    pub fn count_tokens(&self, text: &str) -> usize {
23        self.tokenizer.encode_with_special_tokens(text).len()
24    }
25
26    pub fn count_message_tokens(&self, message: &Message) -> usize {
27        let mut total = message
28            .content
29            .as_ref()
30            .map(|c| self.count_tokens(c))
31            .unwrap_or(0);
32
33        // Add role overhead (4 tokens for message format)
34        total += 4;
35
36        // Add tool_calls if present
37        if let Some(tool_calls) = &message.tool_calls {
38            for call in tool_calls {
39                total += self.count_tokens(&call.id);
40                total += self.count_tokens(&call.function.name);
41                total += self.count_tokens(&call.function.arguments);
42            }
43        }
44
45        total
46    }
47
48    pub fn count_total_tokens(&self, messages: &[Message]) -> usize {
49        messages.iter().map(|m| self.count_message_tokens(m)).sum()
50    }
51
52    pub fn compact_messages(&self, messages: &[Message], target_tokens: usize) -> Vec<Message> {
53        // Always keep system message if present
54        let system_msg = messages.iter().find(|m| matches!(m.role, Role::System));
55
56        // Count tokens without system message
57        let non_system: Vec<_> = messages
58            .iter()
59            .filter(|m| !matches!(m.role, Role::System))
60            .cloned()
61            .collect();
62
63        let mut compacted = Vec::new();
64
65        // Add system message first if exists
66        if let Some(sys) = system_msg {
67            compacted.push(sys.clone());
68        }
69
70        // Calculate target for non-system messages
71        let system_tokens = compacted
72            .iter()
73            .map(|m| self.count_message_tokens(m))
74            .sum::<usize>();
75
76        // Reserve 20% of budget for safety, minimum 100 tokens
77        let safety_buffer = (target_tokens / 5).max(100);
78        let remaining_budget = target_tokens.saturating_sub(system_tokens + safety_buffer);
79
80        // Keep last N messages that fit within remaining budget
81        let mut selected = Vec::new();
82        let mut current_tokens = 0;
83
84        for msg in non_system.iter().rev() {
85            let msg_tokens = self.count_message_tokens(msg);
86
87            if current_tokens + msg_tokens <= remaining_budget {
88                current_tokens += msg_tokens;
89                selected.push(msg.clone());
90            } else {
91                break;
92            }
93        }
94
95        // Reverse to maintain original order
96        selected.reverse();
97        compacted.extend(selected);
98
99        compacted
100    }
101
102    pub fn handoff_to_model(
103        &self,
104        _from_model: &str,
105        to_model: &str,
106        messages: &[Message],
107    ) -> Result<Vec<Message>, LlmError> {
108        // Define context windows for different models (approximate)
109        let target_tokens = match to_model {
110            "claude-3-5-sonnet-20241022" => 200000,
111            "claude-3-5-haiku-20241022" => 200000,
112            "claude-3-opus-20240229" => 200000,
113            "claude-3-sonnet-20240229" => 200000,
114            "claude-3-haiku-20240307" => 200000,
115            _ => 200000, // Default to 200K context window
116        };
117
118        let current_tokens = self.count_total_tokens(messages);
119
120        // Only compact if we're over the target (with 10% buffer for safety)
121        if current_tokens > target_tokens * 9 / 10 {
122            Ok(self.compact_messages(messages, target_tokens))
123        } else {
124            Ok(messages.to_vec())
125        }
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use super::*;
132    use crate::types::FunctionCall;
133    use crate::types::ToolCall;
134
135    #[test]
136    fn test_count_tokens_simple() {
137        let handoff = ModelHandoff::new();
138        let tokens = handoff.count_tokens("Hello, world!");
139        assert!(tokens > 0);
140        assert!(tokens < 10);
141    }
142
143    #[test]
144    fn test_count_message_tokens() {
145        let handoff = ModelHandoff::new();
146        let msg = Message {
147            role: Role::User,
148            content: Some("Hello, world!".to_string()),
149            tool_calls: None,
150            tool_call_id: None,
151        };
152        let tokens = handoff.count_message_tokens(&msg);
153        assert!(tokens > 4); // Content tokens + role overhead
154    }
155
156    #[test]
157    fn test_count_message_tokens_with_tool_calls() {
158        let handoff = ModelHandoff::new();
159        let msg = Message {
160            role: Role::Assistant,
161            content: Some("".to_string()),
162            tool_calls: Some(vec![ToolCall {
163                id: "call_123".to_string(),
164                tool_type: "function".to_string(),
165                function: FunctionCall {
166                    name: "test_tool".to_string(),
167                    arguments: serde_json::json!({"arg": "value"}).to_string(),
168                },
169            }]),
170            tool_call_id: None,
171        };
172        let tokens = handoff.count_message_tokens(&msg);
173        assert!(tokens > 10);
174    }
175
176    #[test]
177    fn test_count_total_tokens() {
178        let handoff = ModelHandoff::new();
179        let messages = vec![
180            Message {
181                role: Role::User,
182                content: Some("Hello".to_string()),
183                tool_calls: None,
184                tool_call_id: None,
185            },
186            Message {
187                role: Role::Assistant,
188                content: Some("Hi there!".to_string()),
189                tool_calls: None,
190                tool_call_id: None,
191            },
192        ];
193        let total = handoff.count_total_tokens(&messages);
194        assert!(total > 0);
195    }
196
197    #[test]
198    fn test_compact_messages_preserves_system() {
199        let handoff = ModelHandoff::new();
200        let messages = vec![
201            Message {
202                role: Role::System,
203                content: Some("You are a helpful assistant.".to_string()),
204                tool_calls: None,
205                tool_call_id: None,
206            },
207            Message {
208                role: Role::User,
209                content: Some("Hello".to_string()),
210                tool_calls: None,
211                tool_call_id: None,
212            },
213        ];
214        let compacted = handoff.compact_messages(&messages, 500);
215        assert!(!compacted.is_empty());
216        if compacted.len() > 1 {
217            assert!(matches!(compacted[0].role, Role::System));
218        }
219    }
220
221    #[test]
222    fn test_compact_messages_keeps_recent() {
223        let handoff = ModelHandoff::new();
224        let mut messages = vec![Message {
225            role: Role::System,
226            content: Some("System".to_string()),
227            tool_calls: None,
228            tool_call_id: None,
229        }];
230
231        // Add 100 messages
232        for i in 0..100 {
233            messages.push(Message {
234                role: if i % 2 == 0 {
235                    Role::User
236                } else {
237                    Role::Assistant
238                },
239                content: Some(format!("Message {}", i)),
240                tool_calls: None,
241                tool_call_id: None,
242            });
243        }
244
245        // Compact to small budget
246        let compacted = handoff.compact_messages(&messages, 500);
247
248        // Should have system + recent messages
249        assert!(compacted.len() < messages.len());
250        assert!(matches!(compacted[0].role, Role::System));
251
252        // Last message should be preserved
253        assert_eq!(
254            compacted.last().unwrap().content,
255            Some("Message 99".to_string())
256        );
257    }
258
259    #[test]
260    fn test_handoff_to_model_no_compaction_needed() {
261        let handoff = ModelHandoff::new();
262        let messages = vec![Message {
263            role: Role::User,
264            content: Some("Hello".to_string()),
265            tool_calls: None,
266            tool_call_id: None,
267        }];
268
269        let result = handoff.handoff_to_model(
270            "claude-3-5-sonnet-20241022",
271            "claude-3-5-haiku-20241022",
272            &messages,
273        );
274
275        assert!(result.is_ok());
276        let handoff_messages = result.unwrap();
277        assert_eq!(handoff_messages.len(), messages.len());
278    }
279
280    #[test]
281    fn test_handoff_to_model_compacts_when_needed() {
282        let handoff = ModelHandoff::new();
283        let mut messages = vec![Message {
284            role: Role::System,
285            content: Some("System".to_string()),
286            tool_calls: None,
287            tool_call_id: None,
288        }];
289
290        // Create 5000 messages with substantial content to exceed 200K context
291        for i in 0..5000 {
292            messages.push(Message {
293                role: if i % 2 == 0 {
294                    Role::User
295                } else {
296                    Role::Assistant
297                },
298                content: Some(format!(
299                    "This is message number {}. It contains significantly more content to ensure we exceed the context window limit. Each message should be approximately 50-60 tokens in length when encoded with the cl100k_base tokenizer. This allows us to test the compaction functionality effectively. ",
300                    i
301                )),
302                tool_calls: None,
303                tool_call_id: None,
304            });
305        }
306
307        let result = handoff.handoff_to_model(
308            "claude-3-5-sonnet-20241022",
309            "claude-3-5-haiku-20241022",
310            &messages,
311        );
312
313        assert!(result.is_ok());
314        let handoff_messages = result.unwrap();
315
316        // Should have compacted
317        assert!(handoff_messages.len() < messages.len());
318        assert!(matches!(handoff_messages[0].role, Role::System));
319    }
320
321    #[test]
322    fn test_token_count_accuracy_within_5_percent() {
323        let handoff = ModelHandoff::new();
324        let text = "The quick brown fox jumps over the lazy dog. ";
325
326        // Count tokens
327        let counted = handoff.count_tokens(text);
328
329        // Expected value based on cl100k_base tokenizer
330        let expected = 11;
331        // Allow 10% tolerance for tokenizer variations
332        let tolerance = (expected as f64 * 0.10) as i32;
333
334        assert!(
335            (counted as i32 - expected).abs() <= tolerance,
336            "Token count {} not within {}% of expected {}",
337            counted,
338            10,
339            expected
340        );
341    }
342}