1use crate::engine::EngineRequest;
7use crate::event::Event;
8use crate::message::{Capabilities, ClientMessage, ErrorResponse, Message, ServerMessage};
9use argon2::{Argon2, PasswordVerifier, password_hash::PasswordHashString};
10use rustls::ServerConfig;
11use rustls::pki_types::{
12 CertificateDer, PrivateKeyDer,
13 pem::{self, PemObject},
14};
15use std::collections::HashMap;
16use std::fmt::Debug;
17use std::net::{IpAddr, SocketAddr};
18use std::path::PathBuf;
19use std::result;
20use std::sync::Arc;
21use std::time::Duration;
22use thiserror::Error;
23use tokio::io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter};
24use tokio::net::{TcpListener, TcpStream};
25use tokio::sync::{RwLock, broadcast, mpsc, oneshot};
26use tokio::time::timeout;
27#[cfg(feature = "tokio-graceful-shutdown")]
28use tokio_graceful_shutdown::{FutureExt, IntoSubsystem, SubsystemHandle};
29use tokio_rustls::TlsAcceptor;
30use tracing::{info, instrument, warn};
31
32pub type ClientRegistry = HashMap<String, PasswordHashString>;
36
37#[derive(Error, Debug)]
38pub enum Error {
39 #[error(transparent)]
40 Io(#[from] io::Error),
41
42 #[error(transparent)]
43 Oneshot(#[from] oneshot::error::RecvError),
44
45 #[error("client sent unexpected message: {0:?}")]
46 UnexpectedMessage(ClientMessage),
47
48 #[error("client requested an unsupported version range: {0} - {1}")]
49 UnsupportedVersion(u8, u8),
50
51 #[error("unrecognized request code: {0}")]
52 UnknownRequest(u8),
53
54 #[error("client provided invalid credentials")]
55 InvalidCredentials,
56}
57
58pub type Result<T> = result::Result<T, Error>;
59
60#[derive(Debug)]
63#[allow(dead_code)]
64pub struct CustomRequestMessage {
65 pub client_id: String,
66 pub message: Message,
67 pub response: oneshot::Sender<Option<Message>>,
68}
69
70pub struct NetworkListener {
103 clients: Arc<RwLock<ClientRegistry>>,
104 addr: SocketAddr,
105 tls_acceptor: TlsAcceptor,
106 engine_tx: mpsc::Sender<(EngineRequest, oneshot::Sender<ServerMessage>)>,
107 event_tx: broadcast::Sender<Event>,
108 custom_req_tx: Option<mpsc::Sender<CustomRequestMessage>>,
109}
110
111impl Debug for NetworkListener {
112 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113 f.debug_struct("NetworkListener")
114 .field("clients", &self.clients)
115 .field("addr", &self.addr)
116 .field("engine_tx", &self.engine_tx)
117 .field("event_tx", &self.event_tx)
118 .field("custom_req_tx", &self.custom_req_tx)
119 .finish()
120 }
121}
122
123impl NetworkListener {
124 pub async fn listen(&self) -> Result<()> {
131 let listener = TcpListener::bind(self.addr).await?;
132 let mut con_counter: usize = 0;
133
134 loop {
135 let (stream, peer_addr) = listener.accept().await?;
136 let conn_id = con_counter;
137 con_counter += 1;
138 let event_tx = self.event_tx.clone();
139
140 self.handle_stream(stream, peer_addr, self.clients.clone(), conn_id, event_tx);
141 }
142 }
143
144 #[instrument(skip(self, stream, peer_addr, clients, event_tx))]
145 fn handle_stream(
146 &self,
147 stream: TcpStream,
148 peer_addr: SocketAddr,
149 clients: Arc<RwLock<ClientRegistry>>,
150 conn_id: usize,
151 event_tx: broadcast::Sender<Event>,
152 ) {
153 let acceptor = self.tls_acceptor.clone();
154 let engine_tx = self.engine_tx.clone();
155 let custom_req_tx = self.custom_req_tx.clone();
156
157 tokio::spawn(async move {
158 info!("new connection from {}", peer_addr);
159
160 let stream = match acceptor.accept(stream).await {
161 Ok(stream) => stream,
162 Err(error) => {
163 warn!("failed to accept stream: {}", error);
164 return;
165 }
166 };
167
168 let (reader, writer) = io::split(stream);
169 let mut reader = BufReader::new(reader);
170 let mut writer = BufWriter::new(writer);
171
172 let (client_id, local_ip, _version) = match timeout(
173 Duration::from_secs(2),
174 authenticate(&mut reader, &mut writer, clients),
175 )
176 .await
177 {
178 Ok(Ok(result)) => result,
179 Ok(Err(Error::UnexpectedMessage(message))) => {
180 warn!("client error: unexpected message: {:?}", message);
181 let _ = write_to_stream(
182 &mut writer,
183 ServerMessage::Error(ErrorResponse::UnexpectedMessage),
184 )
185 .await;
186 return;
187 }
188 Ok(Err(Error::UnsupportedVersion(min_version, max_version))) => {
189 warn!("client error: unsupported version range: {min_version} - {max_version}");
190 let _ = write_to_stream(
191 &mut writer,
192 ServerMessage::Error(ErrorResponse::UnsupportedVersionRange),
193 )
194 .await;
195 return;
196 }
197 Ok(Err(Error::InvalidCredentials)) => {
198 warn!("client error: invalid credentials");
199 let _ = write_to_stream(
200 &mut writer,
201 ServerMessage::Error(ErrorResponse::InvalidCredentials),
202 )
203 .await;
204 return;
205 }
206 Ok(Err(Error::Io(error)))
207 if matches!(
208 error.kind(),
209 io::ErrorKind::InvalidInput | io::ErrorKind::InvalidData
210 ) =>
211 {
212 warn!("client error: malformed message: {:?}", error);
213 let _ = write_to_stream(
214 &mut writer,
215 ServerMessage::Error(ErrorResponse::MalformedMessage),
216 )
217 .await;
218 return;
219 }
220 Ok(Err(error)) => {
221 warn!("client error: connection died: {:?}", error);
222 return;
223 }
224 Err(_) => {
225 warn!("client error: authentication timed out");
226 return;
227 }
228 };
229
230 let _ = event_tx.send(Event::ClientConnect {
231 client_id: client_id.clone(),
232 conn_id,
233 local_ip,
234 });
235
236 match handle_connection(
237 &mut reader,
238 &mut writer,
239 &client_id,
240 engine_tx,
241 custom_req_tx,
242 )
243 .await
244 {
245 Ok(()) => {
246 let _ = event_tx.send(Event::ClientDisconnect { client_id, conn_id });
247 return;
248 }
249 Err(Error::UnexpectedMessage(message)) => {
250 warn!("client error: unexpected message: {:?}", message);
251 let _ = write_to_stream(
252 &mut writer,
253 ServerMessage::Error(ErrorResponse::UnexpectedMessage),
254 )
255 .await;
256 }
257 Err(Error::Io(error))
258 if matches!(
259 error.kind(),
260 io::ErrorKind::InvalidInput | io::ErrorKind::InvalidData
261 ) =>
262 {
263 warn!("client error: malformed message: {:?}", error);
264 let _ = write_to_stream(
265 &mut writer,
266 ServerMessage::Error(ErrorResponse::MalformedMessage),
267 )
268 .await;
269 return;
270 }
271 Err(error) => {
272 warn!("client error: connection died: {:?}", error);
273 return;
274 }
275 }
276
277 let _ = event_tx.send(Event::ClientConnectionLoss { client_id, conn_id });
278 });
279 }
280}
281
282#[cfg(feature = "tokio-graceful-shutdown")]
283impl IntoSubsystem<Error> for NetworkListener {
284 async fn run(self, subsys: &mut SubsystemHandle) -> Result<()> {
285 if let Ok(result) = self.listen().cancel_on_shutdown(subsys).await {
286 result?
287 }
288
289 Ok(())
290 }
291}
292
293async fn read_from_stream<S: AsyncRead + Unpin + Send>(stream: &mut S) -> Result<ClientMessage> {
294 let message_type = stream.read_u8().await?;
295 let payload_length = stream.read_u32_le().await?;
296
297 if payload_length == 0 {
298 return Ok(Message::new(message_type, vec![]).try_into()?);
299 }
300
301 let mut message = vec![0; payload_length as usize];
302 stream.read_exact(&mut message).await?;
303
304 Ok(Message::new(message_type, message).try_into()?)
305}
306
307async fn write_to_stream<S: AsyncWrite + Unpin + Send>(
308 stream: &mut S,
309 message: impl Into<Message>,
310) -> Result<()> {
311 let message: Message = message.into();
312 stream.write_all(&message.into_bytes()).await?;
313 stream.flush().await?;
314
315 Ok(())
316}
317
318#[derive(Debug, Error)]
319pub enum BuilderError {
320 #[error("missing field: {0}")]
321 MissingField(&'static str),
322
323 #[error(transparent)]
324 AddrParse(#[from] std::net::AddrParseError),
325
326 #[error("failed to read PEM file at {path}: {source}")]
327 Pem {
328 path: PathBuf,
329 #[source]
330 source: pem::Error,
331 },
332
333 #[error(transparent)]
334 Rustls(#[from] rustls::Error),
335}
336
337pub type BuilderResult<T> = result::Result<T, BuilderError>;
338
339#[derive(Debug, Default)]
367pub struct NetworkListenerBuilder {
368 address: Option<String>,
369 cert_path: Option<PathBuf>,
370 key_path: Option<PathBuf>,
371 clients: Option<Arc<RwLock<ClientRegistry>>>,
372 engine_tx: Option<mpsc::Sender<(EngineRequest, oneshot::Sender<ServerMessage>)>>,
373 event_tx: Option<broadcast::Sender<Event>>,
374 custom_req_tx: Option<mpsc::Sender<CustomRequestMessage>>,
375}
376
377impl NetworkListenerBuilder {
378 pub fn new() -> Self {
379 Self {
380 address: None,
381 cert_path: None,
382 key_path: None,
383 clients: None,
384 engine_tx: None,
385 event_tx: None,
386 custom_req_tx: None,
387 }
388 }
389
390 pub fn custom_req_tx(
391 self,
392 custom_req_tx: mpsc::Sender<CustomRequestMessage>,
393 ) -> NetworkListenerBuilder {
394 NetworkListenerBuilder {
395 clients: self.clients,
396 address: self.address,
397 cert_path: self.cert_path,
398 key_path: self.key_path,
399 engine_tx: self.engine_tx,
400 event_tx: self.event_tx,
401 custom_req_tx: Some(custom_req_tx),
402 }
403 }
404
405 pub fn address(mut self, address: impl Into<String>) -> Self {
406 self.address = Some(address.into());
407 self
408 }
409
410 pub fn cert_path(mut self, path: impl Into<PathBuf>) -> Self {
411 self.cert_path = Some(path.into());
412 self
413 }
414
415 pub fn key_path(mut self, path: impl Into<PathBuf>) -> Self {
416 self.key_path = Some(path.into());
417 self
418 }
419
420 pub fn clients(mut self, clients: Arc<RwLock<ClientRegistry>>) -> Self {
421 self.clients = Some(clients);
422 self
423 }
424
425 pub fn engine_tx(
426 mut self,
427 tx: mpsc::Sender<(EngineRequest, oneshot::Sender<ServerMessage>)>,
428 ) -> Self {
429 self.engine_tx = Some(tx);
430 self
431 }
432
433 pub fn event_tx(mut self, tx: broadcast::Sender<Event>) -> Self {
434 self.event_tx = Some(tx);
435 self
436 }
437
438 pub fn build(self) -> BuilderResult<NetworkListener> {
442 let addr: SocketAddr = self
443 .address
444 .ok_or_else(|| BuilderError::MissingField("address"))?
445 .parse()?;
446
447 let cert_path = self
448 .cert_path
449 .ok_or_else(|| BuilderError::MissingField("cert_path"))?;
450 let key_path = self
451 .key_path
452 .ok_or_else(|| BuilderError::MissingField("key_path"))?;
453
454 let certs = CertificateDer::pem_file_iter(&cert_path)
455 .map_err(|err| BuilderError::Pem {
456 path: cert_path.clone(),
457 source: err,
458 })?
459 .collect::<result::Result<Vec<_>, _>>()
460 .map_err(|err| BuilderError::Pem {
461 path: cert_path,
462 source: err,
463 })?;
464 let key = PrivateKeyDer::from_pem_file(&key_path).map_err(|err| BuilderError::Pem {
465 path: key_path,
466 source: err,
467 })?;
468
469 let config = ServerConfig::builder()
470 .with_no_client_auth()
471 .with_single_cert(certs, key)?;
472 let tls_acceptor = TlsAcceptor::from(Arc::new(config));
473
474 Ok(NetworkListener {
475 clients: self
476 .clients
477 .ok_or_else(|| BuilderError::MissingField("clients"))?,
478 addr,
479 tls_acceptor,
480 engine_tx: self
481 .engine_tx
482 .ok_or_else(|| BuilderError::MissingField("engine_tx"))?,
483 event_tx: self
484 .event_tx
485 .ok_or_else(|| BuilderError::MissingField("event_tx"))?,
486 custom_req_tx: self.custom_req_tx,
487 })
488 }
489}
490
491async fn handle_connection<R, W>(
496 reader: &mut R,
497 writer: &mut W,
498 client_id: &str,
499 engine_tx: mpsc::Sender<(EngineRequest, oneshot::Sender<ServerMessage>)>,
500 custom_req_tx: Option<mpsc::Sender<CustomRequestMessage>>,
501) -> Result<()>
502where
503 R: AsyncRead + Unpin + Send,
504 W: AsyncWrite + Unpin + Send,
505{
506 loop {
507 let message = match timeout(Duration::from_secs(30), read_from_stream(reader)).await {
508 Ok(Ok(message)) => message,
509 Ok(Err(error)) => return Err(error),
510 Err(_) => return Ok(()),
511 };
512
513 let engine_request = match message {
514 ClientMessage::Bloop { nfc_uid } => EngineRequest::Bloop {
515 nfc_uid,
516 client_id: client_id.to_string(),
517 },
518 ClientMessage::RetrieveAudio { achievement_id } => {
519 EngineRequest::RetrieveAudio { id: achievement_id }
520 }
521 ClientMessage::PreloadCheck {
522 audio_manifest_hash,
523 } => EngineRequest::PreloadCheck {
524 manifest_hash: audio_manifest_hash,
525 },
526 ClientMessage::Ping => {
527 write_to_stream(writer, ServerMessage::Pong).await?;
528 continue;
529 }
530 ClientMessage::Quit => break,
531 ClientMessage::Unknown(message) => {
532 if let Some(sender) = custom_req_tx.as_ref() {
533 let (resp_tx, resp_rx) = oneshot::channel();
534
535 let _ = sender
536 .send(CustomRequestMessage {
537 client_id: client_id.to_string(),
538 message,
539 response: resp_tx,
540 })
541 .await;
542
543 if let Some(message) = resp_rx.await? {
544 write_to_stream(writer, message).await?;
545 }
546 }
547
548 continue;
549 }
550 message => return Err(Error::UnexpectedMessage(message)),
551 };
552
553 let (resp_tx, resp_rx) = oneshot::channel();
554 let _ = engine_tx.send((engine_request, resp_tx)).await;
555 let response = resp_rx.await?;
556
557 write_to_stream(writer, response).await?;
558 }
559
560 Ok(())
561}
562
563async fn authenticate<R, W>(
568 reader: &mut R,
569 writer: &mut W,
570 clients: Arc<RwLock<ClientRegistry>>,
571) -> Result<(String, IpAddr, u8)>
572where
573 R: AsyncRead + Unpin + Send,
574 W: AsyncWrite + Unpin + Send,
575{
576 let (min_version, max_version) = match read_from_stream(reader).await? {
577 ClientMessage::ClientHandshake {
578 min_version,
579 max_version,
580 } => (min_version, max_version),
581 message => return Err(Error::UnexpectedMessage(message)),
582 };
583
584 if min_version > 3 || max_version < 3 {
585 return Err(Error::UnsupportedVersion(min_version, max_version));
586 }
587
588 write_to_stream(
589 writer,
590 ServerMessage::ServerHandshake {
591 accepted_version: 3,
592 capabilities: Capabilities::PreloadCheck,
593 },
594 )
595 .await?;
596
597 let (client_id, client_secret, ip_addr) = match read_from_stream(reader).await? {
598 ClientMessage::Authentication {
599 client_id,
600 client_secret,
601 ip_addr,
602 } => (client_id, client_secret, ip_addr),
603 message => return Err(Error::UnexpectedMessage(message)),
604 };
605
606 let clients = clients.read().await;
607 let Some(secret_hash) = clients.get(&client_id) else {
608 return Err(Error::InvalidCredentials);
609 };
610
611 if Argon2::default()
612 .verify_password(client_secret.as_bytes(), &secret_hash.password_hash())
613 .is_err()
614 {
615 return Err(Error::InvalidCredentials);
616 }
617
618 write_to_stream(writer, ServerMessage::AuthenticationAccepted).await?;
619
620 Ok((client_id.to_string(), ip_addr, 3))
621}
622
623#[cfg(test)]
624mod tests {
625 use super::*;
626 use std::fs;
627 use tempfile::tempdir;
628
629 #[tokio::test]
630 async fn builder_fails_with_missing_fields() {
631 let builder = NetworkListenerBuilder::new();
632 let result = builder.build();
633 assert!(matches!(result, Err(BuilderError::MissingField(_))));
634 }
635
636 #[tokio::test]
637 async fn builder_fails_with_invalid_address() {
638 let builder = NetworkListenerBuilder::new()
639 .address("invalid-addr")
640 .cert_path("cert.pem")
641 .key_path("key.pem")
642 .clients(Arc::new(RwLock::new(Default::default())))
643 .engine_tx(dummy_engine_tx())
644 .event_tx(dummy_event_tx());
645
646 let result = builder.build();
647 assert!(matches!(result, Err(BuilderError::AddrParse(_))));
648 }
649
650 #[tokio::test]
651 async fn builder_fails_on_invalid_pem_files() {
652 let dir = tempdir().unwrap();
653 let cert_path = dir.path().join("cert.pem");
654 let key_path = dir.path().join("key.pem");
655 fs::write(&cert_path, b"invalid-cert").unwrap();
656 fs::write(&key_path, b"invalid-key").unwrap();
657
658 let builder = NetworkListenerBuilder::new()
659 .address("127.0.0.1:12345")
660 .cert_path(&cert_path)
661 .key_path(&key_path)
662 .clients(Arc::new(RwLock::new(Default::default())))
663 .engine_tx(dummy_engine_tx())
664 .event_tx(dummy_event_tx());
665
666 let result = builder.build();
667 assert!(matches!(result, Err(BuilderError::Pem { .. })));
668 }
669
670 #[tokio::test]
671 async fn builder_succeeds_with_valid_dummy_pem() {
672 let dir = tempdir().unwrap();
673 let cert_path = dir.path().join("cert.pem");
674 let key_path = dir.path().join("key.pem");
675
676 let cert_data = include_bytes!("../examples/cert.pem");
677 let key_data = include_bytes!("../examples/key.pem");
678
679 fs::write(&cert_path, cert_data).unwrap();
680 fs::write(&key_path, key_data).unwrap();
681
682 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
683 let builder = NetworkListenerBuilder::new()
684 .address("127.0.0.1:12345")
685 .cert_path(&cert_path)
686 .key_path(&key_path)
687 .clients(Arc::new(RwLock::new(Default::default())))
688 .engine_tx(dummy_engine_tx())
689 .event_tx(dummy_event_tx());
690
691 let result = builder.build();
692 assert!(result.is_ok());
693 }
694
695 #[tokio::test]
696 async fn authentication_fails_with_wrong_client_id() {
697 let clients = Arc::new(RwLock::new(Default::default()));
698
699 let client_handshake = build_handshake(3, 3);
700 let authentication = build_authentication("unknown-client", "password", "127.0.0.1");
701
702 let server_handshake: Message = ServerMessage::ServerHandshake {
703 accepted_version: 3,
704 capabilities: Capabilities::PreloadCheck,
705 }
706 .into();
707
708 let mut reader = tokio_test::io::Builder::new()
709 .read(&client_handshake)
710 .read(&authentication)
711 .build();
712 let mut writer = tokio_test::io::Builder::new()
713 .write(&server_handshake.into_bytes())
714 .build();
715
716 let result = authenticate(&mut reader, &mut writer, clients).await;
717
718 assert!(matches!(result, Err(Error::InvalidCredentials)));
719 }
720
721 #[tokio::test]
722 async fn authentication_succeeds_with_correct_credentials() {
723 let clients = Arc::new(RwLock::new(HashMap::default()));
724 clients.write().await.insert(
725 "client".into(),
726 PasswordHashString::new(
727 "$argon2id$v=19$m=10,t=1,p=1$THh0RHE5YWNkQUZNa2lqUA$dmB4X7J49jjCGA",
728 )
729 .unwrap(),
730 );
731
732 let client_handshake = build_handshake(3, 3);
733 let authentication = build_authentication("client", "secret", "127.0.0.1");
734
735 let server_handshake: Message = ServerMessage::ServerHandshake {
736 accepted_version: 3,
737 capabilities: Capabilities::PreloadCheck,
738 }
739 .into();
740
741 let authentication_accepted: Message = ServerMessage::AuthenticationAccepted.into();
742
743 let mut reader = tokio_test::io::Builder::new()
744 .read(&client_handshake)
745 .read(&authentication)
746 .build();
747 let mut writer = tokio_test::io::Builder::new()
748 .write(&server_handshake.into_bytes())
749 .write(&authentication_accepted.into_bytes())
750 .build();
751
752 let result = authenticate(&mut reader, &mut writer, clients).await;
753 println!("{:?}", result);
754
755 assert!(result.is_ok());
756 }
757
758 #[tokio::test]
759 async fn authentication_fails_with_wrong_password() {
760 let clients = Arc::new(RwLock::new(HashMap::default()));
761 clients.write().await.insert(
762 "client".into(),
763 PasswordHashString::new(
764 "$argon2id$v=19$m=10,t=1,p=1$THh0RHE5YWNkQUZNa2lqUA$dmB4X7J49jjCGA",
765 )
766 .unwrap(),
767 );
768
769 let client_handshake = build_handshake(3, 3);
770 let authentication = build_authentication("client1", "wrong-secret", "127.0.0.1");
771
772 let server_handshake: Message = ServerMessage::ServerHandshake {
773 accepted_version: 3,
774 capabilities: Capabilities::PreloadCheck,
775 }
776 .into();
777
778 let mut reader = tokio_test::io::Builder::new()
779 .read(&client_handshake)
780 .read(&authentication)
781 .build();
782 let mut writer = tokio_test::io::Builder::new()
783 .write(&server_handshake.into_bytes())
784 .build();
785
786 let result = authenticate(&mut reader, &mut writer, clients).await;
787
788 assert!(matches!(result, Err(Error::InvalidCredentials)));
789 }
790
791 fn dummy_engine_tx() -> mpsc::Sender<(EngineRequest, oneshot::Sender<ServerMessage>)> {
792 let (tx, _rx) = mpsc::channel(1);
793 tx
794 }
795
796 fn dummy_event_tx() -> broadcast::Sender<Event> {
797 let (tx, _rx) = broadcast::channel(1);
798 tx
799 }
800
801 fn build_handshake(min_version: u8, max_version: u8) -> Vec<u8> {
802 let mut buf = Vec::new();
803 let payload = [min_version, max_version];
804
805 buf.push(0x01);
806 buf.extend(&(payload.len() as u32).to_le_bytes());
807 buf.extend(&payload);
808
809 buf
810 }
811
812 fn build_authentication(client_id: &str, password: &str, ip_addr: &str) -> Vec<u8> {
813 use std::net::IpAddr;
814
815 let mut buf = Vec::new();
816
817 let client_id_bytes = client_id.as_bytes();
818 let password_bytes = password.as_bytes();
819
820 let mut payload = Vec::new();
821 payload.push(client_id_bytes.len() as u8);
822 payload.extend(client_id_bytes);
823
824 payload.push(password_bytes.len() as u8);
825 payload.extend(password_bytes);
826
827 let ip: IpAddr = ip_addr.parse().expect("Invalid IP address");
828 match ip {
829 IpAddr::V4(v4) => {
830 payload.push(4); payload.extend(&v4.octets());
832 }
833 IpAddr::V6(v6) => {
834 payload.push(6); payload.extend(&v6.octets());
836 }
837 }
838
839 buf.push(0x03);
840 buf.extend(&(payload.len() as u32).to_le_bytes());
841 buf.extend(payload);
842
843 buf
844 }
845}