use std::sync::Arc;
use std::time::Duration;
use parking_lot::Mutex;
use axum::{
Router,
extract::{
State,
ws::{CloseFrame, Message, Utf8Bytes, WebSocket, WebSocketUpgrade, close_code},
},
http::Method,
response::Response,
routing::get,
};
use dynamo_runtime::engine::{AsyncEngineContextProvider, RequestStream};
use dynamo_runtime::pipeline::Context;
use futures::{SinkExt, StreamExt, stream::SplitSink};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
const REQUEST_CHANNEL_CAPACITY: usize = 64;
const CLOSE_DRAIN_TIMEOUT: Duration = Duration::from_secs(5);
use super::{RouteDoc, service_v2};
use crate::discovery::ModelManagerError;
use crate::types::RealtimeBidirectionalEngine;
use dynamo_protocols::types::realtime::{
EventType, RealtimeAPIError, RealtimeClientEvent, RealtimeClientEventSessionUpdate,
RealtimeServerEvent, RealtimeServerEventError, RealtimeServerEventSessionCreated, Session,
};
use uuid::Uuid;
pub fn realtime_router(
state: Arc<service_v2::State>,
path: Option<String>,
) -> (Vec<RouteDoc>, Router) {
let realtime_path = path.unwrap_or_else(|| "/v1/realtime".to_string());
let docs = vec![RouteDoc::new(Method::GET, &realtime_path)];
let router = Router::new()
.route(&realtime_path, get(realtime_ws_handler))
.with_state(state);
(docs, router)
}
async fn realtime_ws_handler(
State(state): State<Arc<service_v2::State>>,
upgrade: WebSocketUpgrade,
) -> Response {
upgrade.on_upgrade(move |socket| handle_socket(socket, state))
}
async fn handle_socket(socket: WebSocket, state: Arc<service_v2::State>) {
let close_reason: Arc<Mutex<Option<Message>>> = Arc::new(Mutex::new(None));
let (ws_tx, mut ws_rx) = socket.split();
let mut writer = ScopedWsWriter::new(ws_tx, close_reason.clone());
let session_created = RealtimeServerEvent::SessionCreated(RealtimeServerEventSessionCreated {
event_id: format!("event_{}", Uuid::new_v4()),
session: Session::RealtimeSession(Box::default()),
});
let session_created_payload = match serde_json::to_string(&session_created) {
Ok(s) => s,
Err(err) => {
tracing::error!(%err, "/v1/realtime serializing session.created failed");
*close_reason.lock() = Some(close_message(
close_code::ERROR,
"internal error preparing session.created",
));
return;
}
};
if let Err(err) = writer
.send(Message::Text(Utf8Bytes::from(session_created_payload)))
.await
{
tracing::debug!(%err, "/v1/realtime client disconnected before session.created");
return;
}
let Some((engine, session_update)) =
select_engine(&mut ws_rx, &mut *writer, state.as_ref()).await
else {
return;
};
let (req_tx, req_rx) = mpsc::channel::<RealtimeClientEvent>(REQUEST_CHANNEL_CAPACITY);
if req_tx
.send(RealtimeClientEvent::SessionUpdate(session_update))
.await
.is_err()
{
tracing::debug!("/v1/realtime engine receiver dropped before session.update delivered");
return;
}
let request_stream = Box::pin(ReceiverStream::new(req_rx));
let input = RequestStream::new(request_stream, Context::new(()).context());
let mut response_stream = match engine.generate(input).await {
Ok(s) => s,
Err(err) => {
tracing::error!(%err, "/v1/realtime engine.generate() failed");
*close_reason.lock() = Some(close_message(
close_code::ERROR,
&format!("engine error: {err}"),
));
return;
}
};
let resp_ctx = response_stream.context();
let outbound = tokio::spawn(async move {
let mut writer = writer;
while let Some(annotated) = response_stream.next().await {
let event = if let Some(event) = annotated.data {
event
} else if let Some(err) = annotated.error {
RealtimeServerEvent::Error(RealtimeServerEventError {
event_id: format!("event_{}", Uuid::new_v4()),
error: RealtimeAPIError {
r#type: "server_error".to_string(),
code: None,
message: err.to_string(),
param: None,
event_id: None,
},
})
} else {
continue;
};
let frame_payload = match serde_json::to_string(&event) {
Ok(s) => s,
Err(err) => {
tracing::warn!(%err, "/v1/realtime serializing response chunk failed");
continue;
}
};
if writer
.send(Message::Text(Utf8Bytes::from(frame_payload)))
.await
.is_err()
{
tracing::debug!("/v1/realtime client disconnected during response");
break;
}
}
});
while let Some(msg) = ws_rx.next().await {
let msg = match msg {
Ok(m) => m,
Err(err) => {
tracing::debug!(%err, "/v1/realtime inbound frame error; treating as disconnect");
break;
}
};
match msg {
Message::Text(text) => {
match serde_json::from_str::<RealtimeClientEvent>(text.as_str()) {
Ok(event) => {
if req_tx.send(event).await.is_err() {
tracing::debug!("/v1/realtime engine receiver dropped; ending inbound");
break;
}
}
Err(err) => {
tracing::warn!(%err, "/v1/realtime malformed JSON frame; closing");
*close_reason.lock() =
Some(close_message(close_code::INVALID, "malformed JSON frame"));
break;
}
}
}
Message::Binary(_) => {
tracing::warn!("/v1/realtime received binary frame; not supported in this slice");
*close_reason.lock() = Some(close_message(
close_code::UNSUPPORTED,
"binary frames not supported",
));
break;
}
Message::Close(_) => break,
Message::Ping(_) | Message::Pong(_) => {} }
}
resp_ctx.stop_generating();
drop(req_tx);
let _ = outbound.await;
}
fn close_message(code: u16, reason: &str) -> Message {
Message::Close(Some(CloseFrame {
code,
reason: Utf8Bytes::from(reason.to_string()),
}))
}
struct ScopedWsWriter {
ws_tx: Option<SplitSink<WebSocket, Message>>,
close_reason: Arc<Mutex<Option<Message>>>,
}
impl ScopedWsWriter {
fn new(
ws_tx: SplitSink<WebSocket, Message>,
close_reason: Arc<Mutex<Option<Message>>>,
) -> Self {
Self {
ws_tx: Some(ws_tx),
close_reason,
}
}
}
impl std::ops::Deref for ScopedWsWriter {
type Target = SplitSink<WebSocket, Message>;
fn deref(&self) -> &Self::Target {
self.ws_tx
.as_ref()
.expect("ScopedWsWriter sink only taken by Drop; no other consumer should exist")
}
}
impl std::ops::DerefMut for ScopedWsWriter {
fn deref_mut(&mut self) -> &mut Self::Target {
self.ws_tx
.as_mut()
.expect("ScopedWsWriter sink only taken by Drop; no other consumer should exist")
}
}
impl Drop for ScopedWsWriter {
fn drop(&mut self) {
let Some(mut ws_tx) = self.ws_tx.take() else {
return;
};
let close_reason = self.close_reason.clone();
tokio::spawn(async move {
let msg = close_reason
.lock()
.take()
.unwrap_or_else(|| close_message(close_code::NORMAL, "stream complete"));
let _ = ws_tx.send(msg).await;
let _ = tokio::time::timeout(CLOSE_DRAIN_TIMEOUT, ws_tx.close()).await;
});
}
}
async fn select_engine<S, T>(
ws_rx: &mut S,
ws_tx: &mut T,
state: &service_v2::State,
) -> Option<(
RealtimeBidirectionalEngine,
RealtimeClientEventSessionUpdate,
)>
where
S: futures::Stream<Item = Result<Message, axum::Error>> + Unpin,
T: futures::Sink<Message, Error = axum::Error> + Unpin,
{
while let Some(msg) = ws_rx.next().await {
let msg = match msg {
Ok(m) => m,
Err(err) => {
tracing::debug!(%err, "/v1/realtime inbound error during engine selection");
return None;
}
};
let event = match msg {
Message::Text(text) => {
match serde_json::from_str::<RealtimeClientEvent>(text.as_str()) {
Ok(e) => e,
Err(err) => {
tracing::debug!(%err, "/v1/realtime malformed JSON during engine selection");
send_error_event(ws_tx, "invalid_request", "malformed JSON frame", None)
.await;
continue;
}
}
}
Message::Binary(_) => {
tracing::debug!("/v1/realtime binary frame during engine selection");
send_error_event(
ws_tx,
"invalid_request",
"binary frames not supported",
None,
)
.await;
continue;
}
Message::Close(_) => return None,
Message::Ping(_) | Message::Pong(_) => continue, };
let session_update = match event {
RealtimeClientEvent::SessionUpdate(req) => req,
other => {
tracing::debug!(
event = other.event_type(),
"/v1/realtime expected session.update before engine selection"
);
send_error_event(
ws_tx,
"invalid_request",
"expected session.update before engine is selected",
Some("session.update"),
)
.await;
continue;
}
};
let model_name = match &session_update.session {
Session::RealtimeSession(s) => s.model.as_deref().filter(|m| !m.is_empty()),
Session::RealtimeTranscriptionSession(_) => {
send_error_event(
ws_tx,
"unsupported_session_type",
"session.type 'transcription' is not yet supported (only 'realtime' is supported)",
Some("session.type"),
)
.await;
continue;
}
};
let Some(model_name) = model_name else {
send_error_event(
ws_tx,
"invalid_request",
"session.model required",
Some("session.model"),
)
.await;
continue;
};
match state.manager().get_realtime_engine(model_name) {
Ok(engine) => return Some((engine, session_update)),
Err(ModelManagerError::ModelNotFound(_)) => {
send_error_event(
ws_tx,
"model_not_found",
&format!("unknown model: {model_name}"),
Some("session.model"),
)
.await;
continue;
}
Err(ModelManagerError::ModelUnavailable(_)) => {
send_error_event(
ws_tx,
"model_unavailable",
&format!("model unavailable: {model_name}"),
Some("session.model"),
)
.await;
continue;
}
Err(err) => {
tracing::error!(%err, "/v1/realtime engine lookup failed");
send_error_event(
ws_tx,
"server_error",
&err.to_string(),
Some("session.model"),
)
.await;
continue;
}
}
}
None
}
async fn send_error_event<S>(ws_tx: &mut S, code: &str, message: &str, param: Option<&str>)
where
S: futures::Sink<Message, Error = axum::Error> + Unpin,
{
let event = RealtimeServerEvent::Error(RealtimeServerEventError {
event_id: format!("event_{}", Uuid::new_v4()),
error: RealtimeAPIError {
r#type: "invalid_request_error".to_string(),
code: Some(code.to_string()),
message: message.to_string(),
param: param.map(|s| s.to_string()),
event_id: None,
},
});
let payload = match serde_json::to_string(&event) {
Ok(s) => s,
Err(err) => {
tracing::warn!(%err, "/v1/realtime serializing error event failed");
return;
}
};
let _ = ws_tx.send(Message::Text(Utf8Bytes::from(payload))).await;
}