use crate::engine::Engine;
use crate::error::{FlowError, Result};
use crate::record::Record;
use crate::stats::StatsCounters;
use std::collections::HashMap;
use std::hash::Hasher;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use tokio::net::UdpSocket;
const MAGIC: u8 = 0x54;
const VERSION_V1: u8 = 0x01;
const VERSION_V2: u8 = 0x02;
const TTL_NONE: u32 = 0;
const MAX_KEY_BYTES: usize = 4096;
const MAX_VAL_BYTES: usize = 64 * 1024;
const MAX_FRAME_RECORDS: usize = 1024;
const AUTH_HASH_BYTES: usize = 8;
fn compute_auth_tag(api_key: &str, count: u16, records_payload: &[u8]) -> [u8; AUTH_HASH_BYTES] {
const SEED_K0: u64 = 0x5550_445f_4155_5448;
const SEED_K1: u64 = 0x464c_4f57_4442_4b45;
let mut h = std::hash::DefaultHasher::new();
h.write_u64(SEED_K0);
h.write_u64(SEED_K1);
h.write(api_key.as_bytes());
h.write(&count.to_be_bytes());
h.write(records_payload);
let hash = h.finish();
hash.to_be_bytes()
}
fn read_u16(data: &[u8], pos: usize) -> Option<(u16, usize)> {
if pos + 2 > data.len() {
return None;
}
Some((
u16::from_be_bytes(data[pos..pos + 2].try_into().unwrap()),
pos + 2,
))
}
fn read_u32(data: &[u8], pos: usize) -> Option<(u32, usize)> {
if pos + 4 > data.len() {
return None;
}
Some((
u32::from_be_bytes(data[pos..pos + 4].try_into().unwrap()),
pos + 4,
))
}
fn read_i64(data: &[u8], pos: usize) -> Option<(i64, usize)> {
if pos + 8 > data.len() {
return None;
}
Some((
i64::from_be_bytes(data[pos..pos + 8].try_into().unwrap()),
pos + 8,
))
}
fn read_record(data: &[u8], mut pos: usize) -> Option<(Record, usize)> {
let (key_len, p) = read_u16(data, pos)?;
if key_len as usize > MAX_KEY_BYTES {
return None;
}
pos = p;
if pos + key_len as usize > data.len() {
return None;
}
let key = data[pos..pos + key_len as usize].to_vec();
pos += key_len as usize;
let (ts, p) = read_i64(data, pos)?;
pos = p;
let (ttl, p) = read_u32(data, pos)?;
pos = p;
let (val_len, p) = read_u16(data, pos)?;
if val_len as usize > MAX_VAL_BYTES {
return None;
}
pos = p;
if pos + val_len as usize > data.len() {
return None;
}
let value = data[pos..pos + val_len as usize].to_vec();
pos += val_len as usize;
let expire_at = if ttl == TTL_NONE {
i64::MAX
} else {
ts + (ttl as i64 * 1_000_000)
};
Some((
Record {
key,
ts,
expire_at,
value,
},
pos,
))
}
pub fn decode_frame(data: &[u8], api_key: Option<&str>) -> Result<Vec<Record>> {
if data.len() < 4 {
return Err(FlowError::Other("frame too short".into()));
}
if data[0] != MAGIC {
return Err(FlowError::Other(format!("invalid magic: {:#x}", data[0])));
}
let version = data[1];
let (raw_count, mut pos) = read_u16(data, 2).unwrap();
let count = raw_count as usize;
if count > MAX_FRAME_RECORDS {
return Err(FlowError::Other(format!(
"frame record count too large: {} (max {})",
count, MAX_FRAME_RECORDS
)));
}
match (version, api_key) {
(VERSION_V1, None) => {
}
(VERSION_V1, Some(_)) => {
return Err(FlowError::Other(
"authentication required; upgrade client to V2 protocol".into(),
));
}
(VERSION_V2, Some(key)) => {
if data.len() < pos + AUTH_HASH_BYTES {
return Err(FlowError::Other("frame too short for v2 auth tag".into()));
}
let received_tag: [u8; AUTH_HASH_BYTES] =
data[pos..pos + AUTH_HASH_BYTES].try_into().unwrap();
pos += AUTH_HASH_BYTES;
let expected_tag = compute_auth_tag(key, raw_count, &data[pos..]);
if received_tag != expected_tag {
return Err(FlowError::Other("authentication failed: invalid key hash".into()));
}
}
(VERSION_V2, None) => {
return Err(FlowError::Other(
"v2 frame received but server has no api_key configured".into(),
));
}
_ => {
return Err(FlowError::Other(format!(
"unsupported version: {}",
data[1]
)));
}
}
let mut records = Vec::with_capacity(count);
for _ in 0..count {
let (rec, p) =
read_record(data, pos).ok_or_else(|| FlowError::Other("truncated record".into()))?;
pos = p;
records.push(rec);
}
Ok(records)
}
pub fn encode_frame(records: &[Record], api_key: Option<&str>) -> Vec<u8> {
let mut buf = Vec::with_capacity(64 * records.len());
buf.push(MAGIC);
let version = if api_key.is_some() {
VERSION_V2
} else {
VERSION_V1
};
buf.push(version);
let count = records.len() as u16;
buf.extend_from_slice(&count.to_be_bytes());
let auth_pos = if api_key.is_some() { Some(buf.len()) } else { None };
if let Some(_) = auth_pos {
buf.extend_from_slice(&[0u8; AUTH_HASH_BYTES]);
}
for rec in records {
let key_bytes = &rec.key;
buf.extend_from_slice(&(key_bytes.len() as u16).to_be_bytes());
buf.extend_from_slice(key_bytes);
buf.extend_from_slice(&rec.ts.to_be_bytes());
let ttl = if rec.expire_at == i64::MAX {
TTL_NONE
} else {
((rec.expire_at - rec.ts) / 1_000_000) as u32
};
buf.extend_from_slice(&ttl.to_be_bytes());
buf.extend_from_slice(&(rec.value.len() as u16).to_be_bytes());
buf.extend_from_slice(&rec.value);
}
if let (Some(pos), Some(key)) = (auth_pos, api_key) {
let tag = compute_auth_tag(key, count, &buf[pos + AUTH_HASH_BYTES..]);
buf[pos..pos + AUTH_HASH_BYTES].copy_from_slice(&tag);
}
buf
}
struct TokenBucket {
tokens: f64,
last_refill: std::time::Instant,
rate: f64,
}
const MAX_RATE_LIMIT_ENTRIES: usize = 100_000;
impl TokenBucket {
fn new(rate_per_sec: u32) -> Self {
Self {
tokens: rate_per_sec as f64,
last_refill: std::time::Instant::now(),
rate: rate_per_sec as f64,
}
}
fn try_consume(&mut self, now: std::time::Instant) -> bool {
let elapsed = (now - self.last_refill).as_secs_f64();
self.tokens = (self.tokens + elapsed * self.rate).min(self.rate);
self.last_refill = now;
if self.tokens >= 1.0 {
self.tokens -= 1.0;
true
} else {
false
}
}
}
pub async fn start_udp_listener(
engine: Arc<Engine>,
stats: Arc<StatsCounters>,
addr: SocketAddr,
max_packet_size: usize,
api_key: Option<String>,
rate_limit_per_ip: u32,
) -> Result<()> {
let socket = UdpSocket::bind(addr).await?;
let mut buf = vec![0u8; max_packet_size];
let mut rate_limits: HashMap<IpAddr, TokenBucket> = HashMap::new();
let mut last_cleanup = std::time::Instant::now();
loop {
match socket.recv_from(&mut buf).await {
Ok((len, src)) => {
let now = std::time::Instant::now();
if rate_limit_per_ip > 0 {
let ip = src.ip();
if now.duration_since(last_cleanup).as_secs() >= 30
|| rate_limits.len() > MAX_RATE_LIMIT_ENTRIES
{
let cutoff = now - std::time::Duration::from_secs(60);
rate_limits.retain(|_, b| b.last_refill >= cutoff);
last_cleanup = now;
}
let bucket = rate_limits
.entry(ip)
.or_insert_with(|| TokenBucket::new(rate_limit_per_ip));
if !bucket.try_consume(now) {
stats
.udp_packets_dropped
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
continue;
}
}
stats
.udp_packets_received
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
match decode_frame(&buf[..len], api_key.as_deref()) {
Ok(records) => {
if let Err(e) = engine.write_batch(&records).await {
tracing::warn!("UDP write error: {}", e);
}
}
Err(e) => {
tracing::debug!("UDP decode error: {}", e);
stats
.udp_packets_dropped
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
}
}
Err(e) => {
tracing::warn!("UDP recv error: {}", e);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode_decode_single() {
let rec = Record {
key: "test-key".into(),
ts: 1234567890,
expire_at: 1234567890 + 3600 * 1_000_000,
value: b"hello".to_vec(),
};
let encoded = encode_frame(std::slice::from_ref(&rec), None);
let decoded = decode_frame(&encoded, None).unwrap();
assert_eq!(decoded.len(), 1);
assert_eq!(decoded[0].key, b"test-key");
assert_eq!(decoded[0].ts, 1234567890);
assert_eq!(decoded[0].value, b"hello");
assert!(decoded[0].expire_at < i64::MAX);
}
#[test]
fn test_encode_decode_no_ttl() {
let rec = Record {
key: "key".into(),
ts: 100,
expire_at: i64::MAX,
value: b"val".to_vec(),
};
let encoded = encode_frame(&[rec], None);
let decoded = decode_frame(&encoded, None).unwrap();
assert_eq!(decoded[0].expire_at, i64::MAX);
}
#[test]
fn test_encode_decode_batch() {
let recs = vec![
Record {
key: "a".into(),
ts: 100,
expire_at: i64::MAX,
value: b"v1".to_vec(),
},
Record {
key: "b".into(),
ts: 200,
expire_at: i64::MAX,
value: b"v2".to_vec(),
},
];
let encoded = encode_frame(&recs, None);
let decoded = decode_frame(&encoded, None).unwrap();
assert_eq!(decoded.len(), 2);
assert_eq!(decoded[0].key, b"a");
assert_eq!(decoded[1].key, b"b");
}
#[test]
fn test_decode_corrupt_magic() {
let rec = Record {
key: "key".into(),
ts: 100,
expire_at: i64::MAX,
value: b"val".to_vec(),
};
let mut encoded = encode_frame(&[rec], None);
encoded[0] = 0x00;
assert!(decode_frame(&encoded, None).is_err());
}
#[test]
fn test_decode_truncated() {
assert!(decode_frame(&[MAGIC, VERSION_V1], None).is_err());
assert!(decode_frame(&[MAGIC, VERSION_V1, 0x00, 0x01], None).is_err());
}
#[test]
fn test_read_u16_with_position() {
let data = [0x01, 0x02, 0x03, 0x04];
let (v, pos) = read_u16(&data, 0).unwrap();
assert_eq!(v, 0x0102);
assert_eq!(pos, 2);
let (v, pos) = read_u16(&data, 2).unwrap();
assert_eq!(v, 0x0304);
assert_eq!(pos, 4);
}
#[test]
fn test_read_u16_oob() {
assert!(read_u16(&[0x01], 0).is_none());
assert!(read_u16(&[0x01, 0x02], 1).is_none());
}
#[test]
fn test_read_u32_with_position() {
let data = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08];
let (v, pos) = read_u32(&data, 0).unwrap();
assert_eq!(v, 0x01020304);
assert_eq!(pos, 4);
let (v, pos) = read_u32(&data, 4).unwrap();
assert_eq!(v, 0x05060708);
assert_eq!(pos, 8);
}
#[test]
fn test_read_u32_oob() {
assert!(read_u32(&[0; 3], 0).is_none());
assert!(read_u32(&[0; 4], 1).is_none());
}
#[test]
fn test_read_i64_with_position() {
let n: i64 = -0x0102030405060708;
let bytes = n.to_be_bytes();
let (v, pos) = read_i64(&bytes, 0).unwrap();
assert_eq!(v, n);
assert_eq!(pos, 8);
}
#[test]
fn test_read_i64_oob() {
assert!(read_i64(&[0; 7], 0).is_none());
assert!(read_i64(&[0; 8], 1).is_none());
}
#[test]
fn test_token_bucket_rate_limiting() {
let mut bucket = TokenBucket::new(10); let start = std::time::Instant::now();
for _ in 0..10 {
assert!(bucket.try_consume(start), "should allow 10 initial tokens");
}
assert!(!bucket.try_consume(start), "should deny 11th token");
let later = start + std::time::Duration::from_millis(200);
assert!(bucket.try_consume(later), "should allow after refill");
assert!(bucket.try_consume(later), "should allow second refilled token");
assert!(!bucket.try_consume(later), "should deny third (no more refill)");
}
#[test]
fn test_auth_tag_deterministic() {
let tag1 = compute_auth_tag("secret", 5, b"payload");
let tag2 = compute_auth_tag("secret", 5, b"payload");
assert_eq!(tag1, tag2, "same inputs must produce same tag");
let tag3 = compute_auth_tag("secret", 6, b"payload");
assert_ne!(tag1, tag3, "different count must produce different tag");
let tag4 = compute_auth_tag("wrong", 5, b"payload");
assert_ne!(tag1, tag4, "different key must produce different tag");
}
#[test]
fn test_v2_frame_rejected_when_too_short_for_auth_tag() {
let frame = [MAGIC, VERSION_V2, 0x00, 0x01];
assert!(decode_frame(&frame, Some("key")).is_err());
}
}