use futures::{stream::SelectAll, StreamExt};
use serde_json::Value;
use tokio::sync::mpsc::{Receiver, Sender};
use tokio_stream::wrappers::ReceiverStream;
use crate::{
ChatRequest, ChatResponse, Notification, NotificationContent, Response, Success, Token,
ToolCall,
};
pub trait NotificationHandler {
fn get_outgoing_channel(&self) -> &Option<Sender<Notification>>;
fn get_channel_name(&self) -> &String;
async fn notify(&self, content: NotificationContent) -> bool {
if self.get_outgoing_channel().is_none() {
return false;
}
let notification_channel = self.get_outgoing_channel().as_ref().unwrap();
match notification_channel
.send(Notification::new(self.get_channel_name().clone(), content))
.await
{
Ok(_) => true,
Err(e) => {
tracing::error!(error = %e, "Failed sending notification");
false
}
}
}
fn forward_notifications(&self, mut from_channel: Receiver<Notification>) {
if let Some(notification_channel) = &self.get_outgoing_channel() {
let to_sender = notification_channel.clone();
tokio::spawn(async move {
while let Some(msg) = from_channel.recv().await {
if to_sender.send(msg.unwrap()).await.is_err() {
break;
}
}
});
}
}
fn forward_multiple_notifications<I>(&self, channels: I)
where
I: IntoIterator<Item = Receiver<Notification>>,
{
let to_sender = match &self.get_outgoing_channel() {
Some(s) => s.clone(),
None => return,
};
let mut merged = SelectAll::new();
for rx in channels {
let stream = ReceiverStream::new(rx).map(|notif| notif);
merged.push(stream);
}
tokio::spawn(async move {
while let Some(notification) = merged.next().await {
if to_sender.send(notification).await.is_err() {
break;
}
}
});
}
async fn notify_done(&self, success: Success, resp: Response) -> bool {
self.notify(NotificationContent::Done(success, resp)).await
}
async fn notify_prompt_request(&self, req: ChatRequest) -> bool {
self.notify(NotificationContent::PromptRequest(req)).await
}
async fn notify_prompt_success(&self, resp: ChatResponse) -> bool {
self.notify(NotificationContent::PromptSuccessResult(resp))
.await
}
async fn notify_prompt_error(&self, error_message: String) -> bool {
self.notify(NotificationContent::PromptErrorResult(error_message))
.await
}
async fn notify_tool_request(&self, tool_call: ToolCall) -> bool {
self.notify(NotificationContent::ToolCallRequest(tool_call))
.await
}
async fn notify_tool_success(&self, tool_result: String) -> bool {
self.notify(NotificationContent::ToolCallSuccessResult(tool_result))
.await
}
async fn notify_tool_error(&self, error_message: String) -> bool {
self.notify(NotificationContent::ToolCallErrorResult(error_message))
.await
}
async fn notify_token(&self, token: Token) -> bool {
self.notify(NotificationContent::Token(token)).await
}
async fn notify_mcp_tool_notification(&self, notification: String) -> bool {
self.notify(NotificationContent::McpToolNotification(notification))
.await
}
async fn notify_custom(&self, custom_val: Value) -> bool {
self.notify(NotificationContent::Custom(custom_val)).await
}
}