edgehog_device_runtime_forwarder/connection/
mod.rs1pub mod http;
10pub mod websocket;
11
12use std::ops::Deref;
13
14use async_trait::async_trait;
15use thiserror::Error as ThisError;
16use tokio::sync::mpsc::Sender;
17use tokio::task::{JoinError, JoinHandle};
18use tokio_tungstenite::tungstenite::Error as TungError;
19use tracing::{error, instrument, trace};
20
21use crate::messages::{Id, ProtoMessage, ProtocolError, WebSocketMessage as ProtoWebSocketMessage};
22
23pub(crate) const WS_CHANNEL_SIZE: usize = 50;
26
27#[non_exhaustive]
29#[derive(displaydoc::Display, ThisError, Debug)]
30pub enum ConnectionError {
31 Channel(&'static str),
33 Http(#[from] reqwest::Error),
35 Protobuf(#[from] ProtocolError),
37 JoinError(#[from] JoinError),
39 WrongProtocol,
41 WebSocket(#[from] Box<TungError>),
43 Connecting,
45}
46
47#[derive(Debug)]
51pub(crate) enum WriteHandle {
52 Http,
53 Ws(Sender<ProtoWebSocketMessage>),
54}
55
56#[derive(Debug)]
58pub(crate) struct ConnectionHandle {
59 pub(crate) handle: JoinHandle<()>,
61 pub(crate) connection: WriteHandle,
63}
64
65impl Deref for ConnectionHandle {
66 type Target = JoinHandle<()>;
67
68 fn deref(&self) -> &Self::Target {
69 &self.handle
70 }
71}
72
73impl ConnectionHandle {
74 #[instrument]
77 pub(crate) async fn send(&self, msg: ProtoMessage) -> Result<(), ConnectionError> {
78 match &self.connection {
79 WriteHandle::Http => Err(ConnectionError::Channel(
80 "sending messages over a channel is only allowed for WebSocket connections",
81 )),
82 WriteHandle::Ws(tx_con) => {
83 let message = msg.into_ws().ok_or(ConnectionError::WrongProtocol)?.message;
84 tx_con.send(message).await.map_err(|_| {
85 ConnectionError::Channel(
86 "error while sending messages to the ConnectionsManager",
87 )
88 })
89 }
90 }
91 }
92}
93
94#[async_trait]
98pub(crate) trait Transport {
99 async fn next(&mut self, id: &Id) -> Result<Option<ProtoMessage>, ConnectionError>;
100}
101
102#[async_trait]
105pub(crate) trait TransportBuilder {
106 type Connection: Transport;
107
108 async fn build(
109 self,
110 id: &Id,
111 tx_ws: Sender<ProtoMessage>,
112 ) -> Result<Self::Connection, ConnectionError>;
113}
114
115#[derive(Debug)]
118pub(crate) struct Connection<T> {
119 id: Id,
120 tx_ws: Sender<ProtoMessage>,
121 state: T,
122}
123
124impl<T> Connection<T> {
125 pub(crate) fn new(id: Id, tx_ws: Sender<ProtoMessage>, state: T) -> Self {
127 Self { id, tx_ws, state }
128 }
129
130 #[instrument(skip_all)]
132 pub(crate) fn spawn(self, write_handle: WriteHandle) -> ConnectionHandle
133 where
134 T: TransportBuilder + Send + 'static,
135 <T as TransportBuilder>::Connection: Send,
136 {
137 let handle = tokio::spawn(async move { self.spawn_inner().await });
139
140 ConnectionHandle {
141 handle,
142 connection: write_handle,
143 }
144 }
145
146 #[instrument(skip_all, fields(id = %self.id))]
147 async fn spawn_inner(self)
148 where
149 T: TransportBuilder + Send + 'static,
150 <T as TransportBuilder>::Connection: Send,
151 {
152 if let Err(err) = self.task().await {
153 error!("connection task failed with error {err:?}");
154 }
155 }
156
157 #[instrument(skip_all)]
160 pub(crate) async fn task(self) -> Result<(), ConnectionError>
161 where
162 T: TransportBuilder,
163 {
164 let mut connection = self.state.build(&self.id, self.tx_ws.clone()).await?;
166 trace!("connection {} created", self.id);
167
168 while let Some(proto_msg) = connection.next(&self.id).await? {
169 self.tx_ws.send(proto_msg).await.map_err(|_| {
170 ConnectionError::Channel(
171 "error while sending generic message to the ConnectionsManager",
172 )
173 })?;
174 }
175
176 Ok(())
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::{
183 http::Http, ConnectionError, ConnectionHandle, Id, ProtoMessage, ProtoWebSocketMessage,
184 Transport, WriteHandle, WS_CHANNEL_SIZE,
185 };
186
187 use crate::messages::{
188 Http as ProtoHttp, HttpMessage as ProtoHttpMessage, HttpRequest as ProtoHttpRequest,
189 WebSocket as ProtoWebSocket,
190 };
191
192 use http::header::CONTENT_TYPE;
193 use http::HeaderValue;
194 use httpmock::MockServer;
195 use tokio::sync::mpsc::channel;
196 use tokio_tungstenite::tungstenite::Bytes;
197 use url::Url;
198
199 async fn empty_task() {}
200
201 fn create_http_req_proto(url: Url) -> ProtoHttpRequest {
202 ProtoHttpRequest {
203 method: http::Method::GET,
204 path: url.path().trim_start_matches('/').to_string(),
205 query_string: url.query().unwrap_or_default().to_string(),
206 headers: http::HeaderMap::new(),
207 body: Vec::new(),
208 port: url.port().expect("nonexistent port"),
209 }
210 }
211
212 fn create_http_req_msg_proto(url: &str) -> ProtoMessage {
213 let url = Url::parse(url).expect("failed to pars Url");
214
215 ProtoMessage::Http(ProtoHttp::new(
216 Id::try_from(b"1234".to_vec()).unwrap(),
217 ProtoHttpMessage::Request(create_http_req_proto(url)),
218 ))
219 }
220
221 #[tokio::test]
222 async fn test_con_handle_send() {
223 let (tx, mut rx) = channel::<ProtoWebSocketMessage>(WS_CHANNEL_SIZE);
224
225 let con_handle = ConnectionHandle {
226 handle: tokio::spawn(empty_task()),
227 connection: WriteHandle::Ws(tx),
228 };
229
230 let proto_msg = ProtoMessage::WebSocket(ProtoWebSocket {
231 socket_id: Id::try_from(b"1234".to_vec()).unwrap(),
232 message: ProtoWebSocketMessage::Binary(Bytes::from_static(b"message")),
233 });
234
235 let res = con_handle.send(proto_msg).await;
236
237 assert!(res.is_ok());
238
239 let res = rx.recv().await.expect("channel error");
240 let expected_res = ProtoWebSocketMessage::Binary(Bytes::from_static(b"message"));
241
242 assert_eq!(res, expected_res);
243 }
244
245 #[tokio::test]
246 async fn test_con_handle_send_error() {
247 let con_handle = ConnectionHandle {
249 handle: tokio::spawn(empty_task()),
250 connection: WriteHandle::Http,
251 };
252
253 let proto_msg = ProtoMessage::WebSocket(ProtoWebSocket {
254 socket_id: Id::try_from(b"1234".to_vec()).unwrap(),
255 message: ProtoWebSocketMessage::Binary(Bytes::from_static(b"message")),
256 });
257
258 let res = con_handle.send(proto_msg).await;
259
260 assert!(matches!(res, Err(ConnectionError::Channel(_))));
261
262 let (tx, _rx) = channel::<ProtoWebSocketMessage>(WS_CHANNEL_SIZE);
264 let con_handle = ConnectionHandle {
265 handle: tokio::spawn(empty_task()),
266 connection: WriteHandle::Ws(tx),
267 };
268
269 let proto_msg = create_http_req_msg_proto("https://host:8080/path?session=abcd");
270 let res = con_handle.send(proto_msg).await;
271
272 assert!(matches!(res, Err(ConnectionError::WrongProtocol)));
273 }
274
275 #[tokio::test]
276 async fn next_http() {
277 let mock_server = MockServer::start();
278
279 let mock_http_req = mock_server.mock(|when, then| {
280 when.method(httpmock::Method::GET)
281 .path("/path")
282 .query_param("session", "abcd");
283 then.status(200)
284 .header("content-type", "text/html")
285 .body("body");
286 });
287
288 let url = mock_server.url("/path?session=abcd");
289
290 let url = Url::parse(&url).expect("failed to parse Url");
291 let http_rep = create_http_req_proto(url);
292
293 let mut http = Http::new(
294 http_rep
295 .request_builder()
296 .expect("failed to retrieve request builder"),
297 );
298
299 let id = Id::try_from(b"1234".to_vec()).unwrap();
300
301 let res = http.next(&id).await.unwrap().unwrap();
302
303 mock_http_req.assert();
305
306 let proto_msg = res.into_http().unwrap();
307 assert_eq!(proto_msg.request_id, id);
308
309 let res = proto_msg.http_msg.into_res().unwrap();
310 assert_eq!(res.status_code, 200);
311 assert_eq!(res.body, b"body");
312 assert_eq!(
313 res.headers.get(CONTENT_TYPE).unwrap(),
314 HeaderValue::from_static("text/html")
315 );
316
317 let res = http.next(&id).await;
319 assert!(res.unwrap().is_none());
320 }
321}