use std::collections::{HashMap, HashSet, VecDeque};
use std::sync::Mutex;
use async_trait::async_trait;
use crate::error::{EncodeError, TransportError, WorldError};
use crate::events::{ComponentUpdate, NetworkEvent, ReplicationEvent};
use crate::traits::{Encoder, GameTransport, WorldState};
use crate::types::{ClientId, ComponentKind, LocalId, NetworkId};
#[derive(Debug, Default)]
pub struct MockTransport {
pub connected_clients: Mutex<HashSet<ClientId>>,
pub per_client_unreliable: Mutex<HashMap<ClientId, Vec<Vec<u8>>>>,
pub per_client_reliable: Mutex<HashMap<ClientId, Vec<Vec<u8>>>>,
pub inbound_queue: Mutex<VecDeque<NetworkEvent>>,
}
impl MockTransport {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn inject_event(&self, event: NetworkEvent) {
self.inbound_queue.lock().unwrap().push_back(event);
}
#[must_use]
pub fn take_unreliable(&self, cid: ClientId) -> Vec<Vec<u8>> {
self.per_client_unreliable
.lock()
.unwrap()
.remove(&cid)
.unwrap_or_default()
}
#[must_use]
pub fn take_reliable(&self, cid: ClientId) -> Vec<Vec<u8>> {
self.per_client_reliable
.lock()
.unwrap()
.remove(&cid)
.unwrap_or_default()
}
pub fn connect(&self, client_id: ClientId) {
self.connected_clients.lock().unwrap().insert(client_id);
}
pub fn disconnect(&self, client_id: ClientId) {
self.connected_clients.lock().unwrap().remove(&client_id);
self.per_client_unreliable
.lock()
.unwrap()
.remove(&client_id);
self.per_client_reliable.lock().unwrap().remove(&client_id);
}
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl GameTransport for MockTransport {
async fn send_unreliable(
&self,
client_id: ClientId,
data: &[u8],
) -> Result<(), TransportError> {
if !self.connected_clients.lock().unwrap().contains(&client_id) {
return Err(TransportError::ClientNotConnected(client_id));
}
if data.len() > crate::MAX_SAFE_PAYLOAD_SIZE {
return Err(TransportError::PayloadTooLarge {
size: data.len(),
max: crate::MAX_SAFE_PAYLOAD_SIZE,
});
}
self.per_client_unreliable
.lock()
.unwrap()
.entry(client_id)
.or_default()
.push(data.to_vec());
Ok(())
}
async fn send_reliable(&self, client_id: ClientId, data: &[u8]) -> Result<(), TransportError> {
if !self.connected_clients.lock().unwrap().contains(&client_id) {
return Err(TransportError::ClientNotConnected(client_id));
}
if data.len() > 65535 {
return Err(TransportError::PayloadTooLarge {
size: data.len(),
max: 65535,
});
}
self.per_client_reliable
.lock()
.unwrap()
.entry(client_id)
.or_default()
.push(data.to_vec());
Ok(())
}
async fn broadcast_unreliable(&self, data: &[u8]) -> Result<(), TransportError> {
if data.len() > crate::MAX_SAFE_PAYLOAD_SIZE {
return Err(TransportError::PayloadTooLarge {
size: data.len(),
max: crate::MAX_SAFE_PAYLOAD_SIZE,
});
}
let clients = self.connected_clients.lock().unwrap();
let mut map = self.per_client_unreliable.lock().unwrap();
for &client_id in clients.iter() {
map.entry(client_id).or_default().push(data.to_vec());
}
Ok(())
}
async fn poll_events(&mut self) -> Result<Vec<NetworkEvent>, TransportError> {
let mut queue = self.inbound_queue.lock().unwrap();
Ok(queue.drain(..).collect())
}
async fn connected_client_count(&self) -> usize {
self.connected_clients.lock().unwrap().len()
}
}
#[derive(Debug, Default)]
pub struct MockWorldState {
next_id: u64,
pub forward_bimap: HashMap<NetworkId, LocalId>,
pub reverse_bimap: HashMap<LocalId, NetworkId>,
pub pending_deltas: Mutex<Vec<ReplicationEvent>>,
pub applied_updates: Mutex<Vec<(ClientId, ComponentUpdate)>>,
}
impl MockWorldState {
#[must_use]
pub fn new() -> Self {
Self {
next_id: 1, forward_bimap: HashMap::new(),
reverse_bimap: HashMap::new(),
pending_deltas: Mutex::new(Vec::new()),
applied_updates: Mutex::new(Vec::new()),
}
}
pub fn queue_delta(&self, event: ReplicationEvent) {
self.pending_deltas.lock().unwrap().push(event);
}
}
impl WorldState for MockWorldState {
fn spawn_networked(&mut self) -> NetworkId {
let n_id = NetworkId(self.next_id);
let l_id = LocalId(self.next_id);
self.next_id += 1;
self.forward_bimap.insert(n_id, l_id);
self.reverse_bimap.insert(l_id, n_id);
n_id
}
fn spawn_networked_for(&mut self, _client_id: ClientId) -> NetworkId {
self.spawn_networked()
}
fn despawn_networked(&mut self, network_id: NetworkId) -> Result<(), WorldError> {
if let Some(l_id) = self.forward_bimap.remove(&network_id) {
self.reverse_bimap.remove(&l_id);
Ok(())
} else {
Err(WorldError::EntityNotFound(network_id))
}
}
fn get_local_id(&self, network_id: NetworkId) -> Option<LocalId> {
self.forward_bimap.get(&network_id).copied()
}
fn get_network_id(&self, local_id: LocalId) -> Option<NetworkId> {
self.reverse_bimap.get(&local_id).copied()
}
fn extract_deltas(&mut self) -> Vec<ReplicationEvent> {
let mut queued = self.pending_deltas.lock().unwrap();
std::mem::take(&mut *queued)
}
fn apply_updates(&mut self, updates: &[(ClientId, ComponentUpdate)]) {
self.applied_updates
.lock()
.unwrap()
.extend(updates.iter().cloned());
}
fn simulate(&mut self) {
}
fn stress_test(&mut self, _count: u16, _rotate: bool) {}
fn spawn_kind(&mut self, _kind: u16, _x: f32, _y: f32, _rot: f32) -> NetworkId {
self.spawn_networked()
}
fn clear_world(&mut self) {}
}
#[derive(Debug, Default)]
pub struct MockEncoder;
impl MockEncoder {
pub const MOCK_SENTINEL: u8 = 0xAE;
pub const MOCK_ERROR_BYTE: u8 = 0xFF;
#[must_use]
pub fn new() -> Self {
Self
}
}
impl Encoder for MockEncoder {
fn encode(&self, event: &ReplicationEvent, buffer: &mut [u8]) -> Result<usize, EncodeError> {
let required = 1 + 8 + 2 + 8 + event.payload.len();
if buffer.len() < required {
return Err(EncodeError::BufferOverflow {
needed: required,
available: buffer.len(),
});
}
buffer[0] = Self::MOCK_SENTINEL;
buffer[1..9].copy_from_slice(&event.network_id.0.to_le_bytes());
buffer[9..11].copy_from_slice(&event.component_kind.0.to_le_bytes());
buffer[11..19].copy_from_slice(&event.tick.to_le_bytes());
buffer[19..required].copy_from_slice(&event.payload);
Ok(required)
}
fn decode(&self, buffer: &[u8]) -> Result<ComponentUpdate, EncodeError> {
if buffer.len() < 19 {
return Err(EncodeError::MalformedPayload {
offset: 0,
message: "Buffer too small for mock header".to_string(),
});
}
if buffer[0] == Self::MOCK_ERROR_BYTE {
return Err(EncodeError::MalformedPayload {
offset: 0,
message: "Triggered artificial MOCK_ERROR_BYTE".to_string(),
});
}
if buffer[0] != Self::MOCK_SENTINEL {
return Err(EncodeError::MalformedPayload {
offset: 0,
message: format!(
"Invalid sentinel: expected {:#x}, got {:#x}",
Self::MOCK_SENTINEL,
buffer[0]
),
});
}
let network_id = u64::from_le_bytes(buffer[1..9].try_into().unwrap());
let component_kind = u16::from_le_bytes(buffer[9..11].try_into().unwrap());
let tick = u64::from_le_bytes(buffer[11..19].try_into().unwrap());
Ok(ComponentUpdate {
network_id: NetworkId(network_id),
component_kind: ComponentKind(component_kind),
payload: buffer[19..].to_vec(),
tick,
})
}
fn encode_event(&self, event: &NetworkEvent) -> Result<Vec<u8>, EncodeError> {
match event {
NetworkEvent::Auth { .. } => Ok(vec![b'A']),
_ => Err(EncodeError::Io(std::io::Error::other(format!(
"MockEncoder: encoding not implemented for {event:?}"
)))),
}
}
fn decode_event(&self, data: &[u8]) -> Result<NetworkEvent, EncodeError> {
if data.is_empty() {
return Err(EncodeError::MalformedPayload {
offset: 0,
message: "Empty event data".to_string(),
});
}
if data[0] == b'A' {
return Ok(NetworkEvent::Auth {
session_token: "mock_token".to_string(),
});
}
Err(EncodeError::MalformedPayload {
offset: 0,
message: format!("Unexpected first byte for mock event: {:#x}", data[0]),
})
}
fn max_encoded_size(&self) -> usize {
crate::MAX_SAFE_PAYLOAD_SIZE
}
}
#[cfg(test)]
mod tests {
use super::*;
#[allow(dead_code)]
const fn assert_transport_bounds<T: GameTransport>() {}
#[allow(dead_code)]
const fn assert_world_bounds<T: WorldState>() {}
#[allow(dead_code)]
const fn assert_encoder_bounds<T: Encoder>() {}
#[test]
fn test_compile_bounds() {
assert_transport_bounds::<MockTransport>();
assert_world_bounds::<MockWorldState>();
assert_encoder_bounds::<MockEncoder>();
}
#[tokio::test]
async fn test_tick_loop_integration() {
let mut transport = MockTransport::new();
let mut world = MockWorldState::new();
let encoder = MockEncoder::new();
for tick in 1..=1000 {
let mut events = transport.poll_events().await.unwrap();
if tick % 100 == 0 {
let cid = ClientId(tick);
transport.connect(cid);
events.push(NetworkEvent::ClientConnected(cid));
}
let mut updates = Vec::new();
for event in events {
if let NetworkEvent::UnreliableMessage { data, client_id } = event
&& let Ok(update) = encoder.decode(&data)
{
updates.push((client_id, update));
}
}
world.apply_updates(&updates);
world.simulate();
if tick % 50 == 0 {
let ent = world.spawn_networked();
world.queue_delta(ReplicationEvent {
network_id: ent,
component_kind: ComponentKind(1),
payload: vec![1, 2, 3],
tick,
});
}
let deltas = world.extract_deltas();
for delta in deltas {
let mut buf = vec![0u8; 1500];
let size = encoder.encode(&delta, &mut buf).unwrap();
let _ = transport.broadcast_unreliable(&buf[..size]).await;
}
}
}
}