Skip to main content

rho_core/
compaction.rs

1use crate::types::{Content, Message, Model, UserContent};
2
3/// Estimate token count using chars/4 heuristic.
4pub fn estimate_tokens(messages: &[Message]) -> usize {
5    let mut chars = 0;
6    for msg in messages {
7        match msg {
8            Message::User { content, .. } => match content {
9                UserContent::Text(t) => chars += t.len(),
10                UserContent::Blocks(blocks) => {
11                    for block in blocks {
12                        chars += content_chars(block);
13                    }
14                }
15            },
16            Message::Assistant { content, .. } => {
17                for block in content {
18                    chars += content_chars(block);
19                }
20            }
21            Message::ToolResult { content, .. } => {
22                for block in content {
23                    chars += content_chars(block);
24                }
25            }
26        }
27    }
28    chars / 4
29}
30
31fn content_chars(c: &Content) -> usize {
32    match c {
33        Content::Text { text } => text.len(),
34        Content::Thinking { thinking } => thinking.len(),
35        Content::ToolCall { arguments, .. } => arguments.to_string().len() + 50, // overhead
36        Content::Image { data, .. } => data.len(),
37    }
38}
39
40/// Replace old tool result content with "[pruned]".
41/// Keeps the last `keep_recent` messages intact.
42pub fn prune_tool_outputs(messages: &[Message], keep_recent: usize) -> Vec<Message> {
43    let len = messages.len();
44    let prune_boundary = len.saturating_sub(keep_recent);
45
46    messages
47        .iter()
48        .enumerate()
49        .map(|(i, msg)| {
50            if i < prune_boundary {
51                match msg {
52                    Message::ToolResult {
53                        tool_call_id,
54                        tool_name,
55                        is_error,
56                        timestamp,
57                        ..
58                    } => Message::ToolResult {
59                        tool_call_id: tool_call_id.clone(),
60                        tool_name: tool_name.clone(),
61                        content: vec![Content::Text {
62                            text: "[pruned]".into(),
63                        }],
64                        is_error: *is_error,
65                        timestamp: *timestamp,
66                    },
67                    other => other.clone(),
68                }
69            } else {
70                msg.clone()
71            }
72        })
73        .collect()
74}
75
76/// Build the transform function for AgentLoopConfig.
77///
78/// Auto-compaction strategy:
79/// 1. If under threshold * context_window, return as-is
80/// 2. Prune tool outputs older than last 6 messages
81/// 3. If still over, prune more aggressively (keep last 2)
82pub fn make_compaction_transform(
83    threshold: f64,
84) -> Box<dyn Fn(&[Message], &Model) -> (Vec<Message>, Option<CompactionResult>) + Send + Sync> {
85    Box::new(move |messages: &[Message], model: &Model| {
86        let estimated = estimate_tokens(messages);
87        let limit = (threshold * model.context_window as f64) as usize;
88
89        if estimated <= limit {
90            return (messages.to_vec(), None);
91        }
92
93        // Phase 1: Prune tool outputs, keep last 6
94        let pruned = prune_tool_outputs(messages, 6);
95        let new_estimate = estimate_tokens(&pruned);
96
97        if new_estimate <= limit {
98            return (
99                pruned,
100                Some(CompactionResult {
101                    original_estimate: estimated,
102                    compacted_estimate: new_estimate,
103                    messages_pruned: messages.len().saturating_sub(6),
104                }),
105            );
106        }
107
108        // Phase 2: More aggressive — keep last 2
109        let pruned = prune_tool_outputs(messages, 2);
110        let new_estimate = estimate_tokens(&pruned);
111
112        (
113            pruned,
114            Some(CompactionResult {
115                original_estimate: estimated,
116                compacted_estimate: new_estimate,
117                messages_pruned: messages.len().saturating_sub(2),
118            }),
119        )
120    })
121}
122
123#[derive(Debug, Clone)]
124pub struct CompactionResult {
125    pub original_estimate: usize,
126    pub compacted_estimate: usize,
127    pub messages_pruned: usize,
128}
129
130/// Build a summary prompt for manual /compact.
131pub fn build_summary_prompt(messages: &[Message]) -> String {
132    let mut context = String::new();
133    for msg in messages {
134        match msg {
135            Message::User { content, .. } => {
136                context.push_str("User: ");
137                match content {
138                    UserContent::Text(t) => context.push_str(t),
139                    UserContent::Blocks(blocks) => {
140                        for b in blocks {
141                            if let Content::Text { text } = b {
142                                context.push_str(text);
143                            }
144                        }
145                    }
146                }
147                context.push('\n');
148            }
149            Message::Assistant { content, .. } => {
150                context.push_str("Assistant: ");
151                for b in content {
152                    if let Content::Text { text } = b {
153                        context.push_str(text);
154                    }
155                }
156                context.push('\n');
157            }
158            Message::ToolResult {
159                tool_name, content, ..
160            } => {
161                context.push_str(&format!("[Tool: {}] ", tool_name));
162                for b in content {
163                    if let Content::Text { text } = b {
164                        // Truncate long tool results in summary
165                        if text.len() > 500 {
166                            context.push_str(&text[..500]);
167                            context.push_str("...");
168                        } else {
169                            context.push_str(text);
170                        }
171                    }
172                }
173                context.push('\n');
174            }
175        }
176    }
177
178    format!(
179        "Summarize the following conversation concisely, preserving key decisions, \
180         file paths, code changes, and important context. Be brief but complete.\n\n\
181         ---\n{}\n---\n\nProvide a concise summary:",
182        context
183    )
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use crate::types::Usage;
190
191    fn make_user(text: &str) -> Message {
192        Message::User {
193            content: UserContent::Text(text.into()),
194            timestamp: 0,
195        }
196    }
197
198    fn make_assistant(text: &str) -> Message {
199        Message::Assistant {
200            content: vec![Content::Text { text: text.into() }],
201            model: "test".into(),
202            usage: Usage::default(),
203            stop_reason: crate::types::StopReason::Stop,
204            timestamp: 0,
205        }
206    }
207
208    fn make_tool_result(name: &str, text: &str) -> Message {
209        Message::ToolResult {
210            tool_call_id: "id".into(),
211            tool_name: name.into(),
212            content: vec![Content::Text { text: text.into() }],
213            is_error: false,
214            timestamp: 0,
215        }
216    }
217
218    #[test]
219    fn estimate_tokens_basic() {
220        let msgs = vec![
221            make_user("hello world"), // 11 chars
222        ];
223        let est = estimate_tokens(&msgs);
224        assert_eq!(est, 2); // 11/4 = 2
225    }
226
227    #[test]
228    fn prune_tool_outputs_keeps_recent() {
229        let msgs = vec![
230            make_user("q1"),
231            make_tool_result("read", "long content here"),
232            make_assistant("a1"),
233            make_user("q2"),
234            make_tool_result("read", "more content"),
235            make_assistant("a2"),
236        ];
237        let pruned = prune_tool_outputs(&msgs, 3);
238
239        // First 3 messages are in prune zone
240        match &pruned[1] {
241            Message::ToolResult { content, .. } => {
242                assert_eq!(content[0], Content::Text { text: "[pruned]".into() });
243            }
244            _ => panic!("expected tool result"),
245        }
246
247        // Last 3 messages kept intact
248        match &pruned[4] {
249            Message::ToolResult { content, .. } => {
250                match &content[0] {
251                    Content::Text { text } => assert_eq!(text, "more content"),
252                    _ => panic!("expected text"),
253                }
254            }
255            _ => panic!("expected tool result"),
256        }
257    }
258
259    #[test]
260    fn compaction_transform_no_op_under_threshold() {
261        let transform = make_compaction_transform(0.8);
262        let model = Model {
263            id: "test".into(),
264            name: "test".into(),
265            provider: "test".into(),
266            base_url: String::new(),
267            reasoning: false,
268            context_window: 200_000,
269            max_tokens: 8192,
270        };
271        let msgs = vec![make_user("short")];
272        let (result, compaction) = transform(&msgs, &model);
273        assert_eq!(result.len(), 1);
274        assert!(compaction.is_none());
275    }
276
277    #[test]
278    fn build_summary_prompt_includes_messages() {
279        let msgs = vec![
280            make_user("Find auth module"),
281            make_assistant("I found it in src/auth.rs"),
282        ];
283        let prompt = build_summary_prompt(&msgs);
284        assert!(prompt.contains("Find auth module"));
285        assert!(prompt.contains("src/auth.rs"));
286    }
287}