construct/mcp_server/
progress.rs1use crate::mcp_server::session::ProgressEvent;
32use crate::tools::progress::{ProgressSink, ProgressToken};
33use serde_json::{Value, json};
34use std::sync::atomic::{AtomicU64, Ordering};
35use tokio::sync::broadcast;
36use tokio::sync::mpsc::UnboundedSender;
37
38pub struct McpProgressSink {
48 tx: UnboundedSender<Value>,
49 requested_token: Option<u64>,
50 counter: AtomicU64,
51 session_events: Option<broadcast::Sender<ProgressEvent>>,
52 tool_name: Option<String>,
53}
54
55impl McpProgressSink {
56 pub fn new(tx: UnboundedSender<Value>, requested_token: Option<u64>) -> Self {
60 Self {
61 tx,
62 requested_token,
63 counter: AtomicU64::new(0),
64 session_events: None,
65 tool_name: None,
66 }
67 }
68
69 pub fn with_session(
72 tx: UnboundedSender<Value>,
73 requested_token: Option<u64>,
74 session_events: broadcast::Sender<ProgressEvent>,
75 tool_name: impl Into<String>,
76 ) -> Self {
77 Self {
78 tx,
79 requested_token,
80 counter: AtomicU64::new(0),
81 session_events: Some(session_events),
82 tool_name: Some(tool_name.into()),
83 }
84 }
85}
86
87impl ProgressSink for McpProgressSink {
88 fn new_token(&self) -> ProgressToken {
89 if let Some(t) = self.requested_token {
90 return ProgressToken(t);
91 }
92 ProgressToken(self.counter.fetch_add(1, Ordering::Relaxed))
93 }
94
95 fn notify(
96 &self,
97 token: ProgressToken,
98 progress: u64,
99 total: Option<u64>,
100 message: Option<&str>,
101 ) {
102 let mut params = json!({
103 "progressToken": token.value(),
104 "progress": progress,
105 });
106 if let Some(total) = total {
107 params["total"] = json!(total);
108 }
109 if let Some(msg) = message {
110 params["message"] = json!(msg);
111 }
112 let envelope = json!({
113 "jsonrpc": "2.0",
114 "method": "notifications/progress",
115 "params": params,
116 });
117 let _ = self.tx.send(envelope);
119
120 if let Some(bus) = &self.session_events {
123 let ev = ProgressEvent::new(
124 token.value(),
125 progress,
126 total,
127 message.map(str::to_string),
128 self.tool_name.clone(),
129 );
130 let _ = bus.send(ev);
131 }
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138 use tokio::sync::mpsc;
139
140 #[tokio::test]
141 async fn emits_progress_envelope_with_requested_token() {
142 let (tx, mut rx) = mpsc::unbounded_channel();
143 let sink = McpProgressSink::new(tx, Some(42));
144 let token = sink.new_token();
145 assert_eq!(token.value(), 42);
146 sink.notify(token, 1, Some(3), Some("one of three"));
147 let evt = rx.recv().await.unwrap();
148 assert_eq!(evt["method"], "notifications/progress");
149 assert_eq!(evt["params"]["progressToken"], 42);
150 assert_eq!(evt["params"]["progress"], 1);
151 assert_eq!(evt["params"]["total"], 3);
152 assert_eq!(evt["params"]["message"], "one of three");
153 }
154
155 #[tokio::test]
156 async fn mints_token_when_none_supplied() {
157 let (tx, _rx) = mpsc::unbounded_channel();
158 let sink = McpProgressSink::new(tx, None);
159 let a = sink.new_token();
160 let b = sink.new_token();
161 assert_ne!(a.value(), b.value());
162 }
163
164 #[tokio::test]
165 async fn dual_fanout_publishes_to_session_broadcast() {
166 let (tx, _rx) = mpsc::unbounded_channel();
167 let (bus_tx, mut bus_rx) = broadcast::channel(8);
168 let sink = McpProgressSink::with_session(tx, Some(9), bus_tx, "notion");
169 sink.notify(sink.new_token(), 2, Some(5), Some("doing a thing"));
170 let ev = bus_rx.recv().await.unwrap();
171 assert_eq!(ev.token, 9);
172 assert_eq!(ev.progress, 2);
173 assert_eq!(ev.total, Some(5));
174 assert_eq!(ev.message.as_deref(), Some("doing a thing"));
175 assert_eq!(ev.tool.as_deref(), Some("notion"));
176 assert!(!ev.timestamp.is_empty());
177 }
178
179 #[tokio::test]
180 async fn dual_fanout_swallows_no_subscriber_error() {
181 let (tx, _rx) = mpsc::unbounded_channel();
182 let (bus_tx, bus_rx) = broadcast::channel(8);
183 drop(bus_rx); let sink = McpProgressSink::with_session(tx, None, bus_tx, "notion");
185 sink.notify(sink.new_token(), 1, None, None);
187 }
188}