#![allow(missing_docs)]
#![allow(dead_code)]
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
pub type PeerId = u64;
pub const LOCAL_PEER_ID: PeerId = u64::MAX;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(bound(serialize = "Cmd: Serialize", deserialize = "Cmd: DeserializeOwned"))]
pub struct Frame<Cmd>
where
Cmd: Serialize + DeserializeOwned,
{
pub tick: u64,
pub snapshot: Vec<u8>,
pub inputs: HashMap<PeerId, Cmd>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(bound(serialize = "Cmd: Serialize", deserialize = "Cmd: DeserializeOwned"))]
pub struct RollbackBuffer<Cmd>
where
Cmd: Serialize + DeserializeOwned,
{
frames: VecDeque<Frame<Cmd>>,
pub head_tick: u64,
pub capacity: usize,
}
impl<Cmd: Clone + Serialize + DeserializeOwned> RollbackBuffer<Cmd> {
pub fn new(capacity: usize) -> Self {
Self {
frames: VecDeque::with_capacity(capacity),
head_tick: 0,
capacity,
}
}
pub fn push(&mut self, frame: Frame<Cmd>) {
if self.frames.len() >= self.capacity {
self.frames.pop_front();
}
self.head_tick = frame.tick;
self.frames.push_back(frame);
}
pub fn get(&self, tick: u64) -> Option<&Frame<Cmd>> {
self.frames.iter().find(|f| f.tick == tick)
}
pub fn get_mut(&mut self, tick: u64) -> Option<&mut Frame<Cmd>> {
self.frames.iter_mut().find(|f| f.tick == tick)
}
pub fn tail_tick(&self) -> Option<u64> {
self.frames.front().map(|f| f.tick)
}
pub fn len(&self) -> usize {
self.frames.len()
}
pub fn is_empty(&self) -> bool {
self.frames.is_empty()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct DesyncReport {
pub tick: u64,
pub local_hash: u64,
pub remote_hash: u64,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum RollbackError {
TickNotInBuffer(u64),
PriorSnapshotUnavailable,
}
fn hash_snapshot(snap: &[u8]) -> u64 {
let mut h: u64 = 0xcbf29ce484222325;
for &b in snap {
h ^= b as u64;
h = h.wrapping_mul(0x100000003d811_u64);
}
h
}
pub struct RollbackWorld<World, Cmd>
where
Cmd: Clone + PartialEq + Serialize + DeserializeOwned,
{
pub world: World,
pub buffer: RollbackBuffer<Cmd>,
advance_fn: fn(&mut World, f64, &[Cmd]),
snapshot_fn: fn(&World) -> Vec<u8>,
restore_fn: fn(&mut World, &[u8]),
pub current_tick: u64,
pending_inputs: HashMap<u64, HashMap<PeerId, Cmd>>,
}
impl<World, Cmd> RollbackWorld<World, Cmd>
where
Cmd: Clone + PartialEq + Serialize + DeserializeOwned,
{
pub fn new(
world: World,
capacity: usize,
advance_fn: fn(&mut World, f64, &[Cmd]),
snapshot_fn: fn(&World) -> Vec<u8>,
restore_fn: fn(&mut World, &[u8]),
) -> Self {
let initial_snapshot = snapshot_fn(&world);
let mut buffer = RollbackBuffer::new(capacity);
buffer.push(Frame {
tick: 0,
snapshot: initial_snapshot,
inputs: HashMap::new(),
});
Self {
world,
buffer,
advance_fn,
snapshot_fn,
restore_fn,
current_tick: 0,
pending_inputs: HashMap::new(),
}
}
pub fn record_input(&mut self, peer: PeerId, tick: u64, cmd: Cmd) {
if tick <= self.current_tick {
if let Some(frame) = self.buffer.get_mut(tick) {
frame.inputs.insert(peer, cmd);
}
} else {
self.pending_inputs
.entry(tick)
.or_default()
.insert(peer, cmd);
}
}
pub fn step(&mut self, dt: f64, local_cmds: &[Cmd]) {
let next_tick = self.current_tick + 1;
let mut inputs: HashMap<PeerId, Cmd> =
self.pending_inputs.remove(&next_tick).unwrap_or_default();
if let Some(cmd) = local_cmds.first() {
inputs.insert(LOCAL_PEER_ID, cmd.clone());
}
let ordered_cmds = Self::inputs_to_sorted_slice(&inputs);
(self.advance_fn)(&mut self.world, dt, &ordered_cmds);
let snapshot = (self.snapshot_fn)(&self.world);
let frame = Frame {
tick: next_tick,
snapshot,
inputs,
};
self.buffer.push(frame);
self.current_tick = next_tick;
}
pub fn resimulate_from(&mut self, from_tick: u64, dt: f64) -> Result<(), RollbackError> {
if self.buffer.get(from_tick).is_none() {
return Err(RollbackError::TickNotInBuffer(from_tick));
}
if from_tick == 0 {
return Err(RollbackError::PriorSnapshotUnavailable);
}
let prior_tick = from_tick - 1;
let prior_snapshot = self
.buffer
.get(prior_tick)
.map(|f| f.snapshot.clone())
.ok_or(RollbackError::PriorSnapshotUnavailable)?;
(self.restore_fn)(&mut self.world, &prior_snapshot);
let end_tick = self.current_tick;
for t in from_tick..=end_tick {
let ordered_cmds: Vec<Cmd> = self
.buffer
.get(t)
.map(|f| Self::inputs_to_sorted_slice(&f.inputs))
.unwrap_or_default();
(self.advance_fn)(&mut self.world, dt, &ordered_cmds);
let new_snapshot = (self.snapshot_fn)(&self.world);
if let Some(frame) = self.buffer.get_mut(t) {
frame.snapshot = new_snapshot;
}
}
Ok(())
}
pub fn check_desync(&self, tick: u64, remote_hash: u64) -> Option<DesyncReport> {
let frame = self.buffer.get(tick)?;
let local_hash = hash_snapshot(&frame.snapshot);
if local_hash != remote_hash {
Some(DesyncReport {
tick,
local_hash,
remote_hash,
})
} else {
None
}
}
fn inputs_to_sorted_slice(inputs: &HashMap<PeerId, Cmd>) -> Vec<Cmd> {
let mut pairs: Vec<(PeerId, &Cmd)> = inputs.iter().map(|(&k, v)| (k, v)).collect();
pairs.sort_by_key(|(peer, _)| *peer);
pairs.into_iter().map(|(_, cmd)| cmd.clone()).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, Default)]
struct ToyWorld {
state: i64,
history: Vec<i64>,
}
fn advance(w: &mut ToyWorld, _dt: f64, cmds: &[i32]) {
let delta: i32 = cmds.iter().sum();
w.state += delta as i64;
w.history.push(w.state);
}
fn snapshot(w: &ToyWorld) -> Vec<u8> {
w.state.to_le_bytes().to_vec()
}
fn restore(w: &mut ToyWorld, snap: &[u8]) {
if snap.len() == 8 {
w.state = i64::from_le_bytes(snap.try_into().unwrap_or([0u8; 8]));
}
}
fn make_world(capacity: usize) -> RollbackWorld<ToyWorld, i32> {
RollbackWorld::new(ToyWorld::default(), capacity, advance, snapshot, restore)
}
#[test]
fn test_late_input_triggers_resimulate() {
let mut rw = make_world(16);
rw.step(1.0 / 60.0, &[1_i32]);
assert_eq!(rw.world.state, 1, "after tick 1 state should be 1");
assert_eq!(rw.current_tick, 1);
rw.record_input(2, 1, 10_i32);
rw.resimulate_from(1, 1.0 / 60.0)
.expect("resimulate should succeed");
assert_eq!(rw.world.state, 11, "after resim state should be 11");
}
#[test]
fn test_identical_input_hash_equality() {
let cmds: &[i32] = &[3, 7, -2, 1, 5, 0, 4, 8, -3, 2];
let mut rw1 = make_world(32);
let mut rw2 = make_world(32);
let dt = 1.0 / 60.0;
for &c in cmds {
rw1.step(dt, &[c]);
rw2.step(dt, &[c]);
}
let tick = rw1.current_tick;
let snap1 = rw1.buffer.get(tick).map(|f| f.snapshot.clone()).unwrap();
let snap2 = rw2.buffer.get(tick).map(|f| f.snapshot.clone()).unwrap();
assert_eq!(
hash_snapshot(&snap1),
hash_snapshot(&snap2),
"identical simulations should produce identical snapshot hashes"
);
}
#[test]
fn test_desync_detection() {
let mut rw = make_world(16);
rw.step(1.0 / 60.0, &[5_i32]);
let tick = rw.current_tick;
{
let frame = rw.buffer.get_mut(tick).unwrap();
if let Some(b) = frame.snapshot.first_mut() {
*b ^= 0xFF;
}
}
let true_state: i64 = 5;
let true_snap = true_state.to_le_bytes().to_vec();
let true_hash = hash_snapshot(&true_snap);
let report = rw.check_desync(tick, true_hash);
assert!(
report.is_some(),
"check_desync should detect the tampered snapshot"
);
let r = report.unwrap();
assert_eq!(r.tick, tick);
assert_eq!(r.remote_hash, true_hash);
assert_ne!(r.local_hash, true_hash);
}
#[test]
fn test_buffer_capacity() {
let mut rw = make_world(3);
let dt = 1.0 / 60.0;
for _ in 0..10 {
rw.step(dt, &[1_i32]);
}
assert_eq!(rw.buffer.len(), 3, "buffer should hold at most 3 frames");
let oldest = rw.buffer.tail_tick().unwrap();
assert!(oldest > 0, "oldest tick should be > 0 after 10 steps");
}
#[test]
fn test_serde_frame() {
let mut inputs = HashMap::new();
inputs.insert(1u64, 42_i32);
inputs.insert(2u64, -7_i32);
let frame: Frame<i32> = Frame {
tick: 99,
snapshot: vec![0xDE, 0xAD, 0xBE, 0xEF],
inputs,
};
let json = serde_json::to_string(&frame).expect("serialize frame");
let restored: Frame<i32> = serde_json::from_str(&json).expect("deserialize frame");
assert_eq!(restored.tick, 99);
assert_eq!(restored.inputs.get(&1u64), Some(&42_i32));
assert_eq!(restored.inputs.get(&2u64), Some(&-7_i32));
assert_eq!(restored.snapshot, vec![0xDE, 0xAD, 0xBE, 0xEF]);
}
#[test]
fn test_resimulate_from_missing_tick() {
let mut rw = make_world(4);
rw.step(1.0 / 60.0, &[1_i32]);
let result = rw.resimulate_from(999, 1.0 / 60.0);
assert_eq!(result, Err(RollbackError::TickNotInBuffer(999)));
}
}