compio_tls/
maybe.rs

1use std::{borrow::Cow, io};
2
3use compio_buf::{BufResult, IoBuf, IoBufMut};
4use compio_io::{AsyncRead, AsyncWrite};
5
6use crate::TlsStream;
7
8#[derive(Debug)]
9#[allow(clippy::large_enum_variant)]
10enum MaybeTlsStreamInner<S> {
11    /// Plain, unencrypted stream
12    Plain(S),
13    /// TLS-encrypted stream
14    Tls(TlsStream<S>),
15}
16
17/// Stream that can be either plain TCP or TLS-encrypted
18#[derive(Debug)]
19pub struct MaybeTlsStream<S>(MaybeTlsStreamInner<S>);
20
21impl<S> MaybeTlsStream<S> {
22    /// Create an unencrypted stream.
23    pub fn new_plain(stream: S) -> Self {
24        Self(MaybeTlsStreamInner::Plain(stream))
25    }
26
27    /// Create a TLS-encrypted stream.
28    pub fn new_tls(stream: TlsStream<S>) -> Self {
29        Self(MaybeTlsStreamInner::Tls(stream))
30    }
31
32    /// Whether the stream is TLS-encrypted.
33    pub fn is_tls(&self) -> bool {
34        matches!(self.0, MaybeTlsStreamInner::Tls(_))
35    }
36
37    /// Returns the negotiated ALPN protocol.
38    pub fn negotiated_alpn(&self) -> Option<Cow<'_, [u8]>> {
39        match &self.0 {
40            MaybeTlsStreamInner::Plain(_) => None,
41            MaybeTlsStreamInner::Tls(s) => s.negotiated_alpn(),
42        }
43    }
44}
45
46impl<S> AsyncRead for MaybeTlsStream<S>
47where
48    S: AsyncRead + AsyncWrite + Unpin + 'static,
49{
50    async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
51        match &mut self.0 {
52            MaybeTlsStreamInner::Plain(stream) => stream.read(buf).await,
53            MaybeTlsStreamInner::Tls(stream) => stream.read(buf).await,
54        }
55    }
56}
57
58impl<S> AsyncWrite for MaybeTlsStream<S>
59where
60    S: AsyncRead + AsyncWrite + Unpin + 'static,
61{
62    async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
63        match &mut self.0 {
64            MaybeTlsStreamInner::Plain(stream) => stream.write(buf).await,
65            MaybeTlsStreamInner::Tls(stream) => stream.write(buf).await,
66        }
67    }
68
69    async fn flush(&mut self) -> io::Result<()> {
70        match &mut self.0 {
71            MaybeTlsStreamInner::Plain(stream) => stream.flush().await,
72            MaybeTlsStreamInner::Tls(stream) => stream.flush().await,
73        }
74    }
75
76    async fn shutdown(&mut self) -> io::Result<()> {
77        match &mut self.0 {
78            MaybeTlsStreamInner::Plain(stream) => stream.shutdown().await,
79            MaybeTlsStreamInner::Tls(stream) => stream.shutdown().await,
80        }
81    }
82}