bloop_server_framework/
network.rs

1//! TLS-based network server for handling authenticated client connections.
2//!
3//! This module provides a [`NetworkListener`] that manages client authentication,
4//! request processing, and message dispatching over a secure TCP connection.
5
6use crate::engine::EngineRequest;
7use crate::event::Event;
8use crate::message::{Capabilities, ClientMessage, ErrorResponse, Message, ServerMessage};
9use argon2::{Argon2, PasswordVerifier, password_hash::PasswordHashString};
10#[cfg(feature = "tokio-graceful-shutdown")]
11use async_trait::async_trait;
12use rustls::ServerConfig;
13use rustls::pki_types::{
14    CertificateDer, PrivateKeyDer,
15    pem::{self, PemObject},
16};
17use std::collections::HashMap;
18use std::fmt::Debug;
19use std::net::{IpAddr, SocketAddr};
20use std::path::PathBuf;
21use std::result;
22use std::sync::Arc;
23use std::time::Duration;
24use thiserror::Error;
25use tokio::io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter};
26use tokio::net::{TcpListener, TcpStream};
27use tokio::sync::{RwLock, broadcast, mpsc, oneshot};
28use tokio::time::timeout;
29#[cfg(feature = "tokio-graceful-shutdown")]
30use tokio_graceful_shutdown::{FutureExt, IntoSubsystem, SubsystemHandle};
31use tokio_rustls::TlsAcceptor;
32use tracing::{info, instrument, warn};
33
34/// Maps client IDs to their stored password hashes (Argon2).
35///
36/// Used during client authentication.
37pub type ClientRegistry = HashMap<String, PasswordHashString>;
38
39#[derive(Error, Debug)]
40pub enum Error {
41    #[error(transparent)]
42    Io(#[from] io::Error),
43
44    #[error(transparent)]
45    Oneshot(#[from] oneshot::error::RecvError),
46
47    #[error("client sent unexpected message: {0:?}")]
48    UnexpectedMessage(ClientMessage),
49
50    #[error("client requested an unsupported version range: {0} - {1}")]
51    UnsupportedVersion(u8, u8),
52
53    #[error("unrecognized request code: {0}")]
54    UnknownRequest(u8),
55
56    #[error("client provided invalid credentials")]
57    InvalidCredentials,
58}
59
60pub type Result<T> = result::Result<T, Error>;
61
62/// A wrapper for custom client requests that are forwarded to the application
63/// layer.
64#[derive(Debug)]
65#[allow(dead_code)]
66pub struct CustomRequestMessage {
67    client_id: String,
68    message: Message,
69    response: oneshot::Sender<Option<Message>>,
70}
71
72/// A TLS-secured TCP server that accepts client connections, authenticates
73/// them, and dispatches their requests to the appropriate handlers.
74///
75/// This listener handles authentication, version negotiation, and supports
76/// custom client messages.
77///
78/// # Examples
79///
80/// ```no_run
81/// use std::sync::Arc;
82/// use tokio::sync::{mpsc, broadcast, RwLock};
83/// use bloop_server_framework::network::NetworkListenerBuilder;
84///
85/// #[tokio::main]
86/// async fn main() {
87///   let clients = Arc::new(RwLock::new(Default::default()));
88///   let (engine_tx, _) = mpsc::channel(10);
89///   let (event_tx, _) = broadcast::channel(10);
90///
91///   let listener = NetworkListenerBuilder::new()
92///       .address("127.0.0.1:12345")
93///       .cert_path("server.crt")
94///       .key_path("server.key")
95///       .clients(clients)
96///       .engine_tx(engine_tx)
97///       .event_tx(event_tx)
98///       .build()
99///       .unwrap();
100///
101///   listener.listen().await.unwrap();
102/// }
103/// ```
104pub struct NetworkListener {
105    clients: Arc<RwLock<ClientRegistry>>,
106    addr: SocketAddr,
107    tls_acceptor: TlsAcceptor,
108    engine_tx: mpsc::Sender<(EngineRequest, oneshot::Sender<ServerMessage>)>,
109    event_tx: broadcast::Sender<Event>,
110    custom_req_tx: Option<mpsc::Sender<CustomRequestMessage>>,
111}
112
113impl Debug for NetworkListener {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        f.debug_struct("NetworkListener")
116            .field("clients", &self.clients)
117            .field("addr", &self.addr)
118            .field("engine_tx", &self.engine_tx)
119            .field("event_tx", &self.event_tx)
120            .field("custom_req_tx", &self.custom_req_tx)
121            .finish()
122    }
123}
124
125impl NetworkListener {
126    /// Starts listening for incoming TCP connections.
127    ///
128    /// This method blocks indefinitely, accepting and processing new connections.
129    /// Each connection is handled asynchronously in its own task.
130    ///
131    /// Returns an error only if the server fails to bind to the specified address.
132    pub async fn listen(&self) -> Result<()> {
133        let listener = TcpListener::bind(self.addr).await?;
134        let mut con_counter: usize = 0;
135
136        loop {
137            let (stream, peer_addr) = listener.accept().await?;
138            let conn_id = con_counter;
139            con_counter += 1;
140            let event_tx = self.event_tx.clone();
141
142            self.handle_stream(stream, peer_addr, self.clients.clone(), conn_id, event_tx);
143        }
144    }
145
146    #[instrument(skip(self, stream, peer_addr, clients, event_tx))]
147    fn handle_stream(
148        &self,
149        stream: TcpStream,
150        peer_addr: SocketAddr,
151        clients: Arc<RwLock<ClientRegistry>>,
152        conn_id: usize,
153        event_tx: broadcast::Sender<Event>,
154    ) {
155        let acceptor = self.tls_acceptor.clone();
156        let engine_tx = self.engine_tx.clone();
157        let custom_req_tx = self.custom_req_tx.clone();
158
159        tokio::spawn(async move {
160            info!("new connection from {}", peer_addr);
161
162            let stream = match acceptor.accept(stream).await {
163                Ok(stream) => stream,
164                Err(error) => {
165                    warn!("failed to accept stream: {}", error);
166                    return;
167                }
168            };
169
170            let (reader, writer) = io::split(stream);
171            let mut reader = BufReader::new(reader);
172            let mut writer = BufWriter::new(writer);
173
174            let (client_id, local_ip, _version) = match timeout(
175                Duration::from_secs(2),
176                authenticate(&mut reader, &mut writer, clients),
177            )
178            .await
179            {
180                Ok(Ok(result)) => result,
181                Ok(Err(Error::UnexpectedMessage(message))) => {
182                    warn!("client error: unexpected message: {:?}", message);
183                    let _ = write_to_stream(
184                        &mut writer,
185                        ServerMessage::Error(ErrorResponse::UnexpectedMessage),
186                    )
187                    .await;
188                    return;
189                }
190                Ok(Err(Error::UnsupportedVersion(min_version, max_version))) => {
191                    warn!("client error: unsupported version range: {min_version} - {max_version}");
192                    let _ = write_to_stream(
193                        &mut writer,
194                        ServerMessage::Error(ErrorResponse::UnsupportedVersionRange),
195                    )
196                    .await;
197                    return;
198                }
199                Ok(Err(Error::InvalidCredentials)) => {
200                    warn!("client error: invalid credentials");
201                    let _ = write_to_stream(
202                        &mut writer,
203                        ServerMessage::Error(ErrorResponse::InvalidCredentials),
204                    )
205                    .await;
206                    return;
207                }
208                Ok(Err(Error::Io(error)))
209                    if matches!(
210                        error.kind(),
211                        io::ErrorKind::InvalidInput | io::ErrorKind::InvalidData
212                    ) =>
213                {
214                    warn!("client error: malformed message: {:?}", error);
215                    let _ = write_to_stream(
216                        &mut writer,
217                        ServerMessage::Error(ErrorResponse::MalformedMessage),
218                    )
219                    .await;
220                    return;
221                }
222                Ok(Err(error)) => {
223                    warn!("client error: connection died: {:?}", error);
224                    return;
225                }
226                Err(_) => {
227                    warn!("client error: authentication timed out");
228                    return;
229                }
230            };
231
232            let _ = event_tx.send(Event::ClientConnect {
233                client_id: client_id.clone(),
234                conn_id,
235                local_ip,
236            });
237
238            match handle_connection(
239                &mut reader,
240                &mut writer,
241                &client_id,
242                engine_tx,
243                custom_req_tx,
244            )
245            .await
246            {
247                Ok(()) => {
248                    let _ = event_tx.send(Event::ClientDisconnect { client_id, conn_id });
249                    return;
250                }
251                Err(Error::UnexpectedMessage(message)) => {
252                    warn!("client error: unexpected message: {:?}", message);
253                    let _ = write_to_stream(
254                        &mut writer,
255                        ServerMessage::Error(ErrorResponse::UnexpectedMessage),
256                    )
257                    .await;
258                }
259                Err(Error::Io(error))
260                    if matches!(
261                        error.kind(),
262                        io::ErrorKind::InvalidInput | io::ErrorKind::InvalidData
263                    ) =>
264                {
265                    warn!("client error: malformed message: {:?}", error);
266                    let _ = write_to_stream(
267                        &mut writer,
268                        ServerMessage::Error(ErrorResponse::MalformedMessage),
269                    )
270                    .await;
271                    return;
272                }
273                Err(error) => {
274                    warn!("client error: connection died: {:?}", error);
275                    return;
276                }
277            }
278
279            let _ = event_tx.send(Event::ClientConnectionLoss { client_id, conn_id });
280        });
281    }
282}
283
284#[cfg(feature = "tokio-graceful-shutdown")]
285#[async_trait]
286impl IntoSubsystem<Error> for NetworkListener {
287    async fn run(mut self, subsys: SubsystemHandle) -> Result<()> {
288        if let Ok(result) = self.listen().cancel_on_shutdown(&subsys).await {
289            result?
290        }
291
292        Ok(())
293    }
294}
295
296async fn read_from_stream<S: AsyncRead + Unpin + Send>(stream: &mut S) -> Result<ClientMessage> {
297    let message_type = stream.read_u8().await?;
298    let payload_length = stream.read_u32_le().await?;
299
300    if payload_length == 0 {
301        return Ok(Message::new(message_type, vec![]).try_into()?);
302    }
303
304    let mut message = vec![0; payload_length as usize];
305    stream.read_exact(&mut message).await?;
306
307    Ok(Message::new(message_type, message).try_into()?)
308}
309
310async fn write_to_stream<S: AsyncWrite + Unpin + Send>(
311    stream: &mut S,
312    message: impl Into<Message>,
313) -> Result<()> {
314    let message: Message = message.into();
315    stream.write_all(&message.into_bytes()).await?;
316    stream.flush().await?;
317
318    Ok(())
319}
320
321#[derive(Debug, Error)]
322pub enum BuilderError {
323    #[error("missing field: {0}")]
324    MissingField(&'static str),
325
326    #[error(transparent)]
327    AddrParse(#[from] std::net::AddrParseError),
328
329    #[error("failed to read PEM file at {path}: {source}")]
330    Pem {
331        path: PathBuf,
332        #[source]
333        source: pem::Error,
334    },
335
336    #[error(transparent)]
337    Rustls(#[from] rustls::Error),
338}
339
340pub type BuilderResult<T> = result::Result<T, BuilderError>;
341
342/// Builder for [`NetworkListener`].
343///
344/// This allows configuring the address, TLS certificates, client registry,
345/// message channels, and custom request handlers.
346///
347/// # Examples
348///
349/// ```
350/// use std::sync::Arc;
351/// use tokio::sync::{broadcast, RwLock};
352/// use tokio::sync::mpsc;
353/// use bloop_server_framework::network::NetworkListenerBuilder;
354///
355/// let (engine_tx, engine_rx) = mpsc::channel(512);
356/// let (event_tx, event_rx) = broadcast::channel(512);
357///
358/// let builder = NetworkListenerBuilder::new()
359///     .address("127.0.0.1:12345")
360///     .clients(Arc::new(RwLock::new(Default::default())))
361///     .engine_tx(engine_tx)
362///     .event_tx(event_tx)
363///     .cert_path("examples/cert.pem")
364///     .key_path("examples/key.pem")
365///     .build()
366///     .unwrap();
367/// ```
368#[derive(Debug, Default)]
369pub struct NetworkListenerBuilder {
370    address: Option<String>,
371    cert_path: Option<PathBuf>,
372    key_path: Option<PathBuf>,
373    clients: Option<Arc<RwLock<ClientRegistry>>>,
374    engine_tx: Option<mpsc::Sender<(EngineRequest, oneshot::Sender<ServerMessage>)>>,
375    event_tx: Option<broadcast::Sender<Event>>,
376    custom_req_tx: Option<mpsc::Sender<CustomRequestMessage>>,
377}
378
379impl NetworkListenerBuilder {
380    pub fn new() -> Self {
381        Self {
382            address: None,
383            cert_path: None,
384            key_path: None,
385            clients: None,
386            engine_tx: None,
387            event_tx: None,
388            custom_req_tx: None,
389        }
390    }
391
392    pub fn custom_req_tx(
393        self,
394        custom_req_tx: mpsc::Sender<CustomRequestMessage>,
395    ) -> NetworkListenerBuilder {
396        NetworkListenerBuilder {
397            clients: self.clients,
398            address: self.address,
399            cert_path: self.cert_path,
400            key_path: self.key_path,
401            engine_tx: self.engine_tx,
402            event_tx: self.event_tx,
403            custom_req_tx: Some(custom_req_tx),
404        }
405    }
406
407    pub fn address(mut self, address: impl Into<String>) -> Self {
408        self.address = Some(address.into());
409        self
410    }
411
412    pub fn cert_path(mut self, path: impl Into<PathBuf>) -> Self {
413        self.cert_path = Some(path.into());
414        self
415    }
416
417    pub fn key_path(mut self, path: impl Into<PathBuf>) -> Self {
418        self.key_path = Some(path.into());
419        self
420    }
421
422    pub fn clients(mut self, clients: Arc<RwLock<ClientRegistry>>) -> Self {
423        self.clients = Some(clients);
424        self
425    }
426
427    pub fn engine_tx(
428        mut self,
429        tx: mpsc::Sender<(EngineRequest, oneshot::Sender<ServerMessage>)>,
430    ) -> Self {
431        self.engine_tx = Some(tx);
432        self
433    }
434
435    pub fn event_tx(mut self, tx: broadcast::Sender<Event>) -> Self {
436        self.event_tx = Some(tx);
437        self
438    }
439
440    /// Builds the [`NetworkListener`] from the provided configuration.
441    ///
442    /// Returns an error if required fields are missing, or TLS setup fails.
443    pub fn build(self) -> BuilderResult<NetworkListener> {
444        let addr: SocketAddr = self
445            .address
446            .ok_or_else(|| BuilderError::MissingField("address"))?
447            .parse()?;
448
449        let cert_path = self
450            .cert_path
451            .ok_or_else(|| BuilderError::MissingField("cert_path"))?;
452        let key_path = self
453            .key_path
454            .ok_or_else(|| BuilderError::MissingField("key_path"))?;
455
456        let certs = CertificateDer::pem_file_iter(&cert_path)
457            .map_err(|err| BuilderError::Pem {
458                path: cert_path.clone(),
459                source: err,
460            })?
461            .collect::<result::Result<Vec<_>, _>>()
462            .map_err(|err| BuilderError::Pem {
463                path: cert_path,
464                source: err,
465            })?;
466        let key = PrivateKeyDer::from_pem_file(&key_path).map_err(|err| BuilderError::Pem {
467            path: key_path,
468            source: err,
469        })?;
470
471        let config = ServerConfig::builder()
472            .with_no_client_auth()
473            .with_single_cert(certs, key)?;
474        let tls_acceptor = TlsAcceptor::from(Arc::new(config));
475
476        Ok(NetworkListener {
477            clients: self
478                .clients
479                .ok_or_else(|| BuilderError::MissingField("clients"))?,
480            addr,
481            tls_acceptor,
482            engine_tx: self
483                .engine_tx
484                .ok_or_else(|| BuilderError::MissingField("engine_tx"))?,
485            event_tx: self
486                .event_tx
487                .ok_or_else(|| BuilderError::MissingField("event_tx"))?,
488            custom_req_tx: self.custom_req_tx,
489        })
490    }
491}
492
493/// Handles an authenticated client connection.
494///
495/// Reads client messages from the stream, dispatches them to the appropriate
496/// handlers, and sends back server responses.
497async fn handle_connection<R, W>(
498    reader: &mut R,
499    writer: &mut W,
500    client_id: &str,
501    engine_tx: mpsc::Sender<(EngineRequest, oneshot::Sender<ServerMessage>)>,
502    custom_req_tx: Option<mpsc::Sender<CustomRequestMessage>>,
503) -> Result<()>
504where
505    R: AsyncRead + Unpin + Send,
506    W: AsyncWrite + Unpin + Send,
507{
508    loop {
509        let message = match timeout(Duration::from_secs(30), read_from_stream(reader)).await {
510            Ok(Ok(message)) => message,
511            Ok(Err(error)) => return Err(error),
512            Err(_) => return Ok(()),
513        };
514
515        let engine_request = match message {
516            ClientMessage::Bloop { nfc_uid } => EngineRequest::Bloop {
517                nfc_uid,
518                client_id: client_id.to_string(),
519            },
520            ClientMessage::RetrieveAudio { achievement_id } => {
521                EngineRequest::RetrieveAudio { id: achievement_id }
522            }
523            ClientMessage::PreloadCheck {
524                audio_manifest_hash,
525            } => EngineRequest::PreloadCheck {
526                manifest_hash: audio_manifest_hash,
527            },
528            ClientMessage::Ping => {
529                write_to_stream(writer, ServerMessage::Pong).await?;
530                continue;
531            }
532            ClientMessage::Quit => break,
533            ClientMessage::Unknown(message) => {
534                if let Some(sender) = custom_req_tx.as_ref() {
535                    let (resp_tx, resp_rx) = oneshot::channel();
536
537                    let _ = sender
538                        .send(CustomRequestMessage {
539                            client_id: client_id.to_string(),
540                            message,
541                            response: resp_tx,
542                        })
543                        .await;
544
545                    if let Some(message) = resp_rx.await? {
546                        write_to_stream(writer, message).await?;
547                    }
548                }
549
550                continue;
551            }
552            message => return Err(Error::UnexpectedMessage(message)),
553        };
554
555        let (resp_tx, resp_rx) = oneshot::channel();
556        let _ = engine_tx.send((engine_request, resp_tx)).await;
557        let response = resp_rx.await?;
558
559        write_to_stream(writer, response).await?;
560    }
561
562    Ok(())
563}
564
565/// Authenticates a client by performing handshake and credential verification.
566///
567/// Returns the client ID, its IP address, and the negotiated protocol version
568/// on success.
569async fn authenticate<R, W>(
570    reader: &mut R,
571    writer: &mut W,
572    clients: Arc<RwLock<ClientRegistry>>,
573) -> Result<(String, IpAddr, u8)>
574where
575    R: AsyncRead + Unpin + Send,
576    W: AsyncWrite + Unpin + Send,
577{
578    let (min_version, max_version) = match read_from_stream(reader).await? {
579        ClientMessage::ClientHandshake {
580            min_version,
581            max_version,
582        } => (min_version, max_version),
583        message => return Err(Error::UnexpectedMessage(message)),
584    };
585
586    if min_version > 3 || max_version < 3 {
587        return Err(Error::UnsupportedVersion(min_version, max_version));
588    }
589
590    write_to_stream(
591        writer,
592        ServerMessage::ServerHandshake {
593            accepted_version: 3,
594            capabilities: Capabilities::PreloadCheck,
595        },
596    )
597    .await?;
598
599    let (client_id, client_secret, ip_addr) = match read_from_stream(reader).await? {
600        ClientMessage::Authentication {
601            client_id,
602            client_secret,
603            ip_addr,
604        } => (client_id, client_secret, ip_addr),
605        message => return Err(Error::UnexpectedMessage(message)),
606    };
607
608    let clients = clients.read().await;
609    let Some(secret_hash) = clients.get(&client_id) else {
610        return Err(Error::InvalidCredentials);
611    };
612
613    if Argon2::default()
614        .verify_password(client_secret.as_bytes(), &secret_hash.password_hash())
615        .is_err()
616    {
617        return Err(Error::InvalidCredentials);
618    }
619
620    write_to_stream(writer, ServerMessage::AuthenticationAccepted).await?;
621
622    Ok((client_id.to_string(), ip_addr, 3))
623}
624
625#[cfg(test)]
626mod tests {
627    use super::*;
628    use std::fs;
629    use tempfile::tempdir;
630
631    #[tokio::test]
632    async fn builder_fails_with_missing_fields() {
633        let builder = NetworkListenerBuilder::new();
634        let result = builder.build();
635        assert!(matches!(result, Err(BuilderError::MissingField(_))));
636    }
637
638    #[tokio::test]
639    async fn builder_fails_with_invalid_address() {
640        let builder = NetworkListenerBuilder::new()
641            .address("invalid-addr")
642            .cert_path("cert.pem")
643            .key_path("key.pem")
644            .clients(Arc::new(RwLock::new(Default::default())))
645            .engine_tx(dummy_engine_tx())
646            .event_tx(dummy_event_tx());
647
648        let result = builder.build();
649        assert!(matches!(result, Err(BuilderError::AddrParse(_))));
650    }
651
652    #[tokio::test]
653    async fn builder_fails_on_invalid_pem_files() {
654        let dir = tempdir().unwrap();
655        let cert_path = dir.path().join("cert.pem");
656        let key_path = dir.path().join("key.pem");
657        fs::write(&cert_path, b"invalid-cert").unwrap();
658        fs::write(&key_path, b"invalid-key").unwrap();
659
660        let builder = NetworkListenerBuilder::new()
661            .address("127.0.0.1:12345")
662            .cert_path(&cert_path)
663            .key_path(&key_path)
664            .clients(Arc::new(RwLock::new(Default::default())))
665            .engine_tx(dummy_engine_tx())
666            .event_tx(dummy_event_tx());
667
668        let result = builder.build();
669        assert!(matches!(result, Err(BuilderError::Pem { .. })));
670    }
671
672    #[tokio::test]
673    async fn builder_succeeds_with_valid_dummy_pem() {
674        let dir = tempdir().unwrap();
675        let cert_path = dir.path().join("cert.pem");
676        let key_path = dir.path().join("key.pem");
677
678        let cert_data = include_bytes!("../examples/cert.pem");
679        let key_data = include_bytes!("../examples/key.pem");
680
681        fs::write(&cert_path, cert_data).unwrap();
682        fs::write(&key_path, key_data).unwrap();
683
684        let builder = NetworkListenerBuilder::new()
685            .address("127.0.0.1:12345")
686            .cert_path(&cert_path)
687            .key_path(&key_path)
688            .clients(Arc::new(RwLock::new(Default::default())))
689            .engine_tx(dummy_engine_tx())
690            .event_tx(dummy_event_tx());
691
692        let result = builder.build();
693        assert!(result.is_ok());
694    }
695
696    #[tokio::test]
697    async fn authentication_fails_with_wrong_client_id() {
698        let clients = Arc::new(RwLock::new(Default::default()));
699
700        let client_handshake = build_handshake(3, 3);
701        let authentication = build_authentication("unknown-client", "password", "127.0.0.1");
702
703        let server_handshake: Message = ServerMessage::ServerHandshake {
704            accepted_version: 3,
705            capabilities: Capabilities::PreloadCheck,
706        }
707        .into();
708
709        let mut reader = tokio_test::io::Builder::new()
710            .read(&client_handshake)
711            .read(&authentication)
712            .build();
713        let mut writer = tokio_test::io::Builder::new()
714            .write(&server_handshake.into_bytes())
715            .build();
716
717        let result = authenticate(&mut reader, &mut writer, clients).await;
718
719        assert!(matches!(result, Err(Error::InvalidCredentials)));
720    }
721
722    #[tokio::test]
723    async fn authentication_succeeds_with_correct_credentials() {
724        let clients = Arc::new(RwLock::new(HashMap::default()));
725        clients.write().await.insert(
726            "client".into(),
727            PasswordHashString::new(
728                "$argon2id$v=19$m=10,t=1,p=1$THh0RHE5YWNkQUZNa2lqUA$dmB4X7J49jjCGA",
729            )
730            .unwrap(),
731        );
732
733        let client_handshake = build_handshake(3, 3);
734        let authentication = build_authentication("client", "secret", "127.0.0.1");
735
736        let server_handshake: Message = ServerMessage::ServerHandshake {
737            accepted_version: 3,
738            capabilities: Capabilities::PreloadCheck,
739        }
740        .into();
741
742        let authentication_accepted: Message = ServerMessage::AuthenticationAccepted.into();
743
744        let mut reader = tokio_test::io::Builder::new()
745            .read(&client_handshake)
746            .read(&authentication)
747            .build();
748        let mut writer = tokio_test::io::Builder::new()
749            .write(&server_handshake.into_bytes())
750            .write(&authentication_accepted.into_bytes())
751            .build();
752
753        let result = authenticate(&mut reader, &mut writer, clients).await;
754        println!("{:?}", result);
755
756        assert!(result.is_ok());
757    }
758
759    #[tokio::test]
760    async fn authentication_fails_with_wrong_password() {
761        let clients = Arc::new(RwLock::new(HashMap::default()));
762        clients.write().await.insert(
763            "client".into(),
764            PasswordHashString::new(
765                "$argon2id$v=19$m=10,t=1,p=1$THh0RHE5YWNkQUZNa2lqUA$dmB4X7J49jjCGA",
766            )
767            .unwrap(),
768        );
769
770        let client_handshake = build_handshake(3, 3);
771        let authentication = build_authentication("client1", "wrong-secret", "127.0.0.1");
772
773        let server_handshake: Message = ServerMessage::ServerHandshake {
774            accepted_version: 3,
775            capabilities: Capabilities::PreloadCheck,
776        }
777        .into();
778
779        let mut reader = tokio_test::io::Builder::new()
780            .read(&client_handshake)
781            .read(&authentication)
782            .build();
783        let mut writer = tokio_test::io::Builder::new()
784            .write(&server_handshake.into_bytes())
785            .build();
786
787        let result = authenticate(&mut reader, &mut writer, clients).await;
788
789        assert!(matches!(result, Err(Error::InvalidCredentials)));
790    }
791
792    fn dummy_engine_tx() -> mpsc::Sender<(EngineRequest, oneshot::Sender<ServerMessage>)> {
793        let (tx, _rx) = mpsc::channel(1);
794        tx
795    }
796
797    fn dummy_event_tx() -> broadcast::Sender<Event> {
798        let (tx, _rx) = broadcast::channel(1);
799        tx
800    }
801
802    fn build_handshake(min_version: u8, max_version: u8) -> Vec<u8> {
803        let mut buf = Vec::new();
804        let payload = [min_version, max_version];
805
806        buf.push(0x01);
807        buf.extend(&(payload.len() as u32).to_le_bytes());
808        buf.extend(&payload);
809
810        buf
811    }
812
813    fn build_authentication(client_id: &str, password: &str, ip_addr: &str) -> Vec<u8> {
814        use std::net::IpAddr;
815
816        let mut buf = Vec::new();
817
818        let client_id_bytes = client_id.as_bytes();
819        let password_bytes = password.as_bytes();
820
821        let mut payload = Vec::new();
822        payload.push(client_id_bytes.len() as u8);
823        payload.extend(client_id_bytes);
824
825        payload.push(password_bytes.len() as u8);
826        payload.extend(password_bytes);
827
828        let ip: IpAddr = ip_addr.parse().expect("Invalid IP address");
829        match ip {
830            IpAddr::V4(v4) => {
831                payload.push(4); // IPv4
832                payload.extend(&v4.octets());
833            }
834            IpAddr::V6(v6) => {
835                payload.push(6); // IPv6
836                payload.extend(&v6.octets());
837            }
838        }
839
840        buf.push(0x03);
841        buf.extend(&(payload.len() as u32).to_le_bytes());
842        buf.extend(payload);
843
844        buf
845    }
846}