1use std::collections::HashMap;
2use std::time::{Duration, Instant};
3
4use tokio::time::{MissedTickBehavior, interval};
5use tokio_util::sync::CancellationToken;
6use tracing::{Instrument, debug_span, error, info_span};
7
8use crate::auth::AuthServiceImpl;
9use aetheris_protocol::error::EncodeError;
10use aetheris_protocol::events::{FragmentedEvent, NetworkEvent};
11use aetheris_protocol::reassembler::Reassembler;
12use aetheris_protocol::traits::{Encoder, GameTransport, WorldState};
13
14#[derive(Debug)]
16pub struct TickScheduler {
17 tick_rate: u64,
18 current_tick: u64,
19 auth_service: AuthServiceImpl,
20
21 authenticated_clients:
23 HashMap<aetheris_protocol::types::ClientId, (String, aetheris_protocol::types::NetworkId)>,
24 reassembler: Reassembler,
25 next_message_id: u32,
26}
27
28impl TickScheduler {
29 #[must_use]
31 pub fn new(tick_rate: u64, auth_service: AuthServiceImpl) -> Self {
32 Self {
33 tick_rate,
34 current_tick: 0,
35 auth_service,
36 authenticated_clients: HashMap::new(),
37 reassembler: Reassembler::new(),
38 next_message_id: 0,
39 }
40 }
41
42 pub async fn run(
44 &mut self,
45 mut transport: Box<dyn GameTransport>,
46 mut world: Box<dyn WorldState>,
47 encoder: Box<dyn Encoder>,
48 shutdown: CancellationToken,
49 ) {
50 #[allow(clippy::cast_precision_loss)]
51 let tick_duration = Duration::from_secs_f64(1.0 / self.tick_rate as f64);
52 let mut interval = interval(tick_duration);
53 interval.set_missed_tick_behavior(MissedTickBehavior::Skip);
54
55 let mut encode_buffer = vec![0u8; encoder.max_encoded_size()];
58
59 loop {
60 tokio::select! {
61 _ = interval.tick() => {
62 self.current_tick += 1;
63
64 let start = Instant::now();
65 Self::tick_step(
66 transport.as_mut(),
67 world.as_mut(),
68 encoder.as_ref(),
69 &self.auth_service,
70 &mut self.authenticated_clients,
71 &mut self.reassembler,
72 &mut self.next_message_id,
73 &mut encode_buffer,
74 self.current_tick,
75 )
76 .instrument(info_span!("tick", tick = self.current_tick))
77 .await;
78 let elapsed = start.elapsed();
79
80 metrics::histogram!("aetheris_tick_duration_seconds").record(elapsed.as_secs_f64());
81 }
82 () = shutdown.cancelled() => {
83 tracing::info!("Server shutting down gracefully");
84 break;
85 }
86 }
87 }
88 }
89
90 #[allow(clippy::too_many_lines, clippy::too_many_arguments)]
92 pub async fn tick_step(
93 transport: &mut dyn GameTransport,
94 world: &mut dyn WorldState,
95 encoder: &dyn Encoder,
96 auth_service: &AuthServiceImpl,
97 authenticated_clients: &mut HashMap<
98 aetheris_protocol::types::ClientId,
99 (String, aetheris_protocol::types::NetworkId),
100 >,
101 reassembler: &mut Reassembler,
102 next_message_id: &mut u32,
103 encode_buffer: &mut [u8],
104 tick: u64,
105 ) {
106 world.advance_tick();
112
113 let t1 = Instant::now();
115 let events = match transport
116 .poll_events()
117 .instrument(debug_span!("stage1_poll"))
118 .await
119 {
120 Ok(e) => e,
121 Err(e) => {
122 error!(error = ?e, "Fatal transport error during poll; skipping tick");
123 return;
124 }
125 };
126 metrics::histogram!("aetheris_stage_duration_seconds", "stage" => "poll")
127 .record(t1.elapsed().as_secs_f64());
128
129 let inbound_count: u64 = events
130 .iter()
131 .filter(|e| {
132 matches!(
133 e,
134 NetworkEvent::UnreliableMessage { .. } | NetworkEvent::ReliableMessage { .. }
135 )
136 })
137 .count() as u64;
138 metrics::counter!("aetheris_packets_inbound_total").increment(inbound_count);
139
140 if tick.is_multiple_of(60) {
142 let mut to_remove = Vec::new();
143 for (&client_id, (jti, _)) in authenticated_clients.iter() {
144 if !auth_service.is_session_authorized(jti, Some(tick)) {
145 tracing::warn!(?client_id, "Session invalidated during periodic check");
146 to_remove.push(client_id);
147 }
148 }
149 for client_id in to_remove {
150 if let Some((_, network_id)) = authenticated_clients.remove(&client_id) {
151 let _ = world.despawn_networked(network_id);
152 }
153 metrics::counter!("aetheris_unprivileged_packets_total").increment(1);
154 }
155 }
156
157 let t2 = Instant::now();
159 let mut pong_responses = None;
160 if !events.is_empty() {
161 let _span = debug_span!("stage2_apply", count = events.len()).entered();
162 let mut updates = Vec::with_capacity(events.len());
163 for event in events {
164 let (client_id, raw_data, is_message) = match event {
166 NetworkEvent::Fragment {
167 client_id,
168 fragment,
169 } => {
170 if let Some(data) = reassembler.add(client_id, fragment) {
171 (client_id, data, true)
172 } else {
173 continue;
174 }
175 }
176 NetworkEvent::UnreliableMessage { data, client_id }
177 | NetworkEvent::ReliableMessage { data, client_id } => {
178 if let Ok(NetworkEvent::Fragment { fragment, .. }) =
180 encoder.decode_event(&data)
181 {
182 if let Some(reassembled) = reassembler.add(client_id, fragment) {
183 (client_id, reassembled, true)
184 } else {
185 continue;
186 }
187 } else {
188 (client_id, data, true)
189 }
190 }
191 NetworkEvent::ClientConnected(id) => {
192 metrics::gauge!("aetheris_connected_clients").increment(1.0);
193 tracing::info!(client_id = ?id, "Client connected (awaiting auth)");
194 (id, Vec::new(), false)
195 }
196 NetworkEvent::ClientDisconnected(id) => {
197 metrics::gauge!("aetheris_connected_clients").decrement(1.0);
198 if let Some((_, network_id)) = authenticated_clients.remove(&id) {
199 let _ = world.despawn_networked(network_id);
200 }
201 tracing::info!(client_id = ?id, "Client disconnected");
202 (id, Vec::new(), false)
203 }
204 NetworkEvent::Ping { client_id, tick } => {
205 if authenticated_clients.contains_key(&client_id) {
206 pong_responses.get_or_insert_with(Vec::new).push((
207 client_id,
208 tick,
209 Instant::now(),
210 ));
211 metrics::counter!("aetheris_protocol_pings_received_total")
212 .increment(1);
213 }
214 (client_id, Vec::new(), false)
215 }
216 NetworkEvent::SessionClosed(id) => {
217 metrics::counter!("aetheris_transport_events_total", "type" => "session_closed")
218 .increment(1);
219 tracing::warn!(client_id = ?id, "WebTransport session closed");
220 if let Some((_, network_id)) = authenticated_clients.remove(&id) {
221 let _ = world.despawn_networked(network_id);
222 }
223 (id, Vec::new(), false)
224 }
225 NetworkEvent::StreamReset(id) => {
226 metrics::counter!("aetheris_transport_events_total", "type" => "stream_reset")
227 .increment(1);
228 tracing::error!(client_id = ?id, "WebTransport stream reset");
229 if let Some((_, network_id)) = authenticated_clients.remove(&id) {
230 let _ = world.despawn_networked(network_id);
231 }
232 (id, Vec::new(), false)
233 }
234 NetworkEvent::Auth { .. }
235 | NetworkEvent::Pong { .. }
236 | NetworkEvent::StressTest { .. }
237 | NetworkEvent::Spawn { .. }
238 | NetworkEvent::ClearWorld { .. } => {
239 continue;
242 }
243 };
244
245 if !is_message {
246 continue;
247 }
248
249 let jti = if let Some((jti, _)) = authenticated_clients.get(&client_id) {
251 if !auth_service.is_session_authorized(jti, Some(tick)) {
253 tracing::warn!(?client_id, "Session revoked; dropping client");
254 if let Some((_, network_id)) = authenticated_clients.remove(&client_id) {
255 let _ = world.despawn_networked(network_id);
256 }
257 metrics::counter!("aetheris_unprivileged_packets_total").increment(1);
258 continue;
259 }
260 jti
261 } else {
262 if let Ok(NetworkEvent::Auth { session_token }) =
264 encoder.decode_event(&raw_data)
265 {
266 tracing::info!(?client_id, "Auth message received");
267 if let Some(jti) =
268 auth_service.validate_and_get_jti(&session_token, Some(tick))
269 {
270 tracing::info!(?client_id, "Client authenticated successfully");
271 let network_id = world.spawn_networked_for(client_id);
272 authenticated_clients.insert(client_id, (jti, network_id));
273 continue;
274 }
275 tracing::warn!(?client_id, "Client failed authentication");
276 } else {
277 tracing::debug!(
278 ?client_id,
279 "Discarding message from unauthenticated client"
280 );
281 metrics::counter!("aetheris_unprivileged_packets_total").increment(1);
282 }
283 continue;
284 };
285
286 if let Ok(protocol_event) = encoder.decode_event(&raw_data) {
288 match protocol_event {
289 NetworkEvent::Ping { tick: p_tick, .. } => {
290 pong_responses.get_or_insert_with(Vec::new).push((
291 client_id,
292 p_tick,
293 Instant::now(),
294 ));
295 metrics::counter!("aetheris_protocol_pings_received_total")
296 .increment(1);
297 }
298 NetworkEvent::Auth { .. } => {
299 tracing::debug!(?client_id, "Client re-authenticating (ignored)");
300 }
301 NetworkEvent::StressTest { count, rotate, .. } => {
302 if can_run_playground_command(jti) {
303 const MAX_STRESS: u16 = 1000;
305 let capped_count = count.min(MAX_STRESS);
306 if count > MAX_STRESS {
307 tracing::warn!(
308 ?client_id,
309 count,
310 capped_count,
311 "Stress test count capped at limit"
312 );
313 }
314
315 tracing::info!(
316 ?client_id,
317 count = capped_count,
318 rotate,
319 "Stress test command executed"
320 );
321 world.stress_test(capped_count, rotate);
322 } else {
323 tracing::warn!(?client_id, "Unauthorized StressTest attempt");
324 metrics::counter!("aetheris_unprivileged_packets_total")
325 .increment(1);
326 }
327 }
328 NetworkEvent::Spawn {
329 entity_type,
330 x,
331 y,
332 rot,
333 ..
334 } => {
335 if can_run_playground_command(jti) {
336 tracing::info!(
337 ?client_id,
338 entity_type,
339 x,
340 y,
341 "Spawn command executed"
342 );
343 world.spawn_kind(entity_type, x, y, rot);
344 } else {
345 tracing::warn!(?client_id, "Unauthorized Spawn attempt");
346 metrics::counter!("aetheris_unprivileged_packets_total")
347 .increment(1);
348 }
349 }
350 NetworkEvent::ClearWorld { .. } => {
351 if can_run_playground_command(jti) {
352 tracing::info!(?client_id, "ClearWorld command executed");
353 world.clear_world();
354 } else {
355 tracing::warn!(?client_id, "Unauthorized ClearWorld attempt");
356 metrics::counter!("aetheris_unprivileged_packets_total")
357 .increment(1);
358 }
359 }
360 _ => {
361 tracing::trace!(?protocol_event, "Protocol event");
362 }
363 }
364 } else {
365 match encoder.decode(&raw_data) {
367 Ok(update) => updates.push((client_id, update)),
368 Err(e) => {
369 metrics::counter!("aetheris_decode_errors_total").increment(1);
370 error!(
371 error = ?e,
372 size = raw_data.len(),
373 "Failed to decode update (not a protocol event)"
374 );
375 }
376 }
377 }
378 }
379 world.apply_updates(&updates);
380 reassembler.cleanup();
381 }
382 metrics::histogram!("aetheris_stage_duration_seconds", "stage" => "apply")
383 .record(t2.elapsed().as_secs_f64());
384
385 if let Some(pongs) = pong_responses {
389 for (client_id, p_tick, received_at) in pongs {
390 let pong_event = NetworkEvent::Pong { tick: p_tick };
391 if let Ok(data) = encoder.encode_event(&pong_event) {
392 let dispatch_start = Instant::now();
396 match transport.send_unreliable(client_id, &data).await {
397 Ok(()) => {
398 let dispatch_ms = dispatch_start.elapsed().as_secs_f64() * 1000.0;
399 let server_hold_ms = received_at.elapsed().as_secs_f64() * 1000.0;
400 metrics::histogram!("aetheris_server_pong_dispatch_ms")
401 .record(dispatch_ms);
402 metrics::histogram!("aetheris_server_ping_hold_ms")
403 .record(server_hold_ms);
404 }
405 Err(e) => {
406 error!(error = ?e, client_id = ?client_id, "Failed to send Pong");
407 }
408 }
409 }
410 }
411 }
412
413 let t3 = Instant::now();
415 {
416 let _span = debug_span!("stage3_simulate").entered();
417 world.simulate();
419 }
420 metrics::histogram!("aetheris_stage_duration_seconds", "stage" => "simulate")
421 .record(t3.elapsed().as_secs_f64());
422
423 let t4 = Instant::now();
425 let deltas = {
426 let _span = debug_span!("stage4_extract").entered();
427 world.extract_deltas()
428 };
429 metrics::histogram!("aetheris_stage_duration_seconds", "stage" => "extract")
430 .record(t4.elapsed().as_secs_f64());
431
432 let t5 = Instant::now();
434 if !deltas.is_empty() {
435 let mut broadcast_count: u64 = 0;
436
437 let stage_span = debug_span!("stage5_send", count = deltas.len());
438 let _guard = stage_span.enter();
439
440 for delta in deltas {
441 let encode_result = encoder.encode(&delta, encode_buffer);
442 match encode_result {
443 Ok(len) if len > aetheris_protocol::MAX_SAFE_PAYLOAD_SIZE => {
444 match Self::fragment_and_broadcast(
445 encode_buffer,
446 len,
447 next_message_id,
448 encoder,
449 transport,
450 )
451 .await
452 {
453 Ok(count) => broadcast_count += count,
454 Err(e) => error!(error = ?e, "Failed to fragment and broadcast delta"),
455 }
456 }
457 Ok(len) => {
458 if let Err(e) = transport.broadcast_unreliable(&encode_buffer[..len]).await
459 {
460 error!(error = ?e, "Failed to broadcast delta");
461 } else {
462 broadcast_count += 1;
463 }
464 }
465 Err(EncodeError::BufferOverflow {
466 needed,
467 available: _,
468 }) => {
469 let mut large_buffer = vec![0u8; needed];
470 if let Ok(len) = encoder.encode(&delta, &mut large_buffer) {
471 match Self::fragment_and_broadcast(
472 &large_buffer,
473 len,
474 next_message_id,
475 encoder,
476 transport,
477 )
478 .await
479 {
480 Ok(count) => broadcast_count += count,
481 Err(e) => {
482 error!(error = ?e, "Failed to fragment and broadcast large delta");
483 }
484 }
485 } else {
486 error!("Failed to encode into large scratch buffer");
487 }
488 }
489 Err(e) => {
490 metrics::counter!("aetheris_encode_errors_total").increment(1);
491 error!(
492 network_id = ?delta.network_id,
493 error = ?e,
494 "Failed to encode delta"
495 );
496 }
497 }
498 }
499 metrics::counter!("aetheris_packets_outbound_total").increment(broadcast_count);
500 metrics::counter!("aetheris_packets_broadcast_total").increment(broadcast_count);
501 }
502 metrics::histogram!("aetheris_stage_duration_seconds", "stage" => "send")
503 .record(t5.elapsed().as_secs_f64());
504 }
505
506 async fn fragment_and_broadcast(
507 data: &[u8],
508 len: usize,
509 next_message_id: &mut u32,
510 encoder: &dyn Encoder,
511 transport: &dyn GameTransport,
512 ) -> Result<u64, EncodeError> {
513 let message_id = *next_message_id;
514 *next_message_id = next_message_id.wrapping_add(1);
515
516 let chunk_size = aetheris_protocol::MAX_FRAGMENT_PAYLOAD_SIZE;
517 let chunks: Vec<_> = data[..len].chunks(chunk_size).collect();
518
519 let Ok(total_fragments) = u16::try_from(chunks.len()) else {
520 error!(
521 message_id,
522 chunks = chunks.len(),
523 "Too many fragments required for message; dropping payload"
524 );
525 return Err(EncodeError::Io(std::io::Error::new(
526 std::io::ErrorKind::InvalidData,
527 "Too many fragments",
528 )));
529 };
530
531 let mut sent_count = 0;
532 for (i, chunk) in chunks.into_iter().enumerate() {
533 let Ok(fragment_index) = u16::try_from(i) else {
534 error!(message_id, index = i, "Fragment index overflow; stopping");
535 break;
536 };
537
538 let fragment = FragmentedEvent {
539 message_id,
540 fragment_index,
541 total_fragments,
542 payload: chunk.to_vec(),
543 };
544 let fragment_event = NetworkEvent::Fragment {
545 client_id: aetheris_protocol::types::ClientId(0),
546 fragment,
547 };
548
549 match encoder.encode_event(&fragment_event) {
550 Ok(encoded_fragment) => {
551 if let Err(e) = transport.broadcast_unreliable(&encoded_fragment).await {
552 error!(error = ?e, "Failed to broadcast fragment");
553 } else {
554 sent_count += 1;
555 }
556 }
557 Err(e) => {
558 error!(error = ?e, "Failed to encode fragment event");
559 }
560 }
561 }
562
563 Ok(sent_count)
564 }
565}
566
567fn can_run_playground_command(jti: &str) -> bool {
572 jti == "admin" || std::env::var("AETHERIS_ENV").ok().as_deref() == Some("dev")
575}