modeldriveprotocol_client/
transport.rs1use std::collections::HashMap;
2
3use async_trait::async_trait;
4use futures_util::{SinkExt, StreamExt};
5use http::Request;
6use serde_json::{json, Value};
7use tokio::sync::mpsc;
8use tokio::task::JoinHandle;
9use tokio_tungstenite::connect_async;
10use tokio_tungstenite::tungstenite::Message;
11
12use crate::error::MdpClientError;
13use crate::protocol::{ClientToServerMessage, ServerToClientMessage};
14
15const DEFAULT_HTTP_LOOP_PATH: &str = "/mdp/http-loop";
16const SESSION_HEADER: &str = "x-mdp-session-id";
17
18#[async_trait]
19pub trait ClientTransport: Send {
20 async fn connect(
21 &mut self,
22 ) -> Result<mpsc::UnboundedReceiver<ServerToClientMessage>, MdpClientError>;
23 async fn send(&mut self, message: ClientToServerMessage) -> Result<(), MdpClientError>;
24 async fn close(&mut self) -> Result<(), MdpClientError>;
25}
26
27pub struct WebSocketClientTransport {
28 server_url: String,
29 headers: HashMap<String, String>,
30 writer: Option<
31 futures_util::stream::SplitSink<
32 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
33 Message,
34 >,
35 >,
36 read_task: Option<JoinHandle<()>>,
37}
38
39impl WebSocketClientTransport {
40 pub fn new(server_url: impl Into<String>, headers: Option<HashMap<String, String>>) -> Self {
41 Self {
42 server_url: server_url.into(),
43 headers: headers.unwrap_or_default(),
44 writer: None,
45 read_task: None,
46 }
47 }
48}
49
50#[async_trait]
51impl ClientTransport for WebSocketClientTransport {
52 async fn connect(
53 &mut self,
54 ) -> Result<mpsc::UnboundedReceiver<ServerToClientMessage>, MdpClientError> {
55 let mut request = Request::builder().uri(&self.server_url);
56 for (key, value) in &self.headers {
57 request = request.header(key, value);
58 }
59 let request = request
60 .body(())
61 .map_err(|error| MdpClientError::Transport(error.to_string()))?;
62
63 let (stream, _) = connect_async(request).await?;
64 let (writer, mut reader) = stream.split();
65 self.writer = Some(writer);
66
67 let (sender, receiver) = mpsc::unbounded_channel();
68 self.read_task = Some(tokio::spawn(async move {
69 while let Some(frame) = reader.next().await {
70 let Ok(frame) = frame else {
71 break;
72 };
73
74 match frame {
75 Message::Text(text) => {
76 let Ok(message) = ServerToClientMessage::from_text(&text) else {
77 continue;
78 };
79 if sender.send(message).is_err() {
80 break;
81 }
82 }
83 Message::Binary(payload) => {
84 let Ok(text) = String::from_utf8(payload.to_vec()) else {
85 continue;
86 };
87 let Ok(message) = ServerToClientMessage::from_text(&text) else {
88 continue;
89 };
90 if sender.send(message).is_err() {
91 break;
92 }
93 }
94 Message::Close(_) => break,
95 _ => continue,
96 }
97 }
98 }));
99
100 Ok(receiver)
101 }
102
103 async fn send(&mut self, message: ClientToServerMessage) -> Result<(), MdpClientError> {
104 let Some(writer) = &mut self.writer else {
105 return Err(MdpClientError::NotConnected);
106 };
107 writer
108 .send(Message::Text(serde_json::to_string(&message)?.into()))
109 .await?;
110 Ok(())
111 }
112
113 async fn close(&mut self) -> Result<(), MdpClientError> {
114 if let Some(writer) = &mut self.writer {
115 writer.close().await?;
116 }
117 self.writer = None;
118 if let Some(task) = self.read_task.take() {
119 task.abort();
120 }
121 Ok(())
122 }
123}
124
125pub struct HttpLoopClientTransport {
126 server_url: String,
127 endpoint_path: String,
128 headers: HashMap<String, String>,
129 poll_wait_ms: u64,
130 client: reqwest::Client,
131 session_id: Option<String>,
132 poll_task: Option<JoinHandle<()>>,
133}
134
135impl HttpLoopClientTransport {
136 pub fn new(server_url: impl Into<String>, headers: Option<HashMap<String, String>>) -> Self {
137 Self {
138 server_url: server_url.into(),
139 endpoint_path: DEFAULT_HTTP_LOOP_PATH.to_string(),
140 headers: headers.unwrap_or_default(),
141 poll_wait_ms: 25_000,
142 client: reqwest::Client::new(),
143 session_id: None,
144 poll_task: None,
145 }
146 }
147
148 fn endpoint_url(&self, suffix: &str) -> String {
149 format!(
150 "{}{}{}",
151 self.server_url.trim_end_matches('/'),
152 self.endpoint_path,
153 suffix
154 )
155 }
156}
157
158#[async_trait]
159impl ClientTransport for HttpLoopClientTransport {
160 async fn connect(
161 &mut self,
162 ) -> Result<mpsc::UnboundedReceiver<ServerToClientMessage>, MdpClientError> {
163 let response = self
164 .client
165 .post(self.endpoint_url("/connect"))
166 .headers(reqwest::header::HeaderMap::new())
167 .json(&json!({}))
168 .send()
169 .await?;
170 let response = response.error_for_status()?;
171 let payload: Value = response.json().await?;
172 let session_id = payload
173 .get("sessionId")
174 .and_then(Value::as_str)
175 .ok_or_else(|| MdpClientError::Protocol("invalid HTTP loop handshake response".to_string()))?
176 .to_string();
177
178 self.session_id = Some(session_id.clone());
179 let client = self.client.clone();
180 let base_url = self.server_url.clone();
181 let endpoint_path = self.endpoint_path.clone();
182 let wait_ms = self.poll_wait_ms;
183 let headers = self.headers.clone();
184
185 let (sender, receiver) = mpsc::unbounded_channel();
186 self.poll_task = Some(tokio::spawn(async move {
187 loop {
188 let response = client
189 .get(format!(
190 "{}{}{}",
191 base_url.trim_end_matches('/'),
192 endpoint_path,
193 "/poll"
194 ))
195 .headers(headers_to_reqwest(&headers))
196 .query(&[("sessionId", session_id.as_str()), ("waitMs", &wait_ms.to_string())])
197 .send()
198 .await;
199
200 let Ok(response) = response else {
201 break;
202 };
203
204 if response.status() == reqwest::StatusCode::NO_CONTENT {
205 continue;
206 }
207
208 let Ok(response) = response.error_for_status() else {
209 break;
210 };
211
212 let Ok(payload) = response.json::<Value>().await else {
213 break;
214 };
215
216 let Some(message) = payload.get("message").cloned() else {
217 continue;
218 };
219
220 let Ok(message) = ServerToClientMessage::from_value(message) else {
221 continue;
222 };
223
224 if sender.send(message).is_err() {
225 break;
226 }
227 }
228 }));
229
230 Ok(receiver)
231 }
232
233 async fn send(&mut self, message: ClientToServerMessage) -> Result<(), MdpClientError> {
234 let Some(session_id) = &self.session_id else {
235 return Err(MdpClientError::NotConnected);
236 };
237 self.client
238 .post(self.endpoint_url("/send"))
239 .headers(headers_to_reqwest(&self.headers))
240 .header(SESSION_HEADER, session_id)
241 .json(&json!({ "message": message }))
242 .send()
243 .await?
244 .error_for_status()?;
245 Ok(())
246 }
247
248 async fn close(&mut self) -> Result<(), MdpClientError> {
249 if let Some(task) = self.poll_task.take() {
250 task.abort();
251 }
252 if let Some(session_id) = self.session_id.take() {
253 let _ = self
254 .client
255 .post(self.endpoint_url("/disconnect"))
256 .headers(headers_to_reqwest(&self.headers))
257 .header(SESSION_HEADER, session_id)
258 .json(&json!({}))
259 .send()
260 .await;
261 }
262 Ok(())
263 }
264}
265
266fn headers_to_reqwest(headers: &HashMap<String, String>) -> reqwest::header::HeaderMap {
267 let mut map = reqwest::header::HeaderMap::new();
268 for (key, value) in headers {
269 if let (Ok(name), Ok(value)) = (
270 reqwest::header::HeaderName::from_bytes(key.as_bytes()),
271 reqwest::header::HeaderValue::from_str(value),
272 ) {
273 map.insert(name, value);
274 }
275 }
276 map
277}