use std::sync::{Arc, OnceLock};
use parking_lot::Mutex;
use axum::{
Router,
extract::{
State,
ws::{CloseFrame, Message, Utf8Bytes, WebSocket, WebSocketUpgrade, close_code},
},
http::Method,
response::Response,
routing::get,
};
use dynamo_runtime::engine::{AsyncEngine, AsyncEngineContextProvider, RequestStream};
use dynamo_runtime::pipeline::Context;
use futures::{SinkExt, StreamExt};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
const REQUEST_CHANNEL_CAPACITY: usize = 64;
use super::{RouteDoc, service_v2};
use crate::engines::EchoBidirectionalEngine;
use crate::protocols::openai::chat_completions::NvCreateChatCompletionRequest;
static BIDIRECTIONAL_ENGINE: OnceLock<EchoBidirectionalEngine> = OnceLock::new();
pub fn install_engine(engine: EchoBidirectionalEngine) -> Result<(), &'static str> {
BIDIRECTIONAL_ENGINE
.set(engine)
.map_err(|_| "realtime bidirectional engine already installed")
}
pub fn install_echo_engine() -> Result<(), &'static str> {
install_engine(EchoBidirectionalEngine {})
}
pub fn realtime_router(
state: Arc<service_v2::State>,
path: Option<String>,
) -> (Vec<RouteDoc>, Router) {
let realtime_path = path.unwrap_or_else(|| "/v1/realtime".to_string());
let docs = vec![RouteDoc::new(Method::GET, &realtime_path)];
let router = Router::new()
.route(&realtime_path, get(realtime_ws_handler))
.with_state(state);
(docs, router)
}
async fn realtime_ws_handler(
State(state): State<Arc<service_v2::State>>,
upgrade: WebSocketUpgrade,
) -> Response {
upgrade.on_upgrade(move |socket| handle_socket(socket, state))
}
async fn handle_socket(mut socket: WebSocket, _state: Arc<service_v2::State>) {
let Some(engine) = BIDIRECTIONAL_ENGINE.get() else {
tracing::error!("/v1/realtime connection rejected: bidirectional engine not installed");
let _ = socket
.send(close_message(
close_code::ERROR,
"bidirectional engine not installed",
))
.await;
return;
};
let (mut ws_tx, mut ws_rx) = socket.split();
let (req_tx, req_rx) = mpsc::channel::<NvCreateChatCompletionRequest>(REQUEST_CHANNEL_CAPACITY);
let request_stream = Box::pin(ReceiverStream::new(req_rx));
let input = RequestStream::new(request_stream, Context::new(()).context());
let close_reason: Arc<Mutex<Option<Message>>> = Arc::new(Mutex::new(None));
let mut response_stream = match engine.generate(input).await {
Ok(s) => s,
Err(err) => {
tracing::error!(%err, "/v1/realtime engine.generate() failed");
let _ = ws_tx
.send(close_message(
close_code::ERROR,
&format!("engine error: {err}"),
))
.await;
return;
}
};
let resp_ctx = response_stream.context();
let outbound_close_reason = close_reason.clone();
let outbound = tokio::spawn(async move {
while let Some(annotated) = response_stream.next().await {
let frame_payload = match serde_json::to_string(&annotated) {
Ok(s) => s,
Err(err) => {
tracing::warn!(%err, "/v1/realtime serializing response chunk failed");
continue;
}
};
if ws_tx
.send(Message::Text(Utf8Bytes::from(frame_payload)))
.await
.is_err()
{
tracing::debug!("/v1/realtime client disconnected during response");
break;
}
}
let msg = outbound_close_reason
.lock()
.take()
.unwrap_or_else(|| close_message(close_code::NORMAL, "stream complete"));
let _ = ws_tx.send(msg).await;
let _ = tokio::time::timeout(std::time::Duration::from_secs(5), ws_tx.close()).await;
});
while let Some(msg) = ws_rx.next().await {
let msg = match msg {
Ok(m) => m,
Err(err) => {
tracing::debug!(%err, "/v1/realtime inbound frame error; treating as disconnect");
break;
}
};
match msg {
Message::Text(text) => {
match serde_json::from_str::<NvCreateChatCompletionRequest>(text.as_str()) {
Ok(req) => {
if req_tx.send(req).await.is_err() {
tracing::debug!("/v1/realtime engine receiver dropped; ending inbound");
break;
}
}
Err(err) => {
tracing::warn!(%err, "/v1/realtime malformed JSON frame; closing");
*close_reason.lock() =
Some(close_message(close_code::INVALID, "malformed JSON frame"));
break;
}
}
}
Message::Binary(_) => {
tracing::warn!("/v1/realtime received binary frame; not supported in this slice");
*close_reason.lock() = Some(close_message(
close_code::UNSUPPORTED,
"binary frames not supported",
));
break;
}
Message::Close(_) => break,
Message::Ping(_) | Message::Pong(_) => {} }
}
resp_ctx.stop_generating();
drop(req_tx);
let _ = outbound.await;
}
fn close_message(code: u16, reason: &str) -> Message {
Message::Close(Some(CloseFrame {
code,
reason: Utf8Bytes::from(reason.to_string()),
}))
}