use std::cell::RefCell;
use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock, RwLock};
use serde::{Deserialize, Serialize};
use crate::tool_annotations::ToolKind;
use crate::value::VmValue;
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolCallStatus {
Pending,
InProgress,
Completed,
Failed,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum AgentEvent {
AgentMessageChunk {
session_id: String,
content: String,
},
AgentThoughtChunk {
session_id: String,
content: String,
},
ToolCall {
session_id: String,
tool_call_id: String,
tool_name: String,
kind: Option<ToolKind>,
status: ToolCallStatus,
raw_input: serde_json::Value,
},
ToolCallUpdate {
session_id: String,
tool_call_id: String,
tool_name: String,
status: ToolCallStatus,
raw_output: Option<serde_json::Value>,
error: Option<String>,
},
Plan {
session_id: String,
plan: serde_json::Value,
},
TurnStart {
session_id: String,
iteration: usize,
},
TurnEnd {
session_id: String,
iteration: usize,
turn_info: serde_json::Value,
},
FeedbackInjected {
session_id: String,
kind: String,
content: String,
},
}
impl AgentEvent {
pub fn session_id(&self) -> &str {
match self {
Self::AgentMessageChunk { session_id, .. }
| Self::AgentThoughtChunk { session_id, .. }
| Self::ToolCall { session_id, .. }
| Self::ToolCallUpdate { session_id, .. }
| Self::Plan { session_id, .. }
| Self::TurnStart { session_id, .. }
| Self::TurnEnd { session_id, .. }
| Self::FeedbackInjected { session_id, .. } => session_id,
}
}
}
pub trait AgentEventSink: Send + Sync {
fn handle_event(&self, event: &AgentEvent);
}
pub struct MultiSink {
sinks: Mutex<Vec<Arc<dyn AgentEventSink>>>,
}
impl MultiSink {
pub fn new() -> Self {
Self {
sinks: Mutex::new(Vec::new()),
}
}
pub fn push(&self, sink: Arc<dyn AgentEventSink>) {
self.sinks.lock().expect("sink mutex poisoned").push(sink);
}
pub fn len(&self) -> usize {
self.sinks.lock().expect("sink mutex poisoned").len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl Default for MultiSink {
fn default() -> Self {
Self::new()
}
}
impl AgentEventSink for MultiSink {
fn handle_event(&self, event: &AgentEvent) {
let sinks = self.sinks.lock().expect("sink mutex poisoned").clone();
for sink in sinks {
sink.handle_event(event);
}
}
}
type ExternalSinkRegistry = RwLock<HashMap<String, Vec<Arc<dyn AgentEventSink>>>>;
fn external_sinks() -> &'static ExternalSinkRegistry {
static REGISTRY: OnceLock<ExternalSinkRegistry> = OnceLock::new();
REGISTRY.get_or_init(|| RwLock::new(HashMap::new()))
}
thread_local! {
static CLOSURE_SUBSCRIBERS: RefCell<HashMap<String, Vec<VmValue>>> =
RefCell::new(HashMap::new());
}
pub fn register_sink(session_id: impl Into<String>, sink: Arc<dyn AgentEventSink>) {
let session_id = session_id.into();
let mut reg = external_sinks().write().expect("sink registry poisoned");
reg.entry(session_id).or_default().push(sink);
}
pub fn register_closure_subscriber(session_id: impl Into<String>, closure: VmValue) {
let session_id = session_id.into();
CLOSURE_SUBSCRIBERS.with(|reg| {
reg.borrow_mut()
.entry(session_id)
.or_default()
.push(closure);
});
}
pub fn closure_subscribers_for(session_id: &str) -> Vec<VmValue> {
CLOSURE_SUBSCRIBERS.with(|reg| reg.borrow().get(session_id).cloned().unwrap_or_default())
}
pub fn clear_session_sinks(session_id: &str) {
external_sinks()
.write()
.expect("sink registry poisoned")
.remove(session_id);
CLOSURE_SUBSCRIBERS.with(|reg| {
reg.borrow_mut().remove(session_id);
});
}
pub fn reset_all_sinks() {
external_sinks()
.write()
.expect("sink registry poisoned")
.clear();
CLOSURE_SUBSCRIBERS.with(|reg| {
reg.borrow_mut().clear();
});
}
pub fn emit_event(event: &AgentEvent) {
let sinks: Vec<Arc<dyn AgentEventSink>> = {
let reg = external_sinks().read().expect("sink registry poisoned");
reg.get(event.session_id()).cloned().unwrap_or_default()
};
for sink in sinks {
sink.handle_event(event);
}
}
pub fn session_external_sink_count(session_id: &str) -> usize {
external_sinks()
.read()
.expect("sink registry poisoned")
.get(session_id)
.map(|v| v.len())
.unwrap_or(0)
}
pub fn session_closure_subscriber_count(session_id: &str) -> usize {
CLOSURE_SUBSCRIBERS.with(|reg| reg.borrow().get(session_id).map(|v| v.len()).unwrap_or(0))
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
struct CountingSink(Arc<AtomicUsize>);
impl AgentEventSink for CountingSink {
fn handle_event(&self, _event: &AgentEvent) {
self.0.fetch_add(1, Ordering::SeqCst);
}
}
#[test]
fn multi_sink_fans_out_in_order() {
let multi = MultiSink::new();
let a = Arc::new(AtomicUsize::new(0));
let b = Arc::new(AtomicUsize::new(0));
multi.push(Arc::new(CountingSink(a.clone())));
multi.push(Arc::new(CountingSink(b.clone())));
let event = AgentEvent::TurnStart {
session_id: "s1".into(),
iteration: 1,
};
multi.handle_event(&event);
assert_eq!(a.load(Ordering::SeqCst), 1);
assert_eq!(b.load(Ordering::SeqCst), 1);
}
#[test]
fn session_scoped_sink_routing() {
reset_all_sinks();
let a = Arc::new(AtomicUsize::new(0));
let b = Arc::new(AtomicUsize::new(0));
register_sink("session-a", Arc::new(CountingSink(a.clone())));
register_sink("session-b", Arc::new(CountingSink(b.clone())));
emit_event(&AgentEvent::TurnStart {
session_id: "session-a".into(),
iteration: 0,
});
assert_eq!(a.load(Ordering::SeqCst), 1);
assert_eq!(b.load(Ordering::SeqCst), 0);
emit_event(&AgentEvent::TurnEnd {
session_id: "session-b".into(),
iteration: 0,
turn_info: serde_json::json!({}),
});
assert_eq!(a.load(Ordering::SeqCst), 1);
assert_eq!(b.load(Ordering::SeqCst), 1);
clear_session_sinks("session-a");
assert_eq!(session_external_sink_count("session-a"), 0);
assert_eq!(session_external_sink_count("session-b"), 1);
reset_all_sinks();
}
#[test]
fn tool_call_status_serde() {
assert_eq!(
serde_json::to_string(&ToolCallStatus::Pending).unwrap(),
"\"pending\""
);
assert_eq!(
serde_json::to_string(&ToolCallStatus::InProgress).unwrap(),
"\"in_progress\""
);
assert_eq!(
serde_json::to_string(&ToolCallStatus::Completed).unwrap(),
"\"completed\""
);
assert_eq!(
serde_json::to_string(&ToolCallStatus::Failed).unwrap(),
"\"failed\""
);
}
}