Skip to main content

edgehog_device_runtime_forwarder/connection/
mod.rs

1// Copyright 2024 SECO Mind Srl
2// SPDX-License-Identifier: Apache-2.0
3
4//! Manage a single connection.
5//!
6//! A connection is responsible for sending and receiving data through a WebSocket connection from
7//! and to the [`ConnectionsManager`](crate::connections_manager::ConnectionsManager).
8
9pub mod http;
10pub mod websocket;
11
12use std::ops::Deref;
13
14use async_trait::async_trait;
15use thiserror::Error as ThisError;
16use tokio::sync::mpsc::Sender;
17use tokio::task::{JoinError, JoinHandle};
18use tokio_tungstenite::tungstenite::Error as TungError;
19use tracing::{error, instrument, trace};
20
21use crate::messages::{Id, ProtoMessage, ProtocolError, WebSocketMessage as ProtoWebSocketMessage};
22
23/// Size of the channel used to send messages from the [Connections Manager](crate::connections_manager::ConnectionsManager)
24/// to a device WebSocket connection
25pub(crate) const WS_CHANNEL_SIZE: usize = 50;
26
27/// Connection errors.
28#[non_exhaustive]
29#[derive(displaydoc::Display, ThisError, Debug)]
30pub enum ConnectionError {
31    /// Channel error.
32    Channel(&'static str),
33    /// Reqwest error.
34    Http(#[from] reqwest::Error),
35    /// Protobuf error.
36    Protobuf(#[from] ProtocolError),
37    /// Failed to Join a task handle.
38    JoinError(#[from] JoinError),
39    /// Message sent to the wrong protocol
40    WrongProtocol,
41    /// Error when receiving message on WebSocket connection, `{0}`.
42    WebSocket(#[from] Box<TungError>),
43    /// Trying to poll while still connecting.
44    Connecting,
45}
46
47/// Enum storing the write side of the channel used by the
48/// [Connections Manager](crate::connections_manager::ConnectionsManager) to send WebSocket
49/// messages to the respective connection that will handle it.
50#[derive(Debug)]
51pub(crate) enum WriteHandle {
52    Http,
53    Ws(Sender<ProtoWebSocketMessage>),
54}
55
56/// Handle to the task spawned to handle a [`Connection`].
57#[derive(Debug)]
58pub(crate) struct ConnectionHandle {
59    /// Handle of the task managing the connection.
60    pub(crate) handle: JoinHandle<()>,
61    /// Handle necessary to send messages to the tokio task managing the connection.
62    pub(crate) connection: WriteHandle,
63}
64
65impl Deref for ConnectionHandle {
66    type Target = JoinHandle<()>;
67
68    fn deref(&self) -> &Self::Target {
69        &self.handle
70    }
71}
72
73impl ConnectionHandle {
74    /// Once the connections manager receives a WebSocket message, it sends a message to the
75    /// respective tokio task handling that connection.
76    #[instrument]
77    pub(crate) async fn send(&self, msg: ProtoMessage) -> Result<(), ConnectionError> {
78        match &self.connection {
79            WriteHandle::Http => Err(ConnectionError::Channel(
80                "sending messages over a channel is only allowed for WebSocket connections",
81            )),
82            WriteHandle::Ws(tx_con) => {
83                let message = msg.into_ws().ok_or(ConnectionError::WrongProtocol)?.message;
84                tx_con.send(message).await.map_err(|_| {
85                    ConnectionError::Channel(
86                        "error while sending messages to the ConnectionsManager",
87                    )
88                })
89            }
90        }
91    }
92}
93
94/// For each Connection implementing a given transport protocol (e.g., [`Http`], [`WebSocket`]), it
95/// provides a method returning a [`protocol message`](ProtoMessage) to send to the
96/// [`ConnectionsManager`](crate::collection::ConnectionsManager).
97#[async_trait]
98pub(crate) trait Transport {
99    async fn next(&mut self, id: &Id) -> Result<Option<ProtoMessage>, ConnectionError>;
100}
101
102/// Trait used by each transport builder (e.g., [`HttpBuilder`], [`WebSocketBuilder`]) to build the
103/// respective transport protocol struct.
104#[async_trait]
105pub(crate) trait TransportBuilder {
106    type Connection: Transport;
107
108    async fn build(
109        self,
110        id: &Id,
111        tx_ws: Sender<ProtoMessage>,
112    ) -> Result<Self::Connection, ConnectionError>;
113}
114
115/// Struct containing the connection information necessary to communicate with the
116/// [`ConnectionsManager`](crate::collection::ConnectionsManager).
117#[derive(Debug)]
118pub(crate) struct Connection<T> {
119    id: Id,
120    tx_ws: Sender<ProtoMessage>,
121    state: T,
122}
123
124impl<T> Connection<T> {
125    /// Initialize a new connection.
126    pub(crate) fn new(id: Id, tx_ws: Sender<ProtoMessage>, state: T) -> Self {
127        Self { id, tx_ws, state }
128    }
129
130    /// Spawn the task responsible for handling the connection.
131    #[instrument(skip_all)]
132    pub(crate) fn spawn(self, write_handle: WriteHandle) -> ConnectionHandle
133    where
134        T: TransportBuilder + Send + 'static,
135        <T as TransportBuilder>::Connection: Send,
136    {
137        // spawn a task responsible for notifying when new data is available
138        let handle = tokio::spawn(async move { self.spawn_inner().await });
139
140        ConnectionHandle {
141            handle,
142            connection: write_handle,
143        }
144    }
145
146    #[instrument(skip_all, fields(id = %self.id))]
147    async fn spawn_inner(self)
148    where
149        T: TransportBuilder + Send + 'static,
150        <T as TransportBuilder>::Connection: Send,
151    {
152        if let Err(err) = self.task().await {
153            error!("connection task failed with error {err:?}");
154        }
155    }
156
157    /// Build the [`Transport`] and send protocol messages to the
158    /// [ConnectionsManager](crate::connections_manager::ConnectionsManager).
159    #[instrument(skip_all)]
160    pub(crate) async fn task(self) -> Result<(), ConnectionError>
161    where
162        T: TransportBuilder,
163    {
164        // create a connection (either HTTP or WebSocket) which implements the Transport trait
165        let mut connection = self.state.build(&self.id, self.tx_ws.clone()).await?;
166        trace!("connection {} created", self.id);
167
168        while let Some(proto_msg) = connection.next(&self.id).await? {
169            self.tx_ws.send(proto_msg).await.map_err(|_| {
170                ConnectionError::Channel(
171                    "error while sending generic message to the ConnectionsManager",
172                )
173            })?;
174        }
175
176        Ok(())
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::{
183        http::Http, ConnectionError, ConnectionHandle, Id, ProtoMessage, ProtoWebSocketMessage,
184        Transport, WriteHandle, WS_CHANNEL_SIZE,
185    };
186
187    use crate::messages::{
188        Http as ProtoHttp, HttpMessage as ProtoHttpMessage, HttpRequest as ProtoHttpRequest,
189        WebSocket as ProtoWebSocket,
190    };
191
192    use http::header::CONTENT_TYPE;
193    use http::HeaderValue;
194    use httpmock::MockServer;
195    use tokio::sync::mpsc::channel;
196    use tokio_tungstenite::tungstenite::Bytes;
197    use url::Url;
198
199    async fn empty_task() {}
200
201    fn create_http_req_proto(url: Url) -> ProtoHttpRequest {
202        ProtoHttpRequest {
203            method: http::Method::GET,
204            path: url.path().trim_start_matches('/').to_string(),
205            query_string: url.query().unwrap_or_default().to_string(),
206            headers: http::HeaderMap::new(),
207            body: Vec::new(),
208            port: url.port().expect("nonexistent port"),
209        }
210    }
211
212    fn create_http_req_msg_proto(url: &str) -> ProtoMessage {
213        let url = Url::parse(url).expect("failed to pars Url");
214
215        ProtoMessage::Http(ProtoHttp::new(
216            Id::try_from(b"1234".to_vec()).unwrap(),
217            ProtoHttpMessage::Request(create_http_req_proto(url)),
218        ))
219    }
220
221    #[tokio::test]
222    async fn test_con_handle_send() {
223        let (tx, mut rx) = channel::<ProtoWebSocketMessage>(WS_CHANNEL_SIZE);
224
225        let con_handle = ConnectionHandle {
226            handle: tokio::spawn(empty_task()),
227            connection: WriteHandle::Ws(tx),
228        };
229
230        let proto_msg = ProtoMessage::WebSocket(ProtoWebSocket {
231            socket_id: Id::try_from(b"1234".to_vec()).unwrap(),
232            message: ProtoWebSocketMessage::Binary(Bytes::from_static(b"message")),
233        });
234
235        let res = con_handle.send(proto_msg).await;
236
237        assert!(res.is_ok());
238
239        let res = rx.recv().await.expect("channel error");
240        let expected_res = ProtoWebSocketMessage::Binary(Bytes::from_static(b"message"));
241
242        assert_eq!(res, expected_res);
243    }
244
245    #[tokio::test]
246    async fn test_con_handle_send_error() {
247        // send() cannot be used in case the write handle is Http
248        let con_handle = ConnectionHandle {
249            handle: tokio::spawn(empty_task()),
250            connection: WriteHandle::Http,
251        };
252
253        let proto_msg = ProtoMessage::WebSocket(ProtoWebSocket {
254            socket_id: Id::try_from(b"1234".to_vec()).unwrap(),
255            message: ProtoWebSocketMessage::Binary(Bytes::from_static(b"message")),
256        });
257
258        let res = con_handle.send(proto_msg).await;
259
260        assert!(matches!(res, Err(ConnectionError::Channel(_))));
261
262        // an error is returned in case the proto message is not of WebSocket type
263        let (tx, _rx) = channel::<ProtoWebSocketMessage>(WS_CHANNEL_SIZE);
264        let con_handle = ConnectionHandle {
265            handle: tokio::spawn(empty_task()),
266            connection: WriteHandle::Ws(tx),
267        };
268
269        let proto_msg = create_http_req_msg_proto("https://host:8080/path?session=abcd");
270        let res = con_handle.send(proto_msg).await;
271
272        assert!(matches!(res, Err(ConnectionError::WrongProtocol)));
273    }
274
275    #[tokio::test]
276    async fn next_http() {
277        let mock_server = MockServer::start();
278
279        let mock_http_req = mock_server.mock(|when, then| {
280            when.method(httpmock::Method::GET)
281                .path("/path")
282                .query_param("session", "abcd");
283            then.status(200)
284                .header("content-type", "text/html")
285                .body("body");
286        });
287
288        let url = mock_server.url("/path?session=abcd");
289
290        let url = Url::parse(&url).expect("failed to parse Url");
291        let http_rep = create_http_req_proto(url);
292
293        let mut http = Http::new(
294            http_rep
295                .request_builder()
296                .expect("failed to retrieve request builder"),
297        );
298
299        let id = Id::try_from(b"1234".to_vec()).unwrap();
300
301        let res = http.next(&id).await.unwrap().unwrap();
302
303        // check that there has been an HTTP call with the specified information
304        mock_http_req.assert();
305
306        let proto_msg = res.into_http().unwrap();
307        assert_eq!(proto_msg.request_id, id);
308
309        let res = proto_msg.http_msg.into_res().unwrap();
310        assert_eq!(res.status_code, 200);
311        assert_eq!(res.body, b"body");
312        assert_eq!(
313            res.headers.get(CONTENT_TYPE).unwrap(),
314            HeaderValue::from_static("text/html")
315        );
316
317        // calling a second time next() on an http should return Ok(None)
318        let res = http.next(&id).await;
319        assert!(res.unwrap().is_none());
320    }
321}