use crate::client_objectiveai_mcp::{client_request, client_response};
use futures::SinkExt;
use futures::stream::SplitSink;
use std::sync::Arc;
use tokio::net::TcpStream;
use tokio::sync::{Mutex, oneshot};
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, tungstenite};
pub(crate) type SharedSink = Arc<
Mutex<
SplitSink<
WebSocketStream<MaybeTlsStream<TcpStream>>,
tungstenite::Message,
>,
>,
>;
pub(crate) type PendingNotifies =
Arc<dashmap::DashMap<String, oneshot::Sender<client_response::Response>>>;
#[derive(Clone)]
pub struct Notifier {
sink: SharedSink,
pending: PendingNotifies,
}
impl Notifier {
pub(crate) fn new(sink: SharedSink, pending: PendingNotifies) -> Self {
Self { sink, pending }
}
pub async fn notify(
&self,
params: crate::agent::completions::request::AgentCompletionNotifyParams,
) -> Result<(), super::HttpError> {
self.send(client_request::Payload::AgentCompletionNotify(params))
.await
}
pub async fn notify_list_changed(
&self,
change: client_request::McpListChanged,
) -> Result<(), super::HttpError> {
self.send(client_request::Payload::McpListChanged(change))
.await
}
async fn send(
&self,
payload: client_request::Payload,
) -> Result<(), super::HttpError> {
let id = uuid::Uuid::new_v4().to_string();
let (tx, rx) = oneshot::channel();
self.pending.insert(id.clone(), tx);
let request = client_request::Request {
id: id.clone(),
payload,
};
let frame = match serde_json::to_string(&request) {
Ok(s) => s,
Err(e) => {
self.pending.remove(&id);
return Err(super::HttpError::NotifySerialize(e));
}
};
{
let mut guard = self.sink.lock().await;
if let Err(e) =
guard.send(tungstenite::Message::Text(frame.into())).await
{
drop(guard);
self.pending.remove(&id);
return Err(super::HttpError::NotifySend(e));
}
}
let response = match rx.await {
Ok(r) => r,
Err(_) => return Err(super::HttpError::NotifyChannelClosed),
};
match response {
client_response::Response::Ok { .. } => Ok(()),
client_response::Response::Error { code, message, .. } => {
Err(super::HttpError::NotifyRejected { code, message })
}
}
}
}