gcdevproxy/
device.rs

1use std::{
2    collections::HashMap,
3    future::Future,
4    io,
5    pin::Pin,
6    sync::{Arc, Mutex},
7    task::{Context, Poll},
8    time::{Duration, Instant},
9};
10
11use bytes::Bytes;
12use futures::{
13    channel::{mpsc, oneshot},
14    future::{AbortHandle, Abortable, Either},
15    ready, FutureExt, SinkExt, Stream, StreamExt,
16};
17use h2::{client::SendRequest, ext::Protocol, Ping, PingPong, RecvStream, SendStream};
18use http::{HeaderValue, Method, StatusCode, Version};
19use hyper::{
20    upgrade::{OnUpgrade, Upgraded},
21    Request, Response,
22};
23use hyper_util::rt::TokioIo;
24use tokio::io::AsyncWriteExt;
25use tokio_util::io::ReaderStream;
26use uuid::Uuid;
27
28use crate::{
29    utils::{HeaderMapExt, RequestExt},
30    Body, Error,
31};
32
33/// Device manager.
34#[derive(Clone)]
35pub struct DeviceManager {
36    devices: Arc<Mutex<HashMap<String, DeviceEntry>>>,
37}
38
39impl DeviceManager {
40    /// Create a new device manager.
41    pub fn new() -> Self {
42        Self {
43            devices: Arc::new(Mutex::new(HashMap::new())),
44        }
45    }
46
47    /// Add a given device.
48    pub fn add(
49        &self,
50        device_id: &str,
51        session_id: Uuid,
52        handle: DeviceHandle,
53    ) -> Option<DeviceHandle> {
54        self.devices
55            .lock()
56            .unwrap()
57            .insert(device_id.to_string(), DeviceEntry::new(session_id, handle))
58            .map(|old| old.into_handle())
59    }
60
61    /// Remove device with a given ID.
62    pub fn remove(&self, device_id: &str, session_id: Option<Uuid>) -> Option<DeviceHandle> {
63        let mut devices = self.devices.lock().unwrap();
64
65        let entry = devices.get(device_id)?;
66
67        if let Some(session_id) = session_id {
68            if session_id != entry.session_id {
69                return None;
70            }
71        }
72
73        devices.remove(device_id).map(|entry| entry.into_handle())
74    }
75
76    /// Get device with a given ID.
77    pub fn get(&self, device_id: &str) -> Option<DeviceHandle> {
78        self.devices
79            .lock()
80            .unwrap()
81            .get(device_id)
82            .map(|entry| entry.handle())
83            .cloned()
84    }
85}
86
87/// Device manager entry.
88struct DeviceEntry {
89    session_id: Uuid,
90    handle: DeviceHandle,
91}
92
93impl DeviceEntry {
94    /// Create a new device manager entry.
95    fn new(session_id: Uuid, handle: DeviceHandle) -> Self {
96        Self { session_id, handle }
97    }
98
99    /// Get the device handle.
100    fn handle(&self) -> &DeviceHandle {
101        &self.handle
102    }
103
104    /// Get the device handle.
105    fn into_handle(self) -> DeviceHandle {
106        self.handle
107    }
108}
109
110const PING_INTERVAL: Duration = Duration::from_secs(10);
111const PONG_TIMEOUT: Duration = Duration::from_secs(20);
112
113/// Typle alias.
114type Connection = Pin<Box<dyn Future<Output = Result<(), Error>> + Send>>;
115
116/// Future representing a device connection.
117///
118/// The future will be resolved when the corresponding connection gets closed.
119pub struct DeviceConnection {
120    connection: Abortable<Connection>,
121}
122
123impl DeviceConnection {
124    /// Create a new device connection.
125    pub async fn new<F, E>(connection: F) -> Result<(Self, DeviceHandle), Error>
126    where
127        F: Future<Output = Result<Upgraded, E>>,
128        E: Into<Error>,
129    {
130        let connection = connection
131            .await
132            .map(TokioIo::new)
133            .map_err(|err| err.into())?;
134
135        let (h2, mut connection) = h2::client::handshake(connection).await?;
136
137        let ping_pong = connection
138            .ping_pong()
139            .expect("unable to get connection ping-pong");
140
141        let keep_alive = KeepAlive::new(ping_pong, PING_INTERVAL, PONG_TIMEOUT);
142
143        let connection: Connection = Box::pin(async move {
144            let keep_alive = keep_alive.run();
145
146            futures::pin_mut!(connection);
147            futures::pin_mut!(keep_alive);
148
149            let select = futures::future::select(connection, keep_alive);
150
151            match select.await {
152                Either::Left((res, _)) => res.map_err(Error::from),
153                Either::Right((res, connection)) => {
154                    if res.is_err() {
155                        res
156                    } else {
157                        connection.await.map_err(Error::from)
158                    }
159                }
160            }
161        });
162
163        let (connection, abort) = futures::future::abortable(connection);
164
165        let (request_tx, mut request_rx) = mpsc::channel::<DeviceRequest>(4);
166
167        tokio::spawn(async move {
168            while let Some(request) = request_rx.next().await {
169                request.spawn_send(h2.clone());
170            }
171        });
172
173        let connection = Self { connection };
174
175        let handle = DeviceHandle { request_tx, abort };
176
177        Ok((connection, handle))
178    }
179}
180
181impl Future for DeviceConnection {
182    type Output = Result<(), Error>;
183
184    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
185        let res = match ready!(self.connection.poll_unpin(cx)) {
186            Ok(Ok(())) => Ok(()),
187            Ok(Err(err)) => Err(err),
188            Err(_) => Ok(()),
189        };
190
191        Poll::Ready(res)
192    }
193}
194
195/// Device handle.
196#[derive(Clone)]
197pub struct DeviceHandle {
198    request_tx: mpsc::Sender<DeviceRequest>,
199    abort: AbortHandle,
200}
201
202impl DeviceHandle {
203    /// Send a given request to the connected device and return a device
204    /// response.
205    pub async fn send_request(&mut self, request: Request<Body>) -> Result<Response<Body>, Error> {
206        let (request, response_rx) = DeviceRequest::new(request);
207
208        self.request_tx.send(request).await.unwrap_or_default();
209
210        response_rx.await
211    }
212
213    /// Close the connection.
214    pub fn close(&self) {
215        self.abort.abort();
216    }
217}
218
219/// Keep-alive handler.
220struct KeepAlive {
221    inner: PingPong,
222    interval: Duration,
223    timeout: Duration,
224}
225
226impl KeepAlive {
227    /// Create a new keep-alive handler.
228    fn new(ping_pong: PingPong, interval: Duration, timeout: Duration) -> Self {
229        Self {
230            inner: ping_pong,
231            interval,
232            timeout,
233        }
234    }
235
236    /// Run the handler.
237    async fn run(mut self) -> Result<(), Error> {
238        let mut next_ping = Instant::now() + self.interval;
239
240        loop {
241            tokio::time::sleep_until(next_ping.into()).await;
242
243            next_ping = Instant::now() + self.interval;
244
245            let pong = tokio::time::timeout(self.timeout, self.inner.ping(Ping::opaque()));
246
247            match pong.await {
248                Ok(Ok(_)) => (),
249                Ok(Err(err)) => {
250                    // do not return any error if the connection was normally
251                    // closed by the remote peer
252                    if let Some(err) = err.get_io() {
253                        if err.kind() == io::ErrorKind::BrokenPipe {
254                            return Ok(());
255                        }
256                    }
257
258                    return Err(err.into());
259                }
260                Err(_) => return Err(Error::from_static_msg("connection timeout")),
261            }
262        }
263    }
264}
265
266/// Request wrapper for requests that shall be sent to a device.
267struct DeviceRequest {
268    request: Request<Body>,
269    response_tx: DeviceResponseTx,
270}
271
272impl DeviceRequest {
273    /// Create a new device request and an associated response future.
274    fn new(request: Request<Body>) -> (Self, DeviceResponseRx) {
275        let (response_tx, response_rx) = oneshot::channel();
276
277        let response_tx = DeviceResponseTx { inner: response_tx };
278        let response_rx = DeviceResponseRx { inner: response_rx };
279
280        let request = Self {
281            request,
282            response_tx,
283        };
284
285        (request, response_rx)
286    }
287
288    /// Send the request into a given device channel in a background task.
289    fn spawn_send(self, channel: SendRequest<Bytes>) {
290        tokio::spawn(self.send(channel));
291    }
292
293    /// Send the request into a given device channel.
294    async fn send(self, channel: SendRequest<Bytes>) {
295        let response = Self::send_internal(self.request, channel).await;
296
297        self.response_tx.send(response);
298    }
299
300    /// Helper function for sending a given HTTP request into a given device
301    /// channel.
302    async fn send_internal(
303        request: Request<Body>,
304        channel: SendRequest<Bytes>,
305    ) -> Result<Response<Body>, Error> {
306        let method = request.method();
307        let headers = request.headers();
308
309        if method == Method::CONNECT || headers.is_connection_upgrade() {
310            Self::send_connect_request(request, channel).await
311        } else {
312            Self::send_standard_request(request, channel).await
313        }
314    }
315
316    /// Helper function for sending a given HTTP request into a given device
317    /// channel.
318    async fn send_connect_request(
319        request: Request<Body>,
320        channel: SendRequest<Bytes>,
321    ) -> Result<Response<Body>, Error> {
322        let version = request.version();
323        let h2_request = request.to_h2_request();
324        let connect_protocol = h2_request.extensions().get::<Protocol>().cloned();
325
326        debug_assert!(h2_request.method() == Method::CONNECT);
327
328        if connect_protocol.is_some() && !channel.is_extended_connect_protocol_enabled() {
329            return Err(Error::from_static_msg(
330                "device does not support connection upgrades",
331            ));
332        }
333
334        let (response, request_body_tx) = channel.ready().await?.send_request(h2_request, false)?;
335
336        let (mut parts, response_body) = response.await?.into_parts();
337
338        parts.version = version;
339
340        parts.headers.remove_hop_by_hop_headers();
341        parts.extensions.clear();
342
343        let response_body_rx = Body::from_stream(ReceiveBody::new(response_body));
344
345        if parts.status.is_success() {
346            let upgrade = hyper::upgrade::on(request);
347
348            tokio::spawn(async move {
349                if let Err(err) =
350                    Self::handle_upgrade(upgrade, request_body_tx, response_body_rx).await
351                {
352                    warn!("connection upgrade failed: {err}");
353                }
354            });
355
356            // NOTE: This may lead to unintentional protocol switch in case
357            // when the device treats it as a standard GET request and
358            // responds with HTTP 2xx.
359            if version < Version::HTTP_2 {
360                if let Some(protocol) = connect_protocol {
361                    parts.status = StatusCode::SWITCHING_PROTOCOLS;
362
363                    let connection = HeaderValue::from_static("upgrade");
364                    let protocol = HeaderValue::from_str(protocol.as_str());
365
366                    parts.headers.insert("connection", connection);
367                    parts.headers.insert("upgrade", protocol.unwrap());
368                }
369            }
370
371            Ok(Response::from_parts(parts, Body::empty()))
372        } else {
373            Ok(Response::from_parts(parts, response_body_rx))
374        }
375    }
376
377    /// Helper function for sending a given HTTP request into a given device
378    /// channel.
379    async fn send_standard_request(
380        request: Request<Body>,
381        channel: SendRequest<Bytes>,
382    ) -> Result<Response<Body>, Error> {
383        let version = request.version();
384        let h2_request = request.to_h2_request();
385
386        debug_assert!(h2_request.method() != Method::CONNECT);
387
388        let (response, request_body_tx) = channel.ready().await?.send_request(h2_request, false)?;
389
390        let body = request.into_body();
391
392        tokio::spawn(async move {
393            if let Err(err) = SendBody::new(body, request_body_tx).await {
394                warn!("unable to send request body: {err}");
395            }
396        });
397
398        let (mut parts, response_body) = response.await?.into_parts();
399
400        parts.version = version;
401
402        parts.headers.remove_hop_by_hop_headers();
403        parts.extensions.clear();
404
405        let body = Body::from_stream(ReceiveBody::new(response_body));
406
407        Ok(Response::from_parts(parts, body))
408    }
409
410    /// Handle communication after connection upgrade.
411    async fn handle_upgrade(
412        upgrade: OnUpgrade,
413        request_body_tx: SendStream<Bytes>,
414        mut response_body_rx: Body,
415    ) -> Result<(), Error> {
416        let upgraded = upgrade.await.map(TokioIo::new).map_err(Error::from_cause)?;
417
418        let (reader, mut writer) = tokio::io::split(upgraded);
419
420        let request_body_rx = ReaderStream::new(reader);
421
422        let device_to_client = async move {
423            while let Some(chunk) = response_body_rx.next().await.transpose()? {
424                writer.write_all(&chunk).await?;
425            }
426
427            writer.shutdown().await.map_err(Error::from)
428        };
429
430        let client_to_device = SendBody::new(request_body_rx, request_body_tx);
431
432        let (d2c_res, c2d_res) = futures::future::join(device_to_client, client_to_device).await;
433
434        d2c_res?;
435        c2d_res?;
436
437        Ok(())
438    }
439}
440
441/// Future that will be resolved into a device response.
442struct DeviceResponseRx {
443    inner: oneshot::Receiver<Result<Response<Body>, Error>>,
444}
445
446impl Future for DeviceResponseRx {
447    type Output = Result<Response<Body>, Error>;
448
449    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
450        match ready!(self.inner.poll_unpin(cx)) {
451            Ok(res) => Poll::Ready(res),
452            Err(_) => Poll::Ready(Err(Error::from_static_msg("device disconnected"))),
453        }
454    }
455}
456
457/// Resolver for the device response future.
458struct DeviceResponseTx {
459    inner: oneshot::Sender<Result<Response<Body>, Error>>,
460}
461
462impl DeviceResponseTx {
463    /// Resolve the device response future.
464    fn send(self, response: Result<Response<Body>, Error>) {
465        self.inner.send(response).unwrap_or_default();
466    }
467}
468
469/// Stream that will handle receiving of an HTTP2 body.
470struct ReceiveBody {
471    inner: RecvStream,
472}
473
474impl ReceiveBody {
475    /// Create a new body stream.
476    fn new(h2: RecvStream) -> Self {
477        Self { inner: h2 }
478    }
479}
480
481impl Stream for ReceiveBody {
482    type Item = io::Result<Bytes>;
483
484    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
485        if let Some(item) = ready!(self.inner.poll_data(cx)) {
486            if let Err(err) = &item {
487                if err.is_reset() || err.is_go_away() {
488                    return Poll::Ready(None);
489                }
490            }
491
492            let data = item.map_err(|err| io::Error::new(io::ErrorKind::Other, err.to_string()))?;
493
494            self.inner
495                .flow_control()
496                .release_capacity(data.len())
497                .unwrap();
498
499            Poll::Ready(Some(Ok(data)))
500        } else {
501            Poll::Ready(None)
502        }
503    }
504}
505
506/// Future that will drive sending of a request/response body into an HTTP2
507/// channel.
508struct SendBody<B> {
509    channel: SendStream<Bytes>,
510    body: B,
511    chunk: Option<Bytes>,
512}
513
514impl<B> SendBody<B> {
515    /// Create a new body sender.
516    fn new(body: B, channel: SendStream<Bytes>) -> Self {
517        Self {
518            channel,
519            body,
520            chunk: None,
521        }
522    }
523
524    /// Poll channel send capacity.
525    fn poll_capacity(
526        &mut self,
527        cx: &mut Context<'_>,
528        required: usize,
529    ) -> Poll<Result<usize, Error>> {
530        let mut capacity = self.channel.capacity();
531
532        while capacity == 0 {
533            // ask the channel for additional send capacity
534            self.channel.reserve_capacity(required);
535
536            capacity = ready!(self.channel.poll_capacity(cx)).ok_or_else(|| {
537                Error::from_static_msg("unable to allocate HTTP2 channel capacity")
538            })??;
539        }
540
541        Poll::Ready(Ok(capacity))
542    }
543}
544
545impl<B, E> SendBody<B>
546where
547    B: Stream<Item = Result<Bytes, E>> + Unpin,
548    E: Into<Error>,
549{
550    /// Poll the next chunk to be sent.
551    fn poll_next_chunk(&mut self, cx: &mut Context<'_>) -> Poll<Result<Option<Bytes>, Error>> {
552        if let Some(chunk) = self.chunk.take() {
553            return Poll::Ready(Ok(Some(chunk)));
554        }
555
556        match ready!(self.body.poll_next_unpin(cx)) {
557            Some(Ok(chunk)) => Poll::Ready(Ok(Some(chunk))),
558            Some(Err(err)) => Poll::Ready(Err(err.into())),
559            None => Poll::Ready(Ok(None)),
560        }
561    }
562}
563
564impl<B, E> Future for SendBody<B>
565where
566    B: Stream<Item = Result<Bytes, E>> + Unpin,
567    E: Into<Error>,
568{
569    type Output = Result<(), Error>;
570
571    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
572        while let Some(mut chunk) = ready!(self.poll_next_chunk(cx))? {
573            if chunk.is_empty() {
574                continue;
575            }
576
577            if let Poll::Ready(capacity) = self.poll_capacity(cx, chunk.len()) {
578                let take = capacity?.min(chunk.len());
579
580                self.channel.send_data(chunk.split_to(take), false)?;
581
582                if !chunk.is_empty() {
583                    self.chunk = Some(chunk);
584                }
585            } else {
586                // we'll use the chunk next time
587                self.chunk = Some(chunk);
588
589                return Poll::Pending;
590            }
591        }
592
593        if let Err(err) = self.channel.send_data(Bytes::new(), true) {
594            // return Ok if we aren't able to send EOF into a closed stream
595            if err.is_reset() || err.is_go_away() {
596                Poll::Ready(Ok(()))
597            } else if err.reason().is_some() || err.is_io() || err.is_remote() {
598                Poll::Ready(Err(err.into()))
599            } else {
600                Poll::Ready(Ok(()))
601            }
602        } else {
603            Poll::Ready(Ok(()))
604        }
605    }
606}