use std::collections::VecDeque;
use crate::enums::ProtocolVersion;
use crate::msgs::base::Payload;
use crate::msgs::codec;
use crate::msgs::enums::ContentType;
use crate::msgs::handshake::HandshakeMessagePayload;
use crate::msgs::message::{Message, MessagePayload, PlainMessage};
const HEADER_SIZE: usize = 1 + 3;
const MAX_HANDSHAKE_SIZE: u32 = 0xffff;
pub struct HandshakeJoiner {
buf: Vec<u8>,
sizes: VecDeque<usize>,
version: ProtocolVersion,
}
impl HandshakeJoiner {
pub fn new() -> Self {
Self {
buf: Vec::new(),
sizes: VecDeque::new(),
version: ProtocolVersion::TLSv1_2,
}
}
pub fn push(&mut self, msg: PlainMessage) -> Result<bool, JoinerError> {
if msg.typ != ContentType::Handshake {
return Err(JoinerError::Unwanted(msg));
}
if self.buf.is_empty() {
self.buf = msg.payload.0;
} else {
self.buf
.extend_from_slice(&msg.payload.0[..]);
}
if msg.version == ProtocolVersion::TLSv1_3 {
self.version = msg.version;
}
let mut complete = self.sizes.iter().copied().sum();
while let Some(size) = payload_size(&self.buf[complete..])? {
self.sizes.push_back(size);
complete += size;
}
Ok(complete == self.buf.len())
}
pub fn pop(&mut self) -> Result<Option<Message>, JoinerError> {
let len = match self.sizes.pop_front() {
Some(len) => len,
None => return Ok(None),
};
let buf = &self.buf[..len];
let mut rd = codec::Reader::init(buf);
let parsed = match HandshakeMessagePayload::read_version(&mut rd, self.version) {
Some(p) => p,
None => return Err(JoinerError::Decode),
};
let message = Message {
version: self.version,
payload: MessagePayload::Handshake {
parsed,
encoded: Payload::new(buf),
},
};
self.buf.drain(..len);
Ok(Some(message))
}
}
fn payload_size(buf: &[u8]) -> Result<Option<usize>, JoinerError> {
if buf.len() < HEADER_SIZE {
return Ok(None);
}
let (header, rest) = buf.split_at(HEADER_SIZE);
match codec::u24::decode(&header[1..]) {
Some(len) if len.0 > MAX_HANDSHAKE_SIZE => Err(JoinerError::Decode),
Some(len) if rest.get(..len.into()).is_some() => Ok(Some(HEADER_SIZE + usize::from(len))),
_ => Ok(None),
}
}
#[derive(Debug)]
pub enum JoinerError {
Unwanted(PlainMessage),
Decode,
}
#[cfg(test)]
mod tests {
use super::HandshakeJoiner;
use crate::enums::ProtocolVersion;
use crate::msgs::base::Payload;
use crate::msgs::codec::Codec;
use crate::msgs::enums::{ContentType, HandshakeType};
use crate::msgs::handshake::{HandshakeMessagePayload, HandshakePayload};
use crate::msgs::message::{Message, MessagePayload, PlainMessage};
#[test]
fn want() {
let mut hj = HandshakeJoiner::new();
let wanted = PlainMessage {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: Payload::new(b"\x00\x00\x00\x00".to_vec()),
};
let unwanted = PlainMessage {
typ: ContentType::Alert,
version: ProtocolVersion::TLSv1_2,
payload: Payload::new(b"ponytown".to_vec()),
};
hj.push(wanted).unwrap();
hj.push(unwanted).unwrap_err();
}
fn pop_eq(expect: &PlainMessage, hj: &mut HandshakeJoiner) {
let got = hj.pop().unwrap().unwrap();
assert_eq!(got.payload.content_type(), expect.typ);
assert_eq!(got.version, expect.version);
let (mut left, mut right) = (Vec::new(), Vec::new());
got.payload.encode(&mut left);
expect.payload.encode(&mut right);
assert_eq!(left, right);
}
#[test]
fn split() {
let mut hj = HandshakeJoiner::new();
assert!(hj
.push(PlainMessage {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: Payload::new(b"\x00\x00\x00\x00\x00\x00\x00\x00".to_vec()),
})
.unwrap());
let expect = Message {
version: ProtocolVersion::TLSv1_2,
payload: MessagePayload::handshake(HandshakeMessagePayload {
typ: HandshakeType::HelloRequest,
payload: HandshakePayload::HelloRequest,
}),
}
.into();
pop_eq(&expect, &mut hj);
pop_eq(&expect, &mut hj);
}
#[test]
fn broken() {
let mut hj = HandshakeJoiner::new();
hj.push(PlainMessage {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: Payload::new(b"\x01\x00\x00\x02\xff\xff".to_vec()),
})
.unwrap();
hj.pop().unwrap_err();
}
#[test]
fn join() {
let mut hj = HandshakeJoiner::new();
hj.push(PlainMessage {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: Payload::new(b"\x14\x00\x00\x10\x00\x01\x02\x03\x04".to_vec()),
})
.unwrap();
assert!(!hj
.push(PlainMessage {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: Payload::new(b"\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e".to_vec()),
})
.unwrap());
assert!(hj
.push(PlainMessage {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: Payload::new(b"\x0f".to_vec()),
})
.unwrap());
let payload = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f".to_vec();
let expect = Message {
version: ProtocolVersion::TLSv1_2,
payload: MessagePayload::handshake(HandshakeMessagePayload {
typ: HandshakeType::Finished,
payload: HandshakePayload::Finished(Payload::new(payload)),
}),
}
.into();
pop_eq(&expect, &mut hj);
}
#[test]
fn test_rejects_giant_certs() {
let mut hj = HandshakeJoiner::new();
hj.push(PlainMessage {
typ: ContentType::Handshake,
version: ProtocolVersion::TLSv1_2,
payload: Payload::new(b"\x0b\x01\x00\x04\x01\x00\x01\x00\xff\xfe".to_vec()),
})
.unwrap_err();
}
}