mcp_sdk_rs/session/mod.rs
1use crate::{
2 client::{ClientHandler, DefaultClientHandler},
3 error::Error,
4 protocol::Response,
5 transport::{stdio::StdioTransport, Message, Transport},
6};
7use futures::StreamExt;
8use std::sync::Arc;
9use tokio::{
10 process::Command,
11 sync::{
12 mpsc::{UnboundedReceiver, UnboundedSender},
13 Mutex,
14 },
15};
16
17pub enum Session {
18 Local {
19 handler: Option<Arc<dyn ClientHandler>>,
20 command: Command,
21 receiver: Arc<Mutex<UnboundedReceiver<Message>>>,
22 sender: Arc<UnboundedSender<Message>>,
23 },
24 Remote {
25 handler: Option<Arc<dyn ClientHandler>>,
26 transport: Arc<dyn Transport>,
27 receiver: Arc<Mutex<UnboundedReceiver<Message>>>,
28 sender: Arc<UnboundedSender<Message>>,
29 },
30}
31impl Session {
32 /// Start the session and listen for messages
33 pub async fn start(self) -> Result<(), Error> {
34 match self {
35 Session::Local {
36 handler,
37 command,
38 receiver,
39 sender,
40 } => {
41 // spawn the child process - wrap ProcessManager to ensure cleanup
42 let pm = Arc::new(tokio::sync::Mutex::new(
43 crate::process::ProcessManager::new(),
44 ));
45 let (output_tx, output_rx) = tokio::sync::mpsc::channel(100);
46 let process_tx = {
47 let mut manager = pm.lock().await;
48 manager
49 .start_process(command, output_tx.clone())
50 .await
51 .expect("a spawned subprocess")
52 };
53
54 let transport = Arc::new(StdioTransport::new(output_rx, process_tx));
55 let handler = handler.unwrap_or(Arc::new(DefaultClientHandler));
56 let t = transport.clone();
57
58 // Clone ProcessManager for cleanup tasks
59 let pm_for_receiver_task = pm.clone();
60 let pm_for_sender_task = pm.clone();
61
62 // listen for messages from the server
63 tokio::spawn(async move {
64 let mut stream = t.receive();
65 while let Some(result) = stream.next().await {
66 match result {
67 Ok(message) => match &message {
68 Message::Request(r) => {
69 let res = handler
70 .handle_request(r.method.clone(), r.params.clone())
71 .await;
72 if t.send(Message::Response(Response::success(
73 r.id.clone(),
74 Some(res.unwrap()),
75 )))
76 .await
77 .is_err()
78 {
79 break;
80 }
81 }
82 Message::Response(_) => {
83 if sender.send(message).is_err() {
84 break;
85 }
86 }
87 Message::Notification(n) => {
88 if handler
89 .handle_notification(n.method.clone(), n.params.clone())
90 .await
91 .is_err()
92 {
93 break;
94 }
95 }
96 },
97 Err(_) => break,
98 }
99 }
100 // Clean up the process when the receiver stream ends
101 let mut manager = pm_for_receiver_task.lock().await;
102 manager.shutdown().await;
103 });
104 // listen for messages to send to the server
105 let rx_clone = receiver.clone();
106 let tx_clone = transport.clone();
107 tokio::spawn(async move {
108 let mut stream = rx_clone.lock().await;
109 while let Some(message) = stream.recv().await {
110 if tx_clone.send(message).await.is_err() {
111 break;
112 }
113 }
114 // Clean up the process when the sender stream ends
115 let mut manager = pm_for_sender_task.lock().await;
116 manager.shutdown().await;
117 });
118
119 Ok(())
120 }
121 Session::Remote {
122 handler,
123 transport,
124 receiver,
125 sender,
126 } => {
127 let t = transport.clone();
128 let handler = handler.unwrap_or(Arc::new(DefaultClientHandler));
129 // listen for messages from the server
130 tokio::spawn(async move {
131 let mut stream = t.receive();
132 while let Some(result) = stream.next().await {
133 match result {
134 Ok(message) => match &message {
135 Message::Request(r) => {
136 let res = handler
137 .handle_request(r.method.clone(), r.params.clone())
138 .await;
139 if t.send(Message::Response(Response::success(
140 r.id.clone(),
141 Some(res.unwrap()),
142 )))
143 .await
144 .is_err()
145 {
146 break;
147 }
148 }
149 Message::Response(_) => {
150 if sender.send(message).is_err() {
151 break;
152 }
153 }
154 Message::Notification(n) => {
155 if handler
156 .handle_notification(n.method.clone(), n.params.clone())
157 .await
158 .is_err()
159 {
160 break;
161 }
162 }
163 },
164 Err(_) => break,
165 }
166 }
167 });
168 // listen for messages to send to the server
169 let rx_clone = receiver.clone();
170 let tx_clone = transport.clone();
171 tokio::spawn(async move {
172 let mut stream = rx_clone.lock().await;
173 while let Some(message) = stream.recv().await {
174 if tx_clone.send(message).await.is_err() {
175 break;
176 }
177 }
178 });
179 Ok(())
180 }
181 }
182 }
183}