blitz_ws/
stream.rs

1//! Convenience wrapper for streams to switch between plain TCP and TLS at runtime.
2//!
3//!  There is no dependency on actual TLS implementations. Everything like
4//! `native_tls` or `openssl` will work as long as there is a TLS stream supporting standard
5//! `Read + Write` traits.
6
7#[cfg(feature = "__rustls-tls")]
8use std::ops::Deref;
9use std::{
10    fmt::Debug,
11    io::{Read, Result as IoResult, Write},
12    net::TcpStream,
13};
14
15#[cfg(feature = "native-tls")]
16use native_tls_crate::TlsStream;
17#[cfg(feature = "__rustls-tls")]
18use rustls::StreamOwned;
19
20/// Stream mode, either plain TCP or TLS.
21#[derive(Debug, Clone, Copy)]
22pub enum Mode {
23    /// Stream mode, either plain TCP or TLS.
24    Plain,
25    /// TLS mode (`wss://` URL).
26    Tls,
27}
28
29/// Trait to switch TCP_NODELAY.
30pub trait NoDelay {
31    /// Set the TCP_NODELAY option to the given value.
32    fn set_nodelay(&mut self, no_delay: bool) -> IoResult<()>;
33}
34
35impl NoDelay for TcpStream {
36    fn set_nodelay(&mut self, no_delay: bool) -> IoResult<()> {
37        TcpStream::set_nodelay(self, no_delay)
38    }
39}
40
41#[cfg(feature = "native-tls")]
42impl<S: Read + Write + NoDelay> NoDelay for TlsStream<S> {
43    fn set_nodelay(&mut self, no_delay: bool) -> IoResult<()> {
44        self.get_mut().set_nodelay(no_delay)
45    }
46}
47
48#[cfg(feature = "__rustls-tls")]
49impl<S, SD, T> NoDelay for StreamOwned<S, T>
50where
51    S: Deref<Target = rustls::ConnectionCommon<SD>>,
52    SD: rustls::SideData,
53    T: Read + Write + NoDelay,
54{
55    fn set_nodelay(&mut self, no_delay: bool) -> IoResult<()> {
56        self.sock.set_nodelay(no_delay)
57    }
58}
59
60/// A simplified stream abstraction that might be protected with TLS.
61#[non_exhaustive]
62#[allow(clippy::large_enum_variant)]
63pub enum SimplifiedStream<S: Read + Write> {
64    /// Unencrypted socket stream.
65    Plain(S),
66
67    /// Encrypted socket stream using `native-tls`.
68    #[cfg(feature = "native-tls")]
69    NativeTls(native_tls_crate::TlsStream<S>),
70
71    /// Encrypted socket stream using `rustls`.
72    #[cfg(feature = "__rustls-tls")]
73    Rustls(rustls::StreamOwned<rustls::ClientConnection, S>),
74}
75
76impl<S: Read + Write + Debug> Debug for SimplifiedStream<S> {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        match self {
79            Self::Plain(s) => f.debug_tuple("SimplifiedStream::Plain").field(s).finish(),
80
81            #[cfg(feature = "native-tls")]
82            Self::NativeTls(s) => f.debug_tuple("SimplifiedStream::NativeTls").field(s).finish(),
83
84            #[cfg(feature = "__rustls-tls")]
85            Self::Rustls(s) => {
86                struct RustlsStreamDebug<'a, S: Read + Write>(
87                    &'a rustls::StreamOwned<rustls::ClientConnection, S>,
88                );
89
90                impl<S: Read + Write + Debug> Debug for RustlsStreamDebug<'_, S> {
91                    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92                        f.debug_struct("StreamOwned")
93                            .field("conn", &self.0.conn)
94                            .field("sock", &self.0.sock)
95                            .finish()
96                    }
97                }
98
99                f.debug_tuple("SimplifiedStream::Rustls").field(&RustlsStreamDebug(s)).finish()
100            }
101        }
102    }
103}
104
105impl<S: Read + Write> Read for SimplifiedStream<S> {
106    fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
107        match self {
108            Self::Plain(ref mut s) => s.read(buf),
109            #[cfg(feature = "native-tls")]
110            Self::NativeTls(ref mut s) => s.read(buf),
111            #[cfg(feature = "__rustls-tls")]
112            Self::Rustls(ref mut s) => s.read(buf),
113        }
114    }
115}
116
117impl<S: Read + Write> Write for SimplifiedStream<S> {
118    fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
119        match self {
120            Self::Plain(ref mut s) => s.write(buf),
121            #[cfg(feature = "native-tls")]
122            Self::NativeTls(ref mut s) => s.write(buf),
123            #[cfg(feature = "__rustls-tls")]
124            Self::Rustls(ref mut s) => s.write(buf),
125        }
126    }
127
128    fn flush(&mut self) -> IoResult<()> {
129        match self {
130            Self::Plain(ref mut s) => s.flush(),
131            #[cfg(feature = "native-tls")]
132            Self::NativeTls(ref mut s) => s.flush(),
133            #[cfg(feature = "__rustls-tls")]
134            Self::Rustls(ref mut s) => s.flush(),
135        }
136    }
137}
138
139impl<S: Read + Write + NoDelay> NoDelay for SimplifiedStream<S> {
140    fn set_nodelay(&mut self, no_delay: bool) -> IoResult<()> {
141        match self {
142            Self::Plain(ref mut s) => s.set_nodelay(no_delay),
143            #[cfg(feature = "native-tls")]
144            Self::NativeTls(ref mut s) => s.set_nodelay(no_delay),
145            #[cfg(feature = "__rustls-tls")]
146            Self::Rustls(ref mut s) => s.set_nodelay(no_delay),
147        }
148    }
149}