mcp_sdk_rs/session/mod.rs
1use crate::{
2 client::{ClientHandler, DefaultClientHandler},
3 error::Error,
4 protocol::Response,
5 transport::{stdio::StdioTransport, Message, Transport},
6};
7use futures::StreamExt;
8use std::sync::Arc;
9use tokio::{
10 process::Command,
11 sync::{
12 mpsc::{UnboundedReceiver, UnboundedSender},
13 Mutex,
14 },
15};
16
17// #[derive(Clone)]
18pub enum Session {
19 Local {
20 handler: Option<Arc<dyn ClientHandler>>,
21 command: Command,
22 receiver: Arc<Mutex<UnboundedReceiver<Message>>>,
23 sender: Arc<UnboundedSender<Message>>,
24 },
25 Remote {
26 handler: Option<Arc<dyn ClientHandler>>,
27 transport: Arc<dyn Transport>,
28 receiver: Arc<Mutex<UnboundedReceiver<Message>>>,
29 sender: Arc<UnboundedSender<Message>>,
30 },
31}
32impl Session {
33 /// Start the session and listen for messages
34 pub async fn start(self) -> Result<(), Error> {
35 match self {
36 Session::Local {
37 handler,
38 command,
39 receiver,
40 sender,
41 } => {
42 // spawn the child process
43 let mut pm = crate::process::ProcessManager::new();
44 let (output_tx, output_rx) = tokio::sync::mpsc::channel(100);
45 let process_tx = pm
46 .start_process(command, output_tx.clone())
47 .await
48 .expect("a spawned subprocess");
49
50 let transport = Arc::new(StdioTransport::new(output_rx, process_tx));
51 let handler = handler.unwrap_or(Arc::new(DefaultClientHandler));
52 let t = transport.clone();
53
54 // listen for messages from the server
55 tokio::spawn(async move {
56 let mut stream = t.receive();
57 while let Some(result) = stream.next().await {
58 match result {
59 Ok(message) => match &message {
60 Message::Request(r) => {
61 let res = handler
62 .handle_request(r.method.clone(), r.params.clone())
63 .await;
64 if t.send(Message::Response(Response::success(
65 r.id.clone(),
66 Some(res.unwrap()),
67 )))
68 .await
69 .is_err()
70 {
71 break;
72 }
73 }
74 Message::Response(_) => {
75 if sender.send(message).is_err() {
76 break;
77 }
78 }
79 Message::Notification(n) => {
80 if handler
81 .handle_notification(n.method.clone(), n.params.clone())
82 .await
83 .is_err()
84 {
85 break;
86 }
87 }
88 },
89 Err(_) => break,
90 }
91 }
92 });
93 // listen for messages to send to the server
94 let rx_clone = receiver.clone();
95 let tx_clone = transport.clone();
96 tokio::spawn(async move {
97 let mut stream = rx_clone.lock().await;
98 while let Some(message) = stream.recv().await {
99 tx_clone.send(message).await.unwrap();
100 }
101 });
102 Ok(())
103 }
104 Session::Remote {
105 handler,
106 transport,
107 receiver,
108 sender,
109 } => {
110 let t = transport.clone();
111 let handler = handler.unwrap_or(Arc::new(DefaultClientHandler));
112 // listen for messages from the server
113 tokio::spawn(async move {
114 let mut stream = t.receive();
115 while let Some(result) = stream.next().await {
116 match result {
117 Ok(message) => match &message {
118 Message::Request(r) => {
119 let res = handler
120 .handle_request(r.method.clone(), r.params.clone())
121 .await;
122 if t.send(Message::Response(Response::success(
123 r.id.clone(),
124 Some(res.unwrap()),
125 )))
126 .await
127 .is_err()
128 {
129 break;
130 }
131 }
132 Message::Response(_) => {
133 if sender.send(message).is_err() {
134 break;
135 }
136 }
137 Message::Notification(n) => {
138 if handler
139 .handle_notification(n.method.clone(), n.params.clone())
140 .await
141 .is_err()
142 {
143 break;
144 }
145 }
146 },
147 Err(_) => break,
148 }
149 }
150 });
151 // listen for messages to send to the server
152 let rx_clone = receiver.clone();
153 let tx_clone = transport.clone();
154 tokio::spawn(async move {
155 let mut stream = rx_clone.lock().await;
156 while let Some(message) = stream.recv().await {
157 tx_clone.send(message).await.unwrap();
158 }
159 });
160 Ok(())
161 }
162 }
163 }
164}