mcp_host/managers/
progress.rs

1//! MCP progress tracking for long-running operations.
2//!
3//! Provides progress notifications with unique tokens, tracking operation state,
4//! and automatic cleanup of stale tokens. Tokens are formatted as:
5//! `"{op_type}-{unix_timestamp}-{counter}"`
6
7use dashmap::DashMap;
8use serde_json::json;
9use std::sync::Arc;
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
12use tokio::sync::mpsc;
13
14use crate::transport::traits::JsonRpcNotification;
15
16/// Tracks state of an active operation
17#[derive(Debug, Clone)]
18struct OperationState {
19    /// Operation type: "stalint", "dictator", etc.
20    #[allow(dead_code)]
21    op_type: String,
22    /// Stable operation ID (same as token)
23    #[allow(dead_code)]
24    op_id: String,
25    /// Total items to process
26    total: u32,
27    /// Currently processed items
28    current: u32,
29    /// When operation started
30    start_time: Instant,
31}
32
33/// Progress token for tracking an operation
34pub type ProgressToken = String;
35
36/// Progress tracker for long-running MCP operations
37///
38/// # Example
39///
40/// ```rust,ignore
41/// let tracker = ProgressTracker::new(notification_tx);
42/// let token = tracker.start("my-operation", 100);
43/// tracker.progress(&token, 50);  // 50% complete
44/// tracker.finish(&token);        // 100% complete
45/// ```
46#[derive(Clone)]
47pub struct ProgressTracker {
48    /// Global counter for unique progress tokens
49    token_counter: Arc<AtomicU64>,
50    /// Active operations: token -> state
51    operations: Arc<DashMap<String, OperationState>>,
52    /// Notification channel sender
53    notif_tx: mpsc::UnboundedSender<JsonRpcNotification>,
54}
55
56impl ProgressTracker {
57    /// Create a new progress tracker
58    ///
59    /// # Arguments
60    ///
61    /// * `notif_tx` - Notification channel for sending progress updates
62    pub fn new(notif_tx: mpsc::UnboundedSender<JsonRpcNotification>) -> Self {
63        Self {
64            token_counter: Arc::new(AtomicU64::new(0)),
65            operations: Arc::new(DashMap::new()),
66            notif_tx,
67        }
68    }
69
70    /// Generate a unique progress token: "{op_type}-{timestamp}-{counter}"
71    fn generate_token(&self, op_type: &str) -> ProgressToken {
72        let counter = self.token_counter.fetch_add(1, Ordering::SeqCst);
73        let timestamp = SystemTime::now()
74            .duration_since(UNIX_EPOCH)
75            .unwrap_or_default()
76            .as_secs();
77        format!("{op_type}-{timestamp}-{counter}")
78    }
79
80    /// Start a new operation and return its unique progress token
81    ///
82    /// Sends initial progress notification (0%)
83    ///
84    /// # Arguments
85    ///
86    /// * `op_type` - Operation type identifier (e.g., "lint", "fix")
87    /// * `total` - Total number of items to process
88    ///
89    /// # Returns
90    ///
91    /// Unique progress token to use for updates
92    pub fn start(&self, op_type: &str, total: u32) -> ProgressToken {
93        let token = self.generate_token(op_type);
94        let op_id = token.clone();
95
96        let op_state = OperationState {
97            op_type: op_type.to_string(),
98            op_id,
99            total,
100            current: 0,
101            start_time: Instant::now(),
102        };
103
104        self.operations.insert(token.clone(), op_state);
105
106        // Send initial progress notification (0%)
107        self.send_notification(&token, 0, total);
108
109        token
110    }
111
112    /// Update progress for an operation (0..=total)
113    ///
114    /// # Arguments
115    ///
116    /// * `token` - Progress token from `start()`
117    /// * `current` - Current progress (will be clamped to [0, total])
118    ///
119    /// # Returns
120    ///
121    /// `true` if operation is still active, `false` if unknown/expired
122    pub fn progress(&self, token: &ProgressToken, current: u32) -> bool {
123        if let Some(mut op) = self.operations.get_mut(token) {
124            // Clamp progress to [0, total]
125            op.current = current.min(op.total);
126            let progress = op.current;
127            let total = op.total;
128            drop(op); // Release lock before sending notification
129
130            self.send_notification(token, progress, total);
131            true
132        } else {
133            false
134        }
135    }
136
137    /// Mark operation complete (final notification at 100%)
138    ///
139    /// Removes the operation from tracking and sends final progress notification.
140    ///
141    /// # Arguments
142    ///
143    /// * `token` - Progress token from `start()`
144    pub fn finish(&self, token: &ProgressToken) {
145        if let Some((_key, op)) = self.operations.remove(token) {
146            self.send_notification(token, op.total, op.total);
147        }
148    }
149
150    /// Send progress notification to client
151    fn send_notification(&self, token: &ProgressToken, progress: u32, total: u32) {
152        let notification = JsonRpcNotification::new(
153            "notifications/progress",
154            Some(json!({
155                "progressToken": token,
156                "progress": progress,
157                "total": total
158            })),
159        );
160
161        let _ = self.notif_tx.send(notification);
162    }
163
164    /// Clean up stale operations (older than 10 minutes)
165    ///
166    /// Call periodically from background tasks. Only performs cleanup
167    /// if at least 60 seconds have passed since last cleanup.
168    pub fn cleanup_stale(&self) {
169        let now = Instant::now();
170        let timeout = Duration::from_secs(600); // 10 minutes
171
172        self.operations
173            .retain(|_token, op| now.duration_since(op.start_time) < timeout);
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180
181    fn create_tracker() -> ProgressTracker {
182        let (tx, _rx) = mpsc::unbounded_channel();
183        ProgressTracker::new(tx)
184    }
185
186    #[test]
187    fn test_token_format() {
188        let tracker = create_tracker();
189        let token = tracker.generate_token("stalint");
190
191        // Format: "{op_type}-{timestamp}-{counter}"
192        let parts: Vec<&str> = token.split('-').collect();
193        assert_eq!(parts.len(), 3);
194        assert_eq!(parts[0], "stalint");
195        // parts[1] is unix timestamp
196        assert!(parts[1].parse::<u64>().is_ok());
197        // parts[2] is counter
198        assert!(parts[2].parse::<u64>().is_ok());
199    }
200
201    #[test]
202    fn test_token_uniqueness() {
203        let tracker = create_tracker();
204        let token1 = tracker.generate_token("stalint");
205        let token2 = tracker.generate_token("stalint");
206
207        assert_ne!(token1, token2);
208    }
209
210    #[test]
211    fn test_start_registers_operation() {
212        let tracker = create_tracker();
213        let token = tracker.start("stalint", 100);
214
215        assert!(tracker.operations.contains_key(&token));
216        let op = tracker.operations.get(&token).unwrap();
217        assert_eq!(op.current, 0);
218        assert_eq!(op.total, 100);
219    }
220
221    #[test]
222    fn test_progress_updates() {
223        let tracker = create_tracker();
224        let token = tracker.start("dictator", 50);
225
226        // Update to 50%
227        assert!(tracker.progress(&token, 25));
228        let op = tracker.operations.get(&token).unwrap();
229        assert_eq!(op.current, 25);
230        assert_eq!(op.total, 50);
231        drop(op);
232
233        // Update to 100%
234        assert!(tracker.progress(&token, 50));
235        let op = tracker.operations.get(&token).unwrap();
236        assert_eq!(op.current, 50);
237        assert_eq!(op.total, 50);
238    }
239
240    #[test]
241    fn test_progress_clamps_to_total() {
242        let tracker = create_tracker();
243        let token = tracker.start("supremecourt", 10);
244
245        // Try to set beyond total
246        assert!(tracker.progress(&token, 99));
247        let op = tracker.operations.get(&token).unwrap();
248        assert_eq!(op.current, 10); // Clamped to total
249        assert_eq!(op.total, 10);
250    }
251
252    #[test]
253    fn test_finish_removes_operation() {
254        let tracker = create_tracker();
255        let token = tracker.start("stalint", 100);
256
257        assert!(tracker.operations.contains_key(&token));
258        tracker.finish(&token);
259        assert!(!tracker.operations.contains_key(&token));
260    }
261
262    #[test]
263    fn test_progress_unknown_token() {
264        let tracker = create_tracker();
265
266        // Unknown token should return false
267        assert!(!tracker.progress(&"unknown-token".to_string(), 50));
268    }
269
270    #[test]
271    fn test_multiple_concurrent_operations() {
272        let tracker = create_tracker();
273        let token1 = tracker.start("stalint", 100);
274        let token2 = tracker.start("dictator", 50);
275
276        assert!(tracker.operations.contains_key(&token1));
277        assert!(tracker.operations.contains_key(&token2));
278
279        // Update both independently
280        assert!(tracker.progress(&token1, 50));
281        assert!(tracker.progress(&token2, 25));
282
283        let op1 = tracker.operations.get(&token1).unwrap();
284        let op2 = tracker.operations.get(&token2).unwrap();
285
286        assert_eq!((op1.current, op1.total), (50, 100));
287        assert_eq!((op2.current, op2.total), (25, 50));
288        drop(op1);
289        drop(op2);
290
291        // Finish one
292        tracker.finish(&token1);
293        assert!(!tracker.operations.contains_key(&token1));
294        assert!(tracker.operations.contains_key(&token2));
295    }
296
297    #[test]
298    fn test_cleanup_stale() {
299        let tracker = create_tracker();
300        let token = tracker.start("test", 100);
301
302        assert!(tracker.operations.contains_key(&token));
303
304        // Manually age the operation
305        if let Some(mut op) = tracker.operations.get_mut(&token) {
306            op.start_time = Instant::now() - Duration::from_secs(700); // 11+ minutes ago
307        }
308
309        tracker.cleanup_stale();
310
311        // Should be removed
312        assert!(!tracker.operations.contains_key(&token));
313    }
314}