use crate::{
handler::{Handler, ToolUseBlockEvent, ToolUseBlockKind},
hook::ToolCall,
};
use std::sync::{Arc, Mutex};
#[derive(Debug, Default)]
pub struct CollectorState {
current_id: Option<String>,
current_name: Option<String>,
input_json_buffer: String,
}
#[derive(Clone)]
pub struct ToolCallCollector {
collected: Arc<Mutex<Vec<ToolCall>>>,
}
impl ToolCallCollector {
pub fn new() -> Self {
Self {
collected: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn take_collected(&self) -> Vec<ToolCall> {
let mut guard = self.collected.lock().unwrap();
std::mem::take(&mut *guard)
}
pub fn collected(&self) -> Vec<ToolCall> {
self.collected.lock().unwrap().clone()
}
pub fn has_pending_calls(&self) -> bool {
!self.collected.lock().unwrap().is_empty()
}
pub fn clear(&self) {
self.collected.lock().unwrap().clear();
}
}
impl Default for ToolCallCollector {
fn default() -> Self {
Self::new()
}
}
impl Handler<ToolUseBlockKind> for ToolCallCollector {
type Scope = CollectorState;
fn on_event(&mut self, scope: &mut Self::Scope, event: &ToolUseBlockEvent) {
match event {
ToolUseBlockEvent::Start(start) => {
scope.current_id = Some(start.id.clone());
scope.current_name = Some(start.name.clone());
scope.input_json_buffer.clear();
}
ToolUseBlockEvent::InputJsonDelta(delta) => {
scope.input_json_buffer.push_str(delta);
}
ToolUseBlockEvent::Stop(_stop) => {
if let (Some(id), Some(name)) = (scope.current_id.take(), scope.current_name.take())
{
let input = serde_json::from_str(&scope.input_json_buffer)
.unwrap_or(serde_json::Value::Null);
let tool_call = ToolCall { id, name, input };
self.collected.lock().unwrap().push(tool_call);
}
scope.input_json_buffer.clear();
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::timeline::Timeline;
use crate::timeline::event::Event;
#[test]
fn test_collect_single_tool_call() {
let collector = ToolCallCollector::new();
let mut timeline = Timeline::new();
timeline.on_tool_use_block(collector.clone());
timeline.dispatch(&Event::tool_use_start(0, "tool_123", "get_weather"));
timeline.dispatch(&Event::tool_input_delta(0, r#"{"city":"#));
timeline.dispatch(&Event::tool_input_delta(0, r#""Tokyo"}"#));
timeline.dispatch(&Event::tool_use_stop(0));
let calls = collector.take_collected();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].id, "tool_123");
assert_eq!(calls[0].name, "get_weather");
assert_eq!(calls[0].input["city"], "Tokyo");
}
#[test]
fn test_collect_multiple_tool_calls() {
let collector = ToolCallCollector::new();
let mut timeline = Timeline::new();
timeline.on_tool_use_block(collector.clone());
timeline.dispatch(&Event::tool_use_start(0, "call_1", "tool_a"));
timeline.dispatch(&Event::tool_input_delta(0, r#"{"a":1}"#));
timeline.dispatch(&Event::tool_use_stop(0));
timeline.dispatch(&Event::tool_use_start(1, "call_2", "tool_b"));
timeline.dispatch(&Event::tool_input_delta(1, r#"{"b":2}"#));
timeline.dispatch(&Event::tool_use_stop(1));
let calls = collector.take_collected();
assert_eq!(calls.len(), 2);
assert_eq!(calls[0].name, "tool_a");
assert_eq!(calls[1].name, "tool_b");
}
}