Skip to main content

modeldriveprotocol_client/
transport.rs

1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use futures_util::{SinkExt, StreamExt};
5use http::Request;
6use serde_json::{json, Value};
7use tokio::sync::mpsc;
8use tokio::task::JoinHandle;
9use tokio_tungstenite::connect_async;
10use tokio_tungstenite::tungstenite::Message;
11
12use crate::error::MdpClientError;
13use crate::protocol::{ClientToServerMessage, ServerToClientMessage};
14
15const DEFAULT_HTTP_LOOP_PATH: &str = "/mdp/http-loop";
16const SESSION_HEADER: &str = "x-mdp-session-id";
17
18#[async_trait]
19pub trait ClientTransport: Send {
20    async fn connect(
21        &mut self,
22    ) -> Result<mpsc::UnboundedReceiver<ServerToClientMessage>, MdpClientError>;
23    async fn send(&mut self, message: ClientToServerMessage) -> Result<(), MdpClientError>;
24    async fn close(&mut self) -> Result<(), MdpClientError>;
25}
26
27pub struct WebSocketClientTransport {
28    server_url: String,
29    headers: HashMap<String, String>,
30    writer: Option<
31        futures_util::stream::SplitSink<
32            tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
33            Message,
34        >,
35    >,
36    read_task: Option<JoinHandle<()>>,
37}
38
39impl WebSocketClientTransport {
40    pub fn new(server_url: impl Into<String>, headers: Option<HashMap<String, String>>) -> Self {
41        Self {
42            server_url: server_url.into(),
43            headers: headers.unwrap_or_default(),
44            writer: None,
45            read_task: None,
46        }
47    }
48}
49
50#[async_trait]
51impl ClientTransport for WebSocketClientTransport {
52    async fn connect(
53        &mut self,
54    ) -> Result<mpsc::UnboundedReceiver<ServerToClientMessage>, MdpClientError> {
55        let mut request = Request::builder().uri(&self.server_url);
56        for (key, value) in &self.headers {
57            request = request.header(key, value);
58        }
59        let request = request
60            .body(())
61            .map_err(|error| MdpClientError::Transport(error.to_string()))?;
62
63        let (stream, _) = connect_async(request).await?;
64        let (writer, mut reader) = stream.split();
65        self.writer = Some(writer);
66
67        let (sender, receiver) = mpsc::unbounded_channel();
68        self.read_task = Some(tokio::spawn(async move {
69            while let Some(frame) = reader.next().await {
70                let Ok(frame) = frame else {
71                    break;
72                };
73
74                match frame {
75                    Message::Text(text) => {
76                        let Ok(message) = ServerToClientMessage::from_text(&text) else {
77                            continue;
78                        };
79                        if sender.send(message).is_err() {
80                            break;
81                        }
82                    }
83                    Message::Binary(payload) => {
84                        let Ok(text) = String::from_utf8(payload.to_vec()) else {
85                            continue;
86                        };
87                        let Ok(message) = ServerToClientMessage::from_text(&text) else {
88                            continue;
89                        };
90                        if sender.send(message).is_err() {
91                            break;
92                        }
93                    }
94                    Message::Close(_) => break,
95                    _ => continue,
96                }
97            }
98        }));
99
100        Ok(receiver)
101    }
102
103    async fn send(&mut self, message: ClientToServerMessage) -> Result<(), MdpClientError> {
104        let Some(writer) = &mut self.writer else {
105            return Err(MdpClientError::NotConnected);
106        };
107        writer
108            .send(Message::Text(serde_json::to_string(&message)?.into()))
109            .await?;
110        Ok(())
111    }
112
113    async fn close(&mut self) -> Result<(), MdpClientError> {
114        if let Some(writer) = &mut self.writer {
115            writer.close().await?;
116        }
117        self.writer = None;
118        if let Some(task) = self.read_task.take() {
119            task.abort();
120        }
121        Ok(())
122    }
123}
124
125pub struct HttpLoopClientTransport {
126    server_url: String,
127    endpoint_path: String,
128    headers: HashMap<String, String>,
129    poll_wait_ms: u64,
130    client: reqwest::Client,
131    session_id: Option<String>,
132    poll_task: Option<JoinHandle<()>>,
133}
134
135impl HttpLoopClientTransport {
136    pub fn new(server_url: impl Into<String>, headers: Option<HashMap<String, String>>) -> Self {
137        Self {
138            server_url: server_url.into(),
139            endpoint_path: DEFAULT_HTTP_LOOP_PATH.to_string(),
140            headers: headers.unwrap_or_default(),
141            poll_wait_ms: 25_000,
142            client: reqwest::Client::new(),
143            session_id: None,
144            poll_task: None,
145        }
146    }
147
148    fn endpoint_url(&self, suffix: &str) -> String {
149        format!(
150            "{}{}{}",
151            self.server_url.trim_end_matches('/'),
152            self.endpoint_path,
153            suffix
154        )
155    }
156}
157
158#[async_trait]
159impl ClientTransport for HttpLoopClientTransport {
160    async fn connect(
161        &mut self,
162    ) -> Result<mpsc::UnboundedReceiver<ServerToClientMessage>, MdpClientError> {
163        let response = self
164            .client
165            .post(self.endpoint_url("/connect"))
166            .headers(reqwest::header::HeaderMap::new())
167            .json(&json!({}))
168            .send()
169            .await?;
170        let response = response.error_for_status()?;
171        let payload: Value = response.json().await?;
172        let session_id = payload
173            .get("sessionId")
174            .and_then(Value::as_str)
175            .ok_or_else(|| MdpClientError::Protocol("invalid HTTP loop handshake response".to_string()))?
176            .to_string();
177
178        self.session_id = Some(session_id.clone());
179        let client = self.client.clone();
180        let base_url = self.server_url.clone();
181        let endpoint_path = self.endpoint_path.clone();
182        let wait_ms = self.poll_wait_ms;
183        let headers = self.headers.clone();
184
185        let (sender, receiver) = mpsc::unbounded_channel();
186        self.poll_task = Some(tokio::spawn(async move {
187            loop {
188                let response = client
189                    .get(format!(
190                        "{}{}{}",
191                        base_url.trim_end_matches('/'),
192                        endpoint_path,
193                        "/poll"
194                    ))
195                    .headers(headers_to_reqwest(&headers))
196                    .query(&[("sessionId", session_id.as_str()), ("waitMs", &wait_ms.to_string())])
197                    .send()
198                    .await;
199
200                let Ok(response) = response else {
201                    break;
202                };
203
204                if response.status() == reqwest::StatusCode::NO_CONTENT {
205                    continue;
206                }
207
208                let Ok(response) = response.error_for_status() else {
209                    break;
210                };
211
212                let Ok(payload) = response.json::<Value>().await else {
213                    break;
214                };
215
216                let Some(message) = payload.get("message").cloned() else {
217                    continue;
218                };
219
220                let Ok(message) = ServerToClientMessage::from_value(message) else {
221                    continue;
222                };
223
224                if sender.send(message).is_err() {
225                    break;
226                }
227            }
228        }));
229
230        Ok(receiver)
231    }
232
233    async fn send(&mut self, message: ClientToServerMessage) -> Result<(), MdpClientError> {
234        let Some(session_id) = &self.session_id else {
235            return Err(MdpClientError::NotConnected);
236        };
237        self.client
238            .post(self.endpoint_url("/send"))
239            .headers(headers_to_reqwest(&self.headers))
240            .header(SESSION_HEADER, session_id)
241            .json(&json!({ "message": message }))
242            .send()
243            .await?
244            .error_for_status()?;
245        Ok(())
246    }
247
248    async fn close(&mut self) -> Result<(), MdpClientError> {
249        if let Some(task) = self.poll_task.take() {
250            task.abort();
251        }
252        if let Some(session_id) = self.session_id.take() {
253            let _ = self
254                .client
255                .post(self.endpoint_url("/disconnect"))
256                .headers(headers_to_reqwest(&self.headers))
257                .header(SESSION_HEADER, session_id)
258                .json(&json!({}))
259                .send()
260                .await;
261        }
262        Ok(())
263    }
264}
265
266fn headers_to_reqwest(headers: &HashMap<String, String>) -> reqwest::header::HeaderMap {
267    let mut map = reqwest::header::HeaderMap::new();
268    for (key, value) in headers {
269        if let (Ok(name), Ok(value)) = (
270            reqwest::header::HeaderName::from_bytes(key.as_bytes()),
271            reqwest::header::HeaderValue::from_str(value),
272        ) {
273            map.insert(name, value);
274        }
275    }
276    map
277}