agentic_memory_mcp/streaming/
progress.rs1use std::collections::HashMap;
4use std::sync::Arc;
5
6use tokio::sync::{mpsc, RwLock};
7
8use crate::types::{JsonRpcNotification, McpResult, ProgressParams, ProgressToken};
9
10#[derive(Debug)]
12struct ProgressState {
13 total: Option<f64>,
14 current: f64,
15 cancelled: bool,
16}
17
18pub struct ProgressTracker {
20 active: Arc<RwLock<HashMap<String, ProgressState>>>,
21 notification_tx: mpsc::Sender<JsonRpcNotification>,
22}
23
24impl ProgressTracker {
25 pub fn new(notification_tx: mpsc::Sender<JsonRpcNotification>) -> Self {
27 Self {
28 active: Arc::new(RwLock::new(HashMap::new())),
29 notification_tx,
30 }
31 }
32
33 pub async fn start(&self, total: Option<f64>) -> String {
35 let token = uuid::Uuid::new_v4().to_string();
36 let state = ProgressState {
37 total,
38 current: 0.0,
39 cancelled: false,
40 };
41 self.active.write().await.insert(token.clone(), state);
42 token
43 }
44
45 pub async fn update(&self, token: &str, current: f64) -> McpResult<()> {
47 let total = {
48 let mut active = self.active.write().await;
49 if let Some(state) = active.get_mut(token) {
50 state.current = current;
51 state.total
52 } else {
53 return Ok(());
54 }
55 };
56
57 let params = ProgressParams {
58 progress_token: ProgressToken::String(token.to_string()),
59 progress: current,
60 total,
61 };
62
63 let notification = JsonRpcNotification::new(
64 "notifications/progress".to_string(),
65 Some(serde_json::to_value(params).unwrap_or_default()),
66 );
67
68 let _ = self.notification_tx.send(notification).await;
69 Ok(())
70 }
71
72 pub async fn cancel(&self, token: &str) {
74 let mut active = self.active.write().await;
75 if let Some(state) = active.get_mut(token) {
76 state.cancelled = true;
77 }
78 }
79
80 pub async fn complete(&self, token: &str) {
82 self.active.write().await.remove(token);
83 }
84
85 pub async fn is_cancelled(&self, token: &str) -> bool {
87 self.active
88 .read()
89 .await
90 .get(token)
91 .map(|s| s.cancelled)
92 .unwrap_or(true)
93 }
94}