use cipher::StreamCipher;
use crate::crypto::{AesCtr256, HANDSHAKE_LEN, ProtoTag, SKIP_LEN, make_cipher};
const PREKEY_LEN: usize = 32;
const IV_LEN: usize = 16;
pub struct MsgSplitter {
dec: AesCtr256,
proto: ProtoTag,
cipher_buf: Vec<u8>,
plain_buf: Vec<u8>,
disabled: bool,
}
impl MsgSplitter {
pub fn new(relay_init: &[u8; HANDSHAKE_LEN], proto: ProtoTag) -> Self {
let relay_enc_key = &relay_init[SKIP_LEN..SKIP_LEN + PREKEY_LEN];
let relay_enc_iv = &relay_init[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
let mut dec = make_cipher(relay_enc_key, relay_enc_iv);
let mut dummy = [0u8; HANDSHAKE_LEN];
dec.apply_keystream(&mut dummy);
Self {
dec,
proto,
cipher_buf: Vec::new(),
plain_buf: Vec::new(),
disabled: false,
}
}
pub fn split(&mut self, encrypted: &[u8]) -> Vec<Vec<u8>> {
if encrypted.is_empty() {
return Vec::new();
}
if self.disabled {
return vec![encrypted.to_vec()];
}
let mut plain = encrypted.to_vec();
self.dec.apply_keystream(&mut plain);
self.cipher_buf.extend_from_slice(encrypted);
self.plain_buf.extend_from_slice(&plain);
let mut parts = Vec::new();
loop {
match self.next_packet_len() {
None => break, Some(0) => {
parts.push(self.cipher_buf.clone());
self.cipher_buf.clear();
self.plain_buf.clear();
self.disabled = true;
break;
}
Some(len) => {
parts.push(self.cipher_buf[..len].to_vec());
self.cipher_buf.drain(..len);
self.plain_buf.drain(..len);
}
}
}
parts
}
pub fn flush(&mut self) -> Vec<Vec<u8>> {
if self.cipher_buf.is_empty() {
return Vec::new();
}
let tail = self.cipher_buf.clone();
self.cipher_buf.clear();
self.plain_buf.clear();
vec![tail]
}
fn next_packet_len(&self) -> Option<usize> {
if self.plain_buf.is_empty() {
return None;
}
match self.proto {
ProtoTag::Abridged => self.abridged_len(),
ProtoTag::Intermediate | ProtoTag::PaddedIntermediate => self.intermediate_len(),
}
}
fn abridged_len(&self) -> Option<usize> {
let first = self.plain_buf[0];
let (payload_len, header_len) = if first == 0x7F || first == 0xFF {
if self.plain_buf.len() < 4 {
return None; }
let l = u32::from_le_bytes([self.plain_buf[1], self.plain_buf[2], self.plain_buf[3], 0])
as usize
* 4;
(l, 4)
} else {
((first & 0x7F) as usize * 4, 1)
};
if payload_len == 0 {
return Some(0); }
let total = header_len + payload_len;
if self.plain_buf.len() < total {
None
} else {
Some(total)
}
}
fn intermediate_len(&self) -> Option<usize> {
if self.plain_buf.len() < 4 {
return None;
}
let payload_len = (u32::from_le_bytes([
self.plain_buf[0],
self.plain_buf[1],
self.plain_buf[2],
self.plain_buf[3],
]) & 0x7FFF_FFFF) as usize;
if payload_len == 0 {
return Some(0);
}
let total = 4 + payload_len;
if self.plain_buf.len() < total {
None
} else {
Some(total)
}
}
}