1use crate::crypto::HashAlgorithm;
2use crate::handshake::{HandshakeSide, HandshakingWrapper};
3use crate::msg::{
4 AlertDescription, AlertMessage, BorrowedMessage, Certificate, Codec, Message, MessageDeframer,
5 MessageType, OpaqueMessage, Reader,
6};
7use crypto::rc4::Rc4;
8use crypto::symmetriccipher::SynchronousStreamCipher;
9use lazy_static::lazy_static;
10use rsa::RsaPrivateKey;
11use std::cmp;
12use std::io::{self, ErrorKind, Read, Write};
13
14lazy_static! {
15 pub static ref SERVER_KEY: RsaPrivateKey = {
17 use rsa::pkcs8::DecodePrivateKey;
18 use rsa::RsaPrivateKey;
19
20 let key_pem = include_str!("key.pem");
21 RsaPrivateKey::from_pkcs8_pem(key_pem)
22 .expect("Failed to load redirector private key")
23 };
24
25 pub static ref SERVER_CERTIFICATE: Certificate = {
27 use pem;
28 let cert_pem = include_str!("cert.pem");
29 let cert_bytes = pem::parse(cert_pem)
30 .expect("Unable to parse server certificate")
31 .contents;
32 Certificate(cert_bytes)
33 };
34}
35
36pub struct BlazeStream<S> {
39 pub(crate) stream: S,
41
42 deframer: MessageDeframer,
44
45 pub(crate) read_processor: ReadProcessor,
47 pub(crate) write_processor: WriteProcessor,
49
50 read_buffer: Vec<u8>,
52 write_buffer: Vec<u8>,
55
56 stopped: bool,
58}
59
60impl<S> BlazeStream<S> {
61 pub fn get_ref(&self) -> &S {
63 return &self.stream;
64 }
65
66 pub fn get_mut(&mut self) -> &mut S {
68 return &mut self.stream;
69 }
70}
71
72#[derive(Debug)]
73pub enum BlazeError {
74 IO(io::Error),
75 FatalAlert(AlertDescription),
76 Stopped,
77 Unsupported,
78}
79
80impl From<io::Error> for BlazeError {
81 fn from(err: io::Error) -> Self {
82 BlazeError::IO(err)
83 }
84}
85
86pub type BlazeResult<T> = Result<T, BlazeError>;
87
88#[derive(Debug)]
92pub enum StreamMode {
93 Server,
94 Client,
95}
96
97impl<S> BlazeStream<S>
98where
99 S: Read + Write,
100{
101 pub fn new(value: S, mode: StreamMode) -> BlazeResult<Self> {
102 let stream = Self {
103 stream: value,
104 deframer: MessageDeframer::new(),
105 read_processor: ReadProcessor::None,
106 write_processor: WriteProcessor::None,
107 write_buffer: Vec::new(),
108 read_buffer: Vec::new(),
109 stopped: false,
110 };
111 let wrapper = HandshakingWrapper::new(
112 stream,
113 match mode {
114 StreamMode::Server => HandshakeSide::Server,
115 StreamMode::Client => HandshakeSide::Client,
116 },
117 );
118 wrapper.handshake()
119 }
120
121 pub fn next_message(&mut self) -> BlazeResult<Message> {
124 loop {
125 if self.stopped {
126 return Err(BlazeError::Stopped);
127 }
128
129 if let Some(message) = self.deframer.next() {
130 let message = self
131 .read_processor
132 .process(message)
133 .map_err(|err| match err {
134 DecryptError::InvalidMac => {
135 self.alert_fatal(AlertDescription::BadRecordMac)
136 }
137 })?;
138 if message.message_type == MessageType::Alert {
139 let mut reader = Reader::new(&message.payload);
140 if let Some(message) = AlertMessage::decode(&mut reader) {
141 self.handle_alert(message.1)?;
142 continue;
143 } else {
144 return Err(self.handle_fatal(AlertDescription::Unknown(0)));
145 }
146 }
147
148 return Ok(message);
149 }
150 if !self.deframer.read(&mut self.stream)? {
151 return Err(self.alert_fatal(AlertDescription::IllegalParameter));
152 }
153 }
154 }
155
156 pub fn shutdown(&mut self) -> BlazeResult<()> {
158 self.alert(AlertDescription::CloseNotify)
159 }
160
161 pub fn handle_alert(&mut self, alert: AlertDescription) -> BlazeResult<()> {
163 match alert {
164 AlertDescription::CloseNotify => {
165 let _ = self.flush();
167 self.stopped = true;
168 Ok(())
169 }
170 _ => Err(BlazeError::FatalAlert(alert)),
171 }
172 }
173
174 pub fn handle_fatal(&mut self, alert: AlertDescription) -> BlazeError {
176 self.stopped = true;
177 return BlazeError::FatalAlert(alert);
178 }
179
180 pub fn write_message(&mut self, message: Message) -> io::Result<()> {
184 for msg in message.fragment() {
185 let msg = self.write_processor.process(msg);
186 let bytes = msg.encode();
187 self.stream.write(&bytes)?;
188 }
189 Ok(())
190 }
191
192 pub fn alert(&mut self, alert: AlertDescription) -> BlazeResult<()> {
194 let message = Message {
195 message_type: MessageType::Alert,
196 payload: alert.encode_vec(),
197 };
198 self.handle_alert(alert)?;
200 self.write_message(message)?;
201 Ok(())
202 }
203
204 pub fn fatal_unexpected(&mut self) -> BlazeError {
205 self.alert_fatal(AlertDescription::UnexpectedMessage)
206 }
207
208 pub fn fatal_illegal(&mut self) -> BlazeError {
209 self.alert_fatal(AlertDescription::IllegalParameter)
210 }
211
212 pub fn alert_fatal(&mut self, alert: AlertDescription) -> BlazeError {
213 let message = Message {
214 message_type: MessageType::Alert,
215 payload: alert.encode_vec(),
216 };
217 let _ = self.write_message(message);
218 self.handle_fatal(alert)
220 }
221
222 pub fn fill_app_data(&mut self) -> io::Result<usize> {
225 if self.stopped {
226 return Err(io_closed());
227 }
228 let buffer_len = self.read_buffer.len();
229 let count = if buffer_len == 0 {
230 let message = self
231 .next_message()
232 .map_err(|_| io::Error::new(ErrorKind::ConnectionAborted, "Ssl Failure"))?;
233
234 if message.message_type != MessageType::ApplicationData {
235 self.alert_fatal(AlertDescription::UnexpectedMessage);
237 return Ok(0);
238 }
239
240 let payload = message.payload;
241 self.read_buffer.extend_from_slice(&payload);
242 payload.len()
243 } else {
244 buffer_len
245 };
246 Ok(count)
247 }
248}
249
250fn io_closed() -> io::Error {
252 io::Error::new(ErrorKind::Other, "Stream already closed")
253}
254
255impl<S> Write for BlazeStream<S>
256where
257 S: Read + Write,
258{
259 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
260 if self.stopped {
261 return Err(io_closed());
262 }
263 self.write_buffer.extend_from_slice(buf);
264 Ok(buf.len())
265 }
266
267 fn flush(&mut self) -> io::Result<()> {
268 if self.stopped {
269 return Err(io_closed());
270 }
271 let message = Message {
272 message_type: MessageType::ApplicationData,
273 payload: self.write_buffer.split_off(0),
274 };
275 self.write_message(message)?;
276 self.stream.flush()
277 }
278}
279
280impl<S> Read for BlazeStream<S>
281where
282 S: Read + Write,
283{
284 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
285 let count = self.fill_app_data()?;
286 if self.stopped {
287 return Err(io_closed());
288 }
289
290 let read = cmp::min(buf.len(), count);
291 if read > 0 {
292 let new_buffer = self.read_buffer.split_off(read);
293 buf[..read].copy_from_slice(&self.read_buffer);
294 self.read_buffer = new_buffer;
295 }
296 Ok(read)
297 }
298}
299
300pub enum WriteProcessor {
303 None,
305 RC4 {
307 alg: HashAlgorithm,
308 key: Rc4,
309 mac_secret: Vec<u8>,
310 seq: u64,
311 },
312}
313
314impl WriteProcessor {
315 pub fn process(&mut self, message: BorrowedMessage) -> OpaqueMessage {
321 match self {
322 WriteProcessor::None => OpaqueMessage {
324 message_type: message.message_type,
325 payload: message.payload.to_vec(),
326 },
327 WriteProcessor::RC4 {
329 alg,
330 key,
331 mac_secret,
332 seq,
333 } => {
334 let mut payload = message.payload.to_vec();
335
336 alg.append_mac(&mut payload, mac_secret, message.message_type.value(), seq);
337
338 let mut payload_enc = vec![0u8; payload.len()];
339 key.process(&payload, &mut payload_enc);
340
341 *seq += 1;
342
343 OpaqueMessage {
344 message_type: message.message_type,
345 payload: payload_enc,
346 }
347 }
348 }
349 }
350}
351
352pub enum ReadProcessor {
355 None,
357 RC4 {
359 alg: HashAlgorithm,
360 key: Rc4,
361 mac_secret: Vec<u8>,
362 seq: u64,
363 },
364}
365
366#[derive(Debug)]
367pub enum DecryptError {
368 InvalidMac,
371}
372
373type DecryptResult<T> = Result<T, DecryptError>;
374
375impl ReadProcessor {
376 pub fn process(&mut self, message: OpaqueMessage) -> DecryptResult<Message> {
377 Ok(match self {
378 ReadProcessor::None => Message {
380 message_type: message.message_type,
381 payload: message.payload,
382 },
383 ReadProcessor::RC4 {
385 alg,
386 key,
387 mac_secret,
388 seq,
389 } => {
390 let mut payload_and_mac = vec![0u8; message.payload.len()];
391 key.process(&message.payload, &mut payload_and_mac);
392
393 let mac_start = payload_and_mac.len() - alg.hash_length();
394 let payload = &payload_and_mac[..mac_start];
395 let mac = &payload_and_mac[mac_start..];
396
397 {
398 let valid_mac = alg.compare_mac(
399 mac,
400 mac_secret,
401 message.message_type.value(),
402 &payload,
403 seq,
404 );
405 if !valid_mac {
406 return Err(DecryptError::InvalidMac);
407 }
408 }
409
410 *seq += 1;
411
412 Message {
413 message_type: message.message_type,
414 payload: payload.to_vec(),
415 }
416 }
417 })
418 }
419}