blaze_ssl_async/
stream.rs

1//! SSL Stream wrapper around the tokio [TcpStream]
2//!
3//! ```rust,no_run
4//! // BlazeStream is a wrapper over tokio TcpStream
5//! use blaze_ssl_async::BlazeStream;
6//!
7//! // Tokio read write extensions used for read_exact and write_all
8//! use tokio::io::{AsyncReadExt, AsyncWriteExt};
9//!
10//! #[tokio::main]
11//! async fn main() -> std::io::Result<()> {
12//!     // BlazeStream::connect takes in any value that implements ToSocketAddrs
13//!     // some common implementations are "HOST:PORT" and ("HOST", PORT)
14//!     let mut stream = BlazeStream::connect(("159.153.64.175", 42127)).await?;
15//!
16//!     // TODO... Read from the stream as you would a normal TcpStream
17//!     let mut buf = [0u8; 12];
18//!     stream.read_exact(&mut buf).await?;
19//!     // Write the bytes back
20//!     stream.write_all(&buf).await?;
21//!     // You **MUST** flush BlazeSSL streams or else the data will never
22//!     // be sent to the client (Attempt to read will automatically flush)
23//!     stream.flush().await?;
24//!
25//!     Ok(())
26//! }
27//! ```
28//!
29use super::{
30    crypto::rc4::*,
31    msg::{codec::*, deframer::MessageDeframer, types::*, AlertError, Message},
32};
33use crate::{handshake::Handshaking, listener::BlazeServerContext};
34use std::{
35    cmp,
36    io::{self, ErrorKind},
37    pin::Pin,
38    sync::Arc,
39    task::{ready, Context, Poll},
40};
41use tokio::{
42    io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf},
43    net::{TcpStream, ToSocketAddrs},
44};
45
46/// Wrapper around [TcpStream] providing SSL encryption
47pub struct BlazeStream {
48    /// Underlying stream target
49    stream: TcpStream,
50
51    /// Message deframer for de-framing messages from the read stream
52    deframer: MessageDeframer,
53
54    /// Decryptor for decrypting messages if the stream is encrypted
55    pub(crate) decryptor: Option<Rc4Decryptor>,
56    /// Encryptor for encrypting messages if the stream should be encrypted
57    pub(crate) encryptor: Option<Rc4Encryptor>,
58
59    /// Buffer for input that is read from the application layer
60    app_read_buffer: Vec<u8>,
61    /// Buffer for output written to the application layer
62    /// (Written to stream when connection is flushed)
63    app_write_buffer: Vec<u8>,
64
65    /// Buffer for the raw packet contents that are going to be
66    /// written to the stream
67    write_buffer: Vec<u8>,
68
69    /// State determining whether the stream is stopped
70    pub(crate) stopped: bool,
71}
72
73impl BlazeStream {
74    /// Connects to a remote address creating a client blaze stream
75    /// to that address.
76    ///
77    /// # Arguments
78    /// * `addr` - The address to connect to
79    pub async fn connect<A: ToSocketAddrs>(addr: A) -> std::io::Result<Self> {
80        let stream = TcpStream::connect(addr).await?;
81        let mut stream = Self::new(stream);
82
83        // Complete the client handshake
84        if let Err(err) = Handshaking::create_client(&mut stream).await {
85            // Ensure the stream is correctly flushed and shutdown on error
86            _ = stream.shutdown().await;
87            return Err(err);
88        }
89
90        Ok(stream)
91    }
92
93    /// Accepts the connection of `stream` as a client connected
94    /// to a server using the provided `data`
95    ///
96    /// ## Arguments
97    /// * `context` - The server context to use
98    pub async fn accept(
99        stream: TcpStream,
100        context: Arc<BlazeServerContext>,
101    ) -> std::io::Result<Self> {
102        let mut stream = Self::new(stream);
103
104        // Complete the server handshake
105        if let Err(err) = Handshaking::create_server(&mut stream, context).await {
106            // Ensure the stream is correctly flushed and shutdown on error
107            _ = stream.shutdown().await;
108            return Err(err);
109        }
110
111        Ok(stream)
112    }
113
114    /// Returns a reference to the underlying stream
115    pub fn get_ref(&self) -> &TcpStream {
116        &self.stream
117    }
118
119    /// Returns a mutable reference to the underlying stream
120    pub fn get_mut(&mut self) -> &mut TcpStream {
121        &mut self.stream
122    }
123
124    /// Returns the underlying stream that this BlazeStream
125    /// is wrapping
126    pub fn into_inner(self) -> TcpStream {
127        self.stream
128    }
129
130    /// Wraps the provided `stream` with a [BlazeStream] preparing
131    /// it to be used with a handshake state
132    fn new(stream: TcpStream) -> Self {
133        Self {
134            stream,
135            deframer: MessageDeframer::new(),
136            decryptor: None,
137            encryptor: None,
138            app_write_buffer: Vec::new(),
139            app_read_buffer: Vec::new(),
140            write_buffer: Vec::new(),
141            stopped: false,
142        }
143    }
144
145    /// Polls for the next message to be recieved. Decryptes encrypted messages
146    /// and handles alert messages.
147    ///
148    /// # Arguments
149    /// * cx - The polling context
150    pub(crate) fn poll_next_message(
151        &mut self,
152        cx: &mut Context<'_>,
153    ) -> Poll<std::io::Result<Message>> {
154        loop {
155            if let Some(mut message) = self.deframer.next() {
156                // Ensure the protocol version is SSLv3
157                if !message.protocol_version.is_valid() {
158                    // Write the error alert message
159                    self.write_alert(AlertError::fatal(AlertDescription::HandshakeFailure));
160
161                    return Poll::Ready(Err(std::io::Error::new(
162                        ErrorKind::Other,
163                        "Unsupported SSL version",
164                    )));
165                }
166
167                // Decrypt message if encryption is enabled
168                self.try_decrypt_message(&mut message)?;
169                return Poll::Ready(Ok(message));
170            }
171
172            // Poll reading data from the stream
173            ready!(self.deframer.poll_read(&mut self.stream, cx))?;
174        }
175    }
176
177    /// Attempts to decrypt the provied `message` if there is a decryptor set
178    fn try_decrypt_message(&mut self, message: &mut Message) -> std::io::Result<()> {
179        let decryptor = match &mut self.decryptor {
180            Some(value) => value,
181            None => return Ok(()),
182        };
183
184        if decryptor.decrypt(message) {
185            return Ok(());
186        }
187
188        // Write the error alert message
189        self.write_alert(AlertError::fatal(AlertDescription::BadRecordMac));
190
191        Err(std::io::Error::new(ErrorKind::Other, "Bad record mac"))
192    }
193
194    /// Triggers a shutdown by sending a CloseNotify alert
195    ///
196    /// # Arguments
197    /// * cx - The polling context
198    fn poll_shutdown_priv(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
199        // Send the alert if not already stopping
200        if !self.stopped {
201            // Send the shutdown close notify
202            self.write_alert(AlertError::warning(AlertDescription::CloseNotify));
203        }
204
205        // Flush any data before shutdown
206        self.poll_flush_priv(cx)
207    }
208
209    /// Fragments the provided message and encrypts the contents if
210    /// encryption is available writing the output to the underlying
211    /// stream
212    ///
213    /// # Arguments
214    /// * message - The message to write
215    pub(crate) fn write_message(&mut self, message: Message) {
216        for mut msg in message.fragment() {
217            if let Some(writer) = &mut self.encryptor {
218                writer.encrypt(&mut msg)
219            }
220
221            msg.encode(&mut self.write_buffer);
222        }
223    }
224
225    /// Writes an alert message and updates the stopped state
226    ///
227    /// # Arguments
228    /// * alert - The alert to write
229    pub(crate) fn write_alert(&mut self, alert: AlertError) {
230        let mut payload = Vec::new();
231        alert.encode(&mut payload);
232
233        let message = Message::new(MessageType::Alert, payload);
234
235        // Internally handle the alert being sent
236        self.write_message(message);
237
238        // Handle stopping from an alert
239        self.stopped = true;
240    }
241
242    /// Writes the provided bytes as application data to the
243    /// app write buffer
244    fn write_app_data(&mut self, buf: &[u8]) -> io::Result<usize> {
245        if self.stopped {
246            return Err(io_closed());
247        };
248        self.app_write_buffer.extend_from_slice(buf);
249        Ok(buf.len())
250    }
251
252    /// Polls reading application data from the app
253    fn poll_read_priv(
254        &mut self,
255        cx: &mut Context<'_>,
256        buf: &mut ReadBuf<'_>,
257    ) -> Poll<io::Result<()>> {
258        // Poll flushing the write buffer before attempting to read
259        ready!(self.poll_flush_priv(cx))?;
260
261        if self.stopped {
262            return Poll::Ready(Err(io_closed()));
263        }
264
265        // Poll for app data from the stream
266        let count = ready!(self.poll_app_data(cx))?;
267
268        // Handle already stopped streams
269        if self.stopped {
270            return Poll::Ready(Err(io_closed()));
271        }
272
273        // Calculate the amount to read based on the buf size and the amount stored
274        let read = cmp::min(buf.remaining(), count);
275        if read > 0 {
276            // Provide the data and replace the stored slice
277            let new_buffer = self.app_read_buffer.split_off(read);
278            buf.put_slice(&self.app_read_buffer);
279            self.app_read_buffer = new_buffer;
280        }
281
282        Poll::Ready(Ok(()))
283    }
284
285    /// Polls flushing all the data for this stream that includes app data
286    /// and the write buffer. This involves writing everything to the write
287    /// buffer and then writing all the data to the stream and attempting
288    /// to flush the stream
289    pub(crate) fn poll_flush_priv(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
290        // Write any written app data as a message to the write buffer
291        if !self.app_write_buffer.is_empty() {
292            let message = Message::new(
293                MessageType::ApplicationData,
294                self.app_write_buffer.split_off(0),
295            );
296
297            self.write_message(message);
298        }
299
300        // Try flushing the internal write buffer
301        let mut write_count: usize = 0;
302        while !self.write_buffer.is_empty() {
303            let stream = Pin::new(&mut self.stream);
304            let count = ready!(stream.poll_write(cx, &self.write_buffer))?;
305            if count > 0 {
306                self.write_buffer = self.write_buffer.split_off(count);
307                write_count += count;
308            }
309        }
310
311        // Skip flushing if we haven't written any data
312        if write_count == 0 {
313            return Poll::Ready(Ok(()));
314        }
315
316        // Try flush the underlying stream
317        Pin::new(&mut self.stream).poll_flush(cx)
318    }
319
320    /// Polls for application data or returns the already present amount of application
321    /// data stored in this stream, Collects application data by polling for messages
322    fn poll_app_data(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
323        let buffer_len = self.app_read_buffer.len();
324
325        // Early return if the buffer is not yet empty
326        if buffer_len != 0 {
327            return Poll::Ready(Ok(buffer_len));
328        }
329
330        // Poll for the next message
331        let message = ready!(self.poll_next_message(cx))?;
332
333        match message.message_type {
334            // Handle errors from the client
335            MessageType::Alert => {
336                let alert = AlertError::from_message(&message);
337
338                // Stop the stream
339                self.stopped = true;
340
341                // On error ready 0 bytes
342                Poll::Ready(Err(io::Error::new(ErrorKind::Other, alert)))
343            }
344
345            // Handle application data
346            MessageType::ApplicationData => {
347                let payload = message.payload;
348                self.app_read_buffer.extend_from_slice(&payload);
349                Poll::Ready(Ok(payload.len()))
350            }
351
352            // Unexpected message kind
353            _ => {
354                self.write_alert(AlertError::fatal(AlertDescription::UnexpectedMessage));
355
356                Poll::Ready(Err(io::Error::new(
357                    ErrorKind::Other,
358                    "Expected application data but got something else",
359                )))
360            }
361        }
362    }
363}
364
365impl AsyncRead for BlazeStream {
366    /// Read polling handled by internal poll_read_priv
367    ///
368    /// # Arguments
369    /// * cx - The polling context
370    /// * buf - The read buffer to read to
371    fn poll_read(
372        self: Pin<&mut Self>,
373        cx: &mut Context<'_>,
374        buf: &mut ReadBuf<'_>,
375    ) -> Poll<io::Result<()>> {
376        self.get_mut().poll_read_priv(cx, buf)
377    }
378}
379
380impl AsyncWrite for BlazeStream {
381    /// Writing polling is always ready as the data is written
382    /// directly to a vec buffer
383    ///
384    /// # Arguments
385    /// * _cx - The polling context
386    /// * buf - The slice of bytes to write as app data
387    fn poll_write(
388        self: Pin<&mut Self>,
389        _cx: &mut Context<'_>,
390        buf: &[u8],
391    ) -> Poll<io::Result<usize>> {
392        Poll::Ready(self.get_mut().write_app_data(buf))
393    }
394
395    /// Polls the internal flushing funciton
396    ///
397    /// # Arguments
398    /// * cx - The polling context
399    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
400        self.get_mut().poll_flush_priv(cx)
401    }
402
403    /// Polls the internal shutdown function
404    ///
405    /// # Arguments
406    /// * cx - The polling context
407    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
408        self.get_mut().poll_shutdown_priv(cx)
409    }
410}
411
412/// Creates an error indicating that the stream is closed
413fn io_closed() -> io::Error {
414    io::Error::new(ErrorKind::UnexpectedEof, "Connection closed")
415}