1use futures::{
4 SinkExt, StreamExt,
5 channel::{mpsc, oneshot},
6 future::BoxFuture,
7};
8use schemars::JsonSchema;
9use serde::{Serialize, de::DeserializeOwned};
10
11use crate::{ConnectionTo, Error, Role, RunWithConnectionTo};
12
13use super::{McpConnectionTo, McpTool};
14
15struct ToolCall<P, R, MyRole: Role> {
16 params: P,
17 mcp_connection: McpConnectionTo<MyRole>,
18 result_tx: futures::channel::oneshot::Sender<Result<R, Error>>,
19}
20
21struct ToolFnMutResponder<F, P, R, Counterpart: Role> {
22 func: F,
23 call_rx: mpsc::Receiver<ToolCall<P, R, Counterpart>>,
24 tool_future_fn: Box<
25 dyn for<'a> Fn(
26 &'a mut F,
27 P,
28 McpConnectionTo<Counterpart>,
29 ) -> BoxFuture<'a, Result<R, Error>>
30 + Send,
31 >,
32}
33
34impl<F, P, R, Counterpart, Counterpart1> RunWithConnectionTo<Counterpart1>
35 for ToolFnMutResponder<F, P, R, Counterpart>
36where
37 Counterpart: Role,
38 Counterpart1: Role,
39 P: Send,
40 R: Send,
41 F: Send,
42{
43 async fn run_with_connection_to(
44 self,
45 _connection: ConnectionTo<Counterpart1>,
46 ) -> Result<(), Error> {
47 let ToolFnMutResponder {
48 mut func,
49 mut call_rx,
50 tool_future_fn,
51 } = self;
52 while let Some(ToolCall {
53 params,
54 mcp_connection,
55 result_tx,
56 }) = call_rx.next().await
57 {
58 let result = tool_future_fn(&mut func, params, mcp_connection).await;
59 result_tx
60 .send(result)
61 .map_err(|_| crate::util::internal_error("failed to send MCP result"))?;
62 }
63 Ok(())
64 }
65}
66
67struct ToolFnResponder<F, P, R, Counterpart: Role> {
68 func: F,
69 call_rx: mpsc::Receiver<ToolCall<P, R, Counterpart>>,
70 tool_future_fn: Box<
71 dyn for<'a> Fn(&'a F, P, McpConnectionTo<Counterpart>) -> BoxFuture<'a, Result<R, Error>>
72 + Send
73 + Sync,
74 >,
75}
76
77impl<F, P, R, Counterpart, Counterpart1> RunWithConnectionTo<Counterpart1>
78 for ToolFnResponder<F, P, R, Counterpart>
79where
80 Counterpart: Role,
81 Counterpart1: Role,
82 P: Send,
83 R: Send,
84 F: Send + Sync,
85{
86 async fn run_with_connection_to(
87 self,
88 _connection: ConnectionTo<Counterpart1>,
89 ) -> Result<(), Error> {
90 let ToolFnResponder {
91 func,
92 call_rx,
93 tool_future_fn,
94 } = self;
95 crate::util::process_stream_concurrently(
96 call_rx,
97 async |tool_call| {
98 fn hack<'a, F, P, R, MyRole>(
99 func: &'a F,
100 params: P,
101 mcp_connection: McpConnectionTo<MyRole>,
102 tool_future_fn: &'a (
103 dyn Fn(
104 &'a F,
105 P,
106 McpConnectionTo<MyRole>,
107 ) -> BoxFuture<'a, Result<R, Error>>
108 + Send
109 + Sync
110 ),
111 result_tx: oneshot::Sender<Result<R, Error>>,
112 ) -> BoxFuture<'a, ()>
113 where
114 MyRole: Role,
115 P: Send,
116 R: Send,
117 F: Send + Sync,
118 {
119 Box::pin(async move {
120 let result = tool_future_fn(func, params, mcp_connection).await;
121 drop(result_tx.send(result));
122 })
123 }
124
125 let ToolCall {
126 params,
127 mcp_connection,
128 result_tx,
129 } = tool_call;
130
131 hack(&func, params, mcp_connection, &*tool_future_fn, result_tx).await;
132 Ok(())
133 },
134 |a, b| Box::pin(a(b)),
135 )
136 .await
137 }
138}
139
140struct ToolFnTool<P, Ret, R: Role> {
141 name: String,
142 description: String,
143 call_tx: mpsc::Sender<ToolCall<P, Ret, R>>,
144}
145
146impl<P, Ret, R> McpTool<R> for ToolFnTool<P, Ret, R>
147where
148 R: Role,
149 P: JsonSchema + DeserializeOwned + 'static + Send,
150 Ret: JsonSchema + Serialize + 'static + Send,
151{
152 type Input = P;
153 type Output = Ret;
154
155 fn name(&self) -> String {
156 self.name.clone()
157 }
158
159 fn description(&self) -> String {
160 self.description.clone()
161 }
162
163 async fn call_tool(&self, params: P, mcp_connection: McpConnectionTo<R>) -> Result<Ret, Error> {
164 let (result_tx, result_rx) = oneshot::channel();
165
166 self.call_tx
167 .clone()
168 .send(ToolCall {
169 params,
170 mcp_connection,
171 result_tx,
172 })
173 .await
174 .map_err(crate::util::internal_error)?;
175
176 result_rx.await.map_err(crate::util::internal_error)?
177 }
178}
179
180pub fn tool_fn_mut<P, Ret, F, Counterpart>(
184 name: impl ToString,
185 description: impl ToString,
186 func: F,
187 tool_future_fn: impl for<'a> Fn(
188 &'a mut F,
189 P,
190 McpConnectionTo<Counterpart>,
191 ) -> BoxFuture<'a, Result<Ret, Error>>
192 + Send
193 + 'static,
194) -> (
195 impl McpTool<Counterpart> + 'static,
196 impl RunWithConnectionTo<Counterpart>,
197)
198where
199 Counterpart: Role,
200 P: JsonSchema + DeserializeOwned + 'static + Send,
201 Ret: JsonSchema + Serialize + 'static + Send,
202 F: AsyncFnMut(P, McpConnectionTo<Counterpart>) -> Result<Ret, Error> + Send,
203{
204 let (call_tx, call_rx) = mpsc::channel(128);
205 (
206 ToolFnTool {
207 name: name.to_string(),
208 description: description.to_string(),
209 call_tx,
210 },
211 ToolFnMutResponder {
212 func,
213 call_rx,
214 tool_future_fn: Box::new(tool_future_fn),
215 },
216 )
217}
218
219pub fn tool_fn<P, Ret, F, Counterpart>(
221 name: impl ToString,
222 description: impl ToString,
223 func: F,
224 tool_future_fn: impl for<'a> Fn(
225 &'a F,
226 P,
227 McpConnectionTo<Counterpart>,
228 ) -> BoxFuture<'a, Result<Ret, Error>>
229 + Send
230 + Sync
231 + 'static,
232) -> (
233 impl McpTool<Counterpart> + 'static,
234 impl RunWithConnectionTo<Counterpart>,
235)
236where
237 Counterpart: Role,
238 P: JsonSchema + DeserializeOwned + 'static + Send,
239 Ret: JsonSchema + Serialize + 'static + Send,
240 F: AsyncFn(P, McpConnectionTo<Counterpart>) -> Result<Ret, Error> + Send + Sync + 'static,
241{
242 let (call_tx, call_rx) = mpsc::channel(128);
243 (
244 ToolFnTool {
245 name: name.to_string(),
246 description: description.to_string(),
247 call_tx,
248 },
249 ToolFnResponder {
250 func,
251 call_rx,
252 tool_future_fn: Box::new(tool_future_fn),
253 },
254 )
255}