#![allow(dead_code)]
use base64::{Engine as _, engine::general_purpose::STANDARD as B64};
use crate::e2e::error::{E2eError, Result};
use crate::e2e::{MAX_CHUNKS, PROTO};
pub const WIRE_PREFIX: &str = "+RPE2E01";
pub type MsgId = [u8; 8];
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WireChunk {
pub msgid: MsgId,
pub ts: i64,
pub part: u8,
pub total: u8,
pub nonce: [u8; 24],
pub ciphertext: Vec<u8>,
}
impl WireChunk {
pub fn encode(&self) -> Result<String> {
if self.total == 0 || self.total > MAX_CHUNKS {
return Err(E2eError::ChunkLimit(self.total));
}
if self.part == 0 || self.part > self.total {
return Err(E2eError::Wire(format!(
"invalid part/total: {}/{}",
self.part, self.total
)));
}
let msgid = hex::encode(self.msgid);
let nonce_b64 = B64.encode(self.nonce);
let ct_b64 = B64.encode(&self.ciphertext);
Ok(format!(
"{WIRE_PREFIX} {msgid} {ts} {part}/{total} {nonce_b64}:{ct_b64}",
ts = self.ts,
part = self.part,
total = self.total,
))
}
pub fn parse(line: &str) -> Result<Option<Self>> {
let rest = match line.strip_prefix(WIRE_PREFIX) {
Some(r) => r.trim_start(),
None => return Ok(None),
};
let mut fields = rest.split_whitespace();
let msgid_hex = fields
.next()
.ok_or_else(|| E2eError::Wire("missing msgid".into()))?;
let ts_str = fields
.next()
.ok_or_else(|| E2eError::Wire("missing ts".into()))?;
let parttot = fields
.next()
.ok_or_else(|| E2eError::Wire("missing part/total".into()))?;
let body = fields
.next()
.ok_or_else(|| E2eError::Wire("missing body".into()))?;
if fields.next().is_some() {
return Err(E2eError::Wire("extra fields".into()));
}
if msgid_hex.len() != 16 {
return Err(E2eError::Wire("msgid must be 16 hex chars".into()));
}
let msgid_vec = hex::decode(msgid_hex)?;
let mut msgid = [0u8; 8];
msgid.copy_from_slice(&msgid_vec);
let ts: i64 = ts_str
.parse()
.map_err(|e| E2eError::Wire(format!("bad ts: {e}")))?;
let (p, t) = parttot
.split_once('/')
.ok_or_else(|| E2eError::Wire("part/total missing slash".into()))?;
let part: u8 = p
.parse()
.map_err(|e| E2eError::Wire(format!("bad part: {e}")))?;
let total: u8 = t
.parse()
.map_err(|e| E2eError::Wire(format!("bad total: {e}")))?;
if total == 0 || total > MAX_CHUNKS || part == 0 || part > total {
return Err(E2eError::Wire(format!("bad part/total {part}/{total}")));
}
let (nonce_b64, ct_b64) = body
.split_once(':')
.ok_or_else(|| E2eError::Wire("missing nonce:ct separator".into()))?;
let nonce_vec = B64.decode(nonce_b64)?;
if nonce_vec.len() != 24 {
return Err(E2eError::Wire(format!(
"nonce must be 24 bytes, got {}",
nonce_vec.len()
)));
}
let mut nonce = [0u8; 24];
nonce.copy_from_slice(&nonce_vec);
let ciphertext = B64.decode(ct_b64)?;
Ok(Some(Self {
msgid,
ts,
part,
total,
nonce,
ciphertext,
}))
}
}
pub fn build_aad(channel: &str, msgid: MsgId, ts: i64, part: u8, total: u8) -> Vec<u8> {
let mut aad = Vec::with_capacity(35 + channel.len());
aad.extend_from_slice(PROTO.as_bytes());
let chan_len = u16::try_from(channel.len()).unwrap_or(u16::MAX);
aad.extend_from_slice(&chan_len.to_be_bytes());
aad.extend_from_slice(channel.as_bytes());
aad.extend_from_slice(&8u16.to_be_bytes());
aad.extend_from_slice(&msgid);
aad.extend_from_slice(&8u16.to_be_bytes());
aad.extend_from_slice(&ts.to_be_bytes());
aad.extend_from_slice(&1u16.to_be_bytes());
aad.push(part);
aad.extend_from_slice(&1u16.to_be_bytes());
aad.push(total);
aad
}
pub fn fresh_msgid() -> MsgId {
let mut id = [0u8; 8];
rand::fill(&mut id);
id
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_chunk() -> WireChunk {
WireChunk {
msgid: [0xab; 8],
ts: 1_712_000_000,
part: 1,
total: 1,
nonce: [0x42; 24],
ciphertext: vec![0xde, 0xad, 0xbe, 0xef],
}
}
#[test]
fn encode_starts_with_prefix() {
let enc = sample_chunk().encode().unwrap();
assert!(enc.starts_with(WIRE_PREFIX));
}
#[test]
fn encode_roundtrip() {
let c = sample_chunk();
let enc = c.encode().unwrap();
let parsed = WireChunk::parse(&enc).unwrap().unwrap();
assert_eq!(parsed, c);
}
#[test]
fn parse_cleartext_returns_none() {
assert_eq!(WireChunk::parse("hello world").unwrap(), None);
assert_eq!(WireChunk::parse("").unwrap(), None);
}
#[test]
fn parse_rejects_invalid_part_total() {
let mut c = sample_chunk();
c.total = 0;
assert!(c.encode().is_err());
c.total = 17;
assert!(c.encode().is_err());
c.total = 3;
c.part = 4;
assert!(c.encode().is_err());
}
#[test]
fn parse_rejects_bad_nonce_length() {
let bad = "+RPE2E01 abababababababab 1712000000 1/1 YWJj:ZGVm";
assert!(WireChunk::parse(bad).is_err());
}
#[test]
fn build_aad_is_deterministic() {
let a = build_aad("#chan", [1; 8], 100, 1, 3);
let b = build_aad("#chan", [1; 8], 100, 1, 3);
assert_eq!(a, b);
}
#[test]
fn build_aad_sensitive_to_every_field() {
let base = build_aad("#chan", [1; 8], 100, 1, 3);
assert_ne!(base, build_aad("#other", [1; 8], 100, 1, 3));
assert_ne!(base, build_aad("#chan", [2; 8], 100, 1, 3));
assert_ne!(base, build_aad("#chan", [1; 8], 101, 1, 3));
assert_ne!(base, build_aad("#chan", [1; 8], 100, 2, 3));
assert_ne!(base, build_aad("#chan", [1; 8], 100, 1, 4));
}
#[test]
fn build_aad_golden_vector() {
let got = build_aad("#chan", [1u8; 8], 100, 1, 3);
let expected: Vec<u8> = vec![
0x52, 0x50, 0x45, 0x32, 0x45, 0x30, 0x31, 0x00, 0x05, 0x23, 0x63, 0x68, 0x61, 0x6e, 0x00, 0x08, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x64,
0x00, 0x01, 0x01, 0x00, 0x01, 0x03,
];
assert_eq!(got.len(), 40, "AAD length mismatch");
assert_eq!(got, expected, "AAD golden byte sequence mismatch");
}
#[test]
fn build_aad_length_prefix_rejects_colon_ambiguity() {
let a = build_aad("#a:b", [1; 8], 100, 1, 3);
let b = build_aad("#a", [1; 8], 100, 1, 3);
assert_ne!(a, b);
assert_ne!(a.len(), b.len());
}
#[test]
fn fresh_msgid_is_random_ish() {
let a = fresh_msgid();
let b = fresh_msgid();
assert_ne!(a, b);
}
}