Skip to main content

bamboo_compression/
segmenter.rs

1//! Message segmentation for budget management.
2//!
3//! Groups messages into atomic segments, ensuring tool-call chains stay together.
4//! This prevents protocol errors that occur when tool results are included without
5//! their corresponding tool calls.
6
7use bamboo_agent_core::{Message, Role};
8use std::collections::HashSet;
9
10/// A segment of conversation that should be treated as atomic during truncation.
11///
12/// Ensures tool-call relationships are preserved: an assistant's tool_call
13/// and the corresponding tool results form a single segment.
14#[derive(Debug, Clone)]
15pub struct MessageSegment {
16    /// Messages in this segment
17    pub messages: Vec<Message>,
18    /// Unique tool call IDs referenced in this segment
19    pub tool_call_ids: HashSet<String>,
20    /// Whether this segment contains a tool call chain (assistant call + results)
21    pub is_tool_chain: bool,
22    /// Approximate token count (for sorting/filtering)
23    pub token_estimate: u32,
24}
25
26impl MessageSegment {
27    /// Create a new segment containing a single message.
28    pub fn from_message(message: Message) -> Self {
29        let tool_call_ids = extract_tool_call_ids(&message);
30        let is_tool_chain = !tool_call_ids.is_empty();
31        Self {
32            messages: vec![message],
33            tool_call_ids,
34            is_tool_chain,
35            token_estimate: 0, // Will be set later by counter
36        }
37    }
38
39    /// Merge another segment into this one.
40    pub fn merge(&mut self, other: MessageSegment) {
41        self.messages.extend(other.messages);
42        self.tool_call_ids.extend(other.tool_call_ids);
43        self.is_tool_chain = !self.tool_call_ids.is_empty();
44        self.token_estimate += other.token_estimate;
45    }
46
47    /// Check if this segment contains a tool result for the given tool call ID.
48    pub fn contains_tool_result(&self, tool_call_id: &str) -> bool {
49        self.messages
50            .iter()
51            .any(|m| m.role == Role::Tool && m.tool_call_id.as_deref() == Some(tool_call_id))
52    }
53
54    /// Check if this segment contains the tool call (assistant message) for the given ID.
55    pub fn contains_tool_call(&self, tool_call_id: &str) -> bool {
56        self.messages.iter().any(|m| {
57            m.role == Role::Assistant
58                && m.tool_calls
59                    .as_ref()
60                    .is_some_and(|tc| tc.iter().any(|c| c.id == tool_call_id))
61        })
62    }
63
64    /// Get the IDs of tool calls that are missing their results in this segment.
65    pub fn get_missing_results(&self) -> Vec<&str> {
66        self.tool_call_ids
67            .iter()
68            .filter(|id| !self.contains_tool_result(id))
69            .map(|id| id.as_str())
70            .collect()
71    }
72}
73
74/// Extracts tool call IDs from a message.
75fn extract_tool_call_ids(message: &Message) -> HashSet<String> {
76    let mut ids = HashSet::new();
77
78    // Tool results reference a tool call
79    if let Some(ref id) = message.tool_call_id {
80        ids.insert(id.clone());
81    }
82
83    // Assistant messages with tool calls
84    if let Some(ref calls) = message.tool_calls {
85        for call in calls {
86            ids.insert(call.id.clone());
87        }
88    }
89
90    ids
91}
92
93/// Segments messages into atomic units for budget management.
94///
95/// # Algorithm
96///
97/// 1. Iterate through messages
98/// 2. When finding an assistant message with tool_calls, start a segment
99/// 3. Continue adding messages to the segment until all tool results are collected
100/// 4. Handle edge cases: orphan tool results, standalone messages
101#[derive(Debug)]
102pub struct MessageSegmenter;
103
104impl MessageSegmenter {
105    /// Create a new segmenter.
106    pub fn new() -> Self {
107        Self
108    }
109
110    /// Segment messages, ensuring tool-call chains stay together.
111    ///
112    /// Returns segments in chronological order (oldest first).
113    pub fn segment(&self, messages: Vec<Message>) -> Vec<MessageSegment> {
114        let mut segments: Vec<MessageSegment> = Vec::new();
115        let mut current_segment: Option<MessageSegment> = None;
116        let mut pending_tool_calls: HashSet<String> = HashSet::new();
117
118        for message in messages {
119            match message.role {
120                // System messages are handled separately (always included)
121                Role::System => {
122                    // System messages don't go into segments - they're handled separately
123                    continue;
124                }
125
126                // User and Tool messages
127                Role::User | Role::Tool => {
128                    if let Some(ref mut seg) = current_segment {
129                        // Check if this is a tool result for a pending tool call
130                        if message.role == Role::Tool {
131                            if let Some(ref tool_call_id) = message.tool_call_id {
132                                let tool_call_id = tool_call_id.clone();
133                                if pending_tool_calls.contains(&tool_call_id) {
134                                    seg.messages.push(message);
135                                    pending_tool_calls.remove(&tool_call_id);
136
137                                    // If all results collected, close the segment
138                                    if pending_tool_calls.is_empty() {
139                                        if let Some(seg) = current_segment.take() {
140                                            segments.push(seg);
141                                        }
142                                    }
143                                    continue;
144                                }
145                            }
146                        }
147
148                        // Not part of tool chain - close current segment and start new
149                        if !pending_tool_calls.is_empty() {
150                            // We have an incomplete tool chain - log warning and continue
151                            // This can happen if tool execution was interrupted
152                            tracing::warn!(
153                                "Incomplete tool chain for tool calls: {:?}",
154                                pending_tool_calls
155                            );
156                            pending_tool_calls.clear();
157                        }
158                        if let Some(seg) = current_segment.take() {
159                            segments.push(seg);
160                        }
161                    }
162
163                    // Start new standalone segment for this message
164                    if message.role == Role::Tool {
165                        // Orphan tool result - this shouldn't happen but handle gracefully
166                        tracing::warn!(
167                            "Orphan tool result without preceding tool call: {:?}",
168                            message.tool_call_id
169                        );
170                        // Still create a segment for it to avoid losing data
171                    }
172                    segments.push(MessageSegment::from_message(message));
173                }
174
175                // Assistant messages
176                Role::Assistant => {
177                    // Check if this assistant is responding to a user message
178                    // (no tool calls = standalone message)
179                    let has_tool_calls = message
180                        .tool_calls
181                        .as_ref()
182                        .is_some_and(|calls| !calls.is_empty());
183
184                    if !has_tool_calls {
185                        // Close any pending segment
186                        if let Some(seg) = current_segment.take() {
187                            if !pending_tool_calls.is_empty() {
188                                tracing::warn!(
189                                    "Tool chain interrupted by assistant message: {:?}",
190                                    pending_tool_calls
191                                );
192                                pending_tool_calls.clear();
193                            }
194                            segments.push(seg);
195                        }
196                        // Create standalone segment
197                        segments.push(MessageSegment::from_message(message));
198                    } else {
199                        // Close any pending segment
200                        if let Some(seg) = current_segment.take() {
201                            if !pending_tool_calls.is_empty() {
202                                tracing::warn!(
203                                    "Tool chain interrupted by new tool call: {:?}",
204                                    pending_tool_calls
205                                );
206                                pending_tool_calls.clear();
207                            }
208                            segments.push(seg);
209                        }
210
211                        // Start new tool-call segment
212                        let mut new_seg = MessageSegment::from_message(message.clone());
213
214                        // Collect pending tool calls
215                        if let Some(ref calls) = message.tool_calls {
216                            for call in calls {
217                                pending_tool_calls.insert(call.id.clone());
218                            }
219                            new_seg.is_tool_chain = true;
220                        }
221
222                        current_segment = Some(new_seg);
223                    }
224                }
225            }
226        }
227
228        // Close any remaining segment
229        if let Some(seg) = current_segment.take() {
230            if !pending_tool_calls.is_empty() {
231                tracing::warn!(
232                    "Session ended with incomplete tool chain: {:?}",
233                    pending_tool_calls
234                );
235                pending_tool_calls.clear();
236            }
237            segments.push(seg);
238        }
239
240        segments
241    }
242
243    /// Segment messages including system messages in a separate collection.
244    ///
245    /// Returns (system_messages, segments) tuple.
246    pub fn segment_with_system(
247        &self,
248        messages: Vec<Message>,
249    ) -> (Vec<Message>, Vec<MessageSegment>) {
250        let system_messages: Vec<Message> = messages
251            .iter()
252            .filter(|m| m.role == Role::System)
253            .cloned()
254            .collect();
255
256        let non_system: Vec<Message> = messages
257            .into_iter()
258            .filter(|m| m.role != Role::System)
259            .collect();
260
261        let segments = self.segment(non_system);
262        (system_messages, segments)
263    }
264}
265
266impl Default for MessageSegmenter {
267    fn default() -> Self {
268        Self::new()
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use bamboo_agent_core::Message;
276    use bamboo_agent_core::{FunctionCall, ToolCall};
277
278    fn create_tool_call(id: &str, name: &str, args: &str) -> ToolCall {
279        ToolCall {
280            id: id.to_string(),
281            tool_type: "function".to_string(),
282            function: FunctionCall {
283                name: name.to_string(),
284                arguments: args.to_string(),
285            },
286        }
287    }
288
289    #[test]
290    fn segments_simple_conversation() {
291        let segmenter = MessageSegmenter::new();
292        let messages = vec![
293            Message::user("Hello"),
294            Message::assistant("Hi there", None),
295            Message::user("How are you?"),
296        ];
297
298        let segments = segmenter.segment(messages);
299
300        assert_eq!(segments.len(), 3, "Expected 3 separate segments");
301        assert!(!segments[0].is_tool_chain);
302        assert!(!segments[1].is_tool_chain);
303        assert!(!segments[2].is_tool_chain);
304    }
305
306    #[test]
307    fn segments_tool_call_chain() {
308        let segmenter = MessageSegmenter::new();
309        let messages = vec![
310            Message::user("Search for something"),
311            Message::assistant(
312                "Let me search",
313                Some(vec![create_tool_call(
314                    "call_1",
315                    "search",
316                    r#"{"q":"test"}"#,
317                )]),
318            ),
319            Message::tool_result("call_1", "Here are the results..."),
320        ];
321
322        let segments = segmenter.segment(messages);
323
324        assert_eq!(segments.len(), 2, "Expected 2 segments (user + tool chain)");
325        assert!(!segments[0].is_tool_chain);
326        assert!(segments[1].is_tool_chain);
327        assert_eq!(segments[1].messages.len(), 2); // assistant + tool result
328    }
329
330    #[test]
331    fn segments_multiple_tool_calls() {
332        let segmenter = MessageSegmenter::new();
333        let messages = vec![
334            Message::user("Do multiple things"),
335            Message::assistant(
336                "I'll help",
337                Some(vec![
338                    create_tool_call("call_1", "search", r#"{"q":"a"}"#),
339                    create_tool_call("call_2", "read", r#"{"file":"test.txt"}"#),
340                ]),
341            ),
342            Message::tool_result("call_1", "Search results..."),
343            Message::tool_result("call_2", "File contents..."),
344        ];
345
346        let segments = segmenter.segment(messages);
347
348        assert_eq!(segments.len(), 2);
349        assert!(segments[1].is_tool_chain);
350        assert_eq!(segments[1].messages.len(), 3); // assistant + 2 results
351        assert_eq!(segments[1].tool_call_ids.len(), 2);
352    }
353
354    #[test]
355    fn handles_orphan_tool_result() {
356        let segmenter = MessageSegmenter::new();
357        let messages = vec![
358            Message::user("Hello"),
359            Message::tool_result("orphan_call", "Some result"),
360        ];
361
362        let segments = segmenter.segment(messages);
363
364        assert_eq!(segments.len(), 2);
365        // Orphan tool result gets its own segment
366        assert_eq!(segments[1].messages.len(), 1);
367    }
368
369    #[test]
370    fn handles_system_messages_separately() {
371        let segmenter = MessageSegmenter::new();
372        let messages = vec![
373            Message::system("You are helpful"),
374            Message::user("Hello"),
375            Message::assistant("Hi", None),
376        ];
377
378        let (system, segments) = segmenter.segment_with_system(messages);
379
380        assert_eq!(system.len(), 1);
381        assert_eq!(segments.len(), 2);
382    }
383
384    #[test]
385    fn segments_multiple_interleaved_tool_chains() {
386        let segmenter = MessageSegmenter::new();
387        let messages = vec![
388            Message::user("First task"),
389            Message::assistant(
390                "Doing first",
391                Some(vec![create_tool_call("call_1", "search", "{}")]),
392            ),
393            Message::tool_result("call_1", "Result 1"),
394            Message::user("Second task"),
395            Message::assistant(
396                "Doing second",
397                Some(vec![create_tool_call("call_2", "read", "{}")]),
398            ),
399            Message::tool_result("call_2", "Result 2"),
400        ];
401
402        let segments = segmenter.segment(messages);
403
404        assert_eq!(segments.len(), 4);
405        // segments[0] = user("First task")
406        // segments[1] = tool chain 1
407        // segments[2] = user("Second task")
408        // segments[3] = tool chain 2
409        assert!(segments[1].is_tool_chain);
410        assert!(segments[3].is_tool_chain);
411    }
412
413    #[test]
414    fn empty_messages_produces_empty_segments() {
415        let segmenter = MessageSegmenter::new();
416        let segments = segmenter.segment(vec![]);
417        assert!(segments.is_empty());
418    }
419
420    #[test]
421    fn handles_incomplete_tool_chain_interrupted_by_user() {
422        let segmenter = MessageSegmenter::new();
423        let messages = vec![
424            Message::user("Search for something"),
425            Message::assistant(
426                "Let me search",
427                Some(vec![create_tool_call("call_1", "search", "{}")]),
428            ),
429            // No tool result for call_1 — user sends a follow-up instead
430            Message::user("Actually, never mind"),
431        ];
432
433        let segments = segmenter.segment(messages);
434
435        // Expected: user("Search..."), incomplete tool chain segment, user("Actually...")
436        assert_eq!(segments.len(), 3);
437        assert!(segments[1].is_tool_chain);
438        assert_eq!(segments[1].messages.len(), 1); // only assistant, no result
439        assert_eq!(segments[1].tool_call_ids.len(), 1);
440    }
441
442    #[test]
443    fn handles_tool_chain_interrupted_by_new_tool_call() {
444        let segmenter = MessageSegmenter::new();
445        let messages = vec![
446            Message::user("Task 1"),
447            Message::assistant(
448                "Doing task 1",
449                Some(vec![create_tool_call("call_1", "search", "{}")]),
450            ),
451            // No tool result for call_1 — assistant starts a new tool call
452            Message::assistant(
453                "Let me try a different approach",
454                Some(vec![create_tool_call("call_2", "read", "{}")]),
455            ),
456            Message::tool_result("call_2", "Result 2"),
457        ];
458
459        let segments = segmenter.segment(messages);
460
461        // Expected: user, interrupted segment(call_1), complete tool chain(call_2 + result)
462        assert_eq!(segments.len(), 3);
463        assert!(segments[1].is_tool_chain);
464        assert_eq!(segments[1].messages.len(), 1); // only assistant with call_1
465        assert!(segments[2].is_tool_chain);
466        assert_eq!(segments[2].messages.len(), 2); // assistant with call_2 + result
467    }
468
469    #[test]
470    fn handles_tool_chain_interrupted_by_assistant_text() {
471        let segmenter = MessageSegmenter::new();
472        let messages = vec![
473            Message::user("Search for something"),
474            Message::assistant(
475                "Let me search",
476                Some(vec![create_tool_call("call_1", "search", "{}")]),
477            ),
478            // No tool result — assistant sends plain text instead
479            Message::assistant("I changed my mind", None),
480        ];
481
482        let segments = segmenter.segment(messages);
483
484        assert_eq!(segments.len(), 3);
485        assert!(segments[1].is_tool_chain);
486        assert_eq!(segments[1].messages.len(), 1); // assistant with call_1
487        assert!(!segments[2].is_tool_chain); // plain text assistant
488    }
489
490    #[test]
491    fn pending_tool_calls_cleared_after_interruption() {
492        // This test verifies the fix: after an interruption, pending_tool_calls
493        // is cleared so that subsequent segments don't inherit stale IDs.
494        let segmenter = MessageSegmenter::new();
495        let messages = vec![
496            Message::user("Task 1"),
497            Message::assistant(
498                "Doing task 1",
499                Some(vec![create_tool_call("call_1", "search", "{}")]),
500            ),
501            // Interrupted by user
502            Message::user("Task 2"),
503            // New tool call with a different ID
504            Message::assistant(
505                "Doing task 2",
506                Some(vec![create_tool_call("call_2", "read", "{}")]),
507            ),
508            Message::tool_result("call_2", "Result 2"),
509        ];
510
511        let segments = segmenter.segment(messages);
512
513        assert_eq!(segments.len(), 4);
514        // Segment 1: interrupted call_1 chain
515        assert!(segments[1].is_tool_chain);
516        assert_eq!(segments[1].tool_call_ids.len(), 1);
517        assert!(segments[1].tool_call_ids.contains("call_1"));
518        // Segment 3: complete call_2 chain
519        assert!(segments[3].is_tool_chain);
520        assert_eq!(segments[3].tool_call_ids.len(), 1);
521        assert!(segments[3].tool_call_ids.contains("call_2"));
522        // call_1 should NOT leak into segment 3
523        assert!(!segments[3].tool_call_ids.contains("call_1"));
524    }
525}