use std::sync::Arc;
use std::time::Duration;
use rmcp::{
model::{
LoggingLevel, LoggingMessageNotificationParam, NumberOrString, ProgressNotificationParam,
ProgressToken,
},
service::Peer,
RoleServer,
};
use tokio::sync::Mutex;
use tokio::time::Instant;
const MIN_GAP: Duration = Duration::from_millis(500);
#[async_trait::async_trait]
pub trait ProgressSink: Send + Sync {
async fn emit_progress(&self, step: f64, total: Option<f64>, token: &NumberOrString);
async fn emit_text(&self, text: &str);
}
pub struct PeerSink {
pub peer: Peer<RoleServer>,
}
#[async_trait::async_trait]
impl ProgressSink for PeerSink {
async fn emit_progress(&self, step: f64, total: Option<f64>, token: &NumberOrString) {
let mut params = ProgressNotificationParam::new(ProgressToken(token.clone()), step);
if let Some(t) = total {
params = params.with_total(t);
}
let _ = self.peer.notify_progress(params).await;
}
async fn emit_text(&self, text: &str) {
let _ = self
.peer
.notify_logging_message(LoggingMessageNotificationParam {
level: LoggingLevel::Info,
logger: Some("codescout".to_string()),
data: serde_json::Value::String(text.to_string()),
})
.await;
}
}
pub struct ProgressReporter {
sink: Arc<dyn ProgressSink>,
token: NumberOrString,
last_emit: Mutex<Option<Instant>>,
}
impl ProgressReporter {
pub fn new(peer: Peer<RoleServer>, token: NumberOrString) -> Arc<Self> {
Arc::new(Self {
sink: Arc::new(PeerSink { peer }),
token,
last_emit: Mutex::new(None),
})
}
#[cfg(test)]
pub fn with_sink(sink: Arc<dyn ProgressSink>, token: NumberOrString) -> Arc<Self> {
Arc::new(Self {
sink,
token,
last_emit: Mutex::new(None),
})
}
pub async fn report(&self, step: u32, total: Option<u32>) {
let now = Instant::now();
{
let mut g = self.last_emit.lock().await;
match *g {
Some(t) if now.duration_since(t) < MIN_GAP => return,
_ => {
*g = Some(now);
}
}
}
self.sink
.emit_progress(step as f64, total.map(|t| t as f64), &self.token)
.await;
}
pub async fn report_text(&self, text: &str) {
self.sink.emit_text(text).await;
}
}
#[cfg(test)]
pub(crate) mod test_support {
use std::sync::atomic::{AtomicU32, Ordering};
#[derive(Default)]
pub struct CountingSink {
pub progress_calls: AtomicU32,
pub text_calls: AtomicU32,
}
#[async_trait::async_trait]
impl crate::tools::progress::ProgressSink for CountingSink {
async fn emit_progress(
&self,
_step: f64,
_total: Option<f64>,
_token: &rmcp::model::NumberOrString,
) {
self.progress_calls.fetch_add(1, Ordering::Relaxed);
}
async fn emit_text(&self, _text: &str) {
self.text_calls.fetch_add(1, Ordering::Relaxed);
}
}
}
#[cfg(test)]
mod tests {
use super::test_support::CountingSink;
use super::*;
use std::sync::atomic::Ordering;
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn report_throttles_to_one_per_window() {
let sink = Arc::new(CountingSink::default());
let r = ProgressReporter::with_sink(sink.clone(), NumberOrString::Number(1));
for i in 0..100 {
r.report(i, Some(100)).await;
tokio::time::advance(Duration::from_millis(9)).await;
}
let n = sink.progress_calls.load(Ordering::SeqCst);
assert!((1..=2).contains(&n), "expected 1–2 emissions, got {}", n);
}
#[tokio::test(flavor = "current_thread", start_paused = true)]
async fn report_allows_after_window_elapsed() {
let sink = Arc::new(CountingSink::default());
let r = ProgressReporter::with_sink(sink.clone(), NumberOrString::Number(2));
r.report(1, None).await; tokio::time::advance(Duration::from_millis(600)).await; r.report(2, None).await; assert_eq!(sink.progress_calls.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn report_text_is_not_throttled() {
let sink = Arc::new(CountingSink::default());
let r = ProgressReporter::with_sink(sink.clone(), NumberOrString::Number(3));
for _ in 0..10 {
r.report_text("hi").await;
}
assert_eq!(sink.text_calls.load(Ordering::SeqCst), 10);
}
}