use super::listeners::McpListenerRegistry;
use axum::extract::ws::{Message, WebSocket};
use dashmap::DashMap;
use futures::stream::SplitSink;
use objectiveai_sdk::client_objectiveai_mcp::server_response;
use std::sync::Arc;
use tokio::sync::{Mutex, oneshot};
pub type SharedSink = Arc<Mutex<SplitSink<WebSocket, Message>>>;
pub struct SessionTracker {
ids: dashmap::DashSet<String>,
}
impl SessionTracker {
pub fn new() -> Arc<Self> {
Arc::new(Self {
ids: dashmap::DashSet::new(),
})
}
pub fn observe<C>(&self, chunk: &C)
where
C: objectiveai_sdk::agent::completions::response::streaming::AgentCompletionIds,
{
for id in chunk.agent_completion_ids() {
self.ids.insert(id.to_string());
}
}
pub fn contains(&self, id: &str) -> bool {
self.ids.contains(id)
}
}
pub type PendingRequests = Arc<DashMap<String, oneshot::Sender<server_response::Response>>>;
pub fn new_pending_requests() -> PendingRequests {
Arc::new(DashMap::new())
}
#[derive(Clone)]
pub struct ReverseChannel {
pub sink: SharedSink,
pub pending: PendingRequests,
}
pub type ReverseChannelRegistry = Arc<DashMap<String, ReverseChannel>>;
pub fn new_reverse_channel_registry() -> ReverseChannelRegistry {
Arc::new(DashMap::new())
}
#[derive(Clone)]
pub struct ReverseAttachConfig {
pub registry: ReverseChannelRegistry,
pub mcp_port: u16,
pub mcp_listeners: McpListenerRegistry,
}
pub struct ReverseAttachHandle {
registry: ReverseChannelRegistry,
channel: ReverseChannel,
registered: std::sync::Mutex<Vec<String>>,
}
impl std::fmt::Debug for ReverseAttachHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let count = self
.registered
.try_lock()
.map(|g| g.len())
.unwrap_or(usize::MAX);
f.debug_struct("ReverseAttachHandle")
.field("registered_count", &count)
.finish_non_exhaustive()
}
}
impl ReverseAttachHandle {
pub fn register(&self, id: String) {
self.registry.insert(id.clone(), self.channel.clone());
self.registered.lock().unwrap().push(id);
}
pub fn registered_ids(&self) -> Vec<String> {
self.registered.lock().unwrap().clone()
}
pub fn channel(&self) -> &ReverseChannel {
&self.channel
}
}
pub struct ReverseAttachGuard {
handle: Arc<ReverseAttachHandle>,
}
impl ReverseAttachGuard {
pub fn new(
registry: ReverseChannelRegistry,
sink: SharedSink,
pending: PendingRequests,
) -> Self {
let handle = Arc::new(ReverseAttachHandle {
registry,
channel: ReverseChannel { sink, pending },
registered: std::sync::Mutex::new(Vec::new()),
});
Self { handle }
}
pub fn handle(&self) -> Arc<ReverseAttachHandle> {
self.handle.clone()
}
}
impl Drop for ReverseAttachGuard {
fn drop(&mut self) {
let ids = std::mem::take(&mut *self.handle.registered.lock().unwrap());
for id in ids {
self.handle.registry.remove(&id);
}
}
}