async_web_server/
tcp_or_tls.rs

1use crate::{HttpIncoming, TcpStream, TlsStream};
2use futures::prelude::*;
3use futures::stream::{FusedStream, SelectAll};
4use futures::StreamExt;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7
8pub trait IsTls {
9    fn is_tls(&self) -> bool;
10}
11
12pub enum TcpOrTlsStream {
13    Tcp(TcpStream),
14    Tls(TlsStream),
15}
16
17impl IsTls for TcpOrTlsStream {
18    fn is_tls(&self) -> bool {
19        match self {
20            Self::Tcp(_) => false,
21            Self::Tls(_) => true,
22        }
23    }
24}
25
26pub struct TcpOrTlsIncoming {
27    incomings: SelectAll<Box<dyn Stream<Item = TcpOrTlsStream> + Unpin>>,
28}
29
30impl TcpOrTlsIncoming {
31    pub fn new() -> Self {
32        Self {
33            incomings: SelectAll::new(),
34        }
35    }
36    pub fn push(
37        &mut self,
38        incoming: impl Stream<Item = impl Into<TcpOrTlsStream>> + Unpin + 'static,
39    ) {
40        self.incomings
41            .push(Box::new(incoming.map(|stream| stream.into())))
42    }
43    pub fn merge(&mut self, other: Self) {
44        self.incomings.extend(other.incomings.into_iter())
45    }
46    pub fn http(self) -> HttpIncoming<TcpOrTlsStream, Self> {
47        HttpIncoming::new(self)
48    }
49}
50
51impl Unpin for TcpOrTlsIncoming {}
52
53impl Stream for TcpOrTlsIncoming {
54    type Item = TcpOrTlsStream;
55
56    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
57        self.incomings.poll_next_unpin(cx)
58    }
59}
60
61impl FusedStream for TcpOrTlsIncoming {
62    fn is_terminated(&self) -> bool {
63        self.incomings.is_terminated()
64    }
65}
66
67impl AsyncRead for TcpOrTlsStream {
68    fn poll_read(
69        self: Pin<&mut Self>,
70        cx: &mut Context<'_>,
71        buf: &mut [u8],
72    ) -> Poll<std::io::Result<usize>> {
73        match self.get_mut() {
74            TcpOrTlsStream::Tcp(tcp) => Pin::new(tcp).poll_read(cx, buf),
75            TcpOrTlsStream::Tls(tls) => Pin::new(tls).poll_read(cx, buf),
76        }
77    }
78}
79
80impl AsyncWrite for TcpOrTlsStream {
81    fn poll_write(
82        self: Pin<&mut Self>,
83        cx: &mut Context<'_>,
84        buf: &[u8],
85    ) -> Poll<std::io::Result<usize>> {
86        match self.get_mut() {
87            TcpOrTlsStream::Tcp(tcp) => Pin::new(tcp).poll_write(cx, buf),
88            TcpOrTlsStream::Tls(tls) => Pin::new(tls).poll_write(cx, buf),
89        }
90    }
91
92    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
93        match self.get_mut() {
94            TcpOrTlsStream::Tcp(tcp) => Pin::new(tcp).poll_flush(cx),
95            TcpOrTlsStream::Tls(tls) => Pin::new(tls).poll_flush(cx),
96        }
97    }
98
99    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
100        match self.get_mut() {
101            TcpOrTlsStream::Tcp(tcp) => Pin::new(tcp).poll_close(cx),
102            TcpOrTlsStream::Tls(tls) => Pin::new(tls).poll_close(cx),
103        }
104    }
105}
106
107impl From<TcpStream> for TcpOrTlsStream {
108    fn from(value: TcpStream) -> Self {
109        TcpOrTlsStream::Tcp(value)
110    }
111}
112
113impl From<TlsStream> for TcpOrTlsStream {
114    fn from(value: TlsStream) -> Self {
115        TcpOrTlsStream::Tls(value)
116    }
117}