bloop_server_framework/
network.rs

1//! TLS-based network server for handling authenticated client connections.
2//!
3//! This module provides a [`NetworkListener`] that manages client authentication,
4//! request processing, and message dispatching over a secure TCP connection.
5
6use crate::engine::EngineRequest;
7use crate::event::Event;
8use crate::message::{ClientMessage, ErrorResponse, Message, ServerFeatures, ServerMessage};
9use argon2::{Argon2, PasswordVerifier, password_hash::PasswordHashString};
10use async_trait::async_trait;
11use rustls::ServerConfig;
12use rustls::pki_types::{
13    CertificateDer, PrivateKeyDer,
14    pem::{self, PemObject},
15};
16use std::collections::HashMap;
17use std::fmt::Debug;
18use std::net::{IpAddr, SocketAddr};
19use std::path::PathBuf;
20use std::result;
21use std::sync::Arc;
22use std::time::Duration;
23use thiserror::Error;
24use tokio::io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter};
25use tokio::net::{TcpListener, TcpStream};
26use tokio::sync::{RwLock, broadcast, mpsc, oneshot};
27use tokio::time::timeout;
28#[cfg(feature = "tokio-graceful-shutdown")]
29use tokio_graceful_shutdown::{FutureExt, IntoSubsystem, SubsystemHandle};
30use tokio_rustls::TlsAcceptor;
31use tracing::{info, instrument, warn};
32
33/// Maps client IDs to their stored password hashes (Argon2).
34///
35/// Used during client authentication.
36pub type ClientRegistry = HashMap<String, PasswordHashString>;
37
38#[derive(Error, Debug)]
39pub enum Error {
40    #[error(transparent)]
41    Io(#[from] io::Error),
42
43    #[error(transparent)]
44    Oneshot(#[from] oneshot::error::RecvError),
45
46    #[error("client sent unexpected message: {0:?}")]
47    UnexpectedMessage(ClientMessage),
48
49    #[error("client requested an unsupported version range: {0} - {1}")]
50    UnsupportedVersion(u8, u8),
51
52    #[error("unrecognized request code: {0}")]
53    UnknownRequest(u8),
54
55    #[error("client provided invalid credentials")]
56    InvalidCredentials,
57}
58
59pub type Result<T> = result::Result<T, Error>;
60
61/// A wrapper for custom client requests that are forwarded to the application
62/// layer.
63#[derive(Debug)]
64#[allow(dead_code)]
65pub struct CustomRequestMessage {
66    client_id: String,
67    message: Message,
68    response: oneshot::Sender<Option<Message>>,
69}
70
71/// A TLS-secured TCP server that accepts client connections, authenticates
72/// them, and dispatches their requests to the appropriate handlers.
73///
74/// This listener handles authentication, version negotiation, and supports
75/// custom client messages.
76///
77/// # Examples
78///
79/// ```no_run
80/// use std::sync::Arc;
81/// use tokio::sync::{mpsc, broadcast, RwLock};
82/// use bloop_server_framework::network::NetworkListenerBuilder;
83///
84/// #[tokio::main]
85/// async fn main() {
86///   let clients = Arc::new(RwLock::new(Default::default()));
87///   let (engine_tx, _) = mpsc::channel(10);
88///   let (event_tx, _) = broadcast::channel(10);
89///
90///   let listener = NetworkListenerBuilder::new()
91///       .address("127.0.0.1:12345")
92///       .cert_path("server.crt")
93///       .key_path("server.key")
94///       .clients(clients)
95///       .engine_tx(engine_tx)
96///       .event_tx(event_tx)
97///       .build()
98///       .unwrap();
99///
100///   listener.listen().await.unwrap();
101/// }
102/// ```
103pub struct NetworkListener {
104    clients: Arc<RwLock<ClientRegistry>>,
105    addr: SocketAddr,
106    tls_acceptor: TlsAcceptor,
107    engine_tx: mpsc::Sender<(EngineRequest, oneshot::Sender<ServerMessage>)>,
108    event_tx: broadcast::Sender<Event>,
109    custom_req_tx: Option<mpsc::Sender<CustomRequestMessage>>,
110}
111
112impl Debug for NetworkListener {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        f.debug_struct("NetworkListener")
115            .field("clients", &self.clients)
116            .field("addr", &self.addr)
117            .field("engine_tx", &self.engine_tx)
118            .field("event_tx", &self.event_tx)
119            .field("custom_req_tx", &self.custom_req_tx)
120            .finish()
121    }
122}
123
124impl NetworkListener {
125    /// Starts listening for incoming TCP connections.
126    ///
127    /// This method blocks indefinitely, accepting and processing new connections.
128    /// Each connection is handled asynchronously in its own task.
129    ///
130    /// Returns an error only if the server fails to bind to the specified address.
131    pub async fn listen(&self) -> Result<()> {
132        let listener = TcpListener::bind(self.addr).await?;
133        let mut con_counter: usize = 0;
134
135        loop {
136            let (stream, peer_addr) = listener.accept().await?;
137            let conn_id = con_counter;
138            con_counter += 1;
139            let event_tx = self.event_tx.clone();
140
141            self.handle_stream(stream, peer_addr, self.clients.clone(), conn_id, event_tx);
142        }
143    }
144
145    #[instrument(skip(self, stream, peer_addr, clients, event_tx))]
146    fn handle_stream(
147        &self,
148        stream: TcpStream,
149        peer_addr: SocketAddr,
150        clients: Arc<RwLock<ClientRegistry>>,
151        conn_id: usize,
152        event_tx: broadcast::Sender<Event>,
153    ) {
154        let acceptor = self.tls_acceptor.clone();
155        let engine_tx = self.engine_tx.clone();
156        let custom_req_tx = self.custom_req_tx.clone();
157
158        tokio::spawn(async move {
159            info!("new connection from {}", peer_addr);
160
161            let stream = match acceptor.accept(stream).await {
162                Ok(stream) => stream,
163                Err(error) => {
164                    warn!("failed to accept stream: {}", error);
165                    return;
166                }
167            };
168
169            let (reader, writer) = io::split(stream);
170            let mut reader = BufReader::new(reader);
171            let mut writer = BufWriter::new(writer);
172
173            let (client_id, local_ip, _version) = match timeout(
174                Duration::from_secs(2),
175                authenticate(&mut reader, &mut writer, clients),
176            )
177            .await
178            {
179                Ok(Ok(result)) => result,
180                Ok(Err(Error::UnexpectedMessage(message))) => {
181                    warn!("client error: unexpected message: {:?}", message);
182                    let _ = write_to_stream(
183                        &mut writer,
184                        ServerMessage::Error(ErrorResponse::UnexpectedMessage),
185                    )
186                    .await;
187                    return;
188                }
189                Ok(Err(Error::UnsupportedVersion(min_version, max_version))) => {
190                    warn!("client error: unsupported version range: {min_version} - {max_version}");
191                    let _ = write_to_stream(
192                        &mut writer,
193                        ServerMessage::Error(ErrorResponse::UnsupportedVersionRange),
194                    )
195                    .await;
196                    return;
197                }
198                Ok(Err(Error::InvalidCredentials)) => {
199                    warn!("client error: invalid credentials");
200                    let _ = write_to_stream(
201                        &mut writer,
202                        ServerMessage::Error(ErrorResponse::InvalidCredentials),
203                    )
204                    .await;
205                    return;
206                }
207                Ok(Err(Error::Io(error)))
208                    if matches!(
209                        error.kind(),
210                        io::ErrorKind::InvalidInput | io::ErrorKind::InvalidData
211                    ) =>
212                {
213                    warn!("client error: malformed message: {:?}", error);
214                    let _ = write_to_stream(
215                        &mut writer,
216                        ServerMessage::Error(ErrorResponse::MalformedMessage),
217                    )
218                    .await;
219                    return;
220                }
221                Ok(Err(error)) => {
222                    warn!("client error: connection died: {:?}", error);
223                    return;
224                }
225                Err(_) => {
226                    warn!("client error: authentication timed out");
227                    return;
228                }
229            };
230
231            let _ = event_tx.send(Event::ClientConnect {
232                client_id: client_id.clone(),
233                conn_id,
234                local_ip,
235            });
236
237            match handle_connection(
238                &mut reader,
239                &mut writer,
240                &client_id,
241                engine_tx,
242                custom_req_tx,
243            )
244            .await
245            {
246                Ok(()) => {
247                    let _ = event_tx.send(Event::ClientDisconnect { client_id, conn_id });
248                    return;
249                }
250                Err(Error::UnexpectedMessage(message)) => {
251                    warn!("client error: unexpected message: {:?}", message);
252                    let _ = write_to_stream(
253                        &mut writer,
254                        ServerMessage::Error(ErrorResponse::UnexpectedMessage),
255                    )
256                    .await;
257                }
258                Err(Error::Io(error))
259                    if matches!(
260                        error.kind(),
261                        io::ErrorKind::InvalidInput | io::ErrorKind::InvalidData
262                    ) =>
263                {
264                    warn!("client error: malformed message: {:?}", error);
265                    let _ = write_to_stream(
266                        &mut writer,
267                        ServerMessage::Error(ErrorResponse::MalformedMessage),
268                    )
269                    .await;
270                    return;
271                }
272                Err(error) => {
273                    warn!("client error: connection died: {:?}", error);
274                    return;
275                }
276            }
277
278            let _ = event_tx.send(Event::ClientConnectionLoss { client_id, conn_id });
279        });
280    }
281}
282
283#[cfg(feature = "tokio-graceful-shutdown")]
284#[async_trait]
285impl IntoSubsystem<Error> for NetworkListener {
286    async fn run(mut self, subsys: SubsystemHandle) -> Result<()> {
287        if let Ok(result) = self.listen().cancel_on_shutdown(&subsys).await {
288            result?
289        }
290
291        Ok(())
292    }
293}
294
295async fn read_from_stream<S: AsyncRead + Unpin + Send>(stream: &mut S) -> Result<ClientMessage> {
296    let message_type = stream.read_u8().await?;
297    let payload_length = stream.read_u32_le().await?;
298
299    if payload_length == 0 {
300        return Ok(Message::new(message_type, vec![]).try_into()?);
301    }
302
303    let mut message = vec![0; payload_length as usize];
304    stream.read_exact(&mut message).await?;
305
306    Ok(Message::new(message_type, message).try_into()?)
307}
308
309async fn write_to_stream<S: AsyncWrite + Unpin + Send>(
310    stream: &mut S,
311    message: impl Into<Message>,
312) -> Result<()> {
313    let message: Message = message.into();
314    stream.write_all(&message.into_bytes()).await?;
315    stream.flush().await?;
316
317    Ok(())
318}
319
320#[derive(Debug, Error)]
321pub enum BuilderError {
322    #[error("missing field: {0}")]
323    MissingField(&'static str),
324
325    #[error(transparent)]
326    AddrParse(#[from] std::net::AddrParseError),
327
328    #[error("failed to read PEM file at {path}: {source}")]
329    Pem {
330        path: PathBuf,
331        #[source]
332        source: pem::Error,
333    },
334
335    #[error(transparent)]
336    Rustls(#[from] rustls::Error),
337}
338
339pub type BuilderResult<T> = result::Result<T, BuilderError>;
340
341/// Builder for [`NetworkListener`].
342///
343/// This allows configuring the address, TLS certificates, client registry,
344/// message channels, and custom request handlers.
345///
346/// # Examples
347///
348/// ```
349/// use std::sync::Arc;
350/// use tokio::sync::{broadcast, RwLock};
351/// use tokio::sync::mpsc;
352/// use bloop_server_framework::network::NetworkListenerBuilder;
353///
354/// let (engine_tx, engine_rx) = mpsc::channel(512);
355/// let (event_tx, event_rx) = broadcast::channel(512);
356///
357/// let builder = NetworkListenerBuilder::new()
358///     .address("127.0.0.1:12345")
359///     .clients(Arc::new(RwLock::new(Default::default())))
360///     .engine_tx(engine_tx)
361///     .event_tx(event_tx)
362///     .cert_path("examples/cert.pem")
363///     .key_path("examples/key.pem")
364///     .build()
365///     .unwrap();
366/// ```
367#[derive(Debug, Default)]
368pub struct NetworkListenerBuilder {
369    address: Option<String>,
370    cert_path: Option<PathBuf>,
371    key_path: Option<PathBuf>,
372    clients: Option<Arc<RwLock<ClientRegistry>>>,
373    engine_tx: Option<mpsc::Sender<(EngineRequest, oneshot::Sender<ServerMessage>)>>,
374    event_tx: Option<broadcast::Sender<Event>>,
375    custom_req_tx: Option<mpsc::Sender<CustomRequestMessage>>,
376}
377
378impl NetworkListenerBuilder {
379    pub fn new() -> Self {
380        Self {
381            address: None,
382            cert_path: None,
383            key_path: None,
384            clients: None,
385            engine_tx: None,
386            event_tx: None,
387            custom_req_tx: None,
388        }
389    }
390
391    pub fn custom_req_tx(
392        self,
393        custom_req_tx: mpsc::Sender<CustomRequestMessage>,
394    ) -> NetworkListenerBuilder {
395        NetworkListenerBuilder {
396            clients: self.clients,
397            address: self.address,
398            cert_path: self.cert_path,
399            key_path: self.key_path,
400            engine_tx: self.engine_tx,
401            event_tx: self.event_tx,
402            custom_req_tx: Some(custom_req_tx),
403        }
404    }
405
406    pub fn address(mut self, address: impl Into<String>) -> Self {
407        self.address = Some(address.into());
408        self
409    }
410
411    pub fn cert_path(mut self, path: impl Into<PathBuf>) -> Self {
412        self.cert_path = Some(path.into());
413        self
414    }
415
416    pub fn key_path(mut self, path: impl Into<PathBuf>) -> Self {
417        self.key_path = Some(path.into());
418        self
419    }
420
421    pub fn clients(mut self, clients: Arc<RwLock<ClientRegistry>>) -> Self {
422        self.clients = Some(clients);
423        self
424    }
425
426    pub fn engine_tx(
427        mut self,
428        tx: mpsc::Sender<(EngineRequest, oneshot::Sender<ServerMessage>)>,
429    ) -> Self {
430        self.engine_tx = Some(tx);
431        self
432    }
433
434    pub fn event_tx(mut self, tx: broadcast::Sender<Event>) -> Self {
435        self.event_tx = Some(tx);
436        self
437    }
438
439    /// Builds the [`NetworkListener`] from the provided configuration.
440    ///
441    /// Returns an error if required fields are missing, or TLS setup fails.
442    pub fn build(self) -> BuilderResult<NetworkListener> {
443        let addr: SocketAddr = self
444            .address
445            .ok_or_else(|| BuilderError::MissingField("address"))?
446            .parse()?;
447
448        let cert_path = self
449            .cert_path
450            .ok_or_else(|| BuilderError::MissingField("cert_path"))?;
451        let key_path = self
452            .key_path
453            .ok_or_else(|| BuilderError::MissingField("key_path"))?;
454
455        let certs = CertificateDer::pem_file_iter(&cert_path)
456            .map_err(|err| BuilderError::Pem {
457                path: cert_path.clone(),
458                source: err,
459            })?
460            .collect::<result::Result<Vec<_>, _>>()
461            .map_err(|err| BuilderError::Pem {
462                path: cert_path,
463                source: err,
464            })?;
465        let key = PrivateKeyDer::from_pem_file(&key_path).map_err(|err| BuilderError::Pem {
466            path: key_path,
467            source: err,
468        })?;
469
470        let config = ServerConfig::builder()
471            .with_no_client_auth()
472            .with_single_cert(certs, key)?;
473        let tls_acceptor = TlsAcceptor::from(Arc::new(config));
474
475        Ok(NetworkListener {
476            clients: self
477                .clients
478                .ok_or_else(|| BuilderError::MissingField("clients"))?,
479            addr,
480            tls_acceptor,
481            engine_tx: self
482                .engine_tx
483                .ok_or_else(|| BuilderError::MissingField("engine_tx"))?,
484            event_tx: self
485                .event_tx
486                .ok_or_else(|| BuilderError::MissingField("event_tx"))?,
487            custom_req_tx: self.custom_req_tx,
488        })
489    }
490}
491
492/// Handles an authenticated client connection.
493///
494/// Reads client messages from the stream, dispatches them to the appropriate
495/// handlers, and sends back server responses.
496async fn handle_connection<R, W>(
497    reader: &mut R,
498    writer: &mut W,
499    client_id: &str,
500    engine_tx: mpsc::Sender<(EngineRequest, oneshot::Sender<ServerMessage>)>,
501    custom_req_tx: Option<mpsc::Sender<CustomRequestMessage>>,
502) -> Result<()>
503where
504    R: AsyncRead + Unpin + Send,
505    W: AsyncWrite + Unpin + Send,
506{
507    loop {
508        let message = match timeout(Duration::from_secs(30), read_from_stream(reader)).await {
509            Ok(Ok(message)) => message,
510            Ok(Err(error)) => return Err(error),
511            Err(_) => return Ok(()),
512        };
513
514        let engine_request = match message {
515            ClientMessage::Bloop { nfc_uid } => EngineRequest::Bloop {
516                nfc_uid,
517                client_id: client_id.to_string(),
518            },
519            ClientMessage::RetrieveAudio { achievement_id } => {
520                EngineRequest::RetrieveAudio { id: achievement_id }
521            }
522            ClientMessage::PreloadCheck {
523                audio_manifest_hash,
524            } => EngineRequest::PreloadCheck {
525                manifest_hash: audio_manifest_hash,
526            },
527            ClientMessage::Ping => {
528                write_to_stream(writer, ServerMessage::Pong).await?;
529                continue;
530            }
531            ClientMessage::Quit => break,
532            ClientMessage::Unknown(message) => {
533                if let Some(sender) = custom_req_tx.as_ref() {
534                    let (resp_tx, resp_rx) = oneshot::channel();
535
536                    let _ = sender
537                        .send(CustomRequestMessage {
538                            client_id: client_id.to_string(),
539                            message,
540                            response: resp_tx,
541                        })
542                        .await;
543
544                    if let Some(message) = resp_rx.await? {
545                        write_to_stream(writer, message).await?;
546                    }
547                }
548
549                continue;
550            }
551            message => return Err(Error::UnexpectedMessage(message)),
552        };
553
554        let (resp_tx, resp_rx) = oneshot::channel();
555        let _ = engine_tx.send((engine_request, resp_tx)).await;
556        let response = resp_rx.await?;
557
558        write_to_stream(writer, response).await?;
559    }
560
561    Ok(())
562}
563
564/// Authenticates a client by performing handshake and credential verification.
565///
566/// Returns the client ID, its IP address, and the negotiated protocol version
567/// on success.
568async fn authenticate<R, W>(
569    reader: &mut R,
570    writer: &mut W,
571    clients: Arc<RwLock<ClientRegistry>>,
572) -> Result<(String, IpAddr, u8)>
573where
574    R: AsyncRead + Unpin + Send,
575    W: AsyncWrite + Unpin + Send,
576{
577    let (min_version, max_version) = match read_from_stream(reader).await? {
578        ClientMessage::ClientHandshake {
579            min_version,
580            max_version,
581        } => (min_version, max_version),
582        message => return Err(Error::UnexpectedMessage(message)),
583    };
584
585    if min_version > 3 || max_version < 3 {
586        return Err(Error::UnsupportedVersion(min_version, max_version));
587    }
588
589    write_to_stream(
590        writer,
591        ServerMessage::ServerHandshake {
592            accepted_version: 3,
593            features: ServerFeatures::PreloadCheck,
594        },
595    )
596    .await?;
597
598    let (client_id, client_secret, ip_addr) = match read_from_stream(reader).await? {
599        ClientMessage::Authentication {
600            client_id,
601            client_secret,
602            ip_addr,
603        } => (client_id, client_secret, ip_addr),
604        message => return Err(Error::UnexpectedMessage(message)),
605    };
606
607    let clients = clients.read().await;
608    let Some(secret_hash) = clients.get(&client_id) else {
609        return Err(Error::InvalidCredentials);
610    };
611
612    if Argon2::default()
613        .verify_password(client_secret.as_bytes(), &secret_hash.password_hash())
614        .is_err()
615    {
616        return Err(Error::InvalidCredentials);
617    }
618
619    write_to_stream(writer, ServerMessage::AuthenticationAccepted).await?;
620
621    Ok((client_id.to_string(), ip_addr, 3))
622}
623
624#[cfg(test)]
625mod tests {
626    use super::*;
627    use std::fs;
628    use tempfile::tempdir;
629
630    #[tokio::test]
631    async fn builder_fails_with_missing_fields() {
632        let builder = NetworkListenerBuilder::new();
633        let result = builder.build();
634        assert!(matches!(result, Err(BuilderError::MissingField(_))));
635    }
636
637    #[tokio::test]
638    async fn builder_fails_with_invalid_address() {
639        let builder = NetworkListenerBuilder::new()
640            .address("invalid-addr")
641            .cert_path("cert.pem")
642            .key_path("key.pem")
643            .clients(Arc::new(RwLock::new(Default::default())))
644            .engine_tx(dummy_engine_tx())
645            .event_tx(dummy_event_tx());
646
647        let result = builder.build();
648        assert!(matches!(result, Err(BuilderError::AddrParse(_))));
649    }
650
651    #[tokio::test]
652    async fn builder_fails_on_invalid_pem_files() {
653        let dir = tempdir().unwrap();
654        let cert_path = dir.path().join("cert.pem");
655        let key_path = dir.path().join("key.pem");
656        fs::write(&cert_path, b"invalid-cert").unwrap();
657        fs::write(&key_path, b"invalid-key").unwrap();
658
659        let builder = NetworkListenerBuilder::new()
660            .address("127.0.0.1:12345")
661            .cert_path(&cert_path)
662            .key_path(&key_path)
663            .clients(Arc::new(RwLock::new(Default::default())))
664            .engine_tx(dummy_engine_tx())
665            .event_tx(dummy_event_tx());
666
667        let result = builder.build();
668        assert!(matches!(result, Err(BuilderError::Pem { .. })));
669    }
670
671    #[tokio::test]
672    async fn builder_succeeds_with_valid_dummy_pem() {
673        let dir = tempdir().unwrap();
674        let cert_path = dir.path().join("cert.pem");
675        let key_path = dir.path().join("key.pem");
676
677        let cert_data = include_bytes!("../examples/cert.pem");
678        let key_data = include_bytes!("../examples/key.pem");
679
680        fs::write(&cert_path, cert_data).unwrap();
681        fs::write(&key_path, key_data).unwrap();
682
683        let builder = NetworkListenerBuilder::new()
684            .address("127.0.0.1:12345")
685            .cert_path(&cert_path)
686            .key_path(&key_path)
687            .clients(Arc::new(RwLock::new(Default::default())))
688            .engine_tx(dummy_engine_tx())
689            .event_tx(dummy_event_tx());
690
691        let result = builder.build();
692        assert!(result.is_ok());
693    }
694
695    #[tokio::test]
696    async fn authentication_fails_with_wrong_client_id() {
697        let clients = Arc::new(RwLock::new(Default::default()));
698
699        let client_handshake = build_handshake(3, 3);
700        let authentication = build_authentication("unknown-client", "password", "127.0.0.1");
701
702        let server_handshake: Message = ServerMessage::ServerHandshake {
703            accepted_version: 3,
704            features: ServerFeatures::PreloadCheck,
705        }
706        .into();
707
708        let mut reader = tokio_test::io::Builder::new()
709            .read(&client_handshake)
710            .read(&authentication)
711            .build();
712        let mut writer = tokio_test::io::Builder::new()
713            .write(&server_handshake.into_bytes())
714            .build();
715
716        let result = authenticate(&mut reader, &mut writer, clients).await;
717
718        assert!(matches!(result, Err(Error::InvalidCredentials)));
719    }
720
721    #[tokio::test]
722    async fn authentication_succeeds_with_correct_credentials() {
723        let clients = Arc::new(RwLock::new(HashMap::default()));
724        clients.write().await.insert(
725            "client".into(),
726            PasswordHashString::new(
727                "$argon2id$v=19$m=10,t=1,p=1$THh0RHE5YWNkQUZNa2lqUA$dmB4X7J49jjCGA",
728            )
729            .unwrap(),
730        );
731
732        let client_handshake = build_handshake(3, 3);
733        let authentication = build_authentication("client", "secret", "127.0.0.1");
734
735        let server_handshake: Message = ServerMessage::ServerHandshake {
736            accepted_version: 3,
737            features: ServerFeatures::PreloadCheck,
738        }
739        .into();
740
741        let authentication_accepted: Message = ServerMessage::AuthenticationAccepted.into();
742
743        let mut reader = tokio_test::io::Builder::new()
744            .read(&client_handshake)
745            .read(&authentication)
746            .build();
747        let mut writer = tokio_test::io::Builder::new()
748            .write(&server_handshake.into_bytes())
749            .write(&authentication_accepted.into_bytes())
750            .build();
751
752        let result = authenticate(&mut reader, &mut writer, clients).await;
753        println!("{:?}", result);
754
755        assert!(result.is_ok());
756    }
757
758    #[tokio::test]
759    async fn authentication_fails_with_wrong_password() {
760        let clients = Arc::new(RwLock::new(HashMap::default()));
761        clients.write().await.insert(
762            "client".into(),
763            PasswordHashString::new(
764                "$argon2id$v=19$m=10,t=1,p=1$THh0RHE5YWNkQUZNa2lqUA$dmB4X7J49jjCGA",
765            )
766            .unwrap(),
767        );
768
769        let client_handshake = build_handshake(3, 3);
770        let authentication = build_authentication("client1", "wrong-secret", "127.0.0.1");
771
772        let server_handshake: Message = ServerMessage::ServerHandshake {
773            accepted_version: 3,
774            features: ServerFeatures::PreloadCheck,
775        }
776        .into();
777
778        let mut reader = tokio_test::io::Builder::new()
779            .read(&client_handshake)
780            .read(&authentication)
781            .build();
782        let mut writer = tokio_test::io::Builder::new()
783            .write(&server_handshake.into_bytes())
784            .build();
785
786        let result = authenticate(&mut reader, &mut writer, clients).await;
787
788        assert!(matches!(result, Err(Error::InvalidCredentials)));
789    }
790
791    fn dummy_engine_tx() -> mpsc::Sender<(EngineRequest, oneshot::Sender<ServerMessage>)> {
792        let (tx, _rx) = mpsc::channel(1);
793        tx
794    }
795
796    fn dummy_event_tx() -> broadcast::Sender<Event> {
797        let (tx, _rx) = broadcast::channel(1);
798        tx
799    }
800
801    fn build_handshake(min_version: u8, max_version: u8) -> Vec<u8> {
802        let mut buf = Vec::new();
803        let payload = [min_version, max_version];
804
805        buf.push(0x01);
806        buf.extend(&(payload.len() as u32).to_le_bytes());
807        buf.extend(&payload);
808
809        buf
810    }
811
812    fn build_authentication(client_id: &str, password: &str, ip_addr: &str) -> Vec<u8> {
813        use std::net::IpAddr;
814
815        let mut buf = Vec::new();
816
817        let client_id_bytes = client_id.as_bytes();
818        let password_bytes = password.as_bytes();
819
820        let mut payload = Vec::new();
821        payload.push(client_id_bytes.len() as u8);
822        payload.extend(client_id_bytes);
823
824        payload.push(password_bytes.len() as u8);
825        payload.extend(password_bytes);
826
827        let ip: IpAddr = ip_addr.parse().expect("Invalid IP address");
828        match ip {
829            IpAddr::V4(v4) => {
830                payload.push(4); // IPv4
831                payload.extend(&v4.octets());
832            }
833            IpAddr::V6(v6) => {
834                payload.push(6); // IPv6
835                payload.extend(&v6.octets());
836            }
837        }
838
839        buf.push(0x03);
840        buf.extend(&(payload.len() as u32).to_le_bytes());
841        buf.extend(payload);
842
843        buf
844    }
845}