use std::io;
use std::ops::Range;
use super::base::Payload;
use super::enums::ContentType;
use super::message::PlainMessage;
use crate::error::{Error, PeerMisbehaved};
use crate::msgs::codec;
use crate::msgs::message::{MessageError, OpaqueMessage};
use crate::record_layer::{Decrypted, RecordLayer};
use crate::ProtocolVersion;
#[derive(Default)]
pub struct MessageDeframer {
desynced: bool,
buf: Vec<u8>,
joining_hs: Option<HandshakePayloadMeta>,
used: usize,
}
impl MessageDeframer {
pub fn pop(&mut self, record_layer: &mut RecordLayer) -> Result<Option<Deframed>, Error> {
if self.desynced {
return Err(Error::CorruptMessage);
} else if self.used == 0 {
return Ok(None);
}
let expected_len = loop {
let start = match &self.joining_hs {
Some(meta) => {
match meta.expected_len {
Some(len) if len <= meta.payload.len() => break len,
_ if meta.quic => return Ok(None),
_ => meta.message.end,
}
}
None => 0,
};
let mut rd = codec::Reader::init(&self.buf[start..self.used]);
let m = match OpaqueMessage::read(&mut rd) {
Ok(m) => m,
Err(MessageError::TooShortForHeader | MessageError::TooShortForLength) => {
return Ok(None)
}
Err(_) => {
self.desynced = true;
return Err(Error::CorruptMessage);
}
};
let end = start + rd.used();
if m.typ == ContentType::ChangeCipherSpec && self.joining_hs.is_none() {
self.discard(end);
return Ok(Some(Deframed {
want_close_before_decrypt: false,
aligned: true,
trial_decryption_finished: false,
message: m.into_plain_message(),
}));
}
let msg = match record_layer.decrypt_incoming(m) {
Ok(Some(decrypted)) => {
let Decrypted {
want_close_before_decrypt,
plaintext,
} = decrypted;
debug_assert!(!want_close_before_decrypt);
plaintext
}
Ok(None) if self.joining_hs.is_some() => {
self.desynced = true;
return Err(
PeerMisbehaved::RejectedEarlyDataInterleavedWithHandshakeMessage.into(),
);
}
Ok(None) => {
self.discard(end);
continue;
}
Err(e) => return Err(e),
};
if self.joining_hs.is_some() && msg.typ != ContentType::Handshake {
self.desynced = true;
return Err(PeerMisbehaved::MessageInterleavedWithHandshakeMessage.into());
}
if msg.typ != ContentType::Handshake {
self.discard(start + rd.used());
return Ok(Some(Deframed {
want_close_before_decrypt: false,
aligned: true,
trial_decryption_finished: false,
message: msg,
}));
}
match self.append_hs(msg.version, &msg.payload.0, end, false)? {
HandshakePayloadState::Blocked => return Ok(None),
HandshakePayloadState::Complete(len) => break len,
HandshakePayloadState::Continue => continue,
}
};
let meta = self.joining_hs.as_mut().unwrap();
let message = PlainMessage {
typ: ContentType::Handshake,
version: meta.version,
payload: Payload::new(&self.buf[meta.payload.start..meta.payload.start + expected_len]),
};
if meta.payload.len() > expected_len {
meta.payload.start += expected_len;
meta.expected_len = payload_size(&self.buf[meta.payload.start..meta.payload.end])?;
} else {
let end = meta.message.end;
self.joining_hs = None;
self.discard(end);
}
Ok(Some(Deframed {
want_close_before_decrypt: false,
aligned: self.joining_hs.is_none(),
trial_decryption_finished: true,
message,
}))
}
#[cfg(feature = "quic")]
pub fn push(&mut self, version: ProtocolVersion, payload: &[u8]) -> Result<(), Error> {
if self.used > 0 && self.joining_hs.is_none() {
return Err(Error::General(
"cannot push QUIC messages into unrelated connection".into(),
));
} else if let Err(err) = self.prepare_read() {
return Err(Error::General(err.into()));
}
let end = self.used + payload.len();
self.append_hs(version, payload, end, true)?;
self.used = end;
Ok(())
}
fn append_hs(
&mut self,
version: ProtocolVersion,
payload: &[u8],
end: usize,
quic: bool,
) -> Result<HandshakePayloadState, Error> {
let meta = match &mut self.joining_hs {
Some(meta) => {
debug_assert_eq!(meta.quic, quic);
let dst = &mut self.buf[meta.payload.end..meta.payload.end + payload.len()];
dst.copy_from_slice(payload);
meta.message.end = end;
meta.payload.end += payload.len();
if meta.expected_len.is_none() {
meta.expected_len =
payload_size(&self.buf[meta.payload.start..meta.payload.end])?;
}
meta
}
None => {
let expected_len = payload_size(payload)?;
let dst = &mut self.buf[..payload.len()];
dst.copy_from_slice(payload);
self.joining_hs
.insert(HandshakePayloadMeta {
message: Range { start: 0, end },
payload: Range {
start: 0,
end: payload.len(),
},
version,
expected_len,
quic,
})
}
};
Ok(match meta.expected_len {
Some(len) if len <= meta.payload.len() => HandshakePayloadState::Complete(len),
_ => match self.used > meta.message.end {
true => HandshakePayloadState::Continue,
false => HandshakePayloadState::Blocked,
},
})
}
#[allow(clippy::comparison_chain)]
pub fn read(&mut self, rd: &mut dyn io::Read) -> io::Result<usize> {
if let Err(err) = self.prepare_read() {
return Err(io::Error::new(io::ErrorKind::InvalidData, err));
}
let new_bytes = rd.read(&mut self.buf[self.used..])?;
self.used += new_bytes;
Ok(new_bytes)
}
fn prepare_read(&mut self) -> Result<(), &'static str> {
let allow_max = match self.joining_hs {
Some(_) => MAX_HANDSHAKE_SIZE as usize,
None => OpaqueMessage::MAX_WIRE_SIZE,
};
if self.used >= allow_max {
return Err("message buffer full");
}
let need_capacity = Ord::min(allow_max, self.used + READ_SIZE);
if need_capacity > self.buf.len() {
self.buf.resize(need_capacity, 0);
} else if self.used == 0 || self.buf.len() > allow_max {
self.buf.resize(need_capacity, 0);
self.buf.shrink_to(need_capacity);
}
Ok(())
}
pub fn has_pending(&self) -> bool {
self.used > 0
}
fn discard(&mut self, taken: usize) {
#[allow(clippy::comparison_chain)]
if taken < self.used {
self.buf
.copy_within(taken..self.used, 0);
self.used -= taken;
} else if taken == self.used {
self.used = 0;
}
}
}
enum HandshakePayloadState {
Blocked,
Complete(usize),
Continue,
}
struct HandshakePayloadMeta {
message: Range<usize>,
payload: Range<usize>,
version: ProtocolVersion,
expected_len: Option<usize>,
quic: bool,
}
fn payload_size(buf: &[u8]) -> Result<Option<usize>, Error> {
if buf.len() < HEADER_SIZE {
return Ok(None);
}
let (header, _) = buf.split_at(HEADER_SIZE);
match codec::u24::decode(&header[1..]) {
Some(len) if len.0 > MAX_HANDSHAKE_SIZE => {
Err(Error::CorruptMessagePayload(ContentType::Handshake))
}
Some(len) => Ok(Some(HEADER_SIZE + usize::from(len))),
_ => Ok(None),
}
}
#[derive(Debug)]
pub struct Deframed {
pub want_close_before_decrypt: bool,
pub aligned: bool,
pub trial_decryption_finished: bool,
pub message: PlainMessage,
}
#[derive(Debug)]
pub enum DeframerError {
HandshakePayloadSizeTooLarge,
}
const HEADER_SIZE: usize = 1 + 3;
const MAX_HANDSHAKE_SIZE: u32 = 0xffff;
const READ_SIZE: usize = 4096;
#[cfg(test)]
mod tests {
use super::MessageDeframer;
use crate::msgs::message::{Message, OpaqueMessage};
use crate::record_layer::RecordLayer;
use crate::{ContentType, Error};
use std::io;
const FIRST_MESSAGE: &[u8] = include_bytes!("../testdata/deframer-test.1.bin");
const SECOND_MESSAGE: &[u8] = include_bytes!("../testdata/deframer-test.2.bin");
const EMPTY_APPLICATIONDATA_MESSAGE: &[u8] =
include_bytes!("../testdata/deframer-empty-applicationdata.bin");
const INVALID_EMPTY_MESSAGE: &[u8] = include_bytes!("../testdata/deframer-invalid-empty.bin");
const INVALID_CONTENTTYPE_MESSAGE: &[u8] =
include_bytes!("../testdata/deframer-invalid-contenttype.bin");
const INVALID_VERSION_MESSAGE: &[u8] =
include_bytes!("../testdata/deframer-invalid-version.bin");
const INVALID_LENGTH_MESSAGE: &[u8] = include_bytes!("../testdata/deframer-invalid-length.bin");
fn input_bytes(d: &mut MessageDeframer, bytes: &[u8]) -> io::Result<usize> {
let mut rd = io::Cursor::new(bytes);
d.read(&mut rd)
}
fn input_bytes_concat(
d: &mut MessageDeframer,
bytes1: &[u8],
bytes2: &[u8],
) -> io::Result<usize> {
let mut bytes = vec![0u8; bytes1.len() + bytes2.len()];
bytes[..bytes1.len()].clone_from_slice(bytes1);
bytes[bytes1.len()..].clone_from_slice(bytes2);
let mut rd = io::Cursor::new(&bytes);
d.read(&mut rd)
}
struct ErrorRead {
error: Option<io::Error>,
}
impl ErrorRead {
fn new(error: io::Error) -> Self {
Self { error: Some(error) }
}
}
impl io::Read for ErrorRead {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
for (i, b) in buf.iter_mut().enumerate() {
*b = i as u8;
}
let error = self.error.take().unwrap();
Err(error)
}
}
fn input_error(d: &mut MessageDeframer) {
let error = io::Error::from(io::ErrorKind::TimedOut);
let mut rd = ErrorRead::new(error);
d.read(&mut rd)
.expect_err("error not propagated");
}
fn input_whole_incremental(d: &mut MessageDeframer, bytes: &[u8]) {
let before = d.used;
for i in 0..bytes.len() {
assert_len(1, input_bytes(d, &bytes[i..i + 1]));
assert!(d.has_pending());
}
assert_eq!(before + bytes.len(), d.used);
}
fn assert_len(want: usize, got: io::Result<usize>) {
if let Ok(gotval) = got {
assert_eq!(gotval, want);
} else {
panic!("read failed, expected {:?} bytes", want);
}
}
fn pop_first(d: &mut MessageDeframer, rl: &mut RecordLayer) {
let m = d.pop(rl).unwrap().unwrap().message;
assert_eq!(m.typ, ContentType::Handshake);
Message::try_from(m).unwrap();
}
fn pop_second(d: &mut MessageDeframer, rl: &mut RecordLayer) {
let m = d.pop(rl).unwrap().unwrap().message;
assert_eq!(m.typ, ContentType::Alert);
Message::try_from(m).unwrap();
}
#[test]
fn check_incremental() {
let mut d = MessageDeframer::default();
assert!(!d.has_pending());
input_whole_incremental(&mut d, FIRST_MESSAGE);
assert!(d.has_pending());
let mut rl = RecordLayer::new();
pop_first(&mut d, &mut rl);
assert!(!d.has_pending());
assert!(!d.desynced);
}
#[test]
fn check_incremental_2() {
let mut d = MessageDeframer::default();
assert!(!d.has_pending());
input_whole_incremental(&mut d, FIRST_MESSAGE);
assert!(d.has_pending());
input_whole_incremental(&mut d, SECOND_MESSAGE);
assert!(d.has_pending());
let mut rl = RecordLayer::new();
pop_first(&mut d, &mut rl);
assert!(d.has_pending());
pop_second(&mut d, &mut rl);
assert!(!d.has_pending());
assert!(!d.desynced);
}
#[test]
fn check_whole() {
let mut d = MessageDeframer::default();
assert!(!d.has_pending());
assert_len(FIRST_MESSAGE.len(), input_bytes(&mut d, FIRST_MESSAGE));
assert!(d.has_pending());
let mut rl = RecordLayer::new();
pop_first(&mut d, &mut rl);
assert!(!d.has_pending());
assert!(!d.desynced);
}
#[test]
fn check_whole_2() {
let mut d = MessageDeframer::default();
assert!(!d.has_pending());
assert_len(FIRST_MESSAGE.len(), input_bytes(&mut d, FIRST_MESSAGE));
assert_len(SECOND_MESSAGE.len(), input_bytes(&mut d, SECOND_MESSAGE));
let mut rl = RecordLayer::new();
pop_first(&mut d, &mut rl);
pop_second(&mut d, &mut rl);
assert!(!d.has_pending());
assert!(!d.desynced);
}
#[test]
fn test_two_in_one_read() {
let mut d = MessageDeframer::default();
assert!(!d.has_pending());
assert_len(
FIRST_MESSAGE.len() + SECOND_MESSAGE.len(),
input_bytes_concat(&mut d, FIRST_MESSAGE, SECOND_MESSAGE),
);
let mut rl = RecordLayer::new();
pop_first(&mut d, &mut rl);
pop_second(&mut d, &mut rl);
assert!(!d.has_pending());
assert!(!d.desynced);
}
#[test]
fn test_two_in_one_read_shortest_first() {
let mut d = MessageDeframer::default();
assert!(!d.has_pending());
assert_len(
FIRST_MESSAGE.len() + SECOND_MESSAGE.len(),
input_bytes_concat(&mut d, SECOND_MESSAGE, FIRST_MESSAGE),
);
let mut rl = RecordLayer::new();
pop_second(&mut d, &mut rl);
pop_first(&mut d, &mut rl);
assert!(!d.has_pending());
assert!(!d.desynced);
}
#[test]
fn test_incremental_with_nonfatal_read_error() {
let mut d = MessageDeframer::default();
assert_len(3, input_bytes(&mut d, &FIRST_MESSAGE[..3]));
input_error(&mut d);
assert_len(
FIRST_MESSAGE.len() - 3,
input_bytes(&mut d, &FIRST_MESSAGE[3..]),
);
let mut rl = RecordLayer::new();
pop_first(&mut d, &mut rl);
assert!(!d.has_pending());
assert!(!d.desynced);
}
#[test]
fn test_invalid_contenttype_errors() {
let mut d = MessageDeframer::default();
assert_len(
INVALID_CONTENTTYPE_MESSAGE.len(),
input_bytes(&mut d, INVALID_CONTENTTYPE_MESSAGE),
);
let mut rl = RecordLayer::new();
assert!(matches!(d.pop(&mut rl).unwrap_err(), Error::CorruptMessage));
}
#[test]
fn test_invalid_version_errors() {
let mut d = MessageDeframer::default();
assert_len(
INVALID_VERSION_MESSAGE.len(),
input_bytes(&mut d, INVALID_VERSION_MESSAGE),
);
let mut rl = RecordLayer::new();
assert!(matches!(d.pop(&mut rl).unwrap_err(), Error::CorruptMessage));
}
#[test]
fn test_invalid_length_errors() {
let mut d = MessageDeframer::default();
assert_len(
INVALID_LENGTH_MESSAGE.len(),
input_bytes(&mut d, INVALID_LENGTH_MESSAGE),
);
let mut rl = RecordLayer::new();
assert!(matches!(d.pop(&mut rl).unwrap_err(), Error::CorruptMessage));
}
#[test]
fn test_empty_applicationdata() {
let mut d = MessageDeframer::default();
assert_len(
EMPTY_APPLICATIONDATA_MESSAGE.len(),
input_bytes(&mut d, EMPTY_APPLICATIONDATA_MESSAGE),
);
let mut rl = RecordLayer::new();
let m = d.pop(&mut rl).unwrap().unwrap().message;
assert_eq!(m.typ, ContentType::ApplicationData);
assert_eq!(m.payload.0.len(), 0);
assert!(!d.has_pending());
assert!(!d.desynced);
}
#[test]
fn test_invalid_empty_errors() {
let mut d = MessageDeframer::default();
assert_len(
INVALID_EMPTY_MESSAGE.len(),
input_bytes(&mut d, INVALID_EMPTY_MESSAGE),
);
let mut rl = RecordLayer::new();
assert!(matches!(d.pop(&mut rl).unwrap_err(), Error::CorruptMessage));
assert!(matches!(d.pop(&mut rl).unwrap_err(), Error::CorruptMessage));
}
#[test]
fn test_limited_buffer() {
const PAYLOAD_LEN: usize = 16_384;
let mut message = Vec::with_capacity(16_389);
message.push(0x17); message.extend(&[0x03, 0x04]); message.extend((PAYLOAD_LEN as u16).to_be_bytes()); message.extend(&[0; PAYLOAD_LEN]);
let mut d = MessageDeframer::default();
assert_len(4096, input_bytes(&mut d, &message));
assert_len(4096, input_bytes(&mut d, &message));
assert_len(4096, input_bytes(&mut d, &message));
assert_len(4096, input_bytes(&mut d, &message));
assert_len(
OpaqueMessage::MAX_WIRE_SIZE - 16_384,
input_bytes(&mut d, &message),
);
assert!(input_bytes(&mut d, &message).is_err());
}
}