1#[cfg(feature = "rustls")]
2use std::io::Result as IoResult;
3
4#[cfg(feature = "rustls")]
5use compio_buf::{BufResult, IoBuf, IoBufMut};
6#[cfg(feature = "rustls")]
7use compio_io::{AsyncRead, AsyncWrite};
8#[cfg(feature = "rustls")]
9use compio_tls::TlsStream;
10
11#[cfg(feature = "rustls")]
13#[derive(Debug)]
14#[allow(clippy::large_enum_variant)]
15pub enum MaybeTlsStream<S> {
16 Plain(S),
18 #[cfg(feature = "rustls")]
20 Tls(TlsStream<S>),
21}
22
23#[cfg(feature = "rustls")]
24impl<S> MaybeTlsStream<S> {
25 pub fn plain(stream: S) -> Self {
26 MaybeTlsStream::Plain(stream)
27 }
28
29 #[cfg(feature = "rustls")]
30 pub fn tls(stream: TlsStream<S>) -> Self {
31 MaybeTlsStream::Tls(stream)
32 }
33
34 pub fn is_tls(&self) -> bool {
35 #[cfg(feature = "rustls")]
36 {
37 matches!(self, MaybeTlsStream::Tls(_))
38 }
39 #[cfg(not(feature = "rustls"))]
40 {
41 false
42 }
43 }
44}
45
46#[cfg(feature = "rustls")]
47impl<S> AsyncRead for MaybeTlsStream<S>
48where
49 S: AsyncRead + AsyncWrite + Unpin + 'static,
50{
51 async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
52 match self {
53 MaybeTlsStream::Plain(stream) => stream.read(buf).await,
54 #[cfg(feature = "rustls")]
55 MaybeTlsStream::Tls(stream) => stream.read(buf).await,
56 }
57 }
58}
59
60#[cfg(feature = "rustls")]
61impl<S> AsyncWrite for MaybeTlsStream<S>
62where
63 S: AsyncRead + AsyncWrite + Unpin + 'static,
64{
65 async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
66 match self {
67 MaybeTlsStream::Plain(stream) => stream.write(buf).await,
68 #[cfg(feature = "rustls")]
69 MaybeTlsStream::Tls(stream) => stream.write(buf).await,
70 }
71 }
72
73 async fn flush(&mut self) -> IoResult<()> {
74 match self {
75 MaybeTlsStream::Plain(stream) => stream.flush().await,
76 #[cfg(feature = "rustls")]
77 MaybeTlsStream::Tls(stream) => stream.flush().await,
78 }
79 }
80
81 async fn shutdown(&mut self) -> IoResult<()> {
82 match self {
83 MaybeTlsStream::Plain(stream) => stream.shutdown().await,
84 #[cfg(feature = "rustls")]
85 MaybeTlsStream::Tls(stream) => stream.shutdown().await,
86 }
87 }
88}