#![cfg(feature = "ws")]
pub(crate) mod ws_route {
use std::collections::HashMap;
use std::sync::Arc;
use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
use axum::response::Response;
use futures::stream::{AbortHandle, Abortable};
use futures::{SinkExt, StreamExt};
use crate::procedure::{ProcedureBody, ProcedureDescriptor, StreamFrame};
use crate::router::ProcKindRuntime;
pub fn ws_handler(
descriptors: Arc<Vec<ProcedureDescriptor>>,
) -> impl Fn(WebSocketUpgrade) -> futures::future::BoxFuture<'static, Response>
+ Clone
+ Send
+ Sync
+ 'static {
move |upgrade: WebSocketUpgrade| {
let descriptors = descriptors.clone();
Box::pin(
async move { upgrade.on_upgrade(move |socket| handle_socket(socket, descriptors)) },
)
}
}
async fn handle_socket(socket: WebSocket, descriptors: Arc<Vec<ProcedureDescriptor>>) {
let mut active: HashMap<u64, AbortHandle> = HashMap::new();
let (mut tx, mut rx) = socket.split();
let (frame_tx, mut frame_rx) = tokio::sync::mpsc::unbounded_channel::<Message>();
let writer_handle = tokio::spawn(async move {
while let Some(msg) = frame_rx.recv().await {
if tx.send(msg).await.is_err() {
break;
}
}
});
while let Some(Ok(msg)) = rx.next().await {
let text = match msg {
Message::Text(t) => t,
Message::Close(_) => break,
_ => continue,
};
let parsed: Result<crate::wire::WsMessage<serde_json::Value, serde_json::Value>, _> =
serde_json::from_str(&text);
let parsed = match parsed {
Ok(p) => p,
Err(e) => {
let _ = frame_tx.send(Message::Text(
serde_json::to_string(&serde_json::json!({
"type": "error",
"payload": {
"id": 0u64,
"err": {
"code": "decode_error",
"payload": { "message": e.to_string() }
}
}
}))
.unwrap(),
));
continue;
}
};
match parsed {
crate::wire::WsMessage::Subscribe {
id,
procedure,
input,
} => {
let desc = descriptors.iter().find(|d| d.name == procedure);
let Some(desc) = desc else {
let _ = frame_tx.send(Message::Text(
serde_json::to_string(&serde_json::json!({
"type": "error",
"payload": {
"id": id,
"err": {
"code": "not_found",
"payload": { "procedure": procedure }
}
}
}))
.unwrap(),
));
let _ = frame_tx.send(Message::Text(
serde_json::to_string(&serde_json::json!({
"type": "end",
"payload": { "id": id }
}))
.unwrap(),
));
continue;
};
let stream_handler = match &desc.body {
ProcedureBody::Stream(h) => h.clone(),
ProcedureBody::Unary(_) => {
debug_assert_ne!(
desc.kind,
ProcKindRuntime::Subscription,
"subscription kind paired with unary body"
);
let _ = frame_tx.send(Message::Text(
serde_json::to_string(&serde_json::json!({
"type": "error",
"payload": {
"id": id,
"err": {
"code": "not_subscription",
"payload": serde_json::Value::Null
}
}
}))
.unwrap(),
));
let _ = frame_tx.send(Message::Text(
serde_json::to_string(&serde_json::json!({
"type": "end",
"payload": { "id": id }
}))
.unwrap(),
));
continue;
}
};
let stream = stream_handler(input);
let (abort_handle, abort_reg) = AbortHandle::new_pair();
active.insert(id, abort_handle);
let frame_tx = frame_tx.clone();
let abortable = Abortable::new(stream, abort_reg);
tokio::spawn(async move {
futures::pin_mut!(abortable);
while let Some(frame) = abortable.next().await {
let envelope = match frame {
StreamFrame::Data(value) => serde_json::json!({
"type": "data",
"payload": { "id": id, "value": value }
}),
StreamFrame::Error { code, payload } => serde_json::json!({
"type": "error",
"payload": {
"id": id,
"err": { "code": code, "payload": payload }
}
}),
};
if frame_tx
.send(Message::Text(serde_json::to_string(&envelope).unwrap()))
.is_err()
{
return;
}
}
let _ = frame_tx.send(Message::Text(
serde_json::to_string(&serde_json::json!({
"type": "end",
"payload": { "id": id }
}))
.unwrap(),
));
});
}
crate::wire::WsMessage::Unsubscribe { id } => {
if let Some(handle) = active.remove(&id) {
handle.abort();
}
}
_ => {}
}
}
for (_, h) in active.drain() {
h.abort();
}
writer_handle.abort();
}
}