mcp_sdk_rs/session/
mod.rs

1use crate::{
2    client::{ClientHandler, DefaultClientHandler},
3    error::Error,
4    protocol::Response,
5    transport::{stdio::StdioTransport, Message, Transport},
6};
7use futures::StreamExt;
8use std::sync::Arc;
9use tokio::{
10    process::Command,
11    sync::{
12        mpsc::{UnboundedReceiver, UnboundedSender},
13        Mutex,
14    },
15};
16
17pub enum Session {
18    Local {
19        handler: Option<Arc<dyn ClientHandler>>,
20        command: Command,
21        receiver: Arc<Mutex<UnboundedReceiver<Message>>>,
22        sender: Arc<UnboundedSender<Message>>,
23    },
24    Remote {
25        handler: Option<Arc<dyn ClientHandler>>,
26        transport: Arc<dyn Transport>,
27        receiver: Arc<Mutex<UnboundedReceiver<Message>>>,
28        sender: Arc<UnboundedSender<Message>>,
29    },
30}
31impl Session {
32    /// Start the session and listen for messages
33    pub async fn start(self) -> Result<(), Error> {
34        match self {
35            Session::Local {
36                handler,
37                command,
38                receiver,
39                sender,
40            } => {
41                // spawn the child process - wrap ProcessManager to ensure cleanup
42                let pm = Arc::new(tokio::sync::Mutex::new(
43                    crate::process::ProcessManager::new(),
44                ));
45                let (output_tx, output_rx) = tokio::sync::mpsc::channel(100);
46                let process_tx = {
47                    let mut manager = pm.lock().await;
48                    manager
49                        .start_process(command, output_tx.clone())
50                        .await
51                        .expect("a spawned subprocess")
52                };
53
54                let transport = Arc::new(StdioTransport::new(output_rx, process_tx));
55                let handler = handler.unwrap_or(Arc::new(DefaultClientHandler));
56                let t = transport.clone();
57
58                // Clone ProcessManager for cleanup tasks
59                let pm_for_receiver_task = pm.clone();
60                let pm_for_sender_task = pm.clone();
61
62                // listen for messages from the server
63                tokio::spawn(async move {
64                    let mut stream = t.receive();
65                    while let Some(result) = stream.next().await {
66                        match result {
67                            Ok(message) => match &message {
68                                Message::Request(r) => {
69                                    let res = handler
70                                        .handle_request(r.method.clone(), r.params.clone())
71                                        .await;
72                                    if t.send(Message::Response(Response::success(
73                                        r.id.clone(),
74                                        Some(res.unwrap()),
75                                    )))
76                                    .await
77                                    .is_err()
78                                    {
79                                        break;
80                                    }
81                                }
82                                Message::Response(_) => {
83                                    if sender.send(message).is_err() {
84                                        break;
85                                    }
86                                }
87                                Message::Notification(n) => {
88                                    if handler
89                                        .handle_notification(n.method.clone(), n.params.clone())
90                                        .await
91                                        .is_err()
92                                    {
93                                        break;
94                                    }
95                                }
96                            },
97                            Err(_) => break,
98                        }
99                    }
100                    // Clean up the process when the receiver stream ends
101                    let mut manager = pm_for_receiver_task.lock().await;
102                    manager.shutdown().await;
103                });
104                // listen for messages to send to the server
105                let rx_clone = receiver.clone();
106                let tx_clone = transport.clone();
107                tokio::spawn(async move {
108                    let mut stream = rx_clone.lock().await;
109                    while let Some(message) = stream.recv().await {
110                        if tx_clone.send(message).await.is_err() {
111                            break;
112                        }
113                    }
114                    // Clean up the process when the sender stream ends
115                    let mut manager = pm_for_sender_task.lock().await;
116                    manager.shutdown().await;
117                });
118
119                Ok(())
120            }
121            Session::Remote {
122                handler,
123                transport,
124                receiver,
125                sender,
126            } => {
127                let t = transport.clone();
128                let handler = handler.unwrap_or(Arc::new(DefaultClientHandler));
129                // listen for messages from the server
130                tokio::spawn(async move {
131                    let mut stream = t.receive();
132                    while let Some(result) = stream.next().await {
133                        match result {
134                            Ok(message) => match &message {
135                                Message::Request(r) => {
136                                    let res = handler
137                                        .handle_request(r.method.clone(), r.params.clone())
138                                        .await;
139                                    if t.send(Message::Response(Response::success(
140                                        r.id.clone(),
141                                        Some(res.unwrap()),
142                                    )))
143                                    .await
144                                    .is_err()
145                                    {
146                                        break;
147                                    }
148                                }
149                                Message::Response(_) => {
150                                    if sender.send(message).is_err() {
151                                        break;
152                                    }
153                                }
154                                Message::Notification(n) => {
155                                    if handler
156                                        .handle_notification(n.method.clone(), n.params.clone())
157                                        .await
158                                        .is_err()
159                                    {
160                                        break;
161                                    }
162                                }
163                            },
164                            Err(_) => break,
165                        }
166                    }
167                });
168                // listen for messages to send to the server
169                let rx_clone = receiver.clone();
170                let tx_clone = transport.clone();
171                tokio::spawn(async move {
172                    let mut stream = rx_clone.lock().await;
173                    while let Some(message) = stream.recv().await {
174                        if tx_clone.send(message).await.is_err() {
175                            break;
176                        }
177                    }
178                });
179                Ok(())
180            }
181        }
182    }
183}