Skip to main content

rustic_ai/
mcp.rs

1use std::pin::Pin;
2use std::sync::Arc;
3use std::time::Duration;
4
5use async_stream::try_stream;
6use async_trait::async_trait;
7use eventsource_stream::Eventsource;
8use futures::lock::Mutex;
9use futures::stream::StreamExt;
10use reqwest::header::HeaderMap;
11use reqwest::{Client, Url};
12use serde::Deserialize;
13use serde_json::{Value, json};
14use uuid::Uuid;
15
16use crate::tools::{RunContext, ToolDefinition, ToolError, ToolKind, Toolset};
17
18#[derive(Clone, Debug)]
19pub struct McpServerStreamableHttp {
20    url: Url,
21    headers: HeaderMap,
22    timeout: Duration,
23    tool_prefix: Option<String>,
24    client: Client,
25    events_url: Option<Url>,
26    cache_tools: bool,
27    cache_resources: bool,
28    cache_prompts: bool,
29    cached_tools: Arc<Mutex<Option<Vec<ToolDefinition>>>>,
30    cached_resources: Arc<Mutex<Option<Vec<McpResource>>>>,
31    cached_prompts: Arc<Mutex<Option<Vec<McpPrompt>>>>,
32}
33
34impl McpServerStreamableHttp {
35    pub fn new(url: impl AsRef<str>) -> Result<Self, ToolError> {
36        let url = Url::parse(url.as_ref())
37            .map_err(|e| ToolError::Toolset(format!("invalid MCP URL: {e}")))?;
38        let timeout = Duration::from_secs(10);
39        let client = Client::builder()
40            .timeout(timeout)
41            .build()
42            .map_err(|e| ToolError::Toolset(format!("failed to build HTTP client: {e}")))?;
43        Ok(Self {
44            url,
45            headers: HeaderMap::new(),
46            timeout,
47            tool_prefix: None,
48            client,
49            events_url: None,
50            cache_tools: true,
51            cache_resources: true,
52            cache_prompts: true,
53            cached_tools: Arc::new(Mutex::new(None)),
54            cached_resources: Arc::new(Mutex::new(None)),
55            cached_prompts: Arc::new(Mutex::new(None)),
56        })
57    }
58
59    pub fn with_headers(mut self, headers: HeaderMap) -> Self {
60        self.headers = headers;
61        self
62    }
63
64    pub fn with_timeout(mut self, timeout: Duration) -> Self {
65        self.timeout = timeout;
66        self.client = Client::builder()
67            .timeout(timeout)
68            .build()
69            .unwrap_or_else(|_| Client::new());
70        self
71    }
72
73    pub fn with_tool_prefix(mut self, prefix: impl Into<String>) -> Self {
74        self.tool_prefix = Some(prefix.into());
75        self
76    }
77
78    pub fn with_events_url(mut self, url: impl AsRef<str>) -> Result<Self, ToolError> {
79        self.events_url = Some(
80            Url::parse(url.as_ref())
81                .map_err(|e| ToolError::Toolset(format!("invalid MCP events URL: {e}")))?,
82        );
83        Ok(self)
84    }
85
86    pub fn cache_tools(mut self, enabled: bool) -> Self {
87        self.cache_tools = enabled;
88        self
89    }
90
91    pub fn cache_resources(mut self, enabled: bool) -> Self {
92        self.cache_resources = enabled;
93        self
94    }
95
96    pub fn cache_prompts(mut self, enabled: bool) -> Self {
97        self.cache_prompts = enabled;
98        self
99    }
100
101    pub async fn invalidate_tools_cache(&self) {
102        *self.cached_tools.lock().await = None;
103    }
104
105    pub async fn invalidate_resources_cache(&self) {
106        *self.cached_resources.lock().await = None;
107    }
108
109    pub async fn invalidate_prompts_cache(&self) {
110        *self.cached_prompts.lock().await = None;
111    }
112
113    async fn rpc(&self, method: &str, params: Value) -> Result<Value, ToolError> {
114        let request_id = Uuid::new_v4().to_string();
115        let payload = json!({
116            "jsonrpc": "2.0",
117            "id": request_id,
118            "method": method,
119            "params": params,
120        });
121        let response = self
122            .client
123            .post(self.url.clone())
124            .headers(self.headers.clone())
125            .json(&payload)
126            .send()
127            .await
128            .map_err(|e| ToolError::Toolset(format!("MCP request failed: {e}")))?;
129
130        let status = response.status();
131        let value: Value = response
132            .json()
133            .await
134            .map_err(|e| ToolError::Toolset(format!("MCP response parse failed: {e}")))?;
135
136        if let Some(error) = value.get("error") {
137            return Err(ToolError::Toolset(format!(
138                "MCP error (status {status}): {error}"
139            )));
140        }
141        value
142            .get("result")
143            .cloned()
144            .ok_or_else(|| ToolError::Toolset("MCP response missing result".to_string()))
145    }
146
147    fn prefix_name(&self, name: &str) -> String {
148        if let Some(prefix) = &self.tool_prefix {
149            format!("{}__{}", prefix, name)
150        } else {
151            name.to_string()
152        }
153    }
154
155    fn unprefix_name<'a>(&self, name: &'a str) -> &'a str {
156        if let Some(prefix) = &self.tool_prefix {
157            let expected = format!("{}__", prefix);
158            name.strip_prefix(&expected).unwrap_or(name)
159        } else {
160            name
161        }
162    }
163
164    pub async fn list_resources(&self) -> Result<Vec<McpResource>, ToolError> {
165        if self.cache_resources
166            && let Some(cached) = self.cached_resources.lock().await.clone()
167        {
168            return Ok(cached);
169        }
170
171        let result = self.rpc("resources/list", json!({})).await?;
172        let resources: RpcResourcesList = serde_json::from_value(result)
173            .map_err(|e| ToolError::Toolset(format!("invalid MCP resources list: {e}")))?;
174        if self.cache_resources {
175            *self.cached_resources.lock().await = Some(resources.resources.clone());
176        }
177        Ok(resources.resources)
178    }
179
180    pub async fn list_resource_templates(&self) -> Result<Vec<McpResourceTemplate>, ToolError> {
181        let result = self.rpc("resources/templates/list", json!({})).await?;
182        let templates: RpcResourceTemplatesList = serde_json::from_value(result)
183            .map_err(|e| ToolError::Toolset(format!("invalid MCP resource templates list: {e}")))?;
184        Ok(templates.resource_templates)
185    }
186
187    pub async fn read_resource(&self, uri: &str) -> Result<Value, ToolError> {
188        let result = self.rpc("resources/read", json!({ "uri": uri })).await?;
189        Ok(result)
190    }
191
192    pub async fn list_prompts(&self) -> Result<Vec<McpPrompt>, ToolError> {
193        if self.cache_prompts
194            && let Some(cached) = self.cached_prompts.lock().await.clone()
195        {
196            return Ok(cached);
197        }
198
199        let result = self.rpc("prompts/list", json!({})).await?;
200        let prompts: RpcPromptsList = serde_json::from_value(result)
201            .map_err(|e| ToolError::Toolset(format!("invalid MCP prompts list: {e}")))?;
202        if self.cache_prompts {
203            *self.cached_prompts.lock().await = Some(prompts.prompts.clone());
204        }
205        Ok(prompts.prompts)
206    }
207
208    pub async fn get_prompt(
209        &self,
210        name: &str,
211        arguments: Option<Value>,
212    ) -> Result<Vec<McpPromptMessage>, ToolError> {
213        let mut params = json!({ "name": name });
214        if let Some(arguments) = arguments
215            && let Value::Object(map) = &mut params
216        {
217            map.insert("arguments".to_string(), arguments);
218        }
219        let result = self.rpc("prompts/get", params).await?;
220        let prompt: RpcPromptGet = serde_json::from_value(result)
221            .map_err(|e| ToolError::Toolset(format!("invalid MCP prompt: {e}")))?;
222        Ok(prompt.messages)
223    }
224
225    pub async fn sample(&self, params: Value) -> Result<Value, ToolError> {
226        self.rpc("sampling/createMessage", params).await
227    }
228
229    pub async fn notifications(&self) -> Result<McpNotificationStream, ToolError> {
230        let events_url = self
231            .events_url
232            .clone()
233            .ok_or_else(|| ToolError::Toolset("MCP events URL not configured".to_string()))?;
234
235        let response = self
236            .client
237            .get(events_url)
238            .headers(self.headers.clone())
239            .send()
240            .await
241            .map_err(|e| ToolError::Toolset(format!("MCP events request failed: {e}")))?;
242
243        if !response.status().is_success() {
244            return Err(ToolError::Toolset(format!(
245                "MCP events error status {}",
246                response.status()
247            )));
248        }
249
250        let mut event_stream = response.bytes_stream().eventsource();
251        let cached_tools = Arc::clone(&self.cached_tools);
252        let cached_resources = Arc::clone(&self.cached_resources);
253        let cached_prompts = Arc::clone(&self.cached_prompts);
254
255        let stream = try_stream! {
256            while let Some(event) = event_stream.next().await {
257                let event = event.map_err(|e| ToolError::Toolset(format!("MCP events stream error: {e}")))?;
258                let notification: McpNotification = serde_json::from_str(&event.data)
259                    .map_err(|e| ToolError::Toolset(format!("MCP notification parse error: {e}")))?;
260
261                match notification.method.as_str() {
262                    "notifications/tools/list_changed" => {
263                        *cached_tools.lock().await = None;
264                    }
265                    "notifications/resources/list_changed" => {
266                        *cached_resources.lock().await = None;
267                    }
268                    "notifications/prompts/list_changed" => {
269                        *cached_prompts.lock().await = None;
270                    }
271                    _ => {}
272                }
273
274                yield notification;
275            }
276        };
277
278        Ok(Box::pin(stream))
279    }
280}
281
282#[derive(Debug, Deserialize)]
283struct RpcToolsList {
284    tools: Vec<RpcTool>,
285}
286
287#[derive(Debug, Deserialize)]
288struct RpcTool {
289    name: String,
290    description: Option<String>,
291    #[serde(rename = "inputSchema")]
292    input_schema: Value,
293    meta: Option<Value>,
294    annotations: Option<Value>,
295    #[serde(rename = "outputSchema")]
296    output_schema: Option<Value>,
297}
298
299#[derive(Debug, Deserialize)]
300struct RpcResourcesList {
301    resources: Vec<McpResource>,
302}
303
304#[derive(Debug, Deserialize)]
305struct RpcResourceTemplatesList {
306    #[serde(rename = "resourceTemplates")]
307    resource_templates: Vec<McpResourceTemplate>,
308}
309
310#[derive(Debug, Deserialize)]
311struct RpcPromptsList {
312    prompts: Vec<McpPrompt>,
313}
314
315#[derive(Debug, Deserialize)]
316struct RpcPromptGet {
317    messages: Vec<McpPromptMessage>,
318}
319
320#[derive(Clone, Debug, Deserialize)]
321pub struct McpResource {
322    pub uri: String,
323    pub name: Option<String>,
324    pub description: Option<String>,
325    #[serde(rename = "mimeType")]
326    pub mime_type: Option<String>,
327    pub metadata: Option<Value>,
328}
329
330#[derive(Clone, Debug, Deserialize)]
331pub struct McpResourceTemplate {
332    pub name: String,
333    pub description: Option<String>,
334    pub uri_template: Option<String>,
335    pub metadata: Option<Value>,
336}
337
338#[derive(Clone, Debug, Deserialize)]
339pub struct McpPrompt {
340    pub name: String,
341    pub description: Option<String>,
342    pub arguments: Option<Vec<McpPromptArgument>>,
343}
344
345#[derive(Clone, Debug, Deserialize)]
346pub struct McpPromptArgument {
347    pub name: String,
348    pub description: Option<String>,
349    pub required: Option<bool>,
350}
351
352#[derive(Clone, Debug, Deserialize)]
353pub struct McpPromptMessage {
354    pub role: String,
355    pub content: Value,
356}
357
358#[derive(Clone, Debug, Deserialize)]
359pub struct McpNotification {
360    pub method: String,
361    pub params: Option<Value>,
362}
363
364pub type McpNotificationStream =
365    Pin<Box<dyn futures::stream::Stream<Item = Result<McpNotification, ToolError>> + Send>>;
366
367#[async_trait]
368impl<Deps> Toolset<Deps> for McpServerStreamableHttp
369where
370    Deps: Send + Sync,
371{
372    async fn list_tools(&self, _ctx: &RunContext<Deps>) -> Result<Vec<ToolDefinition>, ToolError> {
373        if self.cache_tools
374            && let Some(cached) = self.cached_tools.lock().await.clone()
375        {
376            return Ok(cached);
377        }
378
379        let result = self.rpc("tools/list", json!({})).await?;
380        let tools: RpcToolsList = serde_json::from_value(result)
381            .map_err(|e| ToolError::Toolset(format!("invalid MCP tools list: {e}")))?;
382        let mapped: Vec<ToolDefinition> = tools
383            .tools
384            .into_iter()
385            .map(|tool| {
386                let mut def = ToolDefinition::new(
387                    self.prefix_name(&tool.name),
388                    tool.description,
389                    tool.input_schema,
390                );
391                def.kind = ToolKind::Function;
392                def.metadata = Some(json!({
393                    "meta": tool.meta,
394                    "annotations": tool.annotations,
395                    "output_schema": tool.output_schema,
396                }));
397                def
398            })
399            .collect();
400
401        if self.cache_tools {
402            *self.cached_tools.lock().await = Some(mapped.clone());
403        }
404
405        Ok(mapped)
406    }
407
408    async fn call_tool(
409        &self,
410        _ctx: &RunContext<Deps>,
411        name: &str,
412        args: Value,
413    ) -> Result<Value, ToolError> {
414        let name = self.unprefix_name(name).to_string();
415        let result = self
416            .rpc("tools/call", json!({"name": name, "arguments": args}))
417            .await?;
418
419        if let Some(structured) = result.get("structuredContent") {
420            return Ok(structured.clone());
421        }
422
423        if let Some(content) = result.get("content")
424            && let Some(array) = content.as_array()
425            && array.len() == 1
426            && let Some(text) = array[0].get("text").and_then(|v| v.as_str())
427        {
428            return Ok(Value::String(text.to_string()));
429        }
430
431        Ok(result)
432    }
433
434    fn name(&self) -> &str {
435        "mcp-http"
436    }
437}
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442    use serde_json::json;
443    use std::sync::{Arc, Mutex};
444    use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
445    use tokio::net::TcpListener;
446
447    #[derive(Default)]
448    struct RpcState {
449        tool_list_calls: usize,
450        tool_call_calls: usize,
451        last_tool_name: Option<String>,
452        resource_list_calls: usize,
453        resource_template_calls: usize,
454        resource_read_calls: usize,
455        prompt_list_calls: usize,
456        prompt_get_calls: usize,
457        last_resource_uri: Option<String>,
458    }
459
460    async fn spawn_rpc_server(
461        state: Arc<Mutex<RpcState>>,
462    ) -> Result<(String, tokio::task::JoinHandle<()>), ToolError> {
463        let listener = TcpListener::bind("127.0.0.1:0")
464            .await
465            .map_err(|e| ToolError::Toolset(format!("bind failed: {e}")))?;
466        let addr = listener
467            .local_addr()
468            .map_err(|e| ToolError::Toolset(format!("addr failed: {e}")))?;
469        let base_url = format!("http://{}", addr);
470
471        let handle = tokio::spawn(async move {
472            loop {
473                let (stream, _) = match listener.accept().await {
474                    Ok(pair) => pair,
475                    Err(_) => break,
476                };
477                let state = Arc::clone(&state);
478                tokio::spawn(async move {
479                    let mut reader = BufReader::new(stream);
480                    loop {
481                        let mut content_length: usize = 0;
482                        let mut saw_header = false;
483                        loop {
484                            let mut line = String::new();
485                            let bytes = reader.read_line(&mut line).await.unwrap_or(0);
486                            if bytes == 0 {
487                                return;
488                            }
489                            if line == "\r\n" {
490                                break;
491                            }
492                            saw_header = true;
493                            let lower = line.to_ascii_lowercase();
494                            if lower.starts_with("content-length:")
495                                && let Some(value) = line.split(':').nth(1)
496                            {
497                                content_length = value.trim().parse().unwrap_or(0);
498                            }
499                        }
500
501                        if !saw_header {
502                            return;
503                        }
504
505                        let mut body = vec![0u8; content_length];
506                        if content_length > 0 && reader.read_exact(&mut body).await.is_err() {
507                            return;
508                        }
509
510                        let request: Value = serde_json::from_slice(&body).unwrap_or(Value::Null);
511                        let method = request.get("method").and_then(Value::as_str).unwrap_or("");
512                        let id = request.get("id").cloned().unwrap_or(Value::Null);
513                        let params = request.get("params").cloned().unwrap_or(Value::Null);
514
515                        let result = match method {
516                            "tools/list" => {
517                                let mut guard = state.lock().unwrap();
518                                guard.tool_list_calls += 1;
519                                json!({
520                                    "tools": [{
521                                        "name": "echo",
522                                        "description": "Echo input",
523                                        "inputSchema": {"type": "object", "properties": {}},
524                                        "meta": {"version": 1},
525                                        "annotations": {"note": "test"},
526                                        "outputSchema": {"type": "object"}
527                                    }]
528                                })
529                            }
530                            "tools/call" => {
531                                let mut guard = state.lock().unwrap();
532                                guard.tool_call_calls += 1;
533                                guard.last_tool_name = params
534                                    .get("name")
535                                    .and_then(Value::as_str)
536                                    .map(|v| v.to_string());
537                                match guard.last_tool_name.as_deref() {
538                                    Some("structured") => {
539                                        json!({ "structuredContent": { "ok": true } })
540                                    }
541                                    Some("text") => {
542                                        json!({ "content": [ { "text": "hi" } ] })
543                                    }
544                                    _ => json!({ "value": 42 }),
545                                }
546                            }
547                            "resources/list" => {
548                                let mut guard = state.lock().unwrap();
549                                guard.resource_list_calls += 1;
550                                json!({ "resources": [{
551                                    "uri": "file:///tmp/example.txt",
552                                    "name": "example",
553                                    "description": "example resource",
554                                    "mimeType": "text/plain",
555                                    "metadata": {"version": 1}
556                                }] })
557                            }
558                            "resources/templates/list" => {
559                                let mut guard = state.lock().unwrap();
560                                guard.resource_template_calls += 1;
561                                json!({ "resourceTemplates": [{
562                                    "name": "example-template",
563                                    "description": "example template",
564                                    "uriTemplate": "file:///tmp/{name}",
565                                    "metadata": {"source": "test"}
566                                }] })
567                            }
568                            "resources/read" => {
569                                let mut guard = state.lock().unwrap();
570                                guard.resource_read_calls += 1;
571                                guard.last_resource_uri = params
572                                    .get("uri")
573                                    .and_then(Value::as_str)
574                                    .map(|v| v.to_string());
575                                json!({ "contents": [{
576                                    "uri": guard.last_resource_uri.clone().unwrap_or_default(),
577                                    "text": "hello"
578                                }] })
579                            }
580                            "prompts/list" => {
581                                let mut guard = state.lock().unwrap();
582                                guard.prompt_list_calls += 1;
583                                json!({ "prompts": [{
584                                    "name": "example-prompt",
585                                    "description": "example prompt",
586                                    "arguments": [{"name": "topic", "required": true}]
587                                }] })
588                            }
589                            "prompts/get" => {
590                                let mut guard = state.lock().unwrap();
591                                guard.prompt_get_calls += 1;
592                                json!({ "messages": [{
593                                    "role": "user",
594                                    "content": "hi"
595                                }] })
596                            }
597                            _ => {
598                                json!({ "error": { "message": "unknown method" } })
599                            }
600                        };
601
602                        let response = if result.get("error").is_some() {
603                            json!({ "jsonrpc": "2.0", "id": id, "error": result["error"] })
604                        } else {
605                            json!({ "jsonrpc": "2.0", "id": id, "result": result })
606                        };
607                        let response_bytes = response.to_string();
608                        let stream = reader.get_mut();
609                        let _ = stream
610                            .write_all(
611                                format!(
612                                    "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
613                                    response_bytes.len(),
614                                    response_bytes
615                                )
616                                .as_bytes(),
617                            )
618                            .await;
619                    }
620                });
621            }
622        });
623
624        Ok((base_url, handle))
625    }
626
627    #[tokio::test]
628    async fn list_tools_applies_prefix_and_caches() {
629        let state = Arc::new(Mutex::new(RpcState::default()));
630        let (base_url, handle) = spawn_rpc_server(Arc::clone(&state)).await.expect("server");
631
632        let client = McpServerStreamableHttp::new(base_url)
633            .expect("client")
634            .with_tool_prefix("remote");
635
636        let ctx = RunContext {
637            run_id: "run".to_string(),
638            deps: Arc::new(()),
639            model: Arc::new(crate::providers::openai::OpenAIChatModel::new(
640                "gpt-test",
641                "key".to_string(),
642                Url::parse("https://example.com/").expect("url"),
643                None,
644            )),
645            usage: crate::usage::RunUsage::default(),
646            prompt: None,
647            messages: Arc::new(Vec::new()),
648            tool_call_id: None,
649            tool_name: None,
650        };
651
652        let tools = client.list_tools(&ctx).await.expect("tools");
653        assert_eq!(tools.len(), 1);
654        assert_eq!(tools[0].name, "remote__echo");
655        assert!(tools[0].metadata.is_some());
656
657        let tools_again = client.list_tools(&ctx).await.expect("tools again");
658        assert_eq!(tools_again.len(), 1);
659
660        let calls = state.lock().unwrap().tool_list_calls;
661        assert_eq!(calls, 1, "tools list should be cached");
662
663        client.invalidate_tools_cache().await;
664        let _ = client
665            .list_tools(&ctx)
666            .await
667            .expect("tools after invalidate");
668        let calls = state.lock().unwrap().tool_list_calls;
669        assert_eq!(calls, 2, "tools list should be refreshed");
670
671        handle.abort();
672    }
673
674    #[tokio::test]
675    async fn call_tool_returns_structured_or_text() {
676        let state = Arc::new(Mutex::new(RpcState::default()));
677        let (base_url, handle) = spawn_rpc_server(Arc::clone(&state)).await.expect("server");
678
679        let client = McpServerStreamableHttp::new(base_url)
680            .expect("client")
681            .with_tool_prefix("remote");
682
683        let ctx = RunContext {
684            run_id: "run".to_string(),
685            deps: Arc::new(()),
686            model: Arc::new(crate::providers::openai::OpenAIChatModel::new(
687                "gpt-test",
688                "key".to_string(),
689                Url::parse("https://example.com/").expect("url"),
690                None,
691            )),
692            usage: crate::usage::RunUsage::default(),
693            prompt: None,
694            messages: Arc::new(Vec::new()),
695            tool_call_id: None,
696            tool_name: None,
697        };
698
699        let structured = client
700            .call_tool(&ctx, "remote__structured", json!({}))
701            .await
702            .expect("structured");
703        assert_eq!(structured, json!({"ok": true}));
704
705        let text = client
706            .call_tool(&ctx, "remote__text", json!({}))
707            .await
708            .expect("text");
709        assert_eq!(text, Value::String("hi".to_string()));
710
711        let calls = state.lock().unwrap().tool_call_calls;
712        assert_eq!(calls, 2);
713
714        handle.abort();
715    }
716
717    #[tokio::test]
718    async fn list_resources_caches_and_invalidates() {
719        let state = Arc::new(Mutex::new(RpcState::default()));
720        let (base_url, handle) = spawn_rpc_server(Arc::clone(&state)).await.expect("server");
721
722        let client = McpServerStreamableHttp::new(base_url)
723            .expect("client")
724            .cache_resources(true);
725
726        let resources = client.list_resources().await.expect("resources");
727        assert_eq!(resources.len(), 1);
728        let resources_again = client.list_resources().await.expect("resources again");
729        assert_eq!(resources_again.len(), 1);
730
731        let calls = state.lock().unwrap().resource_list_calls;
732        assert_eq!(calls, 1);
733
734        client.invalidate_resources_cache().await;
735        let _ = client.list_resources().await.expect("resources refreshed");
736        let calls = state.lock().unwrap().resource_list_calls;
737        assert_eq!(calls, 2);
738
739        handle.abort();
740    }
741
742    #[tokio::test]
743    async fn list_prompts_and_get_prompt() {
744        let state = Arc::new(Mutex::new(RpcState::default()));
745        let (base_url, handle) = spawn_rpc_server(Arc::clone(&state)).await.expect("server");
746
747        let client = McpServerStreamableHttp::new(base_url)
748            .expect("client")
749            .cache_prompts(true);
750
751        let prompts = client.list_prompts().await.expect("prompts");
752        assert_eq!(prompts.len(), 1);
753        let prompts_again = client.list_prompts().await.expect("prompts again");
754        assert_eq!(prompts_again.len(), 1);
755
756        let calls = state.lock().unwrap().prompt_list_calls;
757        assert_eq!(calls, 1);
758
759        let messages = client
760            .get_prompt("example-prompt", None)
761            .await
762            .expect("prompt messages");
763        assert_eq!(messages.len(), 1);
764        assert_eq!(messages[0].role, "user");
765
766        let calls = state.lock().unwrap().prompt_get_calls;
767        assert_eq!(calls, 1);
768
769        handle.abort();
770    }
771
772    #[tokio::test]
773    async fn list_templates_and_read_resource() {
774        let state = Arc::new(Mutex::new(RpcState::default()));
775        let (base_url, handle) = spawn_rpc_server(Arc::clone(&state)).await.expect("server");
776
777        let client = McpServerStreamableHttp::new(base_url).expect("client");
778        let templates = client.list_resource_templates().await.expect("templates");
779        assert_eq!(templates.len(), 1);
780        assert_eq!(templates[0].name, "example-template");
781
782        let content = client
783            .read_resource("file:///tmp/example.txt")
784            .await
785            .expect("read resource");
786        let text = content
787            .get("contents")
788            .and_then(|value| value.as_array())
789            .and_then(|items| items.first())
790            .and_then(|item| item.get("text"))
791            .and_then(|value| value.as_str())
792            .expect("text");
793        assert_eq!(text, "hello");
794
795        let state = state.lock().unwrap();
796        assert_eq!(state.resource_template_calls, 1);
797        assert_eq!(state.resource_read_calls, 1);
798        assert_eq!(
799            state.last_resource_uri.as_deref(),
800            Some("file:///tmp/example.txt")
801        );
802
803        handle.abort();
804    }
805
806    #[tokio::test]
807    async fn cache_prompts_can_be_disabled() {
808        let state = Arc::new(Mutex::new(RpcState::default()));
809        let (base_url, handle) = spawn_rpc_server(Arc::clone(&state)).await.expect("server");
810
811        let client = McpServerStreamableHttp::new(base_url)
812            .expect("client")
813            .cache_prompts(false);
814
815        let _ = client.list_prompts().await.expect("prompts");
816        let _ = client.list_prompts().await.expect("prompts again");
817
818        let calls = state.lock().unwrap().prompt_list_calls;
819        assert_eq!(calls, 2);
820
821        handle.abort();
822    }
823}