1use std::{
9 io::{self, Read, Write},
10 marker::PhantomData,
11 mem,
12};
13
14use openssl::ssl::{ErrorCode, SslStream};
15
16use crate::{
17 frame::{Frame, FrameBuilder},
18 Blocking, NonBlocking,
19};
20
21const BUF_SIZE: usize = 1024;
22
23pub struct Secure<S, FB>
25where
26 S: io::Read + io::Write,
27 FB: FrameBuilder,
28{
29 inner: SslStream<S>,
30 rx_buf: Vec<u8>,
31 tx_buf: Vec<u8>,
32 phantom: PhantomData<FB>,
33}
34
35impl<S, FB> Secure<S, FB>
36where
37 S: io::Read + io::Write,
38 FB: FrameBuilder,
39{
40 pub fn new(stream: SslStream<S>) -> Secure<S, FB> {
42 Secure {
43 inner: stream,
44 rx_buf: Vec::<u8>::with_capacity(BUF_SIZE),
45 tx_buf: Vec::<u8>::with_capacity(BUF_SIZE),
46 phantom: PhantomData,
47 }
48 }
49}
50
51impl<S, FB> Blocking for Secure<S, FB>
52where
53 S: io::Read + io::Write,
54 FB: FrameBuilder,
55{
56 fn b_recv(&mut self) -> io::Result<Box<dyn Frame>> {
57 match FB::from_bytes(&mut self.rx_buf) {
59 Some(boxed_frame) => {
60 debug!("Complete frame read");
61 return Ok(boxed_frame);
62 }
63 None => {}
64 };
65
66 loop {
67 let mut buf = [0u8; BUF_SIZE];
68 let read_result = self.inner.read(&mut buf);
69 if read_result.is_err() {
70 let err = read_result.unwrap_err();
71 return Err(err);
72 }
73
74 let num_read = read_result.unwrap();
75 trace!("Read {} byte(s)", num_read);
76 self.rx_buf.extend_from_slice(&buf[0..num_read]);
77
78 match FB::from_bytes(&mut self.rx_buf) {
79 Some(boxed_frame) => {
80 debug!("Complete frame read");
81 return Ok(boxed_frame);
82 }
83 None => {}
84 };
85 }
86 }
87
88 fn b_send(&mut self, frame: &dyn Frame) -> io::Result<()> {
89 let out_buf = frame.to_bytes();
90 let write_result = self.inner.write(&out_buf[..]);
91 if write_result.is_err() {
92 let err = write_result.unwrap_err();
93 return Err(err);
94 }
95
96 trace!("Wrote {} byte(s)", write_result.unwrap());
97
98 Ok(())
99 }
100}
101
102impl<S, FB> NonBlocking for Secure<S, FB>
103where
104 S: io::Read + io::Write,
105 FB: FrameBuilder,
106{
107 fn nb_recv(&mut self) -> io::Result<Vec<Box<dyn Frame>>> {
108 loop {
109 let mut buf = [0u8; BUF_SIZE];
110 let read_result = self.inner.ssl_read(&mut buf);
111 if read_result.is_err() {
112 let e = read_result.unwrap_err();
113 match e.code() {
114 ErrorCode::ZERO_RETURN => {
115 return Err(io::Error::new(
116 io::ErrorKind::UnexpectedEof,
117 "UnexpectedEof",
118 ));
119 }
120 ErrorCode::WANT_READ => {
121 break;
122 }
123
124 ErrorCode::WANT_WRITE => {
125 return Err(io::Error::new(io::ErrorKind::Other, "WantWrite"));
126 }
127
128 ErrorCode::SYSCALL => {
129 return Err(io::Error::new(io::ErrorKind::Other, "Syscall"));
130 }
131
132 ErrorCode::SSL => {
133 return Err(io::Error::new(io::ErrorKind::Other, "SSL"));
134 }
135 _ => {
136 return Err(io::Error::new(
138 io::ErrorKind::Other,
139 "Unknown error during ssl_read",
140 ));
141 }
142 };
143 }
144
145 let num_read = read_result.unwrap();
146 trace!("Read {} byte(s)", num_read);
147 self.rx_buf.extend_from_slice(&buf[0..num_read]);
148 }
149
150 let mut ret_buf = Vec::<Box<dyn Frame>>::with_capacity(5);
151 while let Some(boxed_frame) = FB::from_bytes(&mut self.rx_buf) {
152 info!("Complete frame read");
153 ret_buf.push(boxed_frame);
154 }
155
156 if ret_buf.len() > 0 {
157 info!("Read {} frame(s)", ret_buf.len());
158 return Ok(ret_buf);
159 }
160
161 Err(io::Error::new(io::ErrorKind::WouldBlock, "WouldBlock"))
162 }
163
164 fn nb_send(&mut self, frame: &dyn Frame) -> io::Result<()> {
165 self.tx_buf.extend_from_slice(&frame.to_bytes()[..]);
166
167 let mut out_buf = Vec::<u8>::with_capacity(BUF_SIZE);
168 mem::swap(&mut self.tx_buf, &mut out_buf);
169
170 let write_result = self.inner.ssl_write(&out_buf[..]);
171 if write_result.is_err() {
172 let err = write_result.unwrap_err();
173 match err.code() {
174 ErrorCode::ZERO_RETURN => {
175 return Err(io::Error::new(
176 io::ErrorKind::UnexpectedEof,
177 "UnexpectedEof",
178 ));
179 }
180 ErrorCode::WANT_WRITE => {
181 return Err(io::Error::new(io::ErrorKind::WouldBlock, "WouldBlock"));
182 }
183 ErrorCode::SYSCALL => {
184 return Err(io::Error::new(io::ErrorKind::Other, "Syscall"));
185 }
186
187 ErrorCode::SSL => {
188 return Err(io::Error::new(io::ErrorKind::Other, "SSL"));
189 }
190 _ => {
191 return Err(io::Error::new(
193 io::ErrorKind::Other,
194 "Unknown error during ssl_write",
195 ));
196 }
197 };
198 }
199
200 let num_written = write_result.unwrap();
201 if num_written == 0 {
202 return Err(io::Error::new(io::ErrorKind::Other, "Write returned zero"));
203 }
204
205 trace!(
206 "Tried to write {} byte(s) wrote {} byte(s)",
207 out_buf.len(),
208 num_written
209 );
210
211 if num_written < out_buf.len() {
212 let out_buf_len = out_buf.len();
213 self.tx_buf
214 .extend_from_slice(&out_buf[num_written..out_buf_len]);
215
216 return Err(io::Error::new(io::ErrorKind::WouldBlock, "WouldBlock"));
217 }
218
219 Ok(())
220 }
221}