use std::convert::Infallible;
use axum::response::sse::{Event as SseEvent, KeepAlive, Sse};
use axum::response::{IntoResponse, Response};
use serde_json::{json, Value};
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
#[derive(Clone)]
pub struct UiStream {
tx: mpsc::UnboundedSender<Result<SseEvent, Infallible>>,
}
pub fn channel() -> (UiStream, mpsc::UnboundedReceiver<Result<SseEvent, Infallible>>) {
let (tx, rx) = mpsc::unbounded_channel::<Result<SseEvent, Infallible>>();
(UiStream { tx }, rx)
}
impl UiStream {
fn part(&self, value: Value) {
let body = serde_json::to_string(&value).unwrap_or_else(|_| "{}".to_string());
let _ = self.tx.send(Ok(SseEvent::default().data(body)));
}
pub fn start(&self, message_id: &str) {
self.part(json!({ "type": "start", "messageId": message_id }));
}
pub fn start_step(&self) {
self.part(json!({ "type": "start-step" }));
}
pub fn text_start(&self, id: &str) {
self.part(json!({ "type": "text-start", "id": id }));
}
pub fn text_delta(&self, id: &str, delta: &str) {
self.part(json!({ "type": "text-delta", "id": id, "delta": delta }));
}
pub fn text_end(&self, id: &str) {
self.part(json!({ "type": "text-end", "id": id }));
}
pub fn reasoning_start(&self, id: &str) {
self.part(json!({ "type": "reasoning-start", "id": id }));
}
pub fn reasoning_delta(&self, id: &str, delta: &str) {
self.part(json!({ "type": "reasoning-delta", "id": id, "delta": delta }));
}
pub fn reasoning_end(&self, id: &str) {
self.part(json!({ "type": "reasoning-end", "id": id }));
}
pub fn tool_input_available(&self, tool_call_id: &str, tool_name: &str, input: Value) {
self.part(json!({
"type": "tool-input-available",
"toolCallId": tool_call_id,
"toolName": tool_name,
"input": input,
}));
}
pub fn tool_output_available(&self, tool_call_id: &str, output: Value) {
self.part(json!({
"type": "tool-output-available",
"toolCallId": tool_call_id,
"output": output,
}));
}
pub fn tool_output_error(&self, tool_call_id: &str, error_text: &str) {
self.part(json!({
"type": "tool-output-error",
"toolCallId": tool_call_id,
"errorText": error_text,
}));
}
pub fn data(&self, name: &str, data: Value) {
self.part(json!({ "type": format!("data-{name}"), "data": data }));
}
pub fn error(&self, error_text: &str) {
self.part(json!({ "type": "error", "errorText": error_text }));
}
pub fn finish_step(&self) {
self.part(json!({ "type": "finish-step" }));
}
pub fn finish(&self) {
self.part(json!({ "type": "finish" }));
}
pub fn done(&self) {
let _ = self.tx.send(Ok(SseEvent::default().data("[DONE]")));
}
}
pub fn response(rx: mpsc::UnboundedReceiver<Result<SseEvent, Infallible>>) -> Response {
let stream = UnboundedReceiverStream::new(rx);
let sse = Sse::new(stream).keep_alive(KeepAlive::default());
([("x-vercel-ai-ui-message-stream", "v1")], sse).into_response()
}