use std::sync::Arc;
use serde::{Deserialize, Serialize};
use crate::mcp_session::{McpEventKind, SessionState};
pub const MCP_NOTIFICATION_PROGRESS_METHOD: &str = "notifications/progress";
pub const MCP_SEARCH_DOCS_PROGRESS_TOP_K_THRESHOLD: u32 = 100;
pub const MCP_REMEMBER_BATCH_PROGRESS_ITEM_THRESHOLD: usize = 50;
pub const MCP_REMEMBER_BATCH_PROGRESS_EMIT_EVERY: usize = 25;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProgressToken(pub serde_json::Value);
impl ProgressToken {
pub fn from_meta(params: &serde_json::Value) -> Option<Self> {
let meta = params.get("_meta")?;
let token = meta.get("progressToken")?;
match token {
serde_json::Value::Null => None,
serde_json::Value::String(_) | serde_json::Value::Number(_) => {
Some(Self(token.clone()))
}
_ => None,
}
}
pub fn as_value(&self) -> &serde_json::Value {
&self.0
}
}
#[derive(Debug, Clone)]
pub struct ProgressReporter {
session: Arc<SessionState>,
token: ProgressToken,
}
impl ProgressReporter {
pub fn new(session: Arc<SessionState>, token: ProgressToken) -> Self {
Self { session, token }
}
pub fn token(&self) -> &ProgressToken {
&self.token
}
pub fn report(&self, progress: u64, total: Option<u64>, message: Option<&str>) -> u64 {
let mut params = serde_json::Map::with_capacity(4);
params.insert("progressToken".to_string(), self.token.0.clone());
params.insert(
"progress".to_string(),
serde_json::Value::Number(progress.into()),
);
if let Some(t) = total {
params.insert("total".to_string(), serde_json::Value::Number(t.into()));
}
if let Some(m) = message {
params.insert(
"message".to_string(),
serde_json::Value::String(m.to_string()),
);
}
let envelope = serde_json::json!({
"jsonrpc": "2.0",
"method": MCP_NOTIFICATION_PROGRESS_METHOD,
"params": serde_json::Value::Object(params),
});
self.session.publish_event(McpEventKind::Progress, envelope)
}
}
pub fn report_if_some(
reporter: Option<&ProgressReporter>,
progress: u64,
total: Option<u64>,
message: Option<&str>,
) {
if let Some(r) = reporter {
let _id = r.report(progress, total, message);
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use solo_core::TenantId;
fn fresh_session() -> Arc<SessionState> {
Arc::new(SessionState::new(TenantId::default_tenant(), None))
}
#[test]
fn progress_token_from_meta_reads_string_token() {
let params = json!({
"name": "memory_ingest_document",
"_meta": { "progressToken": "abc-123" },
});
let token = ProgressToken::from_meta(¶ms).expect("string token");
assert_eq!(token.as_value(), &json!("abc-123"));
}
#[test]
fn progress_token_from_meta_reads_number_token() {
let params = json!({
"name": "memory_search_docs",
"_meta": { "progressToken": 42 },
});
let token = ProgressToken::from_meta(¶ms).expect("number token");
assert_eq!(token.as_value(), &json!(42));
}
#[test]
fn progress_token_from_meta_returns_none_without_meta() {
let params = json!({"name": "memory_inspect"});
assert!(ProgressToken::from_meta(¶ms).is_none());
}
#[test]
fn progress_token_from_meta_returns_none_for_null_token() {
let params = json!({"_meta": {"progressToken": null}});
assert!(ProgressToken::from_meta(¶ms).is_none());
}
#[test]
fn progress_token_from_meta_rejects_object_token() {
let params = json!({"_meta": {"progressToken": {"nope": 1}}});
assert!(ProgressToken::from_meta(¶ms).is_none());
}
#[test]
fn progress_reporter_publishes_spec_shape() {
let session = fresh_session();
let mut rx = session.subscribe_events();
let reporter = ProgressReporter::new(session.clone(), ProgressToken(json!("tok-1")));
let _id = reporter.report(3, Some(12), Some("crunching"));
let event = rx.try_recv().expect("event published");
assert_eq!(event.event, McpEventKind::Progress);
assert_eq!(event.data["jsonrpc"], json!("2.0"));
assert_eq!(
event.data["method"],
json!(MCP_NOTIFICATION_PROGRESS_METHOD)
);
assert_eq!(event.data["params"]["progressToken"], json!("tok-1"));
assert_eq!(event.data["params"]["progress"], json!(3));
assert_eq!(event.data["params"]["total"], json!(12));
assert_eq!(event.data["params"]["message"], json!("crunching"));
}
#[test]
fn progress_reporter_omits_total_and_message_when_absent() {
let session = fresh_session();
let mut rx = session.subscribe_events();
let reporter = ProgressReporter::new(session.clone(), ProgressToken(json!(7)));
let _id = reporter.report(1, None, None);
let event = rx.try_recv().expect("event published");
assert_eq!(event.data["params"]["progressToken"], json!(7));
assert_eq!(event.data["params"]["progress"], json!(1));
assert!(
event.data["params"].get("total").is_none(),
"total key must be omitted when None"
);
assert!(
event.data["params"].get("message").is_none(),
"message key must be omitted when None"
);
}
#[test]
fn progress_reporter_emits_monotonic_event_ids() {
let session = fresh_session();
let reporter = ProgressReporter::new(session.clone(), ProgressToken(json!("t")));
let id_a = reporter.report(1, None, None);
let id_b = reporter.report(2, None, None);
let id_c = reporter.report(3, None, None);
assert!(id_a < id_b);
assert!(id_b < id_c);
}
#[test]
fn report_if_some_noop_when_none() {
report_if_some(None, 1, Some(2), Some("noop"));
}
#[test]
fn report_if_some_publishes_when_present() {
let session = fresh_session();
let mut rx = session.subscribe_events();
let reporter = ProgressReporter::new(session.clone(), ProgressToken(json!("t")));
report_if_some(Some(&reporter), 1, Some(1), Some("hi"));
let event = rx.try_recv().expect("event published");
assert_eq!(event.data["params"]["progress"], json!(1));
}
}