1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
use crate::msg::deframer::DeframeState;

use super::{
    crypto::{rc4::*, MacGenerator},
    data::BlazeServerData,
    handshake::HandshakingWrapper,
    msg::{codec::*, deframer::MessageDeframer, types::*, AlertMessage, Message},
};
use std::{
    cmp,
    fmt::Display,
    io::{self, ErrorKind},
    net::SocketAddr,
    pin::Pin,
    sync::Arc,
    task::{ready, Context, Poll},
};
use tokio::{
    io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf},
    net::{TcpListener, TcpStream, ToSocketAddrs},
};

/// Wrapper over TcpStream to provide SSL
pub struct BlazeStream {
    /// Underlying stream target
    stream: TcpStream,

    /// Message deframer for de-framing messages from the read stream
    deframer: MessageDeframer,

    /// Decryptor for decrypting messages if the stream is encrypted
    decryptor: Option<Rc4Decryptor>,
    /// Encryptor for encrypting messages if the stream should be encrypted
    encryptor: Option<Rc4Encryptor>,

    /// Buffer for input that is read from the application layer
    app_read_buffer: Vec<u8>,
    /// Buffer for output written to the application layer
    /// (Written to stream when connection is flushed)
    app_write_buffer: Vec<u8>,

    /// Buffer for the raw packet contents that are going to be
    /// written to the stream
    write_buffer: Vec<u8>,

    /// State determining whether the stream is stopped
    stopped: bool,
}

/// Type to use when starting the handshake. Server type will
/// handshake as the server entity and client will handshake
/// as a client entity
pub(crate) enum StreamType {
    /// Stream is a stream created by a server listener
    /// contains additional data provided by the server
    Server { data: Arc<BlazeServerData> },
    /// Stream is a client stream connecting to a server
    Client,
}

impl StreamType {
    /// Returns whether the stream type is a client stream
    pub fn is_client(&self) -> bool {
        matches!(self, Self::Client)
    }

    /// Retrieves borrow of the server data for the current stream
    /// panicing if the stream is a client stream
    pub fn server_data(&self) -> &BlazeServerData {
        match &self {
            Self::Server { data } => data,
            Self::Client => panic!("Tried to access server data on client stream"),
        }
    }
}

impl BlazeStream {
    /// Connects to a remote address creating a client blaze stream
    /// to that address.
    ///
    /// # Arguments
    /// * addr - The address to connect to
    pub async fn connect<A: ToSocketAddrs>(addr: A) -> BlazeResult<Self> {
        let stream = TcpStream::connect(addr).await?;
        Self::new(stream, StreamType::Client).await
    }

    /// Creates a new blaze stream wrapping the provided value with
    /// the provided stream type
    ///
    /// # Arguments
    /// * value - The underlying stream to wrap with SSL
    /// * ty - The type of stream
    async fn new(value: TcpStream, ty: StreamType) -> BlazeResult<Self> {
        // Wrap the stream in a blaze stream
        let stream = Self {
            stream: value,
            deframer: MessageDeframer::new(),
            decryptor: None,
            encryptor: None,
            app_write_buffer: Vec::new(),
            app_read_buffer: Vec::new(),
            write_buffer: Vec::new(),
            stopped: false,
        };

        // Wrap the blaze stream and complete the handshake
        let mut wrapper = HandshakingWrapper::new(stream, ty);
        let result = wrapper.handshake().await;
        let mut stream = wrapper.into_inner();
        if let Err(err) = result {
            // Try flushing any remaining messages (Errors) and ignore errors
            stream.flush().await.ok();
            return Err(err);
        }

        // Return the unwrapped stream
        Ok(stream)
    }

    /// Creates a new RC4 encryptor from the provided key and mac
    /// generator assigning the stream encryptor to it
    ///
    /// # Arguments
    /// * key - The key to use
    /// * mac - The mac generator to use
    pub(crate) fn set_encryptor(&mut self, key: Rc4, mac: MacGenerator) {
        self.encryptor = Some(Rc4Encryptor::new(key, mac))
    }

    /// Creates a new RC4 decryptor from the provided key and mac
    /// generator assigning the stream decryptor to it
    ///
    /// # Arguments
    /// * key - The key to use
    /// * mac - The mac generator to use
    pub(crate) fn set_decryptor(&mut self, key: Rc4, mac: MacGenerator) {
        self.decryptor = Some(Rc4Decryptor::new(key, mac))
    }

    /// Polls for the next message to be recieved. Decryptes encrypted messages
    /// and handles alert messages.
    ///
    /// # Arguments
    /// * cx - The polling context
    pub(crate) fn poll_next_message(&mut self, cx: &mut Context<'_>) -> Poll<BlazeResult<Message>> {
        loop {
            // Stopped streams immeditely results in an error
            if self.stopped {
                return Poll::Ready(Err(BlazeError::Stopped));
            }

            let mut message = match self.deframer.next() {
                // We have a next frame available from the deframer
                Some(message) => message,
                // We need to keep reading from the stream
                None => {
                    // Poll reading data from the stream
                    ready!(self.deframer.poll_read(&mut self.stream, cx))?;

                    // Attempt to deframe messages from the stream
                    let state = self.deframer.deframe();
                    match state {
                        // The stream is invalid close the connection
                        DeframeState::Invalid => {
                            return Poll::Ready(Err(
                                self.alert_fatal(AlertDescription::IllegalParameter)
                            ));
                        }
                        // More data is required we must continue polling
                        DeframeState::Incomplete => continue,
                    }
                }
            };

            // Decrypt message if encryption is enabled
            if let Some(decryptor) = &mut self.decryptor {
                if !decryptor.decrypt(&mut message) {
                    // Handle failed decryption due to invalid MAC field
                    return Poll::Ready(Err(self.alert_fatal(AlertDescription::BadRecordMac)));
                }
            }

            // Handle alert messages
            return Poll::Ready(if let MessageType::Alert = message.message_type {
                // Handle alert messages
                Err(self.handle_alert_message(message))
            } else {
                Ok(message)
            });
        }
    }

    /// Handles recieved alert messages first parsing the message and then
    /// handling it based on its type and returning the respective error
    /// for the type.
    ///
    /// # Arguments
    /// * message - The raw alert message
    fn handle_alert_message(&mut self, message: Message) -> BlazeError {
        // Attempt to read the message
        let mut reader = Reader::new(&message.payload);
        let description = AlertMessage::decode(&mut reader)
            .map(|value| value.1)
            .unwrap_or_else(|| AlertDescription::Unknown(0));

        // All alerts result in shutdown
        self.stopped = true;

        // Handle close notify messages as non errors
        if matches!(description, AlertDescription::CloseNotify) {
            BlazeError::Stopped
        } else {
            // All error alerts are consider to be fatal in this implementation
            BlazeError::Alert(description)
        }
    }

    /// Sets the stopped state to true and sends the close
    /// notify alert if shutdown has not already been called
    fn shutdown(&mut self) {
        if !self.stopped {
            self.alert(&AlertDescription::CloseNotify);
            self.stopped = true;
        }
    }

    /// Triggers a shutdown by sending a CloseNotify alert
    ///
    /// # Arguments
    /// * cx - The polling context
    fn poll_shutdown_priv(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        self.shutdown();
        // Flush any data before shutdown
        self.poll_flush_priv(cx)
    }

    /// Fragments the provided message and encrypts the contents if
    /// encryption is available writing the output to the underlying
    /// stream
    ///
    /// # Arguments
    /// * message - The message to write
    pub(crate) fn write_message(&mut self, message: Message) {
        for msg in message.fragment() {
            let msg = if let Some(writer) = &mut self.encryptor {
                writer.encrypt(msg)
            } else {
                Message {
                    message_type: msg.message_type,
                    payload: msg.payload.to_vec(),
                }
            };
            let bytes = msg.encode();
            self.write_buffer.extend_from_slice(&bytes);
        }
    }

    /// Writes an alert message
    ///
    /// # Arguments
    /// * alert - The alert to write
    pub(crate) fn alert(&mut self, alert: &AlertDescription) {
        let message = Message {
            message_type: MessageType::Alert,
            payload: alert.encode_vec(),
        };
        // Internally handle the alert being sent
        self.write_message(message);
    }

    /// Handles a fatal alert where an unexpected message was recieved
    /// returning the error created
    pub(crate) fn fatal_unexpected(&mut self) -> BlazeError {
        self.alert_fatal(AlertDescription::UnexpectedMessage)
    }

    /// Handles a fatal alert where an illegal parameter was recieved
    /// returning the error created
    pub(crate) fn fatal_illegal(&mut self) -> BlazeError {
        self.alert_fatal(AlertDescription::IllegalParameter)
    }

    /// Writes a fatal alert and calls shutdown returning a
    /// BlazeError for the alert
    ///
    /// # Arguments
    /// * alert - The fatal alert
    fn alert_fatal(&mut self, alert: AlertDescription) -> BlazeError {
        self.alert(&alert);
        // Shutdown the stream because of fatal error
        self.shutdown();
        BlazeError::Alert(alert)
    }

    /// Writes the provided bytes as application data to the
    /// app write buffer
    ///
    /// # Arguments
    /// * buf - The buffer to write
    fn write_app_data(&mut self, buf: &[u8]) -> io::Result<usize> {
        if self.stopped {
            return Err(io_closed());
        };
        self.app_write_buffer.extend_from_slice(buf);
        Ok(buf.len())
    }

    /// Polls reading application data from the app
    ///
    /// # Arguments
    /// * cx -  The polling context
    /// * buf - The buffer to read data into
    fn poll_read_priv(
        &mut self,
        cx: &mut Context<'_>,
        buf: &mut ReadBuf<'_>,
    ) -> Poll<io::Result<()>> {
        // Poll flushing the write buffer before attempting to read
        ready!(self.poll_flush_priv(cx))?;

        // Poll for app data from the stream
        let count = ready!(self.poll_app_data(cx))?;

        // Handle already stopped streams
        if self.stopped {
            return Poll::Ready(Err(io_closed()));
        }

        // Calculate the amount to read based on the buf size and the amount stored
        let read = cmp::min(buf.remaining(), count);
        if read > 0 {
            // Provide the data and replace the stored slice
            let new_buffer = self.app_read_buffer.split_off(read);
            buf.put_slice(&self.app_read_buffer);
            self.app_read_buffer = new_buffer;
        }

        Poll::Ready(Ok(()))
    }

    /// Polls flushing all the data for this stream that includes app data
    /// and the write buffer. This involves writing everything to the write
    /// buffer and then writing all the data to the stream and attempting
    /// to flush the stream
    ///
    /// # Arguments
    /// * cx - The polling context
    fn poll_flush_priv(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        if self.stopped {
            return Poll::Ready(Err(io_closed()));
        }

        // Write any written app data as a message to the write buffer
        if !self.app_write_buffer.is_empty() {
            let message = Message {
                message_type: MessageType::ApplicationData,
                payload: self.app_write_buffer.split_off(0),
            };
            self.write_message(message);
        }

        // Try flushing the internal write buffer
        let mut write_count: usize = 0;
        while !self.write_buffer.is_empty() {
            let stream = Pin::new(&mut self.stream);
            let count = ready!(stream.poll_write(cx, &self.write_buffer))?;
            if count > 0 {
                self.write_buffer = self.write_buffer.split_off(count);
                write_count += count;
            }
        }

        // Skip flushing if we haven't written any data
        if write_count == 0 {
            return Poll::Ready(Ok(()));
        }

        // Try flush the underlying stream
        Pin::new(&mut self.stream).poll_flush(cx)
    }

    /// Polls for application data or returns the already present amount of application
    /// data stored in this stream, Collects application data by polling for messages
    ///
    /// # Arguments
    /// * cx - The polling context
    fn poll_app_data(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<usize>> {
        if self.stopped {
            return Poll::Ready(Err(io_closed()));
        }
        let buffer_len = self.app_read_buffer.len();

        // Early return if the buffer is not yet empty
        if buffer_len != 0 {
            return Poll::Ready(Ok(buffer_len));
        }

        // Poll for the next message
        let message = match ready!(self.poll_next_message(cx)) {
            Ok(value) => value,
            Err(_) => {
                return Poll::Ready(Err(io::Error::new(
                    ErrorKind::ConnectionAborted,
                    "SSL Failure",
                )))
            }
        };

        // The alert message type is already handled in message polling so recieving
        // any messages that aren't application data here should be an error
        Poll::Ready(if let MessageType::ApplicationData = message.message_type {
            let payload = message.payload;
            self.app_read_buffer.extend_from_slice(&payload);
            Ok(payload.len())
        } else {
            // Alert unexpected message
            self.alert_fatal(AlertDescription::UnexpectedMessage);
            Err(io::Error::new(
                ErrorKind::Other,
                "Expected application data but got something else",
            ))
        })
    }

    /// Returns a reference to the underlying stream
    pub fn get_ref(&self) -> &TcpStream {
        &self.stream
    }

    /// Returns a mutable reference to the underlying stream
    pub fn get_mut(&mut self) -> &mut TcpStream {
        &mut self.stream
    }

    /// Returns the underlying stream that this BlazeStream
    /// is wrapping
    pub fn into_inner(self) -> TcpStream {
        self.stream
    }
}

impl AsyncRead for BlazeStream {
    /// Read polling handled by internal poll_read_priv
    ///
    /// # Arguments
    /// * cx - The polling context
    /// * buf - The read buffer to read to
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut ReadBuf<'_>,
    ) -> Poll<io::Result<()>> {
        self.get_mut().poll_read_priv(cx, buf)
    }
}

impl AsyncWrite for BlazeStream {
    /// Writing polling is always ready as the data is written
    /// directly to a vec buffer
    ///
    /// # Arguments
    /// * _cx - The polling context
    /// * buf - The slice of bytes to write as app data
    fn poll_write(
        self: Pin<&mut Self>,
        _cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<io::Result<usize>> {
        Poll::Ready(self.get_mut().write_app_data(buf))
    }

    /// Polls the internal flushing funciton
    ///
    /// # Arguments
    /// * cx - The polling context
    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        self.get_mut().poll_flush_priv(cx)
    }

    /// Polls the internal shutdown function
    ///
    /// # Arguments
    /// * cx - The polling context
    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        self.get_mut().poll_shutdown_priv(cx)
    }
}

/// Listener wrapping TcpListener in order to accept
/// SSL connections
pub struct BlazeListener {
    /// The underlying TcpListener
    listener: TcpListener,
    /// The server data to use for initializing streams
    data: Arc<BlazeServerData>,
}

impl BlazeListener {
    /// Replaces the server private key and certificate used
    /// for accepting connections
    ///
    /// # Arguments
    /// * data - The new server data
    pub fn set_server_data(&mut self, data: Arc<BlazeServerData>) {
        self.data = data;
    }

    /// Binds a new TcpListener wrapping it in a BlazeListener if no
    /// errors occurred
    ///
    /// # Arguments
    /// * addr - The addr(s) to attempt to bind on
    pub async fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<BlazeListener> {
        let listener = TcpListener::bind(addr).await?;
        Ok(BlazeListener {
            listener,
            data: Arc::default(),
        })
    }

    /// Accepts a new TcpStream from the underlying listener wrapping
    /// it in a server BlazeStream returning the wrapped stream and the
    /// stream addr.
    ///
    /// Awaiting the blaze stream creation here would mean connections
    /// wouldnt be able to be accepted so instead a BlazeAccept is returned
    /// and `finish_accept` should be called within a spawned task otherwise
    /// you can use `blocking_accept` to do an immediate handle
    pub async fn accept(&self) -> io::Result<BlazeAccept> {
        let (stream, addr) = self.listener.accept().await?;
        Ok(BlazeAccept {
            stream,
            addr,
            data: self.data.clone(),
        })
    }

    /// Alternative to accpet where the handshaking process is done straight away
    /// rather than in the BlazeAccept, this will prevent new connections from
    /// being accepted until the current handshake is complete
    pub async fn blocking_accept(&self) -> BlazeResult<(BlazeStream, SocketAddr)> {
        let (stream, addr) = self.listener.accept().await?;
        let stream = BlazeStream::new(
            stream,
            StreamType::Server {
                data: self.data.clone(),
            },
        )
        .await?;
        Ok((stream, addr))
    }
}

/// Structure representing a stream accepted from
/// the underlying listener that is yet to be
/// converted into a BlazeStream
pub struct BlazeAccept {
    /// The underlying stream
    stream: TcpStream,
    /// The socket address to the stream
    addr: SocketAddr,
    /// The server data to use for initializing the stream
    data: Arc<BlazeServerData>,
}

impl BlazeAccept {
    /// Finishes the accepting process for this connection. This should be called
    /// in a seperately spawned task to prevent blocking accepting new connections.
    /// Returns the wrapped blaze stream and the socket address
    pub async fn finish_accept(self) -> BlazeResult<(BlazeStream, SocketAddr)> {
        let stream = BlazeStream::new(self.stream, StreamType::Server { data: self.data }).await?;
        Ok((stream, self.addr))
    }
}

/// Creates an error indicating that the stream is closed
fn io_closed() -> io::Error {
    io::Error::new(ErrorKind::Other, "Stream already closed")
}

/// Error implementation for different errors that can
/// occur while handshaking and general operation
#[derive(Debug)]
pub enum BlazeError {
    /// IO
    IO(io::Error),
    /// Fatal alert occurred
    Alert(AlertDescription),
    /// The stream is stopped
    Stopped,
}

impl std::error::Error for BlazeError {}

impl Display for BlazeError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            BlazeError::IO(err) => err.fmt(f),
            BlazeError::Alert(desc) => writeln!(f, "Fatal alert: {:?}", desc),
            BlazeError::Stopped => f.write_str("Connection stopped"),
        }
    }
}

impl From<io::Error> for BlazeError {
    fn from(err: io::Error) -> Self {
        BlazeError::IO(err)
    }
}

/// Type alias for results that return a BlazeError
pub(crate) type BlazeResult<T> = Result<T, BlazeError>;