llm_worker/timeline/
tool_call_collector.rs1use crate::{
7 handler::{Handler, ToolUseBlockEvent, ToolUseBlockKind},
8 hook::ToolCall,
9};
10use std::sync::{Arc, Mutex};
11
12#[derive(Debug, Default)]
16pub struct CollectorState {
17 current_id: Option<String>,
19 current_name: Option<String>,
20 input_json_buffer: String,
22}
23
24#[derive(Clone)]
29pub struct ToolCallCollector {
30 collected: Arc<Mutex<Vec<ToolCall>>>,
32}
33
34impl ToolCallCollector {
35 pub fn new() -> Self {
37 Self {
38 collected: Arc::new(Mutex::new(Vec::new())),
39 }
40 }
41
42 pub fn take_collected(&self) -> Vec<ToolCall> {
44 let mut guard = self.collected.lock().unwrap();
45 std::mem::take(&mut *guard)
46 }
47
48 pub fn collected(&self) -> Vec<ToolCall> {
50 self.collected.lock().unwrap().clone()
51 }
52
53 pub fn has_pending_calls(&self) -> bool {
55 !self.collected.lock().unwrap().is_empty()
56 }
57
58 pub fn clear(&self) {
60 self.collected.lock().unwrap().clear();
61 }
62}
63
64impl Default for ToolCallCollector {
65 fn default() -> Self {
66 Self::new()
67 }
68}
69
70impl Handler<ToolUseBlockKind> for ToolCallCollector {
71 type Scope = CollectorState;
72
73 fn on_event(&mut self, scope: &mut Self::Scope, event: &ToolUseBlockEvent) {
74 match event {
75 ToolUseBlockEvent::Start(start) => {
76 scope.current_id = Some(start.id.clone());
77 scope.current_name = Some(start.name.clone());
78 scope.input_json_buffer.clear();
79 }
80 ToolUseBlockEvent::InputJsonDelta(delta) => {
81 scope.input_json_buffer.push_str(delta);
82 }
83 ToolUseBlockEvent::Stop(_stop) => {
84 if let (Some(id), Some(name)) = (scope.current_id.take(), scope.current_name.take())
86 {
87 let input = serde_json::from_str(&scope.input_json_buffer)
88 .unwrap_or(serde_json::Value::Null);
89
90 let tool_call = ToolCall { id, name, input };
91
92 self.collected.lock().unwrap().push(tool_call);
93 }
94 scope.input_json_buffer.clear();
95 }
96 }
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103 use crate::timeline::Timeline;
104 use crate::timeline::event::Event;
105
106 #[test]
107 fn test_collect_single_tool_call() {
108 let collector = ToolCallCollector::new();
109 let mut timeline = Timeline::new();
110 timeline.on_tool_use_block(collector.clone());
111
112 timeline.dispatch(&Event::tool_use_start(0, "tool_123", "get_weather"));
114 timeline.dispatch(&Event::tool_input_delta(0, r#"{"city":"#));
115 timeline.dispatch(&Event::tool_input_delta(0, r#""Tokyo"}"#));
116 timeline.dispatch(&Event::tool_use_stop(0));
117
118 let calls = collector.take_collected();
120 assert_eq!(calls.len(), 1);
121 assert_eq!(calls[0].id, "tool_123");
122 assert_eq!(calls[0].name, "get_weather");
123 assert_eq!(calls[0].input["city"], "Tokyo");
124 }
125
126 #[test]
127 fn test_collect_multiple_tool_calls() {
128 let collector = ToolCallCollector::new();
129 let mut timeline = Timeline::new();
130 timeline.on_tool_use_block(collector.clone());
131
132 timeline.dispatch(&Event::tool_use_start(0, "call_1", "tool_a"));
134 timeline.dispatch(&Event::tool_input_delta(0, r#"{"a":1}"#));
135 timeline.dispatch(&Event::tool_use_stop(0));
136
137 timeline.dispatch(&Event::tool_use_start(1, "call_2", "tool_b"));
139 timeline.dispatch(&Event::tool_input_delta(1, r#"{"b":2}"#));
140 timeline.dispatch(&Event::tool_use_stop(1));
141
142 let calls = collector.take_collected();
143 assert_eq!(calls.len(), 2);
144 assert_eq!(calls[0].name, "tool_a");
145 assert_eq!(calls[1].name, "tool_b");
146 }
147}