1use std::collections::HashMap;
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::atomic::{AtomicU64, Ordering};
7use std::sync::{Arc, RwLock};
8use std::time::{Duration, Instant};
9
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
13
14use crate::error::Error;
15use crate::llm::types::ToolDefinition;
16use crate::tool::{Tool, ToolOutput};
17
18const PROTOCOL_VERSION: &str = "2025-11-25";
19const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
20
21#[derive(Debug, Serialize)]
24struct JsonRpcRequest {
25 jsonrpc: &'static str,
26 method: String,
27 #[serde(skip_serializing_if = "Option::is_none")]
28 params: Option<Value>,
29 id: u64,
30}
31
32#[derive(Debug, Serialize)]
33struct JsonRpcNotification {
34 jsonrpc: &'static str,
35 method: String,
36 #[serde(skip_serializing_if = "Option::is_none")]
37 params: Option<Value>,
38}
39
40#[derive(Debug, Deserialize)]
41struct JsonRpcResponse {
42 result: Option<Value>,
43 error: Option<JsonRpcError>,
44}
45
46#[derive(Debug, Deserialize)]
47struct JsonRpcError {
48 code: i64,
49 message: String,
50}
51
52#[derive(Debug, Deserialize)]
55#[serde(rename_all = "camelCase")]
56struct McpToolDef {
57 name: String,
58 #[serde(default)]
59 description: Option<String>,
60 #[serde(default)]
61 input_schema: Option<Value>,
62}
63
64#[derive(Debug, Deserialize)]
65struct McpToolsListResult {
66 tools: Vec<McpToolDef>,
67 #[serde(default, rename = "nextCursor")]
68 next_cursor: Option<String>,
69}
70
71#[derive(Debug, Deserialize)]
72struct McpContent {
73 #[serde(rename = "type")]
74 content_type: String,
75 #[serde(default)]
76 text: Option<String>,
77}
78
79#[derive(Debug, Deserialize)]
80#[serde(rename_all = "camelCase")]
81struct McpCallToolResult {
82 content: Vec<McpContent>,
83 #[serde(default)]
84 is_error: bool,
85}
86
87#[derive(Debug, Default, Deserialize)]
90#[allow(dead_code)]
91struct ServerCapabilities {
92 #[serde(default)]
93 resources: Option<ResourcesCapability>,
94 #[serde(default)]
95 prompts: Option<PromptsCapability>,
96 #[serde(default)]
97 logging: Option<Value>,
98}
99
100#[derive(Debug, Default, Deserialize)]
101#[serde(rename_all = "camelCase")]
102#[allow(dead_code)]
103struct ResourcesCapability {
104 #[serde(default)]
105 subscribe: bool,
106 #[serde(default)]
107 list_changed: bool,
108}
109
110#[derive(Debug, Default, Deserialize)]
111#[serde(rename_all = "camelCase")]
112#[allow(dead_code)]
113struct PromptsCapability {
114 #[serde(default)]
115 list_changed: bool,
116}
117
118#[derive(Debug, Default, Deserialize)]
119#[serde(rename_all = "camelCase")]
120#[allow(dead_code)]
121struct InitializeResult {
122 #[serde(default)]
123 capabilities: ServerCapabilities,
124 #[serde(default)]
125 server_info: Option<Value>,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
132#[serde(rename_all = "camelCase")]
133pub struct McpResourceDef {
134 pub uri: String,
135 pub name: String,
136 #[serde(default, skip_serializing_if = "Option::is_none")]
137 pub description: Option<String>,
138 #[serde(default, skip_serializing_if = "Option::is_none")]
139 pub mime_type: Option<String>,
140}
141
142#[derive(Debug, Deserialize)]
143struct McpResourcesListResult {
144 resources: Vec<McpResourceDef>,
145 #[serde(default, rename = "nextCursor")]
146 next_cursor: Option<String>,
147}
148
149#[derive(Debug, Clone, Deserialize)]
151#[serde(rename_all = "camelCase")]
152pub struct McpResourceContent {
153 pub uri: String,
154 #[serde(default)]
155 pub mime_type: Option<String>,
156 #[serde(default)]
157 pub text: Option<String>,
158 #[serde(default)]
159 pub blob: Option<String>,
160}
161
162#[derive(Debug, Deserialize)]
163struct McpResourceReadResult {
164 contents: Vec<McpResourceContent>,
165}
166
167#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct McpPromptDef {
172 pub name: String,
173 #[serde(default, skip_serializing_if = "Option::is_none")]
174 pub description: Option<String>,
175 #[serde(default, skip_serializing_if = "Vec::is_empty")]
176 pub arguments: Vec<McpPromptArgument>,
177}
178
179#[derive(Debug, Clone, Serialize, Deserialize)]
181pub struct McpPromptArgument {
182 pub name: String,
183 #[serde(default, skip_serializing_if = "Option::is_none")]
184 pub description: Option<String>,
185 #[serde(default)]
186 pub required: bool,
187}
188
189#[derive(Debug, Deserialize)]
190struct McpPromptsListResult {
191 prompts: Vec<McpPromptDef>,
192 #[serde(default, rename = "nextCursor")]
193 next_cursor: Option<String>,
194}
195
196#[derive(Debug, Clone, Deserialize)]
198pub struct McpPromptMessage {
199 pub role: String,
200 pub content: McpPromptMessageContent,
201}
202
203#[derive(Debug, Clone, Deserialize)]
205#[serde(rename_all = "camelCase")]
206pub struct McpPromptMessageContent {
207 #[serde(rename = "type")]
208 pub content_type: String,
209 #[serde(default)]
210 pub text: Option<String>,
211}
212
213#[derive(Debug, Deserialize)]
214#[allow(dead_code)]
215struct McpPromptGetResult {
216 #[serde(default)]
217 description: Option<String>,
218 messages: Vec<McpPromptMessage>,
219}
220
221fn handle_log_notification(value: &Value) {
233 fn sanitize_log_field(s: &str) -> String {
236 const MAX: usize = 4 * 1024;
237 let mut out = String::with_capacity(s.len().min(MAX));
238 for c in s.chars() {
239 if out.len() >= MAX {
240 out.push_str("…[truncated]");
241 break;
242 }
243 if c.is_control() {
244 out.push(' ');
245 } else {
246 out.push(c);
247 }
248 }
249 out
250 }
251 if let Some(params) = value.get("params") {
252 let level = params
253 .get("level")
254 .and_then(|v| v.as_str())
255 .unwrap_or("info");
256 let logger_raw = params
257 .get("logger")
258 .and_then(|v| v.as_str())
259 .unwrap_or("mcp");
260 let data_raw = params.get("data").and_then(|v| v.as_str()).unwrap_or("");
261 let logger = sanitize_log_field(logger_raw);
262 let data = sanitize_log_field(data_raw);
263 match level {
264 "error" | "critical" | "alert" | "emergency" => {
265 tracing::error!(target: "mcp_server", logger = %logger, "{data}");
266 }
267 "warning" => {
268 tracing::warn!(target: "mcp_server", logger = %logger, "{data}");
269 }
270 "debug" => {
271 tracing::debug!(target: "mcp_server", logger = %logger, "{data}");
272 }
273 _ => {
274 tracing::info!(target: "mcp_server", logger = %logger, "{data}");
275 }
276 }
277 }
278}
279
280fn extract_sse_events(body: &str) -> Result<Vec<String>, Error> {
288 let mut events: Vec<String> = Vec::new();
289 let mut current_lines: Vec<&str> = Vec::new();
290
291 for line in body.lines() {
292 if line.trim().is_empty() {
293 if !current_lines.is_empty() {
295 events.push(current_lines.join("\n"));
296 current_lines.clear();
297 }
298 } else if let Some(rest) = line.strip_prefix("data:") {
299 let data = rest.strip_prefix(' ').unwrap_or(rest);
301 current_lines.push(data);
302 }
303 }
304
305 if !current_lines.is_empty() {
307 events.push(current_lines.join("\n"));
308 }
309
310 if events.is_empty() {
311 return Err(Error::Mcp("No data field in SSE response".into()));
312 }
313 Ok(events)
314}
315
316fn find_rpc_response(events: &[String], expected_id: u64) -> Result<String, Error> {
327 let mut null_id_error: Option<String> = None;
328 for event in events {
329 if let Ok(value) = serde_json::from_str::<Value>(event) {
330 if value.get("method").and_then(|m| m.as_str()) == Some("notifications/message") {
332 handle_log_notification(&value);
333 continue;
334 }
335 if value.get("id").and_then(|v| v.as_u64()) == Some(expected_id) {
336 return Ok(event.clone());
337 }
338 if value.get("id").map(|v| v.is_null()).unwrap_or(false)
341 && value.get("error").is_some()
342 && null_id_error.is_none()
343 {
344 null_id_error = Some(event.clone());
345 }
346 }
347 }
348 if let Some(ev) = null_id_error {
349 return Ok(ev);
350 }
351 Err(Error::Mcp(format!(
352 "No JSON-RPC response with id={expected_id} found in SSE stream (F-MCP-5)"
353 )))
354}
355
356fn mcp_result_to_tool_output(result: McpCallToolResult) -> ToolOutput {
357 let non_text_count = result
358 .content
359 .iter()
360 .filter(|c| c.content_type != "text")
361 .count();
362 let text: String = result
363 .content
364 .iter()
365 .filter_map(|c| {
366 if c.content_type == "text" {
367 c.text.as_deref()
368 } else {
369 None
370 }
371 })
372 .collect::<Vec<_>>()
373 .join("\n");
374
375 let output = if text.is_empty() && non_text_count > 0 {
376 format!(
377 "[MCP server returned {non_text_count} non-text content block(s) that cannot be displayed]"
378 )
379 } else {
380 text
381 };
382
383 if result.is_error {
384 ToolOutput::error(output)
385 } else {
386 ToolOutput::success(output)
387 }
388}
389
390const MCP_DESCRIPTION_MAX_BYTES: usize = 4 * 1024;
397
398fn mcp_tool_to_definition(tool: &McpToolDef) -> ToolDefinition {
399 let raw_desc = tool.description.clone().unwrap_or_default();
400 ToolDefinition {
401 name: tool.name.clone(),
402 description: sanitize_description(&raw_desc),
408 input_schema: tool
409 .input_schema
410 .clone()
411 .unwrap_or_else(|| serde_json::json!({"type": "object"})),
412 }
413}
414
415fn redact_idp_body(body: &str) -> String {
422 static REDACTORS: std::sync::LazyLock<[(regex::Regex, &'static str); 3]> =
430 std::sync::LazyLock::new(|| {
431 [
432 (
433 regex::Regex::new(r"eyJ[A-Za-z0-9_\-]+\.[A-Za-z0-9_\-]+\.[A-Za-z0-9_\-]+")
434 .expect("static jwt pattern"),
435 "[redacted-jwt]",
436 ),
437 (
438 regex::Regex::new(r"(?i)bearer\s+[A-Za-z0-9_\-\.=]+")
439 .expect("static bearer pattern"),
440 "[redacted-bearer]",
441 ),
442 (
443 regex::Regex::new(
444 r#"(?i)("(?:access|id|refresh|subject)_token"\s*:\s*")[^"]+"#,
445 )
446 .expect("static token-field pattern"),
447 "$1[redacted]",
448 ),
449 ]
450 });
451 let mut out = std::borrow::Cow::Borrowed(body);
452 for (re, repl) in REDACTORS.iter() {
453 match re.replace_all(&out, *repl) {
454 std::borrow::Cow::Borrowed(_) => {}
455 std::borrow::Cow::Owned(s) => out = std::borrow::Cow::Owned(s),
456 }
457 }
458 out.into_owned()
459}
460
461fn sanitize_description(s: &str) -> String {
465 let mut out = String::with_capacity(s.len().min(MCP_DESCRIPTION_MAX_BYTES));
466 for c in s.chars() {
467 if out.len() >= MCP_DESCRIPTION_MAX_BYTES {
468 out.push_str("…[truncated]");
469 break;
470 }
471 if c.is_control() {
473 out.push(' ');
474 } else {
475 out.push(c);
476 }
477 }
478 out
479}
480
481fn process_rpc_response(json_str: &str) -> Result<Value, Error> {
485 let rpc_response: JsonRpcResponse = serde_json::from_str(json_str)?;
486
487 if let Some(err) = rpc_response.error {
488 const MCP_ERROR_MESSAGE_MAX_BYTES: usize = 1024;
496 let truncated = if err.message.len() > MCP_ERROR_MESSAGE_MAX_BYTES {
497 let cut = crate::tool::builtins::floor_char_boundary(
498 &err.message,
499 MCP_ERROR_MESSAGE_MAX_BYTES,
500 );
501 format!("{}…[truncated]", &err.message[..cut])
502 } else {
503 err.message
504 };
505 return Err(Error::Mcp(format!(
506 "[mcp_server_error code={}] {}",
507 err.code, truncated
508 )));
509 }
510
511 rpc_response
512 .result
513 .ok_or_else(|| Error::Mcp("Response missing both result and error".into()))
514}
515
516const MCP_STDIO_LINE_MAX_BYTES: u64 = 16 * 1024 * 1024;
521
522async fn read_stdio_response<R: tokio::io::AsyncBufRead + Unpin>(
528 reader: &mut R,
529 expected_id: u64,
530) -> Result<String, Error> {
531 use tokio::io::AsyncBufReadExt;
532 let mut buf = String::new();
533 loop {
534 buf.clear();
535 let max_bytes = MCP_STDIO_LINE_MAX_BYTES as usize;
540 let mut total: usize = 0;
541 let mut got_eof = true;
542 loop {
543 let chunk = reader
544 .fill_buf()
545 .await
546 .map_err(|e| Error::Mcp(format!("stdio read error: {e}")))?;
547 if chunk.is_empty() {
548 break; }
550 got_eof = false;
551 let nl_pos = chunk.iter().position(|&b| b == b'\n');
552 let take = nl_pos.map(|i| i + 1).unwrap_or(chunk.len());
553 if total.saturating_add(take) > max_bytes {
554 return Err(Error::Mcp(format!(
555 "MCP stdio line exceeded cap of {MCP_STDIO_LINE_MAX_BYTES} bytes (F-MCP-4)"
556 )));
557 }
558 buf.push_str(&String::from_utf8_lossy(&chunk[..take]));
560 total += take;
561 reader.consume(take);
562 if nl_pos.is_some() {
563 break;
564 }
565 }
566 if got_eof && buf.is_empty() {
567 return Err(Error::Mcp("MCP stdio server closed unexpectedly".into()));
568 }
569 let trimmed = buf.trim();
570 if trimmed.is_empty() {
571 continue;
572 }
573
574 let value: Value = match serde_json::from_str(trimmed) {
576 Ok(v) => v,
577 Err(_) => continue,
578 };
579
580 match value.get("id") {
582 None | Some(&Value::Null) => {
583 if value.get("method").and_then(|m| m.as_str()) == Some("notifications/message") {
584 handle_log_notification(&value);
585 }
586 continue;
587 }
588 _ => {}
589 }
590
591 if value.get("id").and_then(|v| v.as_u64()) == Some(expected_id) {
592 return Ok(trimmed.to_string());
593 }
594 }
596}
597
598pub trait AuthProvider: Send + Sync {
605 fn auth_header_for<'a>(
608 &'a self,
609 user_id: &'a str,
610 tenant_id: &'a str,
611 ) -> Pin<Box<dyn Future<Output = Result<Option<String>, Error>> + Send + 'a>>;
612
613 fn auth_header_for_resource<'a>(
619 &'a self,
620 user_id: &'a str,
621 tenant_id: &'a str,
622 _resource: Option<&'a str>,
623 _scopes: Option<&'a [String]>,
624 ) -> Pin<Box<dyn Future<Output = Result<Option<String>, Error>> + Send + 'a>> {
625 self.auth_header_for(user_id, tenant_id)
626 }
627
628 fn has_credentials(&self, _user_id: &str, _tenant_id: &str) -> bool {
635 true
636 }
637}
638
639pub struct StaticAuthProvider {
641 header: Option<String>,
642}
643
644impl StaticAuthProvider {
645 pub fn new(header: Option<String>) -> Self {
646 Self { header }
647 }
648}
649
650impl AuthProvider for StaticAuthProvider {
651 fn auth_header_for<'a>(
652 &'a self,
653 _user_id: &'a str,
654 _tenant_id: &'a str,
655 ) -> Pin<Box<dyn Future<Output = Result<Option<String>, Error>> + Send + 'a>> {
656 Box::pin(async move { Ok(self.header.clone()) })
657 }
658}
659
660pub struct DirectAuthProvider {
663 tokens: HashMap<String, String>,
664}
665
666impl DirectAuthProvider {
667 pub fn new(tokens: HashMap<String, String>) -> Self {
668 Self { tokens }
669 }
670}
671
672impl AuthProvider for DirectAuthProvider {
673 fn auth_header_for<'a>(
674 &'a self,
675 _user_id: &'a str,
676 _tenant_id: &'a str,
677 ) -> Pin<Box<dyn Future<Output = Result<Option<String>, Error>> + Send + 'a>> {
678 Box::pin(async { Ok(None) })
681 }
682
683 fn auth_header_for_resource<'a>(
684 &'a self,
685 _user_id: &'a str,
686 _tenant_id: &'a str,
687 resource: Option<&'a str>,
688 _scopes: Option<&'a [String]>,
689 ) -> Pin<Box<dyn Future<Output = Result<Option<String>, Error>> + Send + 'a>> {
690 Box::pin(async move {
691 Ok(
692 resource
693 .and_then(|url| self.tokens.get(url).map(|token| format!("Bearer {token}"))),
694 )
695 })
696 }
697
698 fn has_credentials(&self, _user_id: &str, _tenant_id: &str) -> bool {
699 !self.tokens.is_empty()
700 }
701}
702
703pub trait AuthResolver: Send + Sync {
711 fn resolve(&self) -> Pin<Box<dyn Future<Output = Result<Option<String>, Error>> + Send + '_>>;
713}
714
715pub struct StaticAuthResolver(pub Option<String>);
717
718impl AuthResolver for StaticAuthResolver {
719 fn resolve(&self) -> Pin<Box<dyn Future<Output = Result<Option<String>, Error>> + Send + '_>> {
720 Box::pin(async move { Ok(self.0.clone()) })
721 }
722}
723
724pub struct DynamicAuthResolver {
729 provider: Arc<dyn AuthProvider>,
730 user_id: String,
731 tenant_id: String,
732 resource: Option<String>,
733 scopes: Option<Vec<String>>,
734}
735
736impl DynamicAuthResolver {
737 pub fn new(
738 provider: Arc<dyn AuthProvider>,
739 user_id: impl Into<String>,
740 tenant_id: impl Into<String>,
741 ) -> Self {
742 Self {
743 provider,
744 user_id: user_id.into(),
745 tenant_id: tenant_id.into(),
746 resource: None,
747 scopes: None,
748 }
749 }
750
751 pub fn with_resource(mut self, resource: Option<String>) -> Self {
753 self.resource = resource;
754 self
755 }
756
757 pub fn with_scopes(mut self, scopes: Option<Vec<String>>) -> Self {
759 self.scopes = scopes;
760 self
761 }
762}
763
764impl AuthResolver for DynamicAuthResolver {
765 fn resolve(&self) -> Pin<Box<dyn Future<Output = Result<Option<String>, Error>> + Send + '_>> {
766 Box::pin(async move {
767 self.provider
768 .auth_header_for_resource(
769 &self.user_id,
770 &self.tenant_id,
771 self.resource.as_deref(),
772 self.scopes.as_deref(),
773 )
774 .await
775 })
776 }
777}
778
779const TENANT_ID_HEADER: &str = "X-Tenant-ID";
781
782pub struct TokenExchangeAuthProvider {
785 client: reqwest::Client,
786 exchange_url: String,
787 client_id: String,
788 client_secret: String,
789 tenant_id: Option<String>,
792 agent_token: String,
794 scopes: Vec<String>,
797 agent_token_cache: RwLock<Option<(String, Instant)>>,
800 user_tokens: Arc<RwLock<HashMap<String, String>>>,
803 token_cache: RwLock<HashMap<TokenCacheKey, (String, Instant)>>,
808}
809
810#[derive(Debug, Clone, PartialEq, Eq, Hash)]
816struct TokenCacheKey {
817 tenant_id: String,
818 user_id: String,
819 resource: String,
820 scopes: String,
821}
822
823#[derive(Deserialize)]
825struct TokenExchangeResponse {
826 access_token: String,
827 #[serde(default)]
828 expires_in: Option<u64>,
829 #[serde(default)]
830 token_type: Option<String>,
831}
832
833const TOKEN_EXCHANGE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
835
836impl TokenExchangeAuthProvider {
837 pub fn new(
847 exchange_url: impl Into<String>,
848 client_id: impl Into<String>,
849 client_secret: impl Into<String>,
850 agent_token: impl Into<String>,
851 ) -> Self {
852 let exchange_url: String = exchange_url.into();
853 if let Err(e) =
854 crate::http::validate_url_sync(&exchange_url, crate::http::IpPolicy::default())
855 {
856 tracing::error!(
857 error = %e,
858 exchange_url = %exchange_url,
859 "TokenExchangeAuthProvider::new: invalid exchange_url; \
860 the OAuth exchange will fail at request time. \
861 Consider TokenExchangeAuthProvider::try_new for a graceful Result."
862 );
863 }
864 Self::new_unchecked(exchange_url, client_id, client_secret, agent_token)
865 }
866
867 pub fn try_new(
874 exchange_url: impl Into<String>,
875 client_id: impl Into<String>,
876 client_secret: impl Into<String>,
877 agent_token: impl Into<String>,
878 ) -> Result<Self, Error> {
879 let exchange_url: String = exchange_url.into();
880 crate::http::validate_url_sync(&exchange_url, crate::http::IpPolicy::default())
881 .map_err(|e| Error::Mcp(format!("invalid exchange_url: {e}")))?;
882 Ok(Self::new_unchecked(
883 exchange_url,
884 client_id,
885 client_secret,
886 agent_token,
887 ))
888 }
889
890 fn new_unchecked(
891 exchange_url: String,
892 client_id: impl Into<String>,
893 client_secret: impl Into<String>,
894 agent_token: impl Into<String>,
895 ) -> Self {
896 Self {
897 client: reqwest::Client::builder()
898 .timeout(TOKEN_EXCHANGE_TIMEOUT)
899 .redirect(reqwest::redirect::Policy::none())
900 .build()
901 .unwrap_or_default(),
902 exchange_url,
903 client_id: client_id.into(),
904 client_secret: client_secret.into(),
905 tenant_id: None,
906 agent_token: agent_token.into(),
907 scopes: Vec::new(),
908 agent_token_cache: RwLock::new(None),
909 user_tokens: Arc::new(RwLock::new(HashMap::new())),
910 token_cache: RwLock::new(HashMap::new()),
911 }
912 }
913
914 pub fn with_tenant_id(mut self, tenant_id: Option<String>) -> Self {
916 self.tenant_id = tenant_id;
917 self
918 }
919
920 pub fn with_scopes(mut self, scopes: Vec<String>) -> Self {
922 self.scopes = scopes;
923 self
924 }
925
926 pub fn with_user_tokens(mut self, tokens: Arc<RwLock<HashMap<String, String>>>) -> Self {
928 self.user_tokens = tokens;
929 self
930 }
931
932 pub fn user_tokens(&self) -> &Arc<RwLock<HashMap<String, String>>> {
934 &self.user_tokens
935 }
936
937 async fn ensure_valid_agent_token(&self) -> Result<String, Error> {
943 {
945 let cache = self
946 .agent_token_cache
947 .read()
948 .map_err(|e| Error::Mcp(format!("agent_token_cache lock poisoned: {e}")))?;
949 if let Some((token, expires_at)) = &*cache
950 && Instant::now() < *expires_at
951 {
952 return Ok(token.clone());
953 }
954 }
955 if let Some(tenant_id) = &self.tenant_id {
957 let scope = if self.scopes.is_empty() {
958 "openid".to_string()
959 } else {
960 self.scopes.join(" ")
961 };
962 let response = self
963 .client
964 .post(&self.exchange_url)
965 .header(TENANT_ID_HEADER, tenant_id)
966 .form(&[
967 ("grant_type", "client_credentials"),
968 ("client_id", &self.client_id),
969 ("client_secret", &self.client_secret),
970 ("scope", &scope),
971 ])
972 .send()
973 .await
974 .map_err(|e| Error::Mcp(format!("Agent token fetch failed: {e}")))?;
975
976 let status = response.status();
977 if !status.is_success() {
978 let body = response.text().await.unwrap_or_default();
979 let body = redact_idp_body(&body);
982 let cut = crate::tool::builtins::floor_char_boundary(&body, 512);
983 return Err(Error::Mcp(format!(
984 "Agent token fetch failed (HTTP {status}): {}",
985 &body[..cut]
986 )));
987 }
988
989 let resp: TokenExchangeResponse = response
990 .json()
991 .await
992 .map_err(|e| Error::Mcp(format!("Agent token response parse error: {e}")))?;
993
994 let ttl = resp.expires_in.unwrap_or(300).min(3600).saturating_sub(30);
995 let expires_at = Instant::now() + Duration::from_secs(ttl);
996 *self
997 .agent_token_cache
998 .write()
999 .map_err(|e| Error::Mcp(format!("agent_token_cache lock poisoned: {e}")))? =
1000 Some((resp.access_token.clone(), expires_at));
1001 return Ok(resp.access_token);
1002 }
1003 Ok(self.agent_token.clone())
1005 }
1006}
1007
1008impl AuthProvider for TokenExchangeAuthProvider {
1009 fn auth_header_for<'a>(
1010 &'a self,
1011 user_id: &'a str,
1012 tenant_id: &'a str,
1013 ) -> Pin<Box<dyn Future<Output = Result<Option<String>, Error>> + Send + 'a>> {
1014 Box::pin(async move {
1015 let cache_key = TokenCacheKey {
1019 tenant_id: tenant_id.to_string(),
1020 user_id: user_id.to_string(),
1021 resource: String::new(),
1022 scopes: String::new(),
1023 };
1024 if let Ok(cache) = self.token_cache.read()
1025 && let Some((token, expires_at)) = cache.get(&cache_key)
1026 && Instant::now() < *expires_at
1027 {
1028 return Ok(Some(format!("Bearer {token}")));
1029 }
1030
1031 let token_key = format!("{tenant_id}:{user_id}");
1032 let subject_token = {
1033 let tokens = self
1034 .user_tokens
1035 .read()
1036 .map_err(|e| Error::Mcp(format!("user_tokens lock poisoned: {e}")))?;
1037 tokens.get(&token_key).cloned().ok_or_else(|| {
1038 Error::Mcp(format!(
1039 "No subject token found for user '{user_id}' in tenant '{tenant_id}'"
1040 ))
1041 })?
1042 };
1043
1044 let agent_token = self.ensure_valid_agent_token().await?;
1045 let response = self
1046 .client
1047 .post(&self.exchange_url)
1048 .header(TENANT_ID_HEADER, tenant_id)
1049 .form(&[
1050 (
1051 "grant_type",
1052 "urn:ietf:params:oauth:grant-type:token-exchange",
1053 ),
1054 ("subject_token", &subject_token),
1055 (
1056 "subject_token_type",
1057 "urn:ietf:params:oauth:token-type:access_token",
1058 ),
1059 ("actor_token", &agent_token),
1060 (
1061 "actor_token_type",
1062 "urn:ietf:params:oauth:token-type:access_token",
1063 ),
1064 ("client_id", &self.client_id),
1065 ("client_secret", &self.client_secret),
1066 ])
1067 .send()
1068 .await
1069 .map_err(|e| Error::Mcp(format!("Token exchange request failed: {e}")))?;
1070
1071 let status = response.status();
1072 if !status.is_success() {
1073 let body = response.text().await.unwrap_or_default();
1074 let cut = crate::tool::builtins::floor_char_boundary(&body, 512);
1076 return Err(Error::Mcp(format!(
1077 "Token exchange failed (HTTP {status}): {}",
1078 &body[..cut]
1079 )));
1080 }
1081
1082 let token_response: TokenExchangeResponse = response
1083 .json()
1084 .await
1085 .map_err(|e| Error::Mcp(format!("Token exchange response parse error: {e}")))?;
1086
1087 let ttl = token_response.expires_in.unwrap_or(300).min(3600);
1089 let now = Instant::now();
1091 let expires_at = now + Duration::from_secs(ttl.saturating_sub(30));
1092 if let Ok(mut cache) = self.token_cache.write() {
1093 cache.retain(|_, (_, exp)| now < *exp);
1095 cache.insert(cache_key, (token_response.access_token.clone(), expires_at));
1096 }
1097
1098 let token_type = token_response.token_type.as_deref().unwrap_or("Bearer");
1099 Ok(Some(format!(
1100 "{token_type} {}",
1101 token_response.access_token
1102 )))
1103 })
1104 }
1105
1106 fn has_credentials(&self, user_id: &str, tenant_id: &str) -> bool {
1107 let token_key = format!("{tenant_id}:{user_id}");
1108 self.user_tokens
1109 .read()
1110 .map(|tokens| tokens.contains_key(&token_key))
1111 .unwrap_or(false)
1112 }
1113
1114 fn auth_header_for_resource<'a>(
1115 &'a self,
1116 user_id: &'a str,
1117 tenant_id: &'a str,
1118 resource: Option<&'a str>,
1119 scopes: Option<&'a [String]>,
1120 ) -> Pin<Box<dyn Future<Output = Result<Option<String>, Error>> + Send + 'a>> {
1121 Box::pin(async move {
1122 let resource_key = resource.unwrap_or("");
1124 let scopes_key = scopes
1125 .map(|s| {
1126 let mut sorted = s.to_vec();
1127 sorted.sort();
1128 sorted.join(",")
1129 })
1130 .unwrap_or_default();
1131 let cache_key = TokenCacheKey {
1135 tenant_id: tenant_id.to_string(),
1136 user_id: user_id.to_string(),
1137 resource: resource_key.to_string(),
1138 scopes: scopes_key.clone(),
1139 };
1140
1141 if let Ok(cache) = self.token_cache.read()
1143 && let Some((token, expires_at)) = cache.get(&cache_key)
1144 && Instant::now() < *expires_at
1145 {
1146 return Ok(Some(format!("Bearer {token}")));
1147 }
1148
1149 let token_key = format!("{tenant_id}:{user_id}");
1150 let subject_token = {
1151 let tokens = self
1152 .user_tokens
1153 .read()
1154 .map_err(|e| Error::Mcp(format!("user_tokens lock poisoned: {e}")))?;
1155 tokens.get(&token_key).cloned().ok_or_else(|| {
1156 Error::Mcp(format!(
1157 "No subject token found for user '{user_id}' in tenant '{tenant_id}'"
1158 ))
1159 })?
1160 };
1161
1162 let agent_token = self.ensure_valid_agent_token().await?;
1163
1164 let mut form_params: Vec<(&str, String)> = vec![
1166 (
1167 "grant_type",
1168 "urn:ietf:params:oauth:grant-type:token-exchange".into(),
1169 ),
1170 ("subject_token", subject_token),
1171 (
1172 "subject_token_type",
1173 "urn:ietf:params:oauth:token-type:access_token".into(),
1174 ),
1175 ("actor_token", agent_token),
1176 (
1177 "actor_token_type",
1178 "urn:ietf:params:oauth:token-type:access_token".into(),
1179 ),
1180 ("client_id", self.client_id.clone()),
1181 ("client_secret", self.client_secret.clone()),
1182 ];
1183 if let Some(r) = resource {
1184 form_params.push(("resource", r.to_string()));
1185 }
1186 if let Some(s) = scopes
1187 && !s.is_empty()
1188 {
1189 form_params.push(("scope", s.join(" ")));
1190 }
1191
1192 let response = self
1193 .client
1194 .post(&self.exchange_url)
1195 .header(TENANT_ID_HEADER, tenant_id)
1196 .form(&form_params)
1197 .send()
1198 .await
1199 .map_err(|e| Error::Mcp(format!("Token exchange request failed: {e}")))?;
1200
1201 let status = response.status();
1202 if !status.is_success() {
1203 let body = response.text().await.unwrap_or_default();
1204 let body = redact_idp_body(&body);
1207 let cut = crate::tool::builtins::floor_char_boundary(&body, 512);
1208 return Err(Error::Mcp(format!(
1209 "Token exchange failed (HTTP {status}): {}",
1210 &body[..cut]
1211 )));
1212 }
1213
1214 let token_response: TokenExchangeResponse = response
1215 .json()
1216 .await
1217 .map_err(|e| Error::Mcp(format!("Token exchange response parse error: {e}")))?;
1218
1219 let ttl = token_response.expires_in.unwrap_or(300).min(3600);
1220 let now = Instant::now();
1221 let expires_at = now + Duration::from_secs(ttl.saturating_sub(30));
1222 if let Ok(mut cache) = self.token_cache.write() {
1223 cache.retain(|_, (_, exp)| now < *exp);
1224 cache.insert(cache_key, (token_response.access_token.clone(), expires_at));
1225 }
1226
1227 let token_type = token_response.token_type.as_deref().unwrap_or("Bearer");
1228 Ok(Some(format!(
1229 "{token_type} {}",
1230 token_response.access_token
1231 )))
1232 })
1233 }
1234}
1235
1236struct HttpTransport {
1240 client: reqwest::Client,
1241 endpoint: String,
1242 session_id: RwLock<Option<String>>,
1243 next_id: AtomicU64,
1244 auth_header: Option<String>,
1245}
1246
1247impl HttpTransport {
1248 fn next_id(&self) -> u64 {
1249 self.next_id.fetch_add(1, Ordering::Relaxed)
1250 }
1251
1252 fn read_session_id(&self) -> Result<Option<String>, Error> {
1254 Ok(self
1255 .session_id
1256 .read()
1257 .map_err(|e| Error::Mcp(format!("Lock poisoned: {e}")))?
1258 .clone())
1259 }
1260
1261 fn update_session_id(&self, response: &reqwest::Response) -> Result<(), Error> {
1263 if let Some(new_sid) = response
1264 .headers()
1265 .get("Mcp-Session-Id")
1266 .and_then(|v| v.to_str().ok())
1267 {
1268 *self
1269 .session_id
1270 .write()
1271 .map_err(|e| Error::Mcp(format!("Lock poisoned: {e}")))? =
1272 Some(new_sid.to_string());
1273 }
1274 Ok(())
1275 }
1276
1277 async fn rpc(
1278 &self,
1279 method: &str,
1280 params: Option<Value>,
1281 auth_override: Option<&str>,
1282 ) -> Result<Value, Error> {
1283 let id = self.next_id();
1284 let request = JsonRpcRequest {
1285 jsonrpc: "2.0",
1286 method: method.to_string(),
1287 params,
1288 id,
1289 };
1290
1291 let mut builder = self
1292 .client
1293 .post(&self.endpoint)
1294 .header("Accept", "application/json, text/event-stream")
1295 .json(&request);
1296
1297 if let Some(sid) = self.read_session_id()? {
1298 builder = builder.header("Mcp-Session-Id", sid);
1299 }
1300 let effective_auth = auth_override.or(self.auth_header.as_deref());
1302 if let Some(auth) = effective_auth {
1303 builder = builder.header("Authorization", auth);
1304 }
1305
1306 let response = builder.send().await?;
1307 self.update_session_id(&response)?;
1308
1309 let status = response.status();
1310 let content_type = response
1311 .headers()
1312 .get("content-type")
1313 .and_then(|v| v.to_str().ok())
1314 .unwrap_or("")
1315 .to_string();
1316 const MCP_HTTP_BODY_MAX_BYTES: usize = 16 * 1024 * 1024;
1321 let body = crate::http::read_text_capped(response, MCP_HTTP_BODY_MAX_BYTES).await?;
1322
1323 if !status.is_success() {
1324 return Err(Error::Mcp(format!("HTTP {}: {}", status.as_u16(), body)));
1325 }
1326
1327 let json_str = if content_type.contains("text/event-stream") {
1328 let events = extract_sse_events(&body)?;
1329 find_rpc_response(&events, id)?
1330 } else {
1331 body
1332 };
1333
1334 process_rpc_response(&json_str)
1335 }
1336
1337 async fn notify(
1338 &self,
1339 method: &str,
1340 params: Option<Value>,
1341 auth_override: Option<&str>,
1342 ) -> Result<(), Error> {
1343 let notification = JsonRpcNotification {
1344 jsonrpc: "2.0",
1345 method: method.to_string(),
1346 params,
1347 };
1348
1349 let mut builder = self
1350 .client
1351 .post(&self.endpoint)
1352 .header("Accept", "application/json, text/event-stream")
1353 .json(¬ification);
1354
1355 if let Some(sid) = self.read_session_id()? {
1356 builder = builder.header("Mcp-Session-Id", sid);
1357 }
1358 let effective_auth = auth_override.or(self.auth_header.as_deref());
1359 if let Some(auth) = effective_auth {
1360 builder = builder.header("Authorization", auth);
1361 }
1362
1363 let response = builder.send().await?;
1364 self.update_session_id(&response)?;
1365
1366 let status = response.status();
1367 if !status.is_success() {
1368 let body = response.text().await?;
1369 return Err(Error::Mcp(format!(
1370 "Notification HTTP {}: {}",
1371 status.as_u16(),
1372 body
1373 )));
1374 }
1375
1376 let _ = response.bytes().await;
1378
1379 Ok(())
1380 }
1381}
1382
1383struct StdioIo {
1390 stdin: tokio::process::ChildStdin,
1391 reader: tokio::io::BufReader<tokio::process::ChildStdout>,
1392 _process: tokio::process::Child,
1393}
1394
1395struct StdioTransport {
1400 io: tokio::sync::Mutex<StdioIo>,
1401 next_id: AtomicU64,
1402}
1403
1404impl StdioTransport {
1405 fn next_id(&self) -> u64 {
1406 self.next_id.fetch_add(1, Ordering::Relaxed)
1407 }
1408
1409 async fn rpc(&self, method: &str, params: Option<Value>) -> Result<Value, Error> {
1410 let id = self.next_id();
1411 let request = JsonRpcRequest {
1412 jsonrpc: "2.0",
1413 method: method.to_string(),
1414 params,
1415 id,
1416 };
1417 let line = serde_json::to_string(&request)? + "\n";
1418
1419 let mut io = self.io.lock().await;
1422 let json_str = tokio::time::timeout(REQUEST_TIMEOUT, async {
1423 io.stdin
1424 .write_all(line.as_bytes())
1425 .await
1426 .map_err(|e| Error::Mcp(format!("stdio write error: {e}")))?;
1427 io.stdin
1428 .flush()
1429 .await
1430 .map_err(|e| Error::Mcp(format!("stdio flush error: {e}")))?;
1431 read_stdio_response(&mut io.reader, id).await
1432 })
1433 .await
1434 .map_err(|_| {
1435 Error::Mcp(format!(
1436 "MCP stdio server timed out after {}s for request {id}",
1437 REQUEST_TIMEOUT.as_secs()
1438 ))
1439 })??;
1440 process_rpc_response(&json_str)
1441 }
1442
1443 async fn notify(&self, method: &str, params: Option<Value>) -> Result<(), Error> {
1444 let notification = JsonRpcNotification {
1445 jsonrpc: "2.0",
1446 method: method.to_string(),
1447 params,
1448 };
1449 let line = serde_json::to_string(¬ification)? + "\n";
1450
1451 let mut io = self.io.lock().await;
1452 tokio::time::timeout(REQUEST_TIMEOUT, async {
1453 io.stdin
1454 .write_all(line.as_bytes())
1455 .await
1456 .map_err(|e| Error::Mcp(format!("stdio write error: {e}")))?;
1457 io.stdin
1458 .flush()
1459 .await
1460 .map_err(|e| Error::Mcp(format!("stdio flush error: {e}")))?;
1461 Ok::<(), Error>(())
1462 })
1463 .await
1464 .map_err(|_| {
1465 Error::Mcp(format!(
1466 "MCP stdio notification timed out after {}s",
1467 REQUEST_TIMEOUT.as_secs()
1468 ))
1469 })??;
1470 Ok(())
1471 }
1472}
1473
1474enum Transport {
1478 Http(HttpTransport),
1479 Stdio(Box<StdioTransport>),
1480}
1481
1482impl Transport {
1483 async fn rpc(&self, method: &str, params: Option<Value>) -> Result<Value, Error> {
1484 self.rpc_with_auth(method, params, None).await
1485 }
1486
1487 async fn rpc_with_auth(
1488 &self,
1489 method: &str,
1490 params: Option<Value>,
1491 auth_override: Option<&str>,
1492 ) -> Result<Value, Error> {
1493 match self {
1494 Transport::Http(t) => t.rpc(method, params, auth_override).await,
1495 Transport::Stdio(t) => t.rpc(method, params).await,
1497 }
1498 }
1499
1500 async fn notify(&self, method: &str, params: Option<Value>) -> Result<(), Error> {
1501 self.notify_with_auth(method, params, None).await
1502 }
1503
1504 async fn notify_with_auth(
1505 &self,
1506 method: &str,
1507 params: Option<Value>,
1508 auth_override: Option<&str>,
1509 ) -> Result<(), Error> {
1510 match self {
1511 Transport::Http(t) => t.notify(method, params, auth_override).await,
1512 Transport::Stdio(t) => t.notify(method, params).await,
1513 }
1514 }
1515
1516 async fn call_tool_with_auth(
1518 &self,
1519 name: &str,
1520 arguments: Value,
1521 auth_override: Option<&str>,
1522 ) -> Result<ToolOutput, Error> {
1523 let arguments = if arguments.is_null() {
1526 serde_json::json!({})
1527 } else {
1528 arguments
1529 };
1530 let params = serde_json::json!({
1531 "name": name,
1532 "arguments": arguments,
1533 });
1534
1535 let result_value = self
1536 .rpc_with_auth("tools/call", Some(params), auth_override)
1537 .await?;
1538 let result: McpCallToolResult = serde_json::from_value(result_value)?;
1539 Ok(mcp_result_to_tool_output(result))
1540 }
1541}
1542
1543struct McpTool {
1546 transport: Arc<Transport>,
1547 def: ToolDefinition,
1548 auth_resolver: Option<Arc<dyn AuthResolver>>,
1551}
1552
1553impl Tool for McpTool {
1554 fn definition(&self) -> ToolDefinition {
1555 self.def.clone()
1556 }
1557
1558 fn execute(
1559 &self,
1560 _ctx: &crate::ExecutionContext,
1561 input: Value,
1562 ) -> Pin<Box<dyn Future<Output = Result<ToolOutput, Error>> + Send + '_>> {
1563 Box::pin(async move {
1564 let auth = if let Some(resolver) = &self.auth_resolver {
1565 resolver.resolve().await?
1566 } else {
1567 None
1568 };
1569 match self
1570 .transport
1571 .call_tool_with_auth(&self.def.name, input, auth.as_deref())
1572 .await
1573 {
1574 Ok(output) => Ok(output),
1575 Err(e) => {
1576 tracing::warn!(
1577 tool = %self.def.name,
1578 error = %e,
1579 "MCP tool call failed"
1580 );
1581 Ok(ToolOutput::error(e.to_string()))
1582 }
1583 }
1584 })
1585 }
1586}
1587
1588struct McpResourceTool {
1597 transport: Arc<Transport>,
1598 resource: McpResourceDef,
1599 tool_name: String,
1600 auth_resolver: Option<Arc<dyn AuthResolver>>,
1601}
1602
1603impl Tool for McpResourceTool {
1604 fn definition(&self) -> ToolDefinition {
1605 let desc = self
1606 .resource
1607 .description
1608 .clone()
1609 .unwrap_or_else(|| format!("Read MCP resource: {}", self.resource.uri));
1610 ToolDefinition {
1611 name: self.tool_name.clone(),
1612 description: desc,
1613 input_schema: serde_json::json!({
1614 "type": "object",
1615 "properties": {},
1616 }),
1617 }
1618 }
1619
1620 fn execute(
1621 &self,
1622 _ctx: &crate::ExecutionContext,
1623 _input: Value,
1624 ) -> Pin<Box<dyn Future<Output = Result<ToolOutput, Error>> + Send + '_>> {
1625 Box::pin(async move {
1626 const ALLOWED_SCHEMES: &[&str] = &["mcp", "https", "http", "resource", "memory"];
1632 let scheme = self
1633 .resource
1634 .uri
1635 .split(':')
1636 .next()
1637 .unwrap_or("")
1638 .to_ascii_lowercase();
1639 if !ALLOWED_SCHEMES.iter().any(|s| *s == scheme) {
1640 return Ok(ToolOutput::error(format!(
1641 "MCP resource URI scheme {scheme:?} is not allowed; \
1642 refused (F-MCP-10). uri={}",
1643 self.resource.uri
1644 )));
1645 }
1646 let auth = if let Some(resolver) = &self.auth_resolver {
1647 resolver.resolve().await?
1648 } else {
1649 None
1650 };
1651 let params = serde_json::json!({ "uri": self.resource.uri });
1652 match self
1653 .transport
1654 .rpc_with_auth("resources/read", Some(params), auth.as_deref())
1655 .await
1656 {
1657 Ok(value) => {
1658 let result: McpResourceReadResult = serde_json::from_value(value)?;
1659 let text: String = result
1660 .contents
1661 .iter()
1662 .filter_map(|c| c.text.as_deref())
1663 .collect::<Vec<_>>()
1664 .join("\n");
1665 if text.is_empty() {
1666 Ok(ToolOutput::success(format!(
1667 "[Resource {} returned no text content]",
1668 self.resource.uri
1669 )))
1670 } else {
1671 Ok(ToolOutput::success(text))
1672 }
1673 }
1674 Err(e) => {
1675 tracing::warn!(
1676 resource = %self.resource.uri,
1677 error = %e,
1678 "MCP resource read failed"
1679 );
1680 Ok(ToolOutput::error(e.to_string()))
1681 }
1682 }
1683 })
1684 }
1685}
1686
1687struct McpPromptTool {
1691 transport: Arc<Transport>,
1692 prompt: McpPromptDef,
1693 tool_name: String,
1694 auth_resolver: Option<Arc<dyn AuthResolver>>,
1695}
1696
1697impl Tool for McpPromptTool {
1698 fn definition(&self) -> ToolDefinition {
1699 let desc = self
1700 .prompt
1701 .description
1702 .clone()
1703 .unwrap_or_else(|| format!("Get MCP prompt: {}", self.prompt.name));
1704 let mut properties = serde_json::Map::new();
1706 let mut required = Vec::new();
1707 for arg in &self.prompt.arguments {
1708 let mut prop = serde_json::Map::new();
1709 prop.insert("type".into(), serde_json::json!("string"));
1710 if let Some(desc) = &arg.description {
1711 prop.insert("description".into(), serde_json::json!(desc));
1712 }
1713 properties.insert(arg.name.clone(), Value::Object(prop));
1714 if arg.required {
1715 required.push(serde_json::json!(arg.name));
1716 }
1717 }
1718 let mut schema = serde_json::json!({
1719 "type": "object",
1720 "properties": properties,
1721 });
1722 if !required.is_empty() {
1723 schema["required"] = Value::Array(required);
1724 }
1725 ToolDefinition {
1726 name: self.tool_name.clone(),
1727 description: desc,
1728 input_schema: schema,
1729 }
1730 }
1731
1732 fn execute(
1733 &self,
1734 _ctx: &crate::ExecutionContext,
1735 input: Value,
1736 ) -> Pin<Box<dyn Future<Output = Result<ToolOutput, Error>> + Send + '_>> {
1737 Box::pin(async move {
1738 let auth = if let Some(resolver) = &self.auth_resolver {
1739 resolver.resolve().await?
1740 } else {
1741 None
1742 };
1743 let arguments = if input.is_null() || input.as_object().is_some_and(|m| m.is_empty()) {
1744 None
1745 } else {
1746 Some(input)
1747 };
1748 let mut params = serde_json::json!({ "name": self.prompt.name });
1749 if let Some(args) = arguments {
1750 params["arguments"] = args;
1751 }
1752 match self
1753 .transport
1754 .rpc_with_auth("prompts/get", Some(params), auth.as_deref())
1755 .await
1756 {
1757 Ok(value) => {
1758 let result: McpPromptGetResult = serde_json::from_value(value)?;
1759 let text: String = result
1760 .messages
1761 .iter()
1762 .map(|m| {
1763 let content = m.content.text.as_deref().unwrap_or("");
1764 format!("[{}] {}", m.role, content)
1765 })
1766 .collect::<Vec<_>>()
1767 .join("\n");
1768 Ok(ToolOutput::success(text))
1769 }
1770 Err(e) => {
1771 tracing::warn!(
1772 prompt = %self.prompt.name,
1773 error = %e,
1774 "MCP prompt get failed"
1775 );
1776 Ok(ToolOutput::error(e.to_string()))
1777 }
1778 }
1779 })
1780 }
1781}
1782
1783#[derive(Debug, Clone, Deserialize)]
1789#[serde(rename_all = "camelCase")]
1790pub struct SamplingRequest {
1791 pub messages: Vec<SamplingMessage>,
1792 #[serde(default)]
1793 pub model_preferences: Option<SamplingModelPreferences>,
1794 #[serde(default)]
1795 pub system_prompt: Option<String>,
1796 #[serde(default)]
1797 pub max_tokens: Option<u32>,
1798}
1799
1800#[derive(Debug, Clone, Serialize, Deserialize)]
1802pub struct SamplingMessage {
1803 pub role: String,
1804 pub content: SamplingContent,
1805}
1806
1807#[derive(Debug, Clone, Serialize, Deserialize)]
1809pub struct SamplingContent {
1810 #[serde(rename = "type")]
1811 pub content_type: String,
1812 #[serde(default)]
1813 pub text: Option<String>,
1814}
1815
1816#[derive(Debug, Clone, Deserialize)]
1821#[serde(rename_all = "camelCase")]
1822pub struct SamplingModelPreferences {
1823 #[serde(default)]
1824 pub hints: Vec<SamplingModelHint>,
1825}
1826
1827#[derive(Debug, Clone, Deserialize)]
1832pub struct SamplingModelHint {
1833 #[serde(default)]
1834 pub name: Option<String>,
1835}
1836
1837#[derive(Debug, Serialize)]
1839#[serde(rename_all = "camelCase")]
1840#[allow(dead_code)]
1841struct SamplingResponse {
1842 role: String,
1843 content: SamplingContent,
1844 model: String,
1845}
1846
1847pub type SamplingHandler = Arc<
1851 dyn Fn(SamplingRequest) -> Pin<Box<dyn Future<Output = Result<(String, String), Error>> + Send>>
1852 + Send
1853 + Sync,
1854>;
1855
1856fn sanitize_tool_name(name: &str) -> String {
1858 name.chars()
1859 .map(|c| {
1860 if c.is_alphanumeric() || c == '_' {
1861 c
1862 } else {
1863 '_'
1864 }
1865 })
1866 .collect()
1867}
1868
1869#[derive(Debug, Clone, Serialize, Deserialize)]
1874pub struct McpRoot {
1875 pub uri: String,
1876 #[serde(default, skip_serializing_if = "Option::is_none")]
1877 pub name: Option<String>,
1878}
1879
1880pub struct McpClient {
1886 transport: Arc<Transport>,
1887 tools: Vec<McpToolDef>,
1888 resources: Vec<McpResourceDef>,
1889 prompts: Vec<McpPromptDef>,
1890 capabilities: ServerCapabilities,
1891 sampling_handler: Option<SamplingHandler>,
1892 roots: Vec<McpRoot>,
1895}
1896
1897impl McpClient {
1898 pub fn roots(&self) -> &[McpRoot] {
1900 &self.roots
1901 }
1902
1903 pub async fn connect(endpoint: &str) -> Result<Self, Error> {
1907 Self::connect_http(endpoint, None).await
1908 }
1909
1910 pub async fn connect_with_auth(
1916 endpoint: &str,
1917 auth_header: impl Into<String>,
1918 ) -> Result<Self, Error> {
1919 Self::connect_http(endpoint, Some(auth_header.into())).await
1920 }
1921
1922 pub fn with_sampling(mut self, handler: SamplingHandler) -> Self {
1926 self.sampling_handler = Some(handler);
1927 self
1928 }
1929
1930 pub fn with_roots(mut self, roots: Vec<McpRoot>) -> Self {
1932 self.roots = roots;
1933 self
1934 }
1935
1936 pub async fn send_roots_changed(&self) -> Result<(), Error> {
1938 self.transport
1939 .notify("notifications/roots/list_changed", None)
1940 .await
1941 }
1942
1943 pub async fn connect_stdio(
1949 command: &str,
1950 args: &[String],
1951 env: &HashMap<String, String>,
1952 ) -> Result<Self, Error> {
1953 let mut cmd = tokio::process::Command::new(command);
1954 cmd.args(args)
1955 .envs(env.iter())
1956 .stdin(std::process::Stdio::piped())
1957 .stdout(std::process::Stdio::piped())
1958 .stderr(std::process::Stdio::piped())
1959 .kill_on_drop(true);
1960
1961 let mut child = cmd.spawn().map_err(|e| {
1962 Error::Mcp(format!("Failed to spawn MCP stdio server '{command}': {e}"))
1963 })?;
1964
1965 let stdin = child
1966 .stdin
1967 .take()
1968 .ok_or_else(|| Error::Mcp("Failed to capture stdin of MCP server".into()))?;
1969 let stdout = child
1970 .stdout
1971 .take()
1972 .ok_or_else(|| Error::Mcp("Failed to capture stdout of MCP server".into()))?;
1973
1974 if let Some(stderr) = child.stderr.take() {
1976 tokio::spawn(async move {
1977 let mut reader = tokio::io::BufReader::new(stderr);
1978 let mut line = String::new();
1979 loop {
1980 line.clear();
1981 match reader.read_line(&mut line).await {
1982 Ok(0) | Err(_) => break,
1983 Ok(_) => {
1984 let trimmed = line.trim();
1985 if !trimmed.is_empty() {
1986 tracing::debug!(
1987 target: "mcp_stdio_stderr",
1988 "{}",
1989 trimmed
1990 );
1991 }
1992 }
1993 }
1994 }
1995 });
1996 }
1997
1998 let transport = Arc::new(Transport::Stdio(Box::new(StdioTransport {
1999 io: tokio::sync::Mutex::new(StdioIo {
2000 stdin,
2001 reader: tokio::io::BufReader::new(stdout),
2002 _process: child,
2003 }),
2004 next_id: AtomicU64::new(0),
2005 })));
2006
2007 Self::handshake_and_discover(transport).await
2008 }
2009
2010 async fn connect_http(endpoint: &str, auth_header: Option<String>) -> Result<Self, Error> {
2011 let safe = crate::http::SafeUrl::parse(endpoint, crate::http::IpPolicy::default()).await?;
2019
2020 let client = reqwest::Client::builder()
2021 .timeout(REQUEST_TIMEOUT)
2022 .redirect(reqwest::redirect::Policy::none())
2026 .build()?;
2027
2028 let transport = Arc::new(Transport::Http(HttpTransport {
2029 client,
2030 endpoint: safe.as_str().to_string(),
2031 session_id: RwLock::new(None),
2032 next_id: AtomicU64::new(0),
2033 auth_header,
2034 }));
2035
2036 Self::handshake_and_discover(transport).await
2037 }
2038
2039 async fn handshake_and_discover(transport: Arc<Transport>) -> Result<Self, Error> {
2041 let init_result = transport
2053 .rpc(
2054 "initialize",
2055 Some(serde_json::json!({
2056 "protocolVersion": PROTOCOL_VERSION,
2057 "capabilities": {
2058 "roots": { "listChanged": true }
2059 },
2060 "clientInfo": {
2061 "name": "heartbit",
2062 "version": env!("CARGO_PKG_VERSION")
2063 }
2064 })),
2065 )
2066 .await?;
2067
2068 let init: InitializeResult = serde_json::from_value(init_result).unwrap_or_default();
2069
2070 transport.notify("notifications/initialized", None).await?;
2071
2072 let mut all_tools = Vec::new();
2074 let mut cursor: Option<String> = None;
2075 loop {
2076 let params = cursor.as_ref().map(|c| serde_json::json!({"cursor": c}));
2077 let tools_result = transport.rpc("tools/list", params).await?;
2078 let page: McpToolsListResult = serde_json::from_value(tools_result)?;
2079 all_tools.extend(page.tools);
2080 cursor = page.next_cursor;
2081 if cursor.is_none() {
2082 break;
2083 }
2084 }
2085
2086 let mut all_resources = Vec::new();
2088 if init.capabilities.resources.is_some() {
2089 let mut cursor: Option<String> = None;
2090 loop {
2091 let params = cursor.as_ref().map(|c| serde_json::json!({"cursor": c}));
2092 match transport.rpc("resources/list", params).await {
2093 Ok(value) => {
2094 let page: McpResourcesListResult = serde_json::from_value(value)?;
2095 all_resources.extend(page.resources);
2096 cursor = page.next_cursor;
2097 if cursor.is_none() {
2098 break;
2099 }
2100 }
2101 Err(e) => {
2102 tracing::warn!(error = %e, "resources/list failed, skipping resource discovery");
2103 break;
2104 }
2105 }
2106 }
2107 }
2108
2109 let mut all_prompts = Vec::new();
2111 if init.capabilities.prompts.is_some() {
2112 let mut cursor: Option<String> = None;
2113 loop {
2114 let params = cursor.as_ref().map(|c| serde_json::json!({"cursor": c}));
2115 match transport.rpc("prompts/list", params).await {
2116 Ok(value) => {
2117 let page: McpPromptsListResult = serde_json::from_value(value)?;
2118 all_prompts.extend(page.prompts);
2119 cursor = page.next_cursor;
2120 if cursor.is_none() {
2121 break;
2122 }
2123 }
2124 Err(e) => {
2125 tracing::warn!(error = %e, "prompts/list failed, skipping prompt discovery");
2126 break;
2127 }
2128 }
2129 }
2130 }
2131
2132 Ok(Self {
2133 transport,
2134 tools: all_tools,
2135 resources: all_resources,
2136 prompts: all_prompts,
2137 capabilities: init.capabilities,
2138 sampling_handler: None,
2139 roots: Vec::new(),
2140 })
2141 }
2142
2143 pub fn tool_definitions(&self) -> Vec<ToolDefinition> {
2148 self.tools.iter().map(mcp_tool_to_definition).collect()
2149 }
2150
2151 pub fn resource_definitions(&self) -> &[McpResourceDef] {
2153 &self.resources
2154 }
2155
2156 pub fn prompt_definitions(&self) -> &[McpPromptDef] {
2158 &self.prompts
2159 }
2160
2161 pub fn supports_resource_subscribe(&self) -> bool {
2163 self.capabilities
2164 .resources
2165 .as_ref()
2166 .is_some_and(|r| r.subscribe)
2167 }
2168
2169 pub async fn resource_read(&self, uri: &str) -> Result<Vec<McpResourceContent>, Error> {
2171 let params = serde_json::json!({ "uri": uri });
2172 let value = self.transport.rpc("resources/read", Some(params)).await?;
2173 let result: McpResourceReadResult = serde_json::from_value(value)?;
2174 Ok(result.contents)
2175 }
2176
2177 pub async fn set_log_level(&self, level: &str) -> Result<(), Error> {
2179 let params = serde_json::json!({ "level": level });
2180 self.transport.rpc("logging/setLevel", Some(params)).await?;
2181 Ok(())
2182 }
2183
2184 pub async fn resource_subscribe(&self, uri: &str) -> Result<(), Error> {
2186 let params = serde_json::json!({ "uri": uri });
2187 self.transport
2188 .rpc("resources/subscribe", Some(params))
2189 .await?;
2190 Ok(())
2191 }
2192
2193 pub async fn prompt_get(
2195 &self,
2196 name: &str,
2197 arguments: Option<Value>,
2198 ) -> Result<Vec<McpPromptMessage>, Error> {
2199 let mut params = serde_json::json!({ "name": name });
2200 if let Some(args) = arguments {
2201 params["arguments"] = args;
2202 }
2203 let value = self.transport.rpc("prompts/get", Some(params)).await?;
2204 let result: McpPromptGetResult = serde_json::from_value(value)?;
2205 Ok(result.messages)
2206 }
2207
2208 pub fn into_tools(self) -> Vec<Arc<dyn Tool>> {
2210 self.stamp_tools(None)
2211 }
2212
2213 pub fn into_tools_with_auth(self, resolver: Arc<dyn AuthResolver>) -> Vec<Arc<dyn Tool>> {
2218 self.stamp_tools(Some(resolver))
2219 }
2220
2221 fn stamp_tools(self, resolver: Option<Arc<dyn AuthResolver>>) -> Vec<Arc<dyn Tool>> {
2222 let transport = self.transport;
2223 self.tools
2224 .into_iter()
2225 .map(|t| {
2226 let tool: Arc<dyn Tool> = Arc::new(McpTool {
2227 transport: Arc::clone(&transport),
2228 def: mcp_tool_to_definition(&t),
2229 auth_resolver: resolver.clone(),
2230 });
2231 tool
2232 })
2233 .collect()
2234 }
2235
2236 pub fn into_resource_tools(&self) -> Vec<Arc<dyn Tool>> {
2240 self.stamp_resource_tools(None)
2241 }
2242
2243 fn stamp_resource_tools(&self, resolver: Option<Arc<dyn AuthResolver>>) -> Vec<Arc<dyn Tool>> {
2244 self.resources
2245 .iter()
2246 .map(|r| {
2247 let tool_name = format!("mcp_resource_{}", sanitize_tool_name(&r.name));
2248 let tool: Arc<dyn Tool> = Arc::new(McpResourceTool {
2249 transport: Arc::clone(&self.transport),
2250 resource: r.clone(),
2251 tool_name,
2252 auth_resolver: resolver.clone(),
2253 });
2254 tool
2255 })
2256 .collect()
2257 }
2258
2259 pub fn into_prompt_tools(&self) -> Vec<Arc<dyn Tool>> {
2263 self.stamp_prompt_tools(None)
2264 }
2265
2266 fn stamp_prompt_tools(&self, resolver: Option<Arc<dyn AuthResolver>>) -> Vec<Arc<dyn Tool>> {
2267 self.prompts
2268 .iter()
2269 .map(|p| {
2270 let tool_name = format!("mcp_prompt_{}", sanitize_tool_name(&p.name));
2271 let tool: Arc<dyn Tool> = Arc::new(McpPromptTool {
2272 transport: Arc::clone(&self.transport),
2273 prompt: p.clone(),
2274 tool_name,
2275 auth_resolver: resolver.clone(),
2276 });
2277 tool
2278 })
2279 .collect()
2280 }
2281
2282 pub fn into_all_tools(self) -> Vec<Arc<dyn Tool>> {
2284 Self::stamp_all_tools_inner(
2285 &self.transport,
2286 &self.tools,
2287 &self.resources,
2288 &self.prompts,
2289 None,
2290 )
2291 }
2292
2293 pub fn into_all_tools_with_auth(self, resolver: Arc<dyn AuthResolver>) -> Vec<Arc<dyn Tool>> {
2295 Self::stamp_all_tools_inner(
2296 &self.transport,
2297 &self.tools,
2298 &self.resources,
2299 &self.prompts,
2300 Some(resolver),
2301 )
2302 }
2303
2304 fn stamp_all_tools_inner(
2305 transport: &Arc<Transport>,
2306 tools: &[McpToolDef],
2307 resources: &[McpResourceDef],
2308 prompts: &[McpPromptDef],
2309 resolver: Option<Arc<dyn AuthResolver>>,
2310 ) -> Vec<Arc<dyn Tool>> {
2311 let mut all: Vec<Arc<dyn Tool>> = tools
2312 .iter()
2313 .map(|t| -> Arc<dyn Tool> {
2314 Arc::new(McpTool {
2315 transport: Arc::clone(transport),
2316 def: mcp_tool_to_definition(t),
2317 auth_resolver: resolver.clone(),
2318 })
2319 })
2320 .collect();
2321 for r in resources {
2322 let tool_name = format!("mcp_resource_{}", sanitize_tool_name(&r.name));
2323 all.push(Arc::new(McpResourceTool {
2324 transport: Arc::clone(transport),
2325 resource: r.clone(),
2326 tool_name,
2327 auth_resolver: resolver.clone(),
2328 }));
2329 }
2330 for p in prompts {
2331 let tool_name = format!("mcp_prompt_{}", sanitize_tool_name(&p.name));
2332 all.push(Arc::new(McpPromptTool {
2333 transport: Arc::clone(transport),
2334 prompt: p.clone(),
2335 tool_name,
2336 auth_resolver: resolver.clone(),
2337 }));
2338 }
2339 all
2340 }
2341
2342 fn into_pool_parts(
2345 self,
2346 ) -> (
2347 Arc<Transport>,
2348 Vec<McpToolDef>,
2349 Vec<McpResourceDef>,
2350 Vec<McpPromptDef>,
2351 ) {
2352 (self.transport, self.tools, self.resources, self.prompts)
2353 }
2354}
2355
2356struct PoolEntry {
2360 transport: Arc<Transport>,
2361 tools: Vec<McpToolDef>,
2362 resources: Vec<McpResourceDef>,
2363 prompts: Vec<McpPromptDef>,
2364}
2365
2366pub struct McpTransportPool {
2372 pool: RwLock<HashMap<String, PoolEntry>>,
2373}
2374
2375impl McpTransportPool {
2376 pub fn new() -> Self {
2377 Self {
2378 pool: RwLock::new(HashMap::new()),
2379 }
2380 }
2381
2382 pub async fn get_or_connect(
2387 &self,
2388 url: &str,
2389 static_auth: Option<String>,
2390 ) -> Result<Vec<ToolDefinition>, Error> {
2391 {
2393 let pool = self
2394 .pool
2395 .read()
2396 .map_err(|e| Error::Mcp(format!("transport pool lock poisoned: {e}")))?;
2397 if let Some(entry) = pool.get(url) {
2398 return Ok(entry.tools.iter().map(mcp_tool_to_definition).collect());
2399 }
2400 }
2401
2402 let client = McpClient::connect_http(url, static_auth).await?;
2404 let (transport, tools, resources, prompts) = client.into_pool_parts();
2405 let defs: Vec<ToolDefinition> = tools.iter().map(mcp_tool_to_definition).collect();
2406
2407 let entry = PoolEntry {
2408 transport,
2409 tools,
2410 resources,
2411 prompts,
2412 };
2413
2414 let mut pool = self
2415 .pool
2416 .write()
2417 .map_err(|e| Error::Mcp(format!("transport pool lock poisoned: {e}")))?;
2418 pool.insert(url.to_string(), entry);
2419
2420 Ok(defs)
2421 }
2422
2423 pub fn tools_for_user(
2427 &self,
2428 url: &str,
2429 resolver: Arc<dyn AuthResolver>,
2430 ) -> Result<Option<Vec<Arc<dyn Tool>>>, Error> {
2431 let pool = self
2432 .pool
2433 .read()
2434 .map_err(|e| Error::Mcp(format!("transport pool lock poisoned: {e}")))?;
2435 let entry = match pool.get(url) {
2436 Some(e) => e,
2437 None => return Ok(None),
2438 };
2439
2440 let resolver = Some(resolver);
2441 let mut all: Vec<Arc<dyn Tool>> = entry
2442 .tools
2443 .iter()
2444 .map(|t| -> Arc<dyn Tool> {
2445 Arc::new(McpTool {
2446 transport: Arc::clone(&entry.transport),
2447 def: mcp_tool_to_definition(t),
2448 auth_resolver: resolver.clone(),
2449 })
2450 })
2451 .collect();
2452 for r in &entry.resources {
2453 let tool_name = format!("mcp_resource_{}", sanitize_tool_name(&r.name));
2454 all.push(Arc::new(McpResourceTool {
2455 transport: Arc::clone(&entry.transport),
2456 resource: r.clone(),
2457 tool_name,
2458 auth_resolver: resolver.clone(),
2459 }));
2460 }
2461 for p in &entry.prompts {
2462 let tool_name = format!("mcp_prompt_{}", sanitize_tool_name(&p.name));
2463 all.push(Arc::new(McpPromptTool {
2464 transport: Arc::clone(&entry.transport),
2465 prompt: p.clone(),
2466 tool_name,
2467 auth_resolver: resolver.clone(),
2468 }));
2469 }
2470 Ok(Some(all))
2471 }
2472
2473 pub fn contains(&self, url: &str) -> bool {
2475 self.pool
2476 .read()
2477 .map(|p| p.contains_key(url))
2478 .unwrap_or(false)
2479 }
2480}
2481
2482impl Default for McpTransportPool {
2483 fn default() -> Self {
2484 Self::new()
2485 }
2486}
2487
2488#[cfg(test)]
2489mod tests {
2490 use super::*;
2491 use serde_json::json;
2492
2493 #[test]
2496 fn jsonrpc_request_serialization() {
2497 let req = JsonRpcRequest {
2498 jsonrpc: "2.0",
2499 method: "tools/list".to_string(),
2500 params: Some(json!({"cursor": null})),
2501 id: 42,
2502 };
2503 let json = serde_json::to_value(&req).unwrap();
2504 assert_eq!(json["jsonrpc"], "2.0");
2505 assert_eq!(json["method"], "tools/list");
2506 assert_eq!(json["id"], 42);
2507 assert!(json.get("params").is_some());
2508 }
2509
2510 #[test]
2511 fn jsonrpc_request_null_params_omitted() {
2512 let req = JsonRpcRequest {
2513 jsonrpc: "2.0",
2514 method: "tools/list".to_string(),
2515 params: None,
2516 id: 1,
2517 };
2518 let json = serde_json::to_value(&req).unwrap();
2519 assert!(json.get("params").is_none());
2520 }
2521
2522 #[test]
2523 fn jsonrpc_notification_has_no_id() {
2524 let notif = JsonRpcNotification {
2525 jsonrpc: "2.0",
2526 method: "notifications/initialized".to_string(),
2527 params: None,
2528 };
2529 let json = serde_json::to_value(¬if).unwrap();
2530 assert_eq!(json["jsonrpc"], "2.0");
2531 assert_eq!(json["method"], "notifications/initialized");
2532 assert!(json.get("id").is_none());
2533 assert!(json.get("params").is_none());
2534 }
2535
2536 #[test]
2537 fn jsonrpc_response_parses_result() {
2538 let json_str = r#"{"jsonrpc":"2.0","result":{"tools":[]},"id":1}"#;
2539 let response: JsonRpcResponse = serde_json::from_str(json_str).unwrap();
2540 assert!(response.result.is_some());
2541 assert!(response.error.is_none());
2542 assert_eq!(response.result.unwrap(), json!({"tools": []}));
2543 }
2544
2545 #[test]
2546 fn jsonrpc_response_parses_error() {
2547 let json_str =
2548 r#"{"jsonrpc":"2.0","error":{"code":-32601,"message":"Method not found"},"id":1}"#;
2549 let response: JsonRpcResponse = serde_json::from_str(json_str).unwrap();
2550 assert!(response.result.is_none());
2551 let err = response.error.unwrap();
2552 assert_eq!(err.code, -32601);
2553 assert_eq!(err.message, "Method not found");
2554 }
2555
2556 #[test]
2559 fn sse_basic_extraction() {
2560 let body = "event: message\ndata: {\"jsonrpc\":\"2.0\",\"result\":{},\"id\":1}\n\n";
2561 let events = extract_sse_events(body).unwrap();
2562 assert_eq!(events.len(), 1);
2563 assert_eq!(events[0], r#"{"jsonrpc":"2.0","result":{},"id":1}"#);
2564 }
2565
2566 #[test]
2567 fn sse_no_data_field_errors() {
2568 let body = "event: message\n\n";
2569 let err = extract_sse_events(body).unwrap_err();
2570 assert!(matches!(err, Error::Mcp(_)));
2571 assert!(err.to_string().contains("No data field"));
2572 }
2573
2574 #[test]
2575 fn sse_no_space_after_colon() {
2576 let body = "data:{\"result\":\"ok\"}\n";
2577 let events = extract_sse_events(body).unwrap();
2578 assert_eq!(events.len(), 1);
2579 assert_eq!(events[0], r#"{"result":"ok"}"#);
2580 }
2581
2582 #[test]
2583 fn sse_multiple_events_extracted() {
2584 let body =
2585 "event: message\ndata: {\"first\": true}\n\nevent: message\ndata: {\"last\": true}\n\n";
2586 let events = extract_sse_events(body).unwrap();
2587 assert_eq!(events.len(), 2);
2588 assert_eq!(events[0], r#"{"first": true}"#);
2589 assert_eq!(events[1], r#"{"last": true}"#);
2590 }
2591
2592 #[test]
2593 fn sse_multi_line_data_concatenated() {
2594 let body = "data: first line\ndata: second line\n\n";
2595 let events = extract_sse_events(body).unwrap();
2596 assert_eq!(events.len(), 1);
2597 assert_eq!(events[0], "first line\nsecond line");
2598 }
2599
2600 #[test]
2603 fn find_response_matches_by_id() {
2604 let events = vec![
2605 r#"{"jsonrpc":"2.0","method":"notifications/progress","params":{}}"#.to_string(),
2606 r#"{"jsonrpc":"2.0","result":{"tools":[]},"id":5}"#.to_string(),
2607 ];
2608 let result = find_rpc_response(&events, 5).unwrap();
2609 assert!(result.contains(r#""id":5"#));
2610 assert!(result.contains(r#""result""#));
2611 }
2612
2613 #[test]
2618 fn find_response_rejects_mismatched_id() {
2619 let events = vec![r#"{"jsonrpc":"2.0","result":{},"id":99}"#.to_string()];
2620 let err = find_rpc_response(&events, 1).unwrap_err();
2621 assert!(matches!(err, Error::Mcp(_)));
2622 }
2623
2624 #[test]
2627 fn find_response_accepts_null_id_error_only() {
2628 let events = vec![
2629 r#"{"jsonrpc":"2.0","error":{"code":-32700,"message":"parse"},"id":null}"#.to_string(),
2630 ];
2631 let result = find_rpc_response(&events, 1).unwrap();
2632 assert!(result.contains("error"));
2633 }
2634
2635 #[test]
2638 fn mcp_tools_list_parsing() {
2639 let json = json!({
2640 "tools": [
2641 {
2642 "name": "read_file",
2643 "description": "Read a file from disk",
2644 "inputSchema": {
2645 "type": "object",
2646 "properties": {
2647 "path": {"type": "string"}
2648 },
2649 "required": ["path"]
2650 }
2651 },
2652 {
2653 "name": "list_dir",
2654 "description": "List directory contents",
2655 "inputSchema": {"type": "object"}
2656 }
2657 ]
2658 });
2659
2660 let result: McpToolsListResult = serde_json::from_value(json).unwrap();
2661 assert_eq!(result.tools.len(), 2);
2662 assert_eq!(result.tools[0].name, "read_file");
2663 assert_eq!(
2664 result.tools[0].description.as_deref(),
2665 Some("Read a file from disk")
2666 );
2667 assert!(result.tools[0].input_schema.is_some());
2668 assert_eq!(result.tools[1].name, "list_dir");
2669 }
2670
2671 #[test]
2672 fn mcp_tool_to_definition_mapping() {
2673 let mcp_def = McpToolDef {
2674 name: "search".into(),
2675 description: Some("Search for files".into()),
2676 input_schema: Some(json!({
2677 "type": "object",
2678 "properties": {"query": {"type": "string"}}
2679 })),
2680 };
2681
2682 let def = mcp_tool_to_definition(&mcp_def);
2683 assert_eq!(def.name, "search");
2684 assert_eq!(def.description, "Search for files");
2685 assert_eq!(
2686 def.input_schema,
2687 json!({"type": "object", "properties": {"query": {"type": "string"}}})
2688 );
2689 }
2690
2691 #[test]
2692 fn mcp_tool_defaults_for_missing_fields() {
2693 let json = json!({"name": "minimal"});
2694 let mcp_def: McpToolDef = serde_json::from_value(json).unwrap();
2695 assert!(mcp_def.description.is_none());
2696 assert!(mcp_def.input_schema.is_none());
2697
2698 let def = mcp_tool_to_definition(&mcp_def);
2699 assert_eq!(def.name, "minimal");
2700 assert_eq!(def.description, "");
2701 assert_eq!(def.input_schema, json!({"type": "object"}));
2702 }
2703
2704 #[test]
2707 fn tool_result_success() {
2708 let result = McpCallToolResult {
2709 content: vec![McpContent {
2710 content_type: "text".into(),
2711 text: Some("file contents here".into()),
2712 }],
2713 is_error: false,
2714 };
2715
2716 let output = mcp_result_to_tool_output(result);
2717 assert_eq!(output.content, "file contents here");
2718 assert!(!output.is_error);
2719 }
2720
2721 #[test]
2722 fn tool_result_error() {
2723 let result = McpCallToolResult {
2724 content: vec![McpContent {
2725 content_type: "text".into(),
2726 text: Some("permission denied".into()),
2727 }],
2728 is_error: true,
2729 };
2730
2731 let output = mcp_result_to_tool_output(result);
2732 assert_eq!(output.content, "permission denied");
2733 assert!(output.is_error);
2734 }
2735
2736 #[test]
2737 fn tool_result_multi_text_joined() {
2738 let result = McpCallToolResult {
2739 content: vec![
2740 McpContent {
2741 content_type: "text".into(),
2742 text: Some("line one".into()),
2743 },
2744 McpContent {
2745 content_type: "text".into(),
2746 text: Some("line two".into()),
2747 },
2748 McpContent {
2749 content_type: "text".into(),
2750 text: Some("line three".into()),
2751 },
2752 ],
2753 is_error: false,
2754 };
2755
2756 let output = mcp_result_to_tool_output(result);
2757 assert_eq!(output.content, "line one\nline two\nline three");
2758 }
2759
2760 #[test]
2761 fn tool_result_images_skipped() {
2762 let result = McpCallToolResult {
2763 content: vec![
2764 McpContent {
2765 content_type: "text".into(),
2766 text: Some("caption".into()),
2767 },
2768 McpContent {
2769 content_type: "image".into(),
2770 text: None,
2771 },
2772 McpContent {
2773 content_type: "text".into(),
2774 text: Some("more text".into()),
2775 },
2776 ],
2777 is_error: false,
2778 };
2779
2780 let output = mcp_result_to_tool_output(result);
2781 assert_eq!(output.content, "caption\nmore text");
2782 }
2783
2784 #[test]
2785 fn tool_result_parses_from_json() {
2786 let json = json!({
2787 "content": [
2788 {"type": "text", "text": "hello from mcp"}
2789 ],
2790 "isError": false
2791 });
2792
2793 let result: McpCallToolResult = serde_json::from_value(json).unwrap();
2794 assert_eq!(result.content.len(), 1);
2795 assert_eq!(result.content[0].text.as_deref(), Some("hello from mcp"));
2796 assert!(!result.is_error);
2797 }
2798
2799 #[test]
2800 fn tool_result_is_error_defaults_false() {
2801 let json = json!({
2802 "content": [
2803 {"type": "text", "text": "ok"}
2804 ]
2805 });
2806
2807 let result: McpCallToolResult = serde_json::from_value(json).unwrap();
2808 assert!(!result.is_error);
2809 }
2810
2811 #[test]
2812 fn tool_result_non_text_only_shows_placeholder() {
2813 let result = McpCallToolResult {
2814 content: vec![
2815 McpContent {
2816 content_type: "image".into(),
2817 text: None,
2818 },
2819 McpContent {
2820 content_type: "resource".into(),
2821 text: None,
2822 },
2823 ],
2824 is_error: false,
2825 };
2826
2827 let output = mcp_result_to_tool_output(result);
2828 assert!(output.content.contains("2 non-text content block(s)"));
2829 assert!(!output.is_error);
2830 }
2831
2832 #[test]
2833 fn tool_result_mixed_text_and_non_text_returns_text() {
2834 let result = McpCallToolResult {
2836 content: vec![
2837 McpContent {
2838 content_type: "text".into(),
2839 text: Some("real text".into()),
2840 },
2841 McpContent {
2842 content_type: "image".into(),
2843 text: None,
2844 },
2845 ],
2846 is_error: false,
2847 };
2848
2849 let output = mcp_result_to_tool_output(result);
2850 assert_eq!(output.content, "real text");
2851 }
2852
2853 #[test]
2856 fn process_rpc_response_success() {
2857 let json_str = r#"{"jsonrpc":"2.0","result":{"tools":[]},"id":1}"#;
2858 let value = process_rpc_response(json_str).unwrap();
2859 assert_eq!(value, json!({"tools": []}));
2860 }
2861
2862 #[test]
2865 fn process_rpc_response_error_is_tagged() {
2866 let json_str =
2867 r#"{"jsonrpc":"2.0","error":{"code":-32601,"message":"Method not found"},"id":1}"#;
2868 let err = process_rpc_response(json_str).unwrap_err();
2869 let s = err.to_string();
2870 assert!(s.contains("[mcp_server_error"), "missing tag prefix: {s}");
2871 assert!(s.contains("code=-32601"), "missing code: {s}");
2872 assert!(s.contains("Method not found"), "missing message: {s}");
2873 }
2874
2875 #[test]
2879 fn process_rpc_response_error_truncates_long_message() {
2880 let huge = "X".repeat(8 * 1024);
2881 let json_str =
2882 format!(r#"{{"jsonrpc":"2.0","error":{{"code":-32000,"message":"{huge}"}},"id":1}}"#);
2883 let err = process_rpc_response(&json_str).unwrap_err();
2884 let s = err.to_string();
2885 assert!(s.contains("…[truncated]"), "missing truncation marker: {s}");
2886 assert!(
2887 s.len() < 2048,
2888 "error message not bounded: {} bytes",
2889 s.len()
2890 );
2891 }
2892
2893 #[test]
2894 fn process_rpc_response_missing_both() {
2895 let json_str = r#"{"jsonrpc":"2.0","id":1}"#;
2896 let err = process_rpc_response(json_str).unwrap_err();
2897 assert!(err.to_string().contains("missing both result and error"));
2898 }
2899
2900 #[tokio::test]
2903 async fn read_stdio_response_finds_matching_id() {
2904 let (mut tx, rx) = tokio::io::duplex(4096);
2905 let mut reader = tokio::io::BufReader::new(rx);
2906
2907 tokio::spawn(async move {
2908 tx.write_all(b"{\"jsonrpc\":\"2.0\",\"result\":{\"ok\":true},\"id\":1}\n")
2909 .await
2910 .unwrap();
2911 });
2912
2913 let response = read_stdio_response(&mut reader, 1).await.unwrap();
2914 assert!(response.contains("\"id\":1"));
2915 assert!(response.contains("\"ok\":true"));
2916 }
2917
2918 #[tokio::test]
2919 async fn read_stdio_response_skips_notifications() {
2920 let (mut tx, rx) = tokio::io::duplex(4096);
2921 let mut reader = tokio::io::BufReader::new(rx);
2922
2923 tokio::spawn(async move {
2924 tx.write_all(b"{\"jsonrpc\":\"2.0\",\"method\":\"notifications/progress\"}\n")
2926 .await
2927 .unwrap();
2928 tx.write_all(b"{\"jsonrpc\":\"2.0\",\"result\":{\"tools\":[]},\"id\":1}\n")
2929 .await
2930 .unwrap();
2931 });
2932
2933 let response = read_stdio_response(&mut reader, 1).await.unwrap();
2934 assert!(response.contains("\"id\":1"));
2935 assert!(response.contains("\"tools\""));
2936 }
2937
2938 #[tokio::test]
2939 async fn read_stdio_response_skips_null_id() {
2940 let (mut tx, rx) = tokio::io::duplex(4096);
2941 let mut reader = tokio::io::BufReader::new(rx);
2942
2943 tokio::spawn(async move {
2944 tx.write_all(b"{\"jsonrpc\":\"2.0\",\"result\":{},\"id\":null}\n")
2946 .await
2947 .unwrap();
2948 tx.write_all(b"{\"jsonrpc\":\"2.0\",\"result\":{\"found\":true},\"id\":2}\n")
2949 .await
2950 .unwrap();
2951 });
2952
2953 let response = read_stdio_response(&mut reader, 2).await.unwrap();
2954 assert!(response.contains("\"id\":2"));
2955 assert!(response.contains("\"found\":true"));
2956 }
2957
2958 #[tokio::test]
2959 async fn read_stdio_response_skips_non_json() {
2960 let (mut tx, rx) = tokio::io::duplex(4096);
2961 let mut reader = tokio::io::BufReader::new(rx);
2962
2963 tokio::spawn(async move {
2964 tx.write_all(b"[DEBUG] initializing server...\n")
2966 .await
2967 .unwrap();
2968 tx.write_all(b"\n").await.unwrap(); tx.write_all(b"{\"jsonrpc\":\"2.0\",\"result\":{},\"id\":0}\n")
2970 .await
2971 .unwrap();
2972 });
2973
2974 let response = read_stdio_response(&mut reader, 0).await.unwrap();
2975 assert!(response.contains("\"id\":0"));
2976 }
2977
2978 #[tokio::test]
2979 async fn read_stdio_response_eof_errors() {
2980 let (tx, rx) = tokio::io::duplex(4096);
2981 let mut reader = tokio::io::BufReader::new(rx);
2982
2983 drop(tx);
2985
2986 let err = read_stdio_response(&mut reader, 0).await.unwrap_err();
2987 assert!(
2988 err.to_string().contains("closed unexpectedly"),
2989 "error: {err}"
2990 );
2991 }
2992
2993 #[tokio::test]
2994 async fn read_stdio_response_skips_wrong_id() {
2995 let (mut tx, rx) = tokio::io::duplex(4096);
2996 let mut reader = tokio::io::BufReader::new(rx);
2997
2998 tokio::spawn(async move {
2999 tx.write_all(b"{\"jsonrpc\":\"2.0\",\"result\":{\"wrong\":true},\"id\":99}\n")
3001 .await
3002 .unwrap();
3003 tx.write_all(b"{\"jsonrpc\":\"2.0\",\"result\":{\"right\":true},\"id\":3}\n")
3004 .await
3005 .unwrap();
3006 });
3007
3008 let response = read_stdio_response(&mut reader, 3).await.unwrap();
3009 assert!(response.contains("\"right\":true"));
3010 }
3011
3012 #[tokio::test]
3013 async fn read_stdio_response_timeout_prevents_hang() {
3014 let (_tx, rx) = tokio::io::duplex(4096);
3016 let mut reader = tokio::io::BufReader::new(rx);
3017
3018 let result = tokio::time::timeout(
3019 Duration::from_millis(50),
3020 read_stdio_response(&mut reader, 0),
3021 )
3022 .await;
3023
3024 assert!(result.is_err(), "should have timed out");
3025 }
3026
3027 #[test]
3030 fn http_transport_next_id_is_monotonic() {
3031 let transport = HttpTransport {
3032 client: reqwest::Client::new(),
3033 endpoint: "http://unused".to_string(),
3034 session_id: RwLock::new(None),
3035 next_id: AtomicU64::new(0),
3036 auth_header: None,
3037 };
3038
3039 assert_eq!(transport.next_id(), 0);
3040 assert_eq!(transport.next_id(), 1);
3041 assert_eq!(transport.next_id(), 2);
3042 }
3043
3044 #[test]
3047 fn mcp_tool_returns_correct_definition() {
3048 let transport = Arc::new(Transport::Http(HttpTransport {
3049 client: reqwest::Client::new(),
3050 endpoint: "http://unused".to_string(),
3051 session_id: RwLock::new(None),
3052 next_id: AtomicU64::new(0),
3053 auth_header: None,
3054 }));
3055
3056 let expected_def = ToolDefinition {
3057 name: "read_file".into(),
3058 description: "Read a file".into(),
3059 input_schema: json!({
3060 "type": "object",
3061 "properties": {"path": {"type": "string"}}
3062 }),
3063 };
3064
3065 let tool = McpTool {
3066 transport,
3067 def: expected_def.clone(),
3068 auth_resolver: None,
3069 };
3070
3071 let def = tool.definition();
3072 assert_eq!(def, expected_def);
3073 }
3074
3075 #[tokio::test]
3078 async fn static_auth_provider_returns_header() {
3079 let provider = StaticAuthProvider::new(Some("Bearer xyz".to_string()));
3080 let result = provider.auth_header_for("user1", "tenant1").await.unwrap();
3081 assert_eq!(result, Some("Bearer xyz".to_string()));
3082 }
3083
3084 #[tokio::test]
3085 async fn static_auth_provider_returns_none() {
3086 let provider = StaticAuthProvider::new(None);
3087 let result = provider.auth_header_for("user1", "tenant1").await.unwrap();
3088 assert_eq!(result, None);
3089 }
3090
3091 #[tokio::test]
3092 async fn static_auth_provider_ignores_user_tenant() {
3093 let provider = StaticAuthProvider::new(Some("Bearer abc".to_string()));
3094 let r1 = provider.auth_header_for("alice", "acme").await.unwrap();
3095 let r2 = provider.auth_header_for("bob", "globex").await.unwrap();
3096 assert_eq!(r1, r2);
3097 assert_eq!(r1, Some("Bearer abc".to_string()));
3098 }
3099
3100 #[tokio::test]
3101 async fn token_exchange_provider_missing_user_token() {
3102 let user_tokens = Arc::new(std::sync::RwLock::new(HashMap::<String, String>::new()));
3103 let provider = TokenExchangeAuthProvider::new(
3104 "https://idp.example.com/token",
3105 "client-id",
3106 "client-secret",
3107 "agent-token-xyz",
3108 )
3109 .with_user_tokens(user_tokens);
3110
3111 let result = provider.auth_header_for("unknown-user", "tenant1").await;
3112 assert!(result.is_err());
3113 let err_msg = result.unwrap_err().to_string();
3114 assert!(
3115 err_msg.contains("unknown-user"),
3116 "error should mention the user_id: {err_msg}"
3117 );
3118 }
3119
3120 #[tokio::test]
3121 async fn mcp_tool_execute_catches_network_errors() {
3122 let transport = Arc::new(Transport::Http(HttpTransport {
3123 client: reqwest::Client::new(),
3124 endpoint: "http://127.0.0.1:1".to_string(), session_id: RwLock::new(None),
3126 next_id: AtomicU64::new(0),
3127 auth_header: None,
3128 }));
3129
3130 let tool = McpTool {
3131 transport,
3132 def: ToolDefinition {
3133 name: "test_tool".into(),
3134 description: "test".into(),
3135 input_schema: json!({"type": "object"}),
3136 },
3137 auth_resolver: None,
3138 };
3139
3140 let result = tool
3143 .execute(&crate::ExecutionContext::default(), json!({}))
3144 .await
3145 .unwrap();
3146 assert!(result.is_error);
3147 assert!(!result.content.is_empty());
3148 }
3149
3150 #[test]
3153 fn server_capabilities_parses_full() {
3154 let json = json!({
3155 "capabilities": {
3156 "resources": { "subscribe": true, "listChanged": true },
3157 "prompts": { "listChanged": false },
3158 "logging": {},
3159 "tools": { "listChanged": true }
3160 },
3161 "serverInfo": { "name": "test-server", "version": "1.0" }
3162 });
3163 let result: InitializeResult = serde_json::from_value(json).unwrap();
3164 assert!(result.capabilities.resources.is_some());
3165 let res = result.capabilities.resources.unwrap();
3166 assert!(res.subscribe);
3167 assert!(res.list_changed);
3168 assert!(result.capabilities.prompts.is_some());
3169 }
3170
3171 #[test]
3172 fn server_capabilities_parses_empty() {
3173 let json = json!({
3174 "capabilities": {},
3175 });
3176 let result: InitializeResult = serde_json::from_value(json).unwrap();
3177 assert!(result.capabilities.resources.is_none());
3178 assert!(result.capabilities.prompts.is_none());
3179 }
3180
3181 #[test]
3182 fn server_capabilities_defaults_on_missing() {
3183 let json = json!({});
3184 let result: InitializeResult = serde_json::from_value(json).unwrap();
3185 assert!(result.capabilities.resources.is_none());
3186 assert!(result.capabilities.prompts.is_none());
3187 }
3188
3189 #[test]
3190 fn server_capabilities_resources_only() {
3191 let json = json!({
3192 "capabilities": {
3193 "resources": {}
3194 }
3195 });
3196 let result: InitializeResult = serde_json::from_value(json).unwrap();
3197 assert!(result.capabilities.resources.is_some());
3198 let res = result.capabilities.resources.unwrap();
3199 assert!(!res.subscribe); assert!(!res.list_changed);
3201 assert!(result.capabilities.prompts.is_none());
3202 }
3203
3204 #[test]
3207 fn resource_def_serde_roundtrip() {
3208 let def = McpResourceDef {
3209 uri: "file:///README.md".into(),
3210 name: "README".into(),
3211 description: Some("Project readme".into()),
3212 mime_type: Some("text/markdown".into()),
3213 };
3214 let json = serde_json::to_value(&def).unwrap();
3215 assert_eq!(json["uri"], "file:///README.md");
3216 assert_eq!(json["name"], "README");
3217 let parsed: McpResourceDef = serde_json::from_value(json).unwrap();
3218 assert_eq!(parsed.uri, "file:///README.md");
3219 assert_eq!(parsed.mime_type.as_deref(), Some("text/markdown"));
3220 }
3221
3222 #[test]
3223 fn resource_def_minimal() {
3224 let json = json!({"uri": "test://x", "name": "x"});
3225 let def: McpResourceDef = serde_json::from_value(json).unwrap();
3226 assert_eq!(def.uri, "test://x");
3227 assert!(def.description.is_none());
3228 assert!(def.mime_type.is_none());
3229 }
3230
3231 #[test]
3232 fn resources_list_result_parsing() {
3233 let json = json!({
3234 "resources": [
3235 {
3236 "uri": "file:///config.toml",
3237 "name": "config",
3238 "description": "App configuration",
3239 "mimeType": "application/toml"
3240 },
3241 {
3242 "uri": "db://users/schema",
3243 "name": "users_schema"
3244 }
3245 ]
3246 });
3247 let result: McpResourcesListResult = serde_json::from_value(json).unwrap();
3248 assert_eq!(result.resources.len(), 2);
3249 assert_eq!(result.resources[0].uri, "file:///config.toml");
3250 assert_eq!(result.resources[0].name, "config");
3251 assert_eq!(
3252 result.resources[0].mime_type.as_deref(),
3253 Some("application/toml")
3254 );
3255 assert_eq!(result.resources[1].name, "users_schema");
3256 assert!(result.next_cursor.is_none());
3257 }
3258
3259 #[test]
3260 fn resources_list_with_cursor() {
3261 let json = json!({
3262 "resources": [{"uri": "a://1", "name": "one"}],
3263 "nextCursor": "page2"
3264 });
3265 let result: McpResourcesListResult = serde_json::from_value(json).unwrap();
3266 assert_eq!(result.resources.len(), 1);
3267 assert_eq!(result.next_cursor.as_deref(), Some("page2"));
3268 }
3269
3270 #[test]
3271 fn resource_content_parsing() {
3272 let json = json!({
3273 "uri": "file:///README.md",
3274 "mimeType": "text/markdown",
3275 "text": "# Hello World"
3276 });
3277 let content: McpResourceContent = serde_json::from_value(json).unwrap();
3278 assert_eq!(content.uri, "file:///README.md");
3279 assert_eq!(content.mime_type.as_deref(), Some("text/markdown"));
3280 assert_eq!(content.text.as_deref(), Some("# Hello World"));
3281 assert!(content.blob.is_none());
3282 }
3283
3284 #[test]
3285 fn resource_read_result_parsing() {
3286 let json = json!({
3287 "contents": [
3288 {"uri": "file:///a.txt", "text": "content A"},
3289 {"uri": "file:///b.txt", "text": "content B"}
3290 ]
3291 });
3292 let result: McpResourceReadResult = serde_json::from_value(json).unwrap();
3293 assert_eq!(result.contents.len(), 2);
3294 assert_eq!(result.contents[0].text.as_deref(), Some("content A"));
3295 }
3296
3297 #[test]
3300 fn prompt_def_serde_roundtrip() {
3301 let def = McpPromptDef {
3302 name: "summarize".into(),
3303 description: Some("Summarize text".into()),
3304 arguments: vec![McpPromptArgument {
3305 name: "text".into(),
3306 description: Some("Text to summarize".into()),
3307 required: true,
3308 }],
3309 };
3310 let json = serde_json::to_value(&def).unwrap();
3311 assert_eq!(json["name"], "summarize");
3312 let parsed: McpPromptDef = serde_json::from_value(json).unwrap();
3313 assert_eq!(parsed.arguments.len(), 1);
3314 assert!(parsed.arguments[0].required);
3315 }
3316
3317 #[test]
3318 fn prompt_def_minimal() {
3319 let json = json!({"name": "greet"});
3320 let def: McpPromptDef = serde_json::from_value(json).unwrap();
3321 assert_eq!(def.name, "greet");
3322 assert!(def.description.is_none());
3323 assert!(def.arguments.is_empty());
3324 }
3325
3326 #[test]
3327 fn prompts_list_result_parsing() {
3328 let json = json!({
3329 "prompts": [
3330 {
3331 "name": "code_review",
3332 "description": "Review code for issues",
3333 "arguments": [
3334 {"name": "code", "description": "Code to review", "required": true},
3335 {"name": "language", "description": "Programming language", "required": false}
3336 ]
3337 }
3338 ]
3339 });
3340 let result: McpPromptsListResult = serde_json::from_value(json).unwrap();
3341 assert_eq!(result.prompts.len(), 1);
3342 assert_eq!(result.prompts[0].name, "code_review");
3343 assert_eq!(result.prompts[0].arguments.len(), 2);
3344 assert!(result.prompts[0].arguments[0].required);
3345 assert!(!result.prompts[0].arguments[1].required);
3346 }
3347
3348 #[test]
3349 fn prompt_get_result_parsing() {
3350 let json = json!({
3351 "description": "A helpful prompt",
3352 "messages": [
3353 {
3354 "role": "user",
3355 "content": {"type": "text", "text": "Please help me with this code"}
3356 },
3357 {
3358 "role": "assistant",
3359 "content": {"type": "text", "text": "I'd be happy to help!"}
3360 }
3361 ]
3362 });
3363 let result: McpPromptGetResult = serde_json::from_value(json).unwrap();
3364 assert_eq!(result.messages.len(), 2);
3365 assert_eq!(result.messages[0].role, "user");
3366 assert_eq!(
3367 result.messages[0].content.text.as_deref(),
3368 Some("Please help me with this code")
3369 );
3370 assert_eq!(result.messages[1].role, "assistant");
3371 }
3372
3373 #[test]
3376 fn sanitize_tool_name_alphanumeric() {
3377 assert_eq!(sanitize_tool_name("hello_world"), "hello_world");
3378 assert_eq!(sanitize_tool_name("test123"), "test123");
3379 }
3380
3381 #[test]
3382 fn sanitize_tool_name_special_chars() {
3383 assert_eq!(sanitize_tool_name("my-resource"), "my_resource");
3384 assert_eq!(sanitize_tool_name("path/to/thing"), "path_to_thing");
3385 assert_eq!(sanitize_tool_name("file.txt"), "file_txt");
3386 assert_eq!(sanitize_tool_name("a b c"), "a_b_c");
3387 }
3388
3389 #[test]
3392 fn resource_tool_definition() {
3393 let transport = Arc::new(Transport::Http(HttpTransport {
3394 client: reqwest::Client::new(),
3395 endpoint: "http://unused".to_string(),
3396 session_id: RwLock::new(None),
3397 next_id: AtomicU64::new(0),
3398 auth_header: None,
3399 }));
3400
3401 let tool = McpResourceTool {
3402 transport,
3403 resource: McpResourceDef {
3404 uri: "file:///README.md".into(),
3405 name: "readme".into(),
3406 description: Some("Project readme".into()),
3407 mime_type: None,
3408 },
3409 tool_name: "mcp_resource_readme".into(),
3410 auth_resolver: None,
3411 };
3412
3413 let def = tool.definition();
3414 assert_eq!(def.name, "mcp_resource_readme");
3415 assert_eq!(def.description, "Project readme");
3416 assert_eq!(
3417 def.input_schema,
3418 json!({"type": "object", "properties": {}})
3419 );
3420 }
3421
3422 #[test]
3423 fn resource_tool_definition_default_description() {
3424 let transport = Arc::new(Transport::Http(HttpTransport {
3425 client: reqwest::Client::new(),
3426 endpoint: "http://unused".to_string(),
3427 session_id: RwLock::new(None),
3428 next_id: AtomicU64::new(0),
3429 auth_header: None,
3430 }));
3431
3432 let tool = McpResourceTool {
3433 transport,
3434 resource: McpResourceDef {
3435 uri: "db://users".into(),
3436 name: "users".into(),
3437 description: None,
3438 mime_type: None,
3439 },
3440 tool_name: "mcp_resource_users".into(),
3441 auth_resolver: None,
3442 };
3443
3444 let def = tool.definition();
3445 assert!(def.description.contains("db://users"));
3446 }
3447
3448 #[test]
3451 fn prompt_tool_definition_with_args() {
3452 let transport = Arc::new(Transport::Http(HttpTransport {
3453 client: reqwest::Client::new(),
3454 endpoint: "http://unused".to_string(),
3455 session_id: RwLock::new(None),
3456 next_id: AtomicU64::new(0),
3457 auth_header: None,
3458 }));
3459
3460 let tool = McpPromptTool {
3461 transport,
3462 prompt: McpPromptDef {
3463 name: "review".into(),
3464 description: Some("Code review".into()),
3465 arguments: vec![
3466 McpPromptArgument {
3467 name: "code".into(),
3468 description: Some("Code to review".into()),
3469 required: true,
3470 },
3471 McpPromptArgument {
3472 name: "language".into(),
3473 description: None,
3474 required: false,
3475 },
3476 ],
3477 },
3478 tool_name: "mcp_prompt_review".into(),
3479 auth_resolver: None,
3480 };
3481
3482 let def = tool.definition();
3483 assert_eq!(def.name, "mcp_prompt_review");
3484 assert_eq!(def.description, "Code review");
3485 let schema = &def.input_schema;
3486 assert!(schema["properties"]["code"].is_object());
3487 assert_eq!(
3488 schema["properties"]["code"]["description"],
3489 "Code to review"
3490 );
3491 assert_eq!(schema["required"], json!(["code"]));
3492 assert!(
3494 !schema["required"]
3495 .as_array()
3496 .unwrap()
3497 .contains(&json!("language"))
3498 );
3499 }
3500
3501 #[test]
3502 fn prompt_tool_definition_no_args() {
3503 let transport = Arc::new(Transport::Http(HttpTransport {
3504 client: reqwest::Client::new(),
3505 endpoint: "http://unused".to_string(),
3506 session_id: RwLock::new(None),
3507 next_id: AtomicU64::new(0),
3508 auth_header: None,
3509 }));
3510
3511 let tool = McpPromptTool {
3512 transport,
3513 prompt: McpPromptDef {
3514 name: "greet".into(),
3515 description: None,
3516 arguments: vec![],
3517 },
3518 tool_name: "mcp_prompt_greet".into(),
3519 auth_resolver: None,
3520 };
3521
3522 let def = tool.definition();
3523 assert_eq!(def.name, "mcp_prompt_greet");
3524 assert!(def.description.contains("greet"));
3525 assert!(def.input_schema.get("required").is_none());
3527 }
3528
3529 #[test]
3532 fn into_resource_tools_creates_correct_names() {
3533 let transport = Arc::new(Transport::Http(HttpTransport {
3534 client: reqwest::Client::new(),
3535 endpoint: "http://unused".to_string(),
3536 session_id: RwLock::new(None),
3537 next_id: AtomicU64::new(0),
3538 auth_header: None,
3539 }));
3540
3541 let client = McpClient {
3542 transport,
3543 tools: vec![],
3544 resources: vec![
3545 McpResourceDef {
3546 uri: "file:///a.txt".into(),
3547 name: "readme-file".into(),
3548 description: None,
3549 mime_type: None,
3550 },
3551 McpResourceDef {
3552 uri: "db://schema".into(),
3553 name: "db schema".into(),
3554 description: Some("Database schema".into()),
3555 mime_type: None,
3556 },
3557 ],
3558 prompts: vec![],
3559 capabilities: ServerCapabilities::default(),
3560 sampling_handler: None,
3561 roots: Vec::new(),
3562 };
3563
3564 let tools = client.into_resource_tools();
3565 assert_eq!(tools.len(), 2);
3566 assert_eq!(tools[0].definition().name, "mcp_resource_readme_file");
3567 assert_eq!(tools[1].definition().name, "mcp_resource_db_schema");
3568 assert_eq!(tools[1].definition().description, "Database schema");
3569 }
3570
3571 #[test]
3572 fn into_prompt_tools_creates_correct_names() {
3573 let transport = Arc::new(Transport::Http(HttpTransport {
3574 client: reqwest::Client::new(),
3575 endpoint: "http://unused".to_string(),
3576 session_id: RwLock::new(None),
3577 next_id: AtomicU64::new(0),
3578 auth_header: None,
3579 }));
3580
3581 let client = McpClient {
3582 transport,
3583 tools: vec![],
3584 resources: vec![],
3585 prompts: vec![McpPromptDef {
3586 name: "code-review".into(),
3587 description: Some("Review code".into()),
3588 arguments: vec![],
3589 }],
3590 capabilities: ServerCapabilities::default(),
3591 sampling_handler: None,
3592 roots: Vec::new(),
3593 };
3594
3595 let tools = client.into_prompt_tools();
3596 assert_eq!(tools.len(), 1);
3597 assert_eq!(tools[0].definition().name, "mcp_prompt_code_review");
3598 }
3599
3600 #[test]
3601 fn into_all_tools_combines_everything() {
3602 let transport = Arc::new(Transport::Http(HttpTransport {
3603 client: reqwest::Client::new(),
3604 endpoint: "http://unused".to_string(),
3605 session_id: RwLock::new(None),
3606 next_id: AtomicU64::new(0),
3607 auth_header: None,
3608 }));
3609
3610 let client = McpClient {
3611 transport,
3612 tools: vec![McpToolDef {
3613 name: "read_file".into(),
3614 description: Some("Read a file".into()),
3615 input_schema: Some(json!({"type": "object"})),
3616 }],
3617 resources: vec![McpResourceDef {
3618 uri: "file:///a.txt".into(),
3619 name: "readme".into(),
3620 description: None,
3621 mime_type: None,
3622 }],
3623 prompts: vec![McpPromptDef {
3624 name: "greet".into(),
3625 description: None,
3626 arguments: vec![],
3627 }],
3628 capabilities: ServerCapabilities::default(),
3629 sampling_handler: None,
3630 roots: Vec::new(),
3631 };
3632
3633 let all = client.into_all_tools();
3634 assert_eq!(all.len(), 3);
3635 let names: Vec<String> = all.iter().map(|t| t.definition().name).collect();
3636 assert!(names.contains(&"read_file".to_string()));
3637 assert!(names.contains(&"mcp_resource_readme".to_string()));
3638 assert!(names.contains(&"mcp_prompt_greet".to_string()));
3639 }
3640
3641 #[test]
3642 fn supports_resource_subscribe_false_by_default() {
3643 let transport = Arc::new(Transport::Http(HttpTransport {
3644 client: reqwest::Client::new(),
3645 endpoint: "http://unused".to_string(),
3646 session_id: RwLock::new(None),
3647 next_id: AtomicU64::new(0),
3648 auth_header: None,
3649 }));
3650 let client = McpClient {
3651 transport,
3652 tools: vec![],
3653 resources: vec![],
3654 prompts: vec![],
3655 capabilities: ServerCapabilities::default(),
3656 sampling_handler: None,
3657 roots: Vec::new(),
3658 };
3659 assert!(!client.supports_resource_subscribe());
3660 }
3661
3662 #[test]
3663 fn supports_resource_subscribe_when_advertised() {
3664 let transport = Arc::new(Transport::Http(HttpTransport {
3665 client: reqwest::Client::new(),
3666 endpoint: "http://unused".to_string(),
3667 session_id: RwLock::new(None),
3668 next_id: AtomicU64::new(0),
3669 auth_header: None,
3670 }));
3671 let client = McpClient {
3672 transport,
3673 tools: vec![],
3674 resources: vec![],
3675 prompts: vec![],
3676 capabilities: ServerCapabilities {
3677 resources: Some(ResourcesCapability {
3678 subscribe: true,
3679 list_changed: false,
3680 }),
3681 ..Default::default()
3682 },
3683 sampling_handler: None,
3684 roots: Vec::new(),
3685 };
3686 assert!(client.supports_resource_subscribe());
3687 }
3688
3689 #[test]
3692 fn sampling_request_parsing() {
3693 let json = json!({
3694 "messages": [
3695 {
3696 "role": "user",
3697 "content": {"type": "text", "text": "What is 2+2?"}
3698 }
3699 ],
3700 "modelPreferences": {
3701 "hints": [{"name": "claude-sonnet-4-6-20250610"}]
3702 },
3703 "systemPrompt": "You are a math helper",
3704 "maxTokens": 100
3705 });
3706 let req: SamplingRequest = serde_json::from_value(json).unwrap();
3707 assert_eq!(req.messages.len(), 1);
3708 assert_eq!(req.messages[0].role, "user");
3709 assert_eq!(
3710 req.messages[0].content.text.as_deref(),
3711 Some("What is 2+2?")
3712 );
3713 assert_eq!(req.system_prompt.as_deref(), Some("You are a math helper"));
3714 assert_eq!(req.max_tokens, Some(100));
3715 let hints = &req.model_preferences.unwrap().hints;
3716 assert_eq!(hints[0].name.as_deref(), Some("claude-sonnet-4-6-20250610"));
3717 }
3718
3719 #[test]
3720 fn sampling_request_minimal() {
3721 let json = json!({
3722 "messages": [{"role": "user", "content": {"type": "text", "text": "hi"}}]
3723 });
3724 let req: SamplingRequest = serde_json::from_value(json).unwrap();
3725 assert_eq!(req.messages.len(), 1);
3726 assert!(req.model_preferences.is_none());
3727 assert!(req.system_prompt.is_none());
3728 assert!(req.max_tokens.is_none());
3729 }
3730
3731 #[test]
3732 fn sampling_response_serialization() {
3733 let resp = SamplingResponse {
3734 role: "assistant".into(),
3735 content: SamplingContent {
3736 content_type: "text".into(),
3737 text: Some("4".into()),
3738 },
3739 model: "claude-sonnet-4-6-20250610".into(),
3740 };
3741 let json = serde_json::to_value(&resp).unwrap();
3742 assert_eq!(json["role"], "assistant");
3743 assert_eq!(json["content"]["type"], "text");
3744 assert_eq!(json["content"]["text"], "4");
3745 assert_eq!(json["model"], "claude-sonnet-4-6-20250610");
3746 }
3747
3748 #[test]
3749 fn sampling_message_serde_roundtrip() {
3750 let msg = SamplingMessage {
3751 role: "user".into(),
3752 content: SamplingContent {
3753 content_type: "text".into(),
3754 text: Some("hello".into()),
3755 },
3756 };
3757 let json = serde_json::to_value(&msg).unwrap();
3758 let parsed: SamplingMessage = serde_json::from_value(json).unwrap();
3759 assert_eq!(parsed.role, "user");
3760 assert_eq!(parsed.content.text.as_deref(), Some("hello"));
3761 }
3762
3763 #[test]
3764 fn with_sampling_sets_handler() {
3765 let transport = Arc::new(Transport::Http(HttpTransport {
3766 client: reqwest::Client::new(),
3767 endpoint: "http://unused".to_string(),
3768 session_id: RwLock::new(None),
3769 next_id: AtomicU64::new(0),
3770 auth_header: None,
3771 }));
3772 let client = McpClient {
3773 transport,
3774 tools: vec![],
3775 resources: vec![],
3776 prompts: vec![],
3777 capabilities: ServerCapabilities::default(),
3778 sampling_handler: None,
3779 roots: Vec::new(),
3780 };
3781 assert!(client.sampling_handler.is_none());
3782
3783 let handler: SamplingHandler =
3784 Arc::new(|_req| Box::pin(async move { Ok(("response".into(), "model".into())) }));
3785 let client = client.with_sampling(handler);
3786 assert!(client.sampling_handler.is_some());
3787 }
3788
3789 #[test]
3792 fn handle_log_notification_info() {
3793 let value = json!({
3795 "jsonrpc": "2.0",
3796 "method": "notifications/message",
3797 "params": {"level": "info", "logger": "test-server", "data": "Server started"}
3798 });
3799 handle_log_notification(&value);
3800 }
3801
3802 #[test]
3803 fn handle_log_notification_error() {
3804 let value = json!({
3805 "jsonrpc": "2.0",
3806 "method": "notifications/message",
3807 "params": {"level": "error", "data": "Something went wrong"}
3808 });
3809 handle_log_notification(&value);
3810 }
3811
3812 #[test]
3813 fn handle_log_notification_missing_params() {
3814 let value = json!({"jsonrpc": "2.0", "method": "notifications/message"});
3815 handle_log_notification(&value); }
3817
3818 #[test]
3819 fn find_rpc_response_skips_log_notifications() {
3820 let events = vec![
3821 r#"{"jsonrpc":"2.0","method":"notifications/message","params":{"level":"info","data":"log"}}"#.to_string(),
3822 r#"{"jsonrpc":"2.0","result":{"ok":true},"id":1}"#.to_string(),
3823 ];
3824 let result = find_rpc_response(&events, 1).unwrap();
3825 assert!(result.contains("\"id\":1"));
3826 }
3827
3828 #[test]
3831 fn mcp_root_serde_roundtrip() {
3832 let root = McpRoot {
3833 uri: "file:///workspace/project".into(),
3834 name: Some("project".into()),
3835 };
3836 let json = serde_json::to_value(&root).unwrap();
3837 assert_eq!(json["uri"], "file:///workspace/project");
3838 assert_eq!(json["name"], "project");
3839 let parsed: McpRoot = serde_json::from_value(json).unwrap();
3840 assert_eq!(parsed.uri, "file:///workspace/project");
3841 }
3842
3843 #[test]
3844 fn mcp_root_minimal() {
3845 let json = json!({"uri": "file:///tmp"});
3846 let root: McpRoot = serde_json::from_value(json).unwrap();
3847 assert_eq!(root.uri, "file:///tmp");
3848 assert!(root.name.is_none());
3849 }
3850
3851 #[test]
3852 fn mcp_root_name_omitted_when_none() {
3853 let root = McpRoot {
3854 uri: "file:///x".into(),
3855 name: None,
3856 };
3857 let json = serde_json::to_string(&root).unwrap();
3858 assert!(!json.contains("name"));
3859 }
3860
3861 #[test]
3862 fn with_roots_sets_roots() {
3863 let transport = Arc::new(Transport::Http(HttpTransport {
3864 client: reqwest::Client::new(),
3865 endpoint: "http://unused".to_string(),
3866 session_id: RwLock::new(None),
3867 next_id: AtomicU64::new(0),
3868 auth_header: None,
3869 }));
3870 let client = McpClient {
3871 transport,
3872 tools: vec![],
3873 resources: vec![],
3874 prompts: vec![],
3875 capabilities: ServerCapabilities::default(),
3876 sampling_handler: None,
3877 roots: Vec::new(),
3878 };
3879 assert!(client.roots().is_empty());
3880
3881 let client = client.with_roots(vec![McpRoot {
3882 uri: "file:///workspace".into(),
3883 name: Some("workspace".into()),
3884 }]);
3885 assert_eq!(client.roots().len(), 1);
3886 assert_eq!(client.roots()[0].uri, "file:///workspace");
3887 }
3888
3889 #[tokio::test]
3890 async fn read_stdio_response_forwards_log_notifications() {
3891 let (mut tx, rx) = tokio::io::duplex(4096);
3892 let mut reader = tokio::io::BufReader::new(rx);
3893
3894 tokio::spawn(async move {
3895 tx.write_all(b"{\"jsonrpc\":\"2.0\",\"method\":\"notifications/message\",\"params\":{\"level\":\"info\",\"data\":\"test log\"}}\n")
3897 .await
3898 .unwrap();
3899 tx.write_all(b"{\"jsonrpc\":\"2.0\",\"result\":{\"ok\":true},\"id\":1}\n")
3900 .await
3901 .unwrap();
3902 });
3903
3904 let response = read_stdio_response(&mut reader, 1).await.unwrap();
3905 assert!(response.contains("\"id\":1"));
3906 assert!(response.contains("\"ok\":true"));
3907 }
3908
3909 #[tokio::test]
3912 async fn static_auth_resolver_returns_header() {
3913 let resolver = StaticAuthResolver(Some("Bearer xyz".into()));
3914 let result = resolver.resolve().await.unwrap();
3915 assert_eq!(result, Some("Bearer xyz".to_string()));
3916 }
3917
3918 #[tokio::test]
3919 async fn static_auth_resolver_returns_none() {
3920 let resolver = StaticAuthResolver(None);
3921 let result = resolver.resolve().await.unwrap();
3922 assert_eq!(result, None);
3923 }
3924
3925 #[tokio::test]
3926 async fn dynamic_auth_resolver_calls_provider() {
3927 let provider = Arc::new(StaticAuthProvider::new(Some("Bearer dynamic".into())));
3928 let resolver = DynamicAuthResolver::new(provider, "user1", "tenant1");
3929 let result = resolver.resolve().await.unwrap();
3930 assert_eq!(result, Some("Bearer dynamic".to_string()));
3931 }
3932
3933 #[tokio::test]
3934 async fn dynamic_auth_resolver_with_resource_and_scopes() {
3935 let provider = Arc::new(StaticAuthProvider::new(Some("Bearer scoped".into())));
3936 let resolver = DynamicAuthResolver::new(provider, "user1", "tenant1")
3937 .with_resource(Some("https://gmail.googleapis.com".into()))
3938 .with_scopes(Some(vec!["gmail.readonly".into()]));
3939 let result = resolver.resolve().await.unwrap();
3941 assert_eq!(result, Some("Bearer scoped".to_string()));
3942 }
3943
3944 #[tokio::test]
3945 async fn auth_header_for_resource_default_delegates() {
3946 let provider = StaticAuthProvider::new(Some("Bearer base".into()));
3947 let result = provider
3948 .auth_header_for_resource(
3949 "user1",
3950 "tenant1",
3951 Some("https://resource.example.com"),
3952 Some(&["scope1".into()]),
3953 )
3954 .await
3955 .unwrap();
3956 assert_eq!(result, Some("Bearer base".to_string()));
3958 }
3959
3960 #[tokio::test]
3963 async fn mcp_tool_with_resolver_injects_auth() {
3964 let transport = Arc::new(Transport::Http(HttpTransport {
3966 client: reqwest::Client::new(),
3967 endpoint: "http://127.0.0.1:1".to_string(),
3968 session_id: RwLock::new(None),
3969 next_id: AtomicU64::new(0),
3970 auth_header: None,
3971 }));
3972
3973 let resolver: Arc<dyn AuthResolver> =
3974 Arc::new(StaticAuthResolver(Some("Bearer user-token".into())));
3975 let tool = McpTool {
3976 transport,
3977 def: ToolDefinition {
3978 name: "test_tool".into(),
3979 description: "test".into(),
3980 input_schema: json!({"type": "object"}),
3981 },
3982 auth_resolver: Some(resolver),
3983 };
3984
3985 let result = tool
3987 .execute(&crate::ExecutionContext::default(), json!({}))
3988 .await
3989 .unwrap();
3990 assert!(result.is_error);
3991 }
3992
3993 #[tokio::test]
3994 async fn mcp_tool_without_resolver_uses_transport_default() {
3995 let transport = Arc::new(Transport::Http(HttpTransport {
3996 client: reqwest::Client::new(),
3997 endpoint: "http://127.0.0.1:1".to_string(),
3998 session_id: RwLock::new(None),
3999 next_id: AtomicU64::new(0),
4000 auth_header: Some("Bearer static".into()),
4001 }));
4002
4003 let tool = McpTool {
4004 transport,
4005 def: ToolDefinition {
4006 name: "test_tool".into(),
4007 description: "test".into(),
4008 input_schema: json!({"type": "object"}),
4009 },
4010 auth_resolver: None,
4011 };
4012
4013 let result = tool
4015 .execute(&crate::ExecutionContext::default(), json!({}))
4016 .await
4017 .unwrap();
4018 assert!(result.is_error);
4019 }
4020
4021 #[test]
4024 fn transport_pool_new_is_empty() {
4025 let pool = McpTransportPool::new();
4026 assert!(!pool.contains("http://example.com/mcp"));
4027 }
4028
4029 #[test]
4030 fn transport_pool_tools_for_user_returns_none_for_unknown_url() {
4031 let pool = McpTransportPool::new();
4032 let resolver: Arc<dyn AuthResolver> = Arc::new(StaticAuthResolver(None));
4033 let result = pool
4034 .tools_for_user("http://unknown.example.com/mcp", resolver)
4035 .unwrap();
4036 assert!(result.is_none());
4037 }
4038
4039 #[test]
4040 fn transport_pool_default_trait() {
4041 let pool = McpTransportPool::default();
4042 assert!(!pool.contains("http://example.com/mcp"));
4043 }
4044
4045 #[test]
4048 fn into_tools_with_auth_stamps_resolver() {
4049 let transport = Arc::new(Transport::Http(HttpTransport {
4050 client: reqwest::Client::new(),
4051 endpoint: "http://unused".to_string(),
4052 session_id: RwLock::new(None),
4053 next_id: AtomicU64::new(0),
4054 auth_header: None,
4055 }));
4056
4057 let client = McpClient {
4058 transport,
4059 tools: vec![McpToolDef {
4060 name: "read_file".into(),
4061 description: Some("Read a file".into()),
4062 input_schema: Some(json!({"type": "object"})),
4063 }],
4064 resources: vec![],
4065 prompts: vec![],
4066 capabilities: ServerCapabilities::default(),
4067 sampling_handler: None,
4068 roots: Vec::new(),
4069 };
4070
4071 let resolver: Arc<dyn AuthResolver> =
4072 Arc::new(StaticAuthResolver(Some("Bearer user".into())));
4073 let tools = client.into_tools_with_auth(resolver);
4074 assert_eq!(tools.len(), 1);
4075 assert_eq!(tools[0].definition().name, "read_file");
4076 }
4077
4078 #[test]
4081 fn static_auth_provider_always_has_credentials() {
4082 let provider = StaticAuthProvider::new(Some("Bearer x".into()));
4083 assert!(provider.has_credentials("u", "t"));
4084 let provider = StaticAuthProvider::new(None);
4085 assert!(provider.has_credentials("u", "t"));
4086 }
4087
4088 #[test]
4089 fn token_exchange_has_credentials_checks_user_tokens() {
4090 let user_tokens = Arc::new(std::sync::RwLock::new(HashMap::<String, String>::new()));
4091 let provider = TokenExchangeAuthProvider::new(
4092 "https://auth.example.com/token",
4093 "client_id",
4094 "client_secret",
4095 "agent_token",
4096 )
4097 .with_user_tokens(Arc::clone(&user_tokens));
4098
4099 assert!(!provider.has_credentials("alice", "acme"));
4101
4102 user_tokens
4104 .write()
4105 .unwrap()
4106 .insert("acme:alice".to_string(), "jwt-alice".to_string());
4107 assert!(provider.has_credentials("alice", "acme"));
4108
4109 assert!(!provider.has_credentials("bob", "acme"));
4111 }
4112
4113 #[tokio::test]
4116 async fn direct_auth_provider_auth_header_for_returns_none() {
4117 let mut tokens = HashMap::new();
4118 tokens.insert("http://mcp.example.com".to_string(), "tok_abc".to_string());
4119 let provider = DirectAuthProvider::new(tokens);
4120 let result = provider.auth_header_for("user1", "tenant1").await.unwrap();
4121 assert!(result.is_none());
4122 }
4123
4124 #[tokio::test]
4125 async fn direct_auth_provider_returns_token_for_known_url() {
4126 let mut tokens = HashMap::new();
4127 tokens.insert("http://mcp.example.com".to_string(), "tok_abc".to_string());
4128 let provider = DirectAuthProvider::new(tokens);
4129 let result = provider
4130 .auth_header_for_resource("u", "t", Some("http://mcp.example.com"), None)
4131 .await
4132 .unwrap();
4133 assert_eq!(result.as_deref(), Some("Bearer tok_abc"));
4134 }
4135
4136 #[tokio::test]
4137 async fn direct_auth_provider_returns_none_for_unknown_url() {
4138 let mut tokens = HashMap::new();
4139 tokens.insert("http://mcp.example.com".to_string(), "tok_abc".to_string());
4140 let provider = DirectAuthProvider::new(tokens);
4141 let result = provider
4142 .auth_header_for_resource("u", "t", Some("http://other.example.com"), None)
4143 .await
4144 .unwrap();
4145 assert!(result.is_none());
4146 }
4147
4148 #[tokio::test]
4149 async fn direct_auth_provider_returns_none_for_no_resource() {
4150 let mut tokens = HashMap::new();
4151 tokens.insert("http://mcp.example.com".to_string(), "tok_abc".to_string());
4152 let provider = DirectAuthProvider::new(tokens);
4153 let result = provider
4154 .auth_header_for_resource("u", "t", None, None)
4155 .await
4156 .unwrap();
4157 assert!(result.is_none());
4158 }
4159
4160 #[test]
4161 fn direct_auth_provider_has_credentials_non_empty() {
4162 let mut tokens = HashMap::new();
4163 tokens.insert("http://mcp.example.com".to_string(), "tok_abc".to_string());
4164 let provider = DirectAuthProvider::new(tokens);
4165 assert!(provider.has_credentials("u", "t"));
4166 }
4167
4168 #[test]
4169 fn direct_auth_provider_has_credentials_empty() {
4170 let provider = DirectAuthProvider::new(HashMap::new());
4171 assert!(!provider.has_credentials("u", "t"));
4172 }
4173
4174 #[test]
4175 fn into_all_tools_with_auth_stamps_resolver() {
4176 let transport = Arc::new(Transport::Http(HttpTransport {
4177 client: reqwest::Client::new(),
4178 endpoint: "http://unused".to_string(),
4179 session_id: RwLock::new(None),
4180 next_id: AtomicU64::new(0),
4181 auth_header: None,
4182 }));
4183
4184 let client = McpClient {
4185 transport,
4186 tools: vec![McpToolDef {
4187 name: "tool1".into(),
4188 description: None,
4189 input_schema: None,
4190 }],
4191 resources: vec![McpResourceDef {
4192 uri: "file:///a.txt".into(),
4193 name: "readme".into(),
4194 description: None,
4195 mime_type: None,
4196 }],
4197 prompts: vec![McpPromptDef {
4198 name: "greet".into(),
4199 description: None,
4200 arguments: vec![],
4201 }],
4202 capabilities: ServerCapabilities::default(),
4203 sampling_handler: None,
4204 roots: Vec::new(),
4205 };
4206
4207 let resolver: Arc<dyn AuthResolver> =
4208 Arc::new(StaticAuthResolver(Some("Bearer user".into())));
4209 let all = client.into_all_tools_with_auth(resolver);
4210 assert_eq!(all.len(), 3);
4211 let names: Vec<String> = all.iter().map(|t| t.definition().name).collect();
4212 assert!(names.contains(&"tool1".to_string()));
4213 assert!(names.contains(&"mcp_resource_readme".to_string()));
4214 assert!(names.contains(&"mcp_prompt_greet".to_string()));
4215 }
4216
4217 #[tokio::test]
4223 async fn connect_http_rejects_loopback_url() {
4224 let result = McpClient::connect_with_auth("http://127.0.0.1/", "Bearer secret").await;
4225 assert!(result.is_err(), "loopback URL must be rejected pre-connect");
4226 let msg = result.err().expect("must be Err").to_string();
4227 assert!(
4228 msg.contains("private")
4229 || msg.contains("loopback")
4230 || msg.contains("refused")
4231 || msg.contains("/127."),
4232 "error should mention SSRF rejection; got: {msg}"
4233 );
4234 }
4235
4236 #[tokio::test]
4240 async fn connect_http_rejects_file_scheme() {
4241 let result = McpClient::connect("file:///etc/passwd").await;
4242 assert!(result.is_err(), "file:// scheme must be rejected");
4243 let msg = result.err().expect("must be Err").to_string();
4244 assert!(
4245 msg.contains("scheme") || msg.contains("file"),
4246 "error should mention scheme; got: {msg}"
4247 );
4248 }
4249
4250 #[tokio::test]
4253 async fn connect_http_rejects_aws_metadata_url() {
4254 let result = McpClient::connect("http://169.254.169.254/").await;
4255 assert!(result.is_err(), "metadata URL must be rejected pre-connect");
4256 }
4257}