Skip to main content

agent_client_protocol/mcp_server/
tool_fn.rs

1//! Runtime-neutral helpers for registering function-backed MCP tools.
2
3use futures::{
4    SinkExt, StreamExt,
5    channel::{mpsc, oneshot},
6    future::BoxFuture,
7};
8use schemars::JsonSchema;
9use serde::{Serialize, de::DeserializeOwned};
10
11use crate::{ConnectionTo, Error, Role, RunWithConnectionTo};
12
13use super::{McpConnectionTo, McpTool};
14
15struct ToolCall<P, R, MyRole: Role> {
16    params: P,
17    mcp_connection: McpConnectionTo<MyRole>,
18    result_tx: futures::channel::oneshot::Sender<Result<R, Error>>,
19}
20
21struct ToolFnMutResponder<F, P, R, Counterpart: Role> {
22    func: F,
23    call_rx: mpsc::Receiver<ToolCall<P, R, Counterpart>>,
24    tool_future_fn: Box<
25        dyn for<'a> Fn(
26                &'a mut F,
27                P,
28                McpConnectionTo<Counterpart>,
29            ) -> BoxFuture<'a, Result<R, Error>>
30            + Send,
31    >,
32}
33
34impl<F, P, R, Counterpart, Counterpart1> RunWithConnectionTo<Counterpart1>
35    for ToolFnMutResponder<F, P, R, Counterpart>
36where
37    Counterpart: Role,
38    Counterpart1: Role,
39    P: Send,
40    R: Send,
41    F: Send,
42{
43    async fn run_with_connection_to(
44        self,
45        _connection: ConnectionTo<Counterpart1>,
46    ) -> Result<(), Error> {
47        let ToolFnMutResponder {
48            mut func,
49            mut call_rx,
50            tool_future_fn,
51        } = self;
52        while let Some(ToolCall {
53            params,
54            mcp_connection,
55            result_tx,
56        }) = call_rx.next().await
57        {
58            let result = tool_future_fn(&mut func, params, mcp_connection).await;
59            result_tx
60                .send(result)
61                .map_err(|_| crate::util::internal_error("failed to send MCP result"))?;
62        }
63        Ok(())
64    }
65}
66
67struct ToolFnResponder<F, P, R, Counterpart: Role> {
68    func: F,
69    call_rx: mpsc::Receiver<ToolCall<P, R, Counterpart>>,
70    tool_future_fn: Box<
71        dyn for<'a> Fn(&'a F, P, McpConnectionTo<Counterpart>) -> BoxFuture<'a, Result<R, Error>>
72            + Send
73            + Sync,
74    >,
75}
76
77impl<F, P, R, Counterpart, Counterpart1> RunWithConnectionTo<Counterpart1>
78    for ToolFnResponder<F, P, R, Counterpart>
79where
80    Counterpart: Role,
81    Counterpart1: Role,
82    P: Send,
83    R: Send,
84    F: Send + Sync,
85{
86    async fn run_with_connection_to(
87        self,
88        _connection: ConnectionTo<Counterpart1>,
89    ) -> Result<(), Error> {
90        let ToolFnResponder {
91            func,
92            call_rx,
93            tool_future_fn,
94        } = self;
95        crate::util::process_stream_concurrently(
96            call_rx,
97            async |tool_call| {
98                fn hack<'a, F, P, R, MyRole>(
99                    func: &'a F,
100                    params: P,
101                    mcp_connection: McpConnectionTo<MyRole>,
102                    tool_future_fn: &'a (
103                            dyn Fn(
104                        &'a F,
105                        P,
106                        McpConnectionTo<MyRole>,
107                    ) -> BoxFuture<'a, Result<R, Error>>
108                                + Send
109                                + Sync
110                        ),
111                    result_tx: oneshot::Sender<Result<R, Error>>,
112                ) -> BoxFuture<'a, ()>
113                where
114                    MyRole: Role,
115                    P: Send,
116                    R: Send,
117                    F: Send + Sync,
118                {
119                    Box::pin(async move {
120                        let result = tool_future_fn(func, params, mcp_connection).await;
121                        drop(result_tx.send(result));
122                    })
123                }
124
125                let ToolCall {
126                    params,
127                    mcp_connection,
128                    result_tx,
129                } = tool_call;
130
131                hack(&func, params, mcp_connection, &*tool_future_fn, result_tx).await;
132                Ok(())
133            },
134            |a, b| Box::pin(a(b)),
135        )
136        .await
137    }
138}
139
140struct ToolFnTool<P, Ret, R: Role> {
141    name: String,
142    description: String,
143    call_tx: mpsc::Sender<ToolCall<P, Ret, R>>,
144}
145
146impl<P, Ret, R> McpTool<R> for ToolFnTool<P, Ret, R>
147where
148    R: Role,
149    P: JsonSchema + DeserializeOwned + 'static + Send,
150    Ret: JsonSchema + Serialize + 'static + Send,
151{
152    type Input = P;
153    type Output = Ret;
154
155    fn name(&self) -> String {
156        self.name.clone()
157    }
158
159    fn description(&self) -> String {
160        self.description.clone()
161    }
162
163    async fn call_tool(&self, params: P, mcp_connection: McpConnectionTo<R>) -> Result<Ret, Error> {
164        let (result_tx, result_rx) = oneshot::channel();
165
166        self.call_tx
167            .clone()
168            .send(ToolCall {
169                params,
170                mcp_connection,
171                result_tx,
172            })
173            .await
174            .map_err(crate::util::internal_error)?;
175
176        result_rx.await.map_err(crate::util::internal_error)?
177    }
178}
179
180/// Create a "single-threaded" function-backed MCP tool and its responder.
181///
182/// Only one invocation of the tool can be running at a time.
183pub fn tool_fn_mut<P, Ret, F, Counterpart>(
184    name: impl ToString,
185    description: impl ToString,
186    func: F,
187    tool_future_fn: impl for<'a> Fn(
188        &'a mut F,
189        P,
190        McpConnectionTo<Counterpart>,
191    ) -> BoxFuture<'a, Result<Ret, Error>>
192    + Send
193    + 'static,
194) -> (
195    impl McpTool<Counterpart> + 'static,
196    impl RunWithConnectionTo<Counterpart>,
197)
198where
199    Counterpart: Role,
200    P: JsonSchema + DeserializeOwned + 'static + Send,
201    Ret: JsonSchema + Serialize + 'static + Send,
202    F: AsyncFnMut(P, McpConnectionTo<Counterpart>) -> Result<Ret, Error> + Send,
203{
204    let (call_tx, call_rx) = mpsc::channel(128);
205    (
206        ToolFnTool {
207            name: name.to_string(),
208            description: description.to_string(),
209            call_tx,
210        },
211        ToolFnMutResponder {
212            func,
213            call_rx,
214            tool_future_fn: Box::new(tool_future_fn),
215        },
216    )
217}
218
219/// Create a stateless function-backed MCP tool and its concurrent responder.
220pub fn tool_fn<P, Ret, F, Counterpart>(
221    name: impl ToString,
222    description: impl ToString,
223    func: F,
224    tool_future_fn: impl for<'a> Fn(
225        &'a F,
226        P,
227        McpConnectionTo<Counterpart>,
228    ) -> BoxFuture<'a, Result<Ret, Error>>
229    + Send
230    + Sync
231    + 'static,
232) -> (
233    impl McpTool<Counterpart> + 'static,
234    impl RunWithConnectionTo<Counterpart>,
235)
236where
237    Counterpart: Role,
238    P: JsonSchema + DeserializeOwned + 'static + Send,
239    Ret: JsonSchema + Serialize + 'static + Send,
240    F: AsyncFn(P, McpConnectionTo<Counterpart>) -> Result<Ret, Error> + Send + Sync + 'static,
241{
242    let (call_tx, call_rx) = mpsc::channel(128);
243    (
244        ToolFnTool {
245            name: name.to_string(),
246            description: description.to_string(),
247            call_tx,
248        },
249        ToolFnResponder {
250            func,
251            call_rx,
252            tool_future_fn: Box::new(tool_future_fn),
253        },
254    )
255}