1use bamboo_agent_core::{Message, Role};
8use std::collections::HashSet;
9
10#[derive(Debug, Clone)]
15pub struct MessageSegment {
16 pub messages: Vec<Message>,
18 pub tool_call_ids: HashSet<String>,
20 pub is_tool_chain: bool,
22 pub token_estimate: u32,
24}
25
26impl MessageSegment {
27 pub fn from_message(message: Message) -> Self {
29 let tool_call_ids = extract_tool_call_ids(&message);
30 let is_tool_chain = !tool_call_ids.is_empty();
31 Self {
32 messages: vec![message],
33 tool_call_ids,
34 is_tool_chain,
35 token_estimate: 0, }
37 }
38
39 pub fn merge(&mut self, other: MessageSegment) {
41 self.messages.extend(other.messages);
42 self.tool_call_ids.extend(other.tool_call_ids);
43 self.is_tool_chain = !self.tool_call_ids.is_empty();
44 self.token_estimate += other.token_estimate;
45 }
46
47 pub fn contains_tool_result(&self, tool_call_id: &str) -> bool {
49 self.messages
50 .iter()
51 .any(|m| m.role == Role::Tool && m.tool_call_id.as_deref() == Some(tool_call_id))
52 }
53
54 pub fn contains_tool_call(&self, tool_call_id: &str) -> bool {
56 self.messages.iter().any(|m| {
57 m.role == Role::Assistant
58 && m.tool_calls
59 .as_ref()
60 .is_some_and(|tc| tc.iter().any(|c| c.id == tool_call_id))
61 })
62 }
63
64 pub fn get_missing_results(&self) -> Vec<&str> {
66 self.tool_call_ids
67 .iter()
68 .filter(|id| !self.contains_tool_result(id))
69 .map(|id| id.as_str())
70 .collect()
71 }
72}
73
74fn extract_tool_call_ids(message: &Message) -> HashSet<String> {
76 let mut ids = HashSet::new();
77
78 if let Some(ref id) = message.tool_call_id {
80 ids.insert(id.clone());
81 }
82
83 if let Some(ref calls) = message.tool_calls {
85 for call in calls {
86 ids.insert(call.id.clone());
87 }
88 }
89
90 ids
91}
92
93#[derive(Debug)]
102pub struct MessageSegmenter;
103
104impl MessageSegmenter {
105 pub fn new() -> Self {
107 Self
108 }
109
110 pub fn segment(&self, messages: Vec<Message>) -> Vec<MessageSegment> {
114 let mut segments: Vec<MessageSegment> = Vec::new();
115 let mut current_segment: Option<MessageSegment> = None;
116 let mut pending_tool_calls: HashSet<String> = HashSet::new();
117
118 for message in messages {
119 match message.role {
120 Role::System => {
122 continue;
124 }
125
126 Role::User | Role::Tool => {
128 if let Some(ref mut seg) = current_segment {
129 if message.role == Role::Tool {
131 if let Some(ref tool_call_id) = message.tool_call_id {
132 let tool_call_id = tool_call_id.clone();
133 if pending_tool_calls.contains(&tool_call_id) {
134 seg.messages.push(message);
135 pending_tool_calls.remove(&tool_call_id);
136
137 if pending_tool_calls.is_empty() {
139 if let Some(seg) = current_segment.take() {
140 segments.push(seg);
141 }
142 }
143 continue;
144 }
145 }
146 }
147
148 if !pending_tool_calls.is_empty() {
150 tracing::warn!(
153 "Incomplete tool chain for tool calls: {:?}",
154 pending_tool_calls
155 );
156 pending_tool_calls.clear();
157 }
158 if let Some(seg) = current_segment.take() {
159 segments.push(seg);
160 }
161 }
162
163 if message.role == Role::Tool {
165 tracing::warn!(
167 "Orphan tool result without preceding tool call: {:?}",
168 message.tool_call_id
169 );
170 }
172 segments.push(MessageSegment::from_message(message));
173 }
174
175 Role::Assistant => {
177 let has_tool_calls = message
180 .tool_calls
181 .as_ref()
182 .is_some_and(|calls| !calls.is_empty());
183
184 if !has_tool_calls {
185 if let Some(seg) = current_segment.take() {
187 if !pending_tool_calls.is_empty() {
188 tracing::warn!(
189 "Tool chain interrupted by assistant message: {:?}",
190 pending_tool_calls
191 );
192 pending_tool_calls.clear();
193 }
194 segments.push(seg);
195 }
196 segments.push(MessageSegment::from_message(message));
198 } else {
199 if let Some(seg) = current_segment.take() {
201 if !pending_tool_calls.is_empty() {
202 tracing::warn!(
203 "Tool chain interrupted by new tool call: {:?}",
204 pending_tool_calls
205 );
206 pending_tool_calls.clear();
207 }
208 segments.push(seg);
209 }
210
211 let mut new_seg = MessageSegment::from_message(message.clone());
213
214 if let Some(ref calls) = message.tool_calls {
216 for call in calls {
217 pending_tool_calls.insert(call.id.clone());
218 }
219 new_seg.is_tool_chain = true;
220 }
221
222 current_segment = Some(new_seg);
223 }
224 }
225 }
226 }
227
228 if let Some(seg) = current_segment.take() {
230 if !pending_tool_calls.is_empty() {
231 tracing::warn!(
232 "Session ended with incomplete tool chain: {:?}",
233 pending_tool_calls
234 );
235 pending_tool_calls.clear();
236 }
237 segments.push(seg);
238 }
239
240 segments
241 }
242
243 pub fn segment_with_system(
247 &self,
248 messages: Vec<Message>,
249 ) -> (Vec<Message>, Vec<MessageSegment>) {
250 let system_messages: Vec<Message> = messages
251 .iter()
252 .filter(|m| m.role == Role::System)
253 .cloned()
254 .collect();
255
256 let non_system: Vec<Message> = messages
257 .into_iter()
258 .filter(|m| m.role != Role::System)
259 .collect();
260
261 let segments = self.segment(non_system);
262 (system_messages, segments)
263 }
264}
265
266impl Default for MessageSegmenter {
267 fn default() -> Self {
268 Self::new()
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275 use bamboo_agent_core::Message;
276 use bamboo_agent_core::{FunctionCall, ToolCall};
277
278 fn create_tool_call(id: &str, name: &str, args: &str) -> ToolCall {
279 ToolCall {
280 id: id.to_string(),
281 tool_type: "function".to_string(),
282 function: FunctionCall {
283 name: name.to_string(),
284 arguments: args.to_string(),
285 },
286 }
287 }
288
289 #[test]
290 fn segments_simple_conversation() {
291 let segmenter = MessageSegmenter::new();
292 let messages = vec![
293 Message::user("Hello"),
294 Message::assistant("Hi there", None),
295 Message::user("How are you?"),
296 ];
297
298 let segments = segmenter.segment(messages);
299
300 assert_eq!(segments.len(), 3, "Expected 3 separate segments");
301 assert!(!segments[0].is_tool_chain);
302 assert!(!segments[1].is_tool_chain);
303 assert!(!segments[2].is_tool_chain);
304 }
305
306 #[test]
307 fn segments_tool_call_chain() {
308 let segmenter = MessageSegmenter::new();
309 let messages = vec![
310 Message::user("Search for something"),
311 Message::assistant(
312 "Let me search",
313 Some(vec![create_tool_call(
314 "call_1",
315 "search",
316 r#"{"q":"test"}"#,
317 )]),
318 ),
319 Message::tool_result("call_1", "Here are the results..."),
320 ];
321
322 let segments = segmenter.segment(messages);
323
324 assert_eq!(segments.len(), 2, "Expected 2 segments (user + tool chain)");
325 assert!(!segments[0].is_tool_chain);
326 assert!(segments[1].is_tool_chain);
327 assert_eq!(segments[1].messages.len(), 2); }
329
330 #[test]
331 fn segments_multiple_tool_calls() {
332 let segmenter = MessageSegmenter::new();
333 let messages = vec![
334 Message::user("Do multiple things"),
335 Message::assistant(
336 "I'll help",
337 Some(vec![
338 create_tool_call("call_1", "search", r#"{"q":"a"}"#),
339 create_tool_call("call_2", "read", r#"{"file":"test.txt"}"#),
340 ]),
341 ),
342 Message::tool_result("call_1", "Search results..."),
343 Message::tool_result("call_2", "File contents..."),
344 ];
345
346 let segments = segmenter.segment(messages);
347
348 assert_eq!(segments.len(), 2);
349 assert!(segments[1].is_tool_chain);
350 assert_eq!(segments[1].messages.len(), 3); assert_eq!(segments[1].tool_call_ids.len(), 2);
352 }
353
354 #[test]
355 fn handles_orphan_tool_result() {
356 let segmenter = MessageSegmenter::new();
357 let messages = vec![
358 Message::user("Hello"),
359 Message::tool_result("orphan_call", "Some result"),
360 ];
361
362 let segments = segmenter.segment(messages);
363
364 assert_eq!(segments.len(), 2);
365 assert_eq!(segments[1].messages.len(), 1);
367 }
368
369 #[test]
370 fn handles_system_messages_separately() {
371 let segmenter = MessageSegmenter::new();
372 let messages = vec![
373 Message::system("You are helpful"),
374 Message::user("Hello"),
375 Message::assistant("Hi", None),
376 ];
377
378 let (system, segments) = segmenter.segment_with_system(messages);
379
380 assert_eq!(system.len(), 1);
381 assert_eq!(segments.len(), 2);
382 }
383
384 #[test]
385 fn segments_multiple_interleaved_tool_chains() {
386 let segmenter = MessageSegmenter::new();
387 let messages = vec![
388 Message::user("First task"),
389 Message::assistant(
390 "Doing first",
391 Some(vec![create_tool_call("call_1", "search", "{}")]),
392 ),
393 Message::tool_result("call_1", "Result 1"),
394 Message::user("Second task"),
395 Message::assistant(
396 "Doing second",
397 Some(vec![create_tool_call("call_2", "read", "{}")]),
398 ),
399 Message::tool_result("call_2", "Result 2"),
400 ];
401
402 let segments = segmenter.segment(messages);
403
404 assert_eq!(segments.len(), 4);
405 assert!(segments[1].is_tool_chain);
410 assert!(segments[3].is_tool_chain);
411 }
412
413 #[test]
414 fn empty_messages_produces_empty_segments() {
415 let segmenter = MessageSegmenter::new();
416 let segments = segmenter.segment(vec![]);
417 assert!(segments.is_empty());
418 }
419
420 #[test]
421 fn handles_incomplete_tool_chain_interrupted_by_user() {
422 let segmenter = MessageSegmenter::new();
423 let messages = vec![
424 Message::user("Search for something"),
425 Message::assistant(
426 "Let me search",
427 Some(vec![create_tool_call("call_1", "search", "{}")]),
428 ),
429 Message::user("Actually, never mind"),
431 ];
432
433 let segments = segmenter.segment(messages);
434
435 assert_eq!(segments.len(), 3);
437 assert!(segments[1].is_tool_chain);
438 assert_eq!(segments[1].messages.len(), 1); assert_eq!(segments[1].tool_call_ids.len(), 1);
440 }
441
442 #[test]
443 fn handles_tool_chain_interrupted_by_new_tool_call() {
444 let segmenter = MessageSegmenter::new();
445 let messages = vec![
446 Message::user("Task 1"),
447 Message::assistant(
448 "Doing task 1",
449 Some(vec![create_tool_call("call_1", "search", "{}")]),
450 ),
451 Message::assistant(
453 "Let me try a different approach",
454 Some(vec![create_tool_call("call_2", "read", "{}")]),
455 ),
456 Message::tool_result("call_2", "Result 2"),
457 ];
458
459 let segments = segmenter.segment(messages);
460
461 assert_eq!(segments.len(), 3);
463 assert!(segments[1].is_tool_chain);
464 assert_eq!(segments[1].messages.len(), 1); assert!(segments[2].is_tool_chain);
466 assert_eq!(segments[2].messages.len(), 2); }
468
469 #[test]
470 fn handles_tool_chain_interrupted_by_assistant_text() {
471 let segmenter = MessageSegmenter::new();
472 let messages = vec![
473 Message::user("Search for something"),
474 Message::assistant(
475 "Let me search",
476 Some(vec![create_tool_call("call_1", "search", "{}")]),
477 ),
478 Message::assistant("I changed my mind", None),
480 ];
481
482 let segments = segmenter.segment(messages);
483
484 assert_eq!(segments.len(), 3);
485 assert!(segments[1].is_tool_chain);
486 assert_eq!(segments[1].messages.len(), 1); assert!(!segments[2].is_tool_chain); }
489
490 #[test]
491 fn pending_tool_calls_cleared_after_interruption() {
492 let segmenter = MessageSegmenter::new();
495 let messages = vec![
496 Message::user("Task 1"),
497 Message::assistant(
498 "Doing task 1",
499 Some(vec![create_tool_call("call_1", "search", "{}")]),
500 ),
501 Message::user("Task 2"),
503 Message::assistant(
505 "Doing task 2",
506 Some(vec![create_tool_call("call_2", "read", "{}")]),
507 ),
508 Message::tool_result("call_2", "Result 2"),
509 ];
510
511 let segments = segmenter.segment(messages);
512
513 assert_eq!(segments.len(), 4);
514 assert!(segments[1].is_tool_chain);
516 assert_eq!(segments[1].tool_call_ids.len(), 1);
517 assert!(segments[1].tool_call_ids.contains("call_1"));
518 assert!(segments[3].is_tool_chain);
520 assert_eq!(segments[3].tool_call_ids.len(), 1);
521 assert!(segments[3].tool_call_ids.contains("call_2"));
522 assert!(!segments[3].tool_call_ids.contains("call_1"));
524 }
525}