1use std::collections::HashMap;
2use std::time::{Duration, Instant};
3
4use tokio::sync::broadcast;
5use tokio::time::{MissedTickBehavior, interval};
6use tracing::{Instrument, debug_span, error, info_span};
7
8use crate::auth::AuthServiceImpl;
9use aetheris_protocol::error::EncodeError;
10use aetheris_protocol::events::{FragmentedEvent, NetworkEvent};
11use aetheris_protocol::reassembler::Reassembler;
12use aetheris_protocol::traits::{Encoder, GameTransport, WorldState};
13
14#[derive(Debug)]
16pub struct TickScheduler {
17 tick_rate: u64,
18 current_tick: u64,
19 auth_service: AuthServiceImpl,
20
21 authenticated_clients:
23 HashMap<aetheris_protocol::types::ClientId, (String, aetheris_protocol::types::NetworkId)>,
24 reassembler: Reassembler,
25 next_message_id: u32,
26}
27
28impl TickScheduler {
29 #[must_use]
31 pub fn new(tick_rate: u64, auth_service: AuthServiceImpl) -> Self {
32 Self {
33 tick_rate,
34 current_tick: 0,
35 auth_service,
36 authenticated_clients: HashMap::new(),
37 reassembler: Reassembler::new(),
38 next_message_id: 0,
39 }
40 }
41
42 pub async fn run(
44 &mut self,
45 mut transport: Box<dyn GameTransport>,
46 mut world: Box<dyn WorldState>,
47 encoder: Box<dyn Encoder>,
48 mut shutdown: broadcast::Receiver<()>,
49 ) {
50 #[allow(clippy::cast_precision_loss)]
51 let tick_duration = Duration::from_secs_f64(1.0 / self.tick_rate as f64);
52 let mut interval = interval(tick_duration);
53 interval.set_missed_tick_behavior(MissedTickBehavior::Skip);
54
55 let mut encode_buffer = vec![0u8; encoder.max_encoded_size()];
58
59 loop {
60 tokio::select! {
61 _ = interval.tick() => {
62 self.current_tick += 1;
63
64 let start = Instant::now();
65 Self::tick_step(
66 transport.as_mut(),
67 world.as_mut(),
68 encoder.as_ref(),
69 &self.auth_service,
70 &mut self.authenticated_clients,
71 &mut self.reassembler,
72 &mut self.next_message_id,
73 &mut encode_buffer,
74 self.current_tick,
75 )
76 .instrument(info_span!("tick", tick = self.current_tick))
77 .await;
78 let elapsed = start.elapsed();
79
80 metrics::histogram!("aetheris_tick_duration_seconds").record(elapsed.as_secs_f64());
81 }
82 _ = shutdown.recv() => {
83 tracing::info!("Server shutting down gracefully");
84 break;
85 }
86 }
87 }
88 }
89
90 #[allow(clippy::too_many_lines, clippy::too_many_arguments)]
92 pub async fn tick_step(
93 transport: &mut dyn GameTransport,
94 world: &mut dyn WorldState,
95 encoder: &dyn Encoder,
96 auth_service: &AuthServiceImpl,
97 authenticated_clients: &mut HashMap<
98 aetheris_protocol::types::ClientId,
99 (String, aetheris_protocol::types::NetworkId),
100 >,
101 reassembler: &mut Reassembler,
102 next_message_id: &mut u32,
103 encode_buffer: &mut [u8],
104 tick: u64,
105 ) {
106 world.advance_tick();
112
113 let t1 = Instant::now();
115 let events = match transport
116 .poll_events()
117 .instrument(debug_span!("stage1_poll"))
118 .await
119 {
120 Ok(e) => e,
121 Err(e) => {
122 error!(error = ?e, "Fatal transport error during poll; skipping tick");
123 return;
124 }
125 };
126 metrics::histogram!("aetheris_stage_duration_seconds", "stage" => "poll")
127 .record(t1.elapsed().as_secs_f64());
128
129 let inbound_count: u64 = events
130 .iter()
131 .filter(|e| {
132 matches!(
133 e,
134 NetworkEvent::UnreliableMessage { .. } | NetworkEvent::ReliableMessage { .. }
135 )
136 })
137 .count() as u64;
138 metrics::counter!("aetheris_packets_inbound_total").increment(inbound_count);
139
140 if tick.is_multiple_of(60) {
142 let mut to_remove = Vec::new();
143 for (&client_id, (jti, _)) in authenticated_clients.iter() {
144 if !auth_service.is_session_authorized(jti, Some(tick)) {
145 tracing::warn!(?client_id, "Session invalidated during periodic check");
146 to_remove.push(client_id);
147 }
148 }
149 for client_id in to_remove {
150 if let Some((_, network_id)) = authenticated_clients.remove(&client_id) {
151 let _ = world.despawn_networked(network_id);
152 }
153 metrics::counter!("aetheris_unprivileged_packets_total").increment(1);
154 }
155 }
156
157 let t2 = Instant::now();
159 let mut pong_responses = None;
160 if !events.is_empty() {
161 let _span = debug_span!("stage2_apply", count = events.len()).entered();
162 let mut updates = Vec::with_capacity(events.len());
163 for event in events {
164 let (client_id, raw_data, is_message) = match event {
166 NetworkEvent::Fragment {
167 client_id,
168 fragment,
169 } => {
170 if let Some(data) = reassembler.ingest(client_id, fragment) {
171 (client_id, data, true)
172 } else {
173 continue;
174 }
175 }
176 NetworkEvent::UnreliableMessage { data, client_id }
177 | NetworkEvent::ReliableMessage { data, client_id } => {
178 if let Ok(NetworkEvent::Fragment { fragment, .. }) =
180 encoder.decode_event(&data)
181 {
182 if let Some(reassembled) = reassembler.ingest(client_id, fragment) {
183 (client_id, reassembled, true)
184 } else {
185 continue;
186 }
187 } else {
188 (client_id, data, true)
189 }
190 }
191 NetworkEvent::ClientConnected(id) => {
192 metrics::gauge!("aetheris_connected_clients").increment(1.0);
193 tracing::info!(client_id = ?id, "Client connected (awaiting auth)");
194 (id, Vec::new(), false)
195 }
196 NetworkEvent::ClientDisconnected(id) | NetworkEvent::Disconnected(id) => {
197 metrics::gauge!("aetheris_connected_clients").decrement(1.0);
198 if let Some((_, network_id)) = authenticated_clients.remove(&id) {
199 let _ = world.despawn_networked(network_id);
200 }
201 tracing::info!(client_id = ?id, "Client disconnected");
202 (id, Vec::new(), false)
203 }
204 NetworkEvent::SessionClosed(id) => {
205 metrics::counter!("aetheris_transport_events_total", "type" => "session_closed")
206 .increment(1);
207 tracing::warn!(client_id = ?id, "WebTransport session closed");
208 if let Some((_, network_id)) = authenticated_clients.remove(&id) {
209 let _ = world.despawn_networked(network_id);
210 }
211 (id, Vec::new(), false)
212 }
213 NetworkEvent::StreamReset(id) => {
214 metrics::counter!("aetheris_transport_events_total", "type" => "stream_reset")
215 .increment(1);
216 tracing::error!(client_id = ?id, "WebTransport stream reset");
217 if let Some((_, network_id)) = authenticated_clients.remove(&id) {
218 let _ = world.despawn_networked(network_id);
219 }
220 (id, Vec::new(), false)
221 }
222 NetworkEvent::Ping { client_id, tick } => {
223 if authenticated_clients.contains_key(&client_id) {
224 pong_responses.get_or_insert_with(Vec::new).push((
225 client_id,
226 tick,
227 Instant::now(),
228 ));
229 metrics::counter!("aetheris_protocol_pings_received_total")
230 .increment(1);
231 }
232 (client_id, Vec::new(), false)
233 }
234 NetworkEvent::ClearWorld { client_id, .. }
235 | NetworkEvent::GameEvent { client_id, .. }
236 | NetworkEvent::StressTest { client_id, .. }
237 | NetworkEvent::Spawn { client_id, .. } => (client_id, Vec::new(), false),
238 NetworkEvent::Pong { .. } | NetworkEvent::Auth { .. } => {
239 (aetheris_protocol::types::ClientId(0), Vec::new(), false)
240 }
241 };
242
243 if !is_message {
244 continue;
245 }
246
247 let jti = if let Some((jti, _)) = authenticated_clients.get(&client_id) {
249 if !auth_service.is_session_authorized(jti, Some(tick)) {
251 tracing::warn!(?client_id, "Session revoked; dropping client");
252 if let Some((_, network_id)) = authenticated_clients.remove(&client_id) {
253 let _ = world.despawn_networked(network_id);
254 }
255 metrics::counter!("aetheris_unprivileged_packets_total").increment(1);
256 continue;
257 }
258 jti
259 } else {
260 match encoder.decode_event(&raw_data) {
262 Ok(NetworkEvent::Auth { session_token }) => {
263 tracing::info!(?client_id, "Auth message received");
264 if let Some(jti) =
265 auth_service.validate_and_get_jti(&session_token, Some(tick))
266 {
267 tracing::info!(?client_id, "Client authenticated successfully");
268 let network_id = world.spawn_networked_for(client_id);
269 authenticated_clients.insert(client_id, (jti, network_id));
270 continue;
271 }
272 tracing::warn!(
273 ?client_id,
274 "Client failed authentication (token rejected)"
275 );
276 }
277 Ok(other) => {
278 tracing::warn!(
279 ?client_id,
280 variant = ?std::mem::discriminant(&other),
281 bytes = raw_data.len(),
282 "Unauthenticated client sent non-Auth event — discarding"
283 );
284 metrics::counter!("aetheris_unprivileged_packets_total").increment(1);
285 }
286 Err(e) => {
287 tracing::warn!(
288 ?client_id,
289 error = ?e,
290 bytes = raw_data.len(),
291 "Failed to decode message from unauthenticated client"
292 );
293 metrics::counter!("aetheris_unprivileged_packets_total").increment(1);
294 }
295 }
296 continue;
297 };
298
299 if let Ok(protocol_event) = encoder.decode_event(&raw_data) {
301 match protocol_event {
302 NetworkEvent::Ping { tick: p_tick, .. } => {
303 pong_responses.get_or_insert_with(Vec::new).push((
304 client_id,
305 p_tick,
306 Instant::now(),
307 ));
308 metrics::counter!("aetheris_protocol_pings_received_total")
309 .increment(1);
310 }
311 NetworkEvent::Auth { .. } => {
312 tracing::debug!(?client_id, "Client re-authenticating (ignored)");
313 }
314 NetworkEvent::StressTest { count, rotate, .. } => {
315 tracing::info!(
316 ?client_id,
317 count,
318 rotate,
319 "StressTest event received from authenticated client"
320 );
321 if can_run_playground_command(jti) {
322 const MAX_STRESS: u16 = 1000;
324 let capped_count = count.min(MAX_STRESS);
325 if count > MAX_STRESS {
326 tracing::warn!(
327 ?client_id,
328 count,
329 capped_count,
330 "Stress test count capped at limit"
331 );
332 }
333
334 tracing::info!(
335 ?client_id,
336 count = capped_count,
337 rotate,
338 "Stress test command executed"
339 );
340 world.stress_test(capped_count, rotate);
341 } else {
342 tracing::warn!(?client_id, "Unauthorized StressTest attempt");
343 metrics::counter!("aetheris_unprivileged_packets_total")
344 .increment(1);
345 }
346 }
347 NetworkEvent::Spawn {
348 entity_type,
349 x,
350 y,
351 rot,
352 ..
353 } => {
354 if can_run_playground_command(jti) {
355 tracing::info!(
356 ?client_id,
357 entity_type,
358 x,
359 y,
360 "Spawn command executed"
361 );
362 world.spawn_kind(entity_type, x, y, rot);
363 } else {
364 tracing::warn!(?client_id, "Unauthorized Spawn attempt");
365 metrics::counter!("aetheris_unprivileged_packets_total")
366 .increment(1);
367 }
368 }
369 NetworkEvent::ClearWorld { .. } => {
370 if can_run_playground_command(jti) {
371 tracing::info!(?client_id, "ClearWorld command executed");
372 world.clear_world();
373 } else {
374 tracing::warn!(?client_id, "Unauthorized ClearWorld attempt");
375 metrics::counter!("aetheris_unprivileged_packets_total")
376 .increment(1);
377 }
378 }
379 _ => {
380 tracing::trace!(?protocol_event, "Protocol event");
381 }
382 }
383 } else {
384 match encoder.decode(&raw_data) {
386 Ok(update) => updates.push((client_id, update)),
387 Err(e) => {
388 metrics::counter!("aetheris_decode_errors_total").increment(1);
389 error!(
390 error = ?e,
391 size = raw_data.len(),
392 "Failed to decode update (not a protocol event)"
393 );
394 }
395 }
396 }
397 }
398 world.apply_updates(&updates);
399 reassembler.prune();
400 }
401 metrics::histogram!("aetheris_stage_duration_seconds", "stage" => "apply")
402 .record(t2.elapsed().as_secs_f64());
403
404 if let Some(pongs) = pong_responses {
408 for (client_id, p_tick, received_at) in pongs {
409 let pong_event = NetworkEvent::Pong { tick: p_tick };
410 if let Ok(data) = encoder.encode_event(&pong_event) {
411 let dispatch_start = Instant::now();
415 match transport.send_unreliable(client_id, &data).await {
416 Ok(()) => {
417 let dispatch_ms = dispatch_start.elapsed().as_secs_f64() * 1000.0;
418 let server_hold_ms = received_at.elapsed().as_secs_f64() * 1000.0;
419 metrics::histogram!("aetheris_server_pong_dispatch_ms")
420 .record(dispatch_ms);
421 metrics::histogram!("aetheris_server_ping_hold_ms")
422 .record(server_hold_ms);
423 }
424 Err(e) => {
425 error!(error = ?e, client_id = ?client_id, "Failed to send Pong");
426 }
427 }
428 }
429 }
430 }
431
432 let t3 = Instant::now();
434 {
435 let _span = debug_span!("stage3_simulate").entered();
436 world.simulate();
438 }
439 metrics::histogram!("aetheris_stage_duration_seconds", "stage" => "simulate")
440 .record(t3.elapsed().as_secs_f64());
441
442 let t4 = Instant::now();
444 let (deltas, reliable_events) = {
445 let _span = debug_span!("stage4_extract").entered();
446 (world.extract_deltas(), world.extract_reliable_events())
447 };
448 metrics::histogram!("aetheris_stage_duration_seconds", "stage" => "extract")
449 .record(t4.elapsed().as_secs_f64());
450
451 let t5 = Instant::now();
453
454 for (target, wire_event) in reliable_events {
456 let targets: Vec<_> = if let Some(id) = target {
458 vec![id]
459 } else {
460 authenticated_clients.keys().copied().collect()
461 };
462
463 for id in targets {
464 let network_event = wire_event.clone().into_network_event(id);
465 match encoder.encode_event(&network_event) {
466 Ok(data) => {
467 if let Err(e) = transport.send_reliable(id, &data).await {
468 error!(error = ?e, client_id = ?id, "Failed to send reliable event");
469 }
470 }
471 Err(e) => {
472 error!(error = ?e, client_id = ?id, "Failed to encode reliable event");
473 }
474 }
475 }
476 }
477
478 if !deltas.is_empty() {
479 let mut broadcast_count: u64 = 0;
480
481 let stage_span = debug_span!("stage5_send", count = deltas.len());
482 let _guard = stage_span.enter();
483
484 for delta in deltas {
485 let encode_result = encoder.encode(&delta, encode_buffer);
486 match encode_result {
487 Ok(len) if len > aetheris_protocol::MAX_SAFE_PAYLOAD_SIZE => {
488 match Self::fragment_and_broadcast(
489 encode_buffer,
490 len,
491 next_message_id,
492 encoder,
493 transport,
494 )
495 .await
496 {
497 Ok(count) => broadcast_count += count,
498 Err(e) => error!(error = ?e, "Failed to fragment and broadcast delta"),
499 }
500 }
501 Ok(len) => {
502 if let Err(e) = transport.broadcast_unreliable(&encode_buffer[..len]).await
503 {
504 error!(error = ?e, "Failed to broadcast delta");
505 } else {
506 broadcast_count += 1;
507 }
508 }
509 Err(EncodeError::BufferOverflow {
510 needed,
511 available: _,
512 }) => {
513 let mut large_buffer = vec![0u8; needed];
514 if let Ok(len) = encoder.encode(&delta, &mut large_buffer) {
515 match Self::fragment_and_broadcast(
516 &large_buffer,
517 len,
518 next_message_id,
519 encoder,
520 transport,
521 )
522 .await
523 {
524 Ok(count) => broadcast_count += count,
525 Err(e) => {
526 error!(error = ?e, "Failed to fragment and broadcast large delta");
527 }
528 }
529 } else {
530 error!("Failed to encode into large scratch buffer");
531 }
532 }
533 Err(e) => {
534 metrics::counter!("aetheris_encode_errors_total").increment(1);
535 error!(
536 network_id = ?delta.network_id,
537 error = ?e,
538 "Failed to encode delta"
539 );
540 }
541 }
542 }
543 metrics::counter!("aetheris_packets_outbound_total").increment(broadcast_count);
544 metrics::counter!("aetheris_packets_broadcast_total").increment(broadcast_count);
545 }
546 metrics::histogram!("aetheris_stage_duration_seconds", "stage" => "send")
547 .record(t5.elapsed().as_secs_f64());
548 }
549
550 async fn fragment_and_broadcast(
551 data: &[u8],
552 len: usize,
553 next_message_id: &mut u32,
554 encoder: &dyn Encoder,
555 transport: &dyn GameTransport,
556 ) -> Result<u64, EncodeError> {
557 let message_id = *next_message_id;
558 *next_message_id = next_message_id.wrapping_add(1);
559
560 let chunk_size = aetheris_protocol::MAX_FRAGMENT_PAYLOAD_SIZE;
561 let chunks: Vec<_> = data[..len].chunks(chunk_size).collect();
562
563 let Ok(total_fragments) = u16::try_from(chunks.len()) else {
564 error!(
565 message_id,
566 chunks = chunks.len(),
567 "Too many fragments required for message; dropping payload"
568 );
569 return Err(EncodeError::Io(std::io::Error::new(
570 std::io::ErrorKind::InvalidData,
571 "Too many fragments",
572 )));
573 };
574
575 let mut sent_count = 0;
576 for (i, chunk) in chunks.into_iter().enumerate() {
577 let Ok(fragment_index) = u16::try_from(i) else {
578 error!(message_id, index = i, "Fragment index overflow; stopping");
579 break;
580 };
581
582 let fragment = FragmentedEvent {
583 message_id,
584 fragment_index,
585 total_fragments,
586 payload: chunk.to_vec(),
587 };
588 let fragment_event = NetworkEvent::Fragment {
589 client_id: aetheris_protocol::types::ClientId(0),
590 fragment,
591 };
592
593 match encoder.encode_event(&fragment_event) {
594 Ok(encoded_fragment) => {
595 if let Err(e) = transport.broadcast_unreliable(&encoded_fragment).await {
596 error!(error = ?e, "Failed to broadcast fragment");
597 } else {
598 sent_count += 1;
599 }
600 }
601 Err(e) => {
602 error!(error = ?e, "Failed to encode fragment event");
603 }
604 }
605 }
606
607 Ok(sent_count)
608 }
609}
610
611fn can_run_playground_command(jti: &str) -> bool {
616 jti == "admin" || std::env::var("AETHERIS_ENV").ok().as_deref() == Some("dev")
619}