compio_ws/
stream.rs

1#[cfg(feature = "rustls")]
2use std::io::Result as IoResult;
3
4#[cfg(feature = "rustls")]
5use compio_buf::{BufResult, IoBuf, IoBufMut};
6#[cfg(feature = "rustls")]
7use compio_io::{AsyncRead, AsyncWrite};
8#[cfg(feature = "rustls")]
9use compio_tls::TlsStream;
10
11/// Stream that can be either plain TCP or TLS-encrypted
12#[cfg(feature = "rustls")]
13#[derive(Debug)]
14#[allow(clippy::large_enum_variant)]
15pub enum MaybeTlsStream<S> {
16    /// Plain, unencrypted stream
17    Plain(S),
18    /// TLS-encrypted stream
19    #[cfg(feature = "rustls")]
20    Tls(TlsStream<S>),
21}
22
23#[cfg(feature = "rustls")]
24impl<S> MaybeTlsStream<S> {
25    pub fn plain(stream: S) -> Self {
26        MaybeTlsStream::Plain(stream)
27    }
28
29    #[cfg(feature = "rustls")]
30    pub fn tls(stream: TlsStream<S>) -> Self {
31        MaybeTlsStream::Tls(stream)
32    }
33
34    pub fn is_tls(&self) -> bool {
35        #[cfg(feature = "rustls")]
36        {
37            matches!(self, MaybeTlsStream::Tls(_))
38        }
39        #[cfg(not(feature = "rustls"))]
40        {
41            false
42        }
43    }
44}
45
46#[cfg(feature = "rustls")]
47impl<S> AsyncRead for MaybeTlsStream<S>
48where
49    S: AsyncRead + AsyncWrite + Unpin + 'static,
50{
51    async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
52        match self {
53            MaybeTlsStream::Plain(stream) => stream.read(buf).await,
54            #[cfg(feature = "rustls")]
55            MaybeTlsStream::Tls(stream) => stream.read(buf).await,
56        }
57    }
58}
59
60#[cfg(feature = "rustls")]
61impl<S> AsyncWrite for MaybeTlsStream<S>
62where
63    S: AsyncRead + AsyncWrite + Unpin + 'static,
64{
65    async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
66        match self {
67            MaybeTlsStream::Plain(stream) => stream.write(buf).await,
68            #[cfg(feature = "rustls")]
69            MaybeTlsStream::Tls(stream) => stream.write(buf).await,
70        }
71    }
72
73    async fn flush(&mut self) -> IoResult<()> {
74        match self {
75            MaybeTlsStream::Plain(stream) => stream.flush().await,
76            #[cfg(feature = "rustls")]
77            MaybeTlsStream::Tls(stream) => stream.flush().await,
78        }
79    }
80
81    async fn shutdown(&mut self) -> IoResult<()> {
82        match self {
83            MaybeTlsStream::Plain(stream) => stream.shutdown().await,
84            #[cfg(feature = "rustls")]
85            MaybeTlsStream::Tls(stream) => stream.shutdown().await,
86        }
87    }
88}