mcp_host/managers/
progress.rs1use dashmap::DashMap;
8use serde_json::json;
9use std::sync::Arc;
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
12use tokio::sync::mpsc;
13
14use crate::transport::traits::JsonRpcNotification;
15
16#[derive(Debug, Clone)]
18struct OperationState {
19 #[allow(dead_code)]
21 op_type: String,
22 #[allow(dead_code)]
24 op_id: String,
25 total: u32,
27 current: u32,
29 start_time: Instant,
31}
32
33pub type ProgressToken = String;
35
36#[derive(Clone)]
47pub struct ProgressTracker {
48 token_counter: Arc<AtomicU64>,
50 operations: Arc<DashMap<String, OperationState>>,
52 notif_tx: mpsc::UnboundedSender<JsonRpcNotification>,
54}
55
56impl ProgressTracker {
57 pub fn new(notif_tx: mpsc::UnboundedSender<JsonRpcNotification>) -> Self {
63 Self {
64 token_counter: Arc::new(AtomicU64::new(0)),
65 operations: Arc::new(DashMap::new()),
66 notif_tx,
67 }
68 }
69
70 fn generate_token(&self, op_type: &str) -> ProgressToken {
72 let counter = self.token_counter.fetch_add(1, Ordering::SeqCst);
73 let timestamp = SystemTime::now()
74 .duration_since(UNIX_EPOCH)
75 .unwrap_or_default()
76 .as_secs();
77 format!("{op_type}-{timestamp}-{counter}")
78 }
79
80 pub fn start(&self, op_type: &str, total: u32) -> ProgressToken {
93 let token = self.generate_token(op_type);
94 let op_id = token.clone();
95
96 let op_state = OperationState {
97 op_type: op_type.to_string(),
98 op_id,
99 total,
100 current: 0,
101 start_time: Instant::now(),
102 };
103
104 self.operations.insert(token.clone(), op_state);
105
106 self.send_notification(&token, 0, total);
108
109 token
110 }
111
112 pub fn progress(&self, token: &ProgressToken, current: u32) -> bool {
123 if let Some(mut op) = self.operations.get_mut(token) {
124 op.current = current.min(op.total);
126 let progress = op.current;
127 let total = op.total;
128 drop(op); self.send_notification(token, progress, total);
131 true
132 } else {
133 false
134 }
135 }
136
137 pub fn finish(&self, token: &ProgressToken) {
145 if let Some((_key, op)) = self.operations.remove(token) {
146 self.send_notification(token, op.total, op.total);
147 }
148 }
149
150 fn send_notification(&self, token: &ProgressToken, progress: u32, total: u32) {
152 let notification = JsonRpcNotification::new(
153 "notifications/progress",
154 Some(json!({
155 "progressToken": token,
156 "progress": progress,
157 "total": total
158 })),
159 );
160
161 let _ = self.notif_tx.send(notification);
162 }
163
164 pub fn cleanup_stale(&self) {
169 let now = Instant::now();
170 let timeout = Duration::from_secs(600); self.operations
173 .retain(|_token, op| now.duration_since(op.start_time) < timeout);
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180
181 fn create_tracker() -> ProgressTracker {
182 let (tx, _rx) = mpsc::unbounded_channel();
183 ProgressTracker::new(tx)
184 }
185
186 #[test]
187 fn test_token_format() {
188 let tracker = create_tracker();
189 let token = tracker.generate_token("stalint");
190
191 let parts: Vec<&str> = token.split('-').collect();
193 assert_eq!(parts.len(), 3);
194 assert_eq!(parts[0], "stalint");
195 assert!(parts[1].parse::<u64>().is_ok());
197 assert!(parts[2].parse::<u64>().is_ok());
199 }
200
201 #[test]
202 fn test_token_uniqueness() {
203 let tracker = create_tracker();
204 let token1 = tracker.generate_token("stalint");
205 let token2 = tracker.generate_token("stalint");
206
207 assert_ne!(token1, token2);
208 }
209
210 #[test]
211 fn test_start_registers_operation() {
212 let tracker = create_tracker();
213 let token = tracker.start("stalint", 100);
214
215 assert!(tracker.operations.contains_key(&token));
216 let op = tracker.operations.get(&token).unwrap();
217 assert_eq!(op.current, 0);
218 assert_eq!(op.total, 100);
219 }
220
221 #[test]
222 fn test_progress_updates() {
223 let tracker = create_tracker();
224 let token = tracker.start("dictator", 50);
225
226 assert!(tracker.progress(&token, 25));
228 let op = tracker.operations.get(&token).unwrap();
229 assert_eq!(op.current, 25);
230 assert_eq!(op.total, 50);
231 drop(op);
232
233 assert!(tracker.progress(&token, 50));
235 let op = tracker.operations.get(&token).unwrap();
236 assert_eq!(op.current, 50);
237 assert_eq!(op.total, 50);
238 }
239
240 #[test]
241 fn test_progress_clamps_to_total() {
242 let tracker = create_tracker();
243 let token = tracker.start("supremecourt", 10);
244
245 assert!(tracker.progress(&token, 99));
247 let op = tracker.operations.get(&token).unwrap();
248 assert_eq!(op.current, 10); assert_eq!(op.total, 10);
250 }
251
252 #[test]
253 fn test_finish_removes_operation() {
254 let tracker = create_tracker();
255 let token = tracker.start("stalint", 100);
256
257 assert!(tracker.operations.contains_key(&token));
258 tracker.finish(&token);
259 assert!(!tracker.operations.contains_key(&token));
260 }
261
262 #[test]
263 fn test_progress_unknown_token() {
264 let tracker = create_tracker();
265
266 assert!(!tracker.progress(&"unknown-token".to_string(), 50));
268 }
269
270 #[test]
271 fn test_multiple_concurrent_operations() {
272 let tracker = create_tracker();
273 let token1 = tracker.start("stalint", 100);
274 let token2 = tracker.start("dictator", 50);
275
276 assert!(tracker.operations.contains_key(&token1));
277 assert!(tracker.operations.contains_key(&token2));
278
279 assert!(tracker.progress(&token1, 50));
281 assert!(tracker.progress(&token2, 25));
282
283 let op1 = tracker.operations.get(&token1).unwrap();
284 let op2 = tracker.operations.get(&token2).unwrap();
285
286 assert_eq!((op1.current, op1.total), (50, 100));
287 assert_eq!((op2.current, op2.total), (25, 50));
288 drop(op1);
289 drop(op2);
290
291 tracker.finish(&token1);
293 assert!(!tracker.operations.contains_key(&token1));
294 assert!(tracker.operations.contains_key(&token2));
295 }
296
297 #[test]
298 fn test_cleanup_stale() {
299 let tracker = create_tracker();
300 let token = tracker.start("test", 100);
301
302 assert!(tracker.operations.contains_key(&token));
303
304 if let Some(mut op) = tracker.operations.get_mut(&token) {
306 op.start_time = Instant::now() - Duration::from_secs(700); }
308
309 tracker.cleanup_stale();
310
311 assert!(!tracker.operations.contains_key(&token));
313 }
314}