use std::cell::RefCell;
use std::sync::{Arc, Mutex};
use serde_json::{json, Value as JsonValue};
pub type OutboundFn = Arc<dyn Fn(JsonValue) + Send + Sync>;
#[derive(Clone)]
pub struct ProgressBus {
outbound: OutboundFn,
last_progress: Arc<Mutex<std::collections::HashMap<String, f64>>>,
}
impl std::fmt::Debug for ProgressBus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ProgressBus").finish_non_exhaustive()
}
}
impl ProgressBus {
pub fn new(outbound: OutboundFn) -> Self {
Self {
outbound,
last_progress: Arc::new(Mutex::new(std::collections::HashMap::new())),
}
}
pub fn from_mpsc(tx: tokio::sync::mpsc::UnboundedSender<JsonValue>) -> Self {
Self::new(Arc::new(move |message| {
let _ = tx.send(message);
}))
}
pub fn report(
&self,
token: &JsonValue,
progress: f64,
total: Option<f64>,
message: Option<String>,
) -> bool {
if !is_valid_progress_token(token) {
return false;
}
if !progress.is_finite() {
return false;
}
if let Some(total) = total {
if !total.is_finite() {
return false;
}
}
let key = canonical_token(token);
{
let mut last = self.last_progress.lock().expect("progress map poisoned");
if let Some(previous) = last.get(&key).copied() {
if progress <= previous {
return false;
}
}
last.insert(key, progress);
}
let mut params = serde_json::Map::new();
params.insert("progressToken".to_string(), token.clone());
params.insert("progress".to_string(), json!(progress));
if let Some(total) = total {
params.insert("total".to_string(), json!(total));
}
if let Some(message) = message {
params.insert("message".to_string(), JsonValue::String(message));
}
(self.outbound)(crate::jsonrpc::notification(
"notifications/progress",
JsonValue::Object(params),
));
true
}
}
#[derive(Clone, Debug)]
pub struct ProgressContext {
pub bus: ProgressBus,
pub token: JsonValue,
}
impl ProgressContext {
pub fn new(bus: ProgressBus, token: JsonValue) -> Self {
Self { bus, token }
}
pub fn report(&self, progress: f64, total: Option<f64>, message: Option<String>) -> bool {
self.bus.report(&self.token, progress, total, message)
}
}
tokio::task_local! {
static CURRENT_CONTEXT: ProgressContext;
}
thread_local! {
static ACTIVE_BUS: RefCell<Option<ProgressBus>> = const { RefCell::new(None) };
}
pub async fn scope_context<F>(ctx: Option<ProgressContext>, future: F) -> F::Output
where
F: std::future::Future,
{
match ctx {
Some(ctx) => CURRENT_CONTEXT.scope(ctx, future).await,
None => future.await,
}
}
pub fn current_context() -> Option<ProgressContext> {
CURRENT_CONTEXT.try_with(|ctx| ctx.clone()).ok()
}
pub fn install_active_bus(bus: Option<ProgressBus>) -> Option<ProgressBus> {
ACTIVE_BUS.with(|cell| std::mem::replace(&mut *cell.borrow_mut(), bus))
}
pub fn active_bus() -> Option<ProgressBus> {
ACTIVE_BUS.with(|cell| cell.borrow().clone())
}
pub struct ActiveBusGuard {
previous: Option<ProgressBus>,
}
impl ActiveBusGuard {
pub fn install(bus: Option<ProgressBus>) -> Self {
Self {
previous: install_active_bus(bus),
}
}
}
impl Drop for ActiveBusGuard {
fn drop(&mut self) {
install_active_bus(self.previous.take());
}
}
pub fn is_valid_progress_token(value: &JsonValue) -> bool {
matches!(value, JsonValue::String(_) | JsonValue::Number(_))
}
fn canonical_token(value: &JsonValue) -> String {
if let Some(s) = value.as_str() {
return s.to_string();
}
if let Some(n) = value.as_i64() {
return n.to_string();
}
if let Some(n) = value.as_u64() {
return n.to_string();
}
if let Some(n) = value.as_f64() {
return n.to_string();
}
value.to_string()
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
fn capturing_bus() -> (ProgressBus, Arc<Mutex<Vec<JsonValue>>>) {
let captured: Arc<Mutex<Vec<JsonValue>>> = Arc::new(Mutex::new(Vec::new()));
let captured_for_sink = captured.clone();
let bus = ProgressBus::new(Arc::new(move |message| {
captured_for_sink
.lock()
.expect("captured progress poisoned")
.push(message);
}));
(bus, captured)
}
#[test]
fn reports_progress_with_monotonic_check() {
let (bus, captured) = capturing_bus();
assert!(bus.report(&json!("tok"), 0.25, Some(1.0), Some("a".into())));
assert!(bus.report(&json!("tok"), 0.5, Some(1.0), None));
assert!(!bus.report(&json!("tok"), 0.5, Some(1.0), None));
assert!(!bus.report(&json!("tok"), 0.4, Some(1.0), None));
let captured = captured.lock().unwrap();
assert_eq!(captured.len(), 2);
assert_eq!(captured[0]["method"], json!("notifications/progress"));
assert_eq!(captured[0]["params"]["progressToken"], json!("tok"));
assert_eq!(captured[0]["params"]["progress"], json!(0.25));
assert_eq!(captured[0]["params"]["total"], json!(1.0));
assert_eq!(captured[0]["params"]["message"], json!("a"));
assert!(captured[1]["params"].get("message").is_none());
}
#[test]
fn reports_progress_for_numeric_token_independently() {
let (bus, captured) = capturing_bus();
assert!(bus.report(&json!(1), 0.1, None, None));
assert!(bus.report(&json!("tok"), 0.05, None, None));
let captured = captured.lock().unwrap();
assert_eq!(captured.len(), 2);
}
#[test]
fn rejects_non_finite_or_invalid_token() {
let (bus, captured) = capturing_bus();
assert!(!bus.report(&JsonValue::Null, 0.1, None, None));
assert!(!bus.report(&json!(true), 0.1, None, None));
assert!(!bus.report(&json!("tok"), f64::NAN, None, None));
assert!(!bus.report(&json!("tok"), 0.1, Some(f64::INFINITY), None));
assert!(captured.lock().unwrap().is_empty());
}
#[tokio::test]
async fn scope_context_is_visible_inside_and_absent_outside() {
assert!(current_context().is_none());
let (bus, _) = capturing_bus();
let ctx = ProgressContext::new(bus, json!("tok"));
scope_context(Some(ctx), async {
assert!(current_context().is_some());
})
.await;
assert!(current_context().is_none());
}
#[tokio::test]
async fn scope_context_isolates_concurrent_tasks() {
let (bus, captured) = capturing_bus();
let ctx_a = ProgressContext::new(bus.clone(), json!("a"));
let ctx_b = ProgressContext::new(bus, json!("b"));
let task_a = scope_context(Some(ctx_a), async {
tokio::task::yield_now().await;
current_context().unwrap().token.clone()
});
let task_b = scope_context(Some(ctx_b), async {
tokio::task::yield_now().await;
current_context().unwrap().token.clone()
});
let (a, b) = tokio::join!(task_a, task_b);
assert_eq!(a, json!("a"));
assert_eq!(b, json!("b"));
assert!(captured.lock().unwrap().is_empty());
}
}