use lz4_flex::{compress_prepend_size, decompress_size_prepended};
use postcard::{from_bytes, to_allocvec};
use crate::store::ZoneDiff;
use crate::zone::ZoneEntry;
pub const LZ4_THRESHOLD: usize = 64;
const TAG_RAW: u8 = 0x00;
const TAG_LZ4: u8 = 0x01;
const SNAPSHOT_MAGIC: [u8; 4] = *b"PLZN";
const SNAPSHOT_VERSION: u16 = 1;
#[must_use]
pub fn encode<T: serde::Serialize>(msg: &T) -> Vec<u8> {
let raw = to_allocvec(msg).expect("postcard serialization is infallible for owned data");
if raw.len() < LZ4_THRESHOLD {
let mut out = Vec::with_capacity(raw.len() + 1);
out.push(TAG_RAW);
out.extend_from_slice(&raw);
out
} else {
let compressed = compress_prepend_size(&raw);
let mut out = Vec::with_capacity(compressed.len() + 1);
out.push(TAG_LZ4);
out.extend_from_slice(&compressed);
out
}
}
#[must_use]
pub fn decode<T: for<'de> serde::Deserialize<'de>>(bytes: &[u8]) -> Option<T> {
let (&tag, rest) = bytes.split_first()?;
let payload = match tag {
TAG_RAW => rest.to_vec(),
TAG_LZ4 => decompress_size_prepended(rest).ok()?,
_ => return None,
};
from_bytes::<T>(&payload).ok()
}
#[must_use]
pub fn encode_snapshot(entries: &[ZoneEntry]) -> Vec<u8> {
let raw = to_allocvec(entries).expect("postcard serialization is infallible");
let compressed = zstd::encode_all(raw.as_slice(), 3).expect("zstd compression of in-memory buffer");
let mut out = Vec::with_capacity(4 + 2 + compressed.len());
out.extend_from_slice(&SNAPSHOT_MAGIC);
out.extend_from_slice(&SNAPSHOT_VERSION.to_le_bytes());
out.extend_from_slice(&compressed);
out
}
#[must_use]
pub fn decode_snapshot(bytes: &[u8]) -> Option<Vec<ZoneEntry>> {
if bytes.len() < 6 {
return None;
}
if bytes[..4] != SNAPSHOT_MAGIC {
return None;
}
let version = u16::from_le_bytes([bytes[4], bytes[5]]);
if version != SNAPSHOT_VERSION {
return None;
}
let raw = zstd::decode_all(&bytes[6..]).ok()?;
from_bytes(&raw).ok()
}
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq)]
pub enum ClientMsg {
FullPos { entity_id: u32, pos: [f32; 3], ts_ms: u32 },
DeltaPos { entity_id: u32, dx: i16, dy: i16, dz: i16, dt_ms: u8 },
Stationary { entity_id: u32, duration_ms: u16 },
RequestSnapshot,
Ack { seq: u16 },
Ping { seq: u16 },
}
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq)]
pub enum ServerMsg {
ZoneBatch { seq: u16, diffs: Vec<ZoneDiff> },
EntityEvent { entity_id: u32, event: ZoneEvent, zone_id: u32, ts_ms: u32 },
ScanResult { zone_id: u32, coverage_pct: u16, holes: Vec<CompactHole> },
Pong { seq: u16, server_ts_ms: u32 },
}
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, Copy, PartialEq)]
pub enum ZoneEvent {
Enter,
Exit,
Dwell { ms: u32 },
}
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, Copy, PartialEq)]
pub struct CompactHole {
pub cx: f32,
pub cy: f32,
pub cz: f32,
pub size_m: f32,
pub depth: u8,
}
impl CompactHole {
pub fn from_hole(h: &crate::scan::Hole) -> Self {
Self {
cx: h.center[0] as f32,
cy: h.center[1] as f32,
cz: h.center[2] as f32,
size_m: h.size_m as f32,
depth: h.depth,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NetworkTier {
Wifi,
Lte,
G3,
G2Edge,
Offline,
}
impl NetworkTier {
pub fn from_rtt(avg_rtt_ms: u32, drop_rate: f64) -> Self {
if drop_rate > 0.1 {
return Self::G2Edge;
}
match avg_rtt_ms {
0..=50 => Self::Wifi,
51..=100 => Self::Lte,
101..=300 => Self::G3,
301..=1000 => Self::G2Edge,
_ => Self::Offline,
}
}
pub fn tick_ms(&self) -> u64 {
match self {
Self::Wifi => 50,
Self::Lte => 100,
Self::G3 => 200,
Self::G2Edge => 1000,
Self::Offline => u64::MAX,
}
}
pub fn correction_threshold_m(&self) -> f64 {
match self {
Self::Wifi => 0.05,
Self::Lte => 0.20,
Self::G3 => 0.50,
Self::G2Edge => 2.0,
Self::Offline => f64::MAX,
}
}
}
pub struct PosTracker {
pub entity_id: u32,
pub tier: NetworkTier,
last_sent_pos: [f32; 3],
last_sent_ts: u32,
pub server_vel: [f64; 3],
}
impl PosTracker {
pub fn new(entity_id: u32, tier: NetworkTier) -> Self {
Self {
entity_id,
tier,
last_sent_pos: [0.0; 3],
last_sent_ts: 0,
server_vel: [0.0; 3],
}
}
pub fn build_pos_msg(&mut self, pos: [f32; 3], ts_ms: u32) -> ClientMsg {
let dx = ((pos[0] - self.last_sent_pos[0]) * 1000.0).round() as i32;
let dy = ((pos[1] - self.last_sent_pos[1]) * 1000.0).round() as i32;
let dz = ((pos[2] - self.last_sent_pos[2]) * 1000.0).round() as i32;
let dt = ts_ms.wrapping_sub(self.last_sent_ts);
let fits = dx.abs() <= i16::MAX as i32
&& dy.abs() <= i16::MAX as i32
&& dz.abs() <= i16::MAX as i32
&& dt <= u8::MAX as u32;
self.last_sent_pos = pos;
self.last_sent_ts = ts_ms;
if fits {
ClientMsg::DeltaPos {
entity_id: self.entity_id,
dx: dx as i16,
dy: dy as i16,
dz: dz as i16,
dt_ms: dt as u8,
}
} else {
ClientMsg::FullPos { entity_id: self.entity_id, pos, ts_ms }
}
}
pub fn needs_update(&self, actual: [f32; 3], ts_ms: u32) -> bool {
let dt = ts_ms.wrapping_sub(self.last_sent_ts) as f64 / 1000.0;
let predicted: [f64; 3] =
std::array::from_fn(|i| self.last_sent_pos[i] as f64 + self.server_vel[i] * dt);
let err = (0..3)
.map(|i| (actual[i] as f64 - predicted[i]).powi(2))
.sum::<f64>()
.sqrt();
err > self.tier.correction_threshold_m()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::zone::Zone;
#[test]
fn small_message_is_uncompressed() {
let msg = ClientMsg::Ping { seq: 7 };
let bytes = encode(&msg);
assert_eq!(bytes[0], TAG_RAW, "small message should skip compression");
let back: ClientMsg = decode(&bytes).unwrap();
assert_eq!(back, msg);
}
#[test]
fn large_message_is_lz4() {
let diffs: Vec<ZoneDiff> = (0..50)
.map(|i| {
ZoneDiff::Add(ZoneEntry::new(
i,
Zone::Cylinder { center: [10.0, 106.0], radius_m: 5.0, z_min: 0.0, z_max: 1.0 },
))
})
.collect();
let msg = ServerMsg::ZoneBatch { seq: 1, diffs };
let bytes = encode(&msg);
assert_eq!(bytes[0], TAG_LZ4, "large message should be compressed");
let back: ServerMsg = decode(&bytes).unwrap();
assert_eq!(back, msg);
}
#[test]
fn decode_rejects_garbage() {
assert!(decode::<ClientMsg>(&[]).is_none());
assert!(decode::<ClientMsg>(&[0xFF, 1, 2, 3]).is_none(), "unknown tag");
assert!(decode::<ClientMsg>(&[TAG_LZ4, 0, 0]).is_none(), "bad lz4 payload");
}
#[test]
fn snapshot_round_trip() {
let entries: Vec<ZoneEntry> = (0..20)
.map(|i| ZoneEntry::new(i, Zone::Aabb { min: [0.0, 0.0, 0.0], max: [1.0, 1.0, 1.0] }))
.collect();
let bytes = encode_snapshot(&entries);
let back = decode_snapshot(&bytes).unwrap();
assert_eq!(back, entries);
assert!(decode_snapshot(b"not zstd").is_none());
}
#[test]
fn tier_classification() {
assert_eq!(NetworkTier::from_rtt(20, 0.0), NetworkTier::Wifi);
assert_eq!(NetworkTier::from_rtt(80, 0.0), NetworkTier::Lte);
assert_eq!(NetworkTier::from_rtt(250, 0.0), NetworkTier::G3);
assert_eq!(NetworkTier::from_rtt(800, 0.0), NetworkTier::G2Edge);
assert_eq!(NetworkTier::from_rtt(5000, 0.0), NetworkTier::Offline);
assert_eq!(NetworkTier::from_rtt(10, 0.5), NetworkTier::G2Edge);
}
#[test]
fn delta_when_close_full_when_far() {
let mut t = PosTracker::new(42, NetworkTier::Wifi);
let m = t.build_pos_msg([1.0, 2.0, 0.5], 100);
match m {
ClientMsg::DeltaPos { entity_id, dx, dy, dz, dt_ms } => {
assert_eq!(entity_id, 42);
assert_eq!((dx, dy, dz), (1000, 2000, 500));
assert_eq!(dt_ms, 100);
}
other => panic!("expected DeltaPos, got {other:?}"),
}
let m = t.build_pos_msg([500.0, 2.0, 0.5], 150);
assert!(matches!(m, ClientMsg::FullPos { .. }), "large jump should be FullPos");
}
#[test]
fn dead_reckoning_suppresses_small_drift() {
let mut t = PosTracker::new(1, NetworkTier::G3); t.build_pos_msg([0.0, 0.0, 0.0], 0);
t.server_vel = [1.0, 0.0, 0.0]; assert!(!t.needs_update([1.2, 0.0, 0.0], 1000));
assert!(t.needs_update([2.0, 0.0, 0.0], 1000));
}
#[test]
fn compact_hole_from_hole() {
let h = crate::scan::Hole { center: [1.5, 2.5, 3.5], size_m: 0.25, depth: 12 };
let c = CompactHole::from_hole(&h);
assert_eq!((c.cx, c.cy, c.cz, c.size_m, c.depth), (1.5, 2.5, 3.5, 0.25, 12));
}
}