use crate::frame_info::PlayerInput;
use crate::network::compression::{decode, encode};
use crate::network::messages::{
ChecksumReport, ConnectionStatus, Input, InputAck, Message, MessageBody, MessageHeader,
QualityReply, QualityReport, SyncReply, SyncRequest,
};
use crate::time_sync::TimeSync;
use crate::{
Config, DesyncDetection, Frame, GgrsError, NonBlockingSocket, PlayerHandle, NULL_FRAME,
};
use tracing::{trace, warn};
use instant::{Duration, Instant};
use std::collections::vec_deque::Drain;
use std::collections::{HashMap, HashSet, VecDeque};
use std::convert::TryFrom;
use std::ops::Add;
use super::network_stats::NetworkStats;
const UDP_HEADER_SIZE: usize = 28; const NUM_SYNC_PACKETS: u32 = 5;
const UDP_SHUTDOWN_TIMER: u64 = 5000;
const PENDING_OUTPUT_SIZE: usize = 128;
const SYNC_RETRY_INTERVAL: Duration = Duration::from_millis(200);
const RUNNING_RETRY_INTERVAL: Duration = Duration::from_millis(200);
const KEEP_ALIVE_INTERVAL: Duration = Duration::from_millis(200);
const QUALITY_REPORT_INTERVAL: Duration = Duration::from_millis(200);
pub const MAX_CHECKSUM_HISTORY_SIZE: usize = 32;
fn millis_since_epoch() -> u128 {
#[cfg(not(target_arch = "wasm32"))]
{
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("Time went backwards")
.as_millis()
}
#[cfg(target_arch = "wasm32")]
{
js_sys::Date::new_0().get_time() as u128
}
}
#[derive(Clone)]
struct InputBytes {
pub frame: Frame,
pub bytes: Vec<u8>,
}
impl InputBytes {
fn zeroed<T: Config>(num_players: usize) -> Self {
let input_size =
bincode::serialized_size(&T::Input::default()).expect("input serialization failed");
let size = (input_size as usize) * num_players;
Self {
frame: NULL_FRAME,
bytes: vec![0; size],
}
}
fn from_inputs<T: Config>(
num_players: usize,
inputs: &HashMap<PlayerHandle, PlayerInput<T::Input>>,
) -> Self {
let mut bytes = Vec::new();
let mut frame = NULL_FRAME;
for handle in 0..num_players {
if let Some(input) = inputs.get(&handle) {
assert!(frame == NULL_FRAME || input.frame == NULL_FRAME || frame == input.frame);
if input.frame != NULL_FRAME {
frame = input.frame;
}
bincode::serialize_into(&mut bytes, &input.input)
.expect("input serialization failed");
}
}
Self { frame, bytes }
}
fn to_player_inputs<T: Config>(&self, num_players: usize) -> Vec<PlayerInput<T::Input>> {
let mut player_inputs = Vec::new();
assert!(self.bytes.len().is_multiple_of(num_players));
let size = self.bytes.len() / num_players;
for p in 0..num_players {
let start = p * size;
let end = start + size;
let player_byte_slice = &self.bytes[start..end];
let input: T::Input =
bincode::deserialize(player_byte_slice).expect("input deserialization failed");
player_inputs.push(PlayerInput::new(self.frame, input));
}
player_inputs
}
}
#[derive(Debug, Clone, PartialEq)]
pub(crate) enum Event<T>
where
T: Config,
{
Synchronizing { total: u32, count: u32 },
Synchronized,
Input {
input: PlayerInput<T::Input>,
player: PlayerHandle,
},
Disconnected,
NetworkInterrupted { disconnect_timeout: u128 },
NetworkResumed,
}
#[derive(Debug, PartialEq, Eq)]
enum ProtocolState {
Initializing,
Synchronizing,
Running,
Disconnected,
Shutdown,
}
pub(crate) struct UdpProtocol<T>
where
T: Config,
{
num_players: usize,
handles: Vec<PlayerHandle>,
send_queue: VecDeque<Message>,
event_queue: VecDeque<Event<T>>,
state: ProtocolState,
sync_remaining_roundtrips: u32,
sync_random_requests: HashSet<u32>,
running_last_quality_report: Instant,
running_last_input_recv: Instant,
disconnect_notify_sent: bool,
disconnect_event_sent: bool,
disconnect_timeout: Duration,
disconnect_notify_start: Duration,
shutdown_timeout: Instant,
fps: usize,
magic: u16,
peer_addr: T::Address,
remote_magic: u16,
peer_connect_status: Vec<ConnectionStatus>,
pending_output: VecDeque<InputBytes>,
last_acked_input: InputBytes,
max_prediction: usize,
recv_inputs: HashMap<Frame, InputBytes>,
time_sync_layer: TimeSync,
local_frame_advantage: i32,
remote_frame_advantage: i32,
stats_start_time: u128,
packets_sent: usize,
bytes_sent: usize,
round_trip_time: u128,
last_send_time: Instant,
last_sync_request_time: Instant,
last_recv_time: Instant,
pub(crate) pending_checksums: HashMap<Frame, u128>,
desync_detection: DesyncDetection,
}
impl<T: Config> PartialEq for UdpProtocol<T> {
fn eq(&self, other: &Self) -> bool {
self.peer_addr == other.peer_addr
}
}
impl<T: Config> UdpProtocol<T> {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
mut handles: Vec<PlayerHandle>,
peer_addr: T::Address,
num_players: usize,
local_players: usize,
max_prediction: usize,
disconnect_timeout: Duration,
disconnect_notify_start: Duration,
fps: usize,
desync_detection: DesyncDetection,
) -> Self {
let mut magic = rand::random::<u16>();
while magic == 0 {
magic = rand::random::<u16>();
}
handles.sort_unstable();
let recv_player_num = handles.len();
let mut peer_connect_status = Vec::new();
for _ in 0..num_players {
peer_connect_status.push(ConnectionStatus::default());
}
let mut recv_inputs = HashMap::new();
recv_inputs.insert(NULL_FRAME, InputBytes::zeroed::<T>(recv_player_num));
Self {
num_players,
handles,
send_queue: VecDeque::new(),
event_queue: VecDeque::new(),
state: ProtocolState::Initializing,
sync_remaining_roundtrips: NUM_SYNC_PACKETS,
sync_random_requests: HashSet::new(),
running_last_quality_report: Instant::now(),
running_last_input_recv: Instant::now(),
disconnect_notify_sent: false,
disconnect_event_sent: false,
disconnect_timeout,
disconnect_notify_start,
shutdown_timeout: Instant::now(),
fps,
magic,
peer_addr,
remote_magic: 0,
peer_connect_status,
pending_output: VecDeque::with_capacity(PENDING_OUTPUT_SIZE),
last_acked_input: InputBytes::zeroed::<T>(local_players),
max_prediction,
recv_inputs,
time_sync_layer: TimeSync::new(),
local_frame_advantage: 0,
remote_frame_advantage: 0,
stats_start_time: 0,
packets_sent: 0,
bytes_sent: 0,
round_trip_time: 0,
last_send_time: Instant::now(),
last_sync_request_time: Instant::now(),
last_recv_time: Instant::now(),
pending_checksums: HashMap::new(),
desync_detection,
}
}
pub(crate) fn update_local_frame_advantage(&mut self, local_frame: Frame) {
if local_frame == NULL_FRAME || self.last_recv_frame() == NULL_FRAME {
return;
}
let ping = i32::try_from(self.round_trip_time / 2).unwrap_or(i32::MAX);
let remote_frame = self.last_recv_frame() + ((ping * self.fps as i32) / 1000);
self.local_frame_advantage = remote_frame - local_frame;
}
pub(crate) fn network_stats(&self) -> Result<NetworkStats, GgrsError> {
if self.state != ProtocolState::Synchronizing && self.state != ProtocolState::Running {
return Err(GgrsError::NotSynchronized);
}
let now = millis_since_epoch();
let seconds = (now - self.stats_start_time) / 1000;
if seconds == 0 {
return Err(GgrsError::NotEnoughData);
}
let total_bytes_sent = self.bytes_sent + (self.packets_sent * UDP_HEADER_SIZE);
let bps = total_bytes_sent / seconds as usize;
Ok(NetworkStats {
ping: self.round_trip_time,
send_queue_len: self.pending_output.len(),
kbps_sent: bps / 1024,
local_frames_behind: self.local_frame_advantage,
remote_frames_behind: self.remote_frame_advantage,
})
}
pub(crate) fn handles(&self) -> &Vec<PlayerHandle> {
&self.handles
}
pub(crate) fn is_synchronized(&self) -> bool {
self.state == ProtocolState::Running
|| self.state == ProtocolState::Disconnected
|| self.state == ProtocolState::Shutdown
}
pub(crate) fn is_running(&self) -> bool {
self.state == ProtocolState::Running
}
pub(crate) fn is_handling_message(&self, addr: &T::Address) -> bool {
self.peer_addr == *addr
}
pub(crate) fn peer_connect_status(&self, handle: PlayerHandle) -> ConnectionStatus {
self.peer_connect_status[handle]
}
pub(crate) fn disconnect(&mut self) {
if self.state == ProtocolState::Shutdown {
return;
}
self.state = ProtocolState::Disconnected;
self.shutdown_timeout = Instant::now().add(Duration::from_millis(UDP_SHUTDOWN_TIMER));
}
pub(crate) fn synchronize(&mut self) {
assert_eq!(self.state, ProtocolState::Initializing);
self.state = ProtocolState::Synchronizing;
self.sync_remaining_roundtrips = NUM_SYNC_PACKETS;
self.stats_start_time = millis_since_epoch();
self.send_sync_request();
}
pub(crate) fn average_frame_advantage(&self) -> i32 {
self.time_sync_layer.average_frame_advantage()
}
pub(crate) fn peer_addr(&self) -> T::Address {
self.peer_addr.clone()
}
pub(crate) fn poll(&mut self, connect_status: &[ConnectionStatus]) -> Drain<'_, Event<T>> {
let now = Instant::now();
match self.state {
ProtocolState::Synchronizing => {
if self.last_sync_request_time + SYNC_RETRY_INTERVAL < now {
self.send_sync_request();
}
}
ProtocolState::Running => {
if self.running_last_input_recv + RUNNING_RETRY_INTERVAL < now {
self.send_pending_output(connect_status);
self.running_last_input_recv = Instant::now();
}
if self.running_last_quality_report + QUALITY_REPORT_INTERVAL < now {
self.send_quality_report();
}
if self.last_send_time + KEEP_ALIVE_INTERVAL < now {
self.send_keep_alive();
}
if !self.disconnect_notify_sent
&& self.last_recv_time + self.disconnect_notify_start < now
{
let duration: Duration = self
.disconnect_timeout
.saturating_sub(self.disconnect_notify_start);
self.event_queue.push_back(Event::NetworkInterrupted {
disconnect_timeout: Duration::as_millis(&duration),
});
self.disconnect_notify_sent = true;
}
if !self.disconnect_event_sent
&& self.last_recv_time + self.disconnect_timeout < now
{
self.event_queue.push_back(Event::Disconnected);
self.disconnect_event_sent = true;
}
}
ProtocolState::Disconnected => {
if self.shutdown_timeout < Instant::now() {
self.state = ProtocolState::Shutdown;
}
}
ProtocolState::Initializing | ProtocolState::Shutdown => (),
}
self.event_queue.drain(..)
}
fn pop_pending_output(&mut self, ack_frame: Frame) {
while !self.pending_output.is_empty() {
if let Some(input) = self.pending_output.front() {
if input.frame <= ack_frame {
self.last_acked_input = self
.pending_output
.pop_front()
.expect("Expected input to exist");
} else {
break;
}
}
}
}
pub(crate) fn send_all_messages(
&mut self,
socket: &mut Box<dyn NonBlockingSocket<T::Address>>,
) {
if self.state == ProtocolState::Shutdown {
trace!(
"Protocol is shutting down; dropping {} messages",
self.send_queue.len()
);
self.send_queue.drain(..);
return;
}
if self.send_queue.is_empty() {
return;
}
trace!("Sending {} messages over socket", self.send_queue.len());
for msg in self.send_queue.drain(..) {
socket.send_to(&msg, &self.peer_addr);
}
}
pub(crate) fn send_input(
&mut self,
inputs: &HashMap<PlayerHandle, PlayerInput<T::Input>>,
connect_status: &[ConnectionStatus],
) {
if self.state != ProtocolState::Running {
return;
}
let endpoint_data = InputBytes::from_inputs::<T>(self.num_players, inputs);
self.time_sync_layer.advance_frame(
endpoint_data.frame,
self.local_frame_advantage,
self.remote_frame_advantage,
);
self.pending_output.push_back(endpoint_data);
if self.pending_output.len() > PENDING_OUTPUT_SIZE {
self.event_queue.push_back(Event::Disconnected);
}
self.send_pending_output(connect_status);
}
fn send_pending_output(&mut self, connect_status: &[ConnectionStatus]) {
let mut body = Input::default();
if let Some(input) = self.pending_output.front() {
assert!(
self.last_acked_input.frame == NULL_FRAME
|| self.last_acked_input.frame + 1 == input.frame
);
body.start_frame = input.frame;
body.bytes = encode(
&self.last_acked_input.bytes,
self.pending_output.iter().map(|gi| &gi.bytes),
);
trace!(
"Encoded {} bytes from {} pending output(s) into {} bytes",
{
let mut sum = 0;
for gi in &self.pending_output {
sum += gi.bytes.len();
}
sum
},
self.pending_output.len(),
body.bytes.len()
);
body.ack_frame = self.last_recv_frame();
body.disconnect_requested = self.state == ProtocolState::Disconnected;
connect_status.clone_into(&mut body.peer_connect_status);
self.queue_message(MessageBody::Input(body));
}
}
fn send_input_ack(&mut self) {
let body = InputAck {
ack_frame: self.last_recv_frame(),
};
self.queue_message(MessageBody::InputAck(body));
}
fn send_keep_alive(&mut self) {
self.queue_message(MessageBody::KeepAlive);
}
fn send_sync_request(&mut self) {
self.last_sync_request_time = Instant::now();
let random_number = rand::random::<u32>();
self.sync_random_requests.insert(random_number);
let body = SyncRequest {
random_request: random_number,
};
self.queue_message(MessageBody::SyncRequest(body));
}
fn send_quality_report(&mut self) {
self.running_last_quality_report = Instant::now();
let body = QualityReport {
frame_advantage: i16::try_from(
self.local_frame_advantage
.clamp(i32::from(i16::MIN), i32::from(i16::MAX)),
)
.expect("local_frame_advantage should have been clamped into the range of an i16"),
ping: millis_since_epoch(),
};
self.queue_message(MessageBody::QualityReport(body));
}
fn queue_message(&mut self, body: MessageBody) {
trace!("Queuing message to {:?}: {:?}", self.peer_addr, body);
let header = MessageHeader { magic: self.magic };
let msg = Message { header, body };
self.packets_sent += 1;
self.last_send_time = Instant::now();
self.bytes_sent += std::mem::size_of_val(&msg);
self.send_queue.push_back(msg);
}
pub(crate) fn handle_message(&mut self, msg: &Message) {
trace!("Handling message from {:?}: {:?}", self.peer_addr, msg);
if self.state == ProtocolState::Shutdown {
trace!("Protocol is shutting down; ignoring message");
return;
}
if self.remote_magic != 0 && msg.header.magic != self.remote_magic {
trace!("Received message with wrong magic; ignoring");
return;
}
self.last_recv_time = Instant::now();
if self.disconnect_notify_sent && self.state == ProtocolState::Running {
trace!("Received message on interrupted protocol; sending NetworkResumed event");
self.disconnect_notify_sent = false;
self.event_queue.push_back(Event::NetworkResumed);
}
match &msg.body {
MessageBody::SyncRequest(body) => self.on_sync_request(*body),
MessageBody::SyncReply(body) => self.on_sync_reply(msg.header, *body),
MessageBody::Input(body) => self.on_input(body),
MessageBody::InputAck(body) => self.on_input_ack(*body),
MessageBody::QualityReport(body) => self.on_quality_report(body),
MessageBody::QualityReply(body) => self.on_quality_reply(body),
MessageBody::ChecksumReport(body) => self.on_checksum_report(body),
MessageBody::KeepAlive => (),
}
}
fn on_sync_request(&mut self, body: SyncRequest) {
let reply_body = SyncReply {
random_reply: body.random_request,
};
self.queue_message(MessageBody::SyncReply(reply_body));
}
fn on_sync_reply(&mut self, header: MessageHeader, body: SyncReply) {
if self.state != ProtocolState::Synchronizing {
return;
}
if !self.sync_random_requests.remove(&body.random_reply) {
return;
}
self.sync_remaining_roundtrips -= 1;
if self.sync_remaining_roundtrips > 0 {
let evt = Event::Synchronizing {
total: NUM_SYNC_PACKETS,
count: NUM_SYNC_PACKETS - self.sync_remaining_roundtrips,
};
self.event_queue.push_back(evt);
self.send_sync_request();
} else {
self.state = ProtocolState::Running;
self.event_queue.push_back(Event::Synchronized);
self.remote_magic = header.magic;
}
}
fn on_input(&mut self, body: &Input) {
self.pop_pending_output(body.ack_frame);
if body.disconnect_requested {
if self.state != ProtocolState::Disconnected && !self.disconnect_event_sent {
self.event_queue.push_back(Event::Disconnected);
self.disconnect_event_sent = true;
}
} else {
for i in 0..self.peer_connect_status.len() {
self.peer_connect_status[i].disconnected = body.peer_connect_status[i].disconnected
|| self.peer_connect_status[i].disconnected;
self.peer_connect_status[i].last_frame = std::cmp::max(
self.peer_connect_status[i].last_frame,
body.peer_connect_status[i].last_frame,
);
}
}
assert!(
self.last_recv_frame() == NULL_FRAME || self.last_recv_frame() + 1 >= body.start_frame
);
let decode_frame = if self.last_recv_frame() == NULL_FRAME {
NULL_FRAME
} else {
body.start_frame - 1
};
if let Some(decode_inp) = self.recv_inputs.get(&decode_frame) {
self.running_last_input_recv = Instant::now();
let recv_inputs = match decode(&decode_inp.bytes, &body.bytes) {
Ok(inputs) => inputs,
Err(e) => {
warn!("Failed to decode input packet, discarding: {e}");
return;
}
};
for (i, inp) in recv_inputs.into_iter().enumerate() {
let inp_frame = body.start_frame + i as i32;
if inp_frame <= self.last_recv_frame() {
continue;
}
let input_data = InputBytes {
frame: inp_frame,
bytes: inp,
};
let player_inputs = input_data.to_player_inputs::<T>(self.handles.len());
self.recv_inputs.insert(input_data.frame, input_data);
for (i, player_input) in player_inputs.into_iter().enumerate() {
self.event_queue.push_back(Event::Input {
input: player_input,
player: self.handles[i],
});
}
}
self.send_input_ack();
let last_recv_frame = self.last_recv_frame();
self.recv_inputs
.retain(|&k, _| k >= last_recv_frame - 2 * self.max_prediction as i32);
}
}
fn on_input_ack(&mut self, body: InputAck) {
self.pop_pending_output(body.ack_frame);
}
fn on_quality_report(&mut self, body: &QualityReport) {
self.remote_frame_advantage = i32::from(body.frame_advantage);
let reply_body = QualityReply { pong: body.ping };
self.queue_message(MessageBody::QualityReply(reply_body));
}
fn on_quality_reply(&mut self, body: &QualityReply) {
let millis = millis_since_epoch();
self.round_trip_time = millis.saturating_sub(body.pong);
}
fn on_checksum_report(&mut self, body: &ChecksumReport) {
let interval = if let DesyncDetection::On { interval } = self.desync_detection {
interval
} else {
debug_assert!(
false,
"Received checksum report, but desync detection is off. Check
that configuration is consistent between peers."
);
1
};
if self.pending_checksums.len() >= MAX_CHECKSUM_HISTORY_SIZE {
let oldest_frame_to_keep =
body.frame - (MAX_CHECKSUM_HISTORY_SIZE as i32 - 1) * interval as i32;
self.pending_checksums
.retain(|&frame, _| frame >= oldest_frame_to_keep);
}
self.pending_checksums.insert(body.frame, body.checksum);
}
fn last_recv_frame(&self) -> Frame {
match self.recv_inputs.iter().max_by_key(|&(k, _)| k) {
Some((k, _)) => *k,
None => NULL_FRAME,
}
}
pub(crate) fn send_checksum_report(&mut self, frame_to_send: Frame, checksum: u128) {
let body = ChecksumReport {
frame: frame_to_send,
checksum,
};
self.queue_message(MessageBody::ChecksumReport(body));
}
}