1#![allow(missing_docs)]
2use std::collections::HashMap;
3use std::sync::Arc;
4
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use uuid::Uuid;
8
9use super::{Tool, ToolOutput};
10use crate::error::Error;
11
12const PROTOCOL_VERSION: &str = "2025-11-25";
13
14#[derive(Debug, Deserialize)]
17struct JsonRpcRequest {
18 #[allow(dead_code)]
19 jsonrpc: Option<String>,
20 method: String,
21 #[serde(default)]
22 params: Option<Value>,
23 id: Option<Value>,
24}
25
26#[derive(Debug, Serialize)]
27struct JsonRpcResponse {
28 jsonrpc: &'static str,
29 #[serde(skip_serializing_if = "Option::is_none")]
30 result: Option<Value>,
31 #[serde(skip_serializing_if = "Option::is_none")]
32 error: Option<JsonRpcError>,
33 id: Value,
34}
35
36#[derive(Debug, Serialize)]
37struct JsonRpcError {
38 code: i64,
39 message: String,
40}
41
42impl JsonRpcResponse {
43 fn success(id: Value, result: Value) -> Self {
44 Self {
45 jsonrpc: "2.0",
46 result: Some(result),
47 error: None,
48 id,
49 }
50 }
51
52 fn error(id: Value, code: i64, message: impl Into<String>) -> Self {
53 Self {
54 jsonrpc: "2.0",
55 result: None,
56 error: Some(JsonRpcError {
57 code,
58 message: message.into(),
59 }),
60 id,
61 }
62 }
63}
64
65const METHOD_NOT_FOUND: i64 = -32601;
67const INVALID_PARAMS: i64 = -32602;
68const INTERNAL_ERROR: i64 = -32603;
69
70#[derive(Debug, Clone)]
74pub struct McpServerConfig {
75 pub name: String,
76 pub version: String,
77 pub expose_tools: bool,
78 pub expose_resources: bool,
79 pub expose_prompts: bool,
80}
81
82impl Default for McpServerConfig {
83 fn default() -> Self {
84 Self {
85 name: "heartbit".into(),
86 version: env!("CARGO_PKG_VERSION").into(),
87 expose_tools: true,
88 expose_resources: true,
89 expose_prompts: false,
90 }
91 }
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
96#[serde(rename_all = "camelCase")]
97pub struct ServerResource {
98 pub uri: String,
99 pub name: String,
100 #[serde(default, skip_serializing_if = "Option::is_none")]
101 pub description: Option<String>,
102 #[serde(default, skip_serializing_if = "Option::is_none")]
103 pub mime_type: Option<String>,
104}
105
106pub type ResourceReader =
108 Arc<dyn Fn(&str) -> Result<Vec<(Option<String>, String)>, Error> + Send + Sync>;
109
110pub type AuthCallback = Arc<dyn Fn(&str, Option<&str>, Option<&str>) -> bool + Send + Sync>;
116
117const UNAUTHORIZED: i64 = -32001;
120
121const MAX_SESSIONS: usize = 256;
128
129pub struct McpServer {
149 config: McpServerConfig,
150 tools: Vec<Arc<dyn Tool>>,
151 resources: Vec<ServerResource>,
152 resource_reader: Option<ResourceReader>,
153 sessions: parking_lot::RwLock<HashMap<String, ()>>,
154 auth_callback: Option<AuthCallback>,
155}
156
157impl McpServer {
158 pub fn new(config: McpServerConfig) -> Self {
159 Self {
160 config,
161 tools: Vec::new(),
162 resources: Vec::new(),
163 resource_reader: None,
164 sessions: parking_lot::RwLock::new(HashMap::new()),
167 auth_callback: None,
168 }
169 }
170
171 pub fn with_tools(mut self, tools: Vec<Arc<dyn Tool>>) -> Self {
173 self.tools = tools;
174 self
175 }
176
177 pub fn with_resources(
179 mut self,
180 resources: Vec<ServerResource>,
181 reader: ResourceReader,
182 ) -> Self {
183 self.resources = resources;
184 self.resource_reader = Some(reader);
185 self
186 }
187
188 pub fn with_auth_callback(mut self, callback: AuthCallback) -> Self {
195 self.auth_callback = Some(callback);
196 self
197 }
198
199 fn ensure_session(&self, session_id: Option<&str>) -> String {
201 if let Some(sid) = session_id
202 && self.sessions.read().contains_key(sid)
203 {
204 return sid.to_string();
205 }
206 let new_sid = Uuid::new_v4().to_string();
207 let mut sessions = self.sessions.write();
208 if sessions.len() >= MAX_SESSIONS
213 && let Some(victim) = sessions.keys().next().cloned()
214 {
215 sessions.remove(&victim);
216 }
217 sessions.insert(new_sid.clone(), ());
218 new_sid
219 }
220
221 pub async fn handle_request(&self, body: &str, session_id: Option<&str>) -> (String, String) {
227 self.handle_request_with_auth(body, session_id, None).await
228 }
229
230 pub async fn handle_request_with_auth(
234 &self,
235 body: &str,
236 session_id: Option<&str>,
237 auth_header: Option<&str>,
238 ) -> (String, String) {
239 let sid = self.ensure_session(session_id);
240
241 let response = match serde_json::from_str::<JsonRpcRequest>(body) {
242 Ok(req) => {
243 if let Some(ref cb) = self.auth_callback
246 && !cb(&req.method, session_id, auth_header)
247 {
248 let id = req.id.clone().unwrap_or(Value::Null);
249 let err = JsonRpcResponse::error(id, UNAUTHORIZED, "Unauthorized");
250 serde_json::to_string(&err).unwrap_or_default()
251 } else {
252 self.route(req).await
253 }
254 }
255 Err(e) => {
256 let err = JsonRpcResponse::error(Value::Null, -32700, format!("Parse error: {e}"));
257 serde_json::to_string(&err).unwrap_or_default()
258 }
259 };
260
261 (response, sid)
262 }
263
264 async fn route(&self, req: JsonRpcRequest) -> String {
265 let id = req.id.clone().unwrap_or(Value::Null);
266 let result = match req.method.as_str() {
267 "initialize" => self.handle_initialize(&id),
268 "ping" => Ok(JsonRpcResponse::success(id.clone(), serde_json::json!({}))),
269 "tools/list" => self.handle_tools_list(&id, req.params.as_ref()),
270 "tools/call" => self.handle_tools_call(&id, req.params.as_ref()).await,
271 "resources/list" => self.handle_resources_list(&id, req.params.as_ref()),
272 "resources/read" => self.handle_resources_read(&id, req.params.as_ref()),
273 _ if req.method.starts_with("notifications/") => {
274 return String::new();
276 }
277 _ => Ok(JsonRpcResponse::error(
278 id.clone(),
279 METHOD_NOT_FOUND,
280 format!("Method not found: {}", req.method),
281 )),
282 };
283
284 match result {
285 Ok(resp) => serde_json::to_string(&resp).unwrap_or_default(),
286 Err(e) => {
287 let resp = JsonRpcResponse::error(id, INTERNAL_ERROR, e.to_string());
288 serde_json::to_string(&resp).unwrap_or_default()
289 }
290 }
291 }
292
293 fn handle_initialize(&self, id: &Value) -> Result<JsonRpcResponse, Error> {
294 let mut capabilities = serde_json::json!({});
295
296 if self.config.expose_tools && !self.tools.is_empty() {
297 capabilities["tools"] = serde_json::json!({ "listChanged": false });
298 }
299 if self.config.expose_resources && !self.resources.is_empty() {
300 capabilities["resources"] =
301 serde_json::json!({ "subscribe": false, "listChanged": false });
302 }
303
304 Ok(JsonRpcResponse::success(
305 id.clone(),
306 serde_json::json!({
307 "protocolVersion": PROTOCOL_VERSION,
308 "capabilities": capabilities,
309 "serverInfo": {
310 "name": self.config.name,
311 "version": self.config.version
312 }
313 }),
314 ))
315 }
316
317 fn handle_tools_list(
318 &self,
319 id: &Value,
320 _params: Option<&Value>,
321 ) -> Result<JsonRpcResponse, Error> {
322 if !self.config.expose_tools {
323 return Ok(JsonRpcResponse::success(
324 id.clone(),
325 serde_json::json!({ "tools": [] }),
326 ));
327 }
328
329 let tools: Vec<Value> = self
330 .tools
331 .iter()
332 .map(|t| {
333 let def = t.definition();
334 serde_json::json!({
335 "name": def.name,
336 "description": def.description,
337 "inputSchema": def.input_schema,
338 })
339 })
340 .collect();
341
342 Ok(JsonRpcResponse::success(
343 id.clone(),
344 serde_json::json!({ "tools": tools }),
345 ))
346 }
347
348 async fn handle_tools_call(
349 &self,
350 id: &Value,
351 params: Option<&Value>,
352 ) -> Result<JsonRpcResponse, Error> {
353 let params = params.ok_or_else(|| Error::Mcp("Missing params for tools/call".into()))?;
354 let name = params
355 .get("name")
356 .and_then(|v| v.as_str())
357 .ok_or_else(|| Error::Mcp("Missing 'name' in tools/call params".into()))?;
358 let arguments = params
359 .get("arguments")
360 .cloned()
361 .unwrap_or(serde_json::json!({}));
362
363 let tool = self
364 .tools
365 .iter()
366 .find(|t| t.definition().name == name)
367 .ok_or_else(|| Error::Mcp(format!("Tool not found: {name}")))?;
368
369 match tool
374 .execute(&crate::ExecutionContext::default(), arguments)
375 .await
376 {
377 Ok(output) => Ok(JsonRpcResponse::success(
378 id.clone(),
379 tool_output_to_mcp(output),
380 )),
381 Err(e) => Ok(JsonRpcResponse::success(
382 id.clone(),
383 serde_json::json!({
384 "content": [{"type": "text", "text": e.to_string()}],
385 "isError": true
386 }),
387 )),
388 }
389 }
390
391 fn handle_resources_list(
392 &self,
393 id: &Value,
394 _params: Option<&Value>,
395 ) -> Result<JsonRpcResponse, Error> {
396 if !self.config.expose_resources {
397 return Ok(JsonRpcResponse::success(
398 id.clone(),
399 serde_json::json!({ "resources": [] }),
400 ));
401 }
402
403 let resources: Vec<Value> = self
404 .resources
405 .iter()
406 .map(|r| serde_json::to_value(r).unwrap_or_default())
407 .collect();
408
409 Ok(JsonRpcResponse::success(
410 id.clone(),
411 serde_json::json!({ "resources": resources }),
412 ))
413 }
414
415 fn handle_resources_read(
416 &self,
417 id: &Value,
418 params: Option<&Value>,
419 ) -> Result<JsonRpcResponse, Error> {
420 let params =
421 params.ok_or_else(|| Error::Mcp("Missing params for resources/read".into()))?;
422 let uri = params
423 .get("uri")
424 .and_then(|v| v.as_str())
425 .ok_or_else(|| Error::Mcp("Missing 'uri' in resources/read params".into()))?;
426
427 if !self.resources.iter().any(|r| r.uri == uri) {
429 return Ok(JsonRpcResponse::error(
430 id.clone(),
431 INVALID_PARAMS,
432 format!("Resource not found: {uri}"),
433 ));
434 }
435
436 let reader = self
437 .resource_reader
438 .as_ref()
439 .ok_or_else(|| Error::Mcp("No resource reader configured".into()))?;
440
441 match reader(uri) {
442 Ok(contents) => {
443 let content_values: Vec<Value> = contents
444 .into_iter()
445 .map(|(mime, text)| {
446 let mut obj = serde_json::json!({
447 "uri": uri,
448 "text": text,
449 });
450 if let Some(m) = mime {
451 obj["mimeType"] = Value::String(m);
452 }
453 obj
454 })
455 .collect();
456 Ok(JsonRpcResponse::success(
457 id.clone(),
458 serde_json::json!({ "contents": content_values }),
459 ))
460 }
461 Err(e) => Ok(JsonRpcResponse::error(
462 id.clone(),
463 INTERNAL_ERROR,
464 e.to_string(),
465 )),
466 }
467 }
468}
469
470fn tool_output_to_mcp(output: ToolOutput) -> Value {
471 serde_json::json!({
472 "content": [{"type": "text", "text": output.content}],
473 "isError": output.is_error
474 })
475}
476
477#[cfg(test)]
478mod tests {
479 use super::*;
480 use std::future::Future;
481 use std::pin::Pin;
482
483 use crate::llm::types::ToolDefinition;
484 use serde_json::json;
485
486 struct EchoTool;
487
488 impl Tool for EchoTool {
489 fn definition(&self) -> ToolDefinition {
490 ToolDefinition {
491 name: "echo".into(),
492 description: "Echo input".into(),
493 input_schema: json!({
494 "type": "object",
495 "properties": {"text": {"type": "string"}},
496 "required": ["text"]
497 }),
498 }
499 }
500
501 fn execute(
502 &self,
503 _ctx: &crate::ExecutionContext,
504 input: Value,
505 ) -> Pin<Box<dyn Future<Output = Result<ToolOutput, Error>> + Send + '_>> {
506 Box::pin(async move {
507 let text = input
508 .get("text")
509 .and_then(|v| v.as_str())
510 .unwrap_or("no text");
511 Ok(ToolOutput::success(text))
512 })
513 }
514 }
515
516 struct FailTool;
517
518 impl Tool for FailTool {
519 fn definition(&self) -> ToolDefinition {
520 ToolDefinition {
521 name: "fail".into(),
522 description: "Always fails".into(),
523 input_schema: json!({"type": "object"}),
524 }
525 }
526
527 fn execute(
528 &self,
529 _ctx: &crate::ExecutionContext,
530 _input: Value,
531 ) -> Pin<Box<dyn Future<Output = Result<ToolOutput, Error>> + Send + '_>> {
532 Box::pin(async move { Err(Error::Mcp("intentional failure".into())) })
533 }
534 }
535
536 fn make_server() -> McpServer {
537 let echo: Arc<dyn Tool> = Arc::new(EchoTool);
538 let fail: Arc<dyn Tool> = Arc::new(FailTool);
539
540 McpServer::new(McpServerConfig::default())
541 .with_tools(vec![echo, fail])
542 .with_resources(
543 vec![
544 ServerResource {
545 uri: "heartbit://tasks/123".into(),
546 name: "task_123".into(),
547 description: Some("Task result".into()),
548 mime_type: Some("text/plain".into()),
549 },
550 ServerResource {
551 uri: "heartbit://config".into(),
552 name: "config".into(),
553 description: None,
554 mime_type: None,
555 },
556 ],
557 Arc::new(|uri: &str| match uri {
558 "heartbit://tasks/123" => {
559 Ok(vec![(Some("text/plain".into()), "Task completed!".into())])
560 }
561 "heartbit://config" => Ok(vec![(None, "key=value".into())]),
562 _ => Err(Error::Mcp(format!("Unknown resource: {uri}"))),
563 }),
564 )
565 }
566
567 #[tokio::test]
570 async fn initialize_returns_capabilities() {
571 let server = make_server();
572 let req = json!({
573 "jsonrpc": "2.0",
574 "method": "initialize",
575 "params": {
576 "protocolVersion": "2025-11-25",
577 "capabilities": {},
578 "clientInfo": {"name": "test", "version": "1.0"}
579 },
580 "id": 1
581 });
582
583 let (resp, sid) = server
584 .handle_request(&serde_json::to_string(&req).unwrap(), None)
585 .await;
586 let parsed: Value = serde_json::from_str(&resp).unwrap();
587
588 assert_eq!(parsed["result"]["protocolVersion"], "2025-11-25");
589 assert!(parsed["result"]["capabilities"]["tools"].is_object());
590 assert!(parsed["result"]["capabilities"]["resources"].is_object());
591 assert_eq!(parsed["result"]["serverInfo"]["name"], "heartbit");
592 assert!(!sid.is_empty());
593 }
594
595 #[tokio::test]
596 async fn initialize_no_tools_capability_when_empty() {
597 let server = McpServer::new(McpServerConfig::default());
598 let req = json!({
599 "jsonrpc": "2.0",
600 "method": "initialize",
601 "params": {},
602 "id": 1
603 });
604
605 let (resp, _) = server
606 .handle_request(&serde_json::to_string(&req).unwrap(), None)
607 .await;
608 let parsed: Value = serde_json::from_str(&resp).unwrap();
609
610 assert!(parsed["result"]["capabilities"]["tools"].is_null());
611 assert!(parsed["result"]["capabilities"]["resources"].is_null());
612 }
613
614 #[tokio::test]
617 async fn ping_returns_empty_result() {
618 let server = make_server();
619 let req = json!({"jsonrpc": "2.0", "method": "ping", "id": 42});
620 let (resp, _) = server
621 .handle_request(&serde_json::to_string(&req).unwrap(), None)
622 .await;
623 let parsed: Value = serde_json::from_str(&resp).unwrap();
624 assert_eq!(parsed["result"], json!({}));
625 assert_eq!(parsed["id"], 42);
626 }
627
628 #[tokio::test]
631 async fn tools_list_returns_all_tools() {
632 let server = make_server();
633 let req = json!({"jsonrpc": "2.0", "method": "tools/list", "id": 1});
634 let (resp, _) = server
635 .handle_request(&serde_json::to_string(&req).unwrap(), None)
636 .await;
637 let parsed: Value = serde_json::from_str(&resp).unwrap();
638
639 let tools = parsed["result"]["tools"].as_array().unwrap();
640 assert_eq!(tools.len(), 2);
641 assert_eq!(tools[0]["name"], "echo");
642 assert_eq!(tools[1]["name"], "fail");
643 assert!(tools[0]["inputSchema"]["properties"]["text"].is_object());
644 }
645
646 #[tokio::test]
647 async fn tools_list_empty_when_disabled() {
648 let server = McpServer::new(McpServerConfig {
649 expose_tools: false,
650 ..Default::default()
651 })
652 .with_tools(vec![Arc::new(EchoTool)]);
653
654 let req = json!({"jsonrpc": "2.0", "method": "tools/list", "id": 1});
655 let (resp, _) = server
656 .handle_request(&serde_json::to_string(&req).unwrap(), None)
657 .await;
658 let parsed: Value = serde_json::from_str(&resp).unwrap();
659 assert_eq!(parsed["result"]["tools"].as_array().unwrap().len(), 0);
660 }
661
662 #[tokio::test]
665 async fn tools_call_echo() {
666 let server = make_server();
667 let req = json!({
668 "jsonrpc": "2.0",
669 "method": "tools/call",
670 "params": {"name": "echo", "arguments": {"text": "hello world"}},
671 "id": 1
672 });
673 let (resp, _) = server
674 .handle_request(&serde_json::to_string(&req).unwrap(), None)
675 .await;
676 let parsed: Value = serde_json::from_str(&resp).unwrap();
677
678 let content = &parsed["result"]["content"][0];
679 assert_eq!(content["type"], "text");
680 assert_eq!(content["text"], "hello world");
681 assert_eq!(parsed["result"]["isError"], false);
682 }
683
684 #[tokio::test]
685 async fn tools_call_fail_returns_error_content() {
686 let server = make_server();
687 let req = json!({
688 "jsonrpc": "2.0",
689 "method": "tools/call",
690 "params": {"name": "fail", "arguments": {}},
691 "id": 1
692 });
693 let (resp, _) = server
694 .handle_request(&serde_json::to_string(&req).unwrap(), None)
695 .await;
696 let parsed: Value = serde_json::from_str(&resp).unwrap();
697
698 assert_eq!(parsed["result"]["isError"], true);
699 assert!(
700 parsed["result"]["content"][0]["text"]
701 .as_str()
702 .unwrap()
703 .contains("intentional failure")
704 );
705 }
706
707 #[tokio::test]
708 async fn tools_call_not_found() {
709 let server = make_server();
710 let req = json!({
711 "jsonrpc": "2.0",
712 "method": "tools/call",
713 "params": {"name": "nonexistent", "arguments": {}},
714 "id": 1
715 });
716 let (resp, _) = server
717 .handle_request(&serde_json::to_string(&req).unwrap(), None)
718 .await;
719 let parsed: Value = serde_json::from_str(&resp).unwrap();
720 assert!(
721 parsed["error"]["message"]
722 .as_str()
723 .unwrap()
724 .contains("not found")
725 );
726 }
727
728 #[tokio::test]
729 async fn tools_call_missing_params() {
730 let server = make_server();
731 let req = json!({
732 "jsonrpc": "2.0",
733 "method": "tools/call",
734 "id": 1
735 });
736 let (resp, _) = server
737 .handle_request(&serde_json::to_string(&req).unwrap(), None)
738 .await;
739 let parsed: Value = serde_json::from_str(&resp).unwrap();
740 assert!(parsed["error"].is_object());
741 }
742
743 #[tokio::test]
746 async fn resources_list_returns_all() {
747 let server = make_server();
748 let req = json!({"jsonrpc": "2.0", "method": "resources/list", "id": 1});
749 let (resp, _) = server
750 .handle_request(&serde_json::to_string(&req).unwrap(), None)
751 .await;
752 let parsed: Value = serde_json::from_str(&resp).unwrap();
753
754 let resources = parsed["result"]["resources"].as_array().unwrap();
755 assert_eq!(resources.len(), 2);
756 assert_eq!(resources[0]["uri"], "heartbit://tasks/123");
757 assert_eq!(resources[0]["name"], "task_123");
758 assert_eq!(resources[0]["mimeType"], "text/plain");
759 }
760
761 #[tokio::test]
762 async fn resources_list_empty_when_disabled() {
763 let server = McpServer::new(McpServerConfig {
764 expose_resources: false,
765 ..Default::default()
766 })
767 .with_resources(
768 vec![ServerResource {
769 uri: "test://x".into(),
770 name: "x".into(),
771 description: None,
772 mime_type: None,
773 }],
774 Arc::new(|_| Ok(vec![])),
775 );
776
777 let req = json!({"jsonrpc": "2.0", "method": "resources/list", "id": 1});
778 let (resp, _) = server
779 .handle_request(&serde_json::to_string(&req).unwrap(), None)
780 .await;
781 let parsed: Value = serde_json::from_str(&resp).unwrap();
782 assert_eq!(parsed["result"]["resources"].as_array().unwrap().len(), 0);
783 }
784
785 #[tokio::test]
788 async fn resources_read_success() {
789 let server = make_server();
790 let req = json!({
791 "jsonrpc": "2.0",
792 "method": "resources/read",
793 "params": {"uri": "heartbit://tasks/123"},
794 "id": 1
795 });
796 let (resp, _) = server
797 .handle_request(&serde_json::to_string(&req).unwrap(), None)
798 .await;
799 let parsed: Value = serde_json::from_str(&resp).unwrap();
800
801 let contents = parsed["result"]["contents"].as_array().unwrap();
802 assert_eq!(contents.len(), 1);
803 assert_eq!(contents[0]["uri"], "heartbit://tasks/123");
804 assert_eq!(contents[0]["text"], "Task completed!");
805 assert_eq!(contents[0]["mimeType"], "text/plain");
806 }
807
808 #[tokio::test]
809 async fn resources_read_not_found() {
810 let server = make_server();
811 let req = json!({
812 "jsonrpc": "2.0",
813 "method": "resources/read",
814 "params": {"uri": "heartbit://nonexistent"},
815 "id": 1
816 });
817 let (resp, _) = server
818 .handle_request(&serde_json::to_string(&req).unwrap(), None)
819 .await;
820 let parsed: Value = serde_json::from_str(&resp).unwrap();
821 assert!(
822 parsed["error"]["message"]
823 .as_str()
824 .unwrap()
825 .contains("not found")
826 );
827 }
828
829 #[tokio::test]
830 async fn resources_read_missing_uri() {
831 let server = make_server();
832 let req = json!({
833 "jsonrpc": "2.0",
834 "method": "resources/read",
835 "params": {},
836 "id": 1
837 });
838 let (resp, _) = server
839 .handle_request(&serde_json::to_string(&req).unwrap(), None)
840 .await;
841 let parsed: Value = serde_json::from_str(&resp).unwrap();
842 assert!(parsed["error"].is_object());
843 }
844
845 #[tokio::test]
848 async fn unknown_method_returns_error() {
849 let server = make_server();
850 let req = json!({"jsonrpc": "2.0", "method": "foobar", "id": 1});
851 let (resp, _) = server
852 .handle_request(&serde_json::to_string(&req).unwrap(), None)
853 .await;
854 let parsed: Value = serde_json::from_str(&resp).unwrap();
855 assert_eq!(parsed["error"]["code"], METHOD_NOT_FOUND);
856 }
857
858 #[tokio::test]
861 async fn notification_returns_empty_string() {
862 let server = make_server();
863 let req = json!({
864 "jsonrpc": "2.0",
865 "method": "notifications/initialized"
866 });
867 let (resp, _) = server
868 .handle_request(&serde_json::to_string(&req).unwrap(), None)
869 .await;
870 assert!(resp.is_empty());
871 }
872
873 #[tokio::test]
876 async fn invalid_json_returns_parse_error() {
877 let server = make_server();
878 let (resp, _) = server.handle_request("not json", None).await;
879 let parsed: Value = serde_json::from_str(&resp).unwrap();
880 assert_eq!(parsed["error"]["code"], -32700);
881 }
882
883 #[tokio::test]
886 async fn session_id_created_on_first_request() {
887 let server = make_server();
888 let req = json!({"jsonrpc": "2.0", "method": "ping", "id": 1});
889 let (_, sid1) = server
890 .handle_request(&serde_json::to_string(&req).unwrap(), None)
891 .await;
892 assert!(!sid1.is_empty());
893 let (_, sid2) = server
895 .handle_request(&serde_json::to_string(&req).unwrap(), Some(&sid1))
896 .await;
897 assert_eq!(sid1, sid2);
898 }
899
900 #[tokio::test]
901 async fn unknown_session_creates_new() {
902 let server = make_server();
903 let req = json!({"jsonrpc": "2.0", "method": "ping", "id": 1});
904 let (_, sid) = server
905 .handle_request(&serde_json::to_string(&req).unwrap(), Some("bad-session"))
906 .await;
907 assert_ne!(sid, "bad-session");
908 }
909
910 #[test]
913 fn tool_output_success_to_mcp() {
914 let output = ToolOutput::success("hello");
915 let mcp = tool_output_to_mcp(output);
916 assert_eq!(mcp["content"][0]["type"], "text");
917 assert_eq!(mcp["content"][0]["text"], "hello");
918 assert_eq!(mcp["isError"], false);
919 }
920
921 #[test]
922 fn tool_output_error_to_mcp() {
923 let output = ToolOutput::error("bad");
924 let mcp = tool_output_to_mcp(output);
925 assert_eq!(mcp["content"][0]["text"], "bad");
926 assert_eq!(mcp["isError"], true);
927 }
928
929 #[test]
932 fn config_defaults() {
933 let config = McpServerConfig::default();
934 assert_eq!(config.name, "heartbit");
935 assert!(config.expose_tools);
936 assert!(config.expose_resources);
937 assert!(!config.expose_prompts);
938 }
939
940 #[test]
943 fn server_resource_serde_roundtrip() {
944 let r = ServerResource {
945 uri: "heartbit://tasks/1".into(),
946 name: "task_1".into(),
947 description: Some("A task".into()),
948 mime_type: Some("application/json".into()),
949 };
950 let json = serde_json::to_value(&r).unwrap();
951 assert_eq!(json["uri"], "heartbit://tasks/1");
952 assert_eq!(json["mimeType"], "application/json");
953 let parsed: ServerResource = serde_json::from_value(json).unwrap();
954 assert_eq!(parsed.name, "task_1");
955 }
956
957 #[test]
958 fn server_resource_minimal() {
959 let json = json!({"uri": "test://x", "name": "x"});
960 let r: ServerResource = serde_json::from_value(json).unwrap();
961 assert!(r.description.is_none());
962 assert!(r.mime_type.is_none());
963 }
964
965 #[tokio::test]
968 async fn auth_callback_rejects_when_returning_false() {
969 let echo: Arc<dyn Tool> = Arc::new(EchoTool);
970 let server = McpServer::new(McpServerConfig::default())
971 .with_tools(vec![echo])
972 .with_auth_callback(Arc::new(|_method, _sid, _auth| false));
973
974 let req = json!({
975 "jsonrpc": "2.0",
976 "method": "tools/call",
977 "id": 7,
978 "params": {"name": "echo", "arguments": {"text": "should not run"}}
979 });
980 let (resp, _sid) = server.handle_request(&req.to_string(), None).await;
981 let parsed: Value = serde_json::from_str(&resp).unwrap();
982 assert!(parsed["error"].is_object(), "expected error response");
983 let code = parsed["error"]["code"].as_i64().unwrap_or_default();
984 assert_eq!(code, UNAUTHORIZED, "expected 'Unauthorized' code");
985 assert!(
986 parsed["result"].is_null(),
987 "result must be absent on auth failure"
988 );
989 }
990
991 #[tokio::test]
994 async fn auth_callback_allows_when_returning_true() {
995 let echo: Arc<dyn Tool> = Arc::new(EchoTool);
996 let server = McpServer::new(McpServerConfig::default())
997 .with_tools(vec![echo])
998 .with_auth_callback(Arc::new(|_method, _sid, _auth| true));
999
1000 let req = json!({
1001 "jsonrpc": "2.0",
1002 "method": "tools/call",
1003 "id": 8,
1004 "params": {"name": "echo", "arguments": {"text": "ok"}}
1005 });
1006 let (resp, _sid) = server.handle_request(&req.to_string(), None).await;
1007 let parsed: Value = serde_json::from_str(&resp).unwrap();
1008 assert!(parsed["error"].is_null(), "expected success: {parsed}");
1009 assert!(
1010 parsed["result"]["content"][0]["text"]
1011 .as_str()
1012 .unwrap_or_default()
1013 .contains("ok")
1014 );
1015 }
1016
1017 #[tokio::test]
1020 async fn session_map_is_bounded() {
1021 let server = McpServer::new(McpServerConfig::default());
1022 {
1024 let mut sessions = server.sessions.write();
1025 for i in 0..MAX_SESSIONS {
1026 sessions.insert(format!("sid-{i}"), ());
1027 }
1028 assert_eq!(sessions.len(), MAX_SESSIONS);
1029 }
1030 let _ = server.ensure_session(None);
1032 let sessions = server.sessions.read();
1033 assert!(
1034 sessions.len() <= MAX_SESSIONS,
1035 "session map exceeded MAX_SESSIONS = {MAX_SESSIONS}: {}",
1036 sessions.len()
1037 );
1038 }
1039}