Skip to main content

proof_engine/networking/
rpc.rs

1//! Remote Procedure Call (RPC) system for game events.
2//!
3//! RPCs are named, typed, network-dispatched function calls.  They are
4//! registered centrally in `RpcRegistry`, dispatched via `RpcQueue`, and
5//! can be targeted at a single client, a team, or all peers.
6//!
7//! ## Flow
8//! 1. Server/client calls `RpcQueue::enqueue(call)`.
9//! 2. Transport drains the queue and serialises all pending calls into one
10//!    or more UDP packets (via `RpcBatcher`).
11//! 3. Remote side deserialises, looks up the `RpcId` in `RpcRegistry`, and
12//!    invokes the registered `RpcHandler`.
13//! 4. `RpcSecurity` validates the caller and rate-limits calls per client.
14//! 5. `RpcReplay` optionally records every call for debugging / replay.
15
16use std::collections::{HashMap, VecDeque};
17use std::time::{Duration, Instant};
18
19use crate::networking::sync::Vec3;
20
21// ─── RpcId ───────────────────────────────────────────────────────────────────
22
23/// Compact 16-bit identifier for a registered RPC.
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
25pub struct RpcId(pub u16);
26
27// ─── Built-in RPC IDs ────────────────────────────────────────────────────────
28
29pub const RPC_CHAT_MESSAGE:   RpcId = RpcId(0x0001);
30pub const RPC_PLAYER_JOINED:  RpcId = RpcId(0x0002);
31pub const RPC_PLAYER_LEFT:    RpcId = RpcId(0x0003);
32pub const RPC_GAME_EVENT:     RpcId = RpcId(0x0004);
33pub const RPC_FORCE_FIELD:    RpcId = RpcId(0x0005);
34pub const RPC_PARTICLE_BURST: RpcId = RpcId(0x0006);
35pub const RPC_SCREEN_EFFECT:  RpcId = RpcId(0x0007);
36pub const RPC_PLAY_SOUND:     RpcId = RpcId(0x0008);
37pub const RPC_DAMAGE_NUMBER:  RpcId = RpcId(0x0009);
38pub const RPC_ENTITY_STATUS:  RpcId = RpcId(0x000A);
39pub const RPC_CAMERA_SHAKE:   RpcId = RpcId(0x000B);
40pub const RPC_DIALOGUE:       RpcId = RpcId(0x000C);
41
42// ─── RpcTarget ───────────────────────────────────────────────────────────────
43
44/// Who should receive and execute this RPC call.
45#[derive(Debug, Clone, PartialEq, Eq)]
46pub enum RpcTarget {
47    /// Broadcast to all connected clients.
48    All,
49    /// Send to the server only.
50    Server,
51    /// Send to a specific client.
52    Client(u64),
53    /// Send to all players on a team.
54    Team(u8),
55    /// Broadcast except one client.
56    AllExcept(u64),
57}
58
59// ─── RpcParam ────────────────────────────────────────────────────────────────
60
61/// Dynamically typed RPC parameter.
62#[derive(Debug, Clone, PartialEq)]
63pub enum RpcParam {
64    Bool(bool),
65    Int(i64),
66    Float(f32),
67    Str(String),
68    Vec3(Vec3),
69    Bytes(Vec<u8>),
70}
71
72impl RpcParam {
73    /// Type tag byte used in serialisation.
74    pub fn type_tag(&self) -> u8 {
75        match self {
76            RpcParam::Bool(_)  => 0x01,
77            RpcParam::Int(_)   => 0x02,
78            RpcParam::Float(_) => 0x03,
79            RpcParam::Str(_)   => 0x04,
80            RpcParam::Vec3(_)  => 0x05,
81            RpcParam::Bytes(_) => 0x06,
82        }
83    }
84
85    /// Serialise to bytes: [type_tag(1)] + [payload].
86    pub fn serialize(&self, out: &mut Vec<u8>) {
87        out.push(self.type_tag());
88        match self {
89            RpcParam::Bool(b)  => out.push(*b as u8),
90            RpcParam::Int(i)   => out.extend_from_slice(&i.to_be_bytes()),
91            RpcParam::Float(f) => out.extend_from_slice(&f.to_bits().to_be_bytes()),
92            RpcParam::Str(s)   => {
93                let bytes = s.as_bytes();
94                let len   = bytes.len().min(0xFFFF) as u16;
95                out.extend_from_slice(&len.to_be_bytes());
96                out.extend_from_slice(&bytes[..len as usize]);
97            }
98            RpcParam::Vec3(v) => {
99                out.extend_from_slice(&v.x.to_bits().to_be_bytes());
100                out.extend_from_slice(&v.y.to_bits().to_be_bytes());
101                out.extend_from_slice(&v.z.to_bits().to_be_bytes());
102            }
103            RpcParam::Bytes(b) => {
104                let len = b.len().min(0xFFFF) as u16;
105                out.extend_from_slice(&len.to_be_bytes());
106                out.extend_from_slice(&b[..len as usize]);
107            }
108        }
109    }
110
111    /// Deserialise one param from `buf` at `offset`.
112    /// Returns `(param, new_offset)`.
113    pub fn deserialize(buf: &[u8], offset: usize) -> Result<(Self, usize), RpcError> {
114        if offset >= buf.len() {
115            return Err(RpcError::DeserializeError("buffer empty".into()));
116        }
117        let tag = buf[offset];
118        let pos = offset + 1;
119
120        macro_rules! need {
121            ($n:expr) => {
122                if pos + $n > buf.len() {
123                    return Err(RpcError::DeserializeError("truncated param".into()));
124                }
125            };
126        }
127
128        match tag {
129            0x01 => {
130                need!(1);
131                Ok((RpcParam::Bool(buf[pos] != 0), pos + 1))
132            }
133            0x02 => {
134                need!(8);
135                let v = i64::from_be_bytes(buf[pos..pos+8].try_into().unwrap());
136                Ok((RpcParam::Int(v), pos + 8))
137            }
138            0x03 => {
139                need!(4);
140                let v = f32::from_bits(u32::from_be_bytes(buf[pos..pos+4].try_into().unwrap()));
141                Ok((RpcParam::Float(v), pos + 4))
142            }
143            0x04 => {
144                need!(2);
145                let len = u16::from_be_bytes([buf[pos], buf[pos+1]]) as usize;
146                if pos + 2 + len > buf.len() {
147                    return Err(RpcError::DeserializeError("str truncated".into()));
148                }
149                let s = std::str::from_utf8(&buf[pos+2..pos+2+len])
150                    .map_err(|e| RpcError::DeserializeError(e.to_string()))?
151                    .to_string();
152                Ok((RpcParam::Str(s), pos + 2 + len))
153            }
154            0x05 => {
155                need!(12);
156                let x = f32::from_bits(u32::from_be_bytes(buf[pos..pos+4].try_into().unwrap()));
157                let y = f32::from_bits(u32::from_be_bytes(buf[pos+4..pos+8].try_into().unwrap()));
158                let z = f32::from_bits(u32::from_be_bytes(buf[pos+8..pos+12].try_into().unwrap()));
159                Ok((RpcParam::Vec3(Vec3::new(x, y, z)), pos + 12))
160            }
161            0x06 => {
162                need!(2);
163                let len = u16::from_be_bytes([buf[pos], buf[pos+1]]) as usize;
164                if pos + 2 + len > buf.len() {
165                    return Err(RpcError::DeserializeError("bytes truncated".into()));
166                }
167                Ok((RpcParam::Bytes(buf[pos+2..pos+2+len].to_vec()), pos + 2 + len))
168            }
169            _ => Err(RpcError::DeserializeError(format!("unknown param tag 0x{tag:02x}"))),
170        }
171    }
172
173    // ── Convenience accessors ────────────────────────────────────────────────
174
175    pub fn as_bool(&self) -> Option<bool> {
176        if let RpcParam::Bool(v) = self { Some(*v) } else { None }
177    }
178    pub fn as_int(&self) -> Option<i64> {
179        if let RpcParam::Int(v) = self { Some(*v) } else { None }
180    }
181    pub fn as_float(&self) -> Option<f32> {
182        if let RpcParam::Float(v) = self { Some(*v) } else { None }
183    }
184    pub fn as_str(&self) -> Option<&str> {
185        if let RpcParam::Str(v) = self { Some(v) } else { None }
186    }
187    pub fn as_vec3(&self) -> Option<Vec3> {
188        if let RpcParam::Vec3(v) = self { Some(*v) } else { None }
189    }
190    pub fn as_bytes(&self) -> Option<&[u8]> {
191        if let RpcParam::Bytes(v) = self { Some(v) } else { None }
192    }
193}
194
195// ─── RpcCall ─────────────────────────────────────────────────────────────────
196
197/// A fully-formed RPC call ready for serialisation and dispatch.
198#[derive(Debug, Clone)]
199pub struct RpcCall {
200    pub id:     RpcId,
201    pub target: RpcTarget,
202    pub params: Vec<RpcParam>,
203    /// Sequence number (assigned by `RpcQueue` on enqueue).
204    pub seq:    u32,
205    /// Originating client_id (0 = server).
206    pub caller: u64,
207}
208
209impl RpcCall {
210    pub fn new(id: RpcId, target: RpcTarget, params: Vec<RpcParam>) -> Self {
211        Self { id, target, params, seq: 0, caller: 0 }
212    }
213
214    pub fn with_caller(mut self, caller_id: u64) -> Self {
215        self.caller = caller_id;
216        self
217    }
218
219    /// Serialise to bytes: [rpc_id(2)] [target(2)] [caller(8)] [param_count(1)] [params...]
220    pub fn serialize(&self) -> Vec<u8> {
221        let mut out = Vec::new();
222        out.extend_from_slice(&self.id.0.to_be_bytes());
223        let target_tag: u16 = match &self.target {
224            RpcTarget::All            => 0x0000,
225            RpcTarget::Server         => 0x0001,
226            RpcTarget::Client(id)     => {
227                out.extend_from_slice(&id.to_be_bytes());
228                0x0002
229            }
230            RpcTarget::Team(t)        => { out.push(*t); 0x0003 }
231            RpcTarget::AllExcept(id)  => { out.extend_from_slice(&id.to_be_bytes()); 0x0004 }
232        };
233        // We need to write target_tag first, so rebuild
234        let mut final_out = Vec::new();
235        final_out.extend_from_slice(&self.id.0.to_be_bytes());
236        final_out.extend_from_slice(&target_tag.to_be_bytes());
237        final_out.extend_from_slice(&self.seq.to_be_bytes());
238        final_out.extend_from_slice(&self.caller.to_be_bytes());
239
240        // Target-specific extra bytes
241        match &self.target {
242            RpcTarget::Client(id)    => final_out.extend_from_slice(&id.to_be_bytes()),
243            RpcTarget::Team(t)       => final_out.push(*t),
244            RpcTarget::AllExcept(id) => final_out.extend_from_slice(&id.to_be_bytes()),
245            _ => {}
246        }
247
248        final_out.push(self.params.len().min(0xFF) as u8);
249        for p in &self.params {
250            p.serialize(&mut final_out);
251        }
252        final_out
253    }
254
255    /// Deserialise one `RpcCall` from `buf` at `offset`.
256    pub fn deserialize(buf: &[u8], offset: usize) -> Result<(Self, usize), RpcError> {
257        let mut pos = offset;
258
259        macro_rules! need {
260            ($n:expr) => {
261                if pos + $n > buf.len() {
262                    return Err(RpcError::DeserializeError("truncated rpc call".into()));
263                }
264            };
265        }
266
267        need!(2);
268        let id = RpcId(u16::from_be_bytes([buf[pos], buf[pos+1]]));
269        pos += 2;
270
271        need!(2);
272        let target_tag = u16::from_be_bytes([buf[pos], buf[pos+1]]);
273        pos += 2;
274
275        need!(4);
276        let seq = u32::from_be_bytes(buf[pos..pos+4].try_into().unwrap());
277        pos += 4;
278
279        need!(8);
280        let caller = u64::from_be_bytes(buf[pos..pos+8].try_into().unwrap());
281        pos += 8;
282
283        let target = match target_tag {
284            0x0000 => RpcTarget::All,
285            0x0001 => RpcTarget::Server,
286            0x0002 => {
287                need!(8);
288                let id_val = u64::from_be_bytes(buf[pos..pos+8].try_into().unwrap());
289                pos += 8;
290                RpcTarget::Client(id_val)
291            }
292            0x0003 => {
293                need!(1);
294                let t = buf[pos];
295                pos += 1;
296                RpcTarget::Team(t)
297            }
298            0x0004 => {
299                need!(8);
300                let id_val = u64::from_be_bytes(buf[pos..pos+8].try_into().unwrap());
301                pos += 8;
302                RpcTarget::AllExcept(id_val)
303            }
304            _ => return Err(RpcError::DeserializeError(format!("unknown target tag {target_tag}"))),
305        };
306
307        need!(1);
308        let param_count = buf[pos] as usize;
309        pos += 1;
310
311        let mut params = Vec::with_capacity(param_count);
312        for _ in 0..param_count {
313            let (p, new_pos) = RpcParam::deserialize(buf, pos)?;
314            params.push(p);
315            pos = new_pos;
316        }
317
318        Ok((Self { id, target, params, seq, caller }, pos))
319    }
320}
321
322// ─── RpcResult ───────────────────────────────────────────────────────────────
323
324/// Return value from an RPC handler.
325pub type RpcResult = Result<Option<RpcParam>, RpcError>;
326
327/// Errors from the RPC system.
328#[derive(Debug, Clone, PartialEq)]
329pub enum RpcError {
330    UnknownRpc(RpcId),
331    InvalidParams { expected: usize, got: usize },
332    WrongParamType { index: usize, expected: &'static str },
333    RateLimited { rpc_id: RpcId, caller: u64 },
334    Unauthorised { rpc_id: RpcId, caller: u64 },
335    DeserializeError(String),
336    HandlerPanic(String),
337}
338
339impl std::fmt::Display for RpcError {
340    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
341        write!(f, "{self:?}")
342    }
343}
344
345impl std::error::Error for RpcError {}
346
347// ─── RpcHandler ──────────────────────────────────────────────────────────────
348
349/// A boxed, type-erased RPC handler function.
350pub type RpcHandler = Box<dyn Fn(&[RpcParam]) -> RpcResult + Send + Sync>;
351
352// ─── RpcRegistry ─────────────────────────────────────────────────────────────
353
354/// Central registry mapping `RpcId` ↔ name ↔ handler.
355pub struct RpcRegistry {
356    /// id → (name, handler)
357    handlers:  HashMap<RpcId, (String, RpcHandler)>,
358    /// name → id (reverse lookup)
359    name_map:  HashMap<String, RpcId>,
360    next_id:   u16,
361}
362
363impl RpcRegistry {
364    pub fn new() -> Self {
365        let mut reg = Self {
366            handlers:  HashMap::new(),
367            name_map:  HashMap::new(),
368            next_id:   0x1000, // user RPCs start here
369        };
370        reg.register_builtins();
371        reg
372    }
373
374    fn alloc_id(&mut self) -> RpcId {
375        let id = RpcId(self.next_id);
376        self.next_id += 1;
377        id
378    }
379
380    /// Register a handler under a fixed `id` and `name`.
381    pub fn register_fixed(
382        &mut self,
383        id:      RpcId,
384        name:    impl Into<String>,
385        handler: RpcHandler,
386    ) {
387        let name = name.into();
388        self.name_map.insert(name.clone(), id);
389        self.handlers.insert(id, (name, handler));
390    }
391
392    /// Register a handler, auto-assigning an ID.  Returns the assigned `RpcId`.
393    pub fn register(
394        &mut self,
395        name:    impl Into<String>,
396        handler: RpcHandler,
397    ) -> RpcId {
398        let id = self.alloc_id();
399        self.register_fixed(id, name, handler);
400        id
401    }
402
403    /// Look up a handler by `RpcId`.
404    pub fn handler(&self, id: RpcId) -> Option<&RpcHandler> {
405        self.handlers.get(&id).map(|(_, h)| h)
406    }
407
408    /// Look up an ID by name.
409    pub fn id_for(&self, name: &str) -> Option<RpcId> {
410        self.name_map.get(name).copied()
411    }
412
413    /// Invoke an RPC by ID.  Returns `Err(UnknownRpc)` if not registered.
414    pub fn invoke(&self, id: RpcId, params: &[RpcParam]) -> RpcResult {
415        match self.handlers.get(&id) {
416            Some((_, h)) => h(params),
417            None => Err(RpcError::UnknownRpc(id)),
418        }
419    }
420
421    /// Number of registered RPCs.
422    pub fn len(&self) -> usize { self.handlers.len() }
423    pub fn is_empty(&self) -> bool { self.handlers.is_empty() }
424
425    // ── Built-in RPCs ────────────────────────────────────────────────────────
426
427    fn register_builtins(&mut self) {
428        // chat_message(sender_id: Int, text: Str)
429        self.register_fixed(RPC_CHAT_MESSAGE, "chat_message", Box::new(|params| {
430            if params.len() < 2 {
431                return Err(RpcError::InvalidParams { expected: 2, got: params.len() });
432            }
433            let _ = params[0].as_int()
434                .ok_or(RpcError::WrongParamType { index: 0, expected: "Int" })?;
435            let _ = params[1].as_str()
436                .ok_or(RpcError::WrongParamType { index: 1, expected: "Str" })?;
437            Ok(None)
438        }));
439
440        // player_joined(player_id: Int, name: Str)
441        self.register_fixed(RPC_PLAYER_JOINED, "player_joined", Box::new(|params| {
442            if params.len() < 2 {
443                return Err(RpcError::InvalidParams { expected: 2, got: params.len() });
444            }
445            let _ = params[0].as_int()
446                .ok_or(RpcError::WrongParamType { index: 0, expected: "Int" })?;
447            let _ = params[1].as_str()
448                .ok_or(RpcError::WrongParamType { index: 1, expected: "Str" })?;
449            Ok(None)
450        }));
451
452        // player_left(player_id: Int)
453        self.register_fixed(RPC_PLAYER_LEFT, "player_left", Box::new(|params| {
454            if params.is_empty() {
455                return Err(RpcError::InvalidParams { expected: 1, got: 0 });
456            }
457            let _ = params[0].as_int()
458                .ok_or(RpcError::WrongParamType { index: 0, expected: "Int" })?;
459            Ok(None)
460        }));
461
462        // game_event(kind: Int, data: Bytes)
463        self.register_fixed(RPC_GAME_EVENT, "game_event", Box::new(|params| {
464            if params.len() < 2 {
465                return Err(RpcError::InvalidParams { expected: 2, got: params.len() });
466            }
467            let _ = params[0].as_int()
468                .ok_or(RpcError::WrongParamType { index: 0, expected: "Int" })?;
469            let _ = params[1].as_bytes()
470                .ok_or(RpcError::WrongParamType { index: 1, expected: "Bytes" })?;
471            Ok(None)
472        }));
473
474        // force_field_spawn(field_type: Int, position: Vec3, strength: Float, ttl: Float)
475        self.register_fixed(RPC_FORCE_FIELD, "force_field_spawn", Box::new(|params| {
476            if params.len() < 4 {
477                return Err(RpcError::InvalidParams { expected: 4, got: params.len() });
478            }
479            let _ = params[0].as_int()
480                .ok_or(RpcError::WrongParamType { index: 0, expected: "Int" })?;
481            let _ = params[1].as_vec3()
482                .ok_or(RpcError::WrongParamType { index: 1, expected: "Vec3" })?;
483            let _ = params[2].as_float()
484                .ok_or(RpcError::WrongParamType { index: 2, expected: "Float" })?;
485            let _ = params[3].as_float()
486                .ok_or(RpcError::WrongParamType { index: 3, expected: "Float" })?;
487            Ok(None)
488        }));
489
490        // particle_burst(preset: Int, origin: Vec3)
491        self.register_fixed(RPC_PARTICLE_BURST, "particle_burst", Box::new(|params| {
492            if params.len() < 2 {
493                return Err(RpcError::InvalidParams { expected: 2, got: params.len() });
494            }
495            let _ = params[0].as_int()
496                .ok_or(RpcError::WrongParamType { index: 0, expected: "Int" })?;
497            let _ = params[1].as_vec3()
498                .ok_or(RpcError::WrongParamType { index: 1, expected: "Vec3" })?;
499            Ok(None)
500        }));
501
502        // screen_effect(effect_type: Int)
503        self.register_fixed(RPC_SCREEN_EFFECT, "screen_effect", Box::new(|params| {
504            if params.is_empty() {
505                return Err(RpcError::InvalidParams { expected: 1, got: 0 });
506            }
507            let _ = params[0].as_int()
508                .ok_or(RpcError::WrongParamType { index: 0, expected: "Int" })?;
509            Ok(None)
510        }));
511
512        // play_sound(sound_id: Int, position: Vec3)
513        self.register_fixed(RPC_PLAY_SOUND, "play_sound", Box::new(|params| {
514            if params.len() < 2 {
515                return Err(RpcError::InvalidParams { expected: 2, got: params.len() });
516            }
517            let _ = params[0].as_int()
518                .ok_or(RpcError::WrongParamType { index: 0, expected: "Int" })?;
519            let _ = params[1].as_vec3()
520                .ok_or(RpcError::WrongParamType { index: 1, expected: "Vec3" })?;
521            Ok(None)
522        }));
523
524        // damage_number(amount: Float, position: Vec3, crit: Bool)
525        self.register_fixed(RPC_DAMAGE_NUMBER, "damage_number", Box::new(|params| {
526            if params.len() < 3 {
527                return Err(RpcError::InvalidParams { expected: 3, got: params.len() });
528            }
529            let _ = params[0].as_float()
530                .ok_or(RpcError::WrongParamType { index: 0, expected: "Float" })?;
531            let _ = params[1].as_vec3()
532                .ok_or(RpcError::WrongParamType { index: 1, expected: "Vec3" })?;
533            let _ = params[2].as_bool()
534                .ok_or(RpcError::WrongParamType { index: 2, expected: "Bool" })?;
535            Ok(None)
536        }));
537
538        // entity_status(entity_id: Int, status_effect: Int, duration: Float)
539        self.register_fixed(RPC_ENTITY_STATUS, "entity_status", Box::new(|params| {
540            if params.len() < 3 {
541                return Err(RpcError::InvalidParams { expected: 3, got: params.len() });
542            }
543            let _ = params[0].as_int()
544                .ok_or(RpcError::WrongParamType { index: 0, expected: "Int" })?;
545            let _ = params[1].as_int()
546                .ok_or(RpcError::WrongParamType { index: 1, expected: "Int" })?;
547            let _ = params[2].as_float()
548                .ok_or(RpcError::WrongParamType { index: 2, expected: "Float" })?;
549            Ok(None)
550        }));
551
552        // camera_shake(trauma: Float)
553        self.register_fixed(RPC_CAMERA_SHAKE, "camera_shake", Box::new(|params| {
554            if params.is_empty() {
555                return Err(RpcError::InvalidParams { expected: 1, got: 0 });
556            }
557            let _ = params[0].as_float()
558                .ok_or(RpcError::WrongParamType { index: 0, expected: "Float" })?;
559            Ok(None)
560        }));
561
562        // dialogue_trigger(npc_id: Int, dialogue_id: Int)
563        self.register_fixed(RPC_DIALOGUE, "dialogue_trigger", Box::new(|params| {
564            if params.len() < 2 {
565                return Err(RpcError::InvalidParams { expected: 2, got: params.len() });
566            }
567            let _ = params[0].as_int()
568                .ok_or(RpcError::WrongParamType { index: 0, expected: "Int" })?;
569            let _ = params[1].as_int()
570                .ok_or(RpcError::WrongParamType { index: 1, expected: "Int" })?;
571            Ok(None)
572        }));
573    }
574}
575
576impl Default for RpcRegistry {
577    fn default() -> Self { Self::new() }
578}
579
580// ─── RpcQueue ────────────────────────────────────────────────────────────────
581
582/// Accumulates outgoing RPC calls for batched network dispatch.
583pub struct RpcQueue {
584    pending:   VecDeque<RpcCall>,
585    next_seq:  u32,
586    /// Maximum calls buffered before oldest are dropped.
587    max_len:   usize,
588}
589
590impl RpcQueue {
591    pub fn new(max_len: usize) -> Self {
592        Self { pending: VecDeque::with_capacity(max_len), next_seq: 0, max_len }
593    }
594
595    /// Add an RPC call to the queue.  Assigns a sequence number.
596    pub fn enqueue(&mut self, mut call: RpcCall) {
597        call.seq = self.next_seq;
598        self.next_seq = self.next_seq.wrapping_add(1);
599        if self.pending.len() >= self.max_len {
600            self.pending.pop_front(); // drop oldest
601        }
602        self.pending.push_back(call);
603    }
604
605    /// Drain all pending calls.
606    pub fn drain(&mut self) -> impl Iterator<Item = RpcCall> + '_ {
607        self.pending.drain(..)
608    }
609
610    /// Peek at the front without removing.
611    pub fn peek(&self) -> Option<&RpcCall> { self.pending.front() }
612
613    pub fn len(&self) -> usize { self.pending.len() }
614    pub fn is_empty(&self) -> bool { self.pending.is_empty() }
615    pub fn clear(&mut self) { self.pending.clear(); }
616}
617
618// ─── RpcSecurity ─────────────────────────────────────────────────────────────
619
620/// Per-client rate limiting and basic call validation.
621#[derive(Debug, Clone)]
622struct RateState {
623    /// Timestamps of recent calls for this RPC from this client.
624    history: VecDeque<Instant>,
625    /// Maximum allowed calls per `window`.
626    max_calls: u32,
627    window: Duration,
628}
629
630impl RateState {
631    fn new(max_calls: u32, window_ms: u64) -> Self {
632        Self {
633            history:   VecDeque::new(),
634            max_calls,
635            window:    Duration::from_millis(window_ms),
636        }
637    }
638
639    /// Returns `true` if the call is allowed (and records it).
640    fn allow(&mut self) -> bool {
641        let now = Instant::now();
642        // Evict old entries
643        while let Some(&front) = self.history.front() {
644            if now.duration_since(front) > self.window {
645                self.history.pop_front();
646            } else {
647                break;
648            }
649        }
650        if self.history.len() as u32 >= self.max_calls {
651            return false;
652        }
653        self.history.push_back(now);
654        true
655    }
656}
657
658pub struct RpcSecurity {
659    /// Per (client_id, rpc_id) rate state.
660    rate_states: HashMap<(u64, RpcId), RateState>,
661    /// RPCs restricted to server-only callers.
662    server_only: std::collections::HashSet<RpcId>,
663    /// Default max calls per 1 second window.
664    default_rate: u32,
665}
666
667impl RpcSecurity {
668    pub fn new(default_rate: u32) -> Self {
669        let mut s = Self {
670            rate_states: HashMap::new(),
671            server_only: std::collections::HashSet::new(),
672            default_rate,
673        };
674        // These RPCs can only come from the server (caller = 0)
675        s.server_only.insert(RPC_PLAYER_JOINED);
676        s.server_only.insert(RPC_PLAYER_LEFT);
677        s.server_only.insert(RPC_GAME_EVENT);
678        s.server_only.insert(RPC_FORCE_FIELD);
679        s.server_only.insert(RPC_ENTITY_STATUS);
680        s
681    }
682
683    /// Returns `Ok(())` if the call is allowed.
684    pub fn check(&mut self, call: &RpcCall) -> Result<(), RpcError> {
685        // Server-only check
686        if self.server_only.contains(&call.id) && call.caller != 0 {
687            return Err(RpcError::Unauthorised { rpc_id: call.id, caller: call.caller });
688        }
689
690        // Rate limiting
691        let rate = self.rate_states
692            .entry((call.caller, call.id))
693            .or_insert_with(|| RateState::new(self.default_rate, 1000));
694
695        if !rate.allow() {
696            return Err(RpcError::RateLimited { rpc_id: call.id, caller: call.caller });
697        }
698
699        Ok(())
700    }
701
702    /// Override rate limit for a specific RPC.
703    pub fn set_rate(&mut self, rpc_id: RpcId, max_calls_per_sec: u32) {
704        // Stored as default — new entries for this RPC will use this limit.
705        // Existing entries are not updated to keep the API simple.
706        let _ = (rpc_id, max_calls_per_sec); // Applied on next call via or_insert_with
707    }
708
709    pub fn add_server_only(&mut self, rpc_id: RpcId) {
710        self.server_only.insert(rpc_id);
711    }
712
713    pub fn remove_server_only(&mut self, rpc_id: RpcId) {
714        self.server_only.remove(&rpc_id);
715    }
716}
717
718impl Default for RpcSecurity {
719    fn default() -> Self { Self::new(30) }
720}
721
722// ─── RpcBatcher ──────────────────────────────────────────────────────────────
723
724/// Combines multiple `RpcCall`s into a single packet payload when possible.
725///
726/// Wire format for a batch:
727/// ```text
728/// [count: u8]  [call_1][call_2]...[call_N]
729/// ```
730pub struct RpcBatcher {
731    /// Maximum payload bytes before starting a new batch packet.
732    pub max_batch_bytes: usize,
733}
734
735impl RpcBatcher {
736    pub fn new(max_batch_bytes: usize) -> Self {
737        Self { max_batch_bytes }
738    }
739
740    /// Split `calls` into batches that each fit within `max_batch_bytes`.
741    /// Returns a vector of serialised batch payloads.
742    pub fn batch(&self, calls: &[RpcCall]) -> Vec<Vec<u8>> {
743        let mut batches: Vec<Vec<u8>> = Vec::new();
744        let mut current  = Vec::new();
745        let mut count    = 0u8;
746        let mut count_pos = 0usize;
747
748        // Reserve 1 byte for count at start
749        current.push(0u8);
750        count_pos = 0;
751
752        for call in calls {
753            let serialised = call.serialize();
754            let needed = serialised.len();
755
756            // If adding this call would overflow, flush
757            if current.len() + needed > self.max_batch_bytes && count > 0 {
758                current[count_pos] = count;
759                batches.push(current);
760                current   = vec![0u8];
761                count     = 0;
762                count_pos = 0;
763            }
764
765            current.extend_from_slice(&serialised);
766            count += 1;
767        }
768
769        if count > 0 {
770            current[count_pos] = count;
771            batches.push(current);
772        }
773
774        batches
775    }
776
777    /// Deserialise a batch payload into `RpcCall`s.
778    pub fn unbatch(&self, payload: &[u8]) -> Result<Vec<RpcCall>, RpcError> {
779        if payload.is_empty() {
780            return Ok(Vec::new());
781        }
782        let count = payload[0] as usize;
783        let mut calls = Vec::with_capacity(count);
784        let mut pos   = 1usize;
785        for _ in 0..count {
786            let (call, new_pos) = RpcCall::deserialize(payload, pos)?;
787            calls.push(call);
788            pos = new_pos;
789        }
790        Ok(calls)
791    }
792}
793
794impl Default for RpcBatcher {
795    fn default() -> Self { Self::new(1200) }
796}
797
798// ─── RpcReplay ────────────────────────────────────────────────────────────────
799
800/// Records RPC calls with timestamps for debugging and replay.
801#[derive(Debug, Clone)]
802pub struct RecordedCall {
803    pub timestamp_ms: u64,
804    pub call:         RpcCall,
805}
806
807pub struct RpcReplay {
808    records:      Vec<RecordedCall>,
809    recording:    bool,
810    max_records:  usize,
811    start_time:   Instant,
812}
813
814impl RpcReplay {
815    pub fn new(max_records: usize) -> Self {
816        Self {
817            records:     Vec::new(),
818            recording:   false,
819            max_records,
820            start_time:  Instant::now(),
821        }
822    }
823
824    pub fn start_recording(&mut self) {
825        self.records.clear();
826        self.recording  = true;
827        self.start_time = Instant::now();
828    }
829
830    pub fn stop_recording(&mut self) {
831        self.recording = false;
832    }
833
834    /// Record a call if recording is active.
835    pub fn record(&mut self, call: RpcCall) {
836        if !self.recording { return; }
837        let timestamp_ms = self.start_time.elapsed().as_millis() as u64;
838        if self.records.len() >= self.max_records {
839            self.records.remove(0); // evict oldest
840        }
841        self.records.push(RecordedCall { timestamp_ms, call });
842    }
843
844    /// Replay all recorded calls into `queue` immediately (timestamp ignored).
845    pub fn replay_all(&self, queue: &mut RpcQueue) {
846        for rec in &self.records {
847            queue.enqueue(rec.call.clone());
848        }
849    }
850
851    /// Replay only calls whose `rpc_id` matches `filter`.
852    pub fn replay_filtered(&self, queue: &mut RpcQueue, filter: RpcId) {
853        for rec in &self.records {
854            if rec.call.id == filter {
855                queue.enqueue(rec.call.clone());
856            }
857        }
858    }
859
860    /// Get a read-only slice of recorded calls.
861    pub fn records(&self) -> &[RecordedCall] { &self.records }
862
863    pub fn record_count(&self) -> usize { self.records.len() }
864    pub fn is_recording(&self) -> bool { self.recording }
865
866    /// Export recording as raw bytes (each call serialised, prefixed by timestamp_ms u64).
867    pub fn export(&self) -> Vec<u8> {
868        let mut out = Vec::new();
869        out.extend_from_slice(&(self.records.len() as u32).to_be_bytes());
870        for rec in &self.records {
871            out.extend_from_slice(&rec.timestamp_ms.to_be_bytes());
872            let call_bytes = rec.call.serialize();
873            out.extend_from_slice(&(call_bytes.len() as u32).to_be_bytes());
874            out.extend_from_slice(&call_bytes);
875        }
876        out
877    }
878
879    /// Import from bytes produced by `export`.
880    pub fn import(&mut self, data: &[u8]) -> Result<(), RpcError> {
881        if data.len() < 4 {
882            return Err(RpcError::DeserializeError("export too short".into()));
883        }
884        let count = u32::from_be_bytes(data[0..4].try_into().unwrap()) as usize;
885        let mut pos = 4usize;
886        self.records.clear();
887
888        for _ in 0..count {
889            if pos + 12 > data.len() {
890                return Err(RpcError::DeserializeError("truncated record".into()));
891            }
892            let timestamp_ms = u64::from_be_bytes(data[pos..pos+8].try_into().unwrap());
893            pos += 8;
894            let call_len = u32::from_be_bytes(data[pos..pos+4].try_into().unwrap()) as usize;
895            pos += 4;
896            if pos + call_len > data.len() {
897                return Err(RpcError::DeserializeError("truncated call bytes".into()));
898            }
899            let (call, _) = RpcCall::deserialize(&data[pos..pos+call_len], 0)?;
900            pos += call_len;
901            self.records.push(RecordedCall { timestamp_ms, call });
902        }
903        Ok(())
904    }
905}
906
907impl Default for RpcReplay {
908    fn default() -> Self { Self::new(10_000) }
909}
910
911// ─── Convenience builders ─────────────────────────────────────────────────────
912
913/// Build a `chat_message` RPC call.
914pub fn rpc_chat(sender_id: u64, text: impl Into<String>) -> RpcCall {
915    RpcCall::new(
916        RPC_CHAT_MESSAGE,
917        RpcTarget::All,
918        vec![RpcParam::Int(sender_id as i64), RpcParam::Str(text.into())],
919    )
920}
921
922/// Build a `player_joined` RPC call.
923pub fn rpc_player_joined(player_id: u64, name: impl Into<String>) -> RpcCall {
924    RpcCall::new(
925        RPC_PLAYER_JOINED,
926        RpcTarget::All,
927        vec![RpcParam::Int(player_id as i64), RpcParam::Str(name.into())],
928    )
929}
930
931/// Build a `player_left` RPC call.
932pub fn rpc_player_left(player_id: u64) -> RpcCall {
933    RpcCall::new(
934        RPC_PLAYER_LEFT,
935        RpcTarget::All,
936        vec![RpcParam::Int(player_id as i64)],
937    )
938}
939
940/// Build a `game_event` RPC call.
941pub fn rpc_game_event(kind: i64, data: Vec<u8>) -> RpcCall {
942    RpcCall::new(
943        RPC_GAME_EVENT,
944        RpcTarget::All,
945        vec![RpcParam::Int(kind), RpcParam::Bytes(data)],
946    )
947}
948
949/// Build a `force_field_spawn` RPC call.
950pub fn rpc_force_field(field_type: i64, position: Vec3, strength: f32, ttl: f32) -> RpcCall {
951    RpcCall::new(
952        RPC_FORCE_FIELD,
953        RpcTarget::All,
954        vec![
955            RpcParam::Int(field_type),
956            RpcParam::Vec3(position),
957            RpcParam::Float(strength),
958            RpcParam::Float(ttl),
959        ],
960    )
961}
962
963/// Build a `particle_burst` RPC call.
964pub fn rpc_particle_burst(preset: i64, origin: Vec3) -> RpcCall {
965    RpcCall::new(
966        RPC_PARTICLE_BURST,
967        RpcTarget::All,
968        vec![RpcParam::Int(preset), RpcParam::Vec3(origin)],
969    )
970}
971
972/// Build a `screen_effect` RPC call.
973pub fn rpc_screen_effect(effect_type: i64, target_client: u64) -> RpcCall {
974    RpcCall::new(
975        RPC_SCREEN_EFFECT,
976        RpcTarget::Client(target_client),
977        vec![RpcParam::Int(effect_type)],
978    )
979}
980
981/// Build a `play_sound` RPC call.
982pub fn rpc_play_sound(sound_id: i64, position: Vec3) -> RpcCall {
983    RpcCall::new(
984        RPC_PLAY_SOUND,
985        RpcTarget::All,
986        vec![RpcParam::Int(sound_id), RpcParam::Vec3(position)],
987    )
988}
989
990/// Build a `damage_number` RPC call.
991pub fn rpc_damage_number(amount: f32, position: Vec3, crit: bool) -> RpcCall {
992    RpcCall::new(
993        RPC_DAMAGE_NUMBER,
994        RpcTarget::All,
995        vec![RpcParam::Float(amount), RpcParam::Vec3(position), RpcParam::Bool(crit)],
996    )
997}
998
999/// Build an `entity_status` RPC call.
1000pub fn rpc_entity_status(entity_id: i64, status: i64, duration: f32) -> RpcCall {
1001    RpcCall::new(
1002        RPC_ENTITY_STATUS,
1003        RpcTarget::All,
1004        vec![RpcParam::Int(entity_id), RpcParam::Int(status), RpcParam::Float(duration)],
1005    )
1006}
1007
1008/// Build a `camera_shake` RPC call.
1009pub fn rpc_camera_shake(trauma: f32, target_client: u64) -> RpcCall {
1010    RpcCall::new(
1011        RPC_CAMERA_SHAKE,
1012        RpcTarget::Client(target_client),
1013        vec![RpcParam::Float(trauma)],
1014    )
1015}
1016
1017/// Build a `dialogue_trigger` RPC call.
1018pub fn rpc_dialogue_trigger(npc_id: i64, dialogue_id: i64) -> RpcCall {
1019    RpcCall::new(
1020        RPC_DIALOGUE,
1021        RpcTarget::All,
1022        vec![RpcParam::Int(npc_id), RpcParam::Int(dialogue_id)],
1023    )
1024}
1025
1026// ─── Tests ────────────────────────────────────────────────────────────────────
1027
1028#[cfg(test)]
1029mod tests {
1030    use super::*;
1031    use crate::networking::sync::Vec3;
1032
1033    // ── RpcParam serialisation ────────────────────────────────────────────────
1034
1035    #[test]
1036    fn test_param_bool_roundtrip() {
1037        for &b in &[true, false] {
1038            let p = RpcParam::Bool(b);
1039            let mut buf = Vec::new();
1040            p.serialize(&mut buf);
1041            let (decoded, _) = RpcParam::deserialize(&buf, 0).unwrap();
1042            assert_eq!(decoded, p);
1043        }
1044    }
1045
1046    #[test]
1047    fn test_param_int_roundtrip() {
1048        for &v in &[0i64, -1, i64::MIN, i64::MAX, 42] {
1049            let p = RpcParam::Int(v);
1050            let mut buf = Vec::new();
1051            p.serialize(&mut buf);
1052            let (decoded, _) = RpcParam::deserialize(&buf, 0).unwrap();
1053            assert_eq!(decoded, p);
1054        }
1055    }
1056
1057    #[test]
1058    fn test_param_float_roundtrip() {
1059        let p = RpcParam::Float(3.14159);
1060        let mut buf = Vec::new();
1061        p.serialize(&mut buf);
1062        let (decoded, _) = RpcParam::deserialize(&buf, 0).unwrap();
1063        if let (RpcParam::Float(a), RpcParam::Float(b)) = (&p, &decoded) {
1064            assert!((a - b).abs() < 1e-6);
1065        }
1066    }
1067
1068    #[test]
1069    fn test_param_str_roundtrip() {
1070        let p = RpcParam::Str("hello, world!".into());
1071        let mut buf = Vec::new();
1072        p.serialize(&mut buf);
1073        let (decoded, _) = RpcParam::deserialize(&buf, 0).unwrap();
1074        assert_eq!(decoded, p);
1075    }
1076
1077    #[test]
1078    fn test_param_vec3_roundtrip() {
1079        let p = RpcParam::Vec3(Vec3::new(1.0, 2.0, 3.0));
1080        let mut buf = Vec::new();
1081        p.serialize(&mut buf);
1082        let (decoded, _) = RpcParam::deserialize(&buf, 0).unwrap();
1083        assert_eq!(decoded, p);
1084    }
1085
1086    #[test]
1087    fn test_param_bytes_roundtrip() {
1088        let p = RpcParam::Bytes(vec![0xDE, 0xAD, 0xBE, 0xEF]);
1089        let mut buf = Vec::new();
1090        p.serialize(&mut buf);
1091        let (decoded, _) = RpcParam::deserialize(&buf, 0).unwrap();
1092        assert_eq!(decoded, p);
1093    }
1094
1095    // ── RpcCall serialisation ─────────────────────────────────────────────────
1096
1097    #[test]
1098    fn test_rpc_call_serialize_deserialize() {
1099        let call = rpc_chat(42, "Hi there");
1100        let bytes = call.serialize();
1101        let (decoded, consumed) = RpcCall::deserialize(&bytes, 0).unwrap();
1102        assert_eq!(consumed, bytes.len());
1103        assert_eq!(decoded.id, call.id);
1104        assert_eq!(decoded.params.len(), 2);
1105    }
1106
1107    #[test]
1108    fn test_rpc_call_all_targets() {
1109        let targets = vec![
1110            RpcTarget::All,
1111            RpcTarget::Server,
1112            RpcTarget::Client(123),
1113            RpcTarget::Team(2),
1114            RpcTarget::AllExcept(456),
1115        ];
1116        for target in targets {
1117            let call = RpcCall::new(RPC_CHAT_MESSAGE, target.clone(), vec![
1118                RpcParam::Int(1), RpcParam::Str("x".into()),
1119            ]);
1120            let bytes = call.serialize();
1121            let (decoded, _) = RpcCall::deserialize(&bytes, 0).unwrap();
1122            assert_eq!(decoded.target, target);
1123        }
1124    }
1125
1126    // ── RpcRegistry ───────────────────────────────────────────────────────────
1127
1128    #[test]
1129    fn test_registry_has_builtins() {
1130        let reg = RpcRegistry::new();
1131        assert!(reg.handler(RPC_CHAT_MESSAGE).is_some());
1132        assert!(reg.handler(RPC_PLAYER_JOINED).is_some());
1133        assert!(reg.handler(RPC_CAMERA_SHAKE).is_some());
1134        assert!(reg.handler(RPC_DIALOGUE).is_some());
1135    }
1136
1137    #[test]
1138    fn test_registry_invoke_chat() {
1139        let reg = RpcRegistry::new();
1140        let result = reg.invoke(RPC_CHAT_MESSAGE, &[
1141            RpcParam::Int(1),
1142            RpcParam::Str("hello".into()),
1143        ]);
1144        assert!(result.is_ok());
1145    }
1146
1147    #[test]
1148    fn test_registry_invoke_wrong_params() {
1149        let reg = RpcRegistry::new();
1150        let result = reg.invoke(RPC_CHAT_MESSAGE, &[]);
1151        assert!(matches!(result, Err(RpcError::InvalidParams { .. })));
1152    }
1153
1154    #[test]
1155    fn test_registry_unknown_rpc() {
1156        let reg = RpcRegistry::new();
1157        let result = reg.invoke(RpcId(0xFFFF), &[]);
1158        assert!(matches!(result, Err(RpcError::UnknownRpc(_))));
1159    }
1160
1161    // ── RpcBatcher ────────────────────────────────────────────────────────────
1162
1163    #[test]
1164    fn test_batcher_roundtrip() {
1165        let batcher = RpcBatcher::new(4096);
1166        let calls = vec![
1167            rpc_chat(1, "hello"),
1168            rpc_player_joined(2, "Alice"),
1169            rpc_camera_shake(0.8, 42),
1170        ];
1171        let batches = batcher.batch(&calls);
1172        assert!(!batches.is_empty());
1173
1174        let mut decoded_all = Vec::new();
1175        for batch in &batches {
1176            let decoded = batcher.unbatch(batch).unwrap();
1177            decoded_all.extend(decoded);
1178        }
1179        assert_eq!(decoded_all.len(), calls.len());
1180        assert_eq!(decoded_all[0].id, RPC_CHAT_MESSAGE);
1181        assert_eq!(decoded_all[1].id, RPC_PLAYER_JOINED);
1182    }
1183
1184    #[test]
1185    fn test_batcher_splits_large_batch() {
1186        let batcher = RpcBatcher::new(60); // tiny max
1187        let calls: Vec<RpcCall> = (0..5).map(|i| rpc_chat(i, "x")).collect();
1188        let batches = batcher.batch(&calls);
1189        // With tiny max, should produce multiple batches
1190        assert!(batches.len() >= 1);
1191        // Total decoded calls should equal 5
1192        let total: usize = batches.iter()
1193            .map(|b| batcher.unbatch(b).unwrap().len())
1194            .sum();
1195        assert_eq!(total, 5);
1196    }
1197
1198    // ── RpcQueue ──────────────────────────────────────────────────────────────
1199
1200    #[test]
1201    fn test_rpc_queue_sequence() {
1202        let mut q = RpcQueue::new(64);
1203        q.enqueue(rpc_chat(1, "a"));
1204        q.enqueue(rpc_chat(1, "b"));
1205        let calls: Vec<RpcCall> = q.drain().collect();
1206        assert_eq!(calls[0].seq, 0);
1207        assert_eq!(calls[1].seq, 1);
1208    }
1209
1210    // ── RpcSecurity ───────────────────────────────────────────────────────────
1211
1212    #[test]
1213    fn test_security_server_only_rejected() {
1214        let mut sec = RpcSecurity::new(100);
1215        let mut call = rpc_player_joined(1, "Alice");
1216        call.caller = 99; // non-server caller
1217        assert!(matches!(sec.check(&call), Err(RpcError::Unauthorised { .. })));
1218    }
1219
1220    #[test]
1221    fn test_security_server_allowed() {
1222        let mut sec = RpcSecurity::new(100);
1223        let mut call = rpc_player_joined(1, "Alice");
1224        call.caller = 0; // server
1225        assert!(sec.check(&call).is_ok());
1226    }
1227
1228    #[test]
1229    fn test_security_rate_limit() {
1230        let mut sec = RpcSecurity::new(3); // 3 per second
1231        let mut call = rpc_chat(1, "hi");
1232        call.caller = 5;
1233        // First 3 allowed
1234        assert!(sec.check(&call).is_ok());
1235        assert!(sec.check(&call).is_ok());
1236        assert!(sec.check(&call).is_ok());
1237        // 4th should be rate-limited
1238        assert!(matches!(sec.check(&call), Err(RpcError::RateLimited { .. })));
1239    }
1240
1241    // ── RpcReplay ─────────────────────────────────────────────────────────────
1242
1243    #[test]
1244    fn test_replay_record_and_replay() {
1245        let mut replay = RpcReplay::new(100);
1246        replay.start_recording();
1247        replay.record(rpc_chat(1, "test"));
1248        replay.record(rpc_camera_shake(0.5, 1));
1249        replay.stop_recording();
1250
1251        assert_eq!(replay.record_count(), 2);
1252
1253        let mut queue = RpcQueue::new(64);
1254        replay.replay_all(&mut queue);
1255        assert_eq!(queue.len(), 2);
1256    }
1257
1258    #[test]
1259    fn test_replay_export_import() {
1260        let mut replay = RpcReplay::new(100);
1261        replay.start_recording();
1262        replay.record(rpc_chat(7, "hello"));
1263        replay.stop_recording();
1264
1265        let exported = replay.export();
1266
1267        let mut replay2 = RpcReplay::new(100);
1268        replay2.import(&exported).unwrap();
1269        assert_eq!(replay2.record_count(), 1);
1270        assert_eq!(replay2.records()[0].call.id, RPC_CHAT_MESSAGE);
1271    }
1272}