use std::io::{self, Write};
use std::sync::atomic::{AtomicU64, Ordering};
use serde_json::json;
static PROGRESS_COUNTER: AtomicU64 = AtomicU64::new(0);
#[must_use]
pub fn next_progress_token() -> String {
let n = PROGRESS_COUNTER.fetch_add(1, Ordering::Relaxed);
format!("{n:016x}")
}
pub fn emit_progress<W: Write>(
out: &mut W,
token: &str,
progress: u32,
total: u32,
message: &str,
) -> io::Result<()> {
let params = if message.is_empty() {
json!({
"progressToken": token,
"progress": progress,
"total": total
})
} else {
json!({
"progressToken": token,
"progress": progress,
"total": total,
"message": message
})
};
let notif = json!({
"jsonrpc": "2.0",
"method": "notifications/progress",
"params": params
});
let line = serde_json::to_string(¬if).expect("notification serialization cannot fail");
writeln!(out, "{line}")?;
out.flush()
}
pub struct ProgressReporter<'w, W: Write> {
out: &'w mut W,
token: String,
current: u32,
total: u32,
}
impl<'w, W: Write> ProgressReporter<'w, W> {
#[must_use]
pub fn new(out: &'w mut W, total: u32) -> Self {
Self {
out,
token: next_progress_token(),
current: 0,
total,
}
}
#[must_use]
pub fn token(&self) -> &str {
&self.token
}
pub fn step(&mut self, message: &str) -> io::Result<()> {
self.current = self.current.saturating_add(1).min(self.total);
emit_progress(self.out, &self.token, self.current, self.total, message)
}
pub fn complete(&mut self, message: &str) -> io::Result<()> {
emit_progress(self.out, &self.token, self.total, self.total, message)
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::Value;
fn parse_line(buf: &[u8]) -> Value {
let line = String::from_utf8_lossy(buf);
serde_json::from_str(line.trim()).expect("valid JSON notification")
}
#[test]
fn next_progress_token_returns_sixteen_hex_digits() {
let token = next_progress_token();
assert_eq!(token.len(), 16, "token: {token}");
assert!(
token.chars().all(|c| c.is_ascii_hexdigit()),
"token: {token}"
);
}
#[test]
fn next_progress_token_is_unique_on_each_call() {
let a = next_progress_token();
let b = next_progress_token();
assert_ne!(a, b);
}
#[test]
fn next_progress_token_is_monotonically_increasing() {
let a = u64::from_str_radix(&next_progress_token(), 16).unwrap();
let b = u64::from_str_radix(&next_progress_token(), 16).unwrap();
assert!(b > a);
}
#[test]
fn emit_progress_writes_valid_jsonrpc_notification() {
let mut buf = Vec::<u8>::new();
emit_progress(&mut buf, "tok-1", 33, 100, "step 1").unwrap();
let v = parse_line(&buf);
assert_eq!(v["jsonrpc"], "2.0");
assert_eq!(v["method"], "notifications/progress");
assert_eq!(v["params"]["progressToken"], "tok-1");
assert_eq!(v["params"]["progress"], 33);
assert_eq!(v["params"]["total"], 100);
assert_eq!(v["params"]["message"], "step 1");
}
#[test]
fn emit_progress_omits_message_field_when_empty() {
let mut buf = Vec::<u8>::new();
emit_progress(&mut buf, "tok-2", 0, 10, "").unwrap();
let v = parse_line(&buf);
assert!(
v["params"].get("message").is_none(),
"message should be absent"
);
}
#[test]
fn emit_progress_terminates_with_newline() {
let mut buf = Vec::<u8>::new();
emit_progress(&mut buf, "tok-3", 1, 1, "done").unwrap();
assert!(String::from_utf8(buf).unwrap().ends_with('\n'));
}
#[test]
fn emit_progress_progress_equals_total_signals_completion() {
let mut buf = Vec::<u8>::new();
emit_progress(&mut buf, "tok-4", 5, 5, "Complete").unwrap();
let v = parse_line(&buf);
assert_eq!(v["params"]["progress"], v["params"]["total"]);
}
#[test]
fn progress_reporter_step_advances_progress_by_one() {
let mut buf = Vec::<u8>::new();
{
let mut reporter = ProgressReporter::new(&mut buf, 3);
reporter.step("step 1").unwrap();
}
let v = parse_line(&buf);
assert_eq!(v["params"]["progress"], 1);
assert_eq!(v["params"]["total"], 3);
assert_eq!(v["params"]["message"], "step 1");
}
#[test]
fn progress_reporter_emits_one_notification_per_step() {
let mut buf = Vec::<u8>::new();
{
let mut reporter = ProgressReporter::new(&mut buf, 2);
reporter.step("a").unwrap();
reporter.step("b").unwrap();
}
let text = String::from_utf8(buf).unwrap();
let lines: Vec<&str> = text.trim_end().split('\n').collect();
assert_eq!(lines.len(), 2);
}
#[test]
fn progress_reporter_complete_sets_progress_to_total() {
let mut buf = Vec::<u8>::new();
{
let mut reporter = ProgressReporter::new(&mut buf, 5);
reporter.complete("all done").unwrap();
}
let v = parse_line(&buf);
assert_eq!(v["params"]["progress"], 5);
assert_eq!(v["params"]["total"], 5);
}
#[test]
fn progress_reporter_step_does_not_exceed_total() {
let mut buf = Vec::<u8>::new();
{
let mut reporter = ProgressReporter::new(&mut buf, 1);
reporter.step("first").unwrap();
reporter.step("overflow attempt").unwrap();
}
let text = String::from_utf8(buf).unwrap();
let lines: Vec<&str> = text.trim_end().split('\n').collect();
let second: Value = serde_json::from_str(lines[1]).unwrap();
let p = second["params"]["progress"].as_u64().unwrap();
let t = second["params"]["total"].as_u64().unwrap();
assert_eq!(p, t, "progress must not exceed total");
}
#[test]
fn progress_reporter_all_steps_share_same_token() {
let mut buf = Vec::<u8>::new();
let token;
{
let mut reporter = ProgressReporter::new(&mut buf, 2);
token = reporter.token().to_string();
reporter.step("a").unwrap();
reporter.step("b").unwrap();
}
for line in String::from_utf8(buf).unwrap().trim_end().split('\n') {
let v: Value = serde_json::from_str(line).unwrap();
assert_eq!(v["params"]["progressToken"], token.as_str());
}
}
}