1use crate::engine::EngineRequest;
7use crate::event::Event;
8use crate::message::{ClientMessage, ErrorResponse, Message, ServerFeatures, ServerMessage};
9use argon2::{Argon2, PasswordVerifier, password_hash::PasswordHashString};
10use async_trait::async_trait;
11use rustls::ServerConfig;
12use rustls::pki_types::{
13 CertificateDer, PrivateKeyDer,
14 pem::{self, PemObject},
15};
16use std::collections::HashMap;
17use std::fmt::Debug;
18use std::net::{IpAddr, SocketAddr};
19use std::path::PathBuf;
20use std::result;
21use std::sync::Arc;
22use std::time::Duration;
23use thiserror::Error;
24use tokio::io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter};
25use tokio::net::{TcpListener, TcpStream};
26use tokio::sync::{RwLock, broadcast, mpsc, oneshot};
27use tokio::time::timeout;
28#[cfg(feature = "tokio-graceful-shutdown")]
29use tokio_graceful_shutdown::{FutureExt, IntoSubsystem, SubsystemHandle};
30use tokio_rustls::TlsAcceptor;
31use tracing::{info, instrument, warn};
32
33pub type ClientRegistry = HashMap<String, PasswordHashString>;
37
38#[derive(Error, Debug)]
39pub enum Error {
40 #[error(transparent)]
41 Io(#[from] io::Error),
42
43 #[error(transparent)]
44 Oneshot(#[from] oneshot::error::RecvError),
45
46 #[error("client sent unexpected message: {0:?}")]
47 UnexpectedMessage(ClientMessage),
48
49 #[error("client requested an unsupported version range: {0} - {1}")]
50 UnsupportedVersion(u8, u8),
51
52 #[error("unrecognized request code: {0}")]
53 UnknownRequest(u8),
54
55 #[error("client provided invalid credentials")]
56 InvalidCredentials,
57}
58
59pub type Result<T> = result::Result<T, Error>;
60
61#[derive(Debug)]
64#[allow(dead_code)]
65pub struct CustomRequestMessage {
66 client_id: String,
67 message: Message,
68 response: oneshot::Sender<Option<Message>>,
69}
70
71pub struct NetworkListener {
104 clients: Arc<RwLock<ClientRegistry>>,
105 addr: SocketAddr,
106 tls_acceptor: TlsAcceptor,
107 engine_tx: mpsc::Sender<(EngineRequest, oneshot::Sender<ServerMessage>)>,
108 event_tx: broadcast::Sender<Event>,
109 custom_req_tx: Option<mpsc::Sender<CustomRequestMessage>>,
110}
111
112impl Debug for NetworkListener {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 f.debug_struct("NetworkListener")
115 .field("clients", &self.clients)
116 .field("addr", &self.addr)
117 .field("engine_tx", &self.engine_tx)
118 .field("event_tx", &self.event_tx)
119 .field("custom_req_tx", &self.custom_req_tx)
120 .finish()
121 }
122}
123
124impl NetworkListener {
125 pub async fn listen(&self) -> Result<()> {
132 let listener = TcpListener::bind(self.addr).await?;
133 let mut con_counter: usize = 0;
134
135 loop {
136 let (stream, peer_addr) = listener.accept().await?;
137 let conn_id = con_counter;
138 con_counter += 1;
139 let event_tx = self.event_tx.clone();
140
141 self.handle_stream(stream, peer_addr, self.clients.clone(), conn_id, event_tx);
142 }
143 }
144
145 #[instrument(skip(self, stream, peer_addr, clients, event_tx))]
146 fn handle_stream(
147 &self,
148 stream: TcpStream,
149 peer_addr: SocketAddr,
150 clients: Arc<RwLock<ClientRegistry>>,
151 conn_id: usize,
152 event_tx: broadcast::Sender<Event>,
153 ) {
154 let acceptor = self.tls_acceptor.clone();
155 let engine_tx = self.engine_tx.clone();
156 let custom_req_tx = self.custom_req_tx.clone();
157
158 tokio::spawn(async move {
159 info!("new connection from {}", peer_addr);
160
161 let stream = match acceptor.accept(stream).await {
162 Ok(stream) => stream,
163 Err(error) => {
164 warn!("failed to accept stream: {}", error);
165 return;
166 }
167 };
168
169 let (reader, writer) = io::split(stream);
170 let mut reader = BufReader::new(reader);
171 let mut writer = BufWriter::new(writer);
172
173 let (client_id, local_ip, _version) = match timeout(
174 Duration::from_secs(2),
175 authenticate(&mut reader, &mut writer, clients),
176 )
177 .await
178 {
179 Ok(Ok(result)) => result,
180 Ok(Err(Error::UnexpectedMessage(message))) => {
181 warn!("client error: unexpected message: {:?}", message);
182 let _ = write_to_stream(
183 &mut writer,
184 ServerMessage::Error(ErrorResponse::UnexpectedMessage),
185 )
186 .await;
187 return;
188 }
189 Ok(Err(Error::UnsupportedVersion(min_version, max_version))) => {
190 warn!("client error: unsupported version range: {min_version} - {max_version}");
191 let _ = write_to_stream(
192 &mut writer,
193 ServerMessage::Error(ErrorResponse::UnsupportedVersionRange),
194 )
195 .await;
196 return;
197 }
198 Ok(Err(Error::InvalidCredentials)) => {
199 warn!("client error: invalid credentials");
200 let _ = write_to_stream(
201 &mut writer,
202 ServerMessage::Error(ErrorResponse::InvalidCredentials),
203 )
204 .await;
205 return;
206 }
207 Ok(Err(Error::Io(error)))
208 if matches!(
209 error.kind(),
210 io::ErrorKind::InvalidInput | io::ErrorKind::InvalidData
211 ) =>
212 {
213 warn!("client error: malformed message: {:?}", error);
214 let _ = write_to_stream(
215 &mut writer,
216 ServerMessage::Error(ErrorResponse::MalformedMessage),
217 )
218 .await;
219 return;
220 }
221 Ok(Err(error)) => {
222 warn!("client error: connection died: {:?}", error);
223 return;
224 }
225 Err(_) => {
226 warn!("client error: authentication timed out");
227 return;
228 }
229 };
230
231 let _ = event_tx.send(Event::ClientConnect {
232 client_id: client_id.clone(),
233 conn_id,
234 local_ip,
235 });
236
237 match handle_connection(
238 &mut reader,
239 &mut writer,
240 &client_id,
241 engine_tx,
242 custom_req_tx,
243 )
244 .await
245 {
246 Ok(()) => {
247 let _ = event_tx.send(Event::ClientDisconnect { client_id, conn_id });
248 return;
249 }
250 Err(Error::UnexpectedMessage(message)) => {
251 warn!("client error: unexpected message: {:?}", message);
252 let _ = write_to_stream(
253 &mut writer,
254 ServerMessage::Error(ErrorResponse::UnexpectedMessage),
255 )
256 .await;
257 }
258 Err(Error::Io(error))
259 if matches!(
260 error.kind(),
261 io::ErrorKind::InvalidInput | io::ErrorKind::InvalidData
262 ) =>
263 {
264 warn!("client error: malformed message: {:?}", error);
265 let _ = write_to_stream(
266 &mut writer,
267 ServerMessage::Error(ErrorResponse::MalformedMessage),
268 )
269 .await;
270 return;
271 }
272 Err(error) => {
273 warn!("client error: connection died: {:?}", error);
274 return;
275 }
276 }
277
278 let _ = event_tx.send(Event::ClientConnectionLoss { client_id, conn_id });
279 });
280 }
281}
282
283#[cfg(feature = "tokio-graceful-shutdown")]
284#[async_trait]
285impl IntoSubsystem<Error> for NetworkListener {
286 async fn run(mut self, subsys: SubsystemHandle) -> Result<()> {
287 if let Ok(result) = self.listen().cancel_on_shutdown(&subsys).await {
288 result?
289 }
290
291 Ok(())
292 }
293}
294
295async fn read_from_stream<S: AsyncRead + Unpin + Send>(stream: &mut S) -> Result<ClientMessage> {
296 let message_type = stream.read_u8().await?;
297 let payload_length = stream.read_u32_le().await?;
298
299 if payload_length == 0 {
300 return Ok(Message::new(message_type, vec![]).try_into()?);
301 }
302
303 let mut message = vec![0; payload_length as usize];
304 stream.read_exact(&mut message).await?;
305
306 Ok(Message::new(message_type, message).try_into()?)
307}
308
309async fn write_to_stream<S: AsyncWrite + Unpin + Send>(
310 stream: &mut S,
311 message: impl Into<Message>,
312) -> Result<()> {
313 let message: Message = message.into();
314 stream.write_all(&message.into_bytes()).await?;
315 stream.flush().await?;
316
317 Ok(())
318}
319
320#[derive(Debug, Error)]
321pub enum BuilderError {
322 #[error("missing field: {0}")]
323 MissingField(&'static str),
324
325 #[error(transparent)]
326 AddrParse(#[from] std::net::AddrParseError),
327
328 #[error("failed to read PEM file at {path}: {source}")]
329 Pem {
330 path: PathBuf,
331 #[source]
332 source: pem::Error,
333 },
334
335 #[error(transparent)]
336 Rustls(#[from] rustls::Error),
337}
338
339pub type BuilderResult<T> = result::Result<T, BuilderError>;
340
341#[derive(Debug, Default)]
368pub struct NetworkListenerBuilder {
369 address: Option<String>,
370 cert_path: Option<PathBuf>,
371 key_path: Option<PathBuf>,
372 clients: Option<Arc<RwLock<ClientRegistry>>>,
373 engine_tx: Option<mpsc::Sender<(EngineRequest, oneshot::Sender<ServerMessage>)>>,
374 event_tx: Option<broadcast::Sender<Event>>,
375 custom_req_tx: Option<mpsc::Sender<CustomRequestMessage>>,
376}
377
378impl NetworkListenerBuilder {
379 pub fn new() -> Self {
380 Self {
381 address: None,
382 cert_path: None,
383 key_path: None,
384 clients: None,
385 engine_tx: None,
386 event_tx: None,
387 custom_req_tx: None,
388 }
389 }
390
391 pub fn custom_req_tx(
392 self,
393 custom_req_tx: mpsc::Sender<CustomRequestMessage>,
394 ) -> NetworkListenerBuilder {
395 NetworkListenerBuilder {
396 clients: self.clients,
397 address: self.address,
398 cert_path: self.cert_path,
399 key_path: self.key_path,
400 engine_tx: self.engine_tx,
401 event_tx: self.event_tx,
402 custom_req_tx: Some(custom_req_tx),
403 }
404 }
405
406 pub fn address(mut self, address: impl Into<String>) -> Self {
407 self.address = Some(address.into());
408 self
409 }
410
411 pub fn cert_path(mut self, path: impl Into<PathBuf>) -> Self {
412 self.cert_path = Some(path.into());
413 self
414 }
415
416 pub fn key_path(mut self, path: impl Into<PathBuf>) -> Self {
417 self.key_path = Some(path.into());
418 self
419 }
420
421 pub fn clients(mut self, clients: Arc<RwLock<ClientRegistry>>) -> Self {
422 self.clients = Some(clients);
423 self
424 }
425
426 pub fn engine_tx(
427 mut self,
428 tx: mpsc::Sender<(EngineRequest, oneshot::Sender<ServerMessage>)>,
429 ) -> Self {
430 self.engine_tx = Some(tx);
431 self
432 }
433
434 pub fn event_tx(mut self, tx: broadcast::Sender<Event>) -> Self {
435 self.event_tx = Some(tx);
436 self
437 }
438
439 pub fn build(self) -> BuilderResult<NetworkListener> {
443 let addr: SocketAddr = self
444 .address
445 .ok_or_else(|| BuilderError::MissingField("address"))?
446 .parse()?;
447
448 let cert_path = self
449 .cert_path
450 .ok_or_else(|| BuilderError::MissingField("cert_path"))?;
451 let key_path = self
452 .key_path
453 .ok_or_else(|| BuilderError::MissingField("key_path"))?;
454
455 let certs = CertificateDer::pem_file_iter(&cert_path)
456 .map_err(|err| BuilderError::Pem {
457 path: cert_path.clone(),
458 source: err,
459 })?
460 .collect::<result::Result<Vec<_>, _>>()
461 .map_err(|err| BuilderError::Pem {
462 path: cert_path,
463 source: err,
464 })?;
465 let key = PrivateKeyDer::from_pem_file(&key_path).map_err(|err| BuilderError::Pem {
466 path: key_path,
467 source: err,
468 })?;
469
470 let config = ServerConfig::builder()
471 .with_no_client_auth()
472 .with_single_cert(certs, key)?;
473 let tls_acceptor = TlsAcceptor::from(Arc::new(config));
474
475 Ok(NetworkListener {
476 clients: self
477 .clients
478 .ok_or_else(|| BuilderError::MissingField("clients"))?,
479 addr,
480 tls_acceptor,
481 engine_tx: self
482 .engine_tx
483 .ok_or_else(|| BuilderError::MissingField("engine_tx"))?,
484 event_tx: self
485 .event_tx
486 .ok_or_else(|| BuilderError::MissingField("event_tx"))?,
487 custom_req_tx: self.custom_req_tx,
488 })
489 }
490}
491
492async fn handle_connection<R, W>(
497 reader: &mut R,
498 writer: &mut W,
499 client_id: &str,
500 engine_tx: mpsc::Sender<(EngineRequest, oneshot::Sender<ServerMessage>)>,
501 custom_req_tx: Option<mpsc::Sender<CustomRequestMessage>>,
502) -> Result<()>
503where
504 R: AsyncRead + Unpin + Send,
505 W: AsyncWrite + Unpin + Send,
506{
507 loop {
508 let message = match timeout(Duration::from_secs(30), read_from_stream(reader)).await {
509 Ok(Ok(message)) => message,
510 Ok(Err(error)) => return Err(error),
511 Err(_) => return Ok(()),
512 };
513
514 let engine_request = match message {
515 ClientMessage::Bloop { nfc_uid } => EngineRequest::Bloop {
516 nfc_uid,
517 client_id: client_id.to_string(),
518 },
519 ClientMessage::RetrieveAudio { achievement_id } => {
520 EngineRequest::RetrieveAudio { id: achievement_id }
521 }
522 ClientMessage::PreloadCheck {
523 audio_manifest_hash,
524 } => EngineRequest::PreloadCheck {
525 manifest_hash: audio_manifest_hash,
526 },
527 ClientMessage::Ping => {
528 write_to_stream(writer, ServerMessage::Pong).await?;
529 continue;
530 }
531 ClientMessage::Quit => break,
532 ClientMessage::Unknown(message) => {
533 if let Some(sender) = custom_req_tx.as_ref() {
534 let (resp_tx, resp_rx) = oneshot::channel();
535
536 let _ = sender
537 .send(CustomRequestMessage {
538 client_id: client_id.to_string(),
539 message,
540 response: resp_tx,
541 })
542 .await;
543
544 if let Some(message) = resp_rx.await? {
545 write_to_stream(writer, message).await?;
546 }
547 }
548
549 continue;
550 }
551 message => return Err(Error::UnexpectedMessage(message)),
552 };
553
554 let (resp_tx, resp_rx) = oneshot::channel();
555 let _ = engine_tx.send((engine_request, resp_tx)).await;
556 let response = resp_rx.await?;
557
558 write_to_stream(writer, response).await?;
559 }
560
561 Ok(())
562}
563
564async fn authenticate<R, W>(
569 reader: &mut R,
570 writer: &mut W,
571 clients: Arc<RwLock<ClientRegistry>>,
572) -> Result<(String, IpAddr, u8)>
573where
574 R: AsyncRead + Unpin + Send,
575 W: AsyncWrite + Unpin + Send,
576{
577 let (min_version, max_version) = match read_from_stream(reader).await? {
578 ClientMessage::ClientHandshake {
579 min_version,
580 max_version,
581 } => (min_version, max_version),
582 message => return Err(Error::UnexpectedMessage(message)),
583 };
584
585 if min_version > 3 || max_version < 3 {
586 return Err(Error::UnsupportedVersion(min_version, max_version));
587 }
588
589 write_to_stream(
590 writer,
591 ServerMessage::ServerHandshake {
592 accepted_version: 3,
593 features: ServerFeatures::PreloadCheck,
594 },
595 )
596 .await?;
597
598 let (client_id, client_secret, ip_addr) = match read_from_stream(reader).await? {
599 ClientMessage::Authentication {
600 client_id,
601 client_secret,
602 ip_addr,
603 } => (client_id, client_secret, ip_addr),
604 message => return Err(Error::UnexpectedMessage(message)),
605 };
606
607 let clients = clients.read().await;
608 let Some(secret_hash) = clients.get(&client_id) else {
609 return Err(Error::InvalidCredentials);
610 };
611
612 if Argon2::default()
613 .verify_password(client_secret.as_bytes(), &secret_hash.password_hash())
614 .is_err()
615 {
616 return Err(Error::InvalidCredentials);
617 }
618
619 write_to_stream(writer, ServerMessage::AuthenticationAccepted).await?;
620
621 Ok((client_id.to_string(), ip_addr, 3))
622}
623
624#[cfg(test)]
625mod tests {
626 use super::*;
627 use std::fs;
628 use tempfile::tempdir;
629
630 #[tokio::test]
631 async fn builder_fails_with_missing_fields() {
632 let builder = NetworkListenerBuilder::new();
633 let result = builder.build();
634 assert!(matches!(result, Err(BuilderError::MissingField(_))));
635 }
636
637 #[tokio::test]
638 async fn builder_fails_with_invalid_address() {
639 let builder = NetworkListenerBuilder::new()
640 .address("invalid-addr")
641 .cert_path("cert.pem")
642 .key_path("key.pem")
643 .clients(Arc::new(RwLock::new(Default::default())))
644 .engine_tx(dummy_engine_tx())
645 .event_tx(dummy_event_tx());
646
647 let result = builder.build();
648 assert!(matches!(result, Err(BuilderError::AddrParse(_))));
649 }
650
651 #[tokio::test]
652 async fn builder_fails_on_invalid_pem_files() {
653 let dir = tempdir().unwrap();
654 let cert_path = dir.path().join("cert.pem");
655 let key_path = dir.path().join("key.pem");
656 fs::write(&cert_path, b"invalid-cert").unwrap();
657 fs::write(&key_path, b"invalid-key").unwrap();
658
659 let builder = NetworkListenerBuilder::new()
660 .address("127.0.0.1:12345")
661 .cert_path(&cert_path)
662 .key_path(&key_path)
663 .clients(Arc::new(RwLock::new(Default::default())))
664 .engine_tx(dummy_engine_tx())
665 .event_tx(dummy_event_tx());
666
667 let result = builder.build();
668 assert!(matches!(result, Err(BuilderError::Pem { .. })));
669 }
670
671 #[tokio::test]
672 async fn builder_succeeds_with_valid_dummy_pem() {
673 let dir = tempdir().unwrap();
674 let cert_path = dir.path().join("cert.pem");
675 let key_path = dir.path().join("key.pem");
676
677 let cert_data = include_bytes!("../examples/cert.pem");
678 let key_data = include_bytes!("../examples/key.pem");
679
680 fs::write(&cert_path, cert_data).unwrap();
681 fs::write(&key_path, key_data).unwrap();
682
683 let builder = NetworkListenerBuilder::new()
684 .address("127.0.0.1:12345")
685 .cert_path(&cert_path)
686 .key_path(&key_path)
687 .clients(Arc::new(RwLock::new(Default::default())))
688 .engine_tx(dummy_engine_tx())
689 .event_tx(dummy_event_tx());
690
691 let result = builder.build();
692 assert!(result.is_ok());
693 }
694
695 #[tokio::test]
696 async fn authentication_fails_with_wrong_client_id() {
697 let clients = Arc::new(RwLock::new(Default::default()));
698
699 let client_handshake = build_handshake(3, 3);
700 let authentication = build_authentication("unknown-client", "password", "127.0.0.1");
701
702 let server_handshake: Message = ServerMessage::ServerHandshake {
703 accepted_version: 3,
704 features: ServerFeatures::PreloadCheck,
705 }
706 .into();
707
708 let mut reader = tokio_test::io::Builder::new()
709 .read(&client_handshake)
710 .read(&authentication)
711 .build();
712 let mut writer = tokio_test::io::Builder::new()
713 .write(&server_handshake.into_bytes())
714 .build();
715
716 let result = authenticate(&mut reader, &mut writer, clients).await;
717
718 assert!(matches!(result, Err(Error::InvalidCredentials)));
719 }
720
721 #[tokio::test]
722 async fn authentication_succeeds_with_correct_credentials() {
723 let clients = Arc::new(RwLock::new(HashMap::default()));
724 clients.write().await.insert(
725 "client".into(),
726 PasswordHashString::new(
727 "$argon2id$v=19$m=10,t=1,p=1$THh0RHE5YWNkQUZNa2lqUA$dmB4X7J49jjCGA",
728 )
729 .unwrap(),
730 );
731
732 let client_handshake = build_handshake(3, 3);
733 let authentication = build_authentication("client", "secret", "127.0.0.1");
734
735 let server_handshake: Message = ServerMessage::ServerHandshake {
736 accepted_version: 3,
737 features: ServerFeatures::PreloadCheck,
738 }
739 .into();
740
741 let authentication_accepted: Message = ServerMessage::AuthenticationAccepted.into();
742
743 let mut reader = tokio_test::io::Builder::new()
744 .read(&client_handshake)
745 .read(&authentication)
746 .build();
747 let mut writer = tokio_test::io::Builder::new()
748 .write(&server_handshake.into_bytes())
749 .write(&authentication_accepted.into_bytes())
750 .build();
751
752 let result = authenticate(&mut reader, &mut writer, clients).await;
753 println!("{:?}", result);
754
755 assert!(result.is_ok());
756 }
757
758 #[tokio::test]
759 async fn authentication_fails_with_wrong_password() {
760 let clients = Arc::new(RwLock::new(HashMap::default()));
761 clients.write().await.insert(
762 "client".into(),
763 PasswordHashString::new(
764 "$argon2id$v=19$m=10,t=1,p=1$THh0RHE5YWNkQUZNa2lqUA$dmB4X7J49jjCGA",
765 )
766 .unwrap(),
767 );
768
769 let client_handshake = build_handshake(3, 3);
770 let authentication = build_authentication("client1", "wrong-secret", "127.0.0.1");
771
772 let server_handshake: Message = ServerMessage::ServerHandshake {
773 accepted_version: 3,
774 features: ServerFeatures::PreloadCheck,
775 }
776 .into();
777
778 let mut reader = tokio_test::io::Builder::new()
779 .read(&client_handshake)
780 .read(&authentication)
781 .build();
782 let mut writer = tokio_test::io::Builder::new()
783 .write(&server_handshake.into_bytes())
784 .build();
785
786 let result = authenticate(&mut reader, &mut writer, clients).await;
787
788 assert!(matches!(result, Err(Error::InvalidCredentials)));
789 }
790
791 fn dummy_engine_tx() -> mpsc::Sender<(EngineRequest, oneshot::Sender<ServerMessage>)> {
792 let (tx, _rx) = mpsc::channel(1);
793 tx
794 }
795
796 fn dummy_event_tx() -> broadcast::Sender<Event> {
797 let (tx, _rx) = broadcast::channel(1);
798 tx
799 }
800
801 fn build_handshake(min_version: u8, max_version: u8) -> Vec<u8> {
802 let mut buf = Vec::new();
803 let payload = [min_version, max_version];
804
805 buf.push(0x01);
806 buf.extend(&(payload.len() as u32).to_le_bytes());
807 buf.extend(&payload);
808
809 buf
810 }
811
812 fn build_authentication(client_id: &str, password: &str, ip_addr: &str) -> Vec<u8> {
813 use std::net::IpAddr;
814
815 let mut buf = Vec::new();
816
817 let client_id_bytes = client_id.as_bytes();
818 let password_bytes = password.as_bytes();
819
820 let mut payload = Vec::new();
821 payload.push(client_id_bytes.len() as u8);
822 payload.extend(client_id_bytes);
823
824 payload.push(password_bytes.len() as u8);
825 payload.extend(password_bytes);
826
827 let ip: IpAddr = ip_addr.parse().expect("Invalid IP address");
828 match ip {
829 IpAddr::V4(v4) => {
830 payload.push(4); payload.extend(&v4.octets());
832 }
833 IpAddr::V6(v6) => {
834 payload.push(6); payload.extend(&v6.octets());
836 }
837 }
838
839 buf.push(0x03);
840 buf.extend(&(payload.len() as u32).to_le_bytes());
841 buf.extend(payload);
842
843 buf
844 }
845}