use std::sync::Arc;
use axum::extract::FromRequestParts;
use axum::extract::ws::{CloseCode, CloseFrame, Message, WebSocket, close_code};
use axum::http::request::Parts;
use axum::response::Response;
use futures::{SinkExt, StreamExt};
use futures::stream::{SplitSink, SplitStream};
use objectiveai_sdk::error::ResponseError;
use serde::Serialize;
use tokio::sync::{Mutex, oneshot};
use crate::error::ResponseErrorExt;
pub type SharedSink = Arc<Mutex<SplitSink<WebSocket, Message>>>;
pub struct SessionTracker {
ids: dashmap::DashSet<String>,
}
impl SessionTracker {
pub fn new() -> Arc<Self> {
Arc::new(Self {
ids: dashmap::DashSet::new(),
})
}
pub fn observe<C>(&self, chunk: &C)
where
C: objectiveai_sdk::agent::completions::response::streaming::AgentCompletionIds,
{
for id in chunk.agent_completion_ids() {
self.ids.insert(id.to_string());
}
}
pub fn contains(&self, id: &str) -> bool {
self.ids.contains(id)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Transport {
Sse,
WebSocket,
}
impl<S> FromRequestParts<S> for Transport
where
S: Send + Sync,
{
type Rejection = std::convert::Infallible;
async fn from_request_parts(parts: &mut Parts, _: &S) -> Result<Self, Self::Rejection> {
let is_ws_upgrade = parts
.headers
.get(axum::http::header::UPGRADE)
.and_then(|v| v.to_str().ok())
.map(|v| v.eq_ignore_ascii_case("websocket"))
.unwrap_or(false);
Ok(if is_ws_upgrade {
Transport::WebSocket
} else {
Transport::Sse
})
}
}
pub fn bad_request(message: &str) -> Response {
ResponseError {
code: 400,
message: serde_json::Value::String(message.to_string()),
}
.into_response()
}
use serde::de::DeserializeOwned;
pub async fn recv_body_frame<T: DeserializeOwned>(
socket: &mut WebSocket,
) -> Result<T, ResponseError> {
loop {
match socket.recv().await {
Some(Ok(Message::Text(text))) => {
return serde_json::from_str::<T>(text.as_str()).map_err(|e| ResponseError {
code: 400,
message: serde_json::Value::String(format!(
"failed to deserialize body frame: {e}"
)),
});
}
Some(Ok(Message::Binary(_))) => {
return Err(ResponseError {
code: 400,
message: serde_json::Value::String(
"expected text body frame, got binary".into(),
),
});
}
Some(Ok(Message::Ping(_) | Message::Pong(_))) => continue,
Some(Ok(Message::Close(_))) | None => {
return Err(ResponseError {
code: 400,
message: serde_json::Value::String(
"peer closed before sending body".into(),
),
});
}
Some(Err(e)) => {
return Err(ResponseError {
code: 400,
message: serde_json::Value::String(format!("websocket recv error: {e}")),
});
}
}
}
}
pub async fn send_error_and_close(socket: &mut WebSocket, err: &ResponseError, code: CloseCode) {
let frame = serde_json::to_string(err).unwrap_or_else(|_| String::from("{}"));
let _ = socket.send(Message::Text(frame.into())).await;
let _ = socket
.send(Message::Close(Some(CloseFrame {
code,
reason: "".into(),
})))
.await;
}
pub async fn fatal_setup_error(socket: &mut WebSocket, err: &ResponseError) {
send_error_and_close(socket, err, close_code::ERROR).await;
}
pub async fn fatal_setup_error_split(sink: &SharedSink, err: &ResponseError) {
let frame = serde_json::to_string(err).unwrap_or_else(|_| String::from("{}"));
{
let mut guard = sink.lock().await;
let _ = guard.send(Message::Text(frame.into())).await;
}
send_close_split(sink, close_code::ERROR).await;
}
pub async fn send_chunk_split<C: Serialize>(sink: &SharedSink, chunk: &C) -> Result<(), ()> {
let json = match serde_json::to_string(chunk) {
Ok(s) => s,
Err(_) => return Ok(()), };
let mut guard = sink.lock().await;
let result = guard
.send(Message::Text(json.into()))
.await
.map_err(|_| ());
result
}
pub async fn send_close_split(sink: &SharedSink, code: CloseCode) {
let mut guard = sink.lock().await;
let _ = guard
.send(Message::Close(Some(CloseFrame {
code,
reason: "".into(),
})))
.await;
}
pub type PendingRequests = Arc<
dashmap::DashMap<
String,
oneshot::Sender<objectiveai_sdk::client_objectiveai_mcp::server_response::Response>,
>,
>;
pub fn new_pending_requests() -> PendingRequests {
Arc::new(dashmap::DashMap::new())
}
#[derive(Clone)]
pub struct ReverseChannel {
pub sink: SharedSink,
pub pending: PendingRequests,
}
pub type ReverseChannelRegistry = Arc<dashmap::DashMap<String, ReverseChannel>>;
pub fn new_reverse_channel_registry() -> ReverseChannelRegistry {
Arc::new(dashmap::DashMap::new())
}
#[derive(Clone)]
pub struct ReverseAttachConfig {
pub registry: ReverseChannelRegistry,
pub api_port: u16,
}
pub struct ReverseAttachHandle {
registry: ReverseChannelRegistry,
channel: ReverseChannel,
registered: std::sync::Mutex<Vec<String>>,
}
impl std::fmt::Debug for ReverseAttachHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let count = self
.registered
.try_lock()
.map(|g| g.len())
.unwrap_or(usize::MAX);
f.debug_struct("ReverseAttachHandle")
.field("registered_count", &count)
.finish_non_exhaustive()
}
}
impl ReverseAttachHandle {
pub fn register(&self, id: String) {
self.registry
.insert(id.clone(), self.channel.clone());
self.registered.lock().unwrap().push(id);
}
}
pub struct ReverseAttachGuard {
handle: Arc<ReverseAttachHandle>,
}
impl ReverseAttachGuard {
pub fn new(
registry: ReverseChannelRegistry,
sink: SharedSink,
pending: PendingRequests,
) -> Self {
let handle = Arc::new(ReverseAttachHandle {
registry,
channel: ReverseChannel { sink, pending },
registered: std::sync::Mutex::new(Vec::new()),
});
Self { handle }
}
pub fn handle(&self) -> Arc<ReverseAttachHandle> {
self.handle.clone()
}
}
impl Drop for ReverseAttachGuard {
fn drop(&mut self) {
let ids = std::mem::take(&mut *self.handle.registered.lock().unwrap());
for id in ids {
self.handle.registry.remove(&id);
}
}
}
pub async fn send_server_request(
sink: &SharedSink,
pending: &PendingRequests,
request: objectiveai_sdk::client_objectiveai_mcp::server_request::Request,
) -> Result<
oneshot::Receiver<objectiveai_sdk::client_objectiveai_mcp::server_response::Response>,
(),
> {
let id = request.id.clone();
let (tx, rx) = oneshot::channel();
pending.insert(id.clone(), tx);
let frame = match serde_json::to_string(&request) {
Ok(s) => s,
Err(_) => {
pending.remove(&id);
return Err(());
}
};
let mut guard = sink.lock().await;
let send_result = guard.send(Message::Text(frame.into())).await;
if send_result.is_err() {
drop(guard);
pending.remove(&id);
return Err(());
}
Ok(rx)
}
pub async fn recv_loop<F, Fut>(
mut rx: SplitStream<WebSocket>,
tracker: Arc<SessionTracker>,
sink: SharedSink,
pending: PendingRequests,
notify_fn: F,
) where
F: Fn(objectiveai_sdk::agent::completions::request::AgentCompletionNotifyParams) -> Fut
+ Send
+ Sync
+ 'static,
Fut: std::future::Future<Output = Result<(), crate::agent::completions::Error>>
+ Send
+ 'static,
{
use objectiveai_sdk::client_objectiveai_mcp::{
client_request::{Payload as ClientPayload, Request as ClientRequest},
client_response::Response as ClientResponse,
server_response::Response as ServerResponse,
};
let notify_fn = Arc::new(notify_fn);
loop {
let msg = match rx.next().await {
Some(m) => m,
None => {
return;
}
};
let text = match msg {
Ok(Message::Text(t)) => {
t
}
Ok(Message::Binary(_)) => {
eprintln!("ignoring binary frame on streaming WS recv side");
continue;
}
Ok(Message::Ping(_) | Message::Pong(_)) => continue,
Ok(Message::Close(_)) => {
return;
}
Err(e) => {
eprintln!("streaming WS recv error: {e}");
return;
}
};
if let Ok(request) = serde_json::from_str::<ClientRequest>(text.as_str()) {
let ClientRequest { id, payload } = request;
match payload {
ClientPayload::AgentCompletionNotify(params) => {
let tracker = tracker.clone();
let sink = sink.clone();
let notify_fn = notify_fn.clone();
tokio::spawn(async move {
let response: ClientResponse = if !tracker.contains(¶ms.response_id) {
ClientResponse::Error {
id,
code: 404,
message: serde_json::Value::String(format!(
"response_id {:?} not from this stream",
params.response_id
)),
}
} else {
match (notify_fn)(params).await {
Ok(()) => ClientResponse::Ok { id },
Err(e) => {
let inner = ResponseError::from(&e);
ClientResponse::Error {
id,
code: inner.code,
message: inner.message,
}
}
}
};
let frame = match serde_json::to_string(&response) {
Ok(s) => s,
Err(_) => return,
};
let mut guard = sink.lock().await;
let _ = guard.send(Message::Text(frame.into())).await;
});
continue;
}
}
}
if let Ok(response) = serde_json::from_str::<ServerResponse>(text.as_str()) {
match pending.remove(&response.id) {
Some((_, tx)) => {
let _ = tx.send(response);
}
None => {
eprintln!(
"dropping server_response for unknown id {:?}",
response.id
);
}
}
continue;
}
eprintln!("dropping unparseable WS frame (matched neither client_request nor server_response)");
}
}