1use std::{
2 collections::HashMap,
3 future::Future,
4 io,
5 pin::Pin,
6 sync::{Arc, Mutex},
7 task::{Context, Poll},
8 time::{Duration, Instant},
9};
10
11use bytes::Bytes;
12use futures::{
13 channel::{mpsc, oneshot},
14 future::{AbortHandle, Abortable, Either},
15 ready, FutureExt, SinkExt, Stream, StreamExt,
16};
17use h2::{client::SendRequest, ext::Protocol, Ping, PingPong, RecvStream, SendStream};
18use http::{HeaderValue, Method, StatusCode, Version};
19use hyper::{
20 upgrade::{OnUpgrade, Upgraded},
21 Request, Response,
22};
23use hyper_util::rt::TokioIo;
24use tokio::io::AsyncWriteExt;
25use tokio_util::io::ReaderStream;
26use uuid::Uuid;
27
28use crate::{
29 utils::{HeaderMapExt, RequestExt},
30 Body, Error,
31};
32
33#[derive(Clone)]
35pub struct DeviceManager {
36 devices: Arc<Mutex<HashMap<String, DeviceEntry>>>,
37}
38
39impl DeviceManager {
40 pub fn new() -> Self {
42 Self {
43 devices: Arc::new(Mutex::new(HashMap::new())),
44 }
45 }
46
47 pub fn add(
49 &self,
50 device_id: &str,
51 session_id: Uuid,
52 handle: DeviceHandle,
53 ) -> Option<DeviceHandle> {
54 self.devices
55 .lock()
56 .unwrap()
57 .insert(device_id.to_string(), DeviceEntry::new(session_id, handle))
58 .map(|old| old.into_handle())
59 }
60
61 pub fn remove(&self, device_id: &str, session_id: Option<Uuid>) -> Option<DeviceHandle> {
63 let mut devices = self.devices.lock().unwrap();
64
65 let entry = devices.get(device_id)?;
66
67 if let Some(session_id) = session_id {
68 if session_id != entry.session_id {
69 return None;
70 }
71 }
72
73 devices.remove(device_id).map(|entry| entry.into_handle())
74 }
75
76 pub fn get(&self, device_id: &str) -> Option<DeviceHandle> {
78 self.devices
79 .lock()
80 .unwrap()
81 .get(device_id)
82 .map(|entry| entry.handle())
83 .cloned()
84 }
85}
86
87struct DeviceEntry {
89 session_id: Uuid,
90 handle: DeviceHandle,
91}
92
93impl DeviceEntry {
94 fn new(session_id: Uuid, handle: DeviceHandle) -> Self {
96 Self { session_id, handle }
97 }
98
99 fn handle(&self) -> &DeviceHandle {
101 &self.handle
102 }
103
104 fn into_handle(self) -> DeviceHandle {
106 self.handle
107 }
108}
109
110const PING_INTERVAL: Duration = Duration::from_secs(10);
111const PONG_TIMEOUT: Duration = Duration::from_secs(20);
112
113type Connection = Pin<Box<dyn Future<Output = Result<(), Error>> + Send>>;
115
116pub struct DeviceConnection {
120 connection: Abortable<Connection>,
121}
122
123impl DeviceConnection {
124 pub async fn new<F, E>(connection: F) -> Result<(Self, DeviceHandle), Error>
126 where
127 F: Future<Output = Result<Upgraded, E>>,
128 E: Into<Error>,
129 {
130 let connection = connection
131 .await
132 .map(TokioIo::new)
133 .map_err(|err| err.into())?;
134
135 let (h2, mut connection) = h2::client::handshake(connection).await?;
136
137 let ping_pong = connection
138 .ping_pong()
139 .expect("unable to get connection ping-pong");
140
141 let keep_alive = KeepAlive::new(ping_pong, PING_INTERVAL, PONG_TIMEOUT);
142
143 let connection: Connection = Box::pin(async move {
144 let keep_alive = keep_alive.run();
145
146 futures::pin_mut!(connection);
147 futures::pin_mut!(keep_alive);
148
149 let select = futures::future::select(connection, keep_alive);
150
151 match select.await {
152 Either::Left((res, _)) => res.map_err(Error::from),
153 Either::Right((res, connection)) => {
154 if res.is_err() {
155 res
156 } else {
157 connection.await.map_err(Error::from)
158 }
159 }
160 }
161 });
162
163 let (connection, abort) = futures::future::abortable(connection);
164
165 let (request_tx, mut request_rx) = mpsc::channel::<DeviceRequest>(4);
166
167 tokio::spawn(async move {
168 while let Some(request) = request_rx.next().await {
169 request.spawn_send(h2.clone());
170 }
171 });
172
173 let connection = Self { connection };
174
175 let handle = DeviceHandle { request_tx, abort };
176
177 Ok((connection, handle))
178 }
179}
180
181impl Future for DeviceConnection {
182 type Output = Result<(), Error>;
183
184 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
185 let res = match ready!(self.connection.poll_unpin(cx)) {
186 Ok(Ok(())) => Ok(()),
187 Ok(Err(err)) => Err(err),
188 Err(_) => Ok(()),
189 };
190
191 Poll::Ready(res)
192 }
193}
194
195#[derive(Clone)]
197pub struct DeviceHandle {
198 request_tx: mpsc::Sender<DeviceRequest>,
199 abort: AbortHandle,
200}
201
202impl DeviceHandle {
203 pub async fn send_request(&mut self, request: Request<Body>) -> Result<Response<Body>, Error> {
206 let (request, response_rx) = DeviceRequest::new(request);
207
208 self.request_tx.send(request).await.unwrap_or_default();
209
210 response_rx.await
211 }
212
213 pub fn close(&self) {
215 self.abort.abort();
216 }
217}
218
219struct KeepAlive {
221 inner: PingPong,
222 interval: Duration,
223 timeout: Duration,
224}
225
226impl KeepAlive {
227 fn new(ping_pong: PingPong, interval: Duration, timeout: Duration) -> Self {
229 Self {
230 inner: ping_pong,
231 interval,
232 timeout,
233 }
234 }
235
236 async fn run(mut self) -> Result<(), Error> {
238 let mut next_ping = Instant::now() + self.interval;
239
240 loop {
241 tokio::time::sleep_until(next_ping.into()).await;
242
243 next_ping = Instant::now() + self.interval;
244
245 let pong = tokio::time::timeout(self.timeout, self.inner.ping(Ping::opaque()));
246
247 match pong.await {
248 Ok(Ok(_)) => (),
249 Ok(Err(err)) => {
250 if let Some(err) = err.get_io() {
253 if err.kind() == io::ErrorKind::BrokenPipe {
254 return Ok(());
255 }
256 }
257
258 return Err(err.into());
259 }
260 Err(_) => return Err(Error::from_static_msg("connection timeout")),
261 }
262 }
263 }
264}
265
266struct DeviceRequest {
268 request: Request<Body>,
269 response_tx: DeviceResponseTx,
270}
271
272impl DeviceRequest {
273 fn new(request: Request<Body>) -> (Self, DeviceResponseRx) {
275 let (response_tx, response_rx) = oneshot::channel();
276
277 let response_tx = DeviceResponseTx { inner: response_tx };
278 let response_rx = DeviceResponseRx { inner: response_rx };
279
280 let request = Self {
281 request,
282 response_tx,
283 };
284
285 (request, response_rx)
286 }
287
288 fn spawn_send(self, channel: SendRequest<Bytes>) {
290 tokio::spawn(self.send(channel));
291 }
292
293 async fn send(self, channel: SendRequest<Bytes>) {
295 let response = Self::send_internal(self.request, channel).await;
296
297 self.response_tx.send(response);
298 }
299
300 async fn send_internal(
303 request: Request<Body>,
304 channel: SendRequest<Bytes>,
305 ) -> Result<Response<Body>, Error> {
306 let method = request.method();
307 let headers = request.headers();
308
309 if method == Method::CONNECT || headers.is_connection_upgrade() {
310 Self::send_connect_request(request, channel).await
311 } else {
312 Self::send_standard_request(request, channel).await
313 }
314 }
315
316 async fn send_connect_request(
319 request: Request<Body>,
320 channel: SendRequest<Bytes>,
321 ) -> Result<Response<Body>, Error> {
322 let version = request.version();
323 let h2_request = request.to_h2_request();
324 let connect_protocol = h2_request.extensions().get::<Protocol>().cloned();
325
326 debug_assert!(h2_request.method() == Method::CONNECT);
327
328 if connect_protocol.is_some() && !channel.is_extended_connect_protocol_enabled() {
329 return Err(Error::from_static_msg(
330 "device does not support connection upgrades",
331 ));
332 }
333
334 let (response, request_body_tx) = channel.ready().await?.send_request(h2_request, false)?;
335
336 let (mut parts, response_body) = response.await?.into_parts();
337
338 parts.version = version;
339
340 parts.headers.remove_hop_by_hop_headers();
341 parts.extensions.clear();
342
343 let response_body_rx = Body::from_stream(ReceiveBody::new(response_body));
344
345 if parts.status.is_success() {
346 let upgrade = hyper::upgrade::on(request);
347
348 tokio::spawn(async move {
349 if let Err(err) =
350 Self::handle_upgrade(upgrade, request_body_tx, response_body_rx).await
351 {
352 warn!("connection upgrade failed: {err}");
353 }
354 });
355
356 if version < Version::HTTP_2 {
360 if let Some(protocol) = connect_protocol {
361 parts.status = StatusCode::SWITCHING_PROTOCOLS;
362
363 let connection = HeaderValue::from_static("upgrade");
364 let protocol = HeaderValue::from_str(protocol.as_str());
365
366 parts.headers.insert("connection", connection);
367 parts.headers.insert("upgrade", protocol.unwrap());
368 }
369 }
370
371 Ok(Response::from_parts(parts, Body::empty()))
372 } else {
373 Ok(Response::from_parts(parts, response_body_rx))
374 }
375 }
376
377 async fn send_standard_request(
380 request: Request<Body>,
381 channel: SendRequest<Bytes>,
382 ) -> Result<Response<Body>, Error> {
383 let version = request.version();
384 let h2_request = request.to_h2_request();
385
386 debug_assert!(h2_request.method() != Method::CONNECT);
387
388 let (response, request_body_tx) = channel.ready().await?.send_request(h2_request, false)?;
389
390 let body = request.into_body();
391
392 tokio::spawn(async move {
393 if let Err(err) = SendBody::new(body, request_body_tx).await {
394 warn!("unable to send request body: {err}");
395 }
396 });
397
398 let (mut parts, response_body) = response.await?.into_parts();
399
400 parts.version = version;
401
402 parts.headers.remove_hop_by_hop_headers();
403 parts.extensions.clear();
404
405 let body = Body::from_stream(ReceiveBody::new(response_body));
406
407 Ok(Response::from_parts(parts, body))
408 }
409
410 async fn handle_upgrade(
412 upgrade: OnUpgrade,
413 request_body_tx: SendStream<Bytes>,
414 mut response_body_rx: Body,
415 ) -> Result<(), Error> {
416 let upgraded = upgrade.await.map(TokioIo::new).map_err(Error::from_cause)?;
417
418 let (reader, mut writer) = tokio::io::split(upgraded);
419
420 let request_body_rx = ReaderStream::new(reader);
421
422 let device_to_client = async move {
423 while let Some(chunk) = response_body_rx.next().await.transpose()? {
424 writer.write_all(&chunk).await?;
425 }
426
427 writer.shutdown().await.map_err(Error::from)
428 };
429
430 let client_to_device = SendBody::new(request_body_rx, request_body_tx);
431
432 let (d2c_res, c2d_res) = futures::future::join(device_to_client, client_to_device).await;
433
434 d2c_res?;
435 c2d_res?;
436
437 Ok(())
438 }
439}
440
441struct DeviceResponseRx {
443 inner: oneshot::Receiver<Result<Response<Body>, Error>>,
444}
445
446impl Future for DeviceResponseRx {
447 type Output = Result<Response<Body>, Error>;
448
449 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
450 match ready!(self.inner.poll_unpin(cx)) {
451 Ok(res) => Poll::Ready(res),
452 Err(_) => Poll::Ready(Err(Error::from_static_msg("device disconnected"))),
453 }
454 }
455}
456
457struct DeviceResponseTx {
459 inner: oneshot::Sender<Result<Response<Body>, Error>>,
460}
461
462impl DeviceResponseTx {
463 fn send(self, response: Result<Response<Body>, Error>) {
465 self.inner.send(response).unwrap_or_default();
466 }
467}
468
469struct ReceiveBody {
471 inner: RecvStream,
472}
473
474impl ReceiveBody {
475 fn new(h2: RecvStream) -> Self {
477 Self { inner: h2 }
478 }
479}
480
481impl Stream for ReceiveBody {
482 type Item = io::Result<Bytes>;
483
484 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
485 if let Some(item) = ready!(self.inner.poll_data(cx)) {
486 if let Err(err) = &item {
487 if err.is_reset() || err.is_go_away() {
488 return Poll::Ready(None);
489 }
490 }
491
492 let data = item.map_err(|err| io::Error::new(io::ErrorKind::Other, err.to_string()))?;
493
494 self.inner
495 .flow_control()
496 .release_capacity(data.len())
497 .unwrap();
498
499 Poll::Ready(Some(Ok(data)))
500 } else {
501 Poll::Ready(None)
502 }
503 }
504}
505
506struct SendBody<B> {
509 channel: SendStream<Bytes>,
510 body: B,
511 chunk: Option<Bytes>,
512}
513
514impl<B> SendBody<B> {
515 fn new(body: B, channel: SendStream<Bytes>) -> Self {
517 Self {
518 channel,
519 body,
520 chunk: None,
521 }
522 }
523
524 fn poll_capacity(
526 &mut self,
527 cx: &mut Context<'_>,
528 required: usize,
529 ) -> Poll<Result<usize, Error>> {
530 let mut capacity = self.channel.capacity();
531
532 while capacity == 0 {
533 self.channel.reserve_capacity(required);
535
536 capacity = ready!(self.channel.poll_capacity(cx)).ok_or_else(|| {
537 Error::from_static_msg("unable to allocate HTTP2 channel capacity")
538 })??;
539 }
540
541 Poll::Ready(Ok(capacity))
542 }
543}
544
545impl<B, E> SendBody<B>
546where
547 B: Stream<Item = Result<Bytes, E>> + Unpin,
548 E: Into<Error>,
549{
550 fn poll_next_chunk(&mut self, cx: &mut Context<'_>) -> Poll<Result<Option<Bytes>, Error>> {
552 if let Some(chunk) = self.chunk.take() {
553 return Poll::Ready(Ok(Some(chunk)));
554 }
555
556 match ready!(self.body.poll_next_unpin(cx)) {
557 Some(Ok(chunk)) => Poll::Ready(Ok(Some(chunk))),
558 Some(Err(err)) => Poll::Ready(Err(err.into())),
559 None => Poll::Ready(Ok(None)),
560 }
561 }
562}
563
564impl<B, E> Future for SendBody<B>
565where
566 B: Stream<Item = Result<Bytes, E>> + Unpin,
567 E: Into<Error>,
568{
569 type Output = Result<(), Error>;
570
571 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
572 while let Some(mut chunk) = ready!(self.poll_next_chunk(cx))? {
573 if chunk.is_empty() {
574 continue;
575 }
576
577 if let Poll::Ready(capacity) = self.poll_capacity(cx, chunk.len()) {
578 let take = capacity?.min(chunk.len());
579
580 self.channel.send_data(chunk.split_to(take), false)?;
581
582 if !chunk.is_empty() {
583 self.chunk = Some(chunk);
584 }
585 } else {
586 self.chunk = Some(chunk);
588
589 return Poll::Pending;
590 }
591 }
592
593 if let Err(err) = self.channel.send_data(Bytes::new(), true) {
594 if err.is_reset() || err.is_go_away() {
596 Poll::Ready(Ok(()))
597 } else if err.reason().is_some() || err.is_io() || err.is_remote() {
598 Poll::Ready(Err(err.into()))
599 } else {
600 Poll::Ready(Ok(()))
601 }
602 } else {
603 Poll::Ready(Ok(()))
604 }
605 }
606}