Skip to main content

agentic_memory_mcp/streaming/
progress.rs

1//! Progress token handling for long-running operations.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use tokio::sync::{mpsc, RwLock};
7
8use crate::types::{JsonRpcNotification, McpResult, ProgressParams, ProgressToken};
9
10/// State of a tracked progress operation.
11#[derive(Debug)]
12struct ProgressState {
13    total: Option<f64>,
14    current: f64,
15    cancelled: bool,
16}
17
18/// Tracks progress for long-running operations and sends notifications.
19pub struct ProgressTracker {
20    active: Arc<RwLock<HashMap<String, ProgressState>>>,
21    notification_tx: mpsc::Sender<JsonRpcNotification>,
22}
23
24impl ProgressTracker {
25    /// Create a new progress tracker with a notification channel.
26    pub fn new(notification_tx: mpsc::Sender<JsonRpcNotification>) -> Self {
27        Self {
28            active: Arc::new(RwLock::new(HashMap::new())),
29            notification_tx,
30        }
31    }
32
33    /// Start tracking a new operation. Returns a unique token.
34    pub async fn start(&self, total: Option<f64>) -> String {
35        let token = uuid::Uuid::new_v4().to_string();
36        let state = ProgressState {
37            total,
38            current: 0.0,
39            cancelled: false,
40        };
41        self.active.write().await.insert(token.clone(), state);
42        token
43    }
44
45    /// Update the progress of an operation.
46    pub async fn update(&self, token: &str, current: f64) -> McpResult<()> {
47        let total = {
48            let mut active = self.active.write().await;
49            if let Some(state) = active.get_mut(token) {
50                state.current = current;
51                state.total
52            } else {
53                return Ok(());
54            }
55        };
56
57        let params = ProgressParams {
58            progress_token: ProgressToken::String(token.to_string()),
59            progress: current,
60            total,
61        };
62
63        let notification = JsonRpcNotification::new(
64            "notifications/progress".to_string(),
65            Some(serde_json::to_value(params).unwrap_or_default()),
66        );
67
68        let _ = self.notification_tx.send(notification).await;
69        Ok(())
70    }
71
72    /// Mark an operation as cancelled.
73    pub async fn cancel(&self, token: &str) {
74        let mut active = self.active.write().await;
75        if let Some(state) = active.get_mut(token) {
76            state.cancelled = true;
77        }
78    }
79
80    /// Complete and remove an operation.
81    pub async fn complete(&self, token: &str) {
82        self.active.write().await.remove(token);
83    }
84
85    /// Check if an operation has been cancelled.
86    pub async fn is_cancelled(&self, token: &str) -> bool {
87        self.active
88            .read()
89            .await
90            .get(token)
91            .map(|s| s.cancelled)
92            .unwrap_or(true)
93    }
94}