use std::sync::Arc;
use syncable_ag_ui_core::{
BaseEvent, Event, InterruptInfo, JsonValue, MessageId, Role, RunFinishedEvent,
RunFinishedOutcome, RunId, RunStartedEvent, TextMessageContentEvent, TextMessageEndEvent,
TextMessageStartEvent, ThreadId, ToolCallArgsEvent, ToolCallEndEvent, ToolCallId,
ToolCallStartEvent,
};
use tokio::sync::{RwLock, broadcast};
#[derive(Clone)]
pub struct EventBridge {
event_tx: broadcast::Sender<Event<JsonValue>>,
thread_id: Arc<RwLock<ThreadId>>,
run_id: Arc<RwLock<Option<RunId>>>,
current_message_id: Arc<RwLock<Option<MessageId>>>,
current_step_name: Arc<RwLock<Option<String>>>,
}
impl EventBridge {
pub fn new(
event_tx: broadcast::Sender<Event<JsonValue>>,
thread_id: Arc<RwLock<ThreadId>>,
run_id: Arc<RwLock<Option<RunId>>>,
) -> Self {
Self {
event_tx,
thread_id,
run_id,
current_message_id: Arc::new(RwLock::new(None)),
current_step_name: Arc::new(RwLock::new(None)),
}
}
fn emit(&self, event: Event<JsonValue>) {
let _ = self.event_tx.send(event);
}
pub async fn start_run(&self) {
let thread_id = self.thread_id.read().await.clone();
let run_id = RunId::random();
*self.run_id.write().await = Some(run_id.clone());
self.emit(Event::RunStarted(RunStartedEvent {
base: BaseEvent::with_current_timestamp(),
thread_id,
run_id,
}));
}
pub async fn finish_run(&self) {
let thread_id = self.thread_id.read().await.clone();
let run_id = self.run_id.write().await.take();
let Some(run_id) = run_id else {
return; };
self.emit(Event::RunFinished(RunFinishedEvent {
base: BaseEvent::with_current_timestamp(),
thread_id,
run_id,
outcome: Some(RunFinishedOutcome::Success),
result: None,
interrupt: None,
}));
}
pub async fn finish_run_with_error(&self, message: &str) {
let _run_id = self.run_id.write().await.take();
self.emit(Event::RunError(syncable_ag_ui_core::RunErrorEvent {
base: BaseEvent::with_current_timestamp(),
message: message.to_string(),
code: None,
}));
}
pub async fn interrupt(&self, reason: Option<&str>, payload: Option<serde_json::Value>) {
let thread_id = self.thread_id.read().await.clone();
let run_id = self.run_id.write().await.take();
let Some(run_id) = run_id else {
return; };
let mut info = InterruptInfo::new();
if let Some(r) = reason {
info = info.with_reason(r);
}
if let Some(p) = payload {
info = info.with_payload(p);
}
self.emit(Event::RunFinished(RunFinishedEvent {
base: BaseEvent::with_current_timestamp(),
thread_id,
run_id,
outcome: Some(RunFinishedOutcome::Interrupt),
result: None,
interrupt: Some(info),
}));
}
pub async fn interrupt_with_id(
&self,
id: &str,
reason: Option<&str>,
payload: Option<serde_json::Value>,
) {
let thread_id = self.thread_id.read().await.clone();
let run_id = self.run_id.write().await.take();
let Some(run_id) = run_id else {
return; };
let mut info = InterruptInfo::new().with_id(id);
if let Some(r) = reason {
info = info.with_reason(r);
}
if let Some(p) = payload {
info = info.with_payload(p);
}
self.emit(Event::RunFinished(RunFinishedEvent {
base: BaseEvent::with_current_timestamp(),
thread_id,
run_id,
outcome: Some(RunFinishedOutcome::Interrupt),
result: None,
interrupt: Some(info),
}));
}
pub async fn start_message(&self) -> MessageId {
let message_id = MessageId::random();
*self.current_message_id.write().await = Some(message_id.clone());
self.emit(Event::TextMessageStart(TextMessageStartEvent {
base: BaseEvent::with_current_timestamp(),
message_id: message_id.clone(),
role: Role::Assistant,
}));
message_id
}
pub async fn emit_text_chunk(&self, delta: &str) {
let message_id = self.current_message_id.read().await.clone();
if let Some(message_id) = message_id {
self.emit(Event::TextMessageContent(
TextMessageContentEvent::new_unchecked(message_id, delta),
));
}
}
pub async fn end_message(&self) {
let message_id = self.current_message_id.write().await.take();
if let Some(message_id) = message_id {
self.emit(Event::TextMessageEnd(TextMessageEndEvent {
base: BaseEvent::with_current_timestamp(),
message_id,
}));
}
}
pub async fn emit_message(&self, content: &str) {
let _message_id = self.start_message().await;
self.emit_text_chunk(content).await;
self.end_message().await;
}
pub async fn start_tool_call(&self, name: &str, args: &JsonValue) -> ToolCallId {
let tool_call_id = ToolCallId::random();
let message_id = {
let mut current = self.current_message_id.write().await;
if current.is_none() {
*current = Some(MessageId::random());
}
current.clone().unwrap()
};
self.emit(Event::ToolCallStart(ToolCallStartEvent {
base: BaseEvent::with_current_timestamp(),
tool_call_id: tool_call_id.clone(),
tool_call_name: name.to_string(),
parent_message_id: Some(message_id),
}));
if !args.is_null() {
if let Ok(args_str) = serde_json::to_string(args) {
self.emit(Event::ToolCallArgs(ToolCallArgsEvent {
base: BaseEvent::with_current_timestamp(),
tool_call_id: tool_call_id.clone(),
delta: args_str,
}));
}
}
tool_call_id
}
pub async fn emit_tool_args_chunk(&self, tool_call_id: &ToolCallId, delta: &str) {
self.emit(Event::ToolCallArgs(ToolCallArgsEvent {
base: BaseEvent::with_current_timestamp(),
tool_call_id: tool_call_id.clone(),
delta: delta.to_string(),
}));
}
pub async fn end_tool_call(&self, tool_call_id: &ToolCallId) {
self.emit(Event::ToolCallEnd(ToolCallEndEvent {
base: BaseEvent::with_current_timestamp(),
tool_call_id: tool_call_id.clone(),
}));
}
pub async fn emit_tool_call(&self, name: &str, args: &JsonValue) {
let tool_call_id = self.start_tool_call(name, args).await;
self.end_tool_call(&tool_call_id).await;
}
pub async fn emit_state_snapshot(&self, state: JsonValue) {
self.emit(Event::StateSnapshot(
syncable_ag_ui_core::StateSnapshotEvent {
base: BaseEvent::with_current_timestamp(),
snapshot: state,
},
));
}
pub async fn emit_state_delta(&self, delta: Vec<JsonValue>) {
self.emit(Event::StateDelta(syncable_ag_ui_core::StateDeltaEvent {
base: BaseEvent::with_current_timestamp(),
delta,
}));
}
pub async fn start_thinking(&self, title: Option<&str>) {
self.emit(Event::ThinkingStart(
syncable_ag_ui_core::ThinkingStartEvent {
base: BaseEvent::with_current_timestamp(),
title: title.map(|s| s.to_string()),
},
));
}
pub async fn end_thinking(&self) {
self.emit(Event::ThinkingEnd(syncable_ag_ui_core::ThinkingEndEvent {
base: BaseEvent::with_current_timestamp(),
}));
}
pub async fn start_step(&self, name: &str) {
*self.current_step_name.write().await = Some(name.to_string());
self.emit(Event::StepStarted(syncable_ag_ui_core::StepStartedEvent {
base: BaseEvent::with_current_timestamp(),
step_name: name.to_string(),
}));
}
pub async fn end_step(&self) {
let step_name = self
.current_step_name
.write()
.await
.take()
.unwrap_or_else(|| "unknown".to_string());
self.emit(Event::StepFinished(
syncable_ag_ui_core::StepFinishedEvent {
base: BaseEvent::with_current_timestamp(),
step_name,
},
));
}
pub async fn emit_custom(&self, name: &str, value: JsonValue) {
self.emit(Event::Custom(syncable_ag_ui_core::CustomEvent {
base: BaseEvent::with_current_timestamp(),
name: name.to_string(),
value,
}));
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_bridge() -> EventBridge {
let (tx, _) = broadcast::channel(100);
EventBridge::new(
tx,
Arc::new(RwLock::new(ThreadId::random())),
Arc::new(RwLock::new(None)),
)
}
#[tokio::test]
async fn test_start_and_finish_run() {
let bridge = create_bridge();
bridge.start_run().await;
assert!(bridge.run_id.read().await.is_some());
bridge.finish_run().await;
assert!(bridge.run_id.read().await.is_none());
}
#[tokio::test]
async fn test_message_lifecycle() {
let bridge = create_bridge();
let _msg_id = bridge.start_message().await;
assert!(bridge.current_message_id.read().await.is_some());
bridge.emit_text_chunk("Hello").await;
bridge.end_message().await;
assert!(bridge.current_message_id.read().await.is_none());
}
#[tokio::test]
async fn test_emit_complete_message() {
let bridge = create_bridge();
bridge.emit_message("Hello, world!").await;
}
#[tokio::test]
async fn test_tool_call() {
let bridge = create_bridge();
let tool_id = bridge
.start_tool_call("test", &serde_json::json!({"key": "value"}))
.await;
bridge.emit_tool_args_chunk(&tool_id, "more args").await;
bridge.end_tool_call(&tool_id).await;
}
#[tokio::test]
async fn test_interrupt() {
let bridge = create_bridge();
bridge.start_run().await;
assert!(bridge.run_id.read().await.is_some());
bridge.interrupt(Some("file_write"), None).await;
assert!(bridge.run_id.read().await.is_none());
}
#[tokio::test]
async fn test_interrupt_with_payload() {
let bridge = create_bridge();
bridge.start_run().await;
bridge
.interrupt(
Some("deployment"),
Some(serde_json::json!({"file": "main.rs", "action": "write"})),
)
.await;
assert!(bridge.run_id.read().await.is_none());
}
#[tokio::test]
async fn test_interrupt_with_id() {
let bridge = create_bridge();
bridge.start_run().await;
bridge
.interrupt_with_id("int-123", Some("deployment"), None)
.await;
assert!(bridge.run_id.read().await.is_none());
}
#[tokio::test]
async fn test_interrupt_without_run() {
let bridge = create_bridge();
bridge.interrupt(Some("test"), None).await;
}
#[tokio::test]
async fn test_events_received_by_subscriber() {
let (tx, mut rx) = broadcast::channel(100);
let bridge = EventBridge::new(
tx,
Arc::new(RwLock::new(ThreadId::random())),
Arc::new(RwLock::new(None)),
);
bridge.start_run().await;
let event = rx.recv().await.expect("Should receive event");
match event {
Event::RunStarted(_) => {}
_ => panic!("Expected RunStarted event"),
}
bridge.emit_message("Hello").await;
let event = rx.recv().await.expect("Should receive event");
match event {
Event::TextMessageStart(_) => {}
_ => panic!("Expected TextMessageStart"),
}
let event = rx.recv().await.expect("Should receive event");
match event {
Event::TextMessageContent(_) => {}
_ => panic!("Expected TextMessageContent"),
}
let event = rx.recv().await.expect("Should receive event");
match event {
Event::TextMessageEnd(_) => {}
_ => panic!("Expected TextMessageEnd"),
}
}
#[tokio::test]
async fn test_step_and_thinking_events() {
let (tx, mut rx) = broadcast::channel(100);
let bridge = EventBridge::new(
tx,
Arc::new(RwLock::new(ThreadId::random())),
Arc::new(RwLock::new(None)),
);
bridge.start_step("processing").await;
let event = rx.recv().await.expect("Should receive event");
match event {
Event::StepStarted(_) => {}
_ => panic!("Expected StepStarted"),
}
bridge.start_thinking(Some("Analyzing")).await;
let event = rx.recv().await.expect("Should receive event");
match event {
Event::ThinkingStart(_) => {}
_ => panic!("Expected ThinkingStart"),
}
bridge.end_thinking().await;
let event = rx.recv().await.expect("Should receive event");
match event {
Event::ThinkingEnd(_) => {}
_ => panic!("Expected ThinkingEnd"),
}
bridge.end_step().await;
let event = rx.recv().await.expect("Should receive event");
match event {
Event::StepFinished(_) => {}
_ => panic!("Expected StepFinished"),
}
}
#[tokio::test]
async fn test_state_snapshot_event() {
let (tx, mut rx) = broadcast::channel(100);
let bridge = EventBridge::new(
tx,
Arc::new(RwLock::new(ThreadId::random())),
Arc::new(RwLock::new(None)),
);
let state = serde_json::json!({
"model": "gpt-4",
"turn_count": 5
});
bridge.emit_state_snapshot(state).await;
let event = rx.recv().await.expect("Should receive event");
match event {
Event::StateSnapshot(e) => {
assert_eq!(e.snapshot["model"], "gpt-4");
assert_eq!(e.snapshot["turn_count"], 5);
}
_ => panic!("Expected StateSnapshot"),
}
}
}