1#![cfg(not(target_arch = "wasm32"))]
3use std::collections::{HashMap, VecDeque};
9use std::net::SocketAddr;
10use std::sync::Arc;
11use tokio::sync::Mutex;
12
13use async_trait::async_trait;
14use rcgen::{CertificateParams, KeyPair};
15use sha2::{Digest, Sha256};
16use tracing::{error, info, warn};
17use wtransport::endpoint::IncomingSession;
18use wtransport::endpoint::endpoint_side::Server as ServerSide;
19use wtransport::{Connection, Endpoint, Identity, ServerConfig};
20
21use aetheris_protocol::events::NetworkEvent;
22use aetheris_protocol::traits::{ClientId, PlatformTransport, TransportError};
23
24type ConnectionMap = HashMap<ClientId, Connection>;
25type AuthValidator = Arc<dyn Fn(&str) -> bool + Send + Sync>;
26
27pub struct WebTransportBridge {
29 _endpoint: Arc<Endpoint<ServerSide>>,
30 events: Arc<Mutex<VecDeque<NetworkEvent>>>,
31 connections: Arc<Mutex<ConnectionMap>>,
33 connected_client_count: Arc<std::sync::atomic::AtomicUsize>,
34 cert_hash: String,
35 auth_validator: Option<AuthValidator>,
36}
37
38impl WebTransportBridge {
39 pub async fn new(addr: SocketAddr, auth_validator: Option<AuthValidator>) -> Self {
46 let (identity, cert_hash) = generate_self_signed_identity().await;
47
48 let config = ServerConfig::builder()
49 .with_bind_address(addr)
50 .with_identity(identity)
51 .max_idle_timeout(Some(std::time::Duration::from_secs(30)))
52 .expect("Invalid idle timeout")
53 .keep_alive_interval(Some(std::time::Duration::from_secs(10)))
54 .build();
55
56 let endpoint = Endpoint::server(config).expect("Failed to create WebTransport endpoint");
57 let endpoint = Arc::new(endpoint);
58 let events = Arc::new(Mutex::new(VecDeque::new()));
59 let connections = Arc::new(Mutex::new(HashMap::new()));
60 let connected_client_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
61
62 let server = Self {
63 _endpoint: Arc::clone(&endpoint),
64 events,
65 connections,
66 connected_client_count,
67 cert_hash,
68 auth_validator,
69 };
70
71 server.spawn_listener(endpoint);
72
73 server
74 }
75
76 fn spawn_listener(&self, endpoint: Arc<Endpoint<ServerSide>>) {
77 let events = Arc::clone(&self.events);
78 let connections = Arc::clone(&self.connections);
79 let client_count = Arc::clone(&self.connected_client_count);
80 let validator = self.auth_validator.clone();
81
82 tokio::spawn(async move {
83 if let Ok(local_addr) = endpoint.local_addr() {
84 info!(
85 "WebTransport listener task started (address: {:?})",
86 local_addr
87 );
88 }
89 loop {
90 info!("WebTransport waiting for next incoming session...");
91 let incoming = endpoint.accept().await;
92 info!("WebTransport received an incoming session attempt");
93 let events_inner = Arc::clone(&events);
94 let connections_inner = Arc::clone(&connections);
95 let count_inner = Arc::clone(&client_count);
96 let validator_inner = validator.clone();
97
98 tokio::spawn(async move {
99 handle_incoming_connection(
100 incoming,
101 events_inner,
102 connections_inner,
103 count_inner,
104 validator_inner,
105 )
106 .await;
107 });
108 }
109 });
110 }
111}
112
113#[async_trait]
114impl PlatformTransport for WebTransportBridge {
115 #[tracing::instrument(skip(self, data), fields(client_id = %client_id.0, size = data.len()))]
116 async fn send_unreliable(
117 &self,
118 client_id: ClientId,
119 data: &[u8],
120 ) -> Result<(), TransportError> {
121 let mut conn_guard = self.connections.lock().await;
122 let connection_map: &mut ConnectionMap = &mut conn_guard;
123 if let Some(conn) = connection_map.get_mut(&client_id) {
124 let conn: &mut Connection = conn;
125 if let Err(e) = conn.send_datagram(data) {
126 metrics::counter!("aetheris_transport_errors_total", "transport" => "webtransport", "type" => "datagram_send_fail").increment(1);
127 return Err(TransportError::Io(std::io::Error::other(format!(
128 "{:?}",
129 e
130 ))));
131 }
132 metrics::counter!("aetheris_transport_packets_total", "transport" => "webtransport", "direction" => "outbound", "channel" => "unreliable").increment(1);
133 metrics::counter!("aetheris_transport_bytes_total", "transport" => "webtransport", "direction" => "outbound", "channel" => "unreliable").increment(data.len() as u64);
134 Ok(())
135 } else {
136 metrics::counter!("aetheris_transport_errors_total", "transport" => "webtransport", "type" => "client_not_connected").increment(1);
137 Err(TransportError::ClientNotConnected(client_id))
138 }
139 }
140
141 #[tracing::instrument(skip(self, data), fields(client_id = %client_id.0, size = data.len()))]
142 async fn send_reliable(&self, client_id: ClientId, data: &[u8]) -> Result<(), TransportError> {
143 let conn = {
144 let conn_guard = self.connections.lock().await;
145 conn_guard.get(&client_id).cloned()
146 };
147
148 if let Some(conn) = conn {
149 match conn.open_bi().await {
151 Ok(opening) => match opening.await {
152 Ok((mut send_stream, _recv_stream)) => {
153 send_stream.write_all(data).await.map_err(|e| {
154 metrics::counter!("aetheris_transport_errors_total", "transport" => "webtransport", "type" => "stream_write_fail").increment(1);
155 TransportError::Io(std::io::Error::other(format!(
156 "Failed to send reliable data: {}",
157 e
158 )))
159 })?;
160 send_stream.finish().await.map_err(|e| {
161 metrics::counter!("aetheris_transport_errors_total", "transport" => "webtransport", "type" => "stream_finish_fail").increment(1);
162 TransportError::Io(std::io::Error::other(format!(
163 "Failed to finish reliable stream: {}",
164 e
165 )))
166 })?;
167 metrics::counter!("aetheris_transport_packets_total", "transport" => "webtransport", "direction" => "outbound", "channel" => "reliable").increment(1);
168 metrics::counter!("aetheris_transport_bytes_total", "transport" => "webtransport", "direction" => "outbound", "channel" => "reliable").increment(data.len() as u64);
169 Ok(())
170 }
171 Err(e) => {
172 metrics::counter!("aetheris_transport_errors_total", "transport" => "webtransport", "type" => "stream_open_fail").increment(1);
173 Err(TransportError::Io(std::io::Error::other(format!(
174 "Failed to establish bidirectional stream: {}",
175 e
176 ))))
177 }
178 },
179 Err(e) => {
180 metrics::counter!("aetheris_transport_errors_total", "transport" => "webtransport", "type" => "stream_init_fail").increment(1);
181 Err(TransportError::Io(std::io::Error::other(format!(
182 "Failed to initiate bidirectional stream: {}",
183 e
184 ))))
185 }
186 }
187 } else {
188 metrics::counter!("aetheris_transport_errors_total", "transport" => "webtransport", "type" => "client_not_connected").increment(1);
189 Err(TransportError::ClientNotConnected(client_id))
190 }
191 }
192
193 #[tracing::instrument(skip(self, data), fields(size = data.len()))]
194 async fn broadcast_unreliable(&self, data: &[u8]) -> Result<(), TransportError> {
195 let mut conn_guard = self.connections.lock().await;
196 let connection_map: &mut ConnectionMap = &mut conn_guard;
197 for (client_id, conn) in connection_map.iter_mut() {
198 if let Err(e) = conn.send_datagram(data) {
199 metrics::counter!("aetheris_transport_errors_total", "transport" => "webtransport", "type" => "broadcast_fail").increment(1);
200 warn!(
201 "Failed to broadcast unreliable datagram to client {:?}: {:?}",
202 client_id, e
203 );
204 } else {
205 metrics::counter!("aetheris_transport_packets_total", "transport" => "webtransport", "direction" => "outbound", "channel" => "broadcast_unreliable").increment(1);
206 metrics::counter!("aetheris_transport_bytes_total", "transport" => "webtransport", "direction" => "outbound", "channel" => "broadcast_unreliable").increment(data.len() as u64);
207 }
208 }
209 Ok(())
210 }
211
212 #[tracing::instrument(skip(self))]
213 async fn poll_events(&mut self) -> Result<Vec<NetworkEvent>, TransportError> {
214 let mut events = self.events.lock().await;
215 Ok(events.drain(..).collect())
216 }
217
218 async fn connected_client_count(&self) -> usize {
219 self.connected_client_count
220 .load(std::sync::atomic::Ordering::Relaxed)
221 }
222
223 async fn disconnect(&self, client_id: ClientId) -> Result<(), TransportError> {
224 let mut conn_guard = self.connections.lock().await;
225 if let Some(conn) = conn_guard.remove(&client_id) {
226 conn.close(4001u32.into(), b"Session Replaced");
229 self.connected_client_count
230 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
231 Ok(())
232 } else {
233 Err(TransportError::ClientNotConnected(client_id))
234 }
235 }
236}
237
238impl WebTransportBridge {
239 #[must_use]
241 pub fn cert_hash(&self) -> &str {
242 &self.cert_hash
243 }
244}
245
246async fn handle_incoming_connection(
247 incoming: IncomingSession,
248 events: Arc<Mutex<VecDeque<NetworkEvent>>>,
249 connections: Arc<Mutex<ConnectionMap>>,
250 connected_client_count: Arc<std::sync::atomic::AtomicUsize>,
251 auth_validator: Option<AuthValidator>,
252) {
253 info!("Handling incoming WebTransport connection...");
254 let session_request = match incoming.await {
255 Ok(r) => {
256 info!(
257 "WebTransport session request received from {:?}",
258 r.remote_address()
259 );
260 r
261 }
262 Err(e) => {
263 tracing::debug!(
265 "Failed to accept incoming WebTransport session request: {}",
266 e
267 );
268 return;
269 }
270 };
271
272 let connection = match session_request.accept().await {
273 Ok(c) => {
274 info!(
275 "WebTransport connection accepted for {:?}",
276 c.remote_address()
277 );
278 c
279 }
280 Err(e) => {
281 warn!("Failed to accept WebTransport connection: {}", e);
282 return;
283 }
284 };
285
286 if let Some(validator) = auth_validator {
288 let auth_result = tokio::time::timeout(std::time::Duration::from_secs(5), async {
289 match connection.accept_bi().await {
290 Ok((_send, recv)) => {
291 use tokio::io::AsyncReadExt;
292 let mut buffer = String::new();
293 if tokio::time::timeout(
295 std::time::Duration::from_secs(2),
296 recv.take(1024).read_to_string(&mut buffer),
297 )
298 .await
299 .is_ok_and(|res| res.is_ok())
300 {
301 validator(buffer.trim())
302 } else {
303 false
304 }
305 }
306 Err(_) => false,
307 }
308 })
309 .await
310 .unwrap_or(false);
311
312 if !auth_result {
313 warn!("WebTransport connection rejected: invalid or missing First-Message Auth token");
314 return;
315 }
316 }
317
318 let client_id = ClientId(rand::random());
319 {
320 let mut conn_guard = connections.lock().await;
321 let connection_map: &mut ConnectionMap = &mut conn_guard;
322 connection_map.insert(client_id, connection.clone());
323 }
324
325 connected_client_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
326
327 {
328 let mut events_guard = events.lock().await;
329 events_guard.push_back(NetworkEvent::ClientConnected(client_id));
330 }
331
332 info!("Client connected via WebTransport: {:?}", client_id);
333
334 let conn_clone = connection.clone();
336 let events_clone = Arc::clone(&events);
337 let connections_clone = Arc::clone(&connections);
338 let count_clone = Arc::clone(&connected_client_count);
339 tokio::spawn(async move {
340 loop {
341 tokio::select! {
342 datagram = conn_clone.receive_datagram() => {
343 match datagram {
344 Ok(data) => {
345 let mut events_guard = events_clone.lock().await;
346 events_guard.push_back(NetworkEvent::UnreliableMessage {
347 client_id,
348 data: data.to_vec(),
349 });
350 metrics::counter!("aetheris_transport_packets_total", "transport" => "webtransport", "direction" => "inbound", "channel" => "unreliable").increment(1);
351 metrics::counter!("aetheris_transport_bytes_total", "transport" => "webtransport", "direction" => "inbound", "channel" => "unreliable").increment(data.len() as u64);
352 }
353 Err(e) => {
354 tracing::debug!("WebTransport receive_datagram loop ended for client {:?}: {:?}", client_id, e);
356 let mut events_guard = events_clone.lock().await;
357 events_guard.push_back(NetworkEvent::SessionClosed(client_id));
358 break;
359 }
360 }
361 }
362 stream_res = conn_clone.accept_bi() => {
363 match stream_res {
364 Ok(bi) => {
365 let events_inner = Arc::clone(&events_clone);
366 tokio::spawn(async move {
367 use tokio::io::AsyncReadExt;
368 const MAX_RELIABLE_PAYLOAD_SIZE: usize = 1024 * 1024; let mut buffer = Vec::new();
371 let mut limited_reader = bi.1.take(MAX_RELIABLE_PAYLOAD_SIZE as u64 + 1);
374
375 if let Err(e) = limited_reader.read_to_end(&mut buffer).await {
376 error!("Failed to read reliable stream for client {:?}: {}", client_id, e);
377 let mut events_guard = events_inner.lock().await;
378 events_guard.push_back(NetworkEvent::StreamReset(client_id));
379 return;
380 }
381
382 if buffer.len() > MAX_RELIABLE_PAYLOAD_SIZE {
384 error!("Reliable message exceeded maximum size ({}) from client {:?}", MAX_RELIABLE_PAYLOAD_SIZE, client_id);
385 return;
386 }
387
388 {
389 let mut events_guard = events_inner.lock().await;
390 let buffer_len = buffer.len() as u64;
391 events_guard.push_back(NetworkEvent::ReliableMessage {
392 client_id,
393 data: buffer,
394 });
395 metrics::counter!("aetheris_transport_packets_total", "transport" => "webtransport", "direction" => "inbound", "channel" => "reliable").increment(1);
396 metrics::counter!("aetheris_transport_bytes_total", "transport" => "webtransport", "direction" => "inbound", "channel" => "reliable").increment(buffer_len);
397 }
398 });
399 }
400 Err(e) => {
401 tracing::debug!("WebTransport accept_bi loop ended for client {:?}: {:?}", client_id, e);
402 break;
403 }
404 }
405 }
406 }
407 }
408
409 tracing::info!("Client session finalized via WebTransport: {:?}", client_id);
410 count_clone.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
411 {
412 let mut conn_guard = connections_clone.lock().await;
413 conn_guard.remove(&client_id);
414 }
415
416 let mut events_guard = events_clone.lock().await;
417 events_guard.push_back(NetworkEvent::ClientDisconnected(client_id));
418 });
419}
420
421async fn generate_self_signed_identity() -> (Identity, String) {
422 const CERT_VERSION: &str = "v3";
425
426 let cert_dir = std::path::PathBuf::from("target/dev-certs");
427 let cert_path = cert_dir.join("cert.pem");
428 let key_path = cert_dir.join("key.pem");
429 let hash_path = cert_dir.join("cert.sha256");
430 let version_path = cert_dir.join("cert.version");
431
432 if cert_path.exists() && key_path.exists() && hash_path.exists() {
433 let version_ok = tokio::fs::read_to_string(&version_path)
434 .await
435 .map(|v| v.trim() == CERT_VERSION)
436 .unwrap_or(false);
437
438 if version_ok {
439 match (
440 tokio::fs::read_to_string(&hash_path).await,
441 Identity::load_pemfiles(&cert_path, &key_path).await,
442 ) {
443 (Ok(hash_b64), Ok(identity)) => {
444 info!("--------------------------------------------------");
445 info!("WEBTRANSPORT SELF-SIGNED CERTIFICATE LOADED");
446 info!("SHA-256 Hash (Base64): {}", hash_b64.trim());
447 info!("(Delete target/dev-certs/ to force regeneration)");
448 info!("--------------------------------------------------");
449 return (identity, hash_b64.trim().to_string());
450 }
451 (hash_err, identity_err) => {
452 warn!(
453 "Failed to load persistent certificate (hash_err: {:?}, id_err: {:?}). Regenerating...",
454 hash_err.is_err(),
455 identity_err.is_err()
456 );
457 }
459 }
460 } else {
461 info!(
462 "Dev cert version mismatch (expected {CERT_VERSION}) — regenerating to pick up updated SAN list"
463 );
464 }
465 }
466
467 let mut params = CertificateParams::new(vec!["localhost".to_string()])
472 .expect("Failed to create cert params");
473 params
474 .subject_alt_names
475 .push(rcgen::SanType::IpAddress(std::net::IpAddr::V4(
476 std::net::Ipv4Addr::new(127, 0, 0, 1),
477 )));
478 params
479 .subject_alt_names
480 .push(rcgen::SanType::IpAddress(std::net::IpAddr::V6(
481 std::net::Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1),
482 )));
483 params.not_before = time::OffsetDateTime::now_utc() - time::Duration::days(1);
484 params.not_after = time::OffsetDateTime::now_utc() + time::Duration::days(10);
485
486 let key_pair = KeyPair::generate().expect("Failed to generate key pair");
487 let cert = params
488 .self_signed(&key_pair)
489 .expect("Failed to self-sign cert");
490
491 let cert_pem = cert.pem();
492 let key_pem = key_pair.serialize_pem();
493
494 let cert_der = cert.der();
495 let mut hasher = Sha256::new();
496 hasher.update(cert_der.as_ref());
497 let hash = hasher.finalize();
498 let hash_b64 = base64::Engine::encode(&base64::prelude::BASE64_STANDARD, hash);
499
500 tokio::fs::create_dir_all(&cert_dir)
501 .await
502 .expect("Failed to create cert directory");
503
504 tokio::fs::write(&cert_path, &cert_pem)
505 .await
506 .expect("Failed to write cert");
507 tokio::fs::write(&key_path, &key_pem)
508 .await
509 .expect("Failed to write key");
510 tokio::fs::write(&hash_path, &hash_b64)
511 .await
512 .expect("Failed to write cert hash");
513 tokio::fs::write(&version_path, CERT_VERSION)
514 .await
515 .expect("Failed to write cert version");
516
517 let validity_days = (params.not_after - params.not_before).whole_days();
518
519 info!("--------------------------------------------------");
520 info!("WEBTRANSPORT SELF-SIGNED CERTIFICATE GENERATED");
521 info!("SHA-256 Hash (Base64): {}", hash_b64);
522 info!(
523 "Valid for: {} days (Chrome serverCertificateHashes constraint: <= 14 days)",
524 validity_days
525 );
526 info!("Saved to: {}", cert_dir.display());
527 info!("--------------------------------------------------");
528
529 Identity::load_pemfiles(&cert_path, &key_path)
530 .await
531 .map(|id| (id, hash_b64))
532 .expect("Failed to load identity from persistent files")
533}
534
535#[cfg(test)]
536mod tests {
537 use super::*;
538 use std::time::Duration;
539 use tokio::time::timeout;
540
541 #[tokio::test]
542 async fn test_concurrent_send_unreliable_load() {
543 let addr = "127.0.0.1:0".parse().unwrap();
545 let mut server = WebTransportBridge::new(addr, None).await;
546 let server_addr = server._endpoint.local_addr().unwrap();
547
548 let num_clients = 100;
549 let mut client_tasks = Vec::new();
550
551 let server_hash = server.cert_hash().to_string();
553 for i in 0..num_clients {
554 let hash_str = server_hash.clone();
555 client_tasks.push(tokio::spawn(async move {
556 let hash_bytes =
557 base64::Engine::decode(&base64::prelude::BASE64_STANDARD, hash_str.trim())
558 .expect("Failed to decode base64 hash");
559
560 let hash = wtransport::tls::Sha256Digest::new(
561 hash_bytes.try_into().expect("Invalid hash length"),
562 );
563
564 let config = wtransport::ClientConfig::builder()
565 .with_bind_address("127.0.0.1:0".parse().unwrap())
566 .with_server_certificate_hashes(vec![hash])
567 .build();
568
569 let endpoint = Endpoint::client(config).expect("Failed to create client endpoint");
570
571 let url = format!("https://{}/", server_addr);
572
573 let connection = match timeout(Duration::from_secs(5), endpoint.connect(&url)).await
574 {
575 Ok(Ok(conn)) => conn,
576 Ok(Err(e)) => panic!("Client {} failed to connect: {:?}", i, e),
577 Err(_) => panic!("Client {} connection timed out", i),
578 };
579
580 tokio::time::sleep(Duration::from_millis(i as u64 * 10)).await;
582 let msg = format!("message from client {}", i);
583 connection
584 .send_datagram(msg.as_bytes())
585 .expect("Failed to send datagram");
586
587 tokio::time::sleep(Duration::from_secs(2)).await;
589 }));
590 }
591
592 let mut connected_count = 0;
594 let mut message_count = 0;
595 let mut peak_client_count = 0;
596 let start = std::time::Instant::now();
597
598 while (connected_count < num_clients || message_count < num_clients)
599 && start.elapsed() < Duration::from_secs(20)
600 {
601 let events = server.poll_events().await.unwrap();
602 for event in events {
603 match event {
604 NetworkEvent::ClientConnected(_) => connected_count += 1,
605 NetworkEvent::UnreliableMessage { .. } => message_count += 1,
606 _ => {}
607 }
608 }
609
610 let current_count = server.connected_client_count().await;
611 if current_count > peak_client_count {
612 peak_client_count = current_count;
613 }
614
615 tokio::time::sleep(Duration::from_millis(50)).await;
616 }
617
618 for task in client_tasks {
620 task.await.expect("Client task panicked or failed");
621 }
622
623 assert!(
624 connected_count >= num_clients,
625 "Only {}/{} clients connected at some point",
626 connected_count,
627 num_clients
628 );
629 assert!(
630 peak_client_count > 0,
631 "No clients were ever recorded as connected in the atomic counter"
632 );
633 assert!(
636 message_count >= 95,
637 "Only {}/{} messages received",
638 message_count,
639 num_clients
640 );
641 }
642}