Skip to main content

sim_lib_mcp/
router.rs

1#[cfg(feature = "stream")]
2use sim_codec_mcp::CANCELLED;
3use sim_codec_mcp::{
4    CAPABILITY_DENIED, EXECUTION_ERROR, INTERNAL_ERROR, INVALID_PARAMS, INVALID_REQUEST,
5    METHOD_NOT_FOUND, McpEnvelope, McpError, McpErrorEnvelope, McpNotification, McpRequest,
6    McpResponse, RATE_LIMITED, envelope_to_expr, expr_to_envelope,
7};
8use sim_kernel::{Cx, Error, Expr, Result};
9
10use crate::McpSession;
11use crate::methods::{core, prompts, resources, tools};
12use crate::session::McpBoundaryLimit;
13
14/// Optional single reply envelope produced by routing one request.
15pub type RouterReply = Option<McpEnvelope>;
16/// Ordered reply envelopes (notifications then final response) for one request.
17pub type RouterReplies = Vec<McpEnvelope>;
18
19struct DispatchReply {
20    result: Expr,
21    notifications: RouterReplies,
22}
23
24/// Dispatches MCP request/notification envelopes against a session.
25pub struct McpRouter {
26    session: McpSession,
27}
28
29impl McpRouter {
30    /// Creates a router bound to `session`.
31    pub fn new(session: McpSession) -> Self {
32        Self { session }
33    }
34
35    /// Creates a router over a permissive fixture session.
36    pub fn fixture() -> Self {
37        Self::new(McpSession::fixture())
38    }
39
40    /// Returns the underlying session.
41    pub fn session(&self) -> &McpSession {
42        &self.session
43    }
44
45    /// Returns a mutable reference to the underlying session.
46    pub fn session_mut(&mut self) -> &mut McpSession {
47        &mut self.session
48    }
49
50    /// Decodes one envelope [`Expr`], routes it, and encodes the final reply.
51    pub fn handle_expr(&mut self, cx: &mut Cx, expr: Expr) -> Result<Option<Expr>> {
52        let envelope = match expr_to_envelope(&expr) {
53            Ok(envelope) => envelope,
54            Err(error) => {
55                return Ok(Some(envelope_to_expr(&invalid_request_error(
56                    Expr::Nil,
57                    error,
58                ))));
59            }
60        };
61        Ok(self
62            .handle(cx, envelope)?
63            .map(|reply| envelope_to_expr(&reply)))
64    }
65
66    /// Decodes one envelope [`Expr`], routes it, and encodes every reply
67    /// (notifications and the final response) in order.
68    pub fn handle_exprs(&mut self, cx: &mut Cx, expr: Expr) -> Result<Vec<Expr>> {
69        let envelope = match expr_to_envelope(&expr) {
70            Ok(envelope) => envelope,
71            Err(error) => {
72                return Ok(vec![envelope_to_expr(&invalid_request_error(
73                    Expr::Nil,
74                    error,
75                ))]);
76            }
77        };
78        Ok(self
79            .handle_many(cx, envelope)?
80            .into_iter()
81            .map(|reply| envelope_to_expr(&reply))
82            .collect())
83    }
84
85    /// Routes `envelope` and returns only the final response or error reply.
86    pub fn handle(&mut self, cx: &mut Cx, envelope: McpEnvelope) -> Result<RouterReply> {
87        Ok(self
88            .handle_many(cx, envelope)?
89            .into_iter()
90            .rev()
91            .find(is_final_reply))
92    }
93
94    /// Routes `envelope` and returns every reply it produces, in order.
95    pub fn handle_many(&mut self, cx: &mut Cx, envelope: McpEnvelope) -> Result<RouterReplies> {
96        match envelope {
97            McpEnvelope::Request(request) => self.handle_request(cx, request),
98            McpEnvelope::Notification(notification) => {
99                self.handle_notification(cx, notification);
100                Ok(Vec::new())
101            }
102            McpEnvelope::Response(_) | McpEnvelope::Error(_) => Ok(Vec::new()),
103        }
104    }
105
106    fn handle_request(&mut self, cx: &mut Cx, request: McpRequest) -> Result<RouterReplies> {
107        let id = request.id.clone();
108        #[cfg(feature = "cassette")]
109        let request_envelope = McpEnvelope::Request(request.clone());
110        if let Err(limit) = self.session.admit_request(&id) {
111            let reply = boundary_limit_error(id, limit);
112            #[cfg(feature = "cassette")]
113            self.audit_request(&request.method, "boundary", limit.audit_outcome());
114            return Ok(vec![reply]);
115        }
116        #[cfg(feature = "cassette")]
117        if let Some(replies) = self.replay_request(&request_envelope, &request.method)? {
118            self.session.end_request(&id);
119            return Ok(replies);
120        }
121        #[cfg(feature = "stream")]
122        if self.session.request_cancelled(&id) {
123            self.session.end_request(&id);
124            return Ok(vec![cancelled_error(id, "request cancelled")]);
125        }
126        let result = self.dispatch_request(cx, &request.method, request.params);
127        #[cfg(feature = "stream")]
128        let cancelled = self.session.request_cancelled(&id);
129        self.session.end_request(&id);
130        #[cfg(feature = "stream")]
131        if cancelled {
132            return Ok(vec![cancelled_error(id, "request cancelled")]);
133        }
134        let replies = match result {
135            Ok(reply) => {
136                let mut replies = reply.notifications;
137                replies.push(McpEnvelope::Response(McpResponse {
138                    id,
139                    result: reply.result,
140                }));
141                replies
142            }
143            Err(error) if matches!(error, Error::UnknownFunction { .. }) => {
144                vec![method_not_found_error(id, error)]
145            }
146            Err(error) => vec![error_response(id, error)],
147        };
148        #[cfg(feature = "cassette")]
149        self.record_request(&request_envelope, &request.method, &replies)?;
150        Ok(replies)
151    }
152
153    fn handle_notification(&mut self, _cx: &mut Cx, notification: McpNotification) {
154        match notification.method.as_str() {
155            "initialized" | "notifications/initialized" => {
156                core::initialized(&mut self.session);
157            }
158            "shutdown" => {
159                core::shutdown(&mut self.session);
160            }
161            #[cfg(feature = "stream")]
162            "notifications/cancelled" => {
163                let _ = crate::stream::apply_cancel_notification(
164                    &mut self.session,
165                    notification.params,
166                );
167            }
168            _ => {}
169        }
170    }
171
172    fn dispatch_request(
173        &mut self,
174        cx: &mut Cx,
175        method: &str,
176        params: Expr,
177    ) -> Result<DispatchReply> {
178        let (result, notifications) = match method {
179            "initialize" => (core::initialize(&mut self.session, params)?, Vec::new()),
180            "initialized" | "notifications/initialized" => {
181                (core::initialized(&mut self.session), Vec::new())
182            }
183            "ping" => (core::ping(), Vec::new()),
184            "shutdown" => (core::shutdown(&mut self.session), Vec::new()),
185            "resources/list" => (resources::list(cx, &self.session)?, Vec::new()),
186            "resources/read" => (resources::read(cx, &self.session, params)?, Vec::new()),
187            "prompts/list" => (prompts::list(cx, &self.session)?, Vec::new()),
188            "prompts/get" => (prompts::get(cx, &self.session, params)?, Vec::new()),
189            "tools/list" => (tools::list(cx, &self.session)?, Vec::new()),
190            #[cfg(feature = "stream")]
191            "tools/call" => {
192                let progress_token = crate::stream::progress_token_from_params(&params);
193                tools::call_with_stream(cx, &mut self.session, params, progress_token)?
194            }
195            #[cfg(not(feature = "stream"))]
196            "tools/call" => (tools::call(cx, &self.session, params)?, Vec::new()),
197            _ => Err(Error::UnknownFunction {
198                function: sim_kernel::Symbol::new(method.to_owned()),
199            })?,
200        };
201        Ok(DispatchReply {
202            result,
203            notifications,
204        })
205    }
206}
207
208fn is_final_reply(envelope: &McpEnvelope) -> bool {
209    matches!(envelope, McpEnvelope::Response(_) | McpEnvelope::Error(_))
210}
211
212fn invalid_request_error(id: Expr, error: Error) -> McpEnvelope {
213    error_envelope(id, INVALID_REQUEST, "invalid request", error.to_string())
214}
215
216fn method_not_found_error(id: Expr, error: Error) -> McpEnvelope {
217    error_envelope(id, METHOD_NOT_FOUND, "method not found", error.to_string())
218}
219
220fn error_response(id: Expr, error: Error) -> McpEnvelope {
221    if let Some(data) = crate::uri::not_found_error_data(&error) {
222        return error_envelope_data(id, INVALID_PARAMS, "not found", data);
223    }
224    let (code, message) = error_code_and_message(&error);
225    error_envelope(id, code, message, error.to_string())
226}
227
228#[cfg(feature = "stream")]
229fn cancelled_error(id: Expr, detail: &str) -> McpEnvelope {
230    error_envelope(id, CANCELLED, "cancelled", detail.to_owned())
231}
232
233fn boundary_limit_error(id: Expr, limit: McpBoundaryLimit) -> McpEnvelope {
234    match limit {
235        McpBoundaryLimit::Deadline => error_envelope(
236            id,
237            EXECUTION_ERROR,
238            "deadline exceeded",
239            "deadline exceeded".to_owned(),
240        ),
241        McpBoundaryLimit::Rate => error_envelope(
242            id,
243            RATE_LIMITED,
244            "rate limited",
245            "rate limit exceeded".to_owned(),
246        ),
247        McpBoundaryLimit::ActiveRequests => error_envelope(
248            id,
249            RATE_LIMITED,
250            "rate limited",
251            "active request limit exceeded".to_owned(),
252        ),
253    }
254}
255
256fn error_envelope(id: Expr, code: i64, message: &str, detail: String) -> McpEnvelope {
257    error_envelope_data(id, code, message, Expr::String(detail))
258}
259
260#[cfg(feature = "cassette")]
261impl McpRouter {
262    fn replay_request(
263        &mut self,
264        request: &McpEnvelope,
265        method: &str,
266    ) -> Result<Option<RouterReplies>> {
267        let Some(cassette) = self.session.cassette_mut() else {
268            return Ok(None);
269        };
270        let Some(replies) = cassette.replay(request)? else {
271            return Ok(None);
272        };
273        if let Some(operation) = auditable_operation(method) {
274            cassette.record_audit(method, operation, "replay");
275        }
276        Ok(Some(replies))
277    }
278
279    fn record_request(
280        &mut self,
281        request: &McpEnvelope,
282        method: &str,
283        replies: &[McpEnvelope],
284    ) -> Result<()> {
285        let Some(cassette) = self.session.cassette_mut() else {
286            return Ok(());
287        };
288        cassette.record_exchange(request, replies)?;
289        if let Some(operation) = auditable_operation(method) {
290            cassette.record_audit(method, operation, reply_outcome(replies));
291        }
292        Ok(())
293    }
294
295    fn audit_request(&mut self, method: &str, operation: &str, outcome: &str) {
296        if let Some(cassette) = self.session.cassette_mut() {
297            cassette.record_audit(method, operation, outcome);
298        }
299    }
300}
301
302#[cfg(feature = "cassette")]
303fn auditable_operation(method: &str) -> Option<&'static str> {
304    match method {
305        "tools/call" => Some("tools/call"),
306        "resources/read" => Some("resources/read"),
307        "prompts/get" => Some("prompts/get"),
308        "sampling/createMessage" => Some("sampling"),
309        _ => None,
310    }
311}
312
313#[cfg(feature = "cassette")]
314fn reply_outcome(replies: &[McpEnvelope]) -> &'static str {
315    if replies
316        .iter()
317        .any(|reply| matches!(reply, McpEnvelope::Error(_)))
318    {
319        "error"
320    } else {
321        "ok"
322    }
323}
324
325impl McpBoundaryLimit {
326    #[cfg(feature = "cassette")]
327    fn audit_outcome(self) -> &'static str {
328        match self {
329            Self::Deadline => "deadline-denied",
330            Self::Rate => "rate-denied",
331            Self::ActiveRequests => "active-denied",
332        }
333    }
334}
335
336fn error_envelope_data(id: Expr, code: i64, message: &str, data: Expr) -> McpEnvelope {
337    McpEnvelope::Error(McpErrorEnvelope {
338        id,
339        error: McpError {
340            code,
341            message: message.to_owned(),
342            data,
343        },
344    })
345}
346
347fn error_code_and_message(error: &Error) -> (i64, &'static str) {
348    match error {
349        Error::CapabilityDenied { .. } | Error::TrustDenied { .. } => {
350            (CAPABILITY_DENIED, "capability denied")
351        }
352        Error::TypeMismatch { .. }
353        | Error::WrongShape { .. }
354        | Error::NoMatchingOverload { .. } => (INVALID_PARAMS, "invalid params"),
355        Error::UnknownFunction { .. } => (METHOD_NOT_FOUND, "method not found"),
356        Error::HostError(_) | Error::PoisonedLock(_) => (INTERNAL_ERROR, "internal error"),
357        _ => (EXECUTION_ERROR, "execution error"),
358    }
359}