use axum::extract::ws::{Message, WebSocket};
use dashmap::DashMap;
use futures::stream::SplitSink;
use objectiveai_sdk::client_objectiveai_mcp::server_response;
use std::sync::Arc;
use tokio::sync::{Mutex, oneshot};
pub type SharedSink = Arc<Mutex<SplitSink<WebSocket, Message>>>;
pub type PendingRequests = Arc<DashMap<String, oneshot::Sender<server_response::Response>>>;
pub fn new_pending_requests() -> PendingRequests {
Arc::new(DashMap::new())
}
#[derive(Clone)]
pub struct ReverseChannel {
pub sink: SharedSink,
pub pending: PendingRequests,
}
#[derive(Clone)]
pub struct ReverseAttachConfig {
pub reverse_channel_timeout: std::time::Duration,
}
pub struct ReverseAttachHandle {
channel: ReverseChannel,
reverse_channel_timeout: std::time::Duration,
}
impl std::fmt::Debug for ReverseAttachHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ReverseAttachHandle").finish_non_exhaustive()
}
}
impl ReverseAttachHandle {
pub fn channel(&self) -> &ReverseChannel {
&self.channel
}
pub fn reverse_channel_timeout(&self) -> std::time::Duration {
self.reverse_channel_timeout
}
}
pub struct ReverseAttachGuard {
handle: Arc<ReverseAttachHandle>,
}
impl ReverseAttachGuard {
pub fn new(
sink: SharedSink,
pending: PendingRequests,
reverse_channel_timeout: std::time::Duration,
) -> Self {
let handle = Arc::new(ReverseAttachHandle {
channel: ReverseChannel { sink, pending },
reverse_channel_timeout,
});
Self { handle }
}
pub fn handle(&self) -> Arc<ReverseAttachHandle> {
self.handle.clone()
}
}