use std::sync::Arc;
use axum::extract::FromRequestParts;
use axum::extract::ws::{CloseCode, CloseFrame, Message, WebSocket, close_code};
use axum::http::request::Parts;
use axum::response::Response;
use futures::{SinkExt, StreamExt};
use futures::stream::SplitStream;
use objectiveai_sdk::error::ResponseError;
use serde::Serialize;
pub use crate::objectiveai_mcp::{
PendingRequests, ReverseAttachConfig, ReverseAttachGuard, ReverseAttachHandle, ReverseChannel,
SharedSink, new_pending_requests,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Transport {
Sse,
WebSocket,
}
impl<S> FromRequestParts<S> for Transport
where
S: Send + Sync,
{
type Rejection = std::convert::Infallible;
async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
let is_ws_upgrade = parts
.headers
.get(axum::http::header::UPGRADE)
.and_then(|v| v.to_str().ok())
.map(|v| v.eq_ignore_ascii_case("websocket"))
.unwrap_or(false);
Ok(if is_ws_upgrade {
Transport::WebSocket
} else {
Transport::Sse
})
}
}
use serde::de::DeserializeOwned;
pub async fn recv_body_frame<T: DeserializeOwned>(
socket: &mut WebSocket,
) -> Result<T, ResponseError> {
loop {
match socket.recv().await {
Some(Ok(Message::Text(text))) => {
return serde_json::from_str::<T>(text.as_str()).map_err(|e| ResponseError {
code: 400,
message: serde_json::Value::String(format!(
"failed to deserialize body frame: {e}"
)),
});
}
Some(Ok(Message::Binary(_))) => {
return Err(ResponseError {
code: 400,
message: serde_json::Value::String(
"expected text body frame, got binary".into(),
),
});
}
Some(Ok(Message::Ping(_) | Message::Pong(_))) => continue,
Some(Ok(Message::Close(_))) | None => {
return Err(ResponseError {
code: 400,
message: serde_json::Value::String(
"peer closed before sending body".into(),
),
});
}
Some(Err(e)) => {
return Err(ResponseError {
code: 400,
message: serde_json::Value::String(format!("websocket recv error: {e}")),
});
}
}
}
}
pub async fn send_error_and_close(socket: &mut WebSocket, err: &ResponseError, code: CloseCode) {
let frame = serde_json::to_string(err).unwrap_or_else(|_| String::from("{}"));
let _ = socket.send(Message::Text(frame.into())).await;
let _ = socket
.send(Message::Close(Some(CloseFrame {
code,
reason: "".into(),
})))
.await;
}
pub async fn fatal_setup_error_split(sink: &SharedSink, err: &ResponseError) {
let frame = serde_json::to_string(err).unwrap_or_else(|_| String::from("{}"));
{
let mut guard = sink.lock().await;
let _ = guard.send(Message::Text(frame.into())).await;
}
send_close_split(sink, close_code::ERROR).await;
}
pub async fn send_chunk_split<C: Serialize>(sink: &SharedSink, chunk: &C) -> Result<(), ()> {
let json = match serde_json::to_string(chunk) {
Ok(s) => s,
Err(_) => return Ok(()), };
let mut guard = sink.lock().await;
let result = guard
.send(Message::Text(json.into()))
.await
.map_err(|_| ());
result
}
pub async fn drain_reverse_channel(
sink: SharedSink,
mut req_rx: tokio::sync::mpsc::UnboundedReceiver<
objectiveai_sdk::client_objectiveai_mcp::server_request::Request,
>,
) {
while let Some(req) = req_rx.recv().await {
let frame = match serde_json::to_string(&req) {
Ok(s) => s,
Err(_) => continue,
};
let mut guard = sink.lock().await;
if guard.send(Message::Text(frame.into())).await.is_err() {
return;
}
}
}
pub async fn send_close_split(sink: &SharedSink, code: CloseCode) {
let mut guard = sink.lock().await;
let _ = guard
.send(Message::Close(Some(CloseFrame {
code,
reason: "".into(),
})))
.await;
}
pub async fn recv_loop(
mut rx: SplitStream<WebSocket>,
sink: SharedSink,
pending: PendingRequests,
channel: objectiveai_mcp_proxy::ReverseChannel,
) {
use objectiveai_sdk::client_objectiveai_mcp::{
client_request::Request as ClientRequest,
server_response::Response as ServerResponse,
};
loop {
let msg = match rx.next().await {
Some(m) => m,
None => {
return;
}
};
let text = match msg {
Ok(Message::Text(t)) => {
t
}
Ok(Message::Binary(_)) => {
eprintln!("ignoring binary frame on streaming WS recv side");
continue;
}
Ok(Message::Ping(_) | Message::Pong(_)) => continue,
Ok(Message::Close(_)) => {
return;
}
Err(e) => {
eprintln!("streaming WS recv error: {e}");
return;
}
};
if let Ok(request) = serde_json::from_str::<ClientRequest>(text.as_str()) {
let response = channel.deliver_client_request(request);
let frame = match serde_json::to_string(&response) {
Ok(s) => s,
Err(_) => continue,
};
let sink = sink.clone();
tokio::spawn(async move {
let mut guard = sink.lock().await;
let _ = guard.send(Message::Text(frame.into())).await;
});
continue;
}
if let Ok(response) = serde_json::from_str::<ServerResponse>(text.as_str()) {
if response.payload.mcp_kind().is_some() {
channel.deliver_response(response);
} else {
match pending.remove(&response.id) {
Some((_, tx)) => {
let _ = tx.send(response);
}
None => {
eprintln!(
"dropping server_response for unknown id {:?}",
response.id
);
}
}
}
continue;
}
eprintln!("dropping unparseable WS frame (matched neither client_request nor server_response)");
}
}