use std::collections::{HashMap, VecDeque};
use std::time::{Duration, Instant};
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as B64};
use crate::e2e::crypto::aead::NONCE_LEN;
use crate::e2e::error::{E2eError, Result};
pub const CTCP_TAG: &str = "RPEE2E";
pub const PROTO_VERSION: u8 = 1;
const RATE_LIMIT_WINDOW: Duration = Duration::from_secs(30);
const INCOMING_WINDOW: Duration = Duration::from_secs(60);
const INCOMING_MAX_PER_WINDOW: usize = 3;
const INCOMING_BACKOFF: Duration = Duration::from_secs(5 * 60);
#[derive(Debug, Clone)]
pub struct KeyReq {
pub channel: String,
pub pubkey: [u8; 32],
pub eph_x25519: [u8; 32],
pub nonce: [u8; 16],
pub sig: [u8; 64],
}
#[derive(Debug, Clone)]
pub struct KeyRsp {
pub channel: String,
pub pubkey: [u8; 32],
pub ephemeral_pub: [u8; 32],
pub wrap_nonce: [u8; NONCE_LEN],
pub wrap_ct: Vec<u8>,
pub nonce: [u8; 16],
pub sig: [u8; 64],
}
fn sig_payload_keyreq(
channel: &str,
pubkey: &[u8; 32],
eph_x25519: &[u8; 32],
nonce: &[u8; 16],
) -> Vec<u8> {
let mut v = Vec::with_capacity(16 + channel.len() + 32 + 32 + 16);
v.extend_from_slice(b"KEYREQ:");
v.extend_from_slice(channel.as_bytes());
v.push(b':');
v.extend_from_slice(pubkey);
v.push(b':');
v.extend_from_slice(eph_x25519);
v.push(b':');
v.extend_from_slice(nonce);
v
}
#[derive(Debug, Clone)]
pub struct KeyRekey {
pub channel: String,
pub pubkey: [u8; 32],
pub eph_pub: [u8; 32],
pub wrap_nonce: [u8; NONCE_LEN],
pub wrap_ct: Vec<u8>,
pub nonce: [u8; 16],
pub sig: [u8; 64],
}
fn sig_payload_keyrsp(
channel: &str,
pubkey: &[u8; 32],
eph_pub: &[u8; 32],
wrap_nonce: &[u8; NONCE_LEN],
wrap_ct: &[u8],
nonce: &[u8; 16],
) -> Vec<u8> {
let mut v = Vec::with_capacity(16 + channel.len() + 32 + 32 + NONCE_LEN + wrap_ct.len() + 16);
v.extend_from_slice(b"KEYRSP:");
v.extend_from_slice(channel.as_bytes());
v.push(b':');
v.extend_from_slice(pubkey);
v.push(b':');
v.extend_from_slice(eph_pub);
v.push(b':');
v.extend_from_slice(wrap_nonce);
v.push(b':');
v.extend_from_slice(wrap_ct);
v.push(b':');
v.extend_from_slice(nonce);
v
}
fn sig_payload_keyrekey(
channel: &str,
pubkey: &[u8; 32],
eph_pub: &[u8; 32],
wrap_nonce: &[u8; NONCE_LEN],
wrap_ct: &[u8],
nonce: &[u8; 16],
) -> Vec<u8> {
let mut v = Vec::with_capacity(8 + channel.len() + 32 + 32 + NONCE_LEN + wrap_ct.len() + 16);
v.extend_from_slice(b"REKEY:");
v.extend_from_slice(channel.as_bytes());
v.push(b':');
v.extend_from_slice(pubkey);
v.push(b':');
v.extend_from_slice(eph_pub);
v.push(b':');
v.extend_from_slice(wrap_nonce);
v.push(b':');
v.extend_from_slice(wrap_ct);
v.push(b':');
v.extend_from_slice(nonce);
v
}
#[must_use]
pub fn encode_keyreq(req: &KeyReq) -> String {
format!(
"{CTCP_TAG} KEYREQ v={PROTO_VERSION} c={chan} p={pub_} e={eph} n={nonce} s={sig}",
chan = req.channel,
pub_ = b64_encode(req.pubkey),
eph = b64_encode(req.eph_x25519),
nonce = b64_encode(req.nonce),
sig = b64_encode(req.sig),
)
}
#[must_use]
pub fn encode_keyrsp(rsp: &KeyRsp) -> String {
format!(
"{CTCP_TAG} KEYRSP v={PROTO_VERSION} c={chan} p={pub_} e={eph} wn={wnonce} w={wrap} n={nonce} s={sig}",
chan = rsp.channel,
pub_ = b64_encode(rsp.pubkey),
eph = b64_encode(rsp.ephemeral_pub),
wnonce = b64_encode(rsp.wrap_nonce),
wrap = B64.encode(&rsp.wrap_ct),
nonce = b64_encode(rsp.nonce),
sig = b64_encode(rsp.sig),
)
}
#[must_use]
pub fn encode_keyrekey(rk: &KeyRekey) -> String {
format!(
"{CTCP_TAG} REKEY v={PROTO_VERSION} c={chan} p={pub_} e={eph} wn={wnonce} w={wrap} n={nonce} s={sig}",
chan = rk.channel,
pub_ = b64_encode(rk.pubkey),
eph = b64_encode(rk.eph_pub),
wnonce = b64_encode(rk.wrap_nonce),
wrap = B64.encode(&rk.wrap_ct),
nonce = b64_encode(rk.nonce),
sig = b64_encode(rk.sig),
)
}
#[derive(Debug)]
pub enum HandshakeMsg {
Req(KeyReq),
Rsp(KeyRsp),
Rekey(KeyRekey),
}
pub fn parse(body: &str) -> Result<Option<HandshakeMsg>> {
let mut parts = body.split_whitespace();
if parts.next() != Some(CTCP_TAG) {
return Ok(None);
}
let kind = parts
.next()
.ok_or_else(|| E2eError::Handshake("missing type".into()))?;
let rest: Vec<&str> = parts.collect();
let kv = parse_kv(&rest)?;
let v: u8 = kv
.get("v")
.ok_or_else(|| E2eError::Handshake("missing v".into()))?
.parse()
.map_err(|e| E2eError::Handshake(format!("bad v: {e}")))?;
if v != PROTO_VERSION {
return Err(E2eError::Handshake(format!("unsupported version {v}")));
}
match kind {
"KEYREQ" => parse_keyreq(&kv).map(|r| Some(HandshakeMsg::Req(r))),
"KEYRSP" => parse_keyrsp(&kv).map(|r| Some(HandshakeMsg::Rsp(r))),
"REKEY" => parse_keyrekey(&kv).map(|r| Some(HandshakeMsg::Rekey(r))),
_ => Err(E2eError::Handshake(format!("unknown type {kind}"))),
}
}
fn kv_get<'a>(kv: &'a HashMap<&'a str, &'a str>, key: &'static str) -> Result<&'a str> {
kv.get(key)
.copied()
.ok_or_else(|| E2eError::Handshake(key.into()))
}
fn parse_wrap_nonce(kv: &HashMap<&str, &str>) -> Result<[u8; NONCE_LEN]> {
let raw = b64_decode(kv_get(kv, "wn")?)?;
if raw.len() != NONCE_LEN {
return Err(E2eError::Handshake(format!(
"wn len {} != {NONCE_LEN}",
raw.len()
)));
}
let mut arr = [0u8; NONCE_LEN];
arr.copy_from_slice(&raw);
Ok(arr)
}
fn parse_wrap_ct(kv: &HashMap<&str, &str>) -> Result<Vec<u8>> {
B64.decode(kv_get(kv, "w")?)
.map_err(|e| E2eError::Handshake(format!("bad wrap b64: {e}")))
}
fn parse_keyreq(kv: &HashMap<&str, &str>) -> Result<KeyReq> {
Ok(KeyReq {
channel: kv_get(kv, "c")?.to_string(),
pubkey: b64_32(kv_get(kv, "p")?)?,
eph_x25519: b64_32(kv_get(kv, "e")?)?,
nonce: b64_16(kv_get(kv, "n")?)?,
sig: b64_64(kv_get(kv, "s")?)?,
})
}
fn parse_keyrsp(kv: &HashMap<&str, &str>) -> Result<KeyRsp> {
Ok(KeyRsp {
channel: kv_get(kv, "c")?.to_string(),
pubkey: b64_32(kv_get(kv, "p")?)?,
ephemeral_pub: b64_32(kv_get(kv, "e")?)?,
wrap_nonce: parse_wrap_nonce(kv)?,
wrap_ct: parse_wrap_ct(kv)?,
nonce: b64_16(kv_get(kv, "n")?)?,
sig: b64_64(kv_get(kv, "s")?)?,
})
}
fn parse_keyrekey(kv: &HashMap<&str, &str>) -> Result<KeyRekey> {
Ok(KeyRekey {
channel: kv_get(kv, "c")?.to_string(),
pubkey: b64_32(kv_get(kv, "p")?)?,
eph_pub: b64_32(kv_get(kv, "e")?)?,
wrap_nonce: parse_wrap_nonce(kv)?,
wrap_ct: parse_wrap_ct(kv)?,
nonce: b64_16(kv_get(kv, "n")?)?,
sig: b64_64(kv_get(kv, "s")?)?,
})
}
fn parse_kv<'a>(fields: &'a [&'a str]) -> Result<HashMap<&'a str, &'a str>> {
let mut out: HashMap<&'a str, &'a str> = HashMap::new();
for f in fields {
if let Some((k, v)) = f.split_once('=')
&& out.insert(k, v).is_some()
{
return Err(E2eError::Wire(format!("duplicate key: {k}")));
}
}
Ok(out)
}
fn b64_encode<const N: usize>(bytes: [u8; N]) -> String {
B64.encode(bytes)
}
fn b64_decode(s: &str) -> Result<Vec<u8>> {
B64.decode(s)
.map_err(|e| E2eError::Handshake(format!("bad b64: {e}")))
}
fn b64_fixed<const N: usize>(s: &str) -> Result<[u8; N]> {
let raw = b64_decode(s)?;
if raw.len() != N {
return Err(E2eError::Handshake(format!(
"expected {N} bytes, got {}",
raw.len()
)));
}
let mut arr = [0u8; N];
arr.copy_from_slice(&raw);
Ok(arr)
}
fn b64_32(s: &str) -> Result<[u8; 32]> {
b64_fixed(s)
}
fn b64_16(s: &str) -> Result<[u8; 16]> {
b64_fixed(s)
}
fn b64_64(s: &str) -> Result<[u8; 64]> {
b64_fixed(s)
}
#[derive(Debug, Default)]
struct IncomingBucket {
recent: VecDeque<Instant>,
backoff_until: Option<Instant>,
}
#[derive(Debug, Default)]
pub struct RateLimiter {
last_sent: HashMap<String, Instant>,
incoming: HashMap<String, IncomingBucket>,
}
impl RateLimiter {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn allow_outgoing(&mut self, peer_handle: &str) -> bool {
let now = Instant::now();
if let Some(ts) = self.last_sent.get(peer_handle)
&& now.duration_since(*ts) < RATE_LIMIT_WINDOW
{
return false;
}
self.last_sent.insert(peer_handle.to_string(), now);
true
}
pub fn allow_incoming(&mut self, peer_handle: &str) -> bool {
let now = Instant::now();
let bucket = self.incoming.entry(peer_handle.to_string()).or_default();
if let Some(until) = bucket.backoff_until {
if now < until {
return false;
}
bucket.backoff_until = None;
bucket.recent.clear();
}
while let Some(front) = bucket.recent.front() {
if now.duration_since(*front) > INCOMING_WINDOW {
bucket.recent.pop_front();
} else {
break;
}
}
if bucket.recent.len() >= INCOMING_MAX_PER_WINDOW {
bucket.backoff_until = Some(now + INCOMING_BACKOFF);
return false;
}
bucket.recent.push_back(now);
true
}
#[cfg(test)]
fn force_expire_backoff(&mut self, peer_handle: &str) {
if let Some(bucket) = self.incoming.get_mut(peer_handle) {
bucket.backoff_until = None;
bucket.recent.clear();
}
}
}
#[must_use]
pub fn signed_keyreq_payload(
channel: &str,
pubkey: &[u8; 32],
eph_x25519: &[u8; 32],
nonce: &[u8; 16],
) -> Vec<u8> {
sig_payload_keyreq(channel, pubkey, eph_x25519, nonce)
}
#[must_use]
pub fn signed_keyrsp_payload(
channel: &str,
pubkey: &[u8; 32],
eph_pub: &[u8; 32],
wrap_nonce: &[u8; NONCE_LEN],
wrap_ct: &[u8],
nonce: &[u8; 16],
) -> Vec<u8> {
sig_payload_keyrsp(channel, pubkey, eph_pub, wrap_nonce, wrap_ct, nonce)
}
#[must_use]
pub fn signed_keyrekey_payload(
channel: &str,
pubkey: &[u8; 32],
eph_pub: &[u8; 32],
wrap_nonce: &[u8; NONCE_LEN],
wrap_ct: &[u8],
nonce: &[u8; 16],
) -> Vec<u8> {
sig_payload_keyrekey(channel, pubkey, eph_pub, wrap_nonce, wrap_ct, nonce)
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_req() -> KeyReq {
KeyReq {
channel: "#x".into(),
pubkey: [1; 32],
eph_x25519: [9; 32],
nonce: [2; 16],
sig: [3; 64],
}
}
fn sample_rsp() -> KeyRsp {
KeyRsp {
channel: "#x".into(),
pubkey: [12; 32],
ephemeral_pub: [4; 32],
wrap_nonce: [5; NONCE_LEN],
wrap_ct: vec![6, 7, 8, 9],
nonce: [10; 16],
sig: [11; 64],
}
}
fn sample_rekey() -> KeyRekey {
KeyRekey {
channel: "#x".into(),
pubkey: [13; 32],
eph_pub: [14; 32],
wrap_nonce: [15; NONCE_LEN],
wrap_ct: vec![16, 17, 18, 19, 20],
nonce: [21; 16],
sig: [22; 64],
}
}
#[test]
fn keyreq_roundtrip() {
let req = sample_req();
let enc = encode_keyreq(&req);
let parsed = parse(&enc).unwrap().unwrap();
match parsed {
HandshakeMsg::Req(r) => {
assert_eq!(r.channel, req.channel);
assert_eq!(r.pubkey, req.pubkey);
assert_eq!(r.eph_x25519, req.eph_x25519);
assert_eq!(r.nonce, req.nonce);
assert_eq!(r.sig, req.sig);
}
HandshakeMsg::Rsp(_) | HandshakeMsg::Rekey(_) => panic!("expected Req"),
}
}
#[test]
fn keyrsp_roundtrip() {
let rsp = sample_rsp();
let enc = encode_keyrsp(&rsp);
let parsed = parse(&enc).unwrap().unwrap();
match parsed {
HandshakeMsg::Rsp(r) => {
assert_eq!(r.channel, rsp.channel);
assert_eq!(r.pubkey, rsp.pubkey);
assert_eq!(r.ephemeral_pub, rsp.ephemeral_pub);
assert_eq!(r.wrap_nonce, rsp.wrap_nonce);
assert_eq!(r.wrap_ct, rsp.wrap_ct);
assert_eq!(r.nonce, rsp.nonce);
assert_eq!(r.sig, rsp.sig);
}
HandshakeMsg::Req(_) | HandshakeMsg::Rekey(_) => panic!("expected Rsp"),
}
}
#[test]
fn keyrekey_roundtrip() {
let rk = sample_rekey();
let enc = encode_keyrekey(&rk);
let parsed = parse(&enc).unwrap().unwrap();
match parsed {
HandshakeMsg::Rekey(r) => {
assert_eq!(r.channel, rk.channel);
assert_eq!(r.pubkey, rk.pubkey);
assert_eq!(r.eph_pub, rk.eph_pub);
assert_eq!(r.wrap_nonce, rk.wrap_nonce);
assert_eq!(r.wrap_ct, rk.wrap_ct);
assert_eq!(r.nonce, rk.nonce);
assert_eq!(r.sig, rk.sig);
}
HandshakeMsg::Req(_) | HandshakeMsg::Rsp(_) => panic!("expected Rekey"),
}
}
#[test]
fn keyrekey_sig_payload_binds_eph_and_ct() {
let p1 =
signed_keyrekey_payload("#x", &[1; 32], &[2; 32], &[3; NONCE_LEN], &[4, 5], &[6; 16]);
let p2 =
signed_keyrekey_payload("#x", &[1; 32], &[9; 32], &[3; NONCE_LEN], &[4, 5], &[6; 16]);
let p3 =
signed_keyrekey_payload("#x", &[1; 32], &[2; 32], &[3; NONCE_LEN], &[4, 6], &[6; 16]);
assert_ne!(p1, p2);
assert_ne!(p1, p3);
}
#[test]
fn parse_non_rpee2e_returns_none() {
assert!(parse("SOMETHING ELSE").unwrap().is_none());
assert!(parse("").unwrap().is_none());
}
#[test]
fn parse_kv_rejects_duplicate_key() {
let line = format!(
"{CTCP_TAG} KEYREQ v=1 c=#a c=#b p={p} e={e} n={n} s={s}",
p = b64_encode([0u8; 32]),
e = b64_encode([0u8; 32]),
n = b64_encode([0u8; 16]),
s = b64_encode([0u8; 64]),
);
match parse(&line) {
Err(E2eError::Wire(msg)) => {
assert!(
msg.contains("duplicate key"),
"expected 'duplicate key' in error, got: {msg}"
);
assert!(msg.contains('c'), "expected 'c' in error, got: {msg}");
}
other => panic!("expected Err(Wire(duplicate key)), got {other:?}"),
}
}
#[test]
fn parse_kv_rejects_duplicate_key_in_keyrsp() {
let line = format!(
"{CTCP_TAG} KEYRSP v=1 c=#x p={p} e={e} wn={wn} w={w} w={w2} n={n} s={s}",
p = b64_encode([0u8; 32]),
e = b64_encode([0u8; 32]),
wn = b64_encode([0u8; NONCE_LEN]),
w = B64.encode([0u8; 4]),
w2 = B64.encode([1u8; 4]),
n = b64_encode([0u8; 16]),
s = b64_encode([0u8; 64]),
);
match parse(&line) {
Err(E2eError::Wire(msg)) => assert!(msg.contains("duplicate key: w")),
other => panic!("expected Err(Wire(duplicate key: wrap)), got {other:?}"),
}
}
#[test]
fn parse_rejects_unknown_version() {
let line = format!(
"{CTCP_TAG} KEYREQ v=9 c=#x p={p} e={e} n={n} s={s}",
p = b64_encode([0u8; 32]),
e = b64_encode([0u8; 32]),
n = b64_encode([0u8; 16]),
s = b64_encode([0u8; 64]),
);
assert!(parse(&line).is_err());
}
#[test]
fn keyrsp_fits_under_irc_line_limit_with_long_prefix() {
let rsp = KeyRsp {
channel: "#irc.al".into(),
pubkey: [12; 32],
ephemeral_pub: [4; 32],
wrap_nonce: [5; NONCE_LEN],
wrap_ct: vec![6; 48],
nonce: [10; 16],
sig: [11; 64],
};
let body = format!("\x01{}\x01", encode_keyrsp(&rsp));
let prefix = ":nick!^prostatut@2a14:7584:44e4:7af6:c219:38d4:e5b7:1c63 NOTICE kofany_ :";
let line_len = format!("{prefix}{body}\r\n").len();
assert!(line_len <= 512, "KEYRSP line too long: {line_len} bytes");
}
#[test]
fn rate_limiter_blocks_within_window() {
let mut rl = RateLimiter::new();
assert!(rl.allow_outgoing("~bob@host"));
assert!(!rl.allow_outgoing("~bob@host"));
assert!(rl.allow_outgoing("~alice@host"));
}
#[test]
fn allow_incoming_permits_first_three_then_backoffs() {
let mut rl = RateLimiter::new();
assert!(rl.allow_incoming("~bob@host"));
assert!(rl.allow_incoming("~bob@host"));
assert!(rl.allow_incoming("~bob@host"));
assert!(!rl.allow_incoming("~bob@host"));
assert!(!rl.allow_incoming("~bob@host"));
assert!(!rl.allow_incoming("~bob@host"));
let bucket = rl.incoming.get("~bob@host").expect("bucket present");
assert!(bucket.backoff_until.is_some());
}
#[test]
fn allow_incoming_backoff_expires_after_window() {
let mut rl = RateLimiter::new();
assert!(rl.allow_incoming("~bob@host"));
assert!(rl.allow_incoming("~bob@host"));
assert!(rl.allow_incoming("~bob@host"));
assert!(!rl.allow_incoming("~bob@host"));
rl.force_expire_backoff("~bob@host");
assert!(rl.allow_incoming("~bob@host"));
assert!(rl.allow_incoming("~bob@host"));
assert!(rl.allow_incoming("~bob@host"));
assert!(!rl.allow_incoming("~bob@host"));
}
#[test]
fn allow_incoming_independent_per_peer() {
let mut rl = RateLimiter::new();
assert!(rl.allow_incoming("~bob@host"));
assert!(rl.allow_incoming("~bob@host"));
assert!(rl.allow_incoming("~bob@host"));
assert!(!rl.allow_incoming("~bob@host"));
assert!(rl.allow_incoming("~alice@host"));
assert!(rl.allow_incoming("~alice@host"));
assert!(rl.allow_incoming("~alice@host"));
assert!(!rl.allow_incoming("~alice@host"));
assert!(!rl.allow_incoming("~bob@host"));
}
#[test]
fn keyreq_sig_payload_binds_eph_x25519() {
let p1 = signed_keyreq_payload("#x", &[1; 32], &[9; 32], &[2; 16]);
let p2 = signed_keyreq_payload("#x", &[1; 32], &[8; 32], &[2; 16]);
assert_ne!(p1, p2);
}
}