pmcp/utils/batching.rs
1//! Message batching and debouncing utilities.
2
3use crate::error::Result;
4use crate::types::Notification;
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::Duration;
8#[cfg(not(target_arch = "wasm32"))]
9use tokio::sync::{mpsc, Mutex};
10#[cfg(not(target_arch = "wasm32"))]
11use tokio::time::{interval, sleep};
12use tracing::{debug, trace};
13
14/// Configuration for message batching.
15#[derive(Debug, Clone)]
16pub struct BatchingConfig {
17 /// Maximum number of messages in a batch
18 pub max_batch_size: usize,
19 /// Maximum time to wait before sending a batch
20 pub max_wait_time: Duration,
21 /// Methods to batch (empty means batch all)
22 pub batched_methods: Vec<String>,
23}
24
25impl Default for BatchingConfig {
26 fn default() -> Self {
27 Self {
28 max_batch_size: 10,
29 max_wait_time: Duration::from_millis(100),
30 batched_methods: vec![],
31 }
32 }
33}
34
35/// Message batcher that groups notifications.
36pub struct MessageBatcher {
37 config: BatchingConfig,
38 pending: Arc<Mutex<Vec<Notification>>>,
39 tx: mpsc::Sender<Vec<Notification>>,
40 rx: Arc<Mutex<mpsc::Receiver<Vec<Notification>>>>,
41}
42
43impl std::fmt::Debug for MessageBatcher {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 f.debug_struct("MessageBatcher")
46 .field("config", &self.config)
47 .finish_non_exhaustive()
48 }
49}
50
51impl MessageBatcher {
52 /// Create a new message batcher.
53 ///
54 /// # Examples
55 ///
56 /// ```rust
57 /// use pmcp::utils::{MessageBatcher, BatchingConfig};
58 /// use std::time::Duration;
59 ///
60 /// // Default configuration
61 /// let batcher = MessageBatcher::new(BatchingConfig::default());
62 ///
63 /// // Custom configuration for high-throughput scenarios
64 /// let config = BatchingConfig {
65 /// max_batch_size: 50,
66 /// max_wait_time: Duration::from_millis(200),
67 /// batched_methods: vec!["logs.add".to_string(), "progress.update".to_string()],
68 /// };
69 /// let high_throughput_batcher = MessageBatcher::new(config);
70 ///
71 /// // Configuration for low-latency scenarios
72 /// let low_latency_config = BatchingConfig {
73 /// max_batch_size: 5,
74 /// max_wait_time: Duration::from_millis(10),
75 /// batched_methods: vec![],
76 /// };
77 /// let low_latency_batcher = MessageBatcher::new(low_latency_config);
78 /// ```
79 pub fn new(config: BatchingConfig) -> Self {
80 let (tx, rx) = mpsc::channel(100);
81 Self {
82 config,
83 pending: Arc::new(Mutex::new(Vec::new())),
84 tx,
85 rx: Arc::new(Mutex::new(rx)),
86 }
87 }
88
89 /// Add a notification to the batch.
90 ///
91 /// # Examples
92 ///
93 /// ```rust
94 /// use pmcp::utils::{MessageBatcher, BatchingConfig};
95 /// use pmcp::types::{Notification, ClientNotification, ServerNotification};
96 /// use std::time::Duration;
97 ///
98 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
99 /// let config = BatchingConfig {
100 /// max_batch_size: 3,
101 /// max_wait_time: Duration::from_millis(100),
102 /// batched_methods: vec![],
103 /// };
104 /// let batcher = MessageBatcher::new(config);
105 ///
106 /// // Add various notifications
107 /// batcher.add(Notification::Client(ClientNotification::Initialized)).await?;
108 /// batcher.add(Notification::Client(ClientNotification::RootsListChanged)).await?;
109 ///
110 /// // This will trigger immediate batch send (max_batch_size reached)
111 /// batcher.add(Notification::Server(ServerNotification::ToolsChanged)).await?;
112 /// # Ok(())
113 /// # }
114 /// ```
115 pub async fn add(&self, notification: Notification) -> Result<()> {
116 let mut pending = self.pending.lock().await;
117 pending.push(notification);
118
119 if pending.len() >= self.config.max_batch_size {
120 let batch = std::mem::take(&mut *pending);
121 drop(pending);
122 self.tx
123 .send(batch)
124 .await
125 .map_err(|_| crate::error::Error::Internal("Failed to send batch".to_string()))?;
126 } else {
127 drop(pending);
128 }
129
130 Ok(())
131 }
132
133 /// Start the batching timer.
134 ///
135 /// # Examples
136 ///
137 /// ```rust
138 /// use pmcp::utils::{MessageBatcher, BatchingConfig};
139 /// use pmcp::types::{Notification, ClientNotification};
140 /// use std::time::Duration;
141 ///
142 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
143 /// let config = BatchingConfig {
144 /// max_batch_size: 10,
145 /// max_wait_time: Duration::from_millis(50),
146 /// batched_methods: vec![],
147 /// };
148 /// let batcher = MessageBatcher::new(config);
149 ///
150 /// // Start the timer to automatically flush batches
151 /// batcher.start_timer();
152 ///
153 /// // Add notifications that won't reach max_batch_size
154 /// batcher.add(Notification::Client(ClientNotification::Initialized)).await?;
155 ///
156 /// // Timer will ensure this gets sent after max_wait_time
157 /// # Ok(())
158 /// # }
159 /// ```
160 pub fn start_timer(&self) {
161 let pending = self.pending.clone();
162 let tx = self.tx.clone();
163 let max_wait = self.config.max_wait_time;
164
165 tokio::spawn(async move {
166 let mut ticker = interval(max_wait);
167 loop {
168 ticker.tick().await;
169
170 let mut pending_guard = pending.lock().await;
171 if !pending_guard.is_empty() {
172 let batch = std::mem::take(&mut *pending_guard);
173 drop(pending_guard);
174 if tx.send(batch).await.is_err() {
175 break;
176 }
177 }
178 }
179 });
180 }
181
182 /// Receive the next batch of notifications.
183 ///
184 /// # Examples
185 ///
186 /// ```rust
187 /// use pmcp::utils::{MessageBatcher, BatchingConfig};
188 /// use pmcp::types::{Notification, ClientNotification};
189 /// use std::time::Duration;
190 ///
191 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
192 /// let config = BatchingConfig {
193 /// max_batch_size: 2,
194 /// max_wait_time: Duration::from_millis(100),
195 /// batched_methods: vec![],
196 /// };
197 /// let batcher = MessageBatcher::new(config);
198 /// batcher.start_timer();
199 ///
200 /// // Add notifications
201 /// batcher.add(Notification::Client(ClientNotification::Initialized)).await?;
202 /// batcher.add(Notification::Client(ClientNotification::RootsListChanged)).await?;
203 ///
204 /// // Receive the batch (triggered by max_batch_size)
205 /// if let Some(batch) = batcher.receive_batch().await {
206 /// println!("Received batch with {} notifications", batch.len());
207 /// for notification in batch {
208 /// // Process each notification
209 /// }
210 /// }
211 /// # Ok(())
212 /// # }
213 /// ```
214 pub async fn receive_batch(&self) -> Option<Vec<Notification>> {
215 self.rx.lock().await.recv().await
216 }
217}
218
219/// Configuration for debouncing.
220#[derive(Debug, Clone)]
221pub struct DebouncingConfig {
222 /// Time to wait before sending after last update
223 pub wait_time: Duration,
224 /// Methods to debounce (method -> wait time)
225 pub debounced_methods: HashMap<String, Duration>,
226}
227
228impl Default for DebouncingConfig {
229 fn default() -> Self {
230 Self {
231 wait_time: Duration::from_millis(50),
232 debounced_methods: HashMap::new(),
233 }
234 }
235}
236
237/// Message debouncer that delays and coalesces rapid notifications.
238pub struct MessageDebouncer {
239 config: DebouncingConfig,
240 pending: Arc<Mutex<HashMap<String, Notification>>>,
241 timers: Arc<Mutex<HashMap<String, tokio::task::JoinHandle<()>>>>,
242 tx: mpsc::Sender<Notification>,
243 rx: Arc<Mutex<mpsc::Receiver<Notification>>>,
244}
245
246impl std::fmt::Debug for MessageDebouncer {
247 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
248 f.debug_struct("MessageDebouncer")
249 .field("config", &self.config)
250 .finish_non_exhaustive()
251 }
252}
253
254impl MessageDebouncer {
255 /// Create a new message debouncer.
256 ///
257 /// # Examples
258 ///
259 /// ```rust
260 /// use pmcp::utils::{MessageDebouncer, DebouncingConfig};
261 /// use std::time::Duration;
262 /// use std::collections::HashMap;
263 ///
264 /// // Default configuration
265 /// let debouncer = MessageDebouncer::new(DebouncingConfig::default());
266 ///
267 /// // Custom configuration with per-method settings
268 /// let mut config = DebouncingConfig {
269 /// wait_time: Duration::from_millis(100),
270 /// debounced_methods: HashMap::new(),
271 /// };
272 ///
273 /// // Different debounce times for different notification types
274 /// config.debounced_methods.insert(
275 /// "progress.update".to_string(),
276 /// Duration::from_millis(50)
277 /// );
278 /// config.debounced_methods.insert(
279 /// "file.changed".to_string(),
280 /// Duration::from_millis(500)
281 /// );
282 ///
283 /// let custom_debouncer = MessageDebouncer::new(config);
284 /// ```
285 pub fn new(config: DebouncingConfig) -> Self {
286 let (tx, rx) = mpsc::channel(100);
287 Self {
288 config,
289 pending: Arc::new(Mutex::new(HashMap::new())),
290 timers: Arc::new(Mutex::new(HashMap::new())),
291 tx,
292 rx: Arc::new(Mutex::new(rx)),
293 }
294 }
295
296 /// Add a notification to be debounced.
297 ///
298 /// # Examples
299 ///
300 /// ```rust
301 /// use pmcp::utils::{MessageDebouncer, DebouncingConfig};
302 /// use pmcp::types::{Notification, ServerNotification};
303 /// use std::time::Duration;
304 ///
305 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
306 /// let debouncer = MessageDebouncer::new(DebouncingConfig {
307 /// wait_time: Duration::from_millis(100),
308 /// debounced_methods: Default::default(),
309 /// });
310 ///
311 /// // Rapid file change notifications
312 /// debouncer.add(
313 /// "file:/path/to/file.rs".to_string(),
314 /// Notification::Server(ServerNotification::ResourcesChanged)
315 /// ).await?;
316 ///
317 /// // Another change to same file within debounce window
318 /// tokio::time::sleep(Duration::from_millis(50)).await;
319 /// debouncer.add(
320 /// "file:/path/to/file.rs".to_string(),
321 /// Notification::Server(ServerNotification::ResourcesChanged)
322 /// ).await?;
323 ///
324 /// // Only the last notification will be sent after debounce period
325 /// # Ok(())
326 /// # }
327 /// ```
328 pub async fn add(&self, key: String, notification: Notification) -> Result<()> {
329 trace!("Debouncing notification with key: {}", key);
330
331 let wait_time = self
332 .config
333 .debounced_methods
334 .get(&key)
335 .copied()
336 .unwrap_or(self.config.wait_time);
337
338 // Store the latest notification
339 {
340 let mut pending = self.pending.lock().await;
341 pending.insert(key.clone(), notification);
342 }
343
344 // Cancel existing timer
345 {
346 let mut timers = self.timers.lock().await;
347 if let Some(handle) = timers.remove(&key) {
348 handle.abort();
349 }
350 }
351
352 // Start new timer
353 let pending = self.pending.clone();
354 let tx = self.tx.clone();
355 let timers = self.timers.clone();
356 let key_clone = key.clone();
357
358 let handle = tokio::spawn(async move {
359 sleep(wait_time).await;
360
361 // Send the notification
362 let notification = {
363 let mut pending = pending.lock().await;
364 pending.remove(&key_clone)
365 };
366
367 if let Some(notification) = notification {
368 debug!("Sending debounced notification: {}", key_clone);
369 let _ = tx.send(notification).await;
370 }
371
372 // Remove timer handle
373 let mut timers = timers.lock().await;
374 timers.remove(&key_clone);
375 });
376
377 // Store timer handle
378 {
379 let mut timers = self.timers.lock().await;
380 timers.insert(key, handle);
381 }
382
383 Ok(())
384 }
385
386 /// Receive the next debounced notification.
387 ///
388 /// # Examples
389 ///
390 /// ```rust
391 /// use pmcp::utils::{MessageDebouncer, DebouncingConfig};
392 /// use pmcp::types::{Notification, ServerNotification};
393 /// use std::time::Duration;
394 ///
395 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
396 /// let debouncer = MessageDebouncer::new(DebouncingConfig::default());
397 ///
398 /// // Spawn a task to receive notifications
399 /// tokio::spawn(async move {
400 /// while let Some(notification) = debouncer.receive().await {
401 /// match notification {
402 /// Notification::Server(ServerNotification::ResourcesChanged) => {
403 /// println!("Resources changed (after debounce)");
404 /// }
405 /// _ => {}
406 /// }
407 /// }
408 /// });
409 /// # Ok(())
410 /// # }
411 /// ```
412 pub async fn receive(&self) -> Option<Notification> {
413 self.rx.lock().await.recv().await
414 }
415
416 /// Flush all pending notifications immediately.
417 ///
418 /// # Examples
419 ///
420 /// ```rust
421 /// use pmcp::utils::{MessageDebouncer, DebouncingConfig};
422 /// use pmcp::types::{Notification, ServerNotification, ClientNotification};
423 /// use std::time::Duration;
424 ///
425 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
426 /// let debouncer = MessageDebouncer::new(DebouncingConfig {
427 /// wait_time: Duration::from_secs(5), // Long debounce
428 /// debounced_methods: Default::default(),
429 /// });
430 ///
431 /// // Add several notifications
432 /// debouncer.add(
433 /// "resource1".to_string(),
434 /// Notification::Server(ServerNotification::ResourcesChanged)
435 /// ).await?;
436 /// debouncer.add(
437 /// "resource2".to_string(),
438 /// Notification::Client(ClientNotification::RootsListChanged)
439 /// ).await?;
440 ///
441 /// // Flush immediately instead of waiting for debounce
442 /// let pending = debouncer.flush().await;
443 /// println!("Flushed {} pending notifications", pending.len());
444 ///
445 /// // Process all flushed notifications
446 /// for notification in pending {
447 /// // Handle each notification immediately
448 /// }
449 /// # Ok(())
450 /// # }
451 /// ```
452 pub async fn flush(&self) -> Vec<Notification> {
453 // Cancel all timers
454 {
455 let mut timers = self.timers.lock().await;
456 for (_, handle) in timers.drain() {
457 handle.abort();
458 }
459 }
460
461 // Get all pending notifications
462 let mut pending = self.pending.lock().await;
463 pending.drain().map(|(_, v)| v).collect()
464 }
465}
466
467#[cfg(test)]
468mod tests {
469 use super::*;
470 use crate::types::ClientNotification;
471
472 #[tokio::test]
473 async fn test_message_batcher() {
474 let config = BatchingConfig {
475 max_batch_size: 2,
476 max_wait_time: Duration::from_millis(10),
477 batched_methods: vec![],
478 };
479
480 let batcher = MessageBatcher::new(config);
481 batcher.start_timer();
482
483 // Add two notifications
484 let notif1 = Notification::Client(ClientNotification::Initialized);
485 let notif2 = Notification::Client(ClientNotification::RootsListChanged);
486
487 batcher.add(notif1).await.unwrap();
488 batcher.add(notif2).await.unwrap();
489
490 // Should receive batch immediately (max size reached)
491 let batch = batcher.receive_batch().await;
492 assert!(batch.is_some());
493 assert_eq!(batch.unwrap().len(), 2);
494 }
495
496 #[tokio::test]
497 async fn test_message_debouncer() {
498 let config = DebouncingConfig {
499 wait_time: Duration::from_millis(10),
500 debounced_methods: HashMap::new(),
501 };
502
503 let debouncer = MessageDebouncer::new(config);
504
505 // Add notifications with same key
506 let notif1 = Notification::Client(ClientNotification::Initialized);
507 let notif2 = Notification::Client(ClientNotification::RootsListChanged);
508
509 debouncer.add("test".to_string(), notif1).await.unwrap();
510 tokio::time::sleep(Duration::from_millis(5)).await;
511 debouncer.add("test".to_string(), notif2).await.unwrap();
512
513 // Should only receive the last one after debounce period
514 let received = debouncer.receive().await;
515 assert!(received.is_some());
516
517 // Should be the second notification
518 match received.unwrap() {
519 Notification::Client(ClientNotification::RootsListChanged) => {},
520 _ => panic!("Wrong notification received"),
521 }
522 }
523}