1use crate::error::LlmError;
2use crate::types::{Message, Role};
3use tiktoken_rs::cl100k_base;
4
5pub struct ModelHandoff {
6 tokenizer: tiktoken_rs::CoreBPE,
7}
8
9impl Default for ModelHandoff {
10 fn default() -> Self {
11 Self::new()
12 }
13}
14
15impl ModelHandoff {
16 pub fn new() -> Self {
17 Self {
18 tokenizer: cl100k_base().expect("Failed to load tokenizer"),
19 }
20 }
21
22 pub fn count_tokens(&self, text: &str) -> usize {
23 self.tokenizer.encode_with_special_tokens(text).len()
24 }
25
26 pub fn count_message_tokens(&self, message: &Message) -> usize {
27 let mut total = message
28 .content
29 .as_ref()
30 .map(|c| self.count_tokens(c))
31 .unwrap_or(0);
32
33 total += 4;
35
36 if let Some(tool_calls) = &message.tool_calls {
38 for call in tool_calls {
39 total += self.count_tokens(&call.id);
40 total += self.count_tokens(&call.function.name);
41 total += self.count_tokens(&call.function.arguments);
42 }
43 }
44
45 total
46 }
47
48 pub fn count_total_tokens(&self, messages: &[Message]) -> usize {
49 messages.iter().map(|m| self.count_message_tokens(m)).sum()
50 }
51
52 pub fn compact_messages(&self, messages: &[Message], target_tokens: usize) -> Vec<Message> {
53 let system_msg = messages.iter().find(|m| matches!(m.role, Role::System));
55
56 let non_system: Vec<_> = messages
58 .iter()
59 .filter(|m| !matches!(m.role, Role::System))
60 .cloned()
61 .collect();
62
63 let mut compacted = Vec::new();
64
65 if let Some(sys) = system_msg {
67 compacted.push(sys.clone());
68 }
69
70 let system_tokens = compacted
72 .iter()
73 .map(|m| self.count_message_tokens(m))
74 .sum::<usize>();
75
76 let safety_buffer = (target_tokens / 5).max(100);
78 let remaining_budget = target_tokens.saturating_sub(system_tokens + safety_buffer);
79
80 let mut selected = Vec::new();
82 let mut current_tokens = 0;
83
84 for msg in non_system.iter().rev() {
85 let msg_tokens = self.count_message_tokens(msg);
86
87 if current_tokens + msg_tokens <= remaining_budget {
88 current_tokens += msg_tokens;
89 selected.push(msg.clone());
90 } else {
91 break;
92 }
93 }
94
95 selected.reverse();
97 compacted.extend(selected);
98
99 compacted
100 }
101
102 pub fn handoff_to_model(
103 &self,
104 _from_model: &str,
105 to_model: &str,
106 messages: &[Message],
107 ) -> Result<Vec<Message>, LlmError> {
108 let target_tokens = match to_model {
110 "claude-3-5-sonnet-20241022" => 200000,
111 "claude-3-5-haiku-20241022" => 200000,
112 "claude-3-opus-20240229" => 200000,
113 "claude-3-sonnet-20240229" => 200000,
114 "claude-3-haiku-20240307" => 200000,
115 _ => 200000, };
117
118 let current_tokens = self.count_total_tokens(messages);
119
120 if current_tokens > target_tokens * 9 / 10 {
122 Ok(self.compact_messages(messages, target_tokens))
123 } else {
124 Ok(messages.to_vec())
125 }
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use super::*;
132 use crate::types::FunctionCall;
133 use crate::types::ToolCall;
134
135 #[test]
136 fn test_count_tokens_simple() {
137 let handoff = ModelHandoff::new();
138 let tokens = handoff.count_tokens("Hello, world!");
139 assert!(tokens > 0);
140 assert!(tokens < 10);
141 }
142
143 #[test]
144 fn test_count_message_tokens() {
145 let handoff = ModelHandoff::new();
146 let msg = Message {
147 role: Role::User,
148 content: Some("Hello, world!".to_string()),
149 tool_calls: None,
150 tool_call_id: None,
151 };
152 let tokens = handoff.count_message_tokens(&msg);
153 assert!(tokens > 4); }
155
156 #[test]
157 fn test_count_message_tokens_with_tool_calls() {
158 let handoff = ModelHandoff::new();
159 let msg = Message {
160 role: Role::Assistant,
161 content: Some("".to_string()),
162 tool_calls: Some(vec![ToolCall {
163 id: "call_123".to_string(),
164 tool_type: "function".to_string(),
165 function: FunctionCall {
166 name: "test_tool".to_string(),
167 arguments: serde_json::json!({"arg": "value"}).to_string(),
168 },
169 }]),
170 tool_call_id: None,
171 };
172 let tokens = handoff.count_message_tokens(&msg);
173 assert!(tokens > 10);
174 }
175
176 #[test]
177 fn test_count_total_tokens() {
178 let handoff = ModelHandoff::new();
179 let messages = vec![
180 Message {
181 role: Role::User,
182 content: Some("Hello".to_string()),
183 tool_calls: None,
184 tool_call_id: None,
185 },
186 Message {
187 role: Role::Assistant,
188 content: Some("Hi there!".to_string()),
189 tool_calls: None,
190 tool_call_id: None,
191 },
192 ];
193 let total = handoff.count_total_tokens(&messages);
194 assert!(total > 0);
195 }
196
197 #[test]
198 fn test_compact_messages_preserves_system() {
199 let handoff = ModelHandoff::new();
200 let messages = vec![
201 Message {
202 role: Role::System,
203 content: Some("You are a helpful assistant.".to_string()),
204 tool_calls: None,
205 tool_call_id: None,
206 },
207 Message {
208 role: Role::User,
209 content: Some("Hello".to_string()),
210 tool_calls: None,
211 tool_call_id: None,
212 },
213 ];
214 let compacted = handoff.compact_messages(&messages, 500);
215 assert!(!compacted.is_empty());
216 if compacted.len() > 1 {
217 assert!(matches!(compacted[0].role, Role::System));
218 }
219 }
220
221 #[test]
222 fn test_compact_messages_keeps_recent() {
223 let handoff = ModelHandoff::new();
224 let mut messages = vec![Message {
225 role: Role::System,
226 content: Some("System".to_string()),
227 tool_calls: None,
228 tool_call_id: None,
229 }];
230
231 for i in 0..100 {
233 messages.push(Message {
234 role: if i % 2 == 0 {
235 Role::User
236 } else {
237 Role::Assistant
238 },
239 content: Some(format!("Message {}", i)),
240 tool_calls: None,
241 tool_call_id: None,
242 });
243 }
244
245 let compacted = handoff.compact_messages(&messages, 500);
247
248 assert!(compacted.len() < messages.len());
250 assert!(matches!(compacted[0].role, Role::System));
251
252 assert_eq!(
254 compacted.last().unwrap().content,
255 Some("Message 99".to_string())
256 );
257 }
258
259 #[test]
260 fn test_handoff_to_model_no_compaction_needed() {
261 let handoff = ModelHandoff::new();
262 let messages = vec![Message {
263 role: Role::User,
264 content: Some("Hello".to_string()),
265 tool_calls: None,
266 tool_call_id: None,
267 }];
268
269 let result = handoff.handoff_to_model(
270 "claude-3-5-sonnet-20241022",
271 "claude-3-5-haiku-20241022",
272 &messages,
273 );
274
275 assert!(result.is_ok());
276 let handoff_messages = result.unwrap();
277 assert_eq!(handoff_messages.len(), messages.len());
278 }
279
280 #[test]
281 fn test_handoff_to_model_compacts_when_needed() {
282 let handoff = ModelHandoff::new();
283 let mut messages = vec![Message {
284 role: Role::System,
285 content: Some("System".to_string()),
286 tool_calls: None,
287 tool_call_id: None,
288 }];
289
290 for i in 0..5000 {
292 messages.push(Message {
293 role: if i % 2 == 0 {
294 Role::User
295 } else {
296 Role::Assistant
297 },
298 content: Some(format!(
299 "This is message number {}. It contains significantly more content to ensure we exceed the context window limit. Each message should be approximately 50-60 tokens in length when encoded with the cl100k_base tokenizer. This allows us to test the compaction functionality effectively. ",
300 i
301 )),
302 tool_calls: None,
303 tool_call_id: None,
304 });
305 }
306
307 let result = handoff.handoff_to_model(
308 "claude-3-5-sonnet-20241022",
309 "claude-3-5-haiku-20241022",
310 &messages,
311 );
312
313 assert!(result.is_ok());
314 let handoff_messages = result.unwrap();
315
316 assert!(handoff_messages.len() < messages.len());
318 assert!(matches!(handoff_messages[0].role, Role::System));
319 }
320
321 #[test]
322 fn test_token_count_accuracy_within_5_percent() {
323 let handoff = ModelHandoff::new();
324 let text = "The quick brown fox jumps over the lazy dog. ";
325
326 let counted = handoff.count_tokens(text);
328
329 let expected = 11;
331 let tolerance = (expected as f64 * 0.10) as i32;
333
334 assert!(
335 (counted as i32 - expected).abs() <= tolerance,
336 "Token count {} not within {}% of expected {}",
337 counted,
338 10,
339 expected
340 );
341 }
342}