1use crate::error::Error;
2use mcp_sdk_rs::{
3 client::Client as McpClient,
4 session::Session,
5 transport::{websocket::WebSocketTransport, Message},
6 types::ServerCapabilities,
7 Implementation, LoggingLevel, Prompt, PromptMessage, Resource, ResourceContents,
8 ResourceTemplate, Tool, ToolResult,
9};
10use serde::{Deserialize, Serialize};
11use serde_json::{json, Value};
12use std::{collections::HashMap, fmt, process::Stdio, sync::Arc};
13use tokio::{
14 process::Command,
15 sync::{
16 mpsc::{UnboundedReceiver, UnboundedSender},
17 Mutex,
18 },
19};
20
21#[derive(Default, Deserialize)]
22pub struct LocalPeerBuilder {
23 name: String,
24 description: String,
25 cmd: String,
26 args: Vec<String>,
27 env: HashMap<String, String>,
28}
29impl LocalPeerBuilder {
30 pub fn new() -> LocalPeerBuilder {
31 LocalPeerBuilder::default()
32 }
33 pub fn with_name(mut self, name: String) -> LocalPeerBuilder {
34 self.name = name;
35 self
36 }
37 pub fn with_description(mut self, description: String) -> LocalPeerBuilder {
38 self.description = description;
39 self
40 }
41 pub fn with_cmd(mut self, cmd: String) -> LocalPeerBuilder {
42 self.cmd = cmd;
43 self
44 }
45 pub fn with_args(mut self, args: Vec<String>) -> LocalPeerBuilder {
46 self.args = args;
47 self
48 }
49 pub fn with_env(mut self, env: HashMap<String, String>) -> LocalPeerBuilder {
50 self.env = env;
51 self
52 }
53 pub async fn build(self) -> Result<Peer, Error> {
54 let mut command = Command::new(self.cmd.clone());
55 command.args(self.args.clone());
56 command.envs(self.env.clone());
57 command.stdin(Stdio::piped());
58 command.stdout(Stdio::piped());
59 let (request_tx, request_rx): (UnboundedSender<Message>, UnboundedReceiver<Message>) =
60 tokio::sync::mpsc::unbounded_channel();
61 let (response_tx, response_rx): (UnboundedSender<Message>, UnboundedReceiver<Message>) =
62 tokio::sync::mpsc::unbounded_channel();
63 let session = Session::Local {
64 handler: None,
65 command: command,
66 receiver: Arc::new(Mutex::new(request_rx)),
67 sender: Arc::new(response_tx),
68 };
69 session.start().await.map_err(|_| Error::Internal)?;
70 let client = McpClient::new(request_tx, response_rx);
71 let implementation = Implementation {
72 name: "commune".to_string(),
73 version: env!("CARGO_PKG_VERSION").to_string(),
74 };
75 let caps = client
76 .initialize(implementation, None)
77 .await
78 .map_err(|_| Error::ClientInitialization)?;
79 log::debug!(
80 "connected to local peer '{}'; capabilities: {:?}",
81 self.name,
82 caps
83 );
84 Ok(Peer::Local {
85 name: self.name,
86 description: self.description,
87 cmd: self.cmd,
88 args: self.args,
89 env: self.env,
90 capabilities: caps,
91 client: Some(client),
92 })
93 }
94}
95
96#[derive(Default, Deserialize)]
97pub struct RemotePeerBuilder {
98 name: String,
99 url: String,
100 description: String,
101}
102impl RemotePeerBuilder {
103 pub fn new() -> RemotePeerBuilder {
104 RemotePeerBuilder::default()
105 }
106 pub fn with_name(mut self, name: String) -> RemotePeerBuilder {
107 self.name = name;
108 self
109 }
110 pub fn with_url(mut self, url: String) -> RemotePeerBuilder {
111 self.url = url;
112 self
113 }
114 pub fn with_description(mut self, description: String) -> RemotePeerBuilder {
115 self.description = description;
116 self
117 }
118 pub async fn build(self) -> Result<Peer, Error> {
119 let transport = WebSocketTransport::new(self.url.as_str())
120 .await
121 .map_err(|_| Error::Internal)?;
122 let (request_tx, request_rx): (UnboundedSender<Message>, UnboundedReceiver<Message>) =
123 tokio::sync::mpsc::unbounded_channel();
124 let (response_tx, response_rx): (UnboundedSender<Message>, UnboundedReceiver<Message>) =
125 tokio::sync::mpsc::unbounded_channel();
126 let session = Session::Remote {
127 handler: None,
128 transport: Arc::new(transport),
129 receiver: Arc::new(Mutex::new(request_rx)),
130 sender: Arc::new(response_tx),
131 };
132 session
134 .start()
135 .await
136 .map_err(|_| Error::ClientInitialization)?;
137 let client = McpClient::new(request_tx, response_rx);
138 let implementation = Implementation {
139 name: "commune".to_string(),
140 version: env!("CARGO_PKG_VERSION").to_string(),
141 };
142 let caps = client
143 .initialize(implementation, None)
144 .await
145 .map_err(|_| Error::ClientInitialization)?;
146 log::debug!(
147 "connected to peer '{}' @ {}; capabilities: {:?}",
148 self.name,
149 self.url,
150 caps
151 );
152 Ok(Peer::Remote {
153 name: self.name,
154 url: self.url,
155 description: self.description,
156 capabilities: caps,
157 client: Some(client),
158 })
159 }
160}
161
162#[derive(Clone, Debug, Serialize, Deserialize)]
163pub enum Peer {
164 Local {
165 name: String,
166 description: String,
167 cmd: String,
168 args: Vec<String>,
169 env: HashMap<String, String>,
170 capabilities: ServerCapabilities,
171 #[serde(skip)]
172 client: Option<McpClient>,
173 },
174 Remote {
175 name: String,
176 description: String,
177 url: String,
178 capabilities: ServerCapabilities,
179 #[serde(skip)]
180 client: Option<McpClient>,
181 },
182}
183
184impl Peer {
185 pub async fn list_tools(&self) -> Result<Vec<Tool>, Error> {
187 match self {
188 Peer::Local {
189 name: _,
190 description: _,
191 cmd: _,
192 args: _,
193 env: _,
194 capabilities,
195 client: _,
196 } => {
197 if capabilities.tools.is_some() {
198 let res = self
199 .paginated_request("tools")
200 .await
201 .map_err(|_| Error::Internal)?;
202 let tools: Vec<Tool> = res
203 .into_iter()
204 .map(|r| serde_json::from_value(r).unwrap())
205 .collect();
206 Ok(tools)
207 } else {
208 Err(Error::Unsupported)
209 }
210 }
211 Peer::Remote {
212 name: _,
213 description: _,
214 url: _,
215 capabilities,
216 client: _,
217 } => {
218 if capabilities.tools.is_some() {
219 let res = self
220 .paginated_request("tools")
221 .await
222 .map_err(|_| Error::Internal)?;
223 let tools: Vec<Tool> = res
224 .into_iter()
225 .map(|r| serde_json::from_value(r).unwrap())
226 .collect();
227 Ok(tools)
228 } else {
229 Err(Error::Unsupported)
230 }
231 }
232 }
233 }
234 pub async fn call_tool(&self, name: &str, params: Option<Value>) -> Result<ToolResult, Error> {
236 match self {
237 Peer::Local {
238 name: _,
239 description: _,
240 cmd: _,
241 args: _,
242 env: _,
243 capabilities,
244 client,
245 } => {
246 if capabilities.tools.is_some() {
247 if let Some(c) = &client {
248 let val = c
249 .request(
250 "tools/call",
251 Some(json!({
252 "name": name,
253 "arguments": params.unwrap_or(json!({}))
254 })),
255 )
256 .await
257 .map_err(|e| Error::McpClient(format!("{e}")))?;
258 let tr: ToolResult =
259 serde_json::from_value(val).expect("an mcp formatted tool result");
260 Ok(tr)
261 } else {
262 Err(Error::UninitializedClient)
263 }
264 } else {
265 Err(Error::Unsupported)
266 }
267 }
268 Peer::Remote {
269 name: _,
270 description: _,
271 url: _,
272 capabilities,
273 client,
274 } => {
275 if capabilities.tools.is_some() {
276 if let Some(c) = &client {
277 let val = c
278 .request(
279 "tools/call",
280 Some(json!({
281 "name": name,
282 "arguments": params.unwrap_or(json!({}))
283 })),
284 )
285 .await
286 .map_err(|e| Error::McpClient(format!("{e}")))?;
287 let tr: ToolResult =
288 serde_json::from_value(val).expect("an mcp formatted tool result");
289 Ok(tr)
290 } else {
291 Err(Error::UninitializedClient)
292 }
293 } else {
294 Err(Error::Unsupported)
295 }
296 }
297 }
298 }
299 pub async fn list_resources(&self) -> Result<Vec<Resource>, Error> {
301 match self {
302 Peer::Local {
303 name: _,
304 description: _,
305 cmd: _,
306 args: _,
307 env: _,
308 capabilities,
309 client: _,
310 } => {
311 if capabilities.resources.is_some() {
312 let res = self.paginated_request("resources").await?;
313 let resources: Vec<Resource> = res
314 .into_iter()
315 .map(|r| serde_json::from_value(r).unwrap())
316 .collect();
317 Ok(resources)
318 } else {
319 Err(Error::Unsupported)
320 }
321 }
322 Peer::Remote {
323 name: _,
324 description: _,
325 url: _,
326 capabilities,
327 client: _,
328 } => {
329 if capabilities.resources.is_some() {
330 let res = self.paginated_request("resources").await?;
331 let resources: Vec<Resource> = res
332 .into_iter()
333 .map(|r| serde_json::from_value(r).unwrap())
334 .collect();
335 Ok(resources)
336 } else {
337 Err(Error::Unsupported)
338 }
339 }
340 }
341 }
342 pub async fn get_resource(&self, uri: &str) -> Result<Vec<ResourceContents>, Error> {
344 match self {
345 Peer::Local {
346 name: _,
347 description: _,
348 cmd: _,
349 args: _,
350 env: _,
351 capabilities,
352 client,
353 } => {
354 if capabilities.resources.is_some() {
355 if let Some(c) = &client {
356 let value = c
357 .request("resources/read", Some(json!({"uri": uri})))
358 .await
359 .map_err(|_| Error::McpClient("failed to read resource".to_string()))?;
360 let resource_obj: HashMap<String, Value> =
361 serde_json::from_value(value).map_err(|_| Error::InvalidResponse)?;
362 if let Some(val) = resource_obj.get("contents") {
363 let contents: Vec<ResourceContents> =
364 serde_json::from_value(val.clone())
365 .map_err(|_| Error::InvalidResponse)?;
366 Ok(contents)
367 } else {
368 Ok(vec![])
369 }
370 } else {
371 Err(Error::UninitializedClient)
372 }
373 } else {
374 Err(Error::Unsupported)
375 }
376 }
377 Peer::Remote {
378 name: _,
379 description: _,
380 url: _,
381 capabilities,
382 client,
383 } => {
384 if capabilities.resources.is_some() {
385 if let Some(c) = &client {
386 let value = c
387 .request("resources/read", Some(json!({"uri": uri})))
388 .await
389 .map_err(|_| Error::McpClient("failed to read resource".to_string()))?;
390 let resource_obj: HashMap<String, Value> =
391 serde_json::from_value(value).map_err(|_| Error::InvalidResponse)?;
392 if let Some(val) = resource_obj.get("contents") {
393 let contents: Vec<ResourceContents> =
394 serde_json::from_value(val.clone())
395 .map_err(|_| Error::InvalidResponse)?;
396 Ok(contents)
397 } else {
398 Ok(vec![])
399 }
400 } else {
401 Err(Error::UninitializedClient)
402 }
403 } else {
404 Err(Error::Unsupported)
405 }
406 }
407 }
408 }
409 pub async fn list_resource_templates(&self) -> Result<Vec<ResourceTemplate>, Error> {
411 match self {
412 Peer::Local {
413 name: _,
414 description: _,
415 cmd: _,
416 args: _,
417 env: _,
418 capabilities,
419 client,
420 } => {
421 if capabilities.resources.is_some() {
422 if let Some(c) = &client {
423 let value =
424 c.request("resources/templates/list", None)
425 .await
426 .map_err(|_| {
427 Error::McpClient("failed to list templates".to_string())
428 })?;
429 let template_obj: HashMap<String, Value> =
430 serde_json::from_value(value).map_err(|_| Error::InvalidResponse)?;
431 if let Some(val) = template_obj.get("resourceTemplates") {
432 let contents: Vec<ResourceTemplate> =
433 serde_json::from_value(val.clone())
434 .map_err(|_| Error::InvalidResponse)?;
435 Ok(contents)
436 } else {
437 Ok(vec![])
438 }
439 } else {
440 Err(Error::UninitializedClient)
441 }
442 } else {
443 Err(Error::Unsupported)
444 }
445 }
446 Peer::Remote {
447 name: _,
448 description: _,
449 url: _,
450 capabilities,
451 client,
452 } => {
453 if capabilities.resources.is_some() {
454 if let Some(c) = &client {
455 let value =
456 c.request("resources/templates/list", None)
457 .await
458 .map_err(|_| {
459 Error::McpClient("failed to list templates".to_string())
460 })?;
461 let template_obj: HashMap<String, Value> =
462 serde_json::from_value(value).map_err(|_| Error::InvalidResponse)?;
463 if let Some(val) = template_obj.get("resourceTemplates") {
464 let contents: Vec<ResourceTemplate> =
465 serde_json::from_value(val.clone())
466 .map_err(|_| Error::InvalidResponse)?;
467 Ok(contents)
468 } else {
469 Ok(vec![])
470 }
471 } else {
472 Err(Error::UninitializedClient)
473 }
474 } else {
475 Err(Error::Unsupported)
476 }
477 }
478 }
479 }
480 pub async fn subscribe(&self, uri: &str) -> Result<(), Error> {
482 match self {
483 Peer::Local {
484 name: _,
485 description: _,
486 cmd: _,
487 args: _,
488 env: _,
489 capabilities: _,
490 client,
491 } => {
492 if let Some(c) = &client {
493 c.subscribe(uri).await.map_err(|_| {
494 Error::McpClient("failed to subscribe to update notifications".to_string())
495 })?;
496 }
497 Ok(())
498 }
499 Peer::Remote {
500 name: _,
501 description: _,
502 url: _,
503 capabilities: _,
504 client,
505 } => {
506 if let Some(c) = &client {
507 c.subscribe(uri).await.map_err(|_| {
508 Error::McpClient("failed to subscribe to update notifications".to_string())
509 })?;
510 }
511 Ok(())
512 }
513 }
514 }
515 pub async fn list_prompts(&self) -> Result<Vec<Prompt>, Error> {
517 match self {
518 Peer::Local {
519 name: _,
520 description: _,
521 cmd: _,
522 args: _,
523 env: _,
524 capabilities,
525 client: _,
526 } => {
527 if capabilities.prompts.is_some() {
528 let res = self.paginated_request("prompts").await?;
529 let prompts: Vec<Prompt> = res
530 .into_iter()
531 .map(|r| serde_json::from_value(r).unwrap())
532 .collect();
533 Ok(prompts)
534 } else {
535 Err(Error::Unsupported)
536 }
537 }
538 Peer::Remote {
539 name: _,
540 description: _,
541 url: _,
542 capabilities,
543 client: _,
544 } => {
545 if capabilities.prompts.is_some() {
546 let res = self.paginated_request("prompts").await?;
547 let prompts: Vec<Prompt> = res
548 .into_iter()
549 .map(|r| serde_json::from_value(r).unwrap())
550 .collect();
551 Ok(prompts)
552 } else {
553 Err(Error::Unsupported)
554 }
555 }
556 }
557 }
558 pub async fn get_prompt(
560 &self,
561 name: &str,
562 args: Option<Value>,
563 ) -> Result<Vec<PromptMessage>, Error> {
564 match self {
565 Peer::Local {
566 name: _,
567 description: _,
568 cmd: _,
569 args: _,
570 env: _,
571 capabilities,
572 client,
573 } => {
574 if capabilities.prompts.is_some() {
575 if let Some(c) = &client {
576 let value = c
577 .request(
578 "prompts/get",
579 Some(json!({"name": name, "arguments": args})),
580 )
581 .await
582 .map_err(|_| Error::McpClient("failed to get prompt".to_string()))?;
583 let prompt_obj: HashMap<String, Value> =
584 serde_json::from_value(value).map_err(|_| Error::InvalidResponse)?;
585 if let Some(val) = prompt_obj.get("messages") {
586 let prompt: Vec<PromptMessage> = serde_json::from_value(val.clone())
587 .map_err(|_| Error::InvalidResponse)?;
588 Ok(prompt)
589 } else {
590 Err(Error::InvalidResponse)
591 }
592 } else {
593 Err(Error::UninitializedClient)
594 }
595 } else {
596 Err(Error::Unsupported)
597 }
598 }
599 Peer::Remote {
600 name: _,
601 description: _,
602 url: _,
603 capabilities,
604 client,
605 } => {
606 if capabilities.prompts.is_some() {
607 if let Some(c) = &client {
608 let value = c
609 .request(
610 "prompts/get",
611 Some(json!({"name": name, "arguments": args})),
612 )
613 .await
614 .map_err(|_| Error::McpClient("failed to get prompt".to_string()))?;
615 let prompt_obj: HashMap<String, Value> =
616 serde_json::from_value(value).map_err(|_| Error::InvalidResponse)?;
617 if let Some(val) = prompt_obj.get("messages") {
618 let prompt: Vec<PromptMessage> = serde_json::from_value(val.clone())
619 .map_err(|_| Error::InvalidResponse)?;
620 Ok(prompt)
621 } else {
622 Err(Error::InvalidResponse)
623 }
624 } else {
625 Err(Error::UninitializedClient)
626 }
627 } else {
628 Err(Error::Unsupported)
629 }
630 }
631 }
632 }
633
634 pub async fn set_log_level(&self, level: LoggingLevel) -> Result<(), Error> {
635 match self {
636 Peer::Local {
637 name: _,
638 description: _,
639 cmd: _,
640 args: _,
641 env: _,
642 capabilities,
643 client,
644 } => {
645 if capabilities.logging.is_some() {
646 if let Some(c) = &client {
647 c.set_log_level(level)
648 .await
649 .map_err(|_| Error::McpClient("failed to set log level".to_string()))
650 } else {
651 Err(Error::UninitializedClient)
652 }
653 } else {
654 Err(Error::Unsupported)
655 }
656 }
657 Peer::Remote {
658 name: _,
659 description: _,
660 url: _,
661 capabilities,
662 client,
663 } => {
664 if capabilities.logging.is_some() {
665 if let Some(c) = &client {
666 c.set_log_level(level)
667 .await
668 .map_err(|_| Error::McpClient("failed to set log level".to_string()))
669 } else {
670 Err(Error::UninitializedClient)
671 }
672 } else {
673 Err(Error::Unsupported)
674 }
675 }
676 }
677 }
678
679 async fn paginated_request(&self, thing: &str) -> Result<Vec<Value>, Error> {
681 match self {
682 Peer::Local {
683 name: _,
684 description: _,
685 cmd: _,
686 args: _,
687 env: _,
688 capabilities: _,
689 client,
690 } => {
691 if let Some(client) = &client {
692 let mut res: Vec<Value> = vec![];
693 let mut next_cursor: Option<String> = None;
694 let path = format!("{}/list", thing);
695 let value = client.request(path.as_str(), None).await.map_err(|_| {
696 Error::McpClient("failed to perform paginated request".to_string())
697 })?;
698 let resp_obj: HashMap<String, Value> =
699 serde_json::from_value(value).map_err(|_| Error::InvalidResponse)?;
700 if let Some(val) = resp_obj.get(thing) {
701 if let Some(arr) = val.clone().as_array_mut() {
702 res.append(arr);
703 }
704 if let Some(nc_val) = resp_obj.get("nextCursor") {
705 if let Some(nc) = nc_val.as_str() {
706 next_cursor = Some(nc.to_string());
707 }
708 }
709 }
710 while let Some(ref c) = next_cursor {
711 let value = client
712 .request(path.as_str(), Some(json!({ "cursor": c })))
713 .await
714 .map_err(|_| {
715 Error::McpClient("failed to perform paginated request".to_string())
716 })?;
717 let resp_obj: HashMap<String, Value> =
718 serde_json::from_value(value).map_err(|_| Error::InvalidResponse)?;
719 if let Some(val) = resp_obj.get(thing) {
720 if let Some(arr) = val.clone().as_array_mut() {
721 res.append(arr);
722 }
723 if let Some(nc_val) = resp_obj.get("nextCursor") {
724 if let Some(nc) = nc_val.as_str() {
725 next_cursor = Some(nc.to_string());
726 } else {
727 next_cursor = None;
728 }
729 } else {
730 next_cursor = None;
731 }
732 }
733 }
734 Ok(res)
735 } else {
736 Err(Error::UninitializedClient)
737 }
738 }
739 Peer::Remote {
740 name: _,
741 description: _,
742 url: _,
743 capabilities: _,
744 client,
745 } => {
746 if let Some(client) = &client {
747 let mut res: Vec<Value> = vec![];
748 let mut next_cursor: Option<String> = None;
749 let path = format!("{}/list", thing);
750 let value = client.request(path.as_str(), None).await.map_err(|_| {
751 Error::McpClient("failed to perform paginated request".to_string())
752 })?;
753 let resp_obj: HashMap<String, Value> =
754 serde_json::from_value(value).map_err(|_| Error::InvalidResponse)?;
755 if let Some(val) = resp_obj.get(thing) {
756 if let Some(arr) = val.clone().as_array_mut() {
757 res.append(arr);
758 }
759 if let Some(nc_val) = resp_obj.get("nextCursor") {
760 if let Some(nc) = nc_val.as_str() {
761 next_cursor = Some(nc.to_string());
762 }
763 }
764 }
765 while let Some(ref c) = next_cursor {
766 let value = client
767 .request(path.as_str(), Some(json!({ "cursor": c })))
768 .await
769 .map_err(|_| {
770 Error::McpClient("failed to perform paginated request".to_string())
771 })?;
772 let resp_obj: HashMap<String, Value> =
773 serde_json::from_value(value).map_err(|_| Error::InvalidResponse)?;
774 if let Some(val) = resp_obj.get(thing) {
775 if let Some(arr) = val.clone().as_array_mut() {
776 res.append(arr);
777 }
778 if let Some(nc_val) = resp_obj.get("nextCursor") {
779 if let Some(nc) = nc_val.as_str() {
780 next_cursor = Some(nc.to_string());
781 } else {
782 next_cursor = None;
783 }
784 } else {
785 next_cursor = None;
786 }
787 }
788 }
789 Ok(res)
790 } else {
791 Err(Error::UninitializedClient)
792 }
793 }
794 }
795 }
796}
797
798impl PartialEq for Peer {
799 fn eq(&self, other: &Self) -> bool {
800 match self {
801 Peer::Local {
802 name: _,
803 description: _,
804 cmd,
805 args: _,
806 env: _,
807 capabilities: _,
808 client: _,
809 } => match other {
810 Peer::Local {
811 name: _,
812 description: _,
813 cmd: other_cmd,
814 args: _,
815 env: _,
816 capabilities: _,
817 client: _,
818 } => cmd == other_cmd,
819 Peer::Remote {
820 name: _,
821 description: _,
822 url: _,
823 capabilities: _,
824 client: _,
825 } => false,
826 },
827 Peer::Remote {
828 name: _,
829 description: _,
830 url,
831 capabilities: _,
832 client: _,
833 } => match other {
834 Peer::Local {
835 name: _,
836 description: _,
837 cmd: _,
838 args: _,
839 env: _,
840 capabilities: _,
841 client: _,
842 } => false,
843 Peer::Remote {
844 name: _,
845 description: _,
846 url: other_url,
847 capabilities: _,
848 client: _,
849 } => url == other_url,
850 },
851 }
852 }
853}
854
855pub struct PeerResource {
856 pub peer: Peer,
857 pub resource: Resource,
858}
859
860impl fmt::Display for PeerResource {
861 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
862 write!(f, "{}", self.resource.name)
863 }
864}
865
866pub struct PeerPrompt {
867 pub peer: Peer,
868 pub prompt: Prompt,
869}
870
871impl fmt::Display for PeerPrompt {
872 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
873 write!(f, "{}", self.prompt.name)
874 }
875}