simple_stream/
secure.rs

1// Copyright 2015 Nathan Sizemore <nathanrsizemore@gmail.com>
2//
3// This Source Code Form is subject to the terms of the
4// Mozilla Public License, v. 2.0. If a copy of the MPL was not
5// distributed with this file, You can obtain one at
6// http://mozilla.org/MPL/2.0/.
7
8use std::{
9    io::{self, Read, Write},
10    marker::PhantomData,
11    mem,
12};
13
14use openssl::ssl::{ErrorCode, SslStream};
15
16use crate::{
17    frame::{Frame, FrameBuilder},
18    Blocking, NonBlocking,
19};
20
21const BUF_SIZE: usize = 1024;
22
23/// OpenSSL backed stream.
24pub struct Secure<S, FB>
25where
26    S: io::Read + io::Write,
27    FB: FrameBuilder,
28{
29    inner: SslStream<S>,
30    rx_buf: Vec<u8>,
31    tx_buf: Vec<u8>,
32    phantom: PhantomData<FB>,
33}
34
35impl<S, FB> Secure<S, FB>
36where
37    S: io::Read + io::Write,
38    FB: FrameBuilder,
39{
40    /// Creates a new secured stream.
41    pub fn new(stream: SslStream<S>) -> Secure<S, FB> {
42        Secure {
43            inner: stream,
44            rx_buf: Vec::<u8>::with_capacity(BUF_SIZE),
45            tx_buf: Vec::<u8>::with_capacity(BUF_SIZE),
46            phantom: PhantomData,
47        }
48    }
49}
50
51impl<S, FB> Blocking for Secure<S, FB>
52where
53    S: io::Read + io::Write,
54    FB: FrameBuilder,
55{
56    fn b_recv(&mut self) -> io::Result<Box<dyn Frame>> {
57        // Empty anything that is in our buffer already from any previous reads
58        match FB::from_bytes(&mut self.rx_buf) {
59            Some(boxed_frame) => {
60                debug!("Complete frame read");
61                return Ok(boxed_frame);
62            }
63            None => {}
64        };
65
66        loop {
67            let mut buf = [0u8; BUF_SIZE];
68            let read_result = self.inner.read(&mut buf);
69            if read_result.is_err() {
70                let err = read_result.unwrap_err();
71                return Err(err);
72            }
73
74            let num_read = read_result.unwrap();
75            trace!("Read {} byte(s)", num_read);
76            self.rx_buf.extend_from_slice(&buf[0..num_read]);
77
78            match FB::from_bytes(&mut self.rx_buf) {
79                Some(boxed_frame) => {
80                    debug!("Complete frame read");
81                    return Ok(boxed_frame);
82                }
83                None => {}
84            };
85        }
86    }
87
88    fn b_send(&mut self, frame: &dyn Frame) -> io::Result<()> {
89        let out_buf = frame.to_bytes();
90        let write_result = self.inner.write(&out_buf[..]);
91        if write_result.is_err() {
92            let err = write_result.unwrap_err();
93            return Err(err);
94        }
95
96        trace!("Wrote {} byte(s)", write_result.unwrap());
97
98        Ok(())
99    }
100}
101
102impl<S, FB> NonBlocking for Secure<S, FB>
103where
104    S: io::Read + io::Write,
105    FB: FrameBuilder,
106{
107    fn nb_recv(&mut self) -> io::Result<Vec<Box<dyn Frame>>> {
108        loop {
109            let mut buf = [0u8; BUF_SIZE];
110            let read_result = self.inner.ssl_read(&mut buf);
111            if read_result.is_err() {
112                let e = read_result.unwrap_err();
113                match e.code() {
114                    ErrorCode::ZERO_RETURN => {
115                        return Err(io::Error::new(
116                            io::ErrorKind::UnexpectedEof,
117                            "UnexpectedEof",
118                        ));
119                    }
120                    ErrorCode::WANT_READ => {
121                        break;
122                    }
123
124                    ErrorCode::WANT_WRITE => {
125                        return Err(io::Error::new(io::ErrorKind::Other, "WantWrite"));
126                    }
127
128                    ErrorCode::SYSCALL => {
129                        return Err(io::Error::new(io::ErrorKind::Other, "Syscall"));
130                    }
131
132                    ErrorCode::SSL => {
133                        return Err(io::Error::new(io::ErrorKind::Other, "SSL"));
134                    }
135                    _ => {
136                        // Other error types should not be thrown from this operation
137                        return Err(io::Error::new(
138                            io::ErrorKind::Other,
139                            "Unknown error during ssl_read",
140                        ));
141                    }
142                };
143            }
144
145            let num_read = read_result.unwrap();
146            trace!("Read {} byte(s)", num_read);
147            self.rx_buf.extend_from_slice(&buf[0..num_read]);
148        }
149
150        let mut ret_buf = Vec::<Box<dyn Frame>>::with_capacity(5);
151        while let Some(boxed_frame) = FB::from_bytes(&mut self.rx_buf) {
152            info!("Complete frame read");
153            ret_buf.push(boxed_frame);
154        }
155
156        if ret_buf.len() > 0 {
157            info!("Read {} frame(s)", ret_buf.len());
158            return Ok(ret_buf);
159        }
160
161        Err(io::Error::new(io::ErrorKind::WouldBlock, "WouldBlock"))
162    }
163
164    fn nb_send(&mut self, frame: &dyn Frame) -> io::Result<()> {
165        self.tx_buf.extend_from_slice(&frame.to_bytes()[..]);
166
167        let mut out_buf = Vec::<u8>::with_capacity(BUF_SIZE);
168        mem::swap(&mut self.tx_buf, &mut out_buf);
169
170        let write_result = self.inner.ssl_write(&out_buf[..]);
171        if write_result.is_err() {
172            let err = write_result.unwrap_err();
173            match err.code() {
174                ErrorCode::ZERO_RETURN => {
175                    return Err(io::Error::new(
176                        io::ErrorKind::UnexpectedEof,
177                        "UnexpectedEof",
178                    ));
179                }
180                ErrorCode::WANT_WRITE => {
181                    return Err(io::Error::new(io::ErrorKind::WouldBlock, "WouldBlock"));
182                }
183                ErrorCode::SYSCALL => {
184                    return Err(io::Error::new(io::ErrorKind::Other, "Syscall"));
185                }
186
187                ErrorCode::SSL => {
188                    return Err(io::Error::new(io::ErrorKind::Other, "SSL"));
189                }
190                _ => {
191                    // Other error types should not be thrown from this operation
192                    return Err(io::Error::new(
193                        io::ErrorKind::Other,
194                        "Unknown error during ssl_write",
195                    ));
196                }
197            };
198        }
199
200        let num_written = write_result.unwrap();
201        if num_written == 0 {
202            return Err(io::Error::new(io::ErrorKind::Other, "Write returned zero"));
203        }
204
205        trace!(
206            "Tried to write {} byte(s) wrote {} byte(s)",
207            out_buf.len(),
208            num_written
209        );
210
211        if num_written < out_buf.len() {
212            let out_buf_len = out_buf.len();
213            self.tx_buf
214                .extend_from_slice(&out_buf[num_written..out_buf_len]);
215
216            return Err(io::Error::new(io::ErrorKind::WouldBlock, "WouldBlock"));
217        }
218
219        Ok(())
220    }
221}