Skip to main content

harn_vm/
mcp_progress.rs

1//! MCP `notifications/progress` plumbing — server-to-client progress
2//! updates emitted from a long-running tool handler.
3//!
4//! The MCP spec (2025-11-25) lets a client opt into progress updates by
5//! attaching `_meta.progressToken` to a request. While the matching tool
6//! is in flight, the server may emit any number of
7//! `notifications/progress` notifications carrying the same token. The
8//! server must not emit progress for requests without a token, and
9//! progress values must strictly increase per token.
10//!
11//! Mechanics mirror the elicitation bus
12//! (`crate::mcp_elicit::ElicitationBus`): a per-connection [`ProgressBus`]
13//! wraps the transport's outbound JSON sink (installed thread-locally
14//! via [`install_active_bus`]), and a per-call [`ProgressContext`] is
15//! bound for the duration of a tool handler future via [`scope_context`]
16//! (a tokio task-local) so that helpers — notably the
17//! `mcp_report_progress` stdlib builtin — can find the right token
18//! without taking it as an explicit argument. The split between
19//! thread-local bus and task-local context matters: adapters spawn
20//! concurrent tool calls onto a shared `LocalSet`, so a thread-local
21//! context would race across awaits.
22//!
23//! Spec: <https://modelcontextprotocol.io/specification/2025-11-25/basic/utilities/progress>
24
25use std::cell::RefCell;
26use std::sync::{Arc, Mutex};
27
28use serde_json::{json, Value as JsonValue};
29
30/// Outbound JSON sink for progress notifications.
31///
32/// We accept a closure rather than a concrete `mpsc::UnboundedSender` so
33/// HTTP transports — which feed an `axum::Sse` stream via a wrapping
34/// closure — can install the same kind of bus as stdio without an
35/// adapter shim.
36pub type OutboundFn = Arc<dyn Fn(JsonValue) + Send + Sync>;
37
38/// Per-connection progress notifier.
39///
40/// Cheap to clone — every clone shares the same outbound sink and the
41/// same "last-progress" map used to enforce monotonicity.
42#[derive(Clone)]
43pub struct ProgressBus {
44    outbound: OutboundFn,
45    last_progress: Arc<Mutex<std::collections::HashMap<String, f64>>>,
46}
47
48impl std::fmt::Debug for ProgressBus {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        f.debug_struct("ProgressBus").finish_non_exhaustive()
51    }
52}
53
54impl ProgressBus {
55    pub fn new(outbound: OutboundFn) -> Self {
56        Self {
57            outbound,
58            last_progress: Arc::new(Mutex::new(std::collections::HashMap::new())),
59        }
60    }
61
62    /// Convenience constructor that wraps a `tokio::sync::mpsc`
63    /// unbounded sender — the shape used by the stdio MCP server.
64    pub fn from_mpsc(tx: tokio::sync::mpsc::UnboundedSender<JsonValue>) -> Self {
65        Self::new(Arc::new(move |message| {
66            let _ = tx.send(message);
67        }))
68    }
69
70    /// Emit a `notifications/progress` notification for `token`.
71    ///
72    /// Per spec the `progress` value MUST monotonically increase across
73    /// calls with the same token. We silently drop any update that would
74    /// regress so a buggy handler can't violate the contract on the
75    /// wire; the user-visible builtin returns `false` in that case so
76    /// scripts can detect it if they care.
77    pub fn report(
78        &self,
79        token: &JsonValue,
80        progress: f64,
81        total: Option<f64>,
82        message: Option<String>,
83    ) -> bool {
84        if !is_valid_progress_token(token) {
85            return false;
86        }
87        if !progress.is_finite() {
88            return false;
89        }
90        if let Some(total) = total {
91            if !total.is_finite() {
92                return false;
93            }
94        }
95        let key = canonical_token(token);
96        {
97            let mut last = self.last_progress.lock().expect("progress map poisoned");
98            if let Some(previous) = last.get(&key).copied() {
99                if progress <= previous {
100                    return false;
101                }
102            }
103            last.insert(key, progress);
104        }
105        let mut params = serde_json::Map::new();
106        params.insert("progressToken".to_string(), token.clone());
107        params.insert("progress".to_string(), json!(progress));
108        if let Some(total) = total {
109            params.insert("total".to_string(), json!(total));
110        }
111        if let Some(message) = message {
112            params.insert("message".to_string(), JsonValue::String(message));
113        }
114        (self.outbound)(crate::jsonrpc::notification(
115            "notifications/progress",
116            JsonValue::Object(params),
117        ));
118        true
119    }
120}
121
122/// Per-call progress context — the bus plus the token the current
123/// request supplied. Bound for the lifetime of a tool handler future
124/// via [`scope_context`] so that `mcp_report_progress(...)` can find it
125/// without explicit threading.
126#[derive(Clone, Debug)]
127pub struct ProgressContext {
128    pub bus: ProgressBus,
129    pub token: JsonValue,
130}
131
132impl ProgressContext {
133    pub fn new(bus: ProgressBus, token: JsonValue) -> Self {
134        Self { bus, token }
135    }
136
137    pub fn report(&self, progress: f64, total: Option<f64>, message: Option<String>) -> bool {
138        self.bus.report(&self.token, progress, total, message)
139    }
140}
141
142tokio::task_local! {
143    /// Per-call progress context — the token plus the bus. Bound for
144    /// the lifetime of a single tool handler future via
145    /// [`scope_context`]. We use a tokio task-local rather than a
146    /// thread-local because adapters (e.g. `harn-serve`) spawn
147    /// concurrent tool calls onto a shared `LocalSet`; a thread-local
148    /// would let one task's await yield to another that overwrites the
149    /// context, which is the exact race tokio task-locals are designed
150    /// to avoid.
151    static CURRENT_CONTEXT: ProgressContext;
152}
153
154thread_local! {
155    static ACTIVE_BUS: RefCell<Option<ProgressBus>> = const { RefCell::new(None) };
156}
157
158/// Run `future` with `ctx` installed as the active progress context.
159/// When `ctx` is `None`, the future runs without a context (so
160/// [`current_context`] returns `None` from inside it). Use this rather
161/// than installing into a thread-local: tool handlers are async and
162/// concurrent on the same OS thread, so we need per-task scoping.
163pub async fn scope_context<F>(ctx: Option<ProgressContext>, future: F) -> F::Output
164where
165    F: std::future::Future,
166{
167    match ctx {
168        Some(ctx) => CURRENT_CONTEXT.scope(ctx, future).await,
169        None => future.await,
170    }
171}
172
173/// Snapshot the progress context for the current task, if any.
174pub fn current_context() -> Option<ProgressContext> {
175    CURRENT_CONTEXT.try_with(|ctx| ctx.clone()).ok()
176}
177
178/// Install a connection-scoped [`ProgressBus`] for the current thread.
179/// Connections are single-threaded (the dispatch loop owns one OS
180/// thread or LocalSet), so a thread-local is the right scope here:
181/// every per-call task derived from this connection sees the same bus.
182pub fn install_active_bus(bus: Option<ProgressBus>) -> Option<ProgressBus> {
183    ACTIVE_BUS.with(|cell| std::mem::replace(&mut *cell.borrow_mut(), bus))
184}
185
186/// Snapshot the active connection-scoped progress bus, if any.
187pub fn active_bus() -> Option<ProgressBus> {
188    ACTIVE_BUS.with(|cell| cell.borrow().clone())
189}
190
191/// RAII guard for [`install_active_bus`]. Connection-scoped, so a
192/// thread-local guard is correct here.
193pub struct ActiveBusGuard {
194    previous: Option<ProgressBus>,
195}
196
197impl ActiveBusGuard {
198    pub fn install(bus: Option<ProgressBus>) -> Self {
199        Self {
200            previous: install_active_bus(bus),
201        }
202    }
203}
204
205impl Drop for ActiveBusGuard {
206    fn drop(&mut self) {
207        install_active_bus(self.previous.take());
208    }
209}
210
211/// MCP progress tokens are constrained to strings or numbers (no nulls,
212/// objects, arrays, or booleans). Validate at the boundary so we never
213/// echo a malformed token back to the client.
214pub fn is_valid_progress_token(value: &JsonValue) -> bool {
215    matches!(value, JsonValue::String(_) | JsonValue::Number(_))
216}
217
218/// Coerce JSON-RPC tokens (strings or numbers) into a single string key
219/// for the per-token monotonicity check.
220fn canonical_token(value: &JsonValue) -> String {
221    if let Some(s) = value.as_str() {
222        return s.to_string();
223    }
224    if let Some(n) = value.as_i64() {
225        return n.to_string();
226    }
227    if let Some(n) = value.as_u64() {
228        return n.to_string();
229    }
230    if let Some(n) = value.as_f64() {
231        return n.to_string();
232    }
233    value.to_string()
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239    use std::sync::Mutex;
240
241    fn capturing_bus() -> (ProgressBus, Arc<Mutex<Vec<JsonValue>>>) {
242        let captured: Arc<Mutex<Vec<JsonValue>>> = Arc::new(Mutex::new(Vec::new()));
243        let captured_for_sink = captured.clone();
244        let bus = ProgressBus::new(Arc::new(move |message| {
245            captured_for_sink
246                .lock()
247                .expect("captured progress poisoned")
248                .push(message);
249        }));
250        (bus, captured)
251    }
252
253    #[test]
254    fn reports_progress_with_monotonic_check() {
255        let (bus, captured) = capturing_bus();
256        assert!(bus.report(&json!("tok"), 0.25, Some(1.0), Some("a".into())));
257        assert!(bus.report(&json!("tok"), 0.5, Some(1.0), None));
258        assert!(!bus.report(&json!("tok"), 0.5, Some(1.0), None));
259        assert!(!bus.report(&json!("tok"), 0.4, Some(1.0), None));
260        let captured = captured.lock().unwrap();
261        assert_eq!(captured.len(), 2);
262        assert_eq!(captured[0]["method"], json!("notifications/progress"));
263        assert_eq!(captured[0]["params"]["progressToken"], json!("tok"));
264        assert_eq!(captured[0]["params"]["progress"], json!(0.25));
265        assert_eq!(captured[0]["params"]["total"], json!(1.0));
266        assert_eq!(captured[0]["params"]["message"], json!("a"));
267        assert!(captured[1]["params"].get("message").is_none());
268    }
269
270    #[test]
271    fn reports_progress_for_numeric_token_independently() {
272        let (bus, captured) = capturing_bus();
273        assert!(bus.report(&json!(1), 0.1, None, None));
274        assert!(bus.report(&json!("tok"), 0.05, None, None));
275        let captured = captured.lock().unwrap();
276        assert_eq!(captured.len(), 2);
277    }
278
279    #[test]
280    fn rejects_non_finite_or_invalid_token() {
281        let (bus, captured) = capturing_bus();
282        assert!(!bus.report(&JsonValue::Null, 0.1, None, None));
283        assert!(!bus.report(&json!(true), 0.1, None, None));
284        assert!(!bus.report(&json!("tok"), f64::NAN, None, None));
285        assert!(!bus.report(&json!("tok"), 0.1, Some(f64::INFINITY), None));
286        assert!(captured.lock().unwrap().is_empty());
287    }
288
289    #[tokio::test]
290    async fn scope_context_is_visible_inside_and_absent_outside() {
291        assert!(current_context().is_none());
292        let (bus, _) = capturing_bus();
293        let ctx = ProgressContext::new(bus, json!("tok"));
294        scope_context(Some(ctx), async {
295            assert!(current_context().is_some());
296        })
297        .await;
298        assert!(current_context().is_none());
299    }
300
301    #[tokio::test]
302    async fn scope_context_isolates_concurrent_tasks() {
303        let (bus, captured) = capturing_bus();
304        let ctx_a = ProgressContext::new(bus.clone(), json!("a"));
305        let ctx_b = ProgressContext::new(bus, json!("b"));
306        let task_a = scope_context(Some(ctx_a), async {
307            tokio::task::yield_now().await;
308            current_context().unwrap().token.clone()
309        });
310        let task_b = scope_context(Some(ctx_b), async {
311            tokio::task::yield_now().await;
312            current_context().unwrap().token.clone()
313        });
314        let (a, b) = tokio::join!(task_a, task_b);
315        assert_eq!(a, json!("a"));
316        assert_eq!(b, json!("b"));
317        assert!(captured.lock().unwrap().is_empty());
318    }
319}