use cipher::StreamCipher;
use crate::crypto::{AesCtr256, HANDSHAKE_LEN, ProtoTag, SKIP_LEN, make_cipher};
const PREKEY_LEN: usize = 32;
const IV_LEN: usize = 16;
pub struct MsgSplitter {
dec: AesCtr256,
proto: ProtoTag,
cipher_buf: Vec<u8>,
plain_buf: Vec<u8>,
disabled: bool,
}
impl MsgSplitter {
pub fn new(relay_init: &[u8; HANDSHAKE_LEN], proto: ProtoTag) -> Self {
let relay_enc_key = &relay_init[SKIP_LEN..SKIP_LEN + PREKEY_LEN];
let relay_enc_iv = &relay_init[SKIP_LEN + PREKEY_LEN..SKIP_LEN + PREKEY_LEN + IV_LEN];
let mut dec = make_cipher(relay_enc_key, relay_enc_iv);
let mut dummy = [0u8; HANDSHAKE_LEN];
dec.apply_keystream(&mut dummy);
Self {
dec,
proto,
cipher_buf: Vec::new(),
plain_buf: Vec::new(),
disabled: false,
}
}
pub fn split(&mut self, encrypted: &[u8]) -> Vec<Vec<u8>> {
if encrypted.is_empty() {
return Vec::new();
}
if self.disabled {
return vec![encrypted.to_vec()];
}
let mut plain = encrypted.to_vec();
self.dec.apply_keystream(&mut plain);
self.cipher_buf.extend_from_slice(encrypted);
self.plain_buf.extend_from_slice(&plain);
let mut parts = Vec::new();
let mut consumed = 0usize;
loop {
match self.next_packet_len(consumed) {
None => break, Some(0) => {
parts.push(self.cipher_buf[consumed..].to_vec());
self.cipher_buf.clear();
self.plain_buf.clear();
self.disabled = true;
return parts;
}
Some(len) => {
let end = consumed + len;
parts.push(self.cipher_buf[consumed..end].to_vec());
consumed = end;
}
}
}
if consumed != 0 {
self.cipher_buf.drain(..consumed);
self.plain_buf.drain(..consumed);
}
parts
}
pub fn flush(&mut self) -> Vec<Vec<u8>> {
if self.cipher_buf.is_empty() {
return Vec::new();
}
let tail = self.cipher_buf.clone();
self.cipher_buf.clear();
self.plain_buf.clear();
vec![tail]
}
fn next_packet_len(&self, offset: usize) -> Option<usize> {
let plain = self.plain_buf.get(offset..)?;
if plain.is_empty() {
return None;
}
match self.proto {
ProtoTag::Abridged => Self::abridged_len(plain),
ProtoTag::Intermediate | ProtoTag::PaddedIntermediate => Self::intermediate_len(plain),
}
}
fn abridged_len(plain: &[u8]) -> Option<usize> {
let first = plain[0];
let (payload_len, header_len) = if first == 0x7F || first == 0xFF {
if plain.len() < 4 {
return None; }
let l = u32::from_le_bytes([plain[1], plain[2], plain[3], 0]) as usize * 4;
(l, 4)
} else {
((first & 0x7F) as usize * 4, 1)
};
if payload_len == 0 {
return Some(0); }
let total = header_len + payload_len;
if plain.len() < total {
None
} else {
Some(total)
}
}
fn intermediate_len(plain: &[u8]) -> Option<usize> {
if plain.len() < 4 {
return None;
}
let payload_len =
(u32::from_le_bytes([plain[0], plain[1], plain[2], plain[3]]) & 0x7FFF_FFFF) as usize;
if payload_len == 0 {
return Some(0);
}
let total = 4 + payload_len;
if plain.len() < total {
None
} else {
Some(total)
}
}
}