Skip to main content

imp_core/
context.rs

1use std::collections::HashMap;
2
3use imp_llm::{truncate_chars_with_suffix, ContentBlock, Message, Model};
4
5fn truncate_for_display(text: &str, max_chars: usize) -> String {
6    truncate_chars_with_suffix(text, max_chars, "...")
7}
8
9/// Context usage stats.
10#[derive(Debug, Clone)]
11pub struct ContextUsage {
12    pub used: u32,
13    pub limit: u32,
14    pub ratio: f64,
15}
16
17/// Fast approximate token counting (~4 chars per token for English).
18pub fn estimate_tokens(text: &str) -> u32 {
19    (text.len() as u32) / 4
20}
21
22/// Estimate total context usage for a message list.
23pub fn context_usage(messages: &[Message], model: &Model) -> ContextUsage {
24    let used: u32 = messages
25        .iter()
26        .map(|m| {
27            let json = serde_json::to_string(m).unwrap_or_default();
28            estimate_tokens(&json)
29        })
30        .sum();
31    let limit = model.meta.context_window;
32    let ratio = if limit > 0 {
33        used as f64 / limit as f64
34    } else {
35        0.0
36    };
37    ContextUsage { used, limit, ratio }
38}
39
40/// Replace old tool result content with lightweight placeholders.
41///
42/// A "turn" is one assistant message plus its following tool results.
43/// Keeps the last `keep_recent_turns` turns fully intact. For older turns,
44/// tool result content is replaced with a summary placeholder preserving
45/// the tool name, a truncated summary of args, and the byte count.
46pub fn mask_observations(messages: &mut [Message], keep_recent_turns: usize) {
47    // Identify turn boundaries — each assistant message starts a new turn.
48    let turn_starts: Vec<usize> = messages
49        .iter()
50        .enumerate()
51        .filter(|(_, m)| m.is_assistant())
52        .map(|(i, _)| i)
53        .collect();
54
55    if turn_starts.len() <= keep_recent_turns {
56        return;
57    }
58
59    // Everything before this message index gets masked.
60    let cutoff_turn = turn_starts.len() - keep_recent_turns;
61    let cutoff_msg_idx = turn_starts[cutoff_turn];
62
63    // Build a map of tool_call_id → args summary from assistant ToolCall blocks
64    // in the region we're about to mask.
65    let mut args_map: HashMap<String, String> = HashMap::new();
66    for msg in &messages[..cutoff_msg_idx] {
67        if let Message::Assistant(assistant) = msg {
68            for block in &assistant.content {
69                if let ContentBlock::ToolCall { id, arguments, .. } = block {
70                    let args_json = serde_json::to_string(arguments).unwrap_or_default();
71                    let summary = truncate_for_display(&args_json, 100);
72                    args_map.insert(id.clone(), summary);
73                }
74            }
75        }
76    }
77
78    // Replace tool result content with placeholders.
79    for msg in &mut messages[..cutoff_msg_idx] {
80        if let Message::ToolResult(ref mut result) = msg {
81            let byte_count: usize = result
82                .content
83                .iter()
84                .map(|b| match b {
85                    ContentBlock::Text { text } => text.len(),
86                    _ => 0,
87                })
88                .sum();
89
90            let args_summary = args_map
91                .get(&result.tool_call_id)
92                .map(|s| s.as_str())
93                .unwrap_or("");
94
95            let placeholder = format!(
96                "[Output omitted — ran {}({}), returned {} bytes]",
97                result.tool_name, args_summary, byte_count
98            );
99            result.content = vec![ContentBlock::Text { text: placeholder }];
100        }
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107    use std::pin::Pin;
108    use std::sync::Arc;
109
110    use async_trait::async_trait;
111    use futures_core::Stream;
112    use imp_llm::model::{Capabilities, ModelMeta, ModelPricing};
113    use imp_llm::provider::Provider;
114    use imp_llm::{AssistantMessage, RequestOptions, StopReason, StreamEvent, ToolResultMessage};
115
116    // -- helpers --
117
118    fn make_user(text: &str) -> Message {
119        Message::user(text)
120    }
121
122    fn make_assistant_tool_call(
123        call_id: &str,
124        tool_name: &str,
125        args: serde_json::Value,
126    ) -> Message {
127        Message::Assistant(AssistantMessage {
128            content: vec![ContentBlock::ToolCall {
129                id: call_id.into(),
130                name: tool_name.into(),
131                arguments: args,
132            }],
133            usage: None,
134            stop_reason: StopReason::ToolUse,
135            timestamp: 1000,
136        })
137    }
138
139    fn make_assistant_text(text: &str) -> Message {
140        Message::Assistant(AssistantMessage {
141            content: vec![ContentBlock::Text { text: text.into() }],
142            usage: None,
143            stop_reason: StopReason::EndTurn,
144            timestamp: 1000,
145        })
146    }
147
148    fn make_tool_result(call_id: &str, tool_name: &str, output: &str) -> Message {
149        Message::ToolResult(ToolResultMessage {
150            tool_call_id: call_id.into(),
151            tool_name: tool_name.into(),
152            content: vec![ContentBlock::Text {
153                text: output.into(),
154            }],
155            is_error: false,
156            details: serde_json::Value::Null,
157            timestamp: 1000,
158        })
159    }
160
161    fn tool_result_text(msg: &Message) -> &str {
162        match msg {
163            Message::ToolResult(tr) => match &tr.content[0] {
164                ContentBlock::Text { text } => text.as_str(),
165                _ => panic!("expected text block"),
166            },
167            _ => panic!("expected ToolResult"),
168        }
169    }
170
171    /// Minimal provider that never streams anything. Used for context_usage tests.
172    struct NullProvider;
173
174    #[async_trait]
175    impl Provider for NullProvider {
176        fn stream(
177            &self,
178            _model: &Model,
179            _context: imp_llm::Context,
180            _options: RequestOptions,
181            _api_key: &str,
182        ) -> Pin<Box<dyn Stream<Item = imp_llm::Result<StreamEvent>> + Send>> {
183            Box::pin(futures::stream::empty())
184        }
185
186        async fn resolve_auth(
187            &self,
188            _auth: &imp_llm::auth::AuthStore,
189        ) -> imp_llm::Result<imp_llm::auth::ApiKey> {
190            Ok("test".into())
191        }
192
193        fn id(&self) -> &str {
194            "null"
195        }
196
197        fn models(&self) -> &[ModelMeta] {
198            &[]
199        }
200    }
201
202    fn test_model() -> Model {
203        Model {
204            meta: ModelMeta {
205                id: "test".into(),
206                provider: "test".into(),
207                name: "Test".into(),
208                context_window: 100_000,
209                max_output_tokens: 4096,
210                pricing: ModelPricing::default(),
211                capabilities: Capabilities::default(),
212            },
213            provider: Arc::new(NullProvider),
214        }
215    }
216
217    // -- token estimation --
218
219    #[test]
220    fn estimate_tokens_rough_accuracy_for_english() {
221        // "The quick brown fox jumps over the lazy dog" is 44 chars.
222        // Real tokenizers produce ~10 tokens for this sentence.
223        // Our estimate: 44 / 4 = 11. Within 2x of 10 ✓
224        let text = "The quick brown fox jumps over the lazy dog";
225        let est = estimate_tokens(text);
226        let actual_approx = 10u32;
227        assert!(
228            est <= actual_approx * 2 && est * 2 >= actual_approx,
229            "estimate {est} should be within 2x of ~{actual_approx}"
230        );
231    }
232
233    #[test]
234    fn estimate_tokens_longer_text() {
235        // ~400 chars of prose → ~100 tokens estimated, real is ~80–90.
236        let text = "Rust is a multi-paradigm programming language designed for performance \
237                    and safety, especially safe concurrency. Rust is syntactically similar to C++ \
238                    but can guarantee memory safety by using a borrow checker to validate references. \
239                    Rust achieves memory safety without garbage collection, and reference counting \
240                    is optional. Rust was originally designed by Graydon Hoare at Mozilla Research.";
241        let est = estimate_tokens(text);
242        // ~380 chars / 4 = 95. Real ≈ 65-75 tokens. Ratio ≈ 1.3x — within 2x.
243        assert!(est > 40 && est < 200, "estimate {est} out of range");
244    }
245
246    // -- observation masking --
247
248    #[test]
249    fn mask_observations_20_turns_keeps_last_10() {
250        let mut messages = Vec::new();
251        messages.push(make_user("initial prompt"));
252
253        for i in 0..20 {
254            let call_id = format!("call_{i}");
255            messages.push(make_assistant_tool_call(
256                &call_id,
257                "read_file",
258                serde_json::json!({"path": format!("/tmp/file_{i}.rs")}),
259            ));
260            messages.push(make_tool_result(
261                &call_id,
262                "read_file",
263                &format!("Contents of file {i} — some long output here"),
264            ));
265        }
266        // 1 user + 20*(assistant+tool_result) = 41 messages total
267
268        mask_observations(&mut messages, 10);
269
270        // First 10 turns are messages[1..21] — tool results at indices 2,4,6,...,20
271        for i in 0..10 {
272            let tr_idx = 2 + i * 2; // tool result indices: 2, 4, 6, ..., 20
273            let text = tool_result_text(&messages[tr_idx]);
274            assert!(
275                text.starts_with("[Output omitted"),
276                "Turn {i} tool result should be masked, got: {text}"
277            );
278        }
279
280        // Last 10 turns are messages[21..41] — tool results at 22,24,...,40
281        for i in 10..20 {
282            let tr_idx = 2 + i * 2;
283            let text = tool_result_text(&messages[tr_idx]);
284            assert!(
285                text.starts_with("Contents of file"),
286                "Turn {i} tool result should be intact, got: {text}"
287            );
288        }
289    }
290
291    #[test]
292    fn masking_preserves_user_messages() {
293        let mut messages = Vec::new();
294        messages.push(make_user("Hello, help me with this task"));
295
296        for i in 0..5 {
297            let call_id = format!("call_{i}");
298            messages.push(make_assistant_tool_call(
299                &call_id,
300                "bash",
301                serde_json::json!({"command": format!("ls /tmp/{i}")}),
302            ));
303            messages.push(make_tool_result(
304                &call_id,
305                "bash",
306                &format!("file_{i}.txt\nmore_output_{i}"),
307            ));
308        }
309
310        mask_observations(&mut messages, 2);
311
312        // User message at index 0 must be preserved verbatim.
313        if let Message::User(u) = &messages[0] {
314            if let ContentBlock::Text { text } = &u.content[0] {
315                assert_eq!(text, "Hello, help me with this task");
316            } else {
317                panic!("expected Text block in user message");
318            }
319        } else {
320            panic!("expected User message at index 0");
321        }
322    }
323
324    #[test]
325    fn masking_preserves_assistant_text_and_tool_call_args() {
326        let mut messages = Vec::new();
327        messages.push(make_user("do stuff"));
328
329        for i in 0..4 {
330            let call_id = format!("call_{i}");
331            let args = serde_json::json!({"command": format!("echo {i}")});
332            messages.push(make_assistant_tool_call(&call_id, "bash", args));
333            messages.push(make_tool_result(&call_id, "bash", &format!("output {i}")));
334        }
335        messages.push(make_assistant_text("All done!"));
336
337        // Keep last 1 turn (the final text-only assistant). That means 4 tool turns get masked.
338        mask_observations(&mut messages, 1);
339
340        // Check all assistant messages are fully preserved.
341        for msg in &messages {
342            if let Message::Assistant(a) = msg {
343                for block in &a.content {
344                    match block {
345                        ContentBlock::ToolCall {
346                            name, arguments, ..
347                        } => {
348                            assert_eq!(name, "bash");
349                            assert!(arguments.get("command").is_some());
350                        }
351                        ContentBlock::Text { text } => {
352                            assert_eq!(text, "All done!");
353                        }
354                        _ => {}
355                    }
356                }
357            }
358        }
359
360        // Tool results in old turns are masked but preserve tool_call_id, tool_name, is_error.
361        let tool_results: Vec<&ToolResultMessage> = messages
362            .iter()
363            .filter_map(|m| {
364                if let Message::ToolResult(tr) = m {
365                    Some(tr)
366                } else {
367                    None
368                }
369            })
370            .collect();
371
372        for tr in &tool_results {
373            assert_eq!(tr.tool_name, "bash");
374            assert!(!tr.is_error);
375            assert!(!tr.tool_call_id.is_empty());
376        }
377    }
378
379    #[test]
380    fn mask_observations_includes_args_summary() {
381        let mut messages = Vec::new();
382        messages.push(make_user("do stuff"));
383
384        let args = serde_json::json!({"path": "/src/main.rs", "line": 42});
385        messages.push(make_assistant_tool_call("c1", "read_file", args));
386        messages.push(make_tool_result("c1", "read_file", "fn main() {}"));
387
388        messages.push(make_assistant_text("done"));
389
390        // Keep only the last turn (text-only), so the tool turn gets masked.
391        mask_observations(&mut messages, 1);
392
393        let text = tool_result_text(&messages[2]);
394        assert!(text.contains("read_file"), "should contain tool name");
395        assert!(text.contains("/src/main.rs"), "should contain args summary");
396        assert!(text.contains("bytes"), "should contain byte count");
397    }
398
399    #[test]
400    fn mask_observations_handles_multibyte_args_without_panicking() {
401        let mut messages = vec![make_user("do stuff")];
402
403        let long_text = format!("{}—bbb", "a".repeat(86));
404        messages.push(make_assistant_tool_call(
405            "c1",
406            "edit",
407            serde_json::json!({"newText": long_text}),
408        ));
409        messages.push(make_tool_result("c1", "edit", "ok"));
410        messages.push(make_assistant_text("done"));
411
412        mask_observations(&mut messages, 1);
413
414        let text = tool_result_text(&messages[2]);
415        assert!(text.starts_with("[Output omitted"));
416        assert!(text.contains("..."));
417    }
418
419    #[test]
420    fn mask_observations_noop_when_few_turns() {
421        let mut messages = vec![make_user("hi"), make_assistant_text("hello")];
422        let original = messages.clone();
423
424        mask_observations(&mut messages, 10);
425
426        // Nothing should change — only 1 turn, window is 10.
427        assert_eq!(messages.len(), original.len());
428    }
429
430    // -- context usage --
431
432    #[test]
433    fn context_usage_basic_calculation() {
434        let model = test_model();
435        let messages = vec![make_user("Hello world"), make_assistant_text("Hi there!")];
436
437        let usage = context_usage(&messages, &model);
438
439        assert!(usage.used > 0, "should estimate > 0 tokens");
440        assert_eq!(usage.limit, 100_000);
441        assert!(usage.ratio > 0.0, "ratio should be positive");
442        assert!(usage.ratio < 1.0, "ratio should be < 1 for small messages");
443    }
444
445    #[test]
446    fn context_usage_masked_vs_unmasked() {
447        let model = test_model();
448
449        let mut messages = Vec::new();
450        messages.push(make_user("prompt"));
451        for i in 0..10 {
452            let call_id = format!("c{i}");
453            let big_output = "x".repeat(2000);
454            messages.push(make_assistant_tool_call(
455                &call_id,
456                "bash",
457                serde_json::json!({"cmd": "ls"}),
458            ));
459            messages.push(make_tool_result(&call_id, "bash", &big_output));
460        }
461
462        let usage_before = context_usage(&messages, &model);
463
464        mask_observations(&mut messages, 2);
465
466        let usage_after = context_usage(&messages, &model);
467
468        assert!(
469            usage_after.used < usage_before.used,
470            "masking should reduce token count: before={}, after={}",
471            usage_before.used,
472            usage_after.used
473        );
474    }
475
476    // -- edge case tests --
477
478    #[test]
479    fn estimate_tokens_empty_string() {
480        assert_eq!(estimate_tokens(""), 0);
481    }
482
483    #[test]
484    fn context_usage_with_zero_messages() {
485        let model = test_model();
486        let messages: Vec<Message> = vec![];
487
488        let usage = context_usage(&messages, &model);
489
490        assert_eq!(usage.used, 0);
491        assert_eq!(usage.ratio, 0.0);
492        assert_eq!(usage.limit, 100_000);
493    }
494
495    #[test]
496    fn context_usage_near_limit() {
497        // Create a message with enough text to approach the limit.
498        let big_text = "a".repeat(400);
499        let messages = vec![make_user(&big_text)];
500
501        // Compute estimated tokens for this message, then set context_window = estimated + 1
502        // so ratio is just under 1.0.
503        let json = serde_json::to_string(&messages[0]).unwrap();
504        let estimated = estimate_tokens(&json);
505        let window = estimated + 1;
506
507        let model = Model {
508            meta: ModelMeta {
509                id: "test".into(),
510                provider: "test".into(),
511                name: "Test".into(),
512                context_window: window,
513                max_output_tokens: 4096,
514                pricing: ModelPricing::default(),
515                capabilities: Capabilities::default(),
516            },
517            provider: Arc::new(NullProvider),
518        };
519
520        let usage = context_usage(&messages, &model);
521
522        assert!(usage.ratio > 0.95, "ratio {} should be > 0.95", usage.ratio);
523        assert!(usage.ratio < 1.0, "ratio {} should be < 1.0", usage.ratio);
524    }
525
526    #[test]
527    fn mask_observations_replaces_content_with_placeholder() {
528        let mut messages = vec![make_user("prompt")];
529        let args = serde_json::json!({"path": "/src/lib.rs"});
530        messages.push(make_assistant_tool_call("c1", "read_file", args));
531        messages.push(make_tool_result(
532            "c1",
533            "read_file",
534            "fn main() { println!(\"hello\"); }",
535        ));
536        // Second turn stays recent.
537        messages.push(make_assistant_text("Done reading."));
538
539        // Keep only last 1 turn → the tool turn gets masked.
540        mask_observations(&mut messages, 1);
541
542        let text = tool_result_text(&messages[2]);
543        // Verify exact placeholder format.
544        assert!(
545            text.starts_with("[Output omitted — ran read_file("),
546            "placeholder should start correctly, got: {text}"
547        );
548        assert!(
549            text.contains("/src/lib.rs"),
550            "placeholder should contain args summary, got: {text}"
551        );
552        assert!(
553            text.ends_with("bytes]"),
554            "placeholder should end with byte count, got: {text}"
555        );
556        // Verify byte count matches original content length.
557        let original_len = "fn main() { println!(\"hello\"); }".len();
558        assert!(
559            text.contains(&format!("{original_len} bytes")),
560            "placeholder should contain correct byte count {original_len}, got: {text}"
561        );
562    }
563
564    #[test]
565    fn mask_observations_preserves_all_assistant_reasoning() {
566        let mut messages = vec![make_user("help me refactor")];
567
568        // Turn 0: assistant text + tool call.
569        messages.push(Message::Assistant(AssistantMessage {
570            content: vec![
571                ContentBlock::Text {
572                    text: "Let me read the file first.".into(),
573                },
574                ContentBlock::ToolCall {
575                    id: "c0".into(),
576                    name: "read".into(),
577                    arguments: serde_json::json!({"path": "a.rs"}),
578                },
579            ],
580            usage: None,
581            stop_reason: StopReason::ToolUse,
582            timestamp: 1000,
583        }));
584        messages.push(make_tool_result("c0", "read", "file contents A"));
585
586        // Turn 1: assistant reasoning text.
587        messages.push(make_assistant_text(
588            "I see the issue — the struct is missing a field.",
589        ));
590
591        // Turn 2: another tool call.
592        messages.push(make_assistant_tool_call(
593            "c2",
594            "edit",
595            serde_json::json!({"file": "a.rs"}),
596        ));
597        messages.push(make_tool_result("c2", "edit", "ok"));
598
599        // Keep last 1 turn → turns 0 and 1 get masked.
600        mask_observations(&mut messages, 1);
601
602        // Collect ALL assistant text blocks — they should all be intact.
603        let assistant_texts: Vec<&str> = messages
604            .iter()
605            .filter_map(|m| {
606                if let Message::Assistant(a) = m {
607                    Some(a.content.iter().filter_map(|b| {
608                        if let ContentBlock::Text { text } = b {
609                            Some(text.as_str())
610                        } else {
611                            None
612                        }
613                    }))
614                } else {
615                    None
616                }
617            })
618            .flatten()
619            .collect();
620
621        assert!(
622            assistant_texts.contains(&"Let me read the file first."),
623            "early assistant reasoning must survive masking"
624        );
625        assert!(
626            assistant_texts.contains(&"I see the issue — the struct is missing a field."),
627            "mid-conversation assistant reasoning must survive masking"
628        );
629    }
630}