1use std::cell::RefCell;
26use std::sync::{Arc, Mutex};
27
28use serde_json::{json, Value as JsonValue};
29
30pub type OutboundFn = Arc<dyn Fn(JsonValue) + Send + Sync>;
37
38#[derive(Clone)]
43pub struct ProgressBus {
44 outbound: OutboundFn,
45 last_progress: Arc<Mutex<std::collections::HashMap<String, f64>>>,
46}
47
48impl std::fmt::Debug for ProgressBus {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 f.debug_struct("ProgressBus").finish_non_exhaustive()
51 }
52}
53
54impl ProgressBus {
55 pub fn new(outbound: OutboundFn) -> Self {
56 Self {
57 outbound,
58 last_progress: Arc::new(Mutex::new(std::collections::HashMap::new())),
59 }
60 }
61
62 pub fn from_mpsc(tx: tokio::sync::mpsc::UnboundedSender<JsonValue>) -> Self {
65 Self::new(Arc::new(move |message| {
66 let _ = tx.send(message);
67 }))
68 }
69
70 pub fn report(
78 &self,
79 token: &JsonValue,
80 progress: f64,
81 total: Option<f64>,
82 message: Option<String>,
83 ) -> bool {
84 if !is_valid_progress_token(token) {
85 return false;
86 }
87 if !progress.is_finite() {
88 return false;
89 }
90 if let Some(total) = total {
91 if !total.is_finite() {
92 return false;
93 }
94 }
95 let key = canonical_token(token);
96 {
97 let mut last = self.last_progress.lock().expect("progress map poisoned");
98 if let Some(previous) = last.get(&key).copied() {
99 if progress <= previous {
100 return false;
101 }
102 }
103 last.insert(key, progress);
104 }
105 let mut params = serde_json::Map::new();
106 params.insert("progressToken".to_string(), token.clone());
107 params.insert("progress".to_string(), json!(progress));
108 if let Some(total) = total {
109 params.insert("total".to_string(), json!(total));
110 }
111 if let Some(message) = message {
112 params.insert("message".to_string(), JsonValue::String(message));
113 }
114 (self.outbound)(crate::jsonrpc::notification(
115 "notifications/progress",
116 JsonValue::Object(params),
117 ));
118 true
119 }
120}
121
122#[derive(Clone, Debug)]
127pub struct ProgressContext {
128 pub bus: ProgressBus,
129 pub token: JsonValue,
130}
131
132impl ProgressContext {
133 pub fn new(bus: ProgressBus, token: JsonValue) -> Self {
134 Self { bus, token }
135 }
136
137 pub fn report(&self, progress: f64, total: Option<f64>, message: Option<String>) -> bool {
138 self.bus.report(&self.token, progress, total, message)
139 }
140}
141
142tokio::task_local! {
143 static CURRENT_CONTEXT: ProgressContext;
152}
153
154thread_local! {
155 static ACTIVE_BUS: RefCell<Option<ProgressBus>> = const { RefCell::new(None) };
156}
157
158pub async fn scope_context<F>(ctx: Option<ProgressContext>, future: F) -> F::Output
164where
165 F: std::future::Future,
166{
167 match ctx {
168 Some(ctx) => CURRENT_CONTEXT.scope(ctx, future).await,
169 None => future.await,
170 }
171}
172
173pub fn current_context() -> Option<ProgressContext> {
175 CURRENT_CONTEXT.try_with(|ctx| ctx.clone()).ok()
176}
177
178pub fn install_active_bus(bus: Option<ProgressBus>) -> Option<ProgressBus> {
183 ACTIVE_BUS.with(|cell| std::mem::replace(&mut *cell.borrow_mut(), bus))
184}
185
186pub fn active_bus() -> Option<ProgressBus> {
188 ACTIVE_BUS.with(|cell| cell.borrow().clone())
189}
190
191pub struct ActiveBusGuard {
194 previous: Option<ProgressBus>,
195}
196
197impl ActiveBusGuard {
198 pub fn install(bus: Option<ProgressBus>) -> Self {
199 Self {
200 previous: install_active_bus(bus),
201 }
202 }
203}
204
205impl Drop for ActiveBusGuard {
206 fn drop(&mut self) {
207 install_active_bus(self.previous.take());
208 }
209}
210
211pub fn is_valid_progress_token(value: &JsonValue) -> bool {
215 matches!(value, JsonValue::String(_) | JsonValue::Number(_))
216}
217
218fn canonical_token(value: &JsonValue) -> String {
221 if let Some(s) = value.as_str() {
222 return s.to_string();
223 }
224 if let Some(n) = value.as_i64() {
225 return n.to_string();
226 }
227 if let Some(n) = value.as_u64() {
228 return n.to_string();
229 }
230 if let Some(n) = value.as_f64() {
231 return n.to_string();
232 }
233 value.to_string()
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239 use std::sync::Mutex;
240
241 fn capturing_bus() -> (ProgressBus, Arc<Mutex<Vec<JsonValue>>>) {
242 let captured: Arc<Mutex<Vec<JsonValue>>> = Arc::new(Mutex::new(Vec::new()));
243 let captured_for_sink = captured.clone();
244 let bus = ProgressBus::new(Arc::new(move |message| {
245 captured_for_sink
246 .lock()
247 .expect("captured progress poisoned")
248 .push(message);
249 }));
250 (bus, captured)
251 }
252
253 #[test]
254 fn reports_progress_with_monotonic_check() {
255 let (bus, captured) = capturing_bus();
256 assert!(bus.report(&json!("tok"), 0.25, Some(1.0), Some("a".into())));
257 assert!(bus.report(&json!("tok"), 0.5, Some(1.0), None));
258 assert!(!bus.report(&json!("tok"), 0.5, Some(1.0), None));
259 assert!(!bus.report(&json!("tok"), 0.4, Some(1.0), None));
260 let captured = captured.lock().unwrap();
261 assert_eq!(captured.len(), 2);
262 assert_eq!(captured[0]["method"], json!("notifications/progress"));
263 assert_eq!(captured[0]["params"]["progressToken"], json!("tok"));
264 assert_eq!(captured[0]["params"]["progress"], json!(0.25));
265 assert_eq!(captured[0]["params"]["total"], json!(1.0));
266 assert_eq!(captured[0]["params"]["message"], json!("a"));
267 assert!(captured[1]["params"].get("message").is_none());
268 }
269
270 #[test]
271 fn reports_progress_for_numeric_token_independently() {
272 let (bus, captured) = capturing_bus();
273 assert!(bus.report(&json!(1), 0.1, None, None));
274 assert!(bus.report(&json!("tok"), 0.05, None, None));
275 let captured = captured.lock().unwrap();
276 assert_eq!(captured.len(), 2);
277 }
278
279 #[test]
280 fn rejects_non_finite_or_invalid_token() {
281 let (bus, captured) = capturing_bus();
282 assert!(!bus.report(&JsonValue::Null, 0.1, None, None));
283 assert!(!bus.report(&json!(true), 0.1, None, None));
284 assert!(!bus.report(&json!("tok"), f64::NAN, None, None));
285 assert!(!bus.report(&json!("tok"), 0.1, Some(f64::INFINITY), None));
286 assert!(captured.lock().unwrap().is_empty());
287 }
288
289 #[tokio::test]
290 async fn scope_context_is_visible_inside_and_absent_outside() {
291 assert!(current_context().is_none());
292 let (bus, _) = capturing_bus();
293 let ctx = ProgressContext::new(bus, json!("tok"));
294 scope_context(Some(ctx), async {
295 assert!(current_context().is_some());
296 })
297 .await;
298 assert!(current_context().is_none());
299 }
300
301 #[tokio::test]
302 async fn scope_context_isolates_concurrent_tasks() {
303 let (bus, captured) = capturing_bus();
304 let ctx_a = ProgressContext::new(bus.clone(), json!("a"));
305 let ctx_b = ProgressContext::new(bus, json!("b"));
306 let task_a = scope_context(Some(ctx_a), async {
307 tokio::task::yield_now().await;
308 current_context().unwrap().token.clone()
309 });
310 let task_b = scope_context(Some(ctx_b), async {
311 tokio::task::yield_now().await;
312 current_context().unwrap().token.clone()
313 });
314 let (a, b) = tokio::join!(task_a, task_b);
315 assert_eq!(a, json!("a"));
316 assert_eq!(b, json!("b"));
317 assert!(captured.lock().unwrap().is_empty());
318 }
319}