1use crate::{
12 protocol::{JsonRpcNotification, ProgressToken},
13 types::{
14 messages::{LogMessageNotification, ProgressNotification, ResourceUpdatedNotification},
15 LoggingLevel,
16 },
17};
18use serde::Serialize;
19use std::sync::Arc;
20use tokio::sync::mpsc;
21
22#[derive(Clone)]
40pub struct NotificationSender {
41 tx: mpsc::Sender<JsonRpcNotification>,
42}
43
44impl NotificationSender {
45 pub fn new(tx: mpsc::Sender<JsonRpcNotification>) -> Self {
47 Self { tx }
48 }
49
50 pub fn channel(buffer: usize) -> (Self, NotificationReceiver) {
52 let (tx, rx) = mpsc::channel(buffer);
53 (Self { tx }, NotificationReceiver { rx })
54 }
55
56 pub async fn send(&self, notification: JsonRpcNotification) -> Result<(), SendError> {
58 self.tx
59 .send(notification)
60 .await
61 .map_err(|_| SendError::ChannelClosed)
62 }
63
64 pub async fn notify<T: Serialize>(&self, method: &str, params: T) -> Result<(), SendError> {
66 let params = serde_json::to_value(params).map_err(SendError::Serialize)?;
67 self.send(JsonRpcNotification::new(method, Some(params)))
68 .await
69 }
70
71 pub async fn resource_updated(&self, uri: impl Into<String>) -> Result<(), SendError> {
77 self.notify(
78 "notifications/resources/updated",
79 ResourceUpdatedNotification { uri: uri.into() },
80 )
81 .await
82 }
83
84 pub async fn resources_list_changed(&self) -> Result<(), SendError> {
88 self.send(JsonRpcNotification::new(
89 "notifications/resources/list_changed",
90 None,
91 ))
92 .await
93 }
94
95 pub async fn tools_list_changed(&self) -> Result<(), SendError> {
101 self.send(JsonRpcNotification::new(
102 "notifications/tools/list_changed",
103 None,
104 ))
105 .await
106 }
107
108 pub async fn prompts_list_changed(&self) -> Result<(), SendError> {
114 self.send(JsonRpcNotification::new(
115 "notifications/prompts/list_changed",
116 None,
117 ))
118 .await
119 }
120
121 pub async fn log(
125 &self,
126 level: LoggingLevel,
127 logger: Option<String>,
128 data: impl Serialize,
129 ) -> Result<(), SendError> {
130 let data = serde_json::to_value(data).map_err(SendError::Serialize)?;
131 self.notify(
132 "notifications/message",
133 LogMessageNotification {
134 level,
135 logger,
136 data,
137 },
138 )
139 .await
140 }
141
142 pub async fn log_debug(
144 &self,
145 logger: impl Into<String>,
146 message: impl Into<String>,
147 ) -> Result<(), SendError> {
148 self.log(LoggingLevel::Debug, Some(logger.into()), message.into())
149 .await
150 }
151
152 pub async fn log_info(
154 &self,
155 logger: impl Into<String>,
156 message: impl Into<String>,
157 ) -> Result<(), SendError> {
158 self.log(LoggingLevel::Info, Some(logger.into()), message.into())
159 .await
160 }
161
162 pub async fn log_warning(
164 &self,
165 logger: impl Into<String>,
166 message: impl Into<String>,
167 ) -> Result<(), SendError> {
168 self.log(LoggingLevel::Warning, Some(logger.into()), message.into())
169 .await
170 }
171
172 pub async fn log_error(
174 &self,
175 logger: impl Into<String>,
176 message: impl Into<String>,
177 ) -> Result<(), SendError> {
178 self.log(LoggingLevel::Error, Some(logger.into()), message.into())
179 .await
180 }
181
182 pub async fn progress(
189 &self,
190 progress_token: impl Into<ProgressToken>,
191 progress: f64,
192 total: Option<f64>,
193 message: Option<String>,
194 ) -> Result<(), SendError> {
195 self.notify(
196 "notifications/progress",
197 ProgressNotification {
198 progress_token: progress_token.into(),
199 progress,
200 total,
201 message,
202 },
203 )
204 .await
205 }
206
207 pub async fn progress_with_message(
209 &self,
210 progress_token: impl Into<ProgressToken>,
211 progress: f64,
212 total: f64,
213 message: impl Into<String>,
214 ) -> Result<(), SendError> {
215 self.progress(progress_token, progress, Some(total), Some(message.into()))
216 .await
217 }
218}
219
220pub struct NotificationReceiver {
222 rx: mpsc::Receiver<JsonRpcNotification>,
223}
224
225impl NotificationReceiver {
226 pub async fn recv(&mut self) -> Option<JsonRpcNotification> {
228 self.rx.recv().await
229 }
230
231 pub fn try_recv(&mut self) -> Result<JsonRpcNotification, mpsc::error::TryRecvError> {
233 self.rx.try_recv()
234 }
235}
236
237#[derive(Debug, thiserror::Error)]
239pub enum SendError {
240 #[error("Notification channel closed")]
241 ChannelClosed,
242 #[error("Failed to serialize notification: {0}")]
243 Serialize(serde_json::Error),
244}
245
246pub type SharedNotificationSender = Arc<NotificationSender>;
250
251impl From<i64> for crate::protocol::ProgressToken {
252 fn from(n: i64) -> Self {
253 crate::protocol::ProgressToken::Number(n)
254 }
255}
256
257impl From<String> for crate::protocol::ProgressToken {
258 fn from(s: String) -> Self {
259 crate::protocol::ProgressToken::String(s)
260 }
261}
262
263impl From<&str> for crate::protocol::ProgressToken {
264 fn from(s: &str) -> Self {
265 crate::protocol::ProgressToken::String(s.to_owned())
266 }
267}