lightyear_deterministic_replication 0.27.0

Primitives for deterministic replication (as opposed to state replication) in the lightyear networking library
Documentation
//! Utilities for computing checksums for data integrity verification.
//!
//! The clients will send checksums at regular intervals to the server, which will verify them against its own computed checksums.
//!
//! Note: we don't have a good way to guarantee that we are iterating through entities in a stable order on both client and server.
//! Because of this, we will compute an order-independent checksum by only hashing component data and then XOR-ing the results together.

use crate::archetypes::ChecksumWorld;
#[cfg(all(feature = "client", feature = "replication"))]
use crate::late_join::CatchUpManager;
use crate::plugin::DeterministicReplicationPlugin;
use alloc::collections::BTreeMap;
#[cfg(feature = "server")]
use bevy_app::FixedLast;
use bevy_app::{App, Plugin, PostUpdate};
use bevy_ecs::prelude::*;
use core::hash::Hasher;
#[cfg(feature = "client")]
use lightyear_connection::client::Client;
#[cfg(feature = "server")]
use lightyear_connection::client::Connected;
use lightyear_connection::direction::NetworkDirection;
#[cfg(feature = "server")]
use lightyear_connection::server::Started;
#[cfg(feature = "server")]
use lightyear_core::id::RemoteId;
use lightyear_core::prelude::LocalTimeline;
use lightyear_core::tick::Tick;
#[cfg(feature = "client")]
use lightyear_inputs::InputChannel;
#[cfg(feature = "client")]
use lightyear_inputs::client::InputSystems;
#[cfg(feature = "server")]
use lightyear_link::server::{LinkOf, Server};
#[cfg(feature = "client")]
use lightyear_messages::plugin::MessageSystems;
use lightyear_messages::prelude::AppMessageExt;
#[cfg(feature = "client")]
use lightyear_messages::prelude::MessageSender;
#[cfg(feature = "server")]
use lightyear_messages::receive::MessageReceiver;
#[cfg(feature = "client")]
use lightyear_prediction::manager::{LastConfirmedInput, StateRollbackMetadata};
#[cfg(feature = "client")]
use lightyear_sync::prelude::{InputTimeline, IsSynced};
use serde::{Deserialize, Serialize};
#[cfg(feature = "server")]
use tracing::error;
use tracing::{debug, trace};

/// History of the checksums on the server to validate client checksums against.
#[derive(Component, Debug, Default)]
pub struct ChecksumHistory {
    history: BTreeMap<Tick, u64>,
}

/// Plugin that can be added to clients to compute and send checksums for all deterministic entities with hashable components.
///
/// The server will receive these checksums and verify them against its own computed checksums.
/// If a checksum does not match, it indicates a desync between the client and server.
#[cfg(feature = "client")]
pub struct ChecksumSendPlugin;

#[cfg(feature = "client")]
impl ChecksumSendPlugin {
    /// Compute a checksum over all deterministic entities' hashable
    /// components at `LastConfirmedInput.tick` and send it to the server.
    fn compute_and_send_checksum(
        mut world: ChecksumWorld<'_, '_, true>,
        local_timeline: Res<LocalTimeline>,
        client: Single<
            (&LastConfirmedInput, &mut MessageSender<ChecksumMessage>),
            (With<Client>, With<IsSynced<InputTimeline>>),
        >,
        #[cfg(feature = "replication")] catchup_manager: Option<
            Single<&CatchUpManager, With<Client>>,
        >,
        state_metadata: Res<StateRollbackMetadata>,
    ) {
        let mut checksum = 0u64;
        let current_tick = local_timeline.tick();
        let (last_confirmed_input, mut sender) = client.into_inner();
        let tick = last_confirmed_input.tick.get();
        // only compute the checksum when we have received remote inputs
        if tick > current_tick {
            return;
        }
        #[cfg(feature = "replication")]
        // Skip while catch-up is running. The client is intentionally hashing
        // pre-catch-up state until the bundled snapshot has been replayed.
        if catchup_manager.is_some_and(|manager| manager.suppresses_checksums()) {
            return;
        }
        // Skip if a one-shot forced rollback is scheduled but not yet
        // consumed. Until the rollback has replayed from the bundled
        // snapshot tick, the client is intentionally hashing pre-catch-up
        // state that the server should not compare against.
        if state_metadata.forced_rollback_tick().is_some() {
            return;
        }

        world.update_archetypes();
        // SAFETY: world.update_archetypes() has been called
        unsafe { world.iter_archetypes() }.for_each(|(archetype, checksum_archetype)| {
            // TODO: guarantee stable entity iteration order across peers.
            archetype.entities().iter().for_each(|entity| {
                checksum_archetype.components.iter().for_each(|(component_id, storage_type)| {
                    trace!("Adding component {:?} from entity {:?} to checksum for tick {:?}",
                        component_id, entity.id(), tick);
                    // SAFETY: the way we constructed the archetypes guarantees that the component exists on the entity and we have unique write access
                    let history_ptr = unsafe {
                        lightyear_utils::ecs::get_component_unchecked_mut(world.world, entity, archetype.table_id(), *storage_type, *component_id)
                    };
                    let (hash_fn, pop_until_tick_and_hash_fn) = world.state.hash_fns.get(component_id).expect("Component in checksum archetype must have a hash function registered");

                    let mut hasher = seahash::SeaHasher::default();
                    pop_until_tick_and_hash_fn.unwrap()(history_ptr, tick, &mut hasher, hash_fn.inner);
                    let hash = hasher.finish();
                    checksum ^= hash; // XOR the hashes together to get an order-independent checksum
                });
            });
        });
        debug!(
            ?current_tick,
            "Computed checksum for LastConfirmedInput tick {:?}: {:016x}", tick, checksum
        );

        sender.send::<InputChannel>(ChecksumMessage { tick, checksum });
    }
}

#[derive(Serialize, Deserialize)]
pub struct ChecksumMessage {
    pub tick: Tick,
    pub checksum: u64,
}

#[cfg(feature = "client")]
impl Plugin for ChecksumSendPlugin {
    fn build(&self, app: &mut App) {
        if !app.is_plugin_added::<DeterministicReplicationPlugin>() {
            app.add_plugins(DeterministicReplicationPlugin);
        }

        // we need the LastConfirmedInput to compute the checksums
        app.register_required_components::<InputTimeline, LastConfirmedInput>();

        if !app.is_message_registered::<ChecksumMessage>() {
            app.register_message::<ChecksumMessage>()
                .add_direction(NetworkDirection::ClientToServer);
        }
    }

    fn finish(&self, app: &mut App) {
        app.add_systems(
            PostUpdate,
            ChecksumSendPlugin::compute_and_send_checksum
                // the LastConfirmedInput must be updated before we compute the checksum
                .after(InputSystems::UpdateRemoteInputTicks)
                .before(MessageSystems::Send),
        );
    }
}

/// Plugin that can be added to the server to receive and validate checksums sent by clients.
///
/// The server needs to also run the simulation to be able to compute its own checksums for comparison.
#[cfg(feature = "server")]
pub struct ChecksumReceivePlugin;

#[cfg(feature = "server")]
impl ChecksumReceivePlugin {
    /// Compute a checksum over all deterministic entities' hashable
    /// components at the current server tick and store it for later
    /// comparison against incoming client checksums.
    fn compute_and_store_checksum(
        mut world: ChecksumWorld<'_, '_, false>,
        timeline: Res<LocalTimeline>,
        server: Single<&mut ChecksumHistory, With<Started>>,
    ) {
        let mut checksum = 0u64;
        let tick = timeline.tick();
        let mut history = server.into_inner();

        world.update_archetypes();
        // SAFETY: world.update_archetypes() has been called
        unsafe { world.iter_archetypes() }.for_each(|(archetype, checksum_archetype)| {
            // TODO: guarantee stable entity iteration order across peers.
            archetype.entities().iter().for_each(|entity| {
                checksum_archetype
                    .components
                    .iter()
                    .for_each(|(component_id, storage_type)| {
                        trace!(
                            "Adding component {:?} from entity {:?} to checksum for tick {:?}",
                            component_id,
                            entity.id(),
                            tick
                        );
                        // SAFETY: the way we constructed the archetypes guarantees that the component exists on the entity and we have unique write access
                        let component_ptr = unsafe {
                            lightyear_utils::ecs::get_component_unchecked(
                                world.world,
                                entity,
                                archetype.table_id(),
                                *storage_type,
                                *component_id,
                            )
                        };
                        let (hash_fn, _) = world.state.hash_fns.get(component_id).expect(
                            "Component in checksum archetype must have a hash function registered",
                        );
                        let mut hasher = seahash::SeaHasher::default();
                        hash_fn.hash_component(component_ptr, &mut hasher);
                        let hash = hasher.finish();
                        checksum ^= hash; // XOR the hashes together to get an order-independent checksum
                    });
            });
        });

        debug!("Computed checksum for tick {:?}: {:016x}", tick, checksum);

        history.history.insert(tick, checksum);
    }

    fn receive_checksum_message(
        mut messages: Query<
            (&mut MessageReceiver<ChecksumMessage>, &LinkOf, &RemoteId),
            With<Connected>,
        >,
        server: Query<&ChecksumHistory, (With<Server>, With<Started>)>,
    ) {
        messages.iter_mut().for_each(|(mut receiver, link_of, remote_id)| {
            if let Ok(history) = server.get(link_of.server) {
                receiver.receive().for_each(|message| {
                    let Some(&expected) = history.history.get(&message.tick) else {
                        return;
                    };
                    if expected == message.checksum {
                        debug!("Checksum match from client {:?} at tick {:?}: {:016x}", remote_id, message.tick, message.checksum);
                    } else if message.checksum != 0 {
                        error!("Checksum mismatch from client {:?} at tick {:?}: expected {:016x}, got {:016x}", remote_id, message.tick, expected, message.checksum);
                    }
                })
            }
        })
    }

    fn clean_history(
        timeline: Res<LocalTimeline>,
        history: Single<&mut ChecksumHistory, (With<Server>, With<Started>)>,
    ) {
        let tick = timeline.tick();
        let mut history = history.into_inner();
        // keep only the last 30 ticks of history
        history.history.retain(|t, _| *t >= tick - 30);
    }
}

#[cfg(feature = "server")]
impl Plugin for ChecksumReceivePlugin {
    fn build(&self, app: &mut App) {
        if !app.is_plugin_added::<DeterministicReplicationPlugin>() {
            app.add_plugins(DeterministicReplicationPlugin);
        }

        // the server will check the checksum validity
        app.register_required_components::<Server, ChecksumHistory>();

        if !app.is_message_registered::<ChecksumMessage>() {
            app.register_message::<ChecksumMessage>()
                .add_direction(NetworkDirection::ClientToServer);
        }

        app.add_systems(
            PostUpdate,
            (
                ChecksumReceivePlugin::clean_history,
                ChecksumReceivePlugin::receive_checksum_message,
            ),
        );
    }

    fn finish(&self, app: &mut App) {
        app.add_systems(FixedLast, ChecksumReceivePlugin::compute_and_store_checksum);
    }
}