edgehog_device_runtime_forwarder/
connections_manager.rs1use std::ops::ControlFlow;
7use std::sync::Arc;
8
9use backoff::{Error as BackoffError, ExponentialBackoff};
10use futures::{future, SinkExt, StreamExt, TryFutureExt};
11use thiserror::Error as ThisError;
12use tokio::net::TcpStream;
13use tokio::select;
14use tokio::sync::mpsc::{channel, Receiver};
15use tokio_tungstenite::connect_async_tls_with_config;
16use tokio_tungstenite::{
17 tungstenite::Error as TungError, tungstenite::Message as TungMessage, Connector,
18 MaybeTlsStream, WebSocketStream,
19};
20use tracing::{debug, error, info, instrument, trace, warn};
21use url::Url;
22
23use crate::collection::Connections;
24use crate::connection::ConnectionError;
25use crate::messages::{Id, ProtoMessage, ProtocolError};
26
27pub(crate) const CHANNEL_SIZE: usize = 50;
29
30#[derive(displaydoc::Display, ThisError, Debug)]
32#[non_exhaustive]
33pub enum Error {
34 WebSocket(#[from] Box<TungError>),
36 Protobuf(#[from] ProtocolError),
38 Connection(#[from] ConnectionError),
40 WrongMessage(Id),
42 ConnectionNotFound(Id),
44 IdAlreadyUsed(Id),
46 Unsupported,
48 TokenNotFound,
50 TokenAlreadyUsed(String),
52 BackOff(#[from] BackoffError<Box<Error>>),
54 Tls(#[from] rustls::Error),
56}
57
58#[derive(displaydoc::Display, ThisError, Debug)]
60pub struct Disconnected(#[from] pub Box<TungError>);
61
62pub type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
64
65#[derive(Debug)]
69pub struct ConnectionsManager {
70 pub(crate) connections: Connections,
72 pub(crate) ws_stream: WsStream,
74 pub(crate) rx_ws: Receiver<ProtoMessage>,
76 pub(crate) url: Url,
78 pub(crate) secure: bool,
80}
81
82impl ConnectionsManager {
83 #[instrument]
85 pub async fn connect(url: Url, secure: bool) -> Result<Self, Error> {
86 let connector = if secure {
88 let tls = edgehog_tls::config()?;
89
90 Connector::Rustls(Arc::new(tls))
91 } else {
92 Connector::Plain
93 };
94
95 #[cfg(test)]
98 if rustls::crypto::CryptoProvider::get_default().is_none() {
99 let _ = rustls::crypto::aws_lc_rs::default_provider()
100 .install_default()
101 .inspect_err(|_| tracing::error!("couldn't install default crypto provider"));
102 }
103
104 let ws_stream = Self::ws_connect(&url, connector).await?;
105
106 let (tx_ws, rx_ws) = channel(CHANNEL_SIZE);
111
112 let connections = Connections::new(tx_ws);
113
114 Ok(Self {
115 connections,
116 ws_stream,
117 rx_ws,
118 url,
119 secure,
120 })
121 }
122
123 #[instrument(skip_all)]
125 pub(crate) async fn ws_connect(
126 url: &Url,
127 connector: Connector,
128 ) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, Error> {
129 let (ws_stream, http_res) =
131 backoff::future::retry(ExponentialBackoff::default(), || async {
132 debug!("creating WebSocket connection with {}", url);
133
134 let connector_cpy = connector.clone();
135
136 connect_async_tls_with_config(url, None, false, Some(connector_cpy))
138 .await
139 .map_err(|err| match err {
140 TungError::Http(http_res) if http_res.status().is_client_error() => {
141 error!(
142 "received HTTP client error ({}), stopping backoff",
143 http_res.status()
144 );
145
146 match get_token(url) {
147 Ok(token) => {
148 BackoffError::Permanent(Error::TokenAlreadyUsed(token))
149 }
150 Err(err) => BackoffError::Permanent(err),
151 }
152 }
153 err => {
154 debug!("try reconnecting with backoff after tungstenite error: {err}");
155 BackoffError::Transient {
156 err: Error::WebSocket(Box::new(err)),
157 retry_after: None,
158 }
159 }
160 })
161 })
162 .await?;
163
164 trace!("WebSocket response {http_res:?}");
165
166 Ok(ws_stream)
167 }
168
169 #[instrument(skip_all)]
171 pub async fn handle_connections(&mut self) -> Result<(), Disconnected> {
172 loop {
173 match self.event_loop().await {
174 Ok(ControlFlow::Continue(())) => {}
175 Ok(ControlFlow::Break(())) | Err(TungError::ConnectionClosed) => break,
178 Err(TungError::Capacity(err)) => {
180 error!("capacity exceeded: {err}");
181 break;
182 }
183 Err(TungError::AlreadyClosed) => {
184 error!("BUG: trying to read/write on an already closed WebSocket");
185 break;
186 }
187 Err(err) => {
190 return Err(Disconnected(Box::new(err)));
191 }
192 }
193 }
194
195 Ok(())
196 }
197
198 #[instrument(skip_all)]
204 pub(crate) async fn event_loop(&mut self) -> Result<ControlFlow<()>, TungError> {
205 let event = self.select_ws_event().await;
206
207 match event {
208 WebSocketEvents::Receive(msg) => {
210 future::ready(msg)
211 .and_then(|msg| self.handle_tung_msg(msg))
212 .await
213 }
214 WebSocketEvents::Send(tung_msg) => {
216 let msg = match tung_msg.encode() {
217 Ok(msg) => TungMessage::Binary(msg.into()),
218 Err(err) => {
219 error!("discard message due to {err:?}");
220 return Ok(ControlFlow::Continue(()));
221 }
222 };
223
224 self.send_to_ws(msg)
225 .await
226 .map(|_| ControlFlow::Continue(()))
227 }
228 }
229 }
230
231 #[instrument(skip_all)]
233 pub(crate) async fn select_ws_event(&mut self) -> WebSocketEvents {
234 select! {
235 res = self.ws_stream.next() => {
236 match res {
237 Some(msg) => {
238 trace!("received tungstenite message from Edgehog: {msg:?}");
239 WebSocketEvents::Receive(msg)
240 }
241 None => {
242 trace!("ws_stream next() returned None, connection already closed");
243 WebSocketEvents::Receive(Err(TungError::AlreadyClosed))
244 }
245 }
246 }
247 next = self.rx_ws.recv() => match next {
248 Some(msg) => {
249 trace!("proto message received from a device connection: {msg:?}");
250 WebSocketEvents::Send(Box::new(msg))
251 }
252 None => unreachable!("BUG: tx_ws channel should never be closed"),
253 }
254 }
255 }
256
257 #[instrument(skip_all)]
259 pub(crate) async fn send_to_ws(&mut self, tung_msg: TungMessage) -> Result<(), TungError> {
260 self.ws_stream.send(tung_msg).await
261 }
262
263 #[instrument(skip_all)]
265 pub(crate) async fn handle_tung_msg(
266 &mut self,
267 msg: TungMessage,
268 ) -> Result<ControlFlow<()>, TungError> {
269 match msg {
270 TungMessage::Ping(data) => {
271 debug!("received ping, sending pong");
272 let msg = TungMessage::Pong(data);
273 self.send_to_ws(msg).await?;
274 }
275 TungMessage::Pong(_) => debug!("received pong"),
276 TungMessage::Close(close_frame) => {
277 debug!("received close frame {close_frame:?}, closing active connections");
278 self.disconnect();
279 info!("closed every connection");
280 return Ok(ControlFlow::Break(()));
281 }
282 TungMessage::Text(data) => warn!("received Text WebSocket frame, {data}"),
284 TungMessage::Binary(bytes) => {
285 match ProtoMessage::decode(&bytes) {
286 Ok(proto_msg) => {
288 trace!("message received from Edgehog: {proto_msg:?}");
289 if let Err(err) = self.handle_proto_msg(proto_msg).await {
290 error!("failed to handle protobuf message due to {err:?}");
291 }
292 }
293 Err(err) => {
294 error!("failed to decode protobuf message due to {err:?}");
295 }
296 }
297 }
298 TungMessage::Frame(_) => error!("unhandled message type: {msg:?}"),
300 }
301
302 Ok(ControlFlow::Continue(()))
303 }
304
305 pub(crate) async fn handle_proto_msg(&mut self, proto_msg: ProtoMessage) -> Result<(), Error> {
307 self.connections.remove_terminated();
309
310 match proto_msg {
311 ProtoMessage::Http(http) => {
312 trace!("received HTTP message: {http:?}");
313 self.connections.handle_http(http)
314 }
315 ProtoMessage::WebSocket(ws) => {
316 trace!("received WebSocket frame: {ws:?}");
317 self.connections.handle_ws(ws).await
318 }
319 }
320 }
321
322 #[instrument(skip_all)]
324 pub async fn reconnect(&mut self) -> Result<(), Error> {
325 debug!("trying to reconnect");
326
327 let connector = if self.secure {
328 let tls = edgehog_tls::config()?;
329
330 Connector::Rustls(Arc::new(tls))
331 } else {
332 Connector::Plain
333 };
334
335 self.ws_stream = Self::ws_connect(&self.url, connector).await?;
336
337 info!("reconnected");
338 Ok(())
339 }
340
341 #[instrument(skip_all)]
343 pub(crate) fn disconnect(&mut self) {
344 self.connections.disconnect();
345 }
346}
347
348pub(crate) fn get_token(url: &Url) -> Result<String, Error> {
350 url.query()
351 .map(|s| s.trim_start_matches("session=").to_string())
352 .ok_or(Error::TokenNotFound)
353}
354
355pub(crate) enum WebSocketEvents {
357 Receive(Result<TungMessage, TungError>),
358 Send(Box<ProtoMessage>),
359}