blaze_ssl_async/stream.rs
1//! SSL Stream wrapper around the tokio [TcpStream]
2//!
3//! ```rust,no_run
4//! // BlazeStream is a wrapper over tokio TcpStream
5//! use blaze_ssl_async::BlazeStream;
6//!
7//! // Tokio read write extensions used for read_exact and write_all
8//! use tokio::io::{AsyncReadExt, AsyncWriteExt};
9//!
10//! #[tokio::main]
11//! async fn main() -> std::io::Result<()> {
12//! // BlazeStream::connect takes in any value that implements ToSocketAddrs
13//! // some common implementations are "HOST:PORT" and ("HOST", PORT)
14//! let mut stream = BlazeStream::connect(("159.153.64.175", 42127)).await?;
15//!
16//! // TODO... Read from the stream as you would a normal TcpStream
17//! let mut buf = [0u8; 12];
18//! stream.read_exact(&mut buf).await?;
19//! // Write the bytes back
20//! stream.write_all(&buf).await?;
21//! // You **MUST** flush BlazeSSL streams or else the data will never
22//! // be sent to the client (Attempt to read will automatically flush)
23//! stream.flush().await?;
24//!
25//! Ok(())
26//! }
27//! ```
28//!
29use super::{
30 crypto::rc4::*,
31 msg::{codec::*, deframer::MessageDeframer, types::*, AlertError, Message},
32};
33use crate::{handshake::Handshaking, listener::BlazeServerContext};
34use std::{
35 cmp,
36 io::{self, ErrorKind},
37 pin::Pin,
38 sync::Arc,
39 task::{ready, Context, Poll},
40};
41use tokio::{
42 io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf},
43 net::{TcpStream, ToSocketAddrs},
44};
45
46/// Wrapper around [TcpStream] providing SSL encryption
47pub struct BlazeStream {
48 /// Underlying stream target
49 stream: TcpStream,
50
51 /// Message deframer for de-framing messages from the read stream
52 deframer: MessageDeframer,
53
54 /// Decryptor for decrypting messages if the stream is encrypted
55 pub(crate) decryptor: Option<Rc4Decryptor>,
56 /// Encryptor for encrypting messages if the stream should be encrypted
57 pub(crate) encryptor: Option<Rc4Encryptor>,
58
59 /// Buffer for input that is read from the application layer
60 app_read_buffer: Vec<u8>,
61 /// Buffer for output written to the application layer
62 /// (Written to stream when connection is flushed)
63 app_write_buffer: Vec<u8>,
64
65 /// Buffer for the raw packet contents that are going to be
66 /// written to the stream
67 write_buffer: Vec<u8>,
68
69 /// State determining whether the stream is stopped
70 pub(crate) stopped: bool,
71}
72
73impl BlazeStream {
74 /// Connects to a remote address creating a client blaze stream
75 /// to that address.
76 ///
77 /// # Arguments
78 /// * `addr` - The address to connect to
79 pub async fn connect<A: ToSocketAddrs>(addr: A) -> std::io::Result<Self> {
80 let stream = TcpStream::connect(addr).await?;
81 let mut stream = Self::new(stream);
82
83 // Complete the client handshake
84 if let Err(err) = Handshaking::create_client(&mut stream).await {
85 // Ensure the stream is correctly flushed and shutdown on error
86 _ = stream.shutdown().await;
87 return Err(err);
88 }
89
90 Ok(stream)
91 }
92
93 /// Accepts the connection of `stream` as a client connected
94 /// to a server using the provided `data`
95 ///
96 /// ## Arguments
97 /// * `context` - The server context to use
98 pub async fn accept(
99 stream: TcpStream,
100 context: Arc<BlazeServerContext>,
101 ) -> std::io::Result<Self> {
102 let mut stream = Self::new(stream);
103
104 // Complete the server handshake
105 if let Err(err) = Handshaking::create_server(&mut stream, context).await {
106 // Ensure the stream is correctly flushed and shutdown on error
107 _ = stream.shutdown().await;
108 return Err(err);
109 }
110
111 Ok(stream)
112 }
113
114 /// Returns a reference to the underlying stream
115 pub fn get_ref(&self) -> &TcpStream {
116 &self.stream
117 }
118
119 /// Returns a mutable reference to the underlying stream
120 pub fn get_mut(&mut self) -> &mut TcpStream {
121 &mut self.stream
122 }
123
124 /// Returns the underlying stream that this BlazeStream
125 /// is wrapping
126 pub fn into_inner(self) -> TcpStream {
127 self.stream
128 }
129
130 /// Wraps the provided `stream` with a [BlazeStream] preparing
131 /// it to be used with a handshake state
132 fn new(stream: TcpStream) -> Self {
133 Self {
134 stream,
135 deframer: MessageDeframer::new(),
136 decryptor: None,
137 encryptor: None,
138 app_write_buffer: Vec::new(),
139 app_read_buffer: Vec::new(),
140 write_buffer: Vec::new(),
141 stopped: false,
142 }
143 }
144
145 /// Polls for the next message to be recieved. Decryptes encrypted messages
146 /// and handles alert messages.
147 ///
148 /// # Arguments
149 /// * cx - The polling context
150 pub(crate) fn poll_next_message(
151 &mut self,
152 cx: &mut Context<'_>,
153 ) -> Poll<std::io::Result<Message>> {
154 loop {
155 if let Some(mut message) = self.deframer.next() {
156 // Ensure the protocol version is SSLv3
157 if !message.protocol_version.is_valid() {
158 // Write the error alert message
159 self.write_alert(AlertError::fatal(AlertDescription::HandshakeFailure));
160
161 return Poll::Ready(Err(std::io::Error::new(
162 ErrorKind::Other,
163 "Unsupported SSL version",
164 )));
165 }
166
167 // Decrypt message if encryption is enabled
168 self.try_decrypt_message(&mut message)?;
169 return Poll::Ready(Ok(message));
170 }
171
172 // Poll reading data from the stream
173 ready!(self.deframer.poll_read(&mut self.stream, cx))?;
174 }
175 }
176
177 /// Attempts to decrypt the provied `message` if there is a decryptor set
178 fn try_decrypt_message(&mut self, message: &mut Message) -> std::io::Result<()> {
179 let decryptor = match &mut self.decryptor {
180 Some(value) => value,
181 None => return Ok(()),
182 };
183
184 if decryptor.decrypt(message) {
185 return Ok(());
186 }
187
188 // Write the error alert message
189 self.write_alert(AlertError::fatal(AlertDescription::BadRecordMac));
190
191 Err(std::io::Error::new(ErrorKind::Other, "Bad record mac"))
192 }
193
194 /// Triggers a shutdown by sending a CloseNotify alert
195 ///
196 /// # Arguments
197 /// * cx - The polling context
198 fn poll_shutdown_priv(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
199 // Send the alert if not already stopping
200 if !self.stopped {
201 // Send the shutdown close notify
202 self.write_alert(AlertError::warning(AlertDescription::CloseNotify));
203 }
204
205 // Flush any data before shutdown
206 self.poll_flush_priv(cx)
207 }
208
209 /// Fragments the provided message and encrypts the contents if
210 /// encryption is available writing the output to the underlying
211 /// stream
212 ///
213 /// # Arguments
214 /// * message - The message to write
215 pub(crate) fn write_message(&mut self, message: Message) {
216 for mut msg in message.fragment() {
217 if let Some(writer) = &mut self.encryptor {
218 writer.encrypt(&mut msg)
219 }
220
221 msg.encode(&mut self.write_buffer);
222 }
223 }
224
225 /// Writes an alert message and updates the stopped state
226 ///
227 /// # Arguments
228 /// * alert - The alert to write
229 pub(crate) fn write_alert(&mut self, alert: AlertError) {
230 let mut payload = Vec::new();
231 alert.encode(&mut payload);
232
233 let message = Message::new(MessageType::Alert, payload);
234
235 // Internally handle the alert being sent
236 self.write_message(message);
237
238 // Handle stopping from an alert
239 self.stopped = true;
240 }
241
242 /// Writes the provided bytes as application data to the
243 /// app write buffer
244 fn write_app_data(&mut self, buf: &[u8]) -> io::Result<usize> {
245 if self.stopped {
246 return Err(io_closed());
247 };
248 self.app_write_buffer.extend_from_slice(buf);
249 Ok(buf.len())
250 }
251
252 /// Polls reading application data from the app
253 fn poll_read_priv(
254 &mut self,
255 cx: &mut Context<'_>,
256 buf: &mut ReadBuf<'_>,
257 ) -> Poll<io::Result<()>> {
258 // Poll flushing the write buffer before attempting to read
259 ready!(self.poll_flush_priv(cx))?;
260
261 if self.stopped {
262 return Poll::Ready(Err(io_closed()));
263 }
264
265 // Poll for app data from the stream
266 let count = ready!(self.poll_app_data(cx))?;
267
268 // Handle already stopped streams
269 if self.stopped {
270 return Poll::Ready(Err(io_closed()));
271 }
272
273 // Calculate the amount to read based on the buf size and the amount stored
274 let read = cmp::min(buf.remaining(), count);
275 if read > 0 {
276 // Provide the data and replace the stored slice
277 let new_buffer = self.app_read_buffer.split_off(read);
278 buf.put_slice(&self.app_read_buffer);
279 self.app_read_buffer = new_buffer;
280 }
281
282 Poll::Ready(Ok(()))
283 }
284
285 /// Polls flushing all the data for this stream that includes app data
286 /// and the write buffer. This involves writing everything to the write
287 /// buffer and then writing all the data to the stream and attempting
288 /// to flush the stream
289 pub(crate) fn poll_flush_priv(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
290 // Write any written app data as a message to the write buffer
291 if !self.app_write_buffer.is_empty() {
292 let message = Message::new(
293 MessageType::ApplicationData,
294 self.app_write_buffer.split_off(0),
295 );
296
297 self.write_message(message);
298 }
299
300 // Try flushing the internal write buffer
301 let mut write_count: usize = 0;
302 while !self.write_buffer.is_empty() {
303 let stream = Pin::new(&mut self.stream);
304 let count = ready!(stream.poll_write(cx, &self.write_buffer))?;
305 if count > 0 {
306 self.write_buffer = self.write_buffer.split_off(count);
307 write_count += count;
308 }
309 }
310
311 // Skip flushing if we haven't written any data
312 if write_count == 0 {
313 return Poll::Ready(Ok(()));
314 }
315
316 // Try flush the underlying stream
317 Pin::new(&mut self.stream).poll_flush(cx)
318 }
319
320 /// Polls for application data or returns the already present amount of application
321 /// data stored in this stream, Collects application data by polling for messages
322 fn poll_app_data(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
323 let buffer_len = self.app_read_buffer.len();
324
325 // Early return if the buffer is not yet empty
326 if buffer_len != 0 {
327 return Poll::Ready(Ok(buffer_len));
328 }
329
330 // Poll for the next message
331 let message = ready!(self.poll_next_message(cx))?;
332
333 match message.message_type {
334 // Handle errors from the client
335 MessageType::Alert => {
336 let alert = AlertError::from_message(&message);
337
338 // Stop the stream
339 self.stopped = true;
340
341 // On error ready 0 bytes
342 Poll::Ready(Err(io::Error::new(ErrorKind::Other, alert)))
343 }
344
345 // Handle application data
346 MessageType::ApplicationData => {
347 let payload = message.payload;
348 self.app_read_buffer.extend_from_slice(&payload);
349 Poll::Ready(Ok(payload.len()))
350 }
351
352 // Unexpected message kind
353 _ => {
354 self.write_alert(AlertError::fatal(AlertDescription::UnexpectedMessage));
355
356 Poll::Ready(Err(io::Error::new(
357 ErrorKind::Other,
358 "Expected application data but got something else",
359 )))
360 }
361 }
362 }
363}
364
365impl AsyncRead for BlazeStream {
366 /// Read polling handled by internal poll_read_priv
367 ///
368 /// # Arguments
369 /// * cx - The polling context
370 /// * buf - The read buffer to read to
371 fn poll_read(
372 self: Pin<&mut Self>,
373 cx: &mut Context<'_>,
374 buf: &mut ReadBuf<'_>,
375 ) -> Poll<io::Result<()>> {
376 self.get_mut().poll_read_priv(cx, buf)
377 }
378}
379
380impl AsyncWrite for BlazeStream {
381 /// Writing polling is always ready as the data is written
382 /// directly to a vec buffer
383 ///
384 /// # Arguments
385 /// * _cx - The polling context
386 /// * buf - The slice of bytes to write as app data
387 fn poll_write(
388 self: Pin<&mut Self>,
389 _cx: &mut Context<'_>,
390 buf: &[u8],
391 ) -> Poll<io::Result<usize>> {
392 Poll::Ready(self.get_mut().write_app_data(buf))
393 }
394
395 /// Polls the internal flushing funciton
396 ///
397 /// # Arguments
398 /// * cx - The polling context
399 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
400 self.get_mut().poll_flush_priv(cx)
401 }
402
403 /// Polls the internal shutdown function
404 ///
405 /// # Arguments
406 /// * cx - The polling context
407 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
408 self.get_mut().poll_shutdown_priv(cx)
409 }
410}
411
412/// Creates an error indicating that the stream is closed
413fn io_closed() -> io::Error {
414 io::Error::new(ErrorKind::UnexpectedEof, "Connection closed")
415}