mcp_sdk_rs/session/
mod.rs

1use crate::{
2    client::{ClientHandler, DefaultClientHandler},
3    error::Error,
4    protocol::Response,
5    transport::{stdio::StdioTransport, Message, Transport},
6};
7use futures::StreamExt;
8use std::sync::Arc;
9use tokio::{
10    process::Command,
11    sync::{
12        mpsc::{UnboundedReceiver, UnboundedSender},
13        Mutex,
14    },
15};
16
17// #[derive(Clone)]
18pub enum Session {
19    Local {
20        handler: Option<Arc<dyn ClientHandler>>,
21        command: Command,
22        receiver: Arc<Mutex<UnboundedReceiver<Message>>>,
23        sender: Arc<UnboundedSender<Message>>,
24    },
25    Remote {
26        handler: Option<Arc<dyn ClientHandler>>,
27        transport: Arc<dyn Transport>,
28        receiver: Arc<Mutex<UnboundedReceiver<Message>>>,
29        sender: Arc<UnboundedSender<Message>>,
30    },
31}
32impl Session {
33    /// Start the session and listen for messages
34    pub async fn start(self) -> Result<(), Error> {
35        match self {
36            Session::Local {
37                handler,
38                command,
39                receiver,
40                sender,
41            } => {
42                // spawn the child process
43                let mut pm = crate::process::ProcessManager::new();
44                let (output_tx, output_rx) = tokio::sync::mpsc::channel(100);
45                let process_tx = pm
46                    .start_process(command, output_tx.clone())
47                    .await
48                    .expect("a spawned subprocess");
49
50                let transport = Arc::new(StdioTransport::new(output_rx, process_tx));
51                let handler = handler.unwrap_or(Arc::new(DefaultClientHandler));
52                let t = transport.clone();
53
54                // listen for messages from the server
55                tokio::spawn(async move {
56                    let mut stream = t.receive();
57                    while let Some(result) = stream.next().await {
58                        match result {
59                            Ok(message) => match &message {
60                                Message::Request(r) => {
61                                    let res = handler
62                                        .handle_request(r.method.clone(), r.params.clone())
63                                        .await;
64                                    if t.send(Message::Response(Response::success(
65                                        r.id.clone(),
66                                        Some(res.unwrap()),
67                                    )))
68                                    .await
69                                    .is_err()
70                                    {
71                                        break;
72                                    }
73                                }
74                                Message::Response(_) => {
75                                    if sender.send(message).is_err() {
76                                        break;
77                                    }
78                                }
79                                Message::Notification(n) => {
80                                    if handler
81                                        .handle_notification(n.method.clone(), n.params.clone())
82                                        .await
83                                        .is_err()
84                                    {
85                                        break;
86                                    }
87                                }
88                            },
89                            Err(_) => break,
90                        }
91                    }
92                });
93                // listen for messages to send to the server
94                let rx_clone = receiver.clone();
95                let tx_clone = transport.clone();
96                tokio::spawn(async move {
97                    let mut stream = rx_clone.lock().await;
98                    while let Some(message) = stream.recv().await {
99                        tx_clone.send(message).await.unwrap();
100                    }
101                });
102                Ok(())
103            }
104            Session::Remote {
105                handler,
106                transport,
107                receiver,
108                sender,
109            } => {
110                let t = transport.clone();
111                let handler = handler.unwrap_or(Arc::new(DefaultClientHandler));
112                // listen for messages from the server
113                tokio::spawn(async move {
114                    let mut stream = t.receive();
115                    while let Some(result) = stream.next().await {
116                        match result {
117                            Ok(message) => match &message {
118                                Message::Request(r) => {
119                                    let res = handler
120                                        .handle_request(r.method.clone(), r.params.clone())
121                                        .await;
122                                    if t.send(Message::Response(Response::success(
123                                        r.id.clone(),
124                                        Some(res.unwrap()),
125                                    )))
126                                    .await
127                                    .is_err()
128                                    {
129                                        break;
130                                    }
131                                }
132                                Message::Response(_) => {
133                                    if sender.send(message).is_err() {
134                                        break;
135                                    }
136                                }
137                                Message::Notification(n) => {
138                                    if handler
139                                        .handle_notification(n.method.clone(), n.params.clone())
140                                        .await
141                                        .is_err()
142                                    {
143                                        break;
144                                    }
145                                }
146                            },
147                            Err(_) => break,
148                        }
149                    }
150                });
151                // listen for messages to send to the server
152                let rx_clone = receiver.clone();
153                let tx_clone = transport.clone();
154                tokio::spawn(async move {
155                    let mut stream = rx_clone.lock().await;
156                    while let Some(message) = stream.recv().await {
157                        tx_clone.send(message).await.unwrap();
158                    }
159                });
160                Ok(())
161            }
162        }
163    }
164}