1use std::collections::{HashMap, VecDeque};
12
13pub type Frame = u64;
16pub type PlayerId = u8;
17
18pub const MAX_PLAYERS: usize = 4;
19pub const MAX_ROLLBACK_FRAMES: usize = 8;
20pub const INPUT_DELAY: usize = 2;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
26pub struct PlayerInput {
27 pub buttons: u32, pub buttons_pressed: u32, pub buttons_released: u32, pub axis_x: i16, pub axis_y: i16, pub axis_rx: i16, pub axis_ry: i16, pub frame: Frame,
35}
36
37impl PlayerInput {
38 pub fn is_held(&self, btn: u32) -> bool { self.buttons & btn != 0 }
39 pub fn is_pressed(&self, btn: u32) -> bool { self.buttons_pressed & btn != 0 }
40 pub fn is_released(&self, btn: u32) -> bool { self.buttons_released & btn != 0 }
41
42 pub fn direction(&self) -> (f32, f32) {
43 (self.axis_x as f32 / 32767.0, self.axis_y as f32 / 32767.0)
44 }
45
46 pub fn to_bytes(&self) -> [u8; 16] {
47 let mut buf = [0u8; 16];
48 buf[0..4].copy_from_slice(&self.buttons.to_le_bytes());
49 buf[4..8].copy_from_slice(&self.buttons_pressed.to_le_bytes());
50 buf[8..10].copy_from_slice(&self.axis_x.to_le_bytes());
51 buf[10..12].copy_from_slice(&self.axis_y.to_le_bytes());
52 buf[12..14].copy_from_slice(&self.axis_rx.to_le_bytes());
53 buf[14..16].copy_from_slice(&self.axis_ry.to_le_bytes());
54 buf
55 }
56
57 pub fn from_bytes(bytes: &[u8; 16]) -> Self {
58 let buttons = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
59 let buttons_pressed = u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
60 let axis_x = i16::from_le_bytes([bytes[8], bytes[9]]);
61 let axis_y = i16::from_le_bytes([bytes[10], bytes[11]]);
62 let axis_rx = i16::from_le_bytes([bytes[12], bytes[13]]);
63 let axis_ry = i16::from_le_bytes([bytes[14], bytes[15]]);
64 Self { buttons, buttons_pressed, buttons_released: 0, axis_x, axis_y, axis_rx, axis_ry, frame: 0 }
65 }
66
67 pub fn is_neutral(&self) -> bool {
69 self.buttons == 0 && self.axis_x == 0 && self.axis_y == 0
70 }
71}
72
73#[derive(Debug, Clone, Default)]
75pub struct FrameInput {
76 pub frame: Frame,
77 pub inputs: [PlayerInput; MAX_PLAYERS],
78 pub confirmed: [bool; MAX_PLAYERS],
79}
80
81impl FrameInput {
82 pub fn new(frame: Frame) -> Self {
83 let mut fi = Self::default();
84 fi.frame = frame;
85 fi
86 }
87
88 pub fn set_input(&mut self, player: PlayerId, input: PlayerInput) {
89 let idx = player as usize;
90 if idx < MAX_PLAYERS {
91 self.inputs[idx] = input;
92 self.confirmed[idx] = true;
93 }
94 }
95
96 pub fn all_confirmed(&self, player_count: u8) -> bool {
97 (0..player_count as usize).all(|i| self.confirmed[i])
98 }
99
100 pub fn checksum(&self) -> u32 {
101 let mut h = 0u32;
102 for inp in &self.inputs {
103 h ^= inp.buttons;
104 h = h.wrapping_add(inp.axis_x as u32).wrapping_mul(0x9e3779b9);
105 }
106 h
107 }
108}
109
110pub struct InputPredictor {
114 last_confirmed: [PlayerInput; MAX_PLAYERS],
115 last_confirmed_frame: [Frame; MAX_PLAYERS],
116 prediction_streak: [u32; MAX_PLAYERS],
117}
118
119impl InputPredictor {
120 pub fn new() -> Self {
121 Self {
122 last_confirmed: [PlayerInput::default(); MAX_PLAYERS],
123 last_confirmed_frame: [0; MAX_PLAYERS],
124 prediction_streak: [0; MAX_PLAYERS],
125 }
126 }
127
128 pub fn confirm_input(&mut self, player: PlayerId, input: PlayerInput) {
129 let idx = player as usize;
130 if idx < MAX_PLAYERS {
131 self.last_confirmed[idx] = input;
132 self.last_confirmed_frame[idx] = input.frame;
133 self.prediction_streak[idx] = 0;
134 }
135 }
136
137 pub fn predict(&mut self, player: PlayerId, frame: Frame) -> PlayerInput {
138 let idx = player as usize;
139 if idx >= MAX_PLAYERS { return PlayerInput::default(); }
140 self.prediction_streak[idx] += 1;
141
142 let mut predicted = self.last_confirmed[idx];
143 if self.prediction_streak[idx] > 6 {
145 predicted.buttons = 0;
146 predicted.axis_x = 0;
147 predicted.axis_y = 0;
148 }
149 predicted.frame = frame;
150 predicted.buttons_pressed = 0; predicted.buttons_released = 0;
152 predicted
153 }
154
155 pub fn prediction_error(&self, player: PlayerId, actual: &PlayerInput) -> bool {
156 let idx = player as usize;
157 if idx >= MAX_PLAYERS { return false; }
158 let predicted = self.last_confirmed[idx];
159 predicted.buttons != actual.buttons || predicted.axis_x != actual.axis_x
160 }
161
162 pub fn streak(&self, player: PlayerId) -> u32 {
163 self.prediction_streak.get(player as usize).copied().unwrap_or(0)
164 }
165}
166
167pub trait GameState: Clone + Send + 'static {
171 fn advance(&mut self, inputs: &FrameInput);
173
174 fn checksum(&self) -> u64;
176
177 fn snapshot_size_hint() -> usize { 4096 }
179}
180
181#[derive(Clone)]
184pub struct StateSnapshot<S: GameState> {
185 pub frame: Frame,
186 pub state: S,
187 pub checksum: u64,
188}
189
190impl<S: GameState> StateSnapshot<S> {
191 pub fn capture(frame: Frame, state: &S) -> Self {
192 let cs = state.checksum();
193 Self { frame, state: state.clone(), checksum: cs }
194 }
195}
196
197pub struct RollbackBuffer<S: GameState> {
201 snapshots: VecDeque<StateSnapshot<S>>,
202 frame_inputs: VecDeque<FrameInput>,
203 capacity: usize,
204}
205
206impl<S: GameState> RollbackBuffer<S> {
207 pub fn new(capacity: usize) -> Self {
208 Self {
209 snapshots: VecDeque::with_capacity(capacity),
210 frame_inputs: VecDeque::with_capacity(capacity * 2),
211 capacity,
212 }
213 }
214
215 pub fn save_snapshot(&mut self, frame: Frame, state: &S) {
216 let snapshot = StateSnapshot::capture(frame, state);
217 if self.snapshots.len() >= self.capacity {
218 self.snapshots.pop_front();
219 }
220 self.snapshots.push_back(snapshot);
221 }
222
223 pub fn save_inputs(&mut self, inputs: FrameInput) {
224 if self.frame_inputs.len() >= self.capacity * 2 {
225 self.frame_inputs.pop_front();
226 }
227 self.frame_inputs.push_back(inputs);
228 }
229
230 pub fn get_snapshot(&self, frame: Frame) -> Option<&StateSnapshot<S>> {
231 self.snapshots.iter().rfind(|s| s.frame == frame)
232 }
233
234 pub fn latest_snapshot(&self) -> Option<&StateSnapshot<S>> {
235 self.snapshots.back()
236 }
237
238 pub fn get_inputs_from(&self, start_frame: Frame) -> Vec<&FrameInput> {
239 self.frame_inputs.iter()
240 .filter(|fi| fi.frame >= start_frame)
241 .collect()
242 }
243
244 pub fn oldest_snapshot_frame(&self) -> Option<Frame> {
245 self.snapshots.front().map(|s| s.frame)
246 }
247
248 pub fn len(&self) -> usize { self.snapshots.len() }
249}
250
251#[derive(Debug, Clone)]
254pub struct DesyncEvent {
255 pub frame: Frame,
256 pub local_checksum: u64,
257 pub remote_checksum: u64,
258 pub player_id: PlayerId,
259}
260
261pub struct DesyncDetector {
262 local_checksums: HashMap<Frame, u64>,
263 remote_checksums: HashMap<(Frame, PlayerId), u64>,
264 desyncs: Vec<DesyncEvent>,
265 check_interval: u32,
266}
267
268impl DesyncDetector {
269 pub fn new(check_interval: u32) -> Self {
270 Self {
271 local_checksums: HashMap::new(),
272 remote_checksums: HashMap::new(),
273 desyncs: Vec::new(),
274 check_interval,
275 }
276 }
277
278 pub fn record_local(&mut self, frame: Frame, checksum: u64) {
279 self.local_checksums.insert(frame, checksum);
280 }
281
282 pub fn record_remote(&mut self, frame: Frame, player: PlayerId, checksum: u64) {
283 self.remote_checksums.insert((frame, player), checksum);
284 if let Some(&local) = self.local_checksums.get(&frame) {
286 if local != checksum {
287 self.desyncs.push(DesyncEvent { frame, local_checksum: local, remote_checksum: checksum, player_id: player });
288 }
289 }
290 }
291
292 pub fn has_desync(&self) -> bool { !self.desyncs.is_empty() }
293
294 pub fn drain_desyncs(&mut self) -> Vec<DesyncEvent> {
295 std::mem::take(&mut self.desyncs)
296 }
297
298 pub fn cleanup_old(&mut self, oldest_frame: Frame) {
299 self.local_checksums.retain(|&f, _| f >= oldest_frame);
300 self.remote_checksums.retain(|(f, _), _| *f >= oldest_frame);
301 }
302}
303
304#[derive(Debug, Clone)]
307pub struct PeerStats {
308 pub player_id: PlayerId,
309 pub rtt_ms: f32,
310 pub rtt_variance_ms: f32,
311 pub packet_loss_pct: f32,
312 pub frames_ahead: i32, pub last_recv_frame: Frame,
314 pub predicted_frames: u32,
315}
316
317pub struct NetworkStats {
318 peers: HashMap<PlayerId, PeerStats>,
319 local_frame: Frame,
320 rtt_samples: VecDeque<f32>,
321 max_rtt_samples: usize,
322}
323
324impl NetworkStats {
325 pub fn new() -> Self {
326 Self {
327 peers: HashMap::new(),
328 local_frame: 0,
329 rtt_samples: VecDeque::with_capacity(64),
330 max_rtt_samples: 64,
331 }
332 }
333
334 pub fn record_rtt(&mut self, player: PlayerId, rtt_ms: f32) {
335 if self.rtt_samples.len() >= self.max_rtt_samples {
336 self.rtt_samples.pop_front();
337 }
338 self.rtt_samples.push_back(rtt_ms);
339 let entry = self.peers.entry(player).or_insert_with(|| PeerStats {
340 player_id: player, rtt_ms: 0.0, rtt_variance_ms: 0.0,
341 packet_loss_pct: 0.0, frames_ahead: 0, last_recv_frame: 0,
342 predicted_frames: 0,
343 });
344 let sum: f32 = self.rtt_samples.iter().sum();
345 entry.rtt_ms = sum / self.rtt_samples.len() as f32;
346 let variance: f32 = self.rtt_samples.iter()
347 .map(|&r| (r - entry.rtt_ms).powi(2))
348 .sum::<f32>() / self.rtt_samples.len() as f32;
349 entry.rtt_variance_ms = variance.sqrt();
350 }
351
352 pub fn recommended_input_delay(&self) -> usize {
353 let max_rtt = self.peers.values()
354 .map(|p| p.rtt_ms)
355 .fold(0.0f32, f32::max);
356 let frames_per_ms = 1000.0 / 60.0; let delay = (max_rtt / (2.0 * frames_per_ms)).ceil() as usize;
358 delay.clamp(1, 6)
359 }
360
361 pub fn peer(&self, player: PlayerId) -> Option<&PeerStats> {
362 self.peers.get(&player)
363 }
364
365 pub fn average_rtt(&self) -> f32 {
366 if self.rtt_samples.is_empty() { return 0.0; }
367 self.rtt_samples.iter().sum::<f32>() / self.rtt_samples.len() as f32
368 }
369
370 pub fn update_frame(&mut self, frame: Frame) { self.local_frame = frame; }
371}
372
373pub struct RollbackSession<S: GameState> {
377 pub current_frame: Frame,
378 pub confirmed_frame: Frame,
379 pub buffer: RollbackBuffer<S>,
380 pub predictor: InputPredictor,
381 pub desync_detector: DesyncDetector,
382 pub net_stats: NetworkStats,
383 pub player_count: u8,
384 pub local_player_id: PlayerId,
385 pub input_delay: usize,
386 local_input_queue: VecDeque<PlayerInput>,
387 pending_remote: HashMap<(Frame, PlayerId), PlayerInput>,
388 rollback_count: u64,
389}
390
391impl<S: GameState> RollbackSession<S> {
392 pub fn new(player_count: u8, local_player_id: PlayerId) -> Self {
393 Self {
394 current_frame: 0,
395 confirmed_frame: 0,
396 buffer: RollbackBuffer::new(MAX_ROLLBACK_FRAMES * 4),
397 predictor: InputPredictor::new(),
398 desync_detector: DesyncDetector::new(8),
399 net_stats: NetworkStats::new(),
400 player_count,
401 local_player_id,
402 input_delay: INPUT_DELAY,
403 local_input_queue: VecDeque::new(),
404 pending_remote: HashMap::new(),
405 rollback_count: 0,
406 }
407 }
408
409 pub fn queue_local_input(&mut self, input: PlayerInput) {
411 self.local_input_queue.push_back(input);
412 }
413
414 pub fn receive_remote_input(&mut self, player: PlayerId, frame: Frame, input: PlayerInput) {
416 self.pending_remote.insert((frame, player), input);
417 self.predictor.confirm_input(player, input);
418 }
419
420 pub fn build_frame_input(&mut self, state: &S) -> FrameInput {
422 let frame = self.current_frame;
423 let mut fi = FrameInput::new(frame);
424
425 let local_input = self.local_input_queue.pop_front().unwrap_or_default();
427 fi.set_input(self.local_player_id, local_input);
428
429 for player in 0..self.player_count {
431 if player == self.local_player_id { continue; }
432 let key = (frame, player);
433 let input = if let Some(&remote) = self.pending_remote.get(&key) {
434 self.pending_remote.remove(&key);
435 fi.confirmed[player as usize] = true;
436 remote
437 } else {
438 self.predictor.predict(player, frame)
439 };
440 fi.inputs[player as usize] = input;
441 }
442
443 let cs = state.checksum();
445 self.desync_detector.record_local(frame, cs);
446
447 fi
448 }
449
450 pub fn advance(&mut self, state: &mut S) -> FrameInput {
452 self.buffer.save_snapshot(self.current_frame, state);
454
455 let fi = self.build_frame_input(state);
456 self.buffer.save_inputs(fi.clone());
457
458 state.advance(&fi);
459 self.net_stats.update_frame(self.current_frame);
460 self.current_frame += 1;
461 fi
462 }
463
464 pub fn check_rollback(&mut self) -> Option<Frame> {
467 let earliest_incorrect = self.pending_remote.keys()
468 .filter(|(frame, _)| *frame < self.current_frame)
469 .map(|(frame, _)| *frame)
470 .min()?;
471
472 if earliest_incorrect < self.current_frame {
473 Some(earliest_incorrect)
474 } else {
475 None
476 }
477 }
478
479 pub fn rollback_to(&mut self, target_frame: Frame, state: &mut S) -> bool {
482 let snapshot = match self.buffer.get_snapshot(target_frame) {
483 Some(s) => s.clone(),
484 None => return false,
485 };
486
487 *state = snapshot.state;
488 let resim_start = target_frame;
489
490 let inputs: Vec<FrameInput> = self.buffer
492 .get_inputs_from(resim_start)
493 .iter()
494 .map(|fi| (*fi).clone())
495 .collect();
496
497 let inputs_len = inputs.len();
499 for mut fi in inputs {
500 for player in 0..self.player_count {
501 let key = (fi.frame, player);
502 if let Some(&confirmed) = self.pending_remote.get(&key) {
503 fi.inputs[player as usize] = confirmed;
504 fi.confirmed[player as usize] = true;
505 self.pending_remote.remove(&key);
506 self.predictor.confirm_input(player, confirmed);
507 }
508 }
509 state.advance(&fi);
510 }
511
512 self.current_frame = target_frame + inputs_len as Frame;
513 self.rollback_count += 1;
514 true
515 }
516
517 pub fn rollback_count(&self) -> u64 { self.rollback_count }
518 pub fn frames_behind(&self) -> u64 {
519 self.current_frame.saturating_sub(self.confirmed_frame)
520 }
521}
522
523#[derive(Debug, Clone)]
526pub struct InputPacket {
527 pub from_player: PlayerId,
528 pub frame: Frame,
529 pub inputs: Vec<(Frame, PlayerInput)>, pub checksum: u32,
531 pub ack_frame: Frame, }
533
534impl InputPacket {
535 pub fn new(player: PlayerId, frame: Frame) -> Self {
536 Self { from_player: player, frame, inputs: Vec::new(), checksum: 0, ack_frame: 0 }
537 }
538
539 pub fn add_input(&mut self, frame: Frame, input: PlayerInput) {
540 self.inputs.push((frame, input));
541 }
542
543 pub fn to_bytes(&self) -> Vec<u8> {
544 let mut buf = Vec::new();
545 buf.push(self.from_player);
546 buf.extend_from_slice(&self.frame.to_le_bytes());
547 buf.extend_from_slice(&self.ack_frame.to_le_bytes());
548 buf.push(self.inputs.len() as u8);
549 for (frame, input) in &self.inputs {
550 buf.extend_from_slice(&frame.to_le_bytes());
551 buf.extend_from_slice(&input.to_bytes());
552 }
553 buf.extend_from_slice(&self.checksum.to_le_bytes());
554 buf
555 }
556
557 pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
558 if bytes.len() < 18 { return None; }
559 let from_player = bytes[0];
560 let frame = Frame::from_le_bytes(bytes[1..9].try_into().ok()?);
561 let ack_frame = Frame::from_le_bytes(bytes[9..17].try_into().ok()?);
562 let count = bytes[17] as usize;
563 let mut inputs = Vec::new();
564 let mut offset = 18;
565 for _ in 0..count {
566 if offset + 24 > bytes.len() { break; }
567 let f = Frame::from_le_bytes(bytes[offset..offset+8].try_into().ok()?);
568 let inp_bytes: &[u8; 16] = bytes[offset+8..offset+24].try_into().ok()?;
569 let inp = PlayerInput::from_bytes(inp_bytes);
570 inputs.push((f, inp));
571 offset += 24;
572 }
573 let checksum = if offset + 4 <= bytes.len() {
574 u32::from_le_bytes(bytes[offset..offset+4].try_into().unwrap_or([0;4]))
575 } else { 0 };
576 Some(Self { from_player, frame, inputs, checksum, ack_frame })
577 }
578}
579
580#[cfg(test)]
583mod tests {
584 use super::*;
585
586 #[derive(Clone)]
587 struct TestState {
588 frame: Frame,
589 value: i64,
590 }
591
592 impl GameState for TestState {
593 fn advance(&mut self, inputs: &FrameInput) {
594 self.frame += 1;
595 if inputs.inputs[0].is_held(1) { self.value += 1; }
596 }
597 fn checksum(&self) -> u64 { self.value as u64 ^ (self.frame << 32) }
598 }
599
600 #[test]
601 fn test_player_input_roundtrip() {
602 let inp = PlayerInput { buttons: 0b1010, axis_x: 1000, axis_y: -500, frame: 42, ..Default::default() };
603 let bytes = inp.to_bytes();
604 let rt = PlayerInput::from_bytes(&bytes);
605 assert_eq!(rt.buttons, inp.buttons);
606 assert_eq!(rt.axis_x, inp.axis_x);
607 assert_eq!(rt.axis_y, inp.axis_y);
608 }
609
610 #[test]
611 fn test_predictor_streak() {
612 let mut pred = InputPredictor::new();
613 let inp = PlayerInput { buttons: 0b0001, frame: 5, ..Default::default() };
614 pred.confirm_input(1, inp);
615 let p = pred.predict(1, 6);
616 assert_eq!(p.buttons, 0b0001);
617 assert_eq!(pred.streak(1), 1);
618 for i in 7..20 { pred.predict(1, i); }
620 let neutral = pred.predict(1, 20);
621 assert_eq!(neutral.buttons, 0);
622 }
623
624 #[test]
625 fn test_rollback_session_advance() {
626 let mut session: RollbackSession<TestState> = RollbackSession::new(1, 0);
627 let mut state = TestState { frame: 0, value: 0 };
628
629 for _ in 0..5 {
631 session.queue_local_input(PlayerInput { buttons: 1, ..Default::default() });
632 }
633 for _ in 0..5 {
634 session.advance(&mut state);
635 }
636 assert_eq!(session.current_frame, 5);
637 }
638
639 #[test]
640 fn test_network_stats_rtt() {
641 let mut stats = NetworkStats::new();
642 for rtt in [20.0, 24.0, 22.0, 18.0, 21.0] {
643 stats.record_rtt(0, rtt);
644 }
645 let avg = stats.average_rtt();
646 assert!(avg > 18.0 && avg < 25.0);
647 }
648
649 #[test]
650 fn test_input_packet_roundtrip() {
651 let mut pkt = InputPacket::new(0, 100);
652 pkt.add_input(100, PlayerInput { buttons: 3, ..Default::default() });
653 let bytes = pkt.to_bytes();
654 let rt = InputPacket::from_bytes(&bytes).unwrap();
655 assert_eq!(rt.from_player, 0);
656 assert_eq!(rt.frame, 100);
657 assert_eq!(rt.inputs.len(), 1);
658 }
659
660 #[test]
661 fn test_desync_detector() {
662 let mut dd = DesyncDetector::new(4);
663 dd.record_local(10, 0xABCDEF);
664 dd.record_remote(10, 1, 0xABCDEF);
665 assert!(!dd.has_desync());
666 dd.record_remote(10, 1, 0x000000); dd.record_local(11, 0x111111);
670 dd.record_remote(11, 1, 0x222222);
671 assert!(dd.has_desync());
672 }
673}