use std::sync::Arc;
use axum::extract::FromRequestParts;
use axum::extract::ws::{CloseCode, CloseFrame, Message, WebSocket, close_code};
use axum::http::request::Parts;
use axum::response::Response;
use futures::{SinkExt, StreamExt};
use futures::stream::SplitStream;
use objectiveai_sdk::error::ResponseError;
use serde::Serialize;
pub use crate::objectiveai_mcp::{
PendingRequests, ReverseAttachConfig, ReverseAttachGuard,
ReverseAttachHandle, ReverseChannel, ReverseChannelRegistry,
SessionTracker, SharedSink, new_pending_requests,
new_reverse_channel_registry,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Transport {
Sse,
WebSocket,
}
impl<S> FromRequestParts<S> for Transport
where
S: Send + Sync,
{
type Rejection = std::convert::Infallible;
async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
let is_ws_upgrade = parts
.headers
.get(axum::http::header::UPGRADE)
.and_then(|v| v.to_str().ok())
.map(|v| v.eq_ignore_ascii_case("websocket"))
.unwrap_or(false);
Ok(if is_ws_upgrade {
Transport::WebSocket
} else {
Transport::Sse
})
}
}
use serde::de::DeserializeOwned;
pub async fn recv_body_frame<T: DeserializeOwned>(
socket: &mut WebSocket,
) -> Result<T, ResponseError> {
loop {
match socket.recv().await {
Some(Ok(Message::Text(text))) => {
return serde_json::from_str::<T>(text.as_str()).map_err(|e| ResponseError {
code: 400,
message: serde_json::Value::String(format!(
"failed to deserialize body frame: {e}"
)),
});
}
Some(Ok(Message::Binary(_))) => {
return Err(ResponseError {
code: 400,
message: serde_json::Value::String(
"expected text body frame, got binary".into(),
),
});
}
Some(Ok(Message::Ping(_) | Message::Pong(_))) => continue,
Some(Ok(Message::Close(_))) | None => {
return Err(ResponseError {
code: 400,
message: serde_json::Value::String(
"peer closed before sending body".into(),
),
});
}
Some(Err(e)) => {
return Err(ResponseError {
code: 400,
message: serde_json::Value::String(format!("websocket recv error: {e}")),
});
}
}
}
}
pub async fn send_error_and_close(socket: &mut WebSocket, err: &ResponseError, code: CloseCode) {
let frame = serde_json::to_string(err).unwrap_or_else(|_| String::from("{}"));
let _ = socket.send(Message::Text(frame.into())).await;
let _ = socket
.send(Message::Close(Some(CloseFrame {
code,
reason: "".into(),
})))
.await;
}
pub async fn fatal_setup_error_split(sink: &SharedSink, err: &ResponseError) {
let frame = serde_json::to_string(err).unwrap_or_else(|_| String::from("{}"));
{
let mut guard = sink.lock().await;
let _ = guard.send(Message::Text(frame.into())).await;
}
send_close_split(sink, close_code::ERROR).await;
}
pub async fn send_chunk_split<C: Serialize>(sink: &SharedSink, chunk: &C) -> Result<(), ()> {
let json = match serde_json::to_string(chunk) {
Ok(s) => s,
Err(_) => return Ok(()), };
let mut guard = sink.lock().await;
let result = guard
.send(Message::Text(json.into()))
.await
.map_err(|_| ());
result
}
pub async fn send_close_split(sink: &SharedSink, code: CloseCode) {
let mut guard = sink.lock().await;
let _ = guard
.send(Message::Close(Some(CloseFrame {
code,
reason: "".into(),
})))
.await;
}
pub async fn recv_loop(
mut rx: SplitStream<WebSocket>,
sink: SharedSink,
pending: PendingRequests,
mcp_listeners: crate::objectiveai_mcp::McpListenerRegistry,
attach_handle: Arc<ReverseAttachHandle>,
) {
use objectiveai_sdk::client_objectiveai_mcp::{
client_request::{Payload as ClientPayload, Request as ClientRequest},
client_response::Response as ClientResponse,
server_response::Response as ServerResponse,
};
loop {
let msg = match rx.next().await {
Some(m) => m,
None => {
return;
}
};
let text = match msg {
Ok(Message::Text(t)) => {
t
}
Ok(Message::Binary(_)) => {
eprintln!("ignoring binary frame on streaming WS recv side");
continue;
}
Ok(Message::Ping(_) | Message::Pong(_)) => continue,
Ok(Message::Close(_)) => {
return;
}
Err(e) => {
eprintln!("streaming WS recv error: {e}");
return;
}
};
if let Ok(request) = serde_json::from_str::<ClientRequest>(text.as_str()) {
let ClientRequest { id, payload } = request;
match payload {
ClientPayload::McpListChanged(change) => {
for response_id in attach_handle.registered_ids() {
mcp_listeners.publish(
&response_id,
&change.mcp_kind,
change.kind,
);
}
let response = ClientResponse::Ok { id };
let frame = match serde_json::to_string(&response) {
Ok(s) => s,
Err(_) => continue,
};
let sink = sink.clone();
tokio::spawn(async move {
let mut guard = sink.lock().await;
let _ = guard.send(Message::Text(frame.into())).await;
});
continue;
}
}
}
if let Ok(response) = serde_json::from_str::<ServerResponse>(text.as_str()) {
match pending.remove(&response.id) {
Some((_, tx)) => {
let _ = tx.send(response);
}
None => {
eprintln!(
"dropping server_response for unknown id {:?}",
response.id
);
}
}
continue;
}
eprintln!("dropping unparseable WS frame (matched neither client_request nor server_response)");
}
}