use crate::types::{Artifact, Task, TaskStatus};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{broadcast, RwLock};
const CHANNEL_CAPACITY: usize = 64;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum StreamEvent {
Task(Task),
StatusUpdate(TaskStatusUpdateEvent),
ArtifactUpdate(TaskArtifactUpdateEvent),
}
impl StreamEvent {
pub fn is_final(&self) -> bool {
match self {
StreamEvent::StatusUpdate(s) => s.final_event,
StreamEvent::Task(_) | StreamEvent::ArtifactUpdate(_) => false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TaskStatusUpdateEvent {
pub task_id: String,
pub context_id: String,
#[serde(default = "status_update_kind")]
pub kind: String,
pub status: TaskStatus,
#[serde(rename = "final")]
pub final_event: bool,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub metadata: HashMap<String, Value>,
}
fn status_update_kind() -> String {
"status-update".into()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TaskArtifactUpdateEvent {
pub task_id: String,
pub context_id: String,
#[serde(default = "artifact_update_kind")]
pub kind: String,
pub artifact: Artifact,
#[serde(default)]
pub append: bool,
#[serde(default)]
pub last_chunk: bool,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub metadata: HashMap<String, Value>,
}
fn artifact_update_kind() -> String {
"artifact-update".into()
}
#[derive(Clone, Default)]
pub struct EventBus {
inner: Arc<RwLock<HashMap<String, broadcast::Sender<StreamEvent>>>>,
}
impl EventBus {
pub fn new() -> Self {
Self::default()
}
pub async fn subscribe(&self, task_id: &str) -> broadcast::Receiver<StreamEvent> {
let mut guard = self.inner.write().await;
guard
.entry(task_id.to_string())
.or_insert_with(|| broadcast::channel(CHANNEL_CAPACITY).0)
.subscribe()
}
pub async fn publish(&self, task_id: &str, event: StreamEvent) {
let is_final = event.is_final();
let mut guard = self.inner.write().await;
let sender = guard
.entry(task_id.to_string())
.or_insert_with(|| broadcast::channel(CHANNEL_CAPACITY).0)
.clone();
if is_final {
guard.remove(task_id);
}
drop(guard);
let _ = sender.send(event);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::bridge::task_with_status;
use crate::types::TaskState;
use chrono::Utc;
use tokio::time::{timeout, Duration};
#[tokio::test]
async fn subscriber_receives_published_events() {
let bus = EventBus::new();
let mut rx = bus.subscribe("t-1").await;
let task = task_with_status(
"t-1".into(),
"ctx".into(),
TaskState::Submitted,
vec![],
vec![],
);
bus.publish("t-1", StreamEvent::Task(task)).await;
let event = timeout(Duration::from_secs(1), rx.recv())
.await
.expect("recv timeout")
.expect("recv ok");
assert!(matches!(event, StreamEvent::Task(_)));
}
#[tokio::test]
async fn final_event_closes_channel() {
let bus = EventBus::new();
let mut rx = bus.subscribe("t-2").await;
bus.publish(
"t-2",
StreamEvent::StatusUpdate(TaskStatusUpdateEvent {
task_id: "t-2".into(),
context_id: "ctx".into(),
kind: status_update_kind(),
status: TaskStatus {
state: TaskState::Completed,
message: None,
timestamp: Utc::now(),
},
final_event: true,
metadata: HashMap::new(),
}),
)
.await;
let _ = rx.recv().await.expect("event");
let res = rx.recv().await;
assert!(matches!(res, Err(broadcast::error::RecvError::Closed)));
}
#[tokio::test]
async fn stream_event_round_trip_serializes() {
let event = StreamEvent::StatusUpdate(TaskStatusUpdateEvent {
task_id: "t".into(),
context_id: "c".into(),
kind: "status-update".into(),
status: TaskStatus {
state: TaskState::Working,
message: None,
timestamp: Utc::now(),
},
final_event: false,
metadata: HashMap::new(),
});
let v = serde_json::to_value(&event).unwrap();
assert_eq!(v["kind"], "status-update");
assert_eq!(v["taskId"], "t");
assert_eq!(v["final"], false);
}
}