llm_worker/timeline/
tool_call_collector.rs

1//! ToolCallCollector - ツール呼び出し収集用ハンドラ
2//!
3//! TimelineのToolUseBlockHandler として登録され、
4//! ストリーム中のToolUseブロックを収集する。
5
6use crate::{
7    handler::{Handler, ToolUseBlockEvent, ToolUseBlockKind},
8    hook::ToolCall,
9};
10use std::sync::{Arc, Mutex};
11
12/// ToolUseブロックから収集したツール呼び出し情報を保持
13///
14/// ToolCallCollectorのHandler実装で使用するスコープ型
15#[derive(Debug, Default)]
16pub struct CollectorState {
17    /// 現在のツール呼び出し情報 (ブロック進行中)
18    current_id: Option<String>,
19    current_name: Option<String>,
20    /// 蓄積中のJSON入力
21    input_json_buffer: String,
22}
23
24/// ToolCallCollector - ToolUseブロックハンドラ
25///
26/// Timelineに登録してToolUseブロックイベントを受信し、
27/// 完了したToolCallを収集する。
28#[derive(Clone)]
29pub struct ToolCallCollector {
30    /// 収集されたToolCall
31    collected: Arc<Mutex<Vec<ToolCall>>>,
32}
33
34impl ToolCallCollector {
35    /// 新しいToolCallCollectorを作成
36    pub fn new() -> Self {
37        Self {
38            collected: Arc::new(Mutex::new(Vec::new())),
39        }
40    }
41
42    /// 収集されたToolCallを取得してクリア
43    pub fn take_collected(&self) -> Vec<ToolCall> {
44        let mut guard = self.collected.lock().unwrap();
45        std::mem::take(&mut *guard)
46    }
47
48    /// 収集されたToolCallの参照を取得
49    pub fn collected(&self) -> Vec<ToolCall> {
50        self.collected.lock().unwrap().clone()
51    }
52
53    /// 収集されたToolCallがあるかどうか
54    pub fn has_pending_calls(&self) -> bool {
55        !self.collected.lock().unwrap().is_empty()
56    }
57
58    /// 収集をクリア
59    pub fn clear(&self) {
60        self.collected.lock().unwrap().clear();
61    }
62}
63
64impl Default for ToolCallCollector {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70impl Handler<ToolUseBlockKind> for ToolCallCollector {
71    type Scope = CollectorState;
72
73    fn on_event(&mut self, scope: &mut Self::Scope, event: &ToolUseBlockEvent) {
74        match event {
75            ToolUseBlockEvent::Start(start) => {
76                scope.current_id = Some(start.id.clone());
77                scope.current_name = Some(start.name.clone());
78                scope.input_json_buffer.clear();
79            }
80            ToolUseBlockEvent::InputJsonDelta(delta) => {
81                scope.input_json_buffer.push_str(delta);
82            }
83            ToolUseBlockEvent::Stop(_stop) => {
84                // ブロック完了時にToolCallを確定
85                if let (Some(id), Some(name)) = (scope.current_id.take(), scope.current_name.take())
86                {
87                    let input = serde_json::from_str(&scope.input_json_buffer)
88                        .unwrap_or(serde_json::Value::Null);
89
90                    let tool_call = ToolCall { id, name, input };
91
92                    self.collected.lock().unwrap().push(tool_call);
93                }
94                scope.input_json_buffer.clear();
95            }
96        }
97    }
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103    use crate::timeline::Timeline;
104    use crate::timeline::event::Event;
105
106    #[test]
107    fn test_collect_single_tool_call() {
108        let collector = ToolCallCollector::new();
109        let mut timeline = Timeline::new();
110        timeline.on_tool_use_block(collector.clone());
111
112        // ToolUseブロックのイベントシーケンスをディスパッチ
113        timeline.dispatch(&Event::tool_use_start(0, "tool_123", "get_weather"));
114        timeline.dispatch(&Event::tool_input_delta(0, r#"{"city":"#));
115        timeline.dispatch(&Event::tool_input_delta(0, r#""Tokyo"}"#));
116        timeline.dispatch(&Event::tool_use_stop(0));
117
118        // 収集されたToolCallを確認
119        let calls = collector.take_collected();
120        assert_eq!(calls.len(), 1);
121        assert_eq!(calls[0].id, "tool_123");
122        assert_eq!(calls[0].name, "get_weather");
123        assert_eq!(calls[0].input["city"], "Tokyo");
124    }
125
126    #[test]
127    fn test_collect_multiple_tool_calls() {
128        let collector = ToolCallCollector::new();
129        let mut timeline = Timeline::new();
130        timeline.on_tool_use_block(collector.clone());
131
132        // 1つ目のToolCall
133        timeline.dispatch(&Event::tool_use_start(0, "call_1", "tool_a"));
134        timeline.dispatch(&Event::tool_input_delta(0, r#"{"a":1}"#));
135        timeline.dispatch(&Event::tool_use_stop(0));
136
137        // 2つ目のToolCall
138        timeline.dispatch(&Event::tool_use_start(1, "call_2", "tool_b"));
139        timeline.dispatch(&Event::tool_input_delta(1, r#"{"b":2}"#));
140        timeline.dispatch(&Event::tool_use_stop(1));
141
142        let calls = collector.take_collected();
143        assert_eq!(calls.len(), 2);
144        assert_eq!(calls[0].name, "tool_a");
145        assert_eq!(calls[1].name, "tool_b");
146    }
147}