construct/mcp_server/
progress_wrap.rs1use crate::tools::progress::{ProgressHandle, ProgressSink};
16use crate::tools::traits::{Tool, ToolResult};
17use async_trait::async_trait;
18use std::sync::Arc;
19
20pub struct ProgressEnvelope {
22 inner: Arc<dyn Tool>,
23 start_message: String,
24 finish_message: String,
25}
26
27impl ProgressEnvelope {
28 pub fn new(inner: Arc<dyn Tool>, start_message: &str, finish_message: &str) -> Self {
29 Self {
30 inner,
31 start_message: start_message.to_string(),
32 finish_message: finish_message.to_string(),
33 }
34 }
35
36 pub fn into_arc(self) -> Arc<dyn Tool> {
38 Arc::new(self)
39 }
40}
41
42#[async_trait]
43impl Tool for ProgressEnvelope {
44 fn name(&self) -> &str {
45 self.inner.name()
46 }
47
48 fn description(&self) -> &str {
49 self.inner.description()
50 }
51
52 fn parameters_schema(&self) -> serde_json::Value {
53 self.inner.parameters_schema()
54 }
55
56 async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
57 self.inner.execute(args).await
58 }
59
60 async fn execute_with_progress(
61 &self,
62 args: serde_json::Value,
63 sink: &dyn ProgressSink,
64 ) -> anyhow::Result<ToolResult> {
65 let handle = ProgressHandle::new(sink, Some(2));
66 handle.update(1, Some(&self.start_message));
67 let result = self.inner.execute(args).await;
68 handle.update(2, Some(&self.finish_message));
69 result
70 }
71}
72
73#[cfg(test)]
74mod tests {
75 use super::*;
76 use crate::tools::progress::ProgressToken;
77 use std::sync::Mutex;
78
79 #[derive(Default)]
80 struct RecordingSink {
81 events: Mutex<Vec<(u64, Option<String>)>>,
82 }
83 impl ProgressSink for RecordingSink {
84 fn notify(
85 &self,
86 _t: ProgressToken,
87 progress: u64,
88 _total: Option<u64>,
89 message: Option<&str>,
90 ) {
91 self.events
92 .lock()
93 .unwrap()
94 .push((progress, message.map(str::to_string)));
95 }
96 }
97
98 struct StubTool;
99 #[async_trait]
100 impl Tool for StubTool {
101 fn name(&self) -> &str {
102 "stub"
103 }
104 fn description(&self) -> &str {
105 "stub"
106 }
107 fn parameters_schema(&self) -> serde_json::Value {
108 serde_json::json!({ "type": "object" })
109 }
110 async fn execute(&self, _args: serde_json::Value) -> anyhow::Result<ToolResult> {
111 Ok(ToolResult {
112 success: true,
113 output: "inner".into(),
114 error: None,
115 })
116 }
117 }
118
119 #[tokio::test]
120 async fn envelope_emits_bookend_progress() {
121 let sink = RecordingSink::default();
122 let wrapped = ProgressEnvelope::new(Arc::new(StubTool), "starting", "done");
123 let r = wrapped
124 .execute_with_progress(serde_json::json!({}), &sink)
125 .await
126 .unwrap();
127 assert_eq!(r.output, "inner");
128 let events = sink.events.lock().unwrap();
129 assert_eq!(events.len(), 2);
130 assert_eq!(events[0], (1, Some("starting".into())));
131 assert_eq!(events[1], (2, Some("done".into())));
132 }
133
134 #[tokio::test]
135 async fn envelope_forwards_spec_and_name() {
136 let wrapped = ProgressEnvelope::new(Arc::new(StubTool), "a", "b");
137 assert_eq!(wrapped.name(), "stub");
138 assert!(!wrapped.parameters_schema().is_null());
139 }
140}