pmcp/utils/
batching.rs

1//! Message batching and debouncing utilities.
2
3use crate::error::Result;
4use crate::types::Notification;
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::Duration;
8#[cfg(not(target_arch = "wasm32"))]
9use tokio::sync::{mpsc, Mutex};
10#[cfg(not(target_arch = "wasm32"))]
11use tokio::time::{interval, sleep};
12use tracing::{debug, trace};
13
14/// Configuration for message batching.
15#[derive(Debug, Clone)]
16pub struct BatchingConfig {
17    /// Maximum number of messages in a batch
18    pub max_batch_size: usize,
19    /// Maximum time to wait before sending a batch
20    pub max_wait_time: Duration,
21    /// Methods to batch (empty means batch all)
22    pub batched_methods: Vec<String>,
23}
24
25impl Default for BatchingConfig {
26    fn default() -> Self {
27        Self {
28            max_batch_size: 10,
29            max_wait_time: Duration::from_millis(100),
30            batched_methods: vec![],
31        }
32    }
33}
34
35/// Message batcher that groups notifications.
36pub struct MessageBatcher {
37    config: BatchingConfig,
38    pending: Arc<Mutex<Vec<Notification>>>,
39    tx: mpsc::Sender<Vec<Notification>>,
40    rx: Arc<Mutex<mpsc::Receiver<Vec<Notification>>>>,
41}
42
43impl std::fmt::Debug for MessageBatcher {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        f.debug_struct("MessageBatcher")
46            .field("config", &self.config)
47            .finish_non_exhaustive()
48    }
49}
50
51impl MessageBatcher {
52    /// Create a new message batcher.
53    ///
54    /// # Examples
55    ///
56    /// ```rust
57    /// use pmcp::utils::{MessageBatcher, BatchingConfig};
58    /// use std::time::Duration;
59    ///
60    /// // Default configuration
61    /// let batcher = MessageBatcher::new(BatchingConfig::default());
62    ///
63    /// // Custom configuration for high-throughput scenarios
64    /// let config = BatchingConfig {
65    ///     max_batch_size: 50,
66    ///     max_wait_time: Duration::from_millis(200),
67    ///     batched_methods: vec!["logs.add".to_string(), "progress.update".to_string()],
68    /// };
69    /// let high_throughput_batcher = MessageBatcher::new(config);
70    ///
71    /// // Configuration for low-latency scenarios
72    /// let low_latency_config = BatchingConfig {
73    ///     max_batch_size: 5,
74    ///     max_wait_time: Duration::from_millis(10),
75    ///     batched_methods: vec![],
76    /// };
77    /// let low_latency_batcher = MessageBatcher::new(low_latency_config);
78    /// ```
79    pub fn new(config: BatchingConfig) -> Self {
80        let (tx, rx) = mpsc::channel(100);
81        Self {
82            config,
83            pending: Arc::new(Mutex::new(Vec::new())),
84            tx,
85            rx: Arc::new(Mutex::new(rx)),
86        }
87    }
88
89    /// Add a notification to the batch.
90    ///
91    /// # Examples
92    ///
93    /// ```rust
94    /// use pmcp::utils::{MessageBatcher, BatchingConfig};
95    /// use pmcp::types::{Notification, ClientNotification, ServerNotification};
96    /// use std::time::Duration;
97    ///
98    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
99    /// let config = BatchingConfig {
100    ///     max_batch_size: 3,
101    ///     max_wait_time: Duration::from_millis(100),
102    ///     batched_methods: vec![],
103    /// };
104    /// let batcher = MessageBatcher::new(config);
105    ///
106    /// // Add various notifications
107    /// batcher.add(Notification::Client(ClientNotification::Initialized)).await?;
108    /// batcher.add(Notification::Client(ClientNotification::RootsListChanged)).await?;
109    ///
110    /// // This will trigger immediate batch send (max_batch_size reached)
111    /// batcher.add(Notification::Server(ServerNotification::ToolsChanged)).await?;
112    /// # Ok(())
113    /// # }
114    /// ```
115    pub async fn add(&self, notification: Notification) -> Result<()> {
116        let mut pending = self.pending.lock().await;
117        pending.push(notification);
118
119        if pending.len() >= self.config.max_batch_size {
120            let batch = std::mem::take(&mut *pending);
121            drop(pending);
122            self.tx
123                .send(batch)
124                .await
125                .map_err(|_| crate::error::Error::Internal("Failed to send batch".to_string()))?;
126        } else {
127            drop(pending);
128        }
129
130        Ok(())
131    }
132
133    /// Start the batching timer.
134    ///
135    /// # Examples
136    ///
137    /// ```rust
138    /// use pmcp::utils::{MessageBatcher, BatchingConfig};
139    /// use pmcp::types::{Notification, ClientNotification};
140    /// use std::time::Duration;
141    ///
142    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
143    /// let config = BatchingConfig {
144    ///     max_batch_size: 10,
145    ///     max_wait_time: Duration::from_millis(50),
146    ///     batched_methods: vec![],
147    /// };
148    /// let batcher = MessageBatcher::new(config);
149    ///
150    /// // Start the timer to automatically flush batches
151    /// batcher.start_timer();
152    ///
153    /// // Add notifications that won't reach max_batch_size
154    /// batcher.add(Notification::Client(ClientNotification::Initialized)).await?;
155    ///
156    /// // Timer will ensure this gets sent after max_wait_time
157    /// # Ok(())
158    /// # }
159    /// ```
160    pub fn start_timer(&self) {
161        let pending = self.pending.clone();
162        let tx = self.tx.clone();
163        let max_wait = self.config.max_wait_time;
164
165        tokio::spawn(async move {
166            let mut ticker = interval(max_wait);
167            loop {
168                ticker.tick().await;
169
170                let mut pending_guard = pending.lock().await;
171                if !pending_guard.is_empty() {
172                    let batch = std::mem::take(&mut *pending_guard);
173                    drop(pending_guard);
174                    if tx.send(batch).await.is_err() {
175                        break;
176                    }
177                }
178            }
179        });
180    }
181
182    /// Receive the next batch of notifications.
183    ///
184    /// # Examples
185    ///
186    /// ```rust
187    /// use pmcp::utils::{MessageBatcher, BatchingConfig};
188    /// use pmcp::types::{Notification, ClientNotification};
189    /// use std::time::Duration;
190    ///
191    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
192    /// let config = BatchingConfig {
193    ///     max_batch_size: 2,
194    ///     max_wait_time: Duration::from_millis(100),
195    ///     batched_methods: vec![],
196    /// };
197    /// let batcher = MessageBatcher::new(config);
198    /// batcher.start_timer();
199    ///
200    /// // Add notifications
201    /// batcher.add(Notification::Client(ClientNotification::Initialized)).await?;
202    /// batcher.add(Notification::Client(ClientNotification::RootsListChanged)).await?;
203    ///
204    /// // Receive the batch (triggered by max_batch_size)
205    /// if let Some(batch) = batcher.receive_batch().await {
206    ///     println!("Received batch with {} notifications", batch.len());
207    ///     for notification in batch {
208    ///         // Process each notification
209    ///     }
210    /// }
211    /// # Ok(())
212    /// # }
213    /// ```
214    pub async fn receive_batch(&self) -> Option<Vec<Notification>> {
215        self.rx.lock().await.recv().await
216    }
217}
218
219/// Configuration for debouncing.
220#[derive(Debug, Clone)]
221pub struct DebouncingConfig {
222    /// Time to wait before sending after last update
223    pub wait_time: Duration,
224    /// Methods to debounce (method -> wait time)
225    pub debounced_methods: HashMap<String, Duration>,
226}
227
228impl Default for DebouncingConfig {
229    fn default() -> Self {
230        Self {
231            wait_time: Duration::from_millis(50),
232            debounced_methods: HashMap::new(),
233        }
234    }
235}
236
237/// Message debouncer that delays and coalesces rapid notifications.
238pub struct MessageDebouncer {
239    config: DebouncingConfig,
240    pending: Arc<Mutex<HashMap<String, Notification>>>,
241    timers: Arc<Mutex<HashMap<String, tokio::task::JoinHandle<()>>>>,
242    tx: mpsc::Sender<Notification>,
243    rx: Arc<Mutex<mpsc::Receiver<Notification>>>,
244}
245
246impl std::fmt::Debug for MessageDebouncer {
247    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
248        f.debug_struct("MessageDebouncer")
249            .field("config", &self.config)
250            .finish_non_exhaustive()
251    }
252}
253
254impl MessageDebouncer {
255    /// Create a new message debouncer.
256    ///
257    /// # Examples
258    ///
259    /// ```rust
260    /// use pmcp::utils::{MessageDebouncer, DebouncingConfig};
261    /// use std::time::Duration;
262    /// use std::collections::HashMap;
263    ///
264    /// // Default configuration
265    /// let debouncer = MessageDebouncer::new(DebouncingConfig::default());
266    ///
267    /// // Custom configuration with per-method settings
268    /// let mut config = DebouncingConfig {
269    ///     wait_time: Duration::from_millis(100),
270    ///     debounced_methods: HashMap::new(),
271    /// };
272    ///
273    /// // Different debounce times for different notification types
274    /// config.debounced_methods.insert(
275    ///     "progress.update".to_string(),
276    ///     Duration::from_millis(50)
277    /// );
278    /// config.debounced_methods.insert(
279    ///     "file.changed".to_string(),
280    ///     Duration::from_millis(500)
281    /// );
282    ///
283    /// let custom_debouncer = MessageDebouncer::new(config);
284    /// ```
285    pub fn new(config: DebouncingConfig) -> Self {
286        let (tx, rx) = mpsc::channel(100);
287        Self {
288            config,
289            pending: Arc::new(Mutex::new(HashMap::new())),
290            timers: Arc::new(Mutex::new(HashMap::new())),
291            tx,
292            rx: Arc::new(Mutex::new(rx)),
293        }
294    }
295
296    /// Add a notification to be debounced.
297    ///
298    /// # Examples
299    ///
300    /// ```rust
301    /// use pmcp::utils::{MessageDebouncer, DebouncingConfig};
302    /// use pmcp::types::{Notification, ServerNotification};
303    /// use std::time::Duration;
304    ///
305    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
306    /// let debouncer = MessageDebouncer::new(DebouncingConfig {
307    ///     wait_time: Duration::from_millis(100),
308    ///     debounced_methods: Default::default(),
309    /// });
310    ///
311    /// // Rapid file change notifications
312    /// debouncer.add(
313    ///     "file:/path/to/file.rs".to_string(),
314    ///     Notification::Server(ServerNotification::ResourcesChanged)
315    /// ).await?;
316    ///
317    /// // Another change to same file within debounce window
318    /// tokio::time::sleep(Duration::from_millis(50)).await;
319    /// debouncer.add(
320    ///     "file:/path/to/file.rs".to_string(),
321    ///     Notification::Server(ServerNotification::ResourcesChanged)
322    /// ).await?;
323    ///
324    /// // Only the last notification will be sent after debounce period
325    /// # Ok(())
326    /// # }
327    /// ```
328    pub async fn add(&self, key: String, notification: Notification) -> Result<()> {
329        trace!("Debouncing notification with key: {}", key);
330
331        let wait_time = self
332            .config
333            .debounced_methods
334            .get(&key)
335            .copied()
336            .unwrap_or(self.config.wait_time);
337
338        // Store the latest notification
339        {
340            let mut pending = self.pending.lock().await;
341            pending.insert(key.clone(), notification);
342        }
343
344        // Cancel existing timer
345        {
346            let mut timers = self.timers.lock().await;
347            if let Some(handle) = timers.remove(&key) {
348                handle.abort();
349            }
350        }
351
352        // Start new timer
353        let pending = self.pending.clone();
354        let tx = self.tx.clone();
355        let timers = self.timers.clone();
356        let key_clone = key.clone();
357
358        let handle = tokio::spawn(async move {
359            sleep(wait_time).await;
360
361            // Send the notification
362            let notification = {
363                let mut pending = pending.lock().await;
364                pending.remove(&key_clone)
365            };
366
367            if let Some(notification) = notification {
368                debug!("Sending debounced notification: {}", key_clone);
369                let _ = tx.send(notification).await;
370            }
371
372            // Remove timer handle
373            let mut timers = timers.lock().await;
374            timers.remove(&key_clone);
375        });
376
377        // Store timer handle
378        {
379            let mut timers = self.timers.lock().await;
380            timers.insert(key, handle);
381        }
382
383        Ok(())
384    }
385
386    /// Receive the next debounced notification.
387    ///
388    /// # Examples
389    ///
390    /// ```rust
391    /// use pmcp::utils::{MessageDebouncer, DebouncingConfig};
392    /// use pmcp::types::{Notification, ServerNotification};
393    /// use std::time::Duration;
394    ///
395    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
396    /// let debouncer = MessageDebouncer::new(DebouncingConfig::default());
397    ///
398    /// // Spawn a task to receive notifications
399    /// tokio::spawn(async move {
400    ///     while let Some(notification) = debouncer.receive().await {
401    ///         match notification {
402    ///             Notification::Server(ServerNotification::ResourcesChanged) => {
403    ///                 println!("Resources changed (after debounce)");
404    ///             }
405    ///             _ => {}
406    ///         }
407    ///     }
408    /// });
409    /// # Ok(())
410    /// # }
411    /// ```
412    pub async fn receive(&self) -> Option<Notification> {
413        self.rx.lock().await.recv().await
414    }
415
416    /// Flush all pending notifications immediately.
417    ///
418    /// # Examples
419    ///
420    /// ```rust
421    /// use pmcp::utils::{MessageDebouncer, DebouncingConfig};
422    /// use pmcp::types::{Notification, ServerNotification, ClientNotification};
423    /// use std::time::Duration;
424    ///
425    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
426    /// let debouncer = MessageDebouncer::new(DebouncingConfig {
427    ///     wait_time: Duration::from_secs(5), // Long debounce
428    ///     debounced_methods: Default::default(),
429    /// });
430    ///
431    /// // Add several notifications
432    /// debouncer.add(
433    ///     "resource1".to_string(),
434    ///     Notification::Server(ServerNotification::ResourcesChanged)
435    /// ).await?;
436    /// debouncer.add(
437    ///     "resource2".to_string(),
438    ///     Notification::Client(ClientNotification::RootsListChanged)
439    /// ).await?;
440    ///
441    /// // Flush immediately instead of waiting for debounce
442    /// let pending = debouncer.flush().await;
443    /// println!("Flushed {} pending notifications", pending.len());
444    ///
445    /// // Process all flushed notifications
446    /// for notification in pending {
447    ///     // Handle each notification immediately
448    /// }
449    /// # Ok(())
450    /// # }
451    /// ```
452    pub async fn flush(&self) -> Vec<Notification> {
453        // Cancel all timers
454        {
455            let mut timers = self.timers.lock().await;
456            for (_, handle) in timers.drain() {
457                handle.abort();
458            }
459        }
460
461        // Get all pending notifications
462        let mut pending = self.pending.lock().await;
463        pending.drain().map(|(_, v)| v).collect()
464    }
465}
466
467#[cfg(test)]
468mod tests {
469    use super::*;
470    use crate::types::ClientNotification;
471
472    #[tokio::test]
473    async fn test_message_batcher() {
474        let config = BatchingConfig {
475            max_batch_size: 2,
476            max_wait_time: Duration::from_millis(10),
477            batched_methods: vec![],
478        };
479
480        let batcher = MessageBatcher::new(config);
481        batcher.start_timer();
482
483        // Add two notifications
484        let notif1 = Notification::Client(ClientNotification::Initialized);
485        let notif2 = Notification::Client(ClientNotification::RootsListChanged);
486
487        batcher.add(notif1).await.unwrap();
488        batcher.add(notif2).await.unwrap();
489
490        // Should receive batch immediately (max size reached)
491        let batch = batcher.receive_batch().await;
492        assert!(batch.is_some());
493        assert_eq!(batch.unwrap().len(), 2);
494    }
495
496    #[tokio::test]
497    async fn test_message_debouncer() {
498        let config = DebouncingConfig {
499            wait_time: Duration::from_millis(10),
500            debounced_methods: HashMap::new(),
501        };
502
503        let debouncer = MessageDebouncer::new(config);
504
505        // Add notifications with same key
506        let notif1 = Notification::Client(ClientNotification::Initialized);
507        let notif2 = Notification::Client(ClientNotification::RootsListChanged);
508
509        debouncer.add("test".to_string(), notif1).await.unwrap();
510        tokio::time::sleep(Duration::from_millis(5)).await;
511        debouncer.add("test".to_string(), notif2).await.unwrap();
512
513        // Should only receive the last one after debounce period
514        let received = debouncer.receive().await;
515        assert!(received.is_some());
516
517        // Should be the second notification
518        match received.unwrap() {
519            Notification::Client(ClientNotification::RootsListChanged) => {},
520            _ => panic!("Wrong notification received"),
521        }
522    }
523}