1use std::collections::HashMap;
2use std::time::{Duration, Instant};
3
4use tokio::sync::broadcast;
5use tokio::time::{MissedTickBehavior, interval};
6use tracing::{Instrument, debug_span, error, info_span};
7
8use crate::auth::AuthServiceImpl;
9use aetheris_protocol::error::EncodeError;
10use aetheris_protocol::events::{FragmentedEvent, NetworkEvent};
11use aetheris_protocol::reassembler::Reassembler;
12use aetheris_protocol::traits::{Encoder, GameTransport, WorldState};
13
14#[derive(Debug)]
16pub struct TickScheduler {
17 tick_rate: u64,
18 current_tick: u64,
19 auth_service: AuthServiceImpl,
20
21 authenticated_clients:
23 HashMap<aetheris_protocol::types::ClientId, (String, aetheris_protocol::types::NetworkId)>,
24 reassembler: Reassembler,
25 next_message_id: u32,
26}
27
28impl TickScheduler {
29 #[must_use]
31 pub fn new(tick_rate: u64, auth_service: AuthServiceImpl) -> Self {
32 Self {
33 tick_rate,
34 current_tick: 0,
35 auth_service,
36 authenticated_clients: HashMap::new(),
37 reassembler: Reassembler::new(),
38 next_message_id: 0,
39 }
40 }
41
42 pub async fn run(
44 &mut self,
45 mut transport: Box<dyn GameTransport>,
46 mut world: Box<dyn WorldState>,
47 encoder: Box<dyn Encoder>,
48 mut shutdown: broadcast::Receiver<()>,
49 ) {
50 #[allow(clippy::cast_precision_loss)]
51 let tick_duration = Duration::from_secs_f64(1.0 / self.tick_rate as f64);
52 let mut interval = interval(tick_duration);
53 interval.set_missed_tick_behavior(MissedTickBehavior::Skip);
54
55 let mut encode_buffer = vec![0u8; encoder.max_encoded_size()];
58
59 loop {
60 tokio::select! {
61 _ = interval.tick() => {
62 self.current_tick += 1;
63
64 let start = Instant::now();
65 Self::tick_step(
66 transport.as_mut(),
67 world.as_mut(),
68 encoder.as_ref(),
69 &self.auth_service,
70 &mut self.authenticated_clients,
71 &mut self.reassembler,
72 &mut self.next_message_id,
73 &mut encode_buffer,
74 self.current_tick,
75 )
76 .instrument(info_span!("tick", tick = self.current_tick))
77 .await;
78 let elapsed = start.elapsed();
79
80 metrics::histogram!("aetheris_tick_duration_seconds").record(elapsed.as_secs_f64());
81 }
82 _ = shutdown.recv() => {
83 tracing::info!("Server shutting down gracefully");
84 break;
85 }
86 }
87 }
88 }
89
90 #[allow(clippy::too_many_lines, clippy::too_many_arguments)]
92 pub async fn tick_step(
93 transport: &mut dyn GameTransport,
94 world: &mut dyn WorldState,
95 encoder: &dyn Encoder,
96 auth_service: &AuthServiceImpl,
97 authenticated_clients: &mut HashMap<
98 aetheris_protocol::types::ClientId,
99 (String, aetheris_protocol::types::NetworkId),
100 >,
101 reassembler: &mut Reassembler,
102 next_message_id: &mut u32,
103 encode_buffer: &mut [u8],
104 tick: u64,
105 ) {
106 world.advance_tick();
112
113 let t1 = Instant::now();
115 let events = match transport
116 .poll_events()
117 .instrument(debug_span!("stage1_poll"))
118 .await
119 {
120 Ok(e) => e,
121 Err(e) => {
122 error!(error = ?e, "Fatal transport error during poll; skipping tick");
123 return;
124 }
125 };
126 metrics::histogram!("aetheris_stage_duration_seconds", "stage" => "poll")
127 .record(t1.elapsed().as_secs_f64());
128
129 let inbound_count: u64 = events
130 .iter()
131 .filter(|e| {
132 matches!(
133 e,
134 NetworkEvent::UnreliableMessage { .. } | NetworkEvent::ReliableMessage { .. }
135 )
136 })
137 .count() as u64;
138 metrics::counter!("aetheris_packets_inbound_total").increment(inbound_count);
139
140 if tick.is_multiple_of(60) {
142 let mut to_remove = Vec::new();
143 for (&client_id, (jti, _)) in authenticated_clients.iter() {
144 if !auth_service.is_session_authorized(jti, Some(tick)) {
145 tracing::warn!(?client_id, "Session invalidated during periodic check");
146 to_remove.push(client_id);
147 }
148 }
149 for client_id in to_remove {
150 if let Some((_, network_id)) = authenticated_clients.remove(&client_id) {
151 let _ = world.despawn_networked(network_id);
152 }
153 metrics::counter!("aetheris_unprivileged_packets_total").increment(1);
154 }
155 }
156
157 let t2 = Instant::now();
159 let mut pong_responses = None;
160 if !events.is_empty() {
161 let _span = debug_span!("stage2_apply", count = events.len()).entered();
162 let mut updates = Vec::with_capacity(events.len());
163 for event in events {
164 let (client_id, raw_data, is_message) = match event {
166 NetworkEvent::Fragment {
167 client_id,
168 fragment,
169 } => {
170 if let Some(data) = reassembler.ingest(client_id, fragment) {
171 (client_id, data, true)
172 } else {
173 continue;
174 }
175 }
176 NetworkEvent::UnreliableMessage { data, client_id }
177 | NetworkEvent::ReliableMessage { data, client_id } => {
178 if let Ok(NetworkEvent::Fragment { fragment, .. }) =
180 encoder.decode_event(&data)
181 {
182 if let Some(reassembled) = reassembler.ingest(client_id, fragment) {
183 (client_id, reassembled, true)
184 } else {
185 continue;
186 }
187 } else {
188 (client_id, data, true)
189 }
190 }
191 NetworkEvent::ClientConnected(id) => {
192 metrics::gauge!("aetheris_connected_clients").increment(1.0);
193 tracing::info!(client_id = ?id, "Client connected (awaiting auth)");
194 (id, Vec::new(), false)
195 }
196 NetworkEvent::ClientDisconnected(id) | NetworkEvent::Disconnected(id) => {
197 metrics::gauge!("aetheris_connected_clients").decrement(1.0);
198 if let Some((_, network_id)) = authenticated_clients.remove(&id) {
199 let _ = world.despawn_networked(network_id);
200 }
201 tracing::info!(client_id = ?id, "Client disconnected");
202 (id, Vec::new(), false)
203 }
204 NetworkEvent::Ping { client_id, tick } => {
205 if authenticated_clients.contains_key(&client_id) {
206 pong_responses.get_or_insert_with(Vec::new).push((
207 client_id,
208 tick,
209 Instant::now(),
210 ));
211 metrics::counter!("aetheris_protocol_pings_received_total")
212 .increment(1);
213 }
214 (client_id, Vec::new(), false)
215 }
216 NetworkEvent::SessionClosed(id) => {
217 metrics::counter!("aetheris_transport_events_total", "type" => "session_closed")
218 .increment(1);
219 tracing::warn!(client_id = ?id, "WebTransport session closed");
220 if let Some((_, network_id)) = authenticated_clients.remove(&id) {
221 let _ = world.despawn_networked(network_id);
222 }
223 (id, Vec::new(), false)
224 }
225 NetworkEvent::StreamReset(id) => {
226 metrics::counter!("aetheris_transport_events_total", "type" => "stream_reset")
227 .increment(1);
228 tracing::error!(client_id = ?id, "WebTransport stream reset");
229 if let Some((_, network_id)) = authenticated_clients.remove(&id) {
230 let _ = world.despawn_networked(network_id);
231 }
232 (id, Vec::new(), false)
233 }
234 NetworkEvent::Auth { .. }
235 | NetworkEvent::Pong { .. }
236 | NetworkEvent::StressTest { .. }
237 | NetworkEvent::Spawn { .. }
238 | NetworkEvent::ClearWorld { .. } => {
239 continue;
242 }
243 };
244
245 if !is_message {
246 continue;
247 }
248
249 let jti = if let Some((jti, _)) = authenticated_clients.get(&client_id) {
251 if !auth_service.is_session_authorized(jti, Some(tick)) {
253 tracing::warn!(?client_id, "Session revoked; dropping client");
254 if let Some((_, network_id)) = authenticated_clients.remove(&client_id) {
255 let _ = world.despawn_networked(network_id);
256 }
257 metrics::counter!("aetheris_unprivileged_packets_total").increment(1);
258 continue;
259 }
260 jti
261 } else {
262 match encoder.decode_event(&raw_data) {
264 Ok(NetworkEvent::Auth { session_token }) => {
265 tracing::info!(?client_id, "Auth message received");
266 if let Some(jti) =
267 auth_service.validate_and_get_jti(&session_token, Some(tick))
268 {
269 tracing::info!(?client_id, "Client authenticated successfully");
270 let network_id = world.spawn_networked_for(client_id);
271 authenticated_clients.insert(client_id, (jti, network_id));
272 continue;
273 }
274 tracing::warn!(
275 ?client_id,
276 "Client failed authentication (token rejected)"
277 );
278 }
279 Ok(other) => {
280 tracing::warn!(
281 ?client_id,
282 variant = ?std::mem::discriminant(&other),
283 bytes = raw_data.len(),
284 "Unauthenticated client sent non-Auth event — discarding"
285 );
286 metrics::counter!("aetheris_unprivileged_packets_total").increment(1);
287 }
288 Err(e) => {
289 tracing::warn!(
290 ?client_id,
291 error = ?e,
292 bytes = raw_data.len(),
293 "Failed to decode message from unauthenticated client"
294 );
295 metrics::counter!("aetheris_unprivileged_packets_total").increment(1);
296 }
297 }
298 continue;
299 };
300
301 if let Ok(protocol_event) = encoder.decode_event(&raw_data) {
303 match protocol_event {
304 NetworkEvent::Ping { tick: p_tick, .. } => {
305 pong_responses.get_or_insert_with(Vec::new).push((
306 client_id,
307 p_tick,
308 Instant::now(),
309 ));
310 metrics::counter!("aetheris_protocol_pings_received_total")
311 .increment(1);
312 }
313 NetworkEvent::Auth { .. } => {
314 tracing::debug!(?client_id, "Client re-authenticating (ignored)");
315 }
316 NetworkEvent::StressTest { count, rotate, .. } => {
317 tracing::info!(
318 ?client_id,
319 count,
320 rotate,
321 "StressTest event received from authenticated client"
322 );
323 if can_run_playground_command(jti) {
324 const MAX_STRESS: u16 = 1000;
326 let capped_count = count.min(MAX_STRESS);
327 if count > MAX_STRESS {
328 tracing::warn!(
329 ?client_id,
330 count,
331 capped_count,
332 "Stress test count capped at limit"
333 );
334 }
335
336 tracing::info!(
337 ?client_id,
338 count = capped_count,
339 rotate,
340 "Stress test command executed"
341 );
342 world.stress_test(capped_count, rotate);
343 } else {
344 tracing::warn!(?client_id, "Unauthorized StressTest attempt");
345 metrics::counter!("aetheris_unprivileged_packets_total")
346 .increment(1);
347 }
348 }
349 NetworkEvent::Spawn {
350 entity_type,
351 x,
352 y,
353 rot,
354 ..
355 } => {
356 if can_run_playground_command(jti) {
357 tracing::info!(
358 ?client_id,
359 entity_type,
360 x,
361 y,
362 "Spawn command executed"
363 );
364 world.spawn_kind(entity_type, x, y, rot);
365 } else {
366 tracing::warn!(?client_id, "Unauthorized Spawn attempt");
367 metrics::counter!("aetheris_unprivileged_packets_total")
368 .increment(1);
369 }
370 }
371 NetworkEvent::ClearWorld { .. } => {
372 if can_run_playground_command(jti) {
373 tracing::info!(?client_id, "ClearWorld command executed");
374 world.clear_world();
375 } else {
376 tracing::warn!(?client_id, "Unauthorized ClearWorld attempt");
377 metrics::counter!("aetheris_unprivileged_packets_total")
378 .increment(1);
379 }
380 }
381 _ => {
382 tracing::trace!(?protocol_event, "Protocol event");
383 }
384 }
385 } else {
386 match encoder.decode(&raw_data) {
388 Ok(update) => updates.push((client_id, update)),
389 Err(e) => {
390 metrics::counter!("aetheris_decode_errors_total").increment(1);
391 error!(
392 error = ?e,
393 size = raw_data.len(),
394 "Failed to decode update (not a protocol event)"
395 );
396 }
397 }
398 }
399 }
400 world.apply_updates(&updates);
401 reassembler.prune();
402 }
403 metrics::histogram!("aetheris_stage_duration_seconds", "stage" => "apply")
404 .record(t2.elapsed().as_secs_f64());
405
406 if let Some(pongs) = pong_responses {
410 for (client_id, p_tick, received_at) in pongs {
411 let pong_event = NetworkEvent::Pong { tick: p_tick };
412 if let Ok(data) = encoder.encode_event(&pong_event) {
413 let dispatch_start = Instant::now();
417 match transport.send_unreliable(client_id, &data).await {
418 Ok(()) => {
419 let dispatch_ms = dispatch_start.elapsed().as_secs_f64() * 1000.0;
420 let server_hold_ms = received_at.elapsed().as_secs_f64() * 1000.0;
421 metrics::histogram!("aetheris_server_pong_dispatch_ms")
422 .record(dispatch_ms);
423 metrics::histogram!("aetheris_server_ping_hold_ms")
424 .record(server_hold_ms);
425 }
426 Err(e) => {
427 error!(error = ?e, client_id = ?client_id, "Failed to send Pong");
428 }
429 }
430 }
431 }
432 }
433
434 let t3 = Instant::now();
436 {
437 let _span = debug_span!("stage3_simulate").entered();
438 world.simulate();
440 }
441 metrics::histogram!("aetheris_stage_duration_seconds", "stage" => "simulate")
442 .record(t3.elapsed().as_secs_f64());
443
444 let t4 = Instant::now();
446 let deltas = {
447 let _span = debug_span!("stage4_extract").entered();
448 world.extract_deltas()
449 };
450 metrics::histogram!("aetheris_stage_duration_seconds", "stage" => "extract")
451 .record(t4.elapsed().as_secs_f64());
452
453 let t5 = Instant::now();
455 if !deltas.is_empty() {
456 let mut broadcast_count: u64 = 0;
457
458 let stage_span = debug_span!("stage5_send", count = deltas.len());
459 let _guard = stage_span.enter();
460
461 for delta in deltas {
462 let encode_result = encoder.encode(&delta, encode_buffer);
463 match encode_result {
464 Ok(len) if len > aetheris_protocol::MAX_SAFE_PAYLOAD_SIZE => {
465 match Self::fragment_and_broadcast(
466 encode_buffer,
467 len,
468 next_message_id,
469 encoder,
470 transport,
471 )
472 .await
473 {
474 Ok(count) => broadcast_count += count,
475 Err(e) => error!(error = ?e, "Failed to fragment and broadcast delta"),
476 }
477 }
478 Ok(len) => {
479 if let Err(e) = transport.broadcast_unreliable(&encode_buffer[..len]).await
480 {
481 error!(error = ?e, "Failed to broadcast delta");
482 } else {
483 broadcast_count += 1;
484 }
485 }
486 Err(EncodeError::BufferOverflow {
487 needed,
488 available: _,
489 }) => {
490 let mut large_buffer = vec![0u8; needed];
491 if let Ok(len) = encoder.encode(&delta, &mut large_buffer) {
492 match Self::fragment_and_broadcast(
493 &large_buffer,
494 len,
495 next_message_id,
496 encoder,
497 transport,
498 )
499 .await
500 {
501 Ok(count) => broadcast_count += count,
502 Err(e) => {
503 error!(error = ?e, "Failed to fragment and broadcast large delta");
504 }
505 }
506 } else {
507 error!("Failed to encode into large scratch buffer");
508 }
509 }
510 Err(e) => {
511 metrics::counter!("aetheris_encode_errors_total").increment(1);
512 error!(
513 network_id = ?delta.network_id,
514 error = ?e,
515 "Failed to encode delta"
516 );
517 }
518 }
519 }
520 metrics::counter!("aetheris_packets_outbound_total").increment(broadcast_count);
521 metrics::counter!("aetheris_packets_broadcast_total").increment(broadcast_count);
522 }
523 metrics::histogram!("aetheris_stage_duration_seconds", "stage" => "send")
524 .record(t5.elapsed().as_secs_f64());
525 }
526
527 async fn fragment_and_broadcast(
528 data: &[u8],
529 len: usize,
530 next_message_id: &mut u32,
531 encoder: &dyn Encoder,
532 transport: &dyn GameTransport,
533 ) -> Result<u64, EncodeError> {
534 let message_id = *next_message_id;
535 *next_message_id = next_message_id.wrapping_add(1);
536
537 let chunk_size = aetheris_protocol::MAX_FRAGMENT_PAYLOAD_SIZE;
538 let chunks: Vec<_> = data[..len].chunks(chunk_size).collect();
539
540 let Ok(total_fragments) = u16::try_from(chunks.len()) else {
541 error!(
542 message_id,
543 chunks = chunks.len(),
544 "Too many fragments required for message; dropping payload"
545 );
546 return Err(EncodeError::Io(std::io::Error::new(
547 std::io::ErrorKind::InvalidData,
548 "Too many fragments",
549 )));
550 };
551
552 let mut sent_count = 0;
553 for (i, chunk) in chunks.into_iter().enumerate() {
554 let Ok(fragment_index) = u16::try_from(i) else {
555 error!(message_id, index = i, "Fragment index overflow; stopping");
556 break;
557 };
558
559 let fragment = FragmentedEvent {
560 message_id,
561 fragment_index,
562 total_fragments,
563 payload: chunk.to_vec(),
564 };
565 let fragment_event = NetworkEvent::Fragment {
566 client_id: aetheris_protocol::types::ClientId(0),
567 fragment,
568 };
569
570 match encoder.encode_event(&fragment_event) {
571 Ok(encoded_fragment) => {
572 if let Err(e) = transport.broadcast_unreliable(&encoded_fragment).await {
573 error!(error = ?e, "Failed to broadcast fragment");
574 } else {
575 sent_count += 1;
576 }
577 }
578 Err(e) => {
579 error!(error = ?e, "Failed to encode fragment event");
580 }
581 }
582 }
583
584 Ok(sent_count)
585 }
586}
587
588fn can_run_playground_command(jti: &str) -> bool {
593 jti == "admin" || std::env::var("AETHERIS_ENV").ok().as_deref() == Some("dev")
596}