1use std::collections::HashMap;
2
3use imp_llm::{truncate_chars_with_suffix, ContentBlock, Message, Model};
4
5fn truncate_for_display(text: &str, max_chars: usize) -> String {
6 truncate_chars_with_suffix(text, max_chars, "...")
7}
8
9#[derive(Debug, Clone)]
11pub struct ContextUsage {
12 pub used: u32,
13 pub limit: u32,
14 pub ratio: f64,
15}
16
17pub fn estimate_tokens(text: &str) -> u32 {
19 (text.len() as u32) / 4
20}
21
22pub fn context_usage(messages: &[Message], model: &Model) -> ContextUsage {
24 let used: u32 = messages
25 .iter()
26 .map(|m| {
27 let json = serde_json::to_string(m).unwrap_or_default();
28 estimate_tokens(&json)
29 })
30 .sum();
31 let limit = model.meta.context_window;
32 let ratio = if limit > 0 {
33 used as f64 / limit as f64
34 } else {
35 0.0
36 };
37 ContextUsage { used, limit, ratio }
38}
39
40pub fn mask_observations(messages: &mut [Message], keep_recent_turns: usize) {
47 let turn_starts: Vec<usize> = messages
49 .iter()
50 .enumerate()
51 .filter(|(_, m)| m.is_assistant())
52 .map(|(i, _)| i)
53 .collect();
54
55 if turn_starts.len() <= keep_recent_turns {
56 return;
57 }
58
59 let cutoff_turn = turn_starts.len() - keep_recent_turns;
61 let cutoff_msg_idx = turn_starts[cutoff_turn];
62
63 let mut args_map: HashMap<String, String> = HashMap::new();
66 for msg in &messages[..cutoff_msg_idx] {
67 if let Message::Assistant(assistant) = msg {
68 for block in &assistant.content {
69 if let ContentBlock::ToolCall { id, arguments, .. } = block {
70 let args_json = serde_json::to_string(arguments).unwrap_or_default();
71 let summary = truncate_for_display(&args_json, 100);
72 args_map.insert(id.clone(), summary);
73 }
74 }
75 }
76 }
77
78 for msg in &mut messages[..cutoff_msg_idx] {
80 if let Message::ToolResult(ref mut result) = msg {
81 let byte_count: usize = result
82 .content
83 .iter()
84 .map(|b| match b {
85 ContentBlock::Text { text } => text.len(),
86 _ => 0,
87 })
88 .sum();
89
90 let args_summary = args_map
91 .get(&result.tool_call_id)
92 .map(|s| s.as_str())
93 .unwrap_or("");
94
95 let placeholder = format!(
96 "[Output omitted — ran {}({}), returned {} bytes]",
97 result.tool_name, args_summary, byte_count
98 );
99 result.content = vec![ContentBlock::Text { text: placeholder }];
100 }
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107 use std::pin::Pin;
108 use std::sync::Arc;
109
110 use async_trait::async_trait;
111 use futures_core::Stream;
112 use imp_llm::model::{Capabilities, ModelMeta, ModelPricing};
113 use imp_llm::provider::Provider;
114 use imp_llm::{AssistantMessage, RequestOptions, StopReason, StreamEvent, ToolResultMessage};
115
116 fn make_user(text: &str) -> Message {
119 Message::user(text)
120 }
121
122 fn make_assistant_tool_call(
123 call_id: &str,
124 tool_name: &str,
125 args: serde_json::Value,
126 ) -> Message {
127 Message::Assistant(AssistantMessage {
128 content: vec![ContentBlock::ToolCall {
129 id: call_id.into(),
130 name: tool_name.into(),
131 arguments: args,
132 }],
133 usage: None,
134 stop_reason: StopReason::ToolUse,
135 timestamp: 1000,
136 })
137 }
138
139 fn make_assistant_text(text: &str) -> Message {
140 Message::Assistant(AssistantMessage {
141 content: vec![ContentBlock::Text { text: text.into() }],
142 usage: None,
143 stop_reason: StopReason::EndTurn,
144 timestamp: 1000,
145 })
146 }
147
148 fn make_tool_result(call_id: &str, tool_name: &str, output: &str) -> Message {
149 Message::ToolResult(ToolResultMessage {
150 tool_call_id: call_id.into(),
151 tool_name: tool_name.into(),
152 content: vec![ContentBlock::Text {
153 text: output.into(),
154 }],
155 is_error: false,
156 details: serde_json::Value::Null,
157 timestamp: 1000,
158 })
159 }
160
161 fn tool_result_text(msg: &Message) -> &str {
162 match msg {
163 Message::ToolResult(tr) => match &tr.content[0] {
164 ContentBlock::Text { text } => text.as_str(),
165 _ => panic!("expected text block"),
166 },
167 _ => panic!("expected ToolResult"),
168 }
169 }
170
171 struct NullProvider;
173
174 #[async_trait]
175 impl Provider for NullProvider {
176 fn stream(
177 &self,
178 _model: &Model,
179 _context: imp_llm::Context,
180 _options: RequestOptions,
181 _api_key: &str,
182 ) -> Pin<Box<dyn Stream<Item = imp_llm::Result<StreamEvent>> + Send>> {
183 Box::pin(futures::stream::empty())
184 }
185
186 async fn resolve_auth(
187 &self,
188 _auth: &imp_llm::auth::AuthStore,
189 ) -> imp_llm::Result<imp_llm::auth::ApiKey> {
190 Ok("test".into())
191 }
192
193 fn id(&self) -> &str {
194 "null"
195 }
196
197 fn models(&self) -> &[ModelMeta] {
198 &[]
199 }
200 }
201
202 fn test_model() -> Model {
203 Model {
204 meta: ModelMeta {
205 id: "test".into(),
206 provider: "test".into(),
207 name: "Test".into(),
208 context_window: 100_000,
209 max_output_tokens: 4096,
210 pricing: ModelPricing::default(),
211 capabilities: Capabilities::default(),
212 },
213 provider: Arc::new(NullProvider),
214 }
215 }
216
217 #[test]
220 fn estimate_tokens_rough_accuracy_for_english() {
221 let text = "The quick brown fox jumps over the lazy dog";
225 let est = estimate_tokens(text);
226 let actual_approx = 10u32;
227 assert!(
228 est <= actual_approx * 2 && est * 2 >= actual_approx,
229 "estimate {est} should be within 2x of ~{actual_approx}"
230 );
231 }
232
233 #[test]
234 fn estimate_tokens_longer_text() {
235 let text = "Rust is a multi-paradigm programming language designed for performance \
237 and safety, especially safe concurrency. Rust is syntactically similar to C++ \
238 but can guarantee memory safety by using a borrow checker to validate references. \
239 Rust achieves memory safety without garbage collection, and reference counting \
240 is optional. Rust was originally designed by Graydon Hoare at Mozilla Research.";
241 let est = estimate_tokens(text);
242 assert!(est > 40 && est < 200, "estimate {est} out of range");
244 }
245
246 #[test]
249 fn mask_observations_20_turns_keeps_last_10() {
250 let mut messages = Vec::new();
251 messages.push(make_user("initial prompt"));
252
253 for i in 0..20 {
254 let call_id = format!("call_{i}");
255 messages.push(make_assistant_tool_call(
256 &call_id,
257 "read_file",
258 serde_json::json!({"path": format!("/tmp/file_{i}.rs")}),
259 ));
260 messages.push(make_tool_result(
261 &call_id,
262 "read_file",
263 &format!("Contents of file {i} — some long output here"),
264 ));
265 }
266 mask_observations(&mut messages, 10);
269
270 for i in 0..10 {
272 let tr_idx = 2 + i * 2; let text = tool_result_text(&messages[tr_idx]);
274 assert!(
275 text.starts_with("[Output omitted"),
276 "Turn {i} tool result should be masked, got: {text}"
277 );
278 }
279
280 for i in 10..20 {
282 let tr_idx = 2 + i * 2;
283 let text = tool_result_text(&messages[tr_idx]);
284 assert!(
285 text.starts_with("Contents of file"),
286 "Turn {i} tool result should be intact, got: {text}"
287 );
288 }
289 }
290
291 #[test]
292 fn masking_preserves_user_messages() {
293 let mut messages = Vec::new();
294 messages.push(make_user("Hello, help me with this task"));
295
296 for i in 0..5 {
297 let call_id = format!("call_{i}");
298 messages.push(make_assistant_tool_call(
299 &call_id,
300 "bash",
301 serde_json::json!({"command": format!("ls /tmp/{i}")}),
302 ));
303 messages.push(make_tool_result(
304 &call_id,
305 "bash",
306 &format!("file_{i}.txt\nmore_output_{i}"),
307 ));
308 }
309
310 mask_observations(&mut messages, 2);
311
312 if let Message::User(u) = &messages[0] {
314 if let ContentBlock::Text { text } = &u.content[0] {
315 assert_eq!(text, "Hello, help me with this task");
316 } else {
317 panic!("expected Text block in user message");
318 }
319 } else {
320 panic!("expected User message at index 0");
321 }
322 }
323
324 #[test]
325 fn masking_preserves_assistant_text_and_tool_call_args() {
326 let mut messages = Vec::new();
327 messages.push(make_user("do stuff"));
328
329 for i in 0..4 {
330 let call_id = format!("call_{i}");
331 let args = serde_json::json!({"command": format!("echo {i}")});
332 messages.push(make_assistant_tool_call(&call_id, "bash", args));
333 messages.push(make_tool_result(&call_id, "bash", &format!("output {i}")));
334 }
335 messages.push(make_assistant_text("All done!"));
336
337 mask_observations(&mut messages, 1);
339
340 for msg in &messages {
342 if let Message::Assistant(a) = msg {
343 for block in &a.content {
344 match block {
345 ContentBlock::ToolCall {
346 name, arguments, ..
347 } => {
348 assert_eq!(name, "bash");
349 assert!(arguments.get("command").is_some());
350 }
351 ContentBlock::Text { text } => {
352 assert_eq!(text, "All done!");
353 }
354 _ => {}
355 }
356 }
357 }
358 }
359
360 let tool_results: Vec<&ToolResultMessage> = messages
362 .iter()
363 .filter_map(|m| {
364 if let Message::ToolResult(tr) = m {
365 Some(tr)
366 } else {
367 None
368 }
369 })
370 .collect();
371
372 for tr in &tool_results {
373 assert_eq!(tr.tool_name, "bash");
374 assert!(!tr.is_error);
375 assert!(!tr.tool_call_id.is_empty());
376 }
377 }
378
379 #[test]
380 fn mask_observations_includes_args_summary() {
381 let mut messages = Vec::new();
382 messages.push(make_user("do stuff"));
383
384 let args = serde_json::json!({"path": "/src/main.rs", "line": 42});
385 messages.push(make_assistant_tool_call("c1", "read_file", args));
386 messages.push(make_tool_result("c1", "read_file", "fn main() {}"));
387
388 messages.push(make_assistant_text("done"));
389
390 mask_observations(&mut messages, 1);
392
393 let text = tool_result_text(&messages[2]);
394 assert!(text.contains("read_file"), "should contain tool name");
395 assert!(text.contains("/src/main.rs"), "should contain args summary");
396 assert!(text.contains("bytes"), "should contain byte count");
397 }
398
399 #[test]
400 fn mask_observations_handles_multibyte_args_without_panicking() {
401 let mut messages = vec![make_user("do stuff")];
402
403 let long_text = format!("{}—bbb", "a".repeat(86));
404 messages.push(make_assistant_tool_call(
405 "c1",
406 "edit",
407 serde_json::json!({"newText": long_text}),
408 ));
409 messages.push(make_tool_result("c1", "edit", "ok"));
410 messages.push(make_assistant_text("done"));
411
412 mask_observations(&mut messages, 1);
413
414 let text = tool_result_text(&messages[2]);
415 assert!(text.starts_with("[Output omitted"));
416 assert!(text.contains("..."));
417 }
418
419 #[test]
420 fn mask_observations_noop_when_few_turns() {
421 let mut messages = vec![make_user("hi"), make_assistant_text("hello")];
422 let original = messages.clone();
423
424 mask_observations(&mut messages, 10);
425
426 assert_eq!(messages.len(), original.len());
428 }
429
430 #[test]
433 fn context_usage_basic_calculation() {
434 let model = test_model();
435 let messages = vec![make_user("Hello world"), make_assistant_text("Hi there!")];
436
437 let usage = context_usage(&messages, &model);
438
439 assert!(usage.used > 0, "should estimate > 0 tokens");
440 assert_eq!(usage.limit, 100_000);
441 assert!(usage.ratio > 0.0, "ratio should be positive");
442 assert!(usage.ratio < 1.0, "ratio should be < 1 for small messages");
443 }
444
445 #[test]
446 fn context_usage_masked_vs_unmasked() {
447 let model = test_model();
448
449 let mut messages = Vec::new();
450 messages.push(make_user("prompt"));
451 for i in 0..10 {
452 let call_id = format!("c{i}");
453 let big_output = "x".repeat(2000);
454 messages.push(make_assistant_tool_call(
455 &call_id,
456 "bash",
457 serde_json::json!({"cmd": "ls"}),
458 ));
459 messages.push(make_tool_result(&call_id, "bash", &big_output));
460 }
461
462 let usage_before = context_usage(&messages, &model);
463
464 mask_observations(&mut messages, 2);
465
466 let usage_after = context_usage(&messages, &model);
467
468 assert!(
469 usage_after.used < usage_before.used,
470 "masking should reduce token count: before={}, after={}",
471 usage_before.used,
472 usage_after.used
473 );
474 }
475
476 #[test]
479 fn estimate_tokens_empty_string() {
480 assert_eq!(estimate_tokens(""), 0);
481 }
482
483 #[test]
484 fn context_usage_with_zero_messages() {
485 let model = test_model();
486 let messages: Vec<Message> = vec![];
487
488 let usage = context_usage(&messages, &model);
489
490 assert_eq!(usage.used, 0);
491 assert_eq!(usage.ratio, 0.0);
492 assert_eq!(usage.limit, 100_000);
493 }
494
495 #[test]
496 fn context_usage_near_limit() {
497 let big_text = "a".repeat(400);
499 let messages = vec![make_user(&big_text)];
500
501 let json = serde_json::to_string(&messages[0]).unwrap();
504 let estimated = estimate_tokens(&json);
505 let window = estimated + 1;
506
507 let model = Model {
508 meta: ModelMeta {
509 id: "test".into(),
510 provider: "test".into(),
511 name: "Test".into(),
512 context_window: window,
513 max_output_tokens: 4096,
514 pricing: ModelPricing::default(),
515 capabilities: Capabilities::default(),
516 },
517 provider: Arc::new(NullProvider),
518 };
519
520 let usage = context_usage(&messages, &model);
521
522 assert!(usage.ratio > 0.95, "ratio {} should be > 0.95", usage.ratio);
523 assert!(usage.ratio < 1.0, "ratio {} should be < 1.0", usage.ratio);
524 }
525
526 #[test]
527 fn mask_observations_replaces_content_with_placeholder() {
528 let mut messages = vec![make_user("prompt")];
529 let args = serde_json::json!({"path": "/src/lib.rs"});
530 messages.push(make_assistant_tool_call("c1", "read_file", args));
531 messages.push(make_tool_result(
532 "c1",
533 "read_file",
534 "fn main() { println!(\"hello\"); }",
535 ));
536 messages.push(make_assistant_text("Done reading."));
538
539 mask_observations(&mut messages, 1);
541
542 let text = tool_result_text(&messages[2]);
543 assert!(
545 text.starts_with("[Output omitted — ran read_file("),
546 "placeholder should start correctly, got: {text}"
547 );
548 assert!(
549 text.contains("/src/lib.rs"),
550 "placeholder should contain args summary, got: {text}"
551 );
552 assert!(
553 text.ends_with("bytes]"),
554 "placeholder should end with byte count, got: {text}"
555 );
556 let original_len = "fn main() { println!(\"hello\"); }".len();
558 assert!(
559 text.contains(&format!("{original_len} bytes")),
560 "placeholder should contain correct byte count {original_len}, got: {text}"
561 );
562 }
563
564 #[test]
565 fn mask_observations_preserves_all_assistant_reasoning() {
566 let mut messages = vec![make_user("help me refactor")];
567
568 messages.push(Message::Assistant(AssistantMessage {
570 content: vec![
571 ContentBlock::Text {
572 text: "Let me read the file first.".into(),
573 },
574 ContentBlock::ToolCall {
575 id: "c0".into(),
576 name: "read".into(),
577 arguments: serde_json::json!({"path": "a.rs"}),
578 },
579 ],
580 usage: None,
581 stop_reason: StopReason::ToolUse,
582 timestamp: 1000,
583 }));
584 messages.push(make_tool_result("c0", "read", "file contents A"));
585
586 messages.push(make_assistant_text(
588 "I see the issue — the struct is missing a field.",
589 ));
590
591 messages.push(make_assistant_tool_call(
593 "c2",
594 "edit",
595 serde_json::json!({"file": "a.rs"}),
596 ));
597 messages.push(make_tool_result("c2", "edit", "ok"));
598
599 mask_observations(&mut messages, 1);
601
602 let assistant_texts: Vec<&str> = messages
604 .iter()
605 .filter_map(|m| {
606 if let Message::Assistant(a) = m {
607 Some(a.content.iter().filter_map(|b| {
608 if let ContentBlock::Text { text } = b {
609 Some(text.as_str())
610 } else {
611 None
612 }
613 }))
614 } else {
615 None
616 }
617 })
618 .flatten()
619 .collect();
620
621 assert!(
622 assistant_texts.contains(&"Let me read the file first."),
623 "early assistant reasoning must survive masking"
624 );
625 assert!(
626 assistant_texts.contains(&"I see the issue — the struct is missing a field."),
627 "mid-conversation assistant reasoning must survive masking"
628 );
629 }
630}