Skip to main content

tower_mcp/
router.rs

1//! MCP Router - routes requests to tools, resources, and prompts
2//!
3//! The router implements Tower's `Service` trait, making it composable with
4//! standard tower middleware.
5
6use std::collections::{HashMap, HashSet};
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::{Arc, RwLock};
10use std::task::{Context, Poll};
11
12use tower_service::Service;
13
14use crate::async_task::TaskStore;
15use crate::context::{
16    CancellationToken, ClientRequesterHandle, NotificationSender, RequestContext,
17    ServerNotification,
18};
19use crate::error::{Error, JsonRpcError, Result};
20use crate::filter::{PromptFilter, ResourceFilter, ToolFilter};
21use crate::prompt::Prompt;
22use crate::protocol::*;
23use crate::resource::{Resource, ResourceTemplate};
24use crate::session::SessionState;
25use crate::tool::Tool;
26
27/// Type alias for completion handler function
28pub type CompletionHandler = Arc<
29    dyn Fn(CompleteParams) -> Pin<Box<dyn Future<Output = Result<CompleteResult>> + Send>>
30        + Send
31        + Sync,
32>;
33
34/// MCP Router that dispatches requests to registered handlers
35///
36/// Implements `tower::Service<McpRequest>` for middleware composition.
37///
38/// # Example
39///
40/// ```rust
41/// use tower_mcp::{McpRouter, ToolBuilder, CallToolResult};
42/// use schemars::JsonSchema;
43/// use serde::Deserialize;
44///
45/// #[derive(Debug, Deserialize, JsonSchema)]
46/// struct Input { value: String }
47///
48/// let tool = ToolBuilder::new("echo")
49///     .description("Echo input")
50///     .handler(|i: Input| async move { Ok(CallToolResult::text(i.value)) })
51///     .build()
52///     .unwrap();
53///
54/// let router = McpRouter::new()
55///     .server_info("my-server", "1.0.0")
56///     .tool(tool);
57/// ```
58#[derive(Clone)]
59pub struct McpRouter {
60    inner: Arc<McpRouterInner>,
61    session: SessionState,
62}
63
64impl std::fmt::Debug for McpRouter {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        f.debug_struct("McpRouter")
67            .field("server_name", &self.inner.server_name)
68            .field("server_version", &self.inner.server_version)
69            .field("tools_count", &self.inner.tools.len())
70            .field("resources_count", &self.inner.resources.len())
71            .field("prompts_count", &self.inner.prompts.len())
72            .field("session_phase", &self.session.phase())
73            .finish()
74    }
75}
76
77/// Inner configuration that is shared across clones
78#[derive(Clone)]
79struct McpRouterInner {
80    server_name: String,
81    server_version: String,
82    /// Human-readable title for the server
83    server_title: Option<String>,
84    /// Description of the server
85    server_description: Option<String>,
86    /// Icons for the server
87    server_icons: Option<Vec<ToolIcon>>,
88    /// URL of the server's website
89    server_website_url: Option<String>,
90    instructions: Option<String>,
91    tools: HashMap<String, Arc<Tool>>,
92    resources: HashMap<String, Arc<Resource>>,
93    /// Resource templates for dynamic resource matching (keyed by uri_template)
94    resource_templates: Vec<Arc<ResourceTemplate>>,
95    prompts: HashMap<String, Arc<Prompt>>,
96    /// In-flight requests for cancellation tracking (shared across clones)
97    in_flight: Arc<RwLock<HashMap<RequestId, CancellationToken>>>,
98    /// Channel for sending notifications to connected clients
99    notification_tx: Option<NotificationSender>,
100    /// Handle for sending requests to the client (for sampling, etc.)
101    client_requester: Option<ClientRequesterHandle>,
102    /// Task store for async operations
103    task_store: TaskStore,
104    /// Subscribed resource URIs
105    subscriptions: Arc<RwLock<HashSet<String>>>,
106    /// Handler for completion requests
107    completion_handler: Option<CompletionHandler>,
108    /// Filter for tools based on session state
109    tool_filter: Option<ToolFilter>,
110    /// Filter for resources based on session state
111    resource_filter: Option<ResourceFilter>,
112    /// Filter for prompts based on session state
113    prompt_filter: Option<PromptFilter>,
114    /// Router-level extensions (for state and middleware data)
115    extensions: Arc<crate::context::Extensions>,
116}
117
118impl McpRouter {
119    /// Create a new MCP router
120    pub fn new() -> Self {
121        Self {
122            inner: Arc::new(McpRouterInner {
123                server_name: "tower-mcp".to_string(),
124                server_version: env!("CARGO_PKG_VERSION").to_string(),
125                server_title: None,
126                server_description: None,
127                server_icons: None,
128                server_website_url: None,
129                instructions: None,
130                tools: HashMap::new(),
131                resources: HashMap::new(),
132                resource_templates: Vec::new(),
133                prompts: HashMap::new(),
134                in_flight: Arc::new(RwLock::new(HashMap::new())),
135                notification_tx: None,
136                client_requester: None,
137                task_store: TaskStore::new(),
138                subscriptions: Arc::new(RwLock::new(HashSet::new())),
139                extensions: Arc::new(crate::context::Extensions::new()),
140                completion_handler: None,
141                tool_filter: None,
142                resource_filter: None,
143                prompt_filter: None,
144            }),
145            session: SessionState::new(),
146        }
147    }
148
149    /// Create a clone with fresh session state.
150    ///
151    /// Use this when creating a new logical session (e.g., per HTTP connection).
152    /// The router configuration (tools, resources, prompts) is shared, but the
153    /// session state (phase, extensions) is independent.
154    ///
155    /// This is typically called by transports when establishing a new client session.
156    pub fn with_fresh_session(&self) -> Self {
157        Self {
158            inner: self.inner.clone(),
159            session: SessionState::new(),
160        }
161    }
162
163    /// Get access to the task store for async operations
164    pub fn task_store(&self) -> &TaskStore {
165        &self.inner.task_store
166    }
167
168    /// Set the notification sender for progress reporting
169    ///
170    /// This is typically called by the transport layer to receive notifications.
171    pub fn with_notification_sender(mut self, tx: NotificationSender) -> Self {
172        Arc::make_mut(&mut self.inner).notification_tx = Some(tx);
173        self
174    }
175
176    /// Get the notification sender (if configured)
177    pub fn notification_sender(&self) -> Option<&NotificationSender> {
178        self.inner.notification_tx.as_ref()
179    }
180
181    /// Set the client requester for server-to-client requests (sampling, etc.)
182    ///
183    /// This is typically called by bidirectional transports (WebSocket, stdio)
184    /// to enable tool handlers to send requests to the client.
185    pub fn with_client_requester(mut self, requester: ClientRequesterHandle) -> Self {
186        Arc::make_mut(&mut self.inner).client_requester = Some(requester);
187        self
188    }
189
190    /// Get the client requester (if configured)
191    pub fn client_requester(&self) -> Option<&ClientRequesterHandle> {
192        self.inner.client_requester.as_ref()
193    }
194
195    /// Add router-level state that handlers can access via the `Extension<T>` extractor.
196    ///
197    /// This is the recommended way to share state across all tools, resources, and prompts
198    /// in a router. The state is available to handlers via the [`crate::extract::Extension`]
199    /// extractor.
200    ///
201    /// # Example
202    ///
203    /// ```rust
204    /// use std::sync::Arc;
205    /// use tower_mcp::{McpRouter, ToolBuilder, CallToolResult};
206    /// use tower_mcp::extract::{Extension, Json};
207    /// use schemars::JsonSchema;
208    /// use serde::Deserialize;
209    ///
210    /// #[derive(Clone)]
211    /// struct AppState {
212    ///     db_url: String,
213    /// }
214    ///
215    /// #[derive(Deserialize, JsonSchema)]
216    /// struct QueryInput {
217    ///     sql: String,
218    /// }
219    ///
220    /// let state = Arc::new(AppState { db_url: "postgres://...".into() });
221    ///
222    /// // Tool extracts state via Extension<T>
223    /// let query_tool = ToolBuilder::new("query")
224    ///     .description("Run a database query")
225    ///     .extractor_handler_typed::<_, _, _, QueryInput>(
226    ///         (),
227    ///         |Extension(state): Extension<Arc<AppState>>, Json(input): Json<QueryInput>| async move {
228    ///             Ok(CallToolResult::text(format!("Query on {}: {}", state.db_url, input.sql)))
229    ///         },
230    ///     )
231    ///     .build()
232    ///     .unwrap();
233    ///
234    /// let router = McpRouter::new()
235    ///     .with_state(state)  // State is now available to all handlers
236    ///     .tool(query_tool);
237    /// ```
238    pub fn with_state<T: Clone + Send + Sync + 'static>(mut self, state: T) -> Self {
239        let inner = Arc::make_mut(&mut self.inner);
240        Arc::make_mut(&mut inner.extensions).insert(state);
241        self
242    }
243
244    /// Add an extension value that handlers can access via the `Extension<T>` extractor.
245    ///
246    /// This is a more general form of `with_state()` for when you need multiple
247    /// typed values available to handlers.
248    pub fn with_extension<T: Clone + Send + Sync + 'static>(self, value: T) -> Self {
249        self.with_state(value)
250    }
251
252    /// Get the router's extensions.
253    pub fn extensions(&self) -> &crate::context::Extensions {
254        &self.inner.extensions
255    }
256
257    /// Create a request context for tracking a request
258    ///
259    /// This registers the request for cancellation tracking and sets up
260    /// progress reporting, client requests, and router extensions if configured.
261    pub fn create_context(
262        &self,
263        request_id: RequestId,
264        progress_token: Option<ProgressToken>,
265    ) -> RequestContext {
266        let ctx = RequestContext::new(request_id.clone());
267
268        // Set up progress token if provided
269        let ctx = if let Some(token) = progress_token {
270            ctx.with_progress_token(token)
271        } else {
272            ctx
273        };
274
275        // Set up notification sender if configured
276        let ctx = if let Some(tx) = &self.inner.notification_tx {
277            ctx.with_notification_sender(tx.clone())
278        } else {
279            ctx
280        };
281
282        // Set up client requester if configured (for sampling support)
283        let ctx = if let Some(requester) = &self.inner.client_requester {
284            ctx.with_client_requester(requester.clone())
285        } else {
286            ctx
287        };
288
289        // Include router extensions (for with_state() and middleware data)
290        let ctx = ctx.with_extensions(self.inner.extensions.clone());
291
292        // Register for cancellation tracking
293        let token = ctx.cancellation_token();
294        if let Ok(mut in_flight) = self.inner.in_flight.write() {
295            in_flight.insert(request_id, token);
296        }
297
298        ctx
299    }
300
301    /// Remove a request from tracking (called when request completes)
302    pub fn complete_request(&self, request_id: &RequestId) {
303        if let Ok(mut in_flight) = self.inner.in_flight.write() {
304            in_flight.remove(request_id);
305        }
306    }
307
308    /// Cancel a tracked request
309    fn cancel_request(&self, request_id: &RequestId) -> bool {
310        let Ok(in_flight) = self.inner.in_flight.read() else {
311            return false;
312        };
313        let Some(token) = in_flight.get(request_id) else {
314            return false;
315        };
316        token.cancel();
317        true
318    }
319
320    /// Set server info
321    pub fn server_info(mut self, name: impl Into<String>, version: impl Into<String>) -> Self {
322        let inner = Arc::make_mut(&mut self.inner);
323        inner.server_name = name.into();
324        inner.server_version = version.into();
325        self
326    }
327
328    /// Set instructions for LLMs describing how to use this server
329    pub fn instructions(mut self, instructions: impl Into<String>) -> Self {
330        Arc::make_mut(&mut self.inner).instructions = Some(instructions.into());
331        self
332    }
333
334    /// Set a human-readable title for the server
335    pub fn server_title(mut self, title: impl Into<String>) -> Self {
336        Arc::make_mut(&mut self.inner).server_title = Some(title.into());
337        self
338    }
339
340    /// Set the server description
341    pub fn server_description(mut self, description: impl Into<String>) -> Self {
342        Arc::make_mut(&mut self.inner).server_description = Some(description.into());
343        self
344    }
345
346    /// Set icons for the server
347    pub fn server_icons(mut self, icons: Vec<ToolIcon>) -> Self {
348        Arc::make_mut(&mut self.inner).server_icons = Some(icons);
349        self
350    }
351
352    /// Set the server's website URL
353    pub fn server_website_url(mut self, url: impl Into<String>) -> Self {
354        Arc::make_mut(&mut self.inner).server_website_url = Some(url.into());
355        self
356    }
357
358    /// Register a tool
359    pub fn tool(mut self, tool: Tool) -> Self {
360        Arc::make_mut(&mut self.inner)
361            .tools
362            .insert(tool.name.clone(), Arc::new(tool));
363        self
364    }
365
366    /// Register a resource
367    pub fn resource(mut self, resource: Resource) -> Self {
368        Arc::make_mut(&mut self.inner)
369            .resources
370            .insert(resource.uri.clone(), Arc::new(resource));
371        self
372    }
373
374    /// Register a resource template
375    ///
376    /// Resource templates allow dynamic resources to be matched by URI pattern.
377    /// When a client requests a resource URI that doesn't match any static
378    /// resource, the router tries to match it against registered templates.
379    ///
380    /// # Example
381    ///
382    /// ```rust
383    /// use tower_mcp::{McpRouter, ResourceTemplateBuilder};
384    /// use tower_mcp::protocol::{ReadResourceResult, ResourceContent};
385    /// use std::collections::HashMap;
386    ///
387    /// let template = ResourceTemplateBuilder::new("file:///{path}")
388    ///     .name("Project Files")
389    ///     .handler(|uri: String, vars: HashMap<String, String>| async move {
390    ///         let path = vars.get("path").unwrap_or(&String::new()).clone();
391    ///         Ok(ReadResourceResult {
392    ///             contents: vec![ResourceContent {
393    ///                 uri,
394    ///                 mime_type: Some("text/plain".to_string()),
395    ///                 text: Some(format!("Contents of {}", path)),
396    ///                 blob: None,
397    ///             }],
398    ///         })
399    ///     });
400    ///
401    /// let router = McpRouter::new()
402    ///     .resource_template(template);
403    /// ```
404    pub fn resource_template(mut self, template: ResourceTemplate) -> Self {
405        Arc::make_mut(&mut self.inner)
406            .resource_templates
407            .push(Arc::new(template));
408        self
409    }
410
411    /// Register a prompt
412    pub fn prompt(mut self, prompt: Prompt) -> Self {
413        Arc::make_mut(&mut self.inner)
414            .prompts
415            .insert(prompt.name.clone(), Arc::new(prompt));
416        self
417    }
418
419    /// Register multiple tools at once.
420    ///
421    /// # Example
422    ///
423    /// ```rust
424    /// use tower_mcp::{McpRouter, ToolBuilder, CallToolResult};
425    /// use schemars::JsonSchema;
426    /// use serde::Deserialize;
427    ///
428    /// #[derive(Debug, Deserialize, JsonSchema)]
429    /// struct Input { value: String }
430    ///
431    /// let tools = vec![
432    ///     ToolBuilder::new("a")
433    ///         .description("Tool A")
434    ///         .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
435    ///         .build().unwrap(),
436    ///     ToolBuilder::new("b")
437    ///         .description("Tool B")
438    ///         .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
439    ///         .build().unwrap(),
440    /// ];
441    ///
442    /// let router = McpRouter::new().tools(tools);
443    /// ```
444    pub fn tools(self, tools: impl IntoIterator<Item = Tool>) -> Self {
445        tools
446            .into_iter()
447            .fold(self, |router, tool| router.tool(tool))
448    }
449
450    /// Register multiple resources at once.
451    ///
452    /// # Example
453    ///
454    /// ```rust
455    /// use tower_mcp::{McpRouter, ResourceBuilder};
456    ///
457    /// let resources = vec![
458    ///     ResourceBuilder::new("file:///a.txt")
459    ///         .name("File A")
460    ///         .text("contents a"),
461    ///     ResourceBuilder::new("file:///b.txt")
462    ///         .name("File B")
463    ///         .text("contents b"),
464    /// ];
465    ///
466    /// let router = McpRouter::new().resources(resources);
467    /// ```
468    pub fn resources(self, resources: impl IntoIterator<Item = Resource>) -> Self {
469        resources
470            .into_iter()
471            .fold(self, |router, resource| router.resource(resource))
472    }
473
474    /// Register multiple prompts at once.
475    ///
476    /// # Example
477    ///
478    /// ```rust
479    /// use tower_mcp::{McpRouter, PromptBuilder};
480    ///
481    /// let prompts = vec![
482    ///     PromptBuilder::new("greet")
483    ///         .description("Greet someone")
484    ///         .user_message("Hello!"),
485    ///     PromptBuilder::new("farewell")
486    ///         .description("Say goodbye")
487    ///         .user_message("Goodbye!"),
488    /// ];
489    ///
490    /// let router = McpRouter::new().prompts(prompts);
491    /// ```
492    pub fn prompts(self, prompts: impl IntoIterator<Item = Prompt>) -> Self {
493        prompts
494            .into_iter()
495            .fold(self, |router, prompt| router.prompt(prompt))
496    }
497
498    /// Merge another router's capabilities into this one.
499    ///
500    /// This combines all tools, resources, resource templates, and prompts from
501    /// the other router into this router. Uses "last wins" semantics for conflicts,
502    /// meaning if both routers have a tool/resource/prompt with the same name,
503    /// the one from `other` will replace the one in `self`.
504    ///
505    /// Server info, instructions, filters, and other router-level configuration
506    /// are NOT merged - only the root router's settings are used.
507    ///
508    /// # Example
509    ///
510    /// ```rust
511    /// use tower_mcp::{McpRouter, ToolBuilder, CallToolResult, ResourceBuilder};
512    /// use schemars::JsonSchema;
513    /// use serde::Deserialize;
514    ///
515    /// #[derive(Debug, Deserialize, JsonSchema)]
516    /// struct Input { value: String }
517    ///
518    /// // Create a router with database tools
519    /// let db_tools = McpRouter::new()
520    ///     .tool(
521    ///         ToolBuilder::new("query")
522    ///             .description("Query the database")
523    ///             .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
524    ///             .build()
525    ///             .unwrap()
526    ///     );
527    ///
528    /// // Create a router with API tools
529    /// let api_tools = McpRouter::new()
530    ///     .tool(
531    ///         ToolBuilder::new("fetch")
532    ///             .description("Fetch from API")
533    ///             .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
534    ///             .build()
535    ///             .unwrap()
536    ///     );
537    ///
538    /// // Merge them together
539    /// let router = McpRouter::new()
540    ///     .server_info("combined", "1.0")
541    ///     .merge(db_tools)
542    ///     .merge(api_tools);
543    /// ```
544    pub fn merge(mut self, other: McpRouter) -> Self {
545        let inner = Arc::make_mut(&mut self.inner);
546        let other_inner = other.inner;
547
548        // Merge tools (last wins)
549        for (name, tool) in &other_inner.tools {
550            inner.tools.insert(name.clone(), tool.clone());
551        }
552
553        // Merge resources (last wins)
554        for (uri, resource) in &other_inner.resources {
555            inner.resources.insert(uri.clone(), resource.clone());
556        }
557
558        // Merge resource templates (append - no deduplication since templates
559        // can have complex matching behavior)
560        for template in &other_inner.resource_templates {
561            inner.resource_templates.push(template.clone());
562        }
563
564        // Merge prompts (last wins)
565        for (name, prompt) in &other_inner.prompts {
566            inner.prompts.insert(name.clone(), prompt.clone());
567        }
568
569        self
570    }
571
572    /// Nest another router's capabilities under a prefix.
573    ///
574    /// This is similar to `merge()`, but all tool names from the nested router
575    /// are prefixed with the given string and a dot separator. For example,
576    /// nesting with prefix "db" will turn a tool named "query" into "db.query".
577    ///
578    /// Resources, resource templates, and prompts are merged without modification
579    /// since they use URIs rather than simple names for identification.
580    ///
581    /// # Example
582    ///
583    /// ```rust
584    /// use tower_mcp::{McpRouter, ToolBuilder, CallToolResult};
585    /// use schemars::JsonSchema;
586    /// use serde::Deserialize;
587    ///
588    /// #[derive(Debug, Deserialize, JsonSchema)]
589    /// struct Input { value: String }
590    ///
591    /// // Create a router with database tools
592    /// let db_tools = McpRouter::new()
593    ///     .tool(
594    ///         ToolBuilder::new("query")
595    ///             .description("Query the database")
596    ///             .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
597    ///             .build()
598    ///             .unwrap()
599    ///     )
600    ///     .tool(
601    ///         ToolBuilder::new("insert")
602    ///             .description("Insert into database")
603    ///             .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
604    ///             .build()
605    ///             .unwrap()
606    ///     );
607    ///
608    /// // Nest under "db" prefix - tools become "db.query" and "db.insert"
609    /// let router = McpRouter::new()
610    ///     .server_info("combined", "1.0")
611    ///     .nest("db", db_tools);
612    /// ```
613    pub fn nest(mut self, prefix: impl Into<String>, other: McpRouter) -> Self {
614        let prefix = prefix.into();
615        let inner = Arc::make_mut(&mut self.inner);
616        let other_inner = other.inner;
617
618        // Nest tools with prefix
619        for tool in other_inner.tools.values() {
620            let prefixed_tool = tool.with_name_prefix(&prefix);
621            inner
622                .tools
623                .insert(prefixed_tool.name.clone(), Arc::new(prefixed_tool));
624        }
625
626        // Merge resources (no prefix - URIs are already namespaced)
627        for (uri, resource) in &other_inner.resources {
628            inner.resources.insert(uri.clone(), resource.clone());
629        }
630
631        // Merge resource templates (no prefix)
632        for template in &other_inner.resource_templates {
633            inner.resource_templates.push(template.clone());
634        }
635
636        // Merge prompts (no prefix - could be added in future if needed)
637        for (name, prompt) in &other_inner.prompts {
638            inner.prompts.insert(name.clone(), prompt.clone());
639        }
640
641        self
642    }
643
644    /// Register a completion handler for `completion/complete` requests.
645    ///
646    /// The handler receives `CompleteParams` containing the reference (prompt or resource)
647    /// and the argument being completed, and should return completion suggestions.
648    ///
649    /// # Example
650    ///
651    /// ```rust
652    /// use tower_mcp::{McpRouter, CompleteResult};
653    /// use tower_mcp::protocol::{CompleteParams, CompletionReference};
654    ///
655    /// let router = McpRouter::new()
656    ///     .completion_handler(|params: CompleteParams| async move {
657    ///         // Provide completions based on the reference and argument
658    ///         match params.reference {
659    ///             CompletionReference::Prompt { name } => {
660    ///                 // Return prompt argument completions
661    ///                 Ok(CompleteResult::new(vec!["option1".to_string(), "option2".to_string()]))
662    ///             }
663    ///             CompletionReference::Resource { uri } => {
664    ///                 // Return resource URI completions
665    ///                 Ok(CompleteResult::new(vec![]))
666    ///             }
667    ///         }
668    ///     });
669    /// ```
670    pub fn completion_handler<F, Fut>(mut self, handler: F) -> Self
671    where
672        F: Fn(CompleteParams) -> Fut + Send + Sync + 'static,
673        Fut: Future<Output = Result<CompleteResult>> + Send + 'static,
674    {
675        Arc::make_mut(&mut self.inner).completion_handler =
676            Some(Arc::new(move |params| Box::pin(handler(params))));
677        self
678    }
679
680    /// Set a filter for tools based on session state.
681    ///
682    /// The filter determines which tools are visible to each session. Tools that
683    /// don't pass the filter will not appear in `tools/list` responses and will
684    /// return an error if called directly.
685    ///
686    /// # Example
687    ///
688    /// ```rust
689    /// use tower_mcp::{McpRouter, ToolBuilder, CallToolResult, CapabilityFilter, Tool, Filterable};
690    /// use schemars::JsonSchema;
691    /// use serde::Deserialize;
692    ///
693    /// #[derive(Debug, Deserialize, JsonSchema)]
694    /// struct Input { value: String }
695    ///
696    /// let public_tool = ToolBuilder::new("public")
697    ///     .description("Available to everyone")
698    ///     .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
699    ///     .build()
700    ///     .unwrap();
701    ///
702    /// let admin_tool = ToolBuilder::new("admin")
703    ///     .description("Admin only")
704    ///     .handler(|i: Input| async move { Ok(CallToolResult::text(&i.value)) })
705    ///     .build()
706    ///     .unwrap();
707    ///
708    /// let router = McpRouter::new()
709    ///     .tool(public_tool)
710    ///     .tool(admin_tool)
711    ///     .tool_filter(CapabilityFilter::new(|_session, tool: &Tool| {
712    ///         // In real code, check session.extensions() for auth claims
713    ///         tool.name() != "admin"
714    ///     }));
715    /// ```
716    pub fn tool_filter(mut self, filter: ToolFilter) -> Self {
717        Arc::make_mut(&mut self.inner).tool_filter = Some(filter);
718        self
719    }
720
721    /// Set a filter for resources based on session state.
722    ///
723    /// The filter receives the current session state and each resource, returning
724    /// `true` if the resource should be visible to this session. Resources that
725    /// don't pass the filter will not appear in `resources/list` responses and will
726    /// return an error if read directly.
727    ///
728    /// # Example
729    ///
730    /// ```rust
731    /// use tower_mcp::{McpRouter, ResourceBuilder, ReadResourceResult, CapabilityFilter, Resource, Filterable};
732    ///
733    /// let public_resource = ResourceBuilder::new("file:///public.txt")
734    ///     .name("Public File")
735    ///     .description("Available to everyone")
736    ///     .text("public content");
737    ///
738    /// let secret_resource = ResourceBuilder::new("file:///secret.txt")
739    ///     .name("Secret File")
740    ///     .description("Admin only")
741    ///     .text("secret content");
742    ///
743    /// let router = McpRouter::new()
744    ///     .resource(public_resource)
745    ///     .resource(secret_resource)
746    ///     .resource_filter(CapabilityFilter::new(|_session, resource: &Resource| {
747    ///         // In real code, check session.extensions() for auth claims
748    ///         !resource.name().contains("Secret")
749    ///     }));
750    /// ```
751    pub fn resource_filter(mut self, filter: ResourceFilter) -> Self {
752        Arc::make_mut(&mut self.inner).resource_filter = Some(filter);
753        self
754    }
755
756    /// Set a filter for prompts based on session state.
757    ///
758    /// The filter receives the current session state and each prompt, returning
759    /// `true` if the prompt should be visible to this session. Prompts that
760    /// don't pass the filter will not appear in `prompts/list` responses and will
761    /// return an error if accessed directly.
762    ///
763    /// # Example
764    ///
765    /// ```rust
766    /// use tower_mcp::{McpRouter, PromptBuilder, CapabilityFilter, Prompt, Filterable};
767    ///
768    /// let public_prompt = PromptBuilder::new("greeting")
769    ///     .description("A friendly greeting")
770    ///     .user_message("Hello!");
771    ///
772    /// let admin_prompt = PromptBuilder::new("system_debug")
773    ///     .description("Admin debugging prompt")
774    ///     .user_message("Debug info");
775    ///
776    /// let router = McpRouter::new()
777    ///     .prompt(public_prompt)
778    ///     .prompt(admin_prompt)
779    ///     .prompt_filter(CapabilityFilter::new(|_session, prompt: &Prompt| {
780    ///         // In real code, check session.extensions() for auth claims
781    ///         !prompt.name().contains("system")
782    ///     }));
783    /// ```
784    pub fn prompt_filter(mut self, filter: PromptFilter) -> Self {
785        Arc::make_mut(&mut self.inner).prompt_filter = Some(filter);
786        self
787    }
788
789    /// Get access to the session state
790    pub fn session(&self) -> &SessionState {
791        &self.session
792    }
793
794    /// Send a log message notification to the client
795    ///
796    /// This sends a `notifications/message` notification with the given parameters.
797    /// Returns `true` if the notification was sent, `false` if no notification channel
798    /// is configured.
799    ///
800    /// # Example
801    ///
802    /// ```rust,ignore
803    /// use tower_mcp::protocol::{LogLevel, LoggingMessageParams};
804    ///
805    /// // Simple info message
806    /// router.log(LoggingMessageParams::new(LogLevel::Info).with_data(
807    ///     serde_json::json!({"message": "Operation completed"})
808    /// ));
809    ///
810    /// // Error with logger name
811    /// router.log(LoggingMessageParams::new(LogLevel::Error)
812    ///     .with_logger("database")
813    ///     .with_data(serde_json::json!({"error": "Connection failed"})));
814    /// ```
815    pub fn log(&self, params: LoggingMessageParams) -> bool {
816        let Some(tx) = &self.inner.notification_tx else {
817            return false;
818        };
819        tx.try_send(ServerNotification::LogMessage(params)).is_ok()
820    }
821
822    /// Send an info-level log message
823    ///
824    /// Convenience method for sending an info log with optional data.
825    pub fn log_info(&self, message: &str) -> bool {
826        self.log(
827            LoggingMessageParams::new(LogLevel::Info)
828                .with_data(serde_json::json!({ "message": message })),
829        )
830    }
831
832    /// Send a warning-level log message
833    pub fn log_warning(&self, message: &str) -> bool {
834        self.log(
835            LoggingMessageParams::new(LogLevel::Warning)
836                .with_data(serde_json::json!({ "message": message })),
837        )
838    }
839
840    /// Send an error-level log message
841    pub fn log_error(&self, message: &str) -> bool {
842        self.log(
843            LoggingMessageParams::new(LogLevel::Error)
844                .with_data(serde_json::json!({ "message": message })),
845        )
846    }
847
848    /// Send a debug-level log message
849    pub fn log_debug(&self, message: &str) -> bool {
850        self.log(
851            LoggingMessageParams::new(LogLevel::Debug)
852                .with_data(serde_json::json!({ "message": message })),
853        )
854    }
855
856    /// Check if a resource URI is currently subscribed
857    pub fn is_subscribed(&self, uri: &str) -> bool {
858        if let Ok(subs) = self.inner.subscriptions.read() {
859            return subs.contains(uri);
860        }
861        false
862    }
863
864    /// Get a list of all subscribed resource URIs
865    pub fn subscribed_uris(&self) -> Vec<String> {
866        if let Ok(subs) = self.inner.subscriptions.read() {
867            return subs.iter().cloned().collect();
868        }
869        Vec::new()
870    }
871
872    /// Subscribe to a resource URI
873    fn subscribe(&self, uri: &str) -> bool {
874        if let Ok(mut subs) = self.inner.subscriptions.write() {
875            return subs.insert(uri.to_string());
876        }
877        false
878    }
879
880    /// Unsubscribe from a resource URI
881    fn unsubscribe(&self, uri: &str) -> bool {
882        if let Ok(mut subs) = self.inner.subscriptions.write() {
883            return subs.remove(uri);
884        }
885        false
886    }
887
888    /// Notify clients that a subscribed resource has been updated
889    ///
890    /// Only sends the notification if the resource is currently subscribed.
891    /// Returns `true` if the notification was sent.
892    pub fn notify_resource_updated(&self, uri: &str) -> bool {
893        // Only notify if the resource is subscribed
894        if !self.is_subscribed(uri) {
895            return false;
896        }
897
898        let Some(tx) = &self.inner.notification_tx else {
899            return false;
900        };
901        tx.try_send(ServerNotification::ResourceUpdated {
902            uri: uri.to_string(),
903        })
904        .is_ok()
905    }
906
907    /// Notify clients that the list of available resources has changed
908    ///
909    /// Returns `true` if the notification was sent.
910    pub fn notify_resources_list_changed(&self) -> bool {
911        let Some(tx) = &self.inner.notification_tx else {
912            return false;
913        };
914        tx.try_send(ServerNotification::ResourcesListChanged)
915            .is_ok()
916    }
917
918    /// Get server capabilities based on registered handlers
919    fn capabilities(&self) -> ServerCapabilities {
920        let has_resources =
921            !self.inner.resources.is_empty() || !self.inner.resource_templates.is_empty();
922
923        ServerCapabilities {
924            tools: if self.inner.tools.is_empty() {
925                None
926            } else {
927                Some(ToolsCapability::default())
928            },
929            resources: if has_resources {
930                Some(ResourcesCapability {
931                    subscribe: true,
932                    ..Default::default()
933                })
934            } else {
935                None
936            },
937            prompts: if self.inner.prompts.is_empty() {
938                None
939            } else {
940                Some(PromptsCapability::default())
941            },
942            // Always advertise logging capability when notification channel is configured
943            logging: if self.inner.notification_tx.is_some() {
944                Some(LoggingCapability::default())
945            } else {
946                None
947            },
948            // Tasks capability is always available
949            tasks: Some(TasksCapability::default()),
950            // Completions capability when a handler is registered
951            completions: if self.inner.completion_handler.is_some() {
952                Some(CompletionsCapability::default())
953            } else {
954                None
955            },
956        }
957    }
958
959    /// Handle an MCP request
960    async fn handle(&self, request_id: RequestId, request: McpRequest) -> Result<McpResponse> {
961        // Enforce session state - reject requests before initialization
962        let method = request.method_name();
963        if !self.session.is_request_allowed(method) {
964            tracing::warn!(
965                method = %method,
966                phase = ?self.session.phase(),
967                "Request rejected: session not initialized"
968            );
969            return Err(Error::JsonRpc(JsonRpcError::invalid_request(format!(
970                "Session not initialized. Only 'initialize' and 'ping' are allowed before initialization. Got: {}",
971                method
972            ))));
973        }
974
975        match request {
976            McpRequest::Initialize(params) => {
977                tracing::info!(
978                    client = %params.client_info.name,
979                    version = %params.client_info.version,
980                    "Client initializing"
981                );
982
983                // Protocol version negotiation: respond with same version if supported,
984                // otherwise respond with our latest supported version
985                let protocol_version = if crate::protocol::SUPPORTED_PROTOCOL_VERSIONS
986                    .contains(&params.protocol_version.as_str())
987                {
988                    params.protocol_version
989                } else {
990                    crate::protocol::LATEST_PROTOCOL_VERSION.to_string()
991                };
992
993                // Transition session state to Initializing
994                self.session.mark_initializing();
995
996                Ok(McpResponse::Initialize(InitializeResult {
997                    protocol_version,
998                    capabilities: self.capabilities(),
999                    server_info: Implementation {
1000                        name: self.inner.server_name.clone(),
1001                        version: self.inner.server_version.clone(),
1002                        title: self.inner.server_title.clone(),
1003                        description: self.inner.server_description.clone(),
1004                        icons: self.inner.server_icons.clone(),
1005                        website_url: self.inner.server_website_url.clone(),
1006                    },
1007                    instructions: self.inner.instructions.clone(),
1008                }))
1009            }
1010
1011            McpRequest::ListTools(_params) => {
1012                let tools: Vec<ToolDefinition> = self
1013                    .inner
1014                    .tools
1015                    .values()
1016                    .filter(|t| {
1017                        // Apply tool filter if configured
1018                        self.inner
1019                            .tool_filter
1020                            .as_ref()
1021                            .map(|f| f.is_visible(&self.session, t))
1022                            .unwrap_or(true)
1023                    })
1024                    .map(|t| t.definition())
1025                    .collect();
1026
1027                Ok(McpResponse::ListTools(ListToolsResult {
1028                    tools,
1029                    next_cursor: None,
1030                }))
1031            }
1032
1033            McpRequest::CallTool(params) => {
1034                let tool =
1035                    self.inner.tools.get(&params.name).ok_or_else(|| {
1036                        Error::JsonRpc(JsonRpcError::method_not_found(&params.name))
1037                    })?;
1038
1039                // Check tool filter if configured
1040                if let Some(filter) = &self.inner.tool_filter {
1041                    if !filter.is_visible(&self.session, tool) {
1042                        return Err(filter.denial_error(&params.name));
1043                    }
1044                }
1045
1046                // Extract progress token from request metadata
1047                let progress_token = params.meta.and_then(|m| m.progress_token);
1048                let ctx = self.create_context(request_id, progress_token);
1049
1050                tracing::debug!(tool = %params.name, "Calling tool");
1051                let result = tool.call_with_context(ctx, params.arguments).await;
1052
1053                Ok(McpResponse::CallTool(result))
1054            }
1055
1056            McpRequest::ListResources(_params) => {
1057                let resources: Vec<ResourceDefinition> = self
1058                    .inner
1059                    .resources
1060                    .values()
1061                    .filter(|r| {
1062                        // Apply resource filter if configured
1063                        self.inner
1064                            .resource_filter
1065                            .as_ref()
1066                            .map(|f| f.is_visible(&self.session, r))
1067                            .unwrap_or(true)
1068                    })
1069                    .map(|r| r.definition())
1070                    .collect();
1071
1072                Ok(McpResponse::ListResources(ListResourcesResult {
1073                    resources,
1074                    next_cursor: None,
1075                }))
1076            }
1077
1078            McpRequest::ListResourceTemplates(_params) => {
1079                let resource_templates: Vec<ResourceTemplateDefinition> = self
1080                    .inner
1081                    .resource_templates
1082                    .iter()
1083                    .map(|t| t.definition())
1084                    .collect();
1085
1086                Ok(McpResponse::ListResourceTemplates(
1087                    ListResourceTemplatesResult {
1088                        resource_templates,
1089                        next_cursor: None,
1090                    },
1091                ))
1092            }
1093
1094            McpRequest::ReadResource(params) => {
1095                // First, try to find a static resource
1096                if let Some(resource) = self.inner.resources.get(&params.uri) {
1097                    // Check resource filter if configured
1098                    if let Some(filter) = &self.inner.resource_filter {
1099                        if !filter.is_visible(&self.session, resource) {
1100                            return Err(filter.denial_error(&params.uri));
1101                        }
1102                    }
1103
1104                    tracing::debug!(uri = %params.uri, "Reading static resource");
1105                    let result = resource.read().await;
1106                    return Ok(McpResponse::ReadResource(result));
1107                }
1108
1109                // If no static resource found, try to match against templates
1110                for template in &self.inner.resource_templates {
1111                    if let Some(variables) = template.match_uri(&params.uri) {
1112                        tracing::debug!(
1113                            uri = %params.uri,
1114                            template = %template.uri_template,
1115                            "Reading resource via template"
1116                        );
1117                        let result = template.read(&params.uri, variables).await?;
1118                        return Ok(McpResponse::ReadResource(result));
1119                    }
1120                }
1121
1122                // No match found
1123                Err(Error::JsonRpc(JsonRpcError::resource_not_found(
1124                    &params.uri,
1125                )))
1126            }
1127
1128            McpRequest::SubscribeResource(params) => {
1129                // Verify the resource exists
1130                if !self.inner.resources.contains_key(&params.uri) {
1131                    return Err(Error::JsonRpc(JsonRpcError::resource_not_found(
1132                        &params.uri,
1133                    )));
1134                }
1135
1136                tracing::debug!(uri = %params.uri, "Subscribing to resource");
1137                self.subscribe(&params.uri);
1138
1139                Ok(McpResponse::SubscribeResource(EmptyResult {}))
1140            }
1141
1142            McpRequest::UnsubscribeResource(params) => {
1143                // Verify the resource exists
1144                if !self.inner.resources.contains_key(&params.uri) {
1145                    return Err(Error::JsonRpc(JsonRpcError::resource_not_found(
1146                        &params.uri,
1147                    )));
1148                }
1149
1150                tracing::debug!(uri = %params.uri, "Unsubscribing from resource");
1151                self.unsubscribe(&params.uri);
1152
1153                Ok(McpResponse::UnsubscribeResource(EmptyResult {}))
1154            }
1155
1156            McpRequest::ListPrompts(_params) => {
1157                let prompts: Vec<PromptDefinition> = self
1158                    .inner
1159                    .prompts
1160                    .values()
1161                    .filter(|p| {
1162                        // Apply prompt filter if configured
1163                        self.inner
1164                            .prompt_filter
1165                            .as_ref()
1166                            .map(|f| f.is_visible(&self.session, p))
1167                            .unwrap_or(true)
1168                    })
1169                    .map(|p| p.definition())
1170                    .collect();
1171
1172                Ok(McpResponse::ListPrompts(ListPromptsResult {
1173                    prompts,
1174                    next_cursor: None,
1175                }))
1176            }
1177
1178            McpRequest::GetPrompt(params) => {
1179                let prompt = self.inner.prompts.get(&params.name).ok_or_else(|| {
1180                    Error::JsonRpc(JsonRpcError::method_not_found(&format!(
1181                        "Prompt not found: {}",
1182                        params.name
1183                    )))
1184                })?;
1185
1186                // Check prompt filter if configured
1187                if let Some(filter) = &self.inner.prompt_filter {
1188                    if !filter.is_visible(&self.session, prompt) {
1189                        return Err(filter.denial_error(&params.name));
1190                    }
1191                }
1192
1193                tracing::debug!(name = %params.name, "Getting prompt");
1194                let result = prompt.get(params.arguments).await?;
1195
1196                Ok(McpResponse::GetPrompt(result))
1197            }
1198
1199            McpRequest::Ping => Ok(McpResponse::Pong(EmptyResult {})),
1200
1201            McpRequest::EnqueueTask(params) => {
1202                // Verify the tool exists
1203                let tool = self.inner.tools.get(&params.tool_name).ok_or_else(|| {
1204                    Error::JsonRpc(JsonRpcError::method_not_found(&format!(
1205                        "Tool not found: {}",
1206                        params.tool_name
1207                    )))
1208                })?;
1209
1210                // Create the task
1211                let (task_id, cancellation_token) = self.inner.task_store.create_task(
1212                    &params.tool_name,
1213                    params.arguments.clone(),
1214                    params.ttl,
1215                );
1216
1217                tracing::info!(task_id = %task_id, tool = %params.tool_name, "Enqueued async task");
1218
1219                // Create a context for the async task execution
1220                let ctx = self.create_context(request_id, None);
1221
1222                // Spawn the task execution in the background
1223                let task_store = self.inner.task_store.clone();
1224                let tool = tool.clone();
1225                let arguments = params.arguments;
1226                let task_id_clone = task_id.clone();
1227
1228                tokio::spawn(async move {
1229                    // Check for cancellation before starting
1230                    if cancellation_token.is_cancelled() {
1231                        tracing::debug!(task_id = %task_id_clone, "Task cancelled before execution");
1232                        return;
1233                    }
1234
1235                    // Execute the tool
1236                    let result = tool.call_with_context(ctx, arguments).await;
1237
1238                    if cancellation_token.is_cancelled() {
1239                        tracing::debug!(task_id = %task_id_clone, "Task cancelled during execution");
1240                    } else if result.is_error {
1241                        // Tool returned an error result
1242                        let error_msg = result.first_text().unwrap_or("Tool execution failed");
1243                        task_store.fail_task(&task_id_clone, error_msg);
1244                        tracing::warn!(task_id = %task_id_clone, error = %error_msg, "Task failed");
1245                    } else {
1246                        task_store.complete_task(&task_id_clone, result);
1247                        tracing::debug!(task_id = %task_id_clone, "Task completed successfully");
1248                    }
1249                });
1250
1251                Ok(McpResponse::EnqueueTask(EnqueueTaskResult {
1252                    task_id,
1253                    status: TaskStatus::Working,
1254                    poll_interval: Some(2),
1255                }))
1256            }
1257
1258            McpRequest::ListTasks(params) => {
1259                let tasks = self.inner.task_store.list_tasks(params.status);
1260
1261                Ok(McpResponse::ListTasks(ListTasksResult {
1262                    tasks,
1263                    next_cursor: None,
1264                }))
1265            }
1266
1267            McpRequest::GetTaskInfo(params) => {
1268                let task = self
1269                    .inner
1270                    .task_store
1271                    .get_task(&params.task_id)
1272                    .ok_or_else(|| {
1273                        Error::JsonRpc(JsonRpcError::invalid_params(format!(
1274                            "Task not found: {}",
1275                            params.task_id
1276                        )))
1277                    })?;
1278
1279                Ok(McpResponse::GetTaskInfo(task))
1280            }
1281
1282            McpRequest::GetTaskResult(params) => {
1283                let (status, result, error) = self
1284                    .inner
1285                    .task_store
1286                    .get_task_full(&params.task_id)
1287                    .ok_or_else(|| {
1288                        Error::JsonRpc(JsonRpcError::invalid_params(format!(
1289                            "Task not found: {}",
1290                            params.task_id
1291                        )))
1292                    })?;
1293
1294                Ok(McpResponse::GetTaskResult(GetTaskResultResult {
1295                    task_id: params.task_id,
1296                    status,
1297                    result,
1298                    error,
1299                }))
1300            }
1301
1302            McpRequest::CancelTask(params) => {
1303                let status = self
1304                    .inner
1305                    .task_store
1306                    .cancel_task(&params.task_id, params.reason.as_deref())
1307                    .ok_or_else(|| {
1308                        Error::JsonRpc(JsonRpcError::invalid_params(format!(
1309                            "Task not found: {}",
1310                            params.task_id
1311                        )))
1312                    })?;
1313
1314                let cancelled = status == TaskStatus::Cancelled;
1315
1316                Ok(McpResponse::CancelTask(CancelTaskResult {
1317                    cancelled,
1318                    status,
1319                }))
1320            }
1321
1322            McpRequest::SetLoggingLevel(params) => {
1323                // Store the log level for filtering outgoing log notifications
1324                // For now, we just accept the request - actual filtering would be
1325                // implemented in the notification sending logic
1326                tracing::debug!(level = ?params.level, "Client set logging level");
1327                Ok(McpResponse::SetLoggingLevel(EmptyResult {}))
1328            }
1329
1330            McpRequest::Complete(params) => {
1331                tracing::debug!(
1332                    reference = ?params.reference,
1333                    argument = %params.argument.name,
1334                    "Completion request"
1335                );
1336
1337                // Delegate to registered completion handler if available
1338                if let Some(ref handler) = self.inner.completion_handler {
1339                    let result = handler(params).await?;
1340                    Ok(McpResponse::Complete(result))
1341                } else {
1342                    // No completion handler registered, return empty completions
1343                    Ok(McpResponse::Complete(CompleteResult::new(vec![])))
1344                }
1345            }
1346
1347            McpRequest::Unknown { method, .. } => {
1348                Err(Error::JsonRpc(JsonRpcError::method_not_found(&method)))
1349            }
1350        }
1351    }
1352
1353    /// Handle an MCP notification (no response expected)
1354    pub fn handle_notification(&self, notification: McpNotification) {
1355        match notification {
1356            McpNotification::Initialized => {
1357                if self.session.mark_initialized() {
1358                    tracing::info!("Session initialized, entering operation phase");
1359                } else {
1360                    tracing::warn!(
1361                        "Received initialized notification in unexpected state: {:?}",
1362                        self.session.phase()
1363                    );
1364                }
1365            }
1366            McpNotification::Cancelled(params) => {
1367                if self.cancel_request(&params.request_id) {
1368                    tracing::info!(
1369                        request_id = ?params.request_id,
1370                        reason = ?params.reason,
1371                        "Request cancelled"
1372                    );
1373                } else {
1374                    tracing::debug!(
1375                        request_id = ?params.request_id,
1376                        reason = ?params.reason,
1377                        "Cancellation requested for unknown request"
1378                    );
1379                }
1380            }
1381            McpNotification::Progress(params) => {
1382                tracing::trace!(
1383                    token = ?params.progress_token,
1384                    progress = params.progress,
1385                    total = ?params.total,
1386                    "Progress notification"
1387                );
1388                // Progress notifications from client are unusual but valid
1389            }
1390            McpNotification::RootsListChanged => {
1391                tracing::info!("Client roots list changed");
1392                // Server should re-request roots if needed
1393                // This is handled by the application layer
1394            }
1395            McpNotification::Unknown { method, .. } => {
1396                tracing::debug!(method = %method, "Unknown notification received");
1397            }
1398        }
1399    }
1400}
1401
1402impl Default for McpRouter {
1403    fn default() -> Self {
1404        Self::new()
1405    }
1406}
1407
1408// =============================================================================
1409// Tower Service implementation
1410// =============================================================================
1411
1412// Re-export Extensions from context for backwards compatibility
1413pub use crate::context::Extensions;
1414
1415/// Request type for the tower Service implementation
1416#[derive(Debug, Clone)]
1417pub struct RouterRequest {
1418    pub id: RequestId,
1419    pub inner: McpRequest,
1420    /// Type-map for passing data (e.g., `TokenClaims`) through middleware.
1421    pub extensions: Extensions,
1422}
1423
1424/// Response type for the tower Service implementation
1425#[derive(Debug, Clone)]
1426pub struct RouterResponse {
1427    pub id: RequestId,
1428    pub inner: std::result::Result<McpResponse, JsonRpcError>,
1429}
1430
1431impl RouterResponse {
1432    /// Convert to JSON-RPC response
1433    pub fn into_jsonrpc(self) -> JsonRpcResponse {
1434        match self.inner {
1435            Ok(response) => match serde_json::to_value(response) {
1436                Ok(result) => JsonRpcResponse::result(self.id, result),
1437                Err(e) => {
1438                    tracing::error!(error = %e, "Failed to serialize response");
1439                    JsonRpcResponse::error(
1440                        Some(self.id),
1441                        JsonRpcError::internal_error(format!("Serialization error: {}", e)),
1442                    )
1443                }
1444            },
1445            Err(error) => JsonRpcResponse::error(Some(self.id), error),
1446        }
1447    }
1448}
1449
1450impl Service<RouterRequest> for McpRouter {
1451    type Response = RouterResponse;
1452    type Error = std::convert::Infallible; // Errors are in the response
1453    type Future =
1454        Pin<Box<dyn Future<Output = std::result::Result<Self::Response, Self::Error>> + Send>>;
1455
1456    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
1457        Poll::Ready(Ok(()))
1458    }
1459
1460    fn call(&mut self, req: RouterRequest) -> Self::Future {
1461        let router = self.clone();
1462        let request_id = req.id.clone();
1463        Box::pin(async move {
1464            let result = router.handle(req.id, req.inner).await;
1465            // Clean up tracking after request completes
1466            router.complete_request(&request_id);
1467            Ok(RouterResponse {
1468                id: request_id,
1469                // Map tower-mcp errors to JSON-RPC errors:
1470                // - Error::JsonRpc: forwarded as-is (preserves original code)
1471                // - Error::Tool: mapped to -32603 (Internal Error)
1472                // - All others: mapped to -32603 (Internal Error)
1473                inner: result.map_err(|e| match e {
1474                    Error::JsonRpc(err) => err,
1475                    Error::Tool(err) => JsonRpcError::internal_error(err.to_string()),
1476                    e => JsonRpcError::internal_error(e.to_string()),
1477                }),
1478            })
1479        })
1480    }
1481}
1482
1483#[cfg(test)]
1484mod tests {
1485    use super::*;
1486    use crate::extract::{Context, Json};
1487    use crate::jsonrpc::JsonRpcService;
1488    use crate::tool::ToolBuilder;
1489    use schemars::JsonSchema;
1490    use serde::Deserialize;
1491    use tower::ServiceExt;
1492
1493    #[derive(Debug, Deserialize, JsonSchema)]
1494    struct AddInput {
1495        a: i64,
1496        b: i64,
1497    }
1498
1499    /// Helper to initialize a router for testing
1500    async fn init_router(router: &mut McpRouter) {
1501        // Send initialize request
1502        let init_req = RouterRequest {
1503            id: RequestId::Number(0),
1504            inner: McpRequest::Initialize(InitializeParams {
1505                protocol_version: "2025-11-25".to_string(),
1506                capabilities: ClientCapabilities {
1507                    roots: None,
1508                    sampling: None,
1509                    elicitation: None,
1510                },
1511                client_info: Implementation {
1512                    name: "test".to_string(),
1513                    version: "1.0".to_string(),
1514                    ..Default::default()
1515                },
1516            }),
1517            extensions: Extensions::new(),
1518        };
1519        let _ = router.ready().await.unwrap().call(init_req).await.unwrap();
1520        // Send initialized notification
1521        router.handle_notification(McpNotification::Initialized);
1522    }
1523
1524    #[tokio::test]
1525    async fn test_router_list_tools() {
1526        let add_tool = ToolBuilder::new("add")
1527            .description("Add two numbers")
1528            .handler(|input: AddInput| async move {
1529                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1530            })
1531            .build()
1532            .expect("valid tool name");
1533
1534        let mut router = McpRouter::new().tool(add_tool);
1535
1536        // Initialize session first
1537        init_router(&mut router).await;
1538
1539        let req = RouterRequest {
1540            id: RequestId::Number(1),
1541            inner: McpRequest::ListTools(ListToolsParams::default()),
1542            extensions: Extensions::new(),
1543        };
1544
1545        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1546
1547        match resp.inner {
1548            Ok(McpResponse::ListTools(result)) => {
1549                assert_eq!(result.tools.len(), 1);
1550                assert_eq!(result.tools[0].name, "add");
1551            }
1552            _ => panic!("Expected ListTools response"),
1553        }
1554    }
1555
1556    #[tokio::test]
1557    async fn test_router_call_tool() {
1558        let add_tool = ToolBuilder::new("add")
1559            .description("Add two numbers")
1560            .handler(|input: AddInput| async move {
1561                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1562            })
1563            .build()
1564            .expect("valid tool name");
1565
1566        let mut router = McpRouter::new().tool(add_tool);
1567
1568        // Initialize session first
1569        init_router(&mut router).await;
1570
1571        let req = RouterRequest {
1572            id: RequestId::Number(1),
1573            inner: McpRequest::CallTool(CallToolParams {
1574                name: "add".to_string(),
1575                arguments: serde_json::json!({"a": 2, "b": 3}),
1576                meta: None,
1577            }),
1578            extensions: Extensions::new(),
1579        };
1580
1581        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1582
1583        match resp.inner {
1584            Ok(McpResponse::CallTool(result)) => {
1585                assert!(!result.is_error);
1586                // Check the text content
1587                match &result.content[0] {
1588                    Content::Text { text, .. } => assert_eq!(text, "5"),
1589                    _ => panic!("Expected text content"),
1590                }
1591            }
1592            _ => panic!("Expected CallTool response"),
1593        }
1594    }
1595
1596    /// Helper to initialize a JsonRpcService for testing
1597    async fn init_jsonrpc_service(service: &mut JsonRpcService<McpRouter>, router: &McpRouter) {
1598        let init_req = JsonRpcRequest::new(0, "initialize").with_params(serde_json::json!({
1599            "protocolVersion": "2025-11-25",
1600            "capabilities": {},
1601            "clientInfo": { "name": "test", "version": "1.0" }
1602        }));
1603        let _ = service.call_single(init_req).await.unwrap();
1604        router.handle_notification(McpNotification::Initialized);
1605    }
1606
1607    #[tokio::test]
1608    async fn test_jsonrpc_service() {
1609        let add_tool = ToolBuilder::new("add")
1610            .description("Add two numbers")
1611            .handler(|input: AddInput| async move {
1612                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1613            })
1614            .build()
1615            .expect("valid tool name");
1616
1617        let router = McpRouter::new().tool(add_tool);
1618        let mut service = JsonRpcService::new(router.clone());
1619
1620        // Initialize session first
1621        init_jsonrpc_service(&mut service, &router).await;
1622
1623        let req = JsonRpcRequest::new(1, "tools/list");
1624
1625        let resp = service.call_single(req).await.unwrap();
1626
1627        match resp {
1628            JsonRpcResponse::Result(r) => {
1629                assert_eq!(r.id, RequestId::Number(1));
1630                let tools = r.result.get("tools").unwrap().as_array().unwrap();
1631                assert_eq!(tools.len(), 1);
1632            }
1633            JsonRpcResponse::Error(_) => panic!("Expected success response"),
1634        }
1635    }
1636
1637    #[tokio::test]
1638    async fn test_batch_request() {
1639        let add_tool = ToolBuilder::new("add")
1640            .description("Add two numbers")
1641            .handler(|input: AddInput| async move {
1642                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1643            })
1644            .build()
1645            .expect("valid tool name");
1646
1647        let router = McpRouter::new().tool(add_tool);
1648        let mut service = JsonRpcService::new(router.clone());
1649
1650        // Initialize session first
1651        init_jsonrpc_service(&mut service, &router).await;
1652
1653        // Create a batch of requests
1654        let requests = vec![
1655            JsonRpcRequest::new(1, "tools/list"),
1656            JsonRpcRequest::new(2, "tools/call").with_params(serde_json::json!({
1657                "name": "add",
1658                "arguments": {"a": 10, "b": 20}
1659            })),
1660            JsonRpcRequest::new(3, "ping"),
1661        ];
1662
1663        let responses = service.call_batch(requests).await.unwrap();
1664
1665        assert_eq!(responses.len(), 3);
1666
1667        // Check first response (tools/list)
1668        match &responses[0] {
1669            JsonRpcResponse::Result(r) => {
1670                assert_eq!(r.id, RequestId::Number(1));
1671                let tools = r.result.get("tools").unwrap().as_array().unwrap();
1672                assert_eq!(tools.len(), 1);
1673            }
1674            JsonRpcResponse::Error(_) => panic!("Expected success for tools/list"),
1675        }
1676
1677        // Check second response (tools/call)
1678        match &responses[1] {
1679            JsonRpcResponse::Result(r) => {
1680                assert_eq!(r.id, RequestId::Number(2));
1681                let content = r.result.get("content").unwrap().as_array().unwrap();
1682                let text = content[0].get("text").unwrap().as_str().unwrap();
1683                assert_eq!(text, "30");
1684            }
1685            JsonRpcResponse::Error(_) => panic!("Expected success for tools/call"),
1686        }
1687
1688        // Check third response (ping)
1689        match &responses[2] {
1690            JsonRpcResponse::Result(r) => {
1691                assert_eq!(r.id, RequestId::Number(3));
1692            }
1693            JsonRpcResponse::Error(_) => panic!("Expected success for ping"),
1694        }
1695    }
1696
1697    #[tokio::test]
1698    async fn test_empty_batch_error() {
1699        let router = McpRouter::new();
1700        let mut service = JsonRpcService::new(router);
1701
1702        let result = service.call_batch(vec![]).await;
1703        assert!(result.is_err());
1704    }
1705
1706    // =========================================================================
1707    // Progress Token Tests
1708    // =========================================================================
1709
1710    #[tokio::test]
1711    async fn test_progress_token_extraction() {
1712        use crate::context::{ServerNotification, notification_channel};
1713        use crate::protocol::ProgressToken;
1714        use std::sync::Arc;
1715        use std::sync::atomic::{AtomicBool, Ordering};
1716
1717        // Track whether progress was reported
1718        let progress_reported = Arc::new(AtomicBool::new(false));
1719        let progress_ref = progress_reported.clone();
1720
1721        // Create a tool that reports progress
1722        let tool = ToolBuilder::new("progress_tool")
1723            .description("Tool that reports progress")
1724            .extractor_handler_typed::<_, _, _, AddInput>(
1725                (),
1726                move |ctx: Context, Json(_input): Json<AddInput>| {
1727                    let reported = progress_ref.clone();
1728                    async move {
1729                        // Report progress - this should work if token was extracted
1730                        ctx.report_progress(50.0, Some(100.0), Some("Halfway"))
1731                            .await;
1732                        reported.store(true, Ordering::SeqCst);
1733                        Ok(CallToolResult::text("done"))
1734                    }
1735                },
1736            )
1737            .build()
1738            .expect("valid tool name");
1739
1740        // Set up notification channel
1741        let (tx, mut rx) = notification_channel(10);
1742        let router = McpRouter::new().with_notification_sender(tx).tool(tool);
1743        let mut service = JsonRpcService::new(router.clone());
1744
1745        // Initialize
1746        init_jsonrpc_service(&mut service, &router).await;
1747
1748        // Call tool WITH progress token in _meta
1749        let req = JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1750            "name": "progress_tool",
1751            "arguments": {"a": 1, "b": 2},
1752            "_meta": {
1753                "progressToken": "test-token-123"
1754            }
1755        }));
1756
1757        let resp = service.call_single(req).await.unwrap();
1758
1759        // Verify the tool was called successfully
1760        match resp {
1761            JsonRpcResponse::Result(_) => {}
1762            JsonRpcResponse::Error(e) => panic!("Expected success, got error: {:?}", e),
1763        }
1764
1765        // Verify progress was reported by handler
1766        assert!(progress_reported.load(Ordering::SeqCst));
1767
1768        // Verify progress notification was sent through channel
1769        let notification = rx.try_recv().expect("Expected progress notification");
1770        match notification {
1771            ServerNotification::Progress(params) => {
1772                assert_eq!(
1773                    params.progress_token,
1774                    ProgressToken::String("test-token-123".to_string())
1775                );
1776                assert_eq!(params.progress, 50.0);
1777                assert_eq!(params.total, Some(100.0));
1778                assert_eq!(params.message.as_deref(), Some("Halfway"));
1779            }
1780            _ => panic!("Expected Progress notification"),
1781        }
1782    }
1783
1784    #[tokio::test]
1785    async fn test_tool_call_without_progress_token() {
1786        use crate::context::notification_channel;
1787        use std::sync::Arc;
1788        use std::sync::atomic::{AtomicBool, Ordering};
1789
1790        let progress_attempted = Arc::new(AtomicBool::new(false));
1791        let progress_ref = progress_attempted.clone();
1792
1793        let tool = ToolBuilder::new("no_token_tool")
1794            .description("Tool that tries to report progress without token")
1795            .extractor_handler_typed::<_, _, _, AddInput>(
1796                (),
1797                move |ctx: Context, Json(_input): Json<AddInput>| {
1798                    let attempted = progress_ref.clone();
1799                    async move {
1800                        // Try to report progress - should be a no-op without token
1801                        ctx.report_progress(50.0, Some(100.0), None).await;
1802                        attempted.store(true, Ordering::SeqCst);
1803                        Ok(CallToolResult::text("done"))
1804                    }
1805                },
1806            )
1807            .build()
1808            .expect("valid tool name");
1809
1810        let (tx, mut rx) = notification_channel(10);
1811        let router = McpRouter::new().with_notification_sender(tx).tool(tool);
1812        let mut service = JsonRpcService::new(router.clone());
1813
1814        init_jsonrpc_service(&mut service, &router).await;
1815
1816        // Call tool WITHOUT progress token
1817        let req = JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1818            "name": "no_token_tool",
1819            "arguments": {"a": 1, "b": 2}
1820        }));
1821
1822        let resp = service.call_single(req).await.unwrap();
1823        assert!(matches!(resp, JsonRpcResponse::Result(_)));
1824
1825        // Handler was called
1826        assert!(progress_attempted.load(Ordering::SeqCst));
1827
1828        // But no notification was sent (no progress token)
1829        assert!(rx.try_recv().is_err());
1830    }
1831
1832    #[tokio::test]
1833    async fn test_batch_errors_returned_not_dropped() {
1834        let add_tool = ToolBuilder::new("add")
1835            .description("Add two numbers")
1836            .handler(|input: AddInput| async move {
1837                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
1838            })
1839            .build()
1840            .expect("valid tool name");
1841
1842        let router = McpRouter::new().tool(add_tool);
1843        let mut service = JsonRpcService::new(router.clone());
1844
1845        init_jsonrpc_service(&mut service, &router).await;
1846
1847        // Create a batch with one valid and one invalid request
1848        let requests = vec![
1849            // Valid request
1850            JsonRpcRequest::new(1, "tools/call").with_params(serde_json::json!({
1851                "name": "add",
1852                "arguments": {"a": 10, "b": 20}
1853            })),
1854            // Invalid request - tool doesn't exist
1855            JsonRpcRequest::new(2, "tools/call").with_params(serde_json::json!({
1856                "name": "nonexistent_tool",
1857                "arguments": {}
1858            })),
1859            // Another valid request
1860            JsonRpcRequest::new(3, "ping"),
1861        ];
1862
1863        let responses = service.call_batch(requests).await.unwrap();
1864
1865        // All three requests should have responses (errors are not dropped)
1866        assert_eq!(responses.len(), 3);
1867
1868        // First should be success
1869        match &responses[0] {
1870            JsonRpcResponse::Result(r) => {
1871                assert_eq!(r.id, RequestId::Number(1));
1872            }
1873            JsonRpcResponse::Error(_) => panic!("Expected success for first request"),
1874        }
1875
1876        // Second should be an error (tool not found)
1877        match &responses[1] {
1878            JsonRpcResponse::Error(e) => {
1879                assert_eq!(e.id, Some(RequestId::Number(2)));
1880                // Error should indicate method not found
1881                assert!(e.error.message.contains("not found") || e.error.code == -32601);
1882            }
1883            JsonRpcResponse::Result(_) => panic!("Expected error for second request"),
1884        }
1885
1886        // Third should be success
1887        match &responses[2] {
1888            JsonRpcResponse::Result(r) => {
1889                assert_eq!(r.id, RequestId::Number(3));
1890            }
1891            JsonRpcResponse::Error(_) => panic!("Expected success for third request"),
1892        }
1893    }
1894
1895    // =========================================================================
1896    // Resource Template Tests
1897    // =========================================================================
1898
1899    #[tokio::test]
1900    async fn test_list_resource_templates() {
1901        use crate::resource::ResourceTemplateBuilder;
1902        use std::collections::HashMap;
1903
1904        let template = ResourceTemplateBuilder::new("file:///{path}")
1905            .name("Project Files")
1906            .description("Access project files")
1907            .handler(|uri: String, _vars: HashMap<String, String>| async move {
1908                Ok(ReadResourceResult {
1909                    contents: vec![ResourceContent {
1910                        uri,
1911                        mime_type: None,
1912                        text: None,
1913                        blob: None,
1914                    }],
1915                })
1916            });
1917
1918        let mut router = McpRouter::new().resource_template(template);
1919
1920        // Initialize session
1921        init_router(&mut router).await;
1922
1923        let req = RouterRequest {
1924            id: RequestId::Number(1),
1925            inner: McpRequest::ListResourceTemplates(ListResourceTemplatesParams::default()),
1926            extensions: Extensions::new(),
1927        };
1928
1929        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1930
1931        match resp.inner {
1932            Ok(McpResponse::ListResourceTemplates(result)) => {
1933                assert_eq!(result.resource_templates.len(), 1);
1934                assert_eq!(result.resource_templates[0].uri_template, "file:///{path}");
1935                assert_eq!(result.resource_templates[0].name, "Project Files");
1936            }
1937            _ => panic!("Expected ListResourceTemplates response"),
1938        }
1939    }
1940
1941    #[tokio::test]
1942    async fn test_read_resource_via_template() {
1943        use crate::resource::ResourceTemplateBuilder;
1944        use std::collections::HashMap;
1945
1946        let template = ResourceTemplateBuilder::new("db://users/{id}")
1947            .name("User Records")
1948            .handler(|uri: String, vars: HashMap<String, String>| async move {
1949                let id = vars.get("id").unwrap().clone();
1950                Ok(ReadResourceResult {
1951                    contents: vec![ResourceContent {
1952                        uri,
1953                        mime_type: Some("application/json".to_string()),
1954                        text: Some(format!(r#"{{"id": "{}"}}"#, id)),
1955                        blob: None,
1956                    }],
1957                })
1958            });
1959
1960        let mut router = McpRouter::new().resource_template(template);
1961
1962        // Initialize session
1963        init_router(&mut router).await;
1964
1965        // Read a resource that matches the template
1966        let req = RouterRequest {
1967            id: RequestId::Number(1),
1968            inner: McpRequest::ReadResource(ReadResourceParams {
1969                uri: "db://users/123".to_string(),
1970            }),
1971            extensions: Extensions::new(),
1972        };
1973
1974        let resp = router.ready().await.unwrap().call(req).await.unwrap();
1975
1976        match resp.inner {
1977            Ok(McpResponse::ReadResource(result)) => {
1978                assert_eq!(result.contents.len(), 1);
1979                assert_eq!(result.contents[0].uri, "db://users/123");
1980                assert!(result.contents[0].text.as_ref().unwrap().contains("123"));
1981            }
1982            _ => panic!("Expected ReadResource response"),
1983        }
1984    }
1985
1986    #[tokio::test]
1987    async fn test_static_resource_takes_precedence_over_template() {
1988        use crate::resource::{ResourceBuilder, ResourceTemplateBuilder};
1989        use std::collections::HashMap;
1990
1991        // Template that would match the same URI
1992        let template = ResourceTemplateBuilder::new("file:///{path}")
1993            .name("Files Template")
1994            .handler(|uri: String, _vars: HashMap<String, String>| async move {
1995                Ok(ReadResourceResult {
1996                    contents: vec![ResourceContent {
1997                        uri,
1998                        mime_type: None,
1999                        text: Some("from template".to_string()),
2000                        blob: None,
2001                    }],
2002                })
2003            });
2004
2005        // Static resource with exact URI
2006        let static_resource = ResourceBuilder::new("file:///README.md")
2007            .name("README")
2008            .text("from static resource");
2009
2010        let mut router = McpRouter::new()
2011            .resource_template(template)
2012            .resource(static_resource);
2013
2014        // Initialize session
2015        init_router(&mut router).await;
2016
2017        // Read the static resource - should NOT go through template
2018        let req = RouterRequest {
2019            id: RequestId::Number(1),
2020            inner: McpRequest::ReadResource(ReadResourceParams {
2021                uri: "file:///README.md".to_string(),
2022            }),
2023            extensions: Extensions::new(),
2024        };
2025
2026        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2027
2028        match resp.inner {
2029            Ok(McpResponse::ReadResource(result)) => {
2030                // Should get static resource, not template
2031                assert_eq!(
2032                    result.contents[0].text.as_deref(),
2033                    Some("from static resource")
2034                );
2035            }
2036            _ => panic!("Expected ReadResource response"),
2037        }
2038    }
2039
2040    #[tokio::test]
2041    async fn test_resource_not_found_when_no_match() {
2042        use crate::resource::ResourceTemplateBuilder;
2043        use std::collections::HashMap;
2044
2045        let template = ResourceTemplateBuilder::new("db://users/{id}")
2046            .name("Users")
2047            .handler(|uri: String, _vars: HashMap<String, String>| async move {
2048                Ok(ReadResourceResult {
2049                    contents: vec![ResourceContent {
2050                        uri,
2051                        mime_type: None,
2052                        text: None,
2053                        blob: None,
2054                    }],
2055                })
2056            });
2057
2058        let mut router = McpRouter::new().resource_template(template);
2059
2060        // Initialize session
2061        init_router(&mut router).await;
2062
2063        // Try to read a URI that doesn't match any resource or template
2064        let req = RouterRequest {
2065            id: RequestId::Number(1),
2066            inner: McpRequest::ReadResource(ReadResourceParams {
2067                uri: "db://posts/123".to_string(),
2068            }),
2069            extensions: Extensions::new(),
2070        };
2071
2072        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2073
2074        match resp.inner {
2075            Err(err) => {
2076                assert!(err.message.contains("not found"));
2077            }
2078            Ok(_) => panic!("Expected error for non-matching URI"),
2079        }
2080    }
2081
2082    #[tokio::test]
2083    async fn test_capabilities_include_resources_with_only_templates() {
2084        use crate::resource::ResourceTemplateBuilder;
2085        use std::collections::HashMap;
2086
2087        let template = ResourceTemplateBuilder::new("file:///{path}")
2088            .name("Files")
2089            .handler(|uri: String, _vars: HashMap<String, String>| async move {
2090                Ok(ReadResourceResult {
2091                    contents: vec![ResourceContent {
2092                        uri,
2093                        mime_type: None,
2094                        text: None,
2095                        blob: None,
2096                    }],
2097                })
2098            });
2099
2100        let mut router = McpRouter::new().resource_template(template);
2101
2102        // Send initialize request and check capabilities
2103        let init_req = RouterRequest {
2104            id: RequestId::Number(0),
2105            inner: McpRequest::Initialize(InitializeParams {
2106                protocol_version: "2025-11-25".to_string(),
2107                capabilities: ClientCapabilities {
2108                    roots: None,
2109                    sampling: None,
2110                    elicitation: None,
2111                },
2112                client_info: Implementation {
2113                    name: "test".to_string(),
2114                    version: "1.0".to_string(),
2115                    ..Default::default()
2116                },
2117            }),
2118            extensions: Extensions::new(),
2119        };
2120        let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
2121
2122        match resp.inner {
2123            Ok(McpResponse::Initialize(result)) => {
2124                // Should have resources capability even though only templates registered
2125                assert!(result.capabilities.resources.is_some());
2126            }
2127            _ => panic!("Expected Initialize response"),
2128        }
2129    }
2130
2131    // =========================================================================
2132    // Logging Notification Tests
2133    // =========================================================================
2134
2135    #[tokio::test]
2136    async fn test_log_sends_notification() {
2137        use crate::context::notification_channel;
2138
2139        let (tx, mut rx) = notification_channel(10);
2140        let router = McpRouter::new().with_notification_sender(tx);
2141
2142        // Send an info log
2143        let sent = router.log_info("Test message");
2144        assert!(sent);
2145
2146        // Should receive the notification
2147        let notification = rx.try_recv().unwrap();
2148        match notification {
2149            ServerNotification::LogMessage(params) => {
2150                assert_eq!(params.level, LogLevel::Info);
2151                let data = params.data.unwrap();
2152                assert_eq!(
2153                    data.get("message").unwrap().as_str().unwrap(),
2154                    "Test message"
2155                );
2156            }
2157            _ => panic!("Expected LogMessage notification"),
2158        }
2159    }
2160
2161    #[tokio::test]
2162    async fn test_log_with_custom_params() {
2163        use crate::context::notification_channel;
2164
2165        let (tx, mut rx) = notification_channel(10);
2166        let router = McpRouter::new().with_notification_sender(tx);
2167
2168        // Send a custom log message
2169        let params = LoggingMessageParams::new(LogLevel::Error)
2170            .with_logger("database")
2171            .with_data(serde_json::json!({
2172                "error": "Connection failed",
2173                "host": "localhost"
2174            }));
2175
2176        let sent = router.log(params);
2177        assert!(sent);
2178
2179        let notification = rx.try_recv().unwrap();
2180        match notification {
2181            ServerNotification::LogMessage(params) => {
2182                assert_eq!(params.level, LogLevel::Error);
2183                assert_eq!(params.logger.as_deref(), Some("database"));
2184                let data = params.data.unwrap();
2185                assert_eq!(
2186                    data.get("error").unwrap().as_str().unwrap(),
2187                    "Connection failed"
2188                );
2189            }
2190            _ => panic!("Expected LogMessage notification"),
2191        }
2192    }
2193
2194    #[tokio::test]
2195    async fn test_log_without_channel_returns_false() {
2196        // Router without notification channel
2197        let router = McpRouter::new();
2198
2199        // Should return false when no channel configured
2200        assert!(!router.log_info("Test"));
2201        assert!(!router.log_warning("Test"));
2202        assert!(!router.log_error("Test"));
2203        assert!(!router.log_debug("Test"));
2204    }
2205
2206    #[tokio::test]
2207    async fn test_logging_capability_with_channel() {
2208        use crate::context::notification_channel;
2209
2210        let (tx, _rx) = notification_channel(10);
2211        let mut router = McpRouter::new().with_notification_sender(tx);
2212
2213        // Initialize and check capabilities
2214        let init_req = RouterRequest {
2215            id: RequestId::Number(0),
2216            inner: McpRequest::Initialize(InitializeParams {
2217                protocol_version: "2025-11-25".to_string(),
2218                capabilities: ClientCapabilities {
2219                    roots: None,
2220                    sampling: None,
2221                    elicitation: None,
2222                },
2223                client_info: Implementation {
2224                    name: "test".to_string(),
2225                    version: "1.0".to_string(),
2226                    ..Default::default()
2227                },
2228            }),
2229            extensions: Extensions::new(),
2230        };
2231        let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
2232
2233        match resp.inner {
2234            Ok(McpResponse::Initialize(result)) => {
2235                // Should have logging capability when notification channel is set
2236                assert!(result.capabilities.logging.is_some());
2237            }
2238            _ => panic!("Expected Initialize response"),
2239        }
2240    }
2241
2242    #[tokio::test]
2243    async fn test_no_logging_capability_without_channel() {
2244        let mut router = McpRouter::new();
2245
2246        // Initialize and check capabilities
2247        let init_req = RouterRequest {
2248            id: RequestId::Number(0),
2249            inner: McpRequest::Initialize(InitializeParams {
2250                protocol_version: "2025-11-25".to_string(),
2251                capabilities: ClientCapabilities {
2252                    roots: None,
2253                    sampling: None,
2254                    elicitation: None,
2255                },
2256                client_info: Implementation {
2257                    name: "test".to_string(),
2258                    version: "1.0".to_string(),
2259                    ..Default::default()
2260                },
2261            }),
2262            extensions: Extensions::new(),
2263        };
2264        let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
2265
2266        match resp.inner {
2267            Ok(McpResponse::Initialize(result)) => {
2268                // Should NOT have logging capability without notification channel
2269                assert!(result.capabilities.logging.is_none());
2270            }
2271            _ => panic!("Expected Initialize response"),
2272        }
2273    }
2274
2275    // =========================================================================
2276    // Task Lifecycle Tests
2277    // =========================================================================
2278
2279    #[tokio::test]
2280    async fn test_enqueue_task() {
2281        let add_tool = ToolBuilder::new("add")
2282            .description("Add two numbers")
2283            .handler(|input: AddInput| async move {
2284                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
2285            })
2286            .build()
2287            .expect("valid tool name");
2288
2289        let mut router = McpRouter::new().tool(add_tool);
2290        init_router(&mut router).await;
2291
2292        let req = RouterRequest {
2293            id: RequestId::Number(1),
2294            inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2295                tool_name: "add".to_string(),
2296                arguments: serde_json::json!({"a": 5, "b": 10}),
2297                ttl: None,
2298            }),
2299            extensions: Extensions::new(),
2300        };
2301
2302        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2303
2304        match resp.inner {
2305            Ok(McpResponse::EnqueueTask(result)) => {
2306                assert!(result.task_id.starts_with("task-"));
2307                assert_eq!(result.status, TaskStatus::Working);
2308            }
2309            _ => panic!("Expected EnqueueTask response"),
2310        }
2311    }
2312
2313    #[tokio::test]
2314    async fn test_list_tasks_empty() {
2315        let mut router = McpRouter::new();
2316        init_router(&mut router).await;
2317
2318        let req = RouterRequest {
2319            id: RequestId::Number(1),
2320            inner: McpRequest::ListTasks(ListTasksParams::default()),
2321            extensions: Extensions::new(),
2322        };
2323
2324        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2325
2326        match resp.inner {
2327            Ok(McpResponse::ListTasks(result)) => {
2328                assert!(result.tasks.is_empty());
2329            }
2330            _ => panic!("Expected ListTasks response"),
2331        }
2332    }
2333
2334    #[tokio::test]
2335    async fn test_task_lifecycle_complete() {
2336        let add_tool = ToolBuilder::new("add")
2337            .description("Add two numbers")
2338            .handler(|input: AddInput| async move {
2339                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
2340            })
2341            .build()
2342            .expect("valid tool name");
2343
2344        let mut router = McpRouter::new().tool(add_tool);
2345        init_router(&mut router).await;
2346
2347        // Enqueue task
2348        let req = RouterRequest {
2349            id: RequestId::Number(1),
2350            inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2351                tool_name: "add".to_string(),
2352                arguments: serde_json::json!({"a": 7, "b": 8}),
2353                ttl: None,
2354            }),
2355            extensions: Extensions::new(),
2356        };
2357
2358        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2359        let task_id = match resp.inner {
2360            Ok(McpResponse::EnqueueTask(result)) => result.task_id,
2361            _ => panic!("Expected EnqueueTask response"),
2362        };
2363
2364        // Wait for task to complete
2365        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
2366
2367        // Get task result
2368        let req = RouterRequest {
2369            id: RequestId::Number(2),
2370            inner: McpRequest::GetTaskResult(GetTaskResultParams {
2371                task_id: task_id.clone(),
2372            }),
2373            extensions: Extensions::new(),
2374        };
2375
2376        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2377
2378        match resp.inner {
2379            Ok(McpResponse::GetTaskResult(result)) => {
2380                assert_eq!(result.task_id, task_id);
2381                assert_eq!(result.status, TaskStatus::Completed);
2382                assert!(result.result.is_some());
2383                assert!(result.error.is_none());
2384
2385                // Check the result content
2386                let tool_result = result.result.unwrap();
2387                match &tool_result.content[0] {
2388                    Content::Text { text, .. } => assert_eq!(text, "15"),
2389                    _ => panic!("Expected text content"),
2390                }
2391            }
2392            _ => panic!("Expected GetTaskResult response"),
2393        }
2394    }
2395
2396    #[tokio::test]
2397    async fn test_task_cancellation() {
2398        // Use a slow tool to test cancellation
2399        let slow_tool = ToolBuilder::new("slow")
2400            .description("Slow tool")
2401            .handler(|_input: serde_json::Value| async move {
2402                tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
2403                Ok(CallToolResult::text("done"))
2404            })
2405            .build()
2406            .expect("valid tool name");
2407
2408        let mut router = McpRouter::new().tool(slow_tool);
2409        init_router(&mut router).await;
2410
2411        // Enqueue task
2412        let req = RouterRequest {
2413            id: RequestId::Number(1),
2414            inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2415                tool_name: "slow".to_string(),
2416                arguments: serde_json::json!({}),
2417                ttl: None,
2418            }),
2419            extensions: Extensions::new(),
2420        };
2421
2422        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2423        let task_id = match resp.inner {
2424            Ok(McpResponse::EnqueueTask(result)) => result.task_id,
2425            _ => panic!("Expected EnqueueTask response"),
2426        };
2427
2428        // Cancel the task
2429        let req = RouterRequest {
2430            id: RequestId::Number(2),
2431            inner: McpRequest::CancelTask(CancelTaskParams {
2432                task_id: task_id.clone(),
2433                reason: Some("Test cancellation".to_string()),
2434            }),
2435            extensions: Extensions::new(),
2436        };
2437
2438        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2439
2440        match resp.inner {
2441            Ok(McpResponse::CancelTask(result)) => {
2442                assert!(result.cancelled);
2443                assert_eq!(result.status, TaskStatus::Cancelled);
2444            }
2445            _ => panic!("Expected CancelTask response"),
2446        }
2447    }
2448
2449    #[tokio::test]
2450    async fn test_get_task_info() {
2451        let add_tool = ToolBuilder::new("add")
2452            .description("Add two numbers")
2453            .handler(|input: AddInput| async move {
2454                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
2455            })
2456            .build()
2457            .expect("valid tool name");
2458
2459        let mut router = McpRouter::new().tool(add_tool);
2460        init_router(&mut router).await;
2461
2462        // Enqueue task
2463        let req = RouterRequest {
2464            id: RequestId::Number(1),
2465            inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2466                tool_name: "add".to_string(),
2467                arguments: serde_json::json!({"a": 1, "b": 2}),
2468                ttl: Some(600),
2469            }),
2470            extensions: Extensions::new(),
2471        };
2472
2473        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2474        let task_id = match resp.inner {
2475            Ok(McpResponse::EnqueueTask(result)) => result.task_id,
2476            _ => panic!("Expected EnqueueTask response"),
2477        };
2478
2479        // Get task info
2480        let req = RouterRequest {
2481            id: RequestId::Number(2),
2482            inner: McpRequest::GetTaskInfo(GetTaskInfoParams {
2483                task_id: task_id.clone(),
2484            }),
2485            extensions: Extensions::new(),
2486        };
2487
2488        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2489
2490        match resp.inner {
2491            Ok(McpResponse::GetTaskInfo(info)) => {
2492                assert_eq!(info.task_id, task_id);
2493                assert!(info.created_at.contains('T')); // ISO 8601
2494                assert_eq!(info.ttl, Some(600));
2495            }
2496            _ => panic!("Expected GetTaskInfo response"),
2497        }
2498    }
2499
2500    #[tokio::test]
2501    async fn test_enqueue_nonexistent_tool() {
2502        let mut router = McpRouter::new();
2503        init_router(&mut router).await;
2504
2505        let req = RouterRequest {
2506            id: RequestId::Number(1),
2507            inner: McpRequest::EnqueueTask(EnqueueTaskParams {
2508                tool_name: "nonexistent".to_string(),
2509                arguments: serde_json::json!({}),
2510                ttl: None,
2511            }),
2512            extensions: Extensions::new(),
2513        };
2514
2515        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2516
2517        match resp.inner {
2518            Err(e) => {
2519                assert!(e.message.contains("not found"));
2520            }
2521            _ => panic!("Expected error response"),
2522        }
2523    }
2524
2525    #[tokio::test]
2526    async fn test_get_nonexistent_task() {
2527        let mut router = McpRouter::new();
2528        init_router(&mut router).await;
2529
2530        let req = RouterRequest {
2531            id: RequestId::Number(1),
2532            inner: McpRequest::GetTaskInfo(GetTaskInfoParams {
2533                task_id: "task-999".to_string(),
2534            }),
2535            extensions: Extensions::new(),
2536        };
2537
2538        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2539
2540        match resp.inner {
2541            Err(e) => {
2542                assert!(e.message.contains("not found"));
2543            }
2544            _ => panic!("Expected error response"),
2545        }
2546    }
2547
2548    // =========================================================================
2549    // Resource Subscription Tests
2550    // =========================================================================
2551
2552    #[tokio::test]
2553    async fn test_subscribe_to_resource() {
2554        use crate::resource::ResourceBuilder;
2555
2556        let resource = ResourceBuilder::new("file:///test.txt")
2557            .name("Test File")
2558            .text("Hello");
2559
2560        let mut router = McpRouter::new().resource(resource);
2561        init_router(&mut router).await;
2562
2563        // Subscribe to the resource
2564        let req = RouterRequest {
2565            id: RequestId::Number(1),
2566            inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2567                uri: "file:///test.txt".to_string(),
2568            }),
2569            extensions: Extensions::new(),
2570        };
2571
2572        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2573
2574        match resp.inner {
2575            Ok(McpResponse::SubscribeResource(_)) => {
2576                // Should be subscribed now
2577                assert!(router.is_subscribed("file:///test.txt"));
2578            }
2579            _ => panic!("Expected SubscribeResource response"),
2580        }
2581    }
2582
2583    #[tokio::test]
2584    async fn test_unsubscribe_from_resource() {
2585        use crate::resource::ResourceBuilder;
2586
2587        let resource = ResourceBuilder::new("file:///test.txt")
2588            .name("Test File")
2589            .text("Hello");
2590
2591        let mut router = McpRouter::new().resource(resource);
2592        init_router(&mut router).await;
2593
2594        // Subscribe first
2595        let req = RouterRequest {
2596            id: RequestId::Number(1),
2597            inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2598                uri: "file:///test.txt".to_string(),
2599            }),
2600            extensions: Extensions::new(),
2601        };
2602        let _ = router.ready().await.unwrap().call(req).await.unwrap();
2603        assert!(router.is_subscribed("file:///test.txt"));
2604
2605        // Now unsubscribe
2606        let req = RouterRequest {
2607            id: RequestId::Number(2),
2608            inner: McpRequest::UnsubscribeResource(UnsubscribeResourceParams {
2609                uri: "file:///test.txt".to_string(),
2610            }),
2611            extensions: Extensions::new(),
2612        };
2613
2614        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2615
2616        match resp.inner {
2617            Ok(McpResponse::UnsubscribeResource(_)) => {
2618                // Should no longer be subscribed
2619                assert!(!router.is_subscribed("file:///test.txt"));
2620            }
2621            _ => panic!("Expected UnsubscribeResource response"),
2622        }
2623    }
2624
2625    #[tokio::test]
2626    async fn test_subscribe_nonexistent_resource() {
2627        let mut router = McpRouter::new();
2628        init_router(&mut router).await;
2629
2630        let req = RouterRequest {
2631            id: RequestId::Number(1),
2632            inner: McpRequest::SubscribeResource(SubscribeResourceParams {
2633                uri: "file:///nonexistent.txt".to_string(),
2634            }),
2635            extensions: Extensions::new(),
2636        };
2637
2638        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2639
2640        match resp.inner {
2641            Err(e) => {
2642                assert!(e.message.contains("not found"));
2643            }
2644            _ => panic!("Expected error response"),
2645        }
2646    }
2647
2648    #[tokio::test]
2649    async fn test_notify_resource_updated() {
2650        use crate::context::notification_channel;
2651        use crate::resource::ResourceBuilder;
2652
2653        let (tx, mut rx) = notification_channel(10);
2654
2655        let resource = ResourceBuilder::new("file:///test.txt")
2656            .name("Test File")
2657            .text("Hello");
2658
2659        let router = McpRouter::new()
2660            .resource(resource)
2661            .with_notification_sender(tx);
2662
2663        // First, manually subscribe (simulate subscription)
2664        router.subscribe("file:///test.txt");
2665
2666        // Now notify
2667        let sent = router.notify_resource_updated("file:///test.txt");
2668        assert!(sent);
2669
2670        // Check the notification was sent
2671        let notification = rx.try_recv().unwrap();
2672        match notification {
2673            ServerNotification::ResourceUpdated { uri } => {
2674                assert_eq!(uri, "file:///test.txt");
2675            }
2676            _ => panic!("Expected ResourceUpdated notification"),
2677        }
2678    }
2679
2680    #[tokio::test]
2681    async fn test_notify_resource_updated_not_subscribed() {
2682        use crate::context::notification_channel;
2683        use crate::resource::ResourceBuilder;
2684
2685        let (tx, mut rx) = notification_channel(10);
2686
2687        let resource = ResourceBuilder::new("file:///test.txt")
2688            .name("Test File")
2689            .text("Hello");
2690
2691        let router = McpRouter::new()
2692            .resource(resource)
2693            .with_notification_sender(tx);
2694
2695        // Try to notify without subscribing
2696        let sent = router.notify_resource_updated("file:///test.txt");
2697        assert!(!sent); // Should not send because not subscribed
2698
2699        // Channel should be empty
2700        assert!(rx.try_recv().is_err());
2701    }
2702
2703    #[tokio::test]
2704    async fn test_notify_resources_list_changed() {
2705        use crate::context::notification_channel;
2706
2707        let (tx, mut rx) = notification_channel(10);
2708        let router = McpRouter::new().with_notification_sender(tx);
2709
2710        let sent = router.notify_resources_list_changed();
2711        assert!(sent);
2712
2713        let notification = rx.try_recv().unwrap();
2714        match notification {
2715            ServerNotification::ResourcesListChanged => {}
2716            _ => panic!("Expected ResourcesListChanged notification"),
2717        }
2718    }
2719
2720    #[tokio::test]
2721    async fn test_subscribed_uris() {
2722        use crate::resource::ResourceBuilder;
2723
2724        let resource1 = ResourceBuilder::new("file:///a.txt").name("A").text("A");
2725
2726        let resource2 = ResourceBuilder::new("file:///b.txt").name("B").text("B");
2727
2728        let router = McpRouter::new().resource(resource1).resource(resource2);
2729
2730        // Subscribe to both
2731        router.subscribe("file:///a.txt");
2732        router.subscribe("file:///b.txt");
2733
2734        let uris = router.subscribed_uris();
2735        assert_eq!(uris.len(), 2);
2736        assert!(uris.contains(&"file:///a.txt".to_string()));
2737        assert!(uris.contains(&"file:///b.txt".to_string()));
2738    }
2739
2740    #[tokio::test]
2741    async fn test_subscription_capability_advertised() {
2742        use crate::resource::ResourceBuilder;
2743
2744        let resource = ResourceBuilder::new("file:///test.txt")
2745            .name("Test")
2746            .text("Hello");
2747
2748        let mut router = McpRouter::new().resource(resource);
2749
2750        // Initialize and check capabilities
2751        let init_req = RouterRequest {
2752            id: RequestId::Number(0),
2753            inner: McpRequest::Initialize(InitializeParams {
2754                protocol_version: "2025-11-25".to_string(),
2755                capabilities: ClientCapabilities {
2756                    roots: None,
2757                    sampling: None,
2758                    elicitation: None,
2759                },
2760                client_info: Implementation {
2761                    name: "test".to_string(),
2762                    version: "1.0".to_string(),
2763                    ..Default::default()
2764                },
2765            }),
2766            extensions: Extensions::new(),
2767        };
2768        let resp = router.ready().await.unwrap().call(init_req).await.unwrap();
2769
2770        match resp.inner {
2771            Ok(McpResponse::Initialize(result)) => {
2772                // Should have resources capability with subscribe enabled
2773                let resources_cap = result.capabilities.resources.unwrap();
2774                assert!(resources_cap.subscribe);
2775            }
2776            _ => panic!("Expected Initialize response"),
2777        }
2778    }
2779
2780    #[tokio::test]
2781    async fn test_completion_handler() {
2782        let router = McpRouter::new()
2783            .server_info("test", "1.0")
2784            .completion_handler(|params: CompleteParams| async move {
2785                // Return suggestions based on the argument value
2786                let prefix = &params.argument.value;
2787                let suggestions: Vec<String> = vec!["alpha", "beta", "gamma"]
2788                    .into_iter()
2789                    .filter(|s| s.starts_with(prefix))
2790                    .map(String::from)
2791                    .collect();
2792                Ok(CompleteResult::new(suggestions))
2793            });
2794
2795        // Initialize
2796        let init_req = RouterRequest {
2797            id: RequestId::Number(0),
2798            inner: McpRequest::Initialize(InitializeParams {
2799                protocol_version: "2025-11-25".to_string(),
2800                capabilities: ClientCapabilities::default(),
2801                client_info: Implementation {
2802                    name: "test".to_string(),
2803                    version: "1.0".to_string(),
2804                    ..Default::default()
2805                },
2806            }),
2807            extensions: Extensions::new(),
2808        };
2809        let resp = router
2810            .clone()
2811            .ready()
2812            .await
2813            .unwrap()
2814            .call(init_req)
2815            .await
2816            .unwrap();
2817
2818        // Check that completions capability is advertised
2819        match resp.inner {
2820            Ok(McpResponse::Initialize(result)) => {
2821                assert!(result.capabilities.completions.is_some());
2822            }
2823            _ => panic!("Expected Initialize response"),
2824        }
2825
2826        // Send initialized notification
2827        router.handle_notification(McpNotification::Initialized);
2828
2829        // Test completion request
2830        let complete_req = RouterRequest {
2831            id: RequestId::Number(1),
2832            inner: McpRequest::Complete(CompleteParams {
2833                reference: CompletionReference::prompt("test-prompt"),
2834                argument: CompletionArgument::new("query", "al"),
2835            }),
2836            extensions: Extensions::new(),
2837        };
2838        let resp = router
2839            .clone()
2840            .ready()
2841            .await
2842            .unwrap()
2843            .call(complete_req)
2844            .await
2845            .unwrap();
2846
2847        match resp.inner {
2848            Ok(McpResponse::Complete(result)) => {
2849                assert_eq!(result.completion.values, vec!["alpha"]);
2850            }
2851            _ => panic!("Expected Complete response"),
2852        }
2853    }
2854
2855    #[tokio::test]
2856    async fn test_completion_without_handler_returns_empty() {
2857        let router = McpRouter::new().server_info("test", "1.0");
2858
2859        // Initialize
2860        let init_req = RouterRequest {
2861            id: RequestId::Number(0),
2862            inner: McpRequest::Initialize(InitializeParams {
2863                protocol_version: "2025-11-25".to_string(),
2864                capabilities: ClientCapabilities::default(),
2865                client_info: Implementation {
2866                    name: "test".to_string(),
2867                    version: "1.0".to_string(),
2868                    ..Default::default()
2869                },
2870            }),
2871            extensions: Extensions::new(),
2872        };
2873        let resp = router
2874            .clone()
2875            .ready()
2876            .await
2877            .unwrap()
2878            .call(init_req)
2879            .await
2880            .unwrap();
2881
2882        // Check that completions capability is NOT advertised
2883        match resp.inner {
2884            Ok(McpResponse::Initialize(result)) => {
2885                assert!(result.capabilities.completions.is_none());
2886            }
2887            _ => panic!("Expected Initialize response"),
2888        }
2889
2890        // Send initialized notification
2891        router.handle_notification(McpNotification::Initialized);
2892
2893        // Test completion request still works but returns empty
2894        let complete_req = RouterRequest {
2895            id: RequestId::Number(1),
2896            inner: McpRequest::Complete(CompleteParams {
2897                reference: CompletionReference::prompt("test-prompt"),
2898                argument: CompletionArgument::new("query", "al"),
2899            }),
2900            extensions: Extensions::new(),
2901        };
2902        let resp = router
2903            .clone()
2904            .ready()
2905            .await
2906            .unwrap()
2907            .call(complete_req)
2908            .await
2909            .unwrap();
2910
2911        match resp.inner {
2912            Ok(McpResponse::Complete(result)) => {
2913                assert!(result.completion.values.is_empty());
2914            }
2915            _ => panic!("Expected Complete response"),
2916        }
2917    }
2918
2919    #[tokio::test]
2920    async fn test_tool_filter_list() {
2921        use crate::filter::CapabilityFilter;
2922        use crate::tool::Tool;
2923
2924        let public_tool = ToolBuilder::new("public")
2925            .description("Public tool")
2926            .handler(|_: AddInput| async move { Ok(CallToolResult::text("public")) })
2927            .build()
2928            .expect("valid tool name");
2929
2930        let admin_tool = ToolBuilder::new("admin")
2931            .description("Admin tool")
2932            .handler(|_: AddInput| async move { Ok(CallToolResult::text("admin")) })
2933            .build()
2934            .expect("valid tool name");
2935
2936        let mut router = McpRouter::new()
2937            .tool(public_tool)
2938            .tool(admin_tool)
2939            .tool_filter(CapabilityFilter::new(|_, tool: &Tool| tool.name != "admin"));
2940
2941        // Initialize session
2942        init_router(&mut router).await;
2943
2944        let req = RouterRequest {
2945            id: RequestId::Number(1),
2946            inner: McpRequest::ListTools(ListToolsParams::default()),
2947            extensions: Extensions::new(),
2948        };
2949
2950        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2951
2952        match resp.inner {
2953            Ok(McpResponse::ListTools(result)) => {
2954                // Only public tool should be visible
2955                assert_eq!(result.tools.len(), 1);
2956                assert_eq!(result.tools[0].name, "public");
2957            }
2958            _ => panic!("Expected ListTools response"),
2959        }
2960    }
2961
2962    #[tokio::test]
2963    async fn test_tool_filter_call_denied() {
2964        use crate::filter::CapabilityFilter;
2965        use crate::tool::Tool;
2966
2967        let admin_tool = ToolBuilder::new("admin")
2968            .description("Admin tool")
2969            .handler(|_: AddInput| async move { Ok(CallToolResult::text("admin")) })
2970            .build()
2971            .expect("valid tool name");
2972
2973        let mut router = McpRouter::new()
2974            .tool(admin_tool)
2975            .tool_filter(CapabilityFilter::new(|_, _: &Tool| false)); // Deny all
2976
2977        // Initialize session
2978        init_router(&mut router).await;
2979
2980        let req = RouterRequest {
2981            id: RequestId::Number(1),
2982            inner: McpRequest::CallTool(CallToolParams {
2983                name: "admin".to_string(),
2984                arguments: serde_json::json!({"a": 1, "b": 2}),
2985                meta: None,
2986            }),
2987            extensions: Extensions::new(),
2988        };
2989
2990        let resp = router.ready().await.unwrap().call(req).await.unwrap();
2991
2992        // Should get method not found error (default denial behavior)
2993        match resp.inner {
2994            Err(e) => {
2995                assert_eq!(e.code, -32601); // Method not found
2996            }
2997            _ => panic!("Expected JsonRpc error"),
2998        }
2999    }
3000
3001    #[tokio::test]
3002    async fn test_tool_filter_call_allowed() {
3003        use crate::filter::CapabilityFilter;
3004        use crate::tool::Tool;
3005
3006        let public_tool = ToolBuilder::new("public")
3007            .description("Public tool")
3008            .handler(|input: AddInput| async move {
3009                Ok(CallToolResult::text(format!("{}", input.a + input.b)))
3010            })
3011            .build()
3012            .expect("valid tool name");
3013
3014        let mut router = McpRouter::new()
3015            .tool(public_tool)
3016            .tool_filter(CapabilityFilter::new(|_, _: &Tool| true)); // Allow all
3017
3018        // Initialize session
3019        init_router(&mut router).await;
3020
3021        let req = RouterRequest {
3022            id: RequestId::Number(1),
3023            inner: McpRequest::CallTool(CallToolParams {
3024                name: "public".to_string(),
3025                arguments: serde_json::json!({"a": 1, "b": 2}),
3026                meta: None,
3027            }),
3028            extensions: Extensions::new(),
3029        };
3030
3031        let resp = router.ready().await.unwrap().call(req).await.unwrap();
3032
3033        match resp.inner {
3034            Ok(McpResponse::CallTool(result)) => {
3035                assert!(!result.is_error);
3036            }
3037            _ => panic!("Expected CallTool response"),
3038        }
3039    }
3040
3041    #[tokio::test]
3042    async fn test_tool_filter_custom_denial() {
3043        use crate::filter::{CapabilityFilter, DenialBehavior};
3044        use crate::tool::Tool;
3045
3046        let admin_tool = ToolBuilder::new("admin")
3047            .description("Admin tool")
3048            .handler(|_: AddInput| async move { Ok(CallToolResult::text("admin")) })
3049            .build()
3050            .expect("valid tool name");
3051
3052        let mut router = McpRouter::new().tool(admin_tool).tool_filter(
3053            CapabilityFilter::new(|_, _: &Tool| false)
3054                .denial_behavior(DenialBehavior::Unauthorized),
3055        );
3056
3057        // Initialize session
3058        init_router(&mut router).await;
3059
3060        let req = RouterRequest {
3061            id: RequestId::Number(1),
3062            inner: McpRequest::CallTool(CallToolParams {
3063                name: "admin".to_string(),
3064                arguments: serde_json::json!({"a": 1, "b": 2}),
3065                meta: None,
3066            }),
3067            extensions: Extensions::new(),
3068        };
3069
3070        let resp = router.ready().await.unwrap().call(req).await.unwrap();
3071
3072        // Should get forbidden error
3073        match resp.inner {
3074            Err(e) => {
3075                assert_eq!(e.code, -32007); // Forbidden
3076                assert!(e.message.contains("Unauthorized"));
3077            }
3078            _ => panic!("Expected JsonRpc error"),
3079        }
3080    }
3081
3082    #[tokio::test]
3083    async fn test_resource_filter_list() {
3084        use crate::filter::CapabilityFilter;
3085        use crate::resource::{Resource, ResourceBuilder};
3086
3087        let public_resource = ResourceBuilder::new("file:///public.txt")
3088            .name("Public File")
3089            .text("public content");
3090
3091        let secret_resource = ResourceBuilder::new("file:///secret.txt")
3092            .name("Secret File")
3093            .text("secret content");
3094
3095        let mut router = McpRouter::new()
3096            .resource(public_resource)
3097            .resource(secret_resource)
3098            .resource_filter(CapabilityFilter::new(|_, r: &Resource| {
3099                !r.name.contains("Secret")
3100            }));
3101
3102        // Initialize session
3103        init_router(&mut router).await;
3104
3105        let req = RouterRequest {
3106            id: RequestId::Number(1),
3107            inner: McpRequest::ListResources(ListResourcesParams::default()),
3108            extensions: Extensions::new(),
3109        };
3110
3111        let resp = router.ready().await.unwrap().call(req).await.unwrap();
3112
3113        match resp.inner {
3114            Ok(McpResponse::ListResources(result)) => {
3115                // Should only see public resource
3116                assert_eq!(result.resources.len(), 1);
3117                assert_eq!(result.resources[0].name, "Public File");
3118            }
3119            _ => panic!("Expected ListResources response"),
3120        }
3121    }
3122
3123    #[tokio::test]
3124    async fn test_resource_filter_read_denied() {
3125        use crate::filter::CapabilityFilter;
3126        use crate::resource::{Resource, ResourceBuilder};
3127
3128        let secret_resource = ResourceBuilder::new("file:///secret.txt")
3129            .name("Secret File")
3130            .text("secret content");
3131
3132        let mut router = McpRouter::new()
3133            .resource(secret_resource)
3134            .resource_filter(CapabilityFilter::new(|_, _: &Resource| false)); // Deny all
3135
3136        // Initialize session
3137        init_router(&mut router).await;
3138
3139        let req = RouterRequest {
3140            id: RequestId::Number(1),
3141            inner: McpRequest::ReadResource(ReadResourceParams {
3142                uri: "file:///secret.txt".to_string(),
3143            }),
3144            extensions: Extensions::new(),
3145        };
3146
3147        let resp = router.ready().await.unwrap().call(req).await.unwrap();
3148
3149        // Should get method not found error (default denial behavior)
3150        match resp.inner {
3151            Err(e) => {
3152                assert_eq!(e.code, -32601); // Method not found
3153            }
3154            _ => panic!("Expected JsonRpc error"),
3155        }
3156    }
3157
3158    #[tokio::test]
3159    async fn test_resource_filter_read_allowed() {
3160        use crate::filter::CapabilityFilter;
3161        use crate::resource::{Resource, ResourceBuilder};
3162
3163        let public_resource = ResourceBuilder::new("file:///public.txt")
3164            .name("Public File")
3165            .text("public content");
3166
3167        let mut router = McpRouter::new()
3168            .resource(public_resource)
3169            .resource_filter(CapabilityFilter::new(|_, _: &Resource| true)); // Allow all
3170
3171        // Initialize session
3172        init_router(&mut router).await;
3173
3174        let req = RouterRequest {
3175            id: RequestId::Number(1),
3176            inner: McpRequest::ReadResource(ReadResourceParams {
3177                uri: "file:///public.txt".to_string(),
3178            }),
3179            extensions: Extensions::new(),
3180        };
3181
3182        let resp = router.ready().await.unwrap().call(req).await.unwrap();
3183
3184        match resp.inner {
3185            Ok(McpResponse::ReadResource(result)) => {
3186                assert_eq!(result.contents.len(), 1);
3187                assert_eq!(result.contents[0].text.as_deref(), Some("public content"));
3188            }
3189            _ => panic!("Expected ReadResource response"),
3190        }
3191    }
3192
3193    #[tokio::test]
3194    async fn test_resource_filter_custom_denial() {
3195        use crate::filter::{CapabilityFilter, DenialBehavior};
3196        use crate::resource::{Resource, ResourceBuilder};
3197
3198        let secret_resource = ResourceBuilder::new("file:///secret.txt")
3199            .name("Secret File")
3200            .text("secret content");
3201
3202        let mut router = McpRouter::new().resource(secret_resource).resource_filter(
3203            CapabilityFilter::new(|_, _: &Resource| false)
3204                .denial_behavior(DenialBehavior::Unauthorized),
3205        );
3206
3207        // Initialize session
3208        init_router(&mut router).await;
3209
3210        let req = RouterRequest {
3211            id: RequestId::Number(1),
3212            inner: McpRequest::ReadResource(ReadResourceParams {
3213                uri: "file:///secret.txt".to_string(),
3214            }),
3215            extensions: Extensions::new(),
3216        };
3217
3218        let resp = router.ready().await.unwrap().call(req).await.unwrap();
3219
3220        // Should get forbidden error
3221        match resp.inner {
3222            Err(e) => {
3223                assert_eq!(e.code, -32007); // Forbidden
3224                assert!(e.message.contains("Unauthorized"));
3225            }
3226            _ => panic!("Expected JsonRpc error"),
3227        }
3228    }
3229
3230    #[tokio::test]
3231    async fn test_prompt_filter_list() {
3232        use crate::filter::CapabilityFilter;
3233        use crate::prompt::{Prompt, PromptBuilder};
3234
3235        let public_prompt = PromptBuilder::new("greeting")
3236            .description("A greeting")
3237            .user_message("Hello!");
3238
3239        let admin_prompt = PromptBuilder::new("system_debug")
3240            .description("Admin prompt")
3241            .user_message("Debug");
3242
3243        let mut router = McpRouter::new()
3244            .prompt(public_prompt)
3245            .prompt(admin_prompt)
3246            .prompt_filter(CapabilityFilter::new(|_, p: &Prompt| {
3247                !p.name.contains("system")
3248            }));
3249
3250        // Initialize session
3251        init_router(&mut router).await;
3252
3253        let req = RouterRequest {
3254            id: RequestId::Number(1),
3255            inner: McpRequest::ListPrompts(ListPromptsParams::default()),
3256            extensions: Extensions::new(),
3257        };
3258
3259        let resp = router.ready().await.unwrap().call(req).await.unwrap();
3260
3261        match resp.inner {
3262            Ok(McpResponse::ListPrompts(result)) => {
3263                // Should only see public prompt
3264                assert_eq!(result.prompts.len(), 1);
3265                assert_eq!(result.prompts[0].name, "greeting");
3266            }
3267            _ => panic!("Expected ListPrompts response"),
3268        }
3269    }
3270
3271    #[tokio::test]
3272    async fn test_prompt_filter_get_denied() {
3273        use crate::filter::CapabilityFilter;
3274        use crate::prompt::{Prompt, PromptBuilder};
3275        use std::collections::HashMap;
3276
3277        let admin_prompt = PromptBuilder::new("system_debug")
3278            .description("Admin prompt")
3279            .user_message("Debug");
3280
3281        let mut router = McpRouter::new()
3282            .prompt(admin_prompt)
3283            .prompt_filter(CapabilityFilter::new(|_, _: &Prompt| false)); // Deny all
3284
3285        // Initialize session
3286        init_router(&mut router).await;
3287
3288        let req = RouterRequest {
3289            id: RequestId::Number(1),
3290            inner: McpRequest::GetPrompt(GetPromptParams {
3291                name: "system_debug".to_string(),
3292                arguments: HashMap::new(),
3293            }),
3294            extensions: Extensions::new(),
3295        };
3296
3297        let resp = router.ready().await.unwrap().call(req).await.unwrap();
3298
3299        // Should get method not found error (default denial behavior)
3300        match resp.inner {
3301            Err(e) => {
3302                assert_eq!(e.code, -32601); // Method not found
3303            }
3304            _ => panic!("Expected JsonRpc error"),
3305        }
3306    }
3307
3308    #[tokio::test]
3309    async fn test_prompt_filter_get_allowed() {
3310        use crate::filter::CapabilityFilter;
3311        use crate::prompt::{Prompt, PromptBuilder};
3312        use std::collections::HashMap;
3313
3314        let public_prompt = PromptBuilder::new("greeting")
3315            .description("A greeting")
3316            .user_message("Hello!");
3317
3318        let mut router = McpRouter::new()
3319            .prompt(public_prompt)
3320            .prompt_filter(CapabilityFilter::new(|_, _: &Prompt| true)); // Allow all
3321
3322        // Initialize session
3323        init_router(&mut router).await;
3324
3325        let req = RouterRequest {
3326            id: RequestId::Number(1),
3327            inner: McpRequest::GetPrompt(GetPromptParams {
3328                name: "greeting".to_string(),
3329                arguments: HashMap::new(),
3330            }),
3331            extensions: Extensions::new(),
3332        };
3333
3334        let resp = router.ready().await.unwrap().call(req).await.unwrap();
3335
3336        match resp.inner {
3337            Ok(McpResponse::GetPrompt(result)) => {
3338                assert_eq!(result.messages.len(), 1);
3339            }
3340            _ => panic!("Expected GetPrompt response"),
3341        }
3342    }
3343
3344    #[tokio::test]
3345    async fn test_prompt_filter_custom_denial() {
3346        use crate::filter::{CapabilityFilter, DenialBehavior};
3347        use crate::prompt::{Prompt, PromptBuilder};
3348        use std::collections::HashMap;
3349
3350        let admin_prompt = PromptBuilder::new("system_debug")
3351            .description("Admin prompt")
3352            .user_message("Debug");
3353
3354        let mut router = McpRouter::new().prompt(admin_prompt).prompt_filter(
3355            CapabilityFilter::new(|_, _: &Prompt| false)
3356                .denial_behavior(DenialBehavior::Unauthorized),
3357        );
3358
3359        // Initialize session
3360        init_router(&mut router).await;
3361
3362        let req = RouterRequest {
3363            id: RequestId::Number(1),
3364            inner: McpRequest::GetPrompt(GetPromptParams {
3365                name: "system_debug".to_string(),
3366                arguments: HashMap::new(),
3367            }),
3368            extensions: Extensions::new(),
3369        };
3370
3371        let resp = router.ready().await.unwrap().call(req).await.unwrap();
3372
3373        // Should get forbidden error
3374        match resp.inner {
3375            Err(e) => {
3376                assert_eq!(e.code, -32007); // Forbidden
3377                assert!(e.message.contains("Unauthorized"));
3378            }
3379            _ => panic!("Expected JsonRpc error"),
3380        }
3381    }
3382
3383    // =========================================================================
3384    // Router Composition Tests (merge/nest)
3385    // =========================================================================
3386
3387    #[derive(Debug, Deserialize, JsonSchema)]
3388    struct StringInput {
3389        value: String,
3390    }
3391
3392    #[tokio::test]
3393    async fn test_router_merge_tools() {
3394        // Create first router with a tool
3395        let tool_a = ToolBuilder::new("tool_a")
3396            .description("Tool A")
3397            .handler(|_: StringInput| async move { Ok(CallToolResult::text("A")) })
3398            .build()
3399            .unwrap();
3400
3401        let router_a = McpRouter::new().tool(tool_a);
3402
3403        // Create second router with different tools
3404        let tool_b = ToolBuilder::new("tool_b")
3405            .description("Tool B")
3406            .handler(|_: StringInput| async move { Ok(CallToolResult::text("B")) })
3407            .build()
3408            .unwrap();
3409        let tool_c = ToolBuilder::new("tool_c")
3410            .description("Tool C")
3411            .handler(|_: StringInput| async move { Ok(CallToolResult::text("C")) })
3412            .build()
3413            .unwrap();
3414
3415        let router_b = McpRouter::new().tool(tool_b).tool(tool_c);
3416
3417        // Merge them
3418        let mut merged = McpRouter::new()
3419            .server_info("merged", "1.0")
3420            .merge(router_a)
3421            .merge(router_b);
3422
3423        init_router(&mut merged).await;
3424
3425        // List tools
3426        let req = RouterRequest {
3427            id: RequestId::Number(1),
3428            inner: McpRequest::ListTools(ListToolsParams::default()),
3429            extensions: Extensions::new(),
3430        };
3431
3432        let resp = merged.ready().await.unwrap().call(req).await.unwrap();
3433
3434        match resp.inner {
3435            Ok(McpResponse::ListTools(result)) => {
3436                assert_eq!(result.tools.len(), 3);
3437                let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
3438                assert!(names.contains(&"tool_a"));
3439                assert!(names.contains(&"tool_b"));
3440                assert!(names.contains(&"tool_c"));
3441            }
3442            _ => panic!("Expected ListTools response"),
3443        }
3444    }
3445
3446    #[tokio::test]
3447    async fn test_router_merge_overwrites_duplicates() {
3448        // Create first router with a tool
3449        let tool_v1 = ToolBuilder::new("shared")
3450            .description("Version 1")
3451            .handler(|_: StringInput| async move { Ok(CallToolResult::text("v1")) })
3452            .build()
3453            .unwrap();
3454
3455        let router_a = McpRouter::new().tool(tool_v1);
3456
3457        // Create second router with same tool name but different description
3458        let tool_v2 = ToolBuilder::new("shared")
3459            .description("Version 2")
3460            .handler(|_: StringInput| async move { Ok(CallToolResult::text("v2")) })
3461            .build()
3462            .unwrap();
3463
3464        let router_b = McpRouter::new().tool(tool_v2);
3465
3466        // Merge - second should win
3467        let mut merged = McpRouter::new().merge(router_a).merge(router_b);
3468
3469        init_router(&mut merged).await;
3470
3471        let req = RouterRequest {
3472            id: RequestId::Number(1),
3473            inner: McpRequest::ListTools(ListToolsParams::default()),
3474            extensions: Extensions::new(),
3475        };
3476
3477        let resp = merged.ready().await.unwrap().call(req).await.unwrap();
3478
3479        match resp.inner {
3480            Ok(McpResponse::ListTools(result)) => {
3481                assert_eq!(result.tools.len(), 1);
3482                assert_eq!(result.tools[0].name, "shared");
3483                assert_eq!(result.tools[0].description.as_deref(), Some("Version 2"));
3484            }
3485            _ => panic!("Expected ListTools response"),
3486        }
3487    }
3488
3489    #[tokio::test]
3490    async fn test_router_merge_resources() {
3491        use crate::resource::ResourceBuilder;
3492
3493        // Create routers with different resources
3494        let router_a = McpRouter::new().resource(
3495            ResourceBuilder::new("file:///a.txt")
3496                .name("File A")
3497                .text("content a"),
3498        );
3499
3500        let router_b = McpRouter::new().resource(
3501            ResourceBuilder::new("file:///b.txt")
3502                .name("File B")
3503                .text("content b"),
3504        );
3505
3506        let mut merged = McpRouter::new().merge(router_a).merge(router_b);
3507
3508        init_router(&mut merged).await;
3509
3510        let req = RouterRequest {
3511            id: RequestId::Number(1),
3512            inner: McpRequest::ListResources(ListResourcesParams::default()),
3513            extensions: Extensions::new(),
3514        };
3515
3516        let resp = merged.ready().await.unwrap().call(req).await.unwrap();
3517
3518        match resp.inner {
3519            Ok(McpResponse::ListResources(result)) => {
3520                assert_eq!(result.resources.len(), 2);
3521                let uris: Vec<&str> = result.resources.iter().map(|r| r.uri.as_str()).collect();
3522                assert!(uris.contains(&"file:///a.txt"));
3523                assert!(uris.contains(&"file:///b.txt"));
3524            }
3525            _ => panic!("Expected ListResources response"),
3526        }
3527    }
3528
3529    #[tokio::test]
3530    async fn test_router_merge_prompts() {
3531        use crate::prompt::PromptBuilder;
3532
3533        let router_a =
3534            McpRouter::new().prompt(PromptBuilder::new("prompt_a").user_message("Hello A"));
3535
3536        let router_b =
3537            McpRouter::new().prompt(PromptBuilder::new("prompt_b").user_message("Hello B"));
3538
3539        let mut merged = McpRouter::new().merge(router_a).merge(router_b);
3540
3541        init_router(&mut merged).await;
3542
3543        let req = RouterRequest {
3544            id: RequestId::Number(1),
3545            inner: McpRequest::ListPrompts(ListPromptsParams::default()),
3546            extensions: Extensions::new(),
3547        };
3548
3549        let resp = merged.ready().await.unwrap().call(req).await.unwrap();
3550
3551        match resp.inner {
3552            Ok(McpResponse::ListPrompts(result)) => {
3553                assert_eq!(result.prompts.len(), 2);
3554                let names: Vec<&str> = result.prompts.iter().map(|p| p.name.as_str()).collect();
3555                assert!(names.contains(&"prompt_a"));
3556                assert!(names.contains(&"prompt_b"));
3557            }
3558            _ => panic!("Expected ListPrompts response"),
3559        }
3560    }
3561
3562    #[tokio::test]
3563    async fn test_router_nest_prefixes_tools() {
3564        // Create a router with tools
3565        let tool_query = ToolBuilder::new("query")
3566            .description("Query the database")
3567            .handler(|_: StringInput| async move { Ok(CallToolResult::text("query result")) })
3568            .build()
3569            .unwrap();
3570        let tool_insert = ToolBuilder::new("insert")
3571            .description("Insert into database")
3572            .handler(|_: StringInput| async move { Ok(CallToolResult::text("insert result")) })
3573            .build()
3574            .unwrap();
3575
3576        let db_router = McpRouter::new().tool(tool_query).tool(tool_insert);
3577
3578        // Nest under "db" prefix
3579        let mut router = McpRouter::new()
3580            .server_info("nested", "1.0")
3581            .nest("db", db_router);
3582
3583        init_router(&mut router).await;
3584
3585        let req = RouterRequest {
3586            id: RequestId::Number(1),
3587            inner: McpRequest::ListTools(ListToolsParams::default()),
3588            extensions: Extensions::new(),
3589        };
3590
3591        let resp = router.ready().await.unwrap().call(req).await.unwrap();
3592
3593        match resp.inner {
3594            Ok(McpResponse::ListTools(result)) => {
3595                assert_eq!(result.tools.len(), 2);
3596                let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
3597                assert!(names.contains(&"db.query"));
3598                assert!(names.contains(&"db.insert"));
3599            }
3600            _ => panic!("Expected ListTools response"),
3601        }
3602    }
3603
3604    #[tokio::test]
3605    async fn test_router_nest_call_prefixed_tool() {
3606        let tool = ToolBuilder::new("echo")
3607            .description("Echo input")
3608            .handler(|input: StringInput| async move { Ok(CallToolResult::text(&input.value)) })
3609            .build()
3610            .unwrap();
3611
3612        let nested_router = McpRouter::new().tool(tool);
3613
3614        let mut router = McpRouter::new().nest("api", nested_router);
3615
3616        init_router(&mut router).await;
3617
3618        // Call the prefixed tool
3619        let req = RouterRequest {
3620            id: RequestId::Number(1),
3621            inner: McpRequest::CallTool(CallToolParams {
3622                name: "api.echo".to_string(),
3623                arguments: serde_json::json!({"value": "hello world"}),
3624                meta: None,
3625            }),
3626            extensions: Extensions::new(),
3627        };
3628
3629        let resp = router.ready().await.unwrap().call(req).await.unwrap();
3630
3631        match resp.inner {
3632            Ok(McpResponse::CallTool(result)) => {
3633                assert!(!result.is_error);
3634                match &result.content[0] {
3635                    Content::Text { text, .. } => assert_eq!(text, "hello world"),
3636                    _ => panic!("Expected text content"),
3637                }
3638            }
3639            _ => panic!("Expected CallTool response"),
3640        }
3641    }
3642
3643    #[tokio::test]
3644    async fn test_router_multiple_nests() {
3645        let db_tool = ToolBuilder::new("query")
3646            .description("Database query")
3647            .handler(|_: StringInput| async move { Ok(CallToolResult::text("db")) })
3648            .build()
3649            .unwrap();
3650
3651        let api_tool = ToolBuilder::new("fetch")
3652            .description("API fetch")
3653            .handler(|_: StringInput| async move { Ok(CallToolResult::text("api")) })
3654            .build()
3655            .unwrap();
3656
3657        let db_router = McpRouter::new().tool(db_tool);
3658        let api_router = McpRouter::new().tool(api_tool);
3659
3660        let mut router = McpRouter::new()
3661            .nest("db", db_router)
3662            .nest("api", api_router);
3663
3664        init_router(&mut router).await;
3665
3666        let req = RouterRequest {
3667            id: RequestId::Number(1),
3668            inner: McpRequest::ListTools(ListToolsParams::default()),
3669            extensions: Extensions::new(),
3670        };
3671
3672        let resp = router.ready().await.unwrap().call(req).await.unwrap();
3673
3674        match resp.inner {
3675            Ok(McpResponse::ListTools(result)) => {
3676                assert_eq!(result.tools.len(), 2);
3677                let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
3678                assert!(names.contains(&"db.query"));
3679                assert!(names.contains(&"api.fetch"));
3680            }
3681            _ => panic!("Expected ListTools response"),
3682        }
3683    }
3684
3685    #[tokio::test]
3686    async fn test_router_merge_and_nest_combined() {
3687        // Test combining merge and nest
3688        let tool_a = ToolBuilder::new("local")
3689            .description("Local tool")
3690            .handler(|_: StringInput| async move { Ok(CallToolResult::text("local")) })
3691            .build()
3692            .unwrap();
3693
3694        let nested_tool = ToolBuilder::new("remote")
3695            .description("Remote tool")
3696            .handler(|_: StringInput| async move { Ok(CallToolResult::text("remote")) })
3697            .build()
3698            .unwrap();
3699
3700        let nested_router = McpRouter::new().tool(nested_tool);
3701
3702        let mut router = McpRouter::new()
3703            .tool(tool_a)
3704            .nest("external", nested_router);
3705
3706        init_router(&mut router).await;
3707
3708        let req = RouterRequest {
3709            id: RequestId::Number(1),
3710            inner: McpRequest::ListTools(ListToolsParams::default()),
3711            extensions: Extensions::new(),
3712        };
3713
3714        let resp = router.ready().await.unwrap().call(req).await.unwrap();
3715
3716        match resp.inner {
3717            Ok(McpResponse::ListTools(result)) => {
3718                assert_eq!(result.tools.len(), 2);
3719                let names: Vec<&str> = result.tools.iter().map(|t| t.name.as_str()).collect();
3720                assert!(names.contains(&"local"));
3721                assert!(names.contains(&"external.remote"));
3722            }
3723            _ => panic!("Expected ListTools response"),
3724        }
3725    }
3726
3727    #[tokio::test]
3728    async fn test_router_merge_preserves_server_info() {
3729        let child_router = McpRouter::new()
3730            .server_info("child", "2.0")
3731            .instructions("Child instructions");
3732
3733        let mut router = McpRouter::new()
3734            .server_info("parent", "1.0")
3735            .instructions("Parent instructions")
3736            .merge(child_router);
3737
3738        init_router(&mut router).await;
3739
3740        // Initialize response should have parent's server info
3741        let init_req = RouterRequest {
3742            id: RequestId::Number(99),
3743            inner: McpRequest::Initialize(InitializeParams {
3744                protocol_version: "2025-11-25".to_string(),
3745                capabilities: ClientCapabilities::default(),
3746                client_info: Implementation {
3747                    name: "test".to_string(),
3748                    version: "1.0".to_string(),
3749                    ..Default::default()
3750                },
3751            }),
3752            extensions: Extensions::new(),
3753        };
3754
3755        // Create fresh router for this test since we need to call initialize
3756        let child_router2 = McpRouter::new().server_info("child", "2.0");
3757        let mut fresh_router = McpRouter::new()
3758            .server_info("parent", "1.0")
3759            .merge(child_router2);
3760
3761        let resp = fresh_router
3762            .ready()
3763            .await
3764            .unwrap()
3765            .call(init_req)
3766            .await
3767            .unwrap();
3768
3769        match resp.inner {
3770            Ok(McpResponse::Initialize(result)) => {
3771                assert_eq!(result.server_info.name, "parent");
3772                assert_eq!(result.server_info.version, "1.0");
3773            }
3774            _ => panic!("Expected Initialize response"),
3775        }
3776    }
3777}