use std::sync::Arc;
use serde_json::Value as JsonValue;
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use crate::EmbeddedDatabase;
use super::progress::{self, ChannelProgressSink, ProgressEvent, ProgressSink};
use super::tools::{call_tool, ToolOutcome};
pub fn call_tool_streaming(
db: Option<Arc<EmbeddedDatabase>>,
name: String,
args: JsonValue,
) -> (mpsc::UnboundedReceiver<ProgressEvent>, JoinHandle<ToolOutcome>) {
let (sink, rx) = ChannelProgressSink::channel();
let sink_arc: Arc<dyn ProgressSink> = Arc::new(sink);
let handle = tokio::task::spawn_blocking(move || {
progress::set_sink(sink_arc);
let outcome = call_tool(db.as_deref(), &name, args);
progress::clear_sink();
outcome
});
(rx, handle)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mcp::progress::emit;
use crate::mcp::tools::ToolOutcome;
use serde_json::json;
fn fake_streaming_handler(args: JsonValue) -> ToolOutcome {
let n = args["n"].as_u64().unwrap_or(0);
for i in 0..n {
emit(ProgressEvent::new(i as f64).with_total(n as f64));
}
ToolOutcome::ok(json!({ "emitted": n }))
}
#[tokio::test]
async fn channel_drains_emitted_events() {
let (sink, mut rx) = ChannelProgressSink::channel();
let sink_arc: Arc<dyn ProgressSink> = Arc::new(sink);
let handle = tokio::task::spawn_blocking(move || {
progress::set_sink(sink_arc);
let outcome = fake_streaming_handler(json!({ "n": 3 }));
progress::clear_sink();
outcome
});
let mut events = Vec::new();
while let Some(ev) = rx.recv().await {
events.push(ev);
}
let outcome = handle.await.unwrap();
assert!(!outcome.is_error);
assert_eq!(events.len(), 3);
assert_eq!(events[0].progress, 0.0);
assert_eq!(events[2].progress, 2.0);
}
}