1use std::cell::RefCell;
21use std::collections::HashMap;
22use std::sync::{Arc, Mutex, OnceLock, RwLock};
23
24use serde::{Deserialize, Serialize};
25
26use crate::tool_annotations::ToolKind;
27use crate::value::VmValue;
28
29#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
31#[serde(rename_all = "snake_case")]
32pub enum ToolCallStatus {
33 Pending,
35 InProgress,
37 Completed,
39 Failed,
41}
42
43#[derive(Clone, Debug, Serialize, Deserialize)]
46#[serde(tag = "type", rename_all = "snake_case")]
47pub enum AgentEvent {
48 AgentMessageChunk {
49 session_id: String,
50 content: String,
51 },
52 AgentThoughtChunk {
53 session_id: String,
54 content: String,
55 },
56 ToolCall {
57 session_id: String,
58 tool_call_id: String,
59 tool_name: String,
60 kind: Option<ToolKind>,
61 status: ToolCallStatus,
62 raw_input: serde_json::Value,
63 },
64 ToolCallUpdate {
65 session_id: String,
66 tool_call_id: String,
67 tool_name: String,
68 status: ToolCallStatus,
69 raw_output: Option<serde_json::Value>,
70 error: Option<String>,
71 },
72 Plan {
73 session_id: String,
74 plan: serde_json::Value,
75 },
76 TurnStart {
78 session_id: String,
79 iteration: usize,
80 },
81 TurnEnd {
82 session_id: String,
83 iteration: usize,
84 turn_info: serde_json::Value,
85 },
86 FeedbackInjected {
87 session_id: String,
88 kind: String,
89 content: String,
90 },
91}
92
93impl AgentEvent {
94 pub fn session_id(&self) -> &str {
95 match self {
96 Self::AgentMessageChunk { session_id, .. }
97 | Self::AgentThoughtChunk { session_id, .. }
98 | Self::ToolCall { session_id, .. }
99 | Self::ToolCallUpdate { session_id, .. }
100 | Self::Plan { session_id, .. }
101 | Self::TurnStart { session_id, .. }
102 | Self::TurnEnd { session_id, .. }
103 | Self::FeedbackInjected { session_id, .. } => session_id,
104 }
105 }
106}
107
108pub trait AgentEventSink: Send + Sync {
111 fn handle_event(&self, event: &AgentEvent);
112}
113
114pub struct MultiSink {
116 sinks: Mutex<Vec<Arc<dyn AgentEventSink>>>,
117}
118
119impl MultiSink {
120 pub fn new() -> Self {
121 Self {
122 sinks: Mutex::new(Vec::new()),
123 }
124 }
125 pub fn push(&self, sink: Arc<dyn AgentEventSink>) {
126 self.sinks.lock().expect("sink mutex poisoned").push(sink);
127 }
128 pub fn len(&self) -> usize {
129 self.sinks.lock().expect("sink mutex poisoned").len()
130 }
131 pub fn is_empty(&self) -> bool {
132 self.len() == 0
133 }
134}
135
136impl Default for MultiSink {
137 fn default() -> Self {
138 Self::new()
139 }
140}
141
142impl AgentEventSink for MultiSink {
143 fn handle_event(&self, event: &AgentEvent) {
144 let sinks = self.sinks.lock().expect("sink mutex poisoned").clone();
145 for sink in sinks {
146 sink.handle_event(event);
147 }
148 }
149}
150
151type ExternalSinkRegistry = RwLock<HashMap<String, Vec<Arc<dyn AgentEventSink>>>>;
154
155fn external_sinks() -> &'static ExternalSinkRegistry {
156 static REGISTRY: OnceLock<ExternalSinkRegistry> = OnceLock::new();
157 REGISTRY.get_or_init(|| RwLock::new(HashMap::new()))
158}
159
160thread_local! {
166 static CLOSURE_SUBSCRIBERS: RefCell<HashMap<String, Vec<VmValue>>> =
167 RefCell::new(HashMap::new());
168}
169
170pub fn register_sink(session_id: impl Into<String>, sink: Arc<dyn AgentEventSink>) {
172 let session_id = session_id.into();
173 let mut reg = external_sinks().write().expect("sink registry poisoned");
174 reg.entry(session_id).or_default().push(sink);
175}
176
177pub fn register_closure_subscriber(session_id: impl Into<String>, closure: VmValue) {
178 let session_id = session_id.into();
179 CLOSURE_SUBSCRIBERS.with(|reg| {
180 reg.borrow_mut()
181 .entry(session_id)
182 .or_default()
183 .push(closure);
184 });
185}
186
187pub fn closure_subscribers_for(session_id: &str) -> Vec<VmValue> {
188 CLOSURE_SUBSCRIBERS.with(|reg| reg.borrow().get(session_id).cloned().unwrap_or_default())
189}
190
191pub fn clear_session_sinks(session_id: &str) {
192 external_sinks()
193 .write()
194 .expect("sink registry poisoned")
195 .remove(session_id);
196 CLOSURE_SUBSCRIBERS.with(|reg| {
197 reg.borrow_mut().remove(session_id);
198 });
199}
200
201pub fn reset_all_sinks() {
202 external_sinks()
203 .write()
204 .expect("sink registry poisoned")
205 .clear();
206 CLOSURE_SUBSCRIBERS.with(|reg| {
207 reg.borrow_mut().clear();
208 });
209}
210
211pub fn emit_event(event: &AgentEvent) {
215 let sinks: Vec<Arc<dyn AgentEventSink>> = {
216 let reg = external_sinks().read().expect("sink registry poisoned");
217 reg.get(event.session_id()).cloned().unwrap_or_default()
218 };
219 for sink in sinks {
220 sink.handle_event(event);
221 }
222}
223
224pub fn session_external_sink_count(session_id: &str) -> usize {
225 external_sinks()
226 .read()
227 .expect("sink registry poisoned")
228 .get(session_id)
229 .map(|v| v.len())
230 .unwrap_or(0)
231}
232
233pub fn session_closure_subscriber_count(session_id: &str) -> usize {
234 CLOSURE_SUBSCRIBERS.with(|reg| reg.borrow().get(session_id).map(|v| v.len()).unwrap_or(0))
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240 use std::sync::atomic::{AtomicUsize, Ordering};
241
242 struct CountingSink(Arc<AtomicUsize>);
243 impl AgentEventSink for CountingSink {
244 fn handle_event(&self, _event: &AgentEvent) {
245 self.0.fetch_add(1, Ordering::SeqCst);
246 }
247 }
248
249 #[test]
250 fn multi_sink_fans_out_in_order() {
251 let multi = MultiSink::new();
252 let a = Arc::new(AtomicUsize::new(0));
253 let b = Arc::new(AtomicUsize::new(0));
254 multi.push(Arc::new(CountingSink(a.clone())));
255 multi.push(Arc::new(CountingSink(b.clone())));
256 let event = AgentEvent::TurnStart {
257 session_id: "s1".into(),
258 iteration: 1,
259 };
260 multi.handle_event(&event);
261 assert_eq!(a.load(Ordering::SeqCst), 1);
262 assert_eq!(b.load(Ordering::SeqCst), 1);
263 }
264
265 #[test]
266 fn session_scoped_sink_routing() {
267 reset_all_sinks();
268 let a = Arc::new(AtomicUsize::new(0));
269 let b = Arc::new(AtomicUsize::new(0));
270 register_sink("session-a", Arc::new(CountingSink(a.clone())));
271 register_sink("session-b", Arc::new(CountingSink(b.clone())));
272 emit_event(&AgentEvent::TurnStart {
273 session_id: "session-a".into(),
274 iteration: 0,
275 });
276 assert_eq!(a.load(Ordering::SeqCst), 1);
277 assert_eq!(b.load(Ordering::SeqCst), 0);
278 emit_event(&AgentEvent::TurnEnd {
279 session_id: "session-b".into(),
280 iteration: 0,
281 turn_info: serde_json::json!({}),
282 });
283 assert_eq!(a.load(Ordering::SeqCst), 1);
284 assert_eq!(b.load(Ordering::SeqCst), 1);
285 clear_session_sinks("session-a");
286 assert_eq!(session_external_sink_count("session-a"), 0);
287 assert_eq!(session_external_sink_count("session-b"), 1);
288 reset_all_sinks();
289 }
290
291 #[test]
292 fn tool_call_status_serde() {
293 assert_eq!(
294 serde_json::to_string(&ToolCallStatus::Pending).unwrap(),
295 "\"pending\""
296 );
297 assert_eq!(
298 serde_json::to_string(&ToolCallStatus::InProgress).unwrap(),
299 "\"in_progress\""
300 );
301 assert_eq!(
302 serde_json::to_string(&ToolCallStatus::Completed).unwrap(),
303 "\"completed\""
304 );
305 assert_eq!(
306 serde_json::to_string(&ToolCallStatus::Failed).unwrap(),
307 "\"failed\""
308 );
309 }
310}