use std::convert::Infallible;
use std::sync::Arc;
use std::time::Duration;
use axum::response::sse::{Event, KeepAlive, Sse};
use futures_util::stream::{self, Stream, StreamExt as _};
use serde::{Deserialize, Serialize};
use tokio::sync::broadcast;
use tokio_stream::wrappers::BroadcastStream;
use crate::threads::types::Run;
pub const RUN_EVENT_BROADCAST_CAPACITY: usize = 256;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "event_type", content = "data", rename_all = "snake_case")]
pub enum RunEvent {
Created(Run),
InProgress(Run),
MessageDelta {
run_id: String,
content: String,
},
Completed(Run),
Failed(Run),
}
impl RunEvent {
pub fn sse_event_name(&self) -> &'static str {
match self {
RunEvent::Created(_) => "thread.run.created",
RunEvent::InProgress(_) => "thread.run.in_progress",
RunEvent::MessageDelta { .. } => "thread.message.delta",
RunEvent::Completed(_) => "thread.run.completed",
RunEvent::Failed(_) => "thread.run.failed",
}
}
pub fn is_terminal(&self) -> bool {
matches!(self, RunEvent::Completed(_) | RunEvent::Failed(_))
}
pub fn run_id(&self) -> &str {
match self {
RunEvent::Created(r)
| RunEvent::InProgress(r)
| RunEvent::Completed(r)
| RunEvent::Failed(r) => &r.id,
RunEvent::MessageDelta { run_id, .. } => run_id,
}
}
}
pub type RunEventSender = Arc<broadcast::Sender<RunEvent>>;
pub fn new_run_event_channel() -> (RunEventSender, broadcast::Receiver<RunEvent>) {
let (tx, rx) = broadcast::channel(RUN_EVENT_BROADCAST_CAPACITY);
(Arc::new(tx), rx)
}
pub fn build_run_sse_stream(
event_tx: &RunEventSender,
run_id: String,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
let rx = event_tx.subscribe();
let event_stream = stream::unfold(
(BroadcastStream::new(rx), false),
move |(mut stream, done)| {
let run_id_inner = run_id.clone();
async move {
if done {
return None;
}
match stream.next().await {
None => {
let done_event =
Ok::<Event, Infallible>(Event::default().event("done").data("[DONE]"));
Some((done_event, (stream, true)))
}
Some(result) => {
let event = match result {
Ok(e) => e,
Err(_) => {
let placeholder = Ok::<Event, Infallible>(
Event::default().event("keep-alive").data(""),
);
return Some((placeholder, (stream, false)));
}
};
if event.run_id() != run_id_inner {
let placeholder = Ok::<Event, Infallible>(
Event::default().event("keep-alive").data(""),
);
return Some((placeholder, (stream, false)));
}
let is_terminal = event.is_terminal();
let event_name = event.sse_event_name();
let data = match serde_json::to_string(&event) {
Ok(s) => s,
Err(_) => {
let placeholder = Ok::<Event, Infallible>(
Event::default().event("keep-alive").data(""),
);
return Some((placeholder, (stream, false)));
}
};
let sse_event = Ok(Event::default().event(event_name).data(data));
Some((sse_event, (stream, is_terminal)))
}
}
}
},
);
Sse::new(event_stream).keep_alive(
KeepAlive::new()
.interval(Duration::from_secs(15))
.text("keep-alive"),
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::threads::types::{Run, RunStatus};
fn make_run(id: &str) -> Run {
Run {
id: id.to_string(),
object: "thread.run".to_string(),
created_at: 1_000_000,
thread_id: "thread_test".to_string(),
status: RunStatus::Queued,
model: "test-model".to_string(),
last_error: None,
}
}
#[test]
fn run_event_serialize_created() {
let run = make_run("run_create_test");
let event = RunEvent::Created(run.clone());
let json_str = serde_json::to_string(&event).expect("serialize");
let val: serde_json::Value = serde_json::from_str(&json_str).expect("parse");
assert_eq!(val["event_type"], "created");
assert_eq!(val["data"]["id"], run.id);
assert_eq!(val["data"]["status"], "queued");
}
#[test]
fn run_event_serialize_completed() {
let mut run = make_run("run_complete_test");
run.status = RunStatus::Completed;
let event = RunEvent::Completed(run.clone());
let json_str = serde_json::to_string(&event).expect("serialize");
let val: serde_json::Value = serde_json::from_str(&json_str).expect("parse");
assert_eq!(val["event_type"], "completed");
assert_eq!(val["data"]["id"], run.id);
assert_eq!(val["data"]["status"], "completed");
}
#[test]
fn run_event_message_delta() {
let event = RunEvent::MessageDelta {
run_id: "run_delta_test".to_string(),
content: "Hello, world!".to_string(),
};
let json_str = serde_json::to_string(&event).expect("serialize");
let val: serde_json::Value = serde_json::from_str(&json_str).expect("parse");
assert_eq!(val["event_type"], "message_delta");
assert_eq!(val["data"]["run_id"], "run_delta_test");
assert_eq!(val["data"]["content"], "Hello, world!");
}
#[test]
fn run_event_is_terminal_variants() {
let run = make_run("r");
assert!(!RunEvent::Created(run.clone()).is_terminal());
assert!(!RunEvent::InProgress(run.clone()).is_terminal());
assert!(!RunEvent::MessageDelta {
run_id: "r".into(),
content: "x".into()
}
.is_terminal());
assert!(RunEvent::Completed(run.clone()).is_terminal());
assert!(RunEvent::Failed(run).is_terminal());
}
#[test]
fn run_event_sse_names() {
let run = make_run("r");
assert_eq!(
RunEvent::Created(run.clone()).sse_event_name(),
"thread.run.created"
);
assert_eq!(
RunEvent::InProgress(run.clone()).sse_event_name(),
"thread.run.in_progress"
);
assert_eq!(
RunEvent::MessageDelta {
run_id: "r".into(),
content: "x".into()
}
.sse_event_name(),
"thread.message.delta"
);
assert_eq!(
RunEvent::Completed(run.clone()).sse_event_name(),
"thread.run.completed"
);
assert_eq!(RunEvent::Failed(run).sse_event_name(), "thread.run.failed");
}
#[tokio::test]
async fn run_event_channel_roundtrip() {
let (tx, mut rx) = new_run_event_channel();
let run = make_run("run_channel");
let event = RunEvent::Created(run.clone());
tx.send(event).expect("send");
let received = rx.recv().await.expect("recv");
if let RunEvent::Created(r) = received {
assert_eq!(r.id, run.id);
} else {
panic!("unexpected event type");
}
}
}