bloop_server_framework/
network.rs

1//! TLS-based network server for handling authenticated client connections.
2//!
3//! This module provides a [`NetworkListener`] that manages client authentication,
4//! request processing, and message dispatching over a secure TCP connection.
5
6use crate::engine::EngineRequest;
7use crate::event::Event;
8use crate::message::{Capabilities, ClientMessage, ErrorResponse, Message, ServerMessage};
9use argon2::{Argon2, PasswordVerifier, password_hash::PasswordHashString};
10use rustls::ServerConfig;
11use rustls::pki_types::{
12    CertificateDer, PrivateKeyDer,
13    pem::{self, PemObject},
14};
15use std::collections::HashMap;
16use std::fmt::Debug;
17use std::net::{IpAddr, SocketAddr};
18use std::path::PathBuf;
19use std::result;
20use std::sync::Arc;
21use std::time::Duration;
22use thiserror::Error;
23use tokio::io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, BufWriter};
24use tokio::net::{TcpListener, TcpStream};
25use tokio::sync::{RwLock, broadcast, mpsc, oneshot};
26use tokio::time::timeout;
27#[cfg(feature = "tokio-graceful-shutdown")]
28use tokio_graceful_shutdown::{FutureExt, IntoSubsystem, SubsystemHandle};
29use tokio_rustls::TlsAcceptor;
30use tracing::{info, instrument, warn};
31
32/// Maps client IDs to their stored password hashes (Argon2).
33///
34/// Used during client authentication.
35pub type ClientRegistry = HashMap<String, PasswordHashString>;
36
37#[derive(Error, Debug)]
38pub enum Error {
39    #[error(transparent)]
40    Io(#[from] io::Error),
41
42    #[error(transparent)]
43    Oneshot(#[from] oneshot::error::RecvError),
44
45    #[error("client sent unexpected message: {0:?}")]
46    UnexpectedMessage(ClientMessage),
47
48    #[error("client requested an unsupported version range: {0} - {1}")]
49    UnsupportedVersion(u8, u8),
50
51    #[error("unrecognized request code: {0}")]
52    UnknownRequest(u8),
53
54    #[error("client provided invalid credentials")]
55    InvalidCredentials,
56}
57
58pub type Result<T> = result::Result<T, Error>;
59
60/// A wrapper for custom client requests that are forwarded to the application
61/// layer.
62#[derive(Debug)]
63#[allow(dead_code)]
64pub struct CustomRequestMessage {
65    pub client_id: String,
66    pub message: Message,
67    pub response: oneshot::Sender<Option<Message>>,
68}
69
70/// A TLS-secured TCP server that accepts client connections, authenticates
71/// them, and dispatches their requests to the appropriate handlers.
72///
73/// This listener handles authentication, version negotiation, and supports
74/// custom client messages.
75///
76/// # Examples
77///
78/// ```no_run
79/// use std::sync::Arc;
80/// use tokio::sync::{mpsc, broadcast, RwLock};
81/// use bloop_server_framework::network::NetworkListenerBuilder;
82///
83/// #[tokio::main]
84/// async fn main() {
85///   let clients = Arc::new(RwLock::new(Default::default()));
86///   let (engine_tx, _) = mpsc::channel(10);
87///   let (event_tx, _) = broadcast::channel(10);
88///
89///   let listener = NetworkListenerBuilder::new()
90///       .address("127.0.0.1:12345")
91///       .cert_path("server.crt")
92///       .key_path("server.key")
93///       .clients(clients)
94///       .engine_tx(engine_tx)
95///       .event_tx(event_tx)
96///       .build()
97///       .unwrap();
98///
99///   listener.listen().await.unwrap();
100/// }
101/// ```
102pub struct NetworkListener {
103    clients: Arc<RwLock<ClientRegistry>>,
104    addr: SocketAddr,
105    tls_acceptor: TlsAcceptor,
106    engine_tx: mpsc::Sender<(EngineRequest, oneshot::Sender<ServerMessage>)>,
107    event_tx: broadcast::Sender<Event>,
108    custom_req_tx: Option<mpsc::Sender<CustomRequestMessage>>,
109}
110
111impl Debug for NetworkListener {
112    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113        f.debug_struct("NetworkListener")
114            .field("clients", &self.clients)
115            .field("addr", &self.addr)
116            .field("engine_tx", &self.engine_tx)
117            .field("event_tx", &self.event_tx)
118            .field("custom_req_tx", &self.custom_req_tx)
119            .finish()
120    }
121}
122
123impl NetworkListener {
124    /// Starts listening for incoming TCP connections.
125    ///
126    /// This method blocks indefinitely, accepting and processing new connections.
127    /// Each connection is handled asynchronously in its own task.
128    ///
129    /// Returns an error only if the server fails to bind to the specified address.
130    pub async fn listen(&self) -> Result<()> {
131        let listener = TcpListener::bind(self.addr).await?;
132        let mut con_counter: usize = 0;
133
134        loop {
135            let (stream, peer_addr) = listener.accept().await?;
136            let conn_id = con_counter;
137            con_counter += 1;
138            let event_tx = self.event_tx.clone();
139
140            self.handle_stream(stream, peer_addr, self.clients.clone(), conn_id, event_tx);
141        }
142    }
143
144    #[instrument(skip(self, stream, peer_addr, clients, event_tx))]
145    fn handle_stream(
146        &self,
147        stream: TcpStream,
148        peer_addr: SocketAddr,
149        clients: Arc<RwLock<ClientRegistry>>,
150        conn_id: usize,
151        event_tx: broadcast::Sender<Event>,
152    ) {
153        let acceptor = self.tls_acceptor.clone();
154        let engine_tx = self.engine_tx.clone();
155        let custom_req_tx = self.custom_req_tx.clone();
156
157        tokio::spawn(async move {
158            info!("new connection from {}", peer_addr);
159
160            let stream = match acceptor.accept(stream).await {
161                Ok(stream) => stream,
162                Err(error) => {
163                    warn!("failed to accept stream: {}", error);
164                    return;
165                }
166            };
167
168            let (reader, writer) = io::split(stream);
169            let mut reader = BufReader::new(reader);
170            let mut writer = BufWriter::new(writer);
171
172            let (client_id, local_ip, _version) = match timeout(
173                Duration::from_secs(2),
174                authenticate(&mut reader, &mut writer, clients),
175            )
176            .await
177            {
178                Ok(Ok(result)) => result,
179                Ok(Err(Error::UnexpectedMessage(message))) => {
180                    warn!("client error: unexpected message: {:?}", message);
181                    let _ = write_to_stream(
182                        &mut writer,
183                        ServerMessage::Error(ErrorResponse::UnexpectedMessage),
184                    )
185                    .await;
186                    return;
187                }
188                Ok(Err(Error::UnsupportedVersion(min_version, max_version))) => {
189                    warn!("client error: unsupported version range: {min_version} - {max_version}");
190                    let _ = write_to_stream(
191                        &mut writer,
192                        ServerMessage::Error(ErrorResponse::UnsupportedVersionRange),
193                    )
194                    .await;
195                    return;
196                }
197                Ok(Err(Error::InvalidCredentials)) => {
198                    warn!("client error: invalid credentials");
199                    let _ = write_to_stream(
200                        &mut writer,
201                        ServerMessage::Error(ErrorResponse::InvalidCredentials),
202                    )
203                    .await;
204                    return;
205                }
206                Ok(Err(Error::Io(error)))
207                    if matches!(
208                        error.kind(),
209                        io::ErrorKind::InvalidInput | io::ErrorKind::InvalidData
210                    ) =>
211                {
212                    warn!("client error: malformed message: {:?}", error);
213                    let _ = write_to_stream(
214                        &mut writer,
215                        ServerMessage::Error(ErrorResponse::MalformedMessage),
216                    )
217                    .await;
218                    return;
219                }
220                Ok(Err(error)) => {
221                    warn!("client error: connection died: {:?}", error);
222                    return;
223                }
224                Err(_) => {
225                    warn!("client error: authentication timed out");
226                    return;
227                }
228            };
229
230            let _ = event_tx.send(Event::ClientConnect {
231                client_id: client_id.clone(),
232                conn_id,
233                local_ip,
234            });
235
236            match handle_connection(
237                &mut reader,
238                &mut writer,
239                &client_id,
240                engine_tx,
241                custom_req_tx,
242            )
243            .await
244            {
245                Ok(()) => {
246                    let _ = event_tx.send(Event::ClientDisconnect { client_id, conn_id });
247                    return;
248                }
249                Err(Error::UnexpectedMessage(message)) => {
250                    warn!("client error: unexpected message: {:?}", message);
251                    let _ = write_to_stream(
252                        &mut writer,
253                        ServerMessage::Error(ErrorResponse::UnexpectedMessage),
254                    )
255                    .await;
256                }
257                Err(Error::Io(error))
258                    if matches!(
259                        error.kind(),
260                        io::ErrorKind::InvalidInput | io::ErrorKind::InvalidData
261                    ) =>
262                {
263                    warn!("client error: malformed message: {:?}", error);
264                    let _ = write_to_stream(
265                        &mut writer,
266                        ServerMessage::Error(ErrorResponse::MalformedMessage),
267                    )
268                    .await;
269                    return;
270                }
271                Err(error) => {
272                    warn!("client error: connection died: {:?}", error);
273                    return;
274                }
275            }
276
277            let _ = event_tx.send(Event::ClientConnectionLoss { client_id, conn_id });
278        });
279    }
280}
281
282#[cfg(feature = "tokio-graceful-shutdown")]
283impl IntoSubsystem<Error> for NetworkListener {
284    async fn run(self, subsys: &mut SubsystemHandle) -> Result<()> {
285        if let Ok(result) = self.listen().cancel_on_shutdown(subsys).await {
286            result?
287        }
288
289        Ok(())
290    }
291}
292
293async fn read_from_stream<S: AsyncRead + Unpin + Send>(stream: &mut S) -> Result<ClientMessage> {
294    let message_type = stream.read_u8().await?;
295    let payload_length = stream.read_u32_le().await?;
296
297    if payload_length == 0 {
298        return Ok(Message::new(message_type, vec![]).try_into()?);
299    }
300
301    let mut message = vec![0; payload_length as usize];
302    stream.read_exact(&mut message).await?;
303
304    Ok(Message::new(message_type, message).try_into()?)
305}
306
307async fn write_to_stream<S: AsyncWrite + Unpin + Send>(
308    stream: &mut S,
309    message: impl Into<Message>,
310) -> Result<()> {
311    let message: Message = message.into();
312    stream.write_all(&message.into_bytes()).await?;
313    stream.flush().await?;
314
315    Ok(())
316}
317
318#[derive(Debug, Error)]
319pub enum BuilderError {
320    #[error("missing field: {0}")]
321    MissingField(&'static str),
322
323    #[error(transparent)]
324    AddrParse(#[from] std::net::AddrParseError),
325
326    #[error("failed to read PEM file at {path}: {source}")]
327    Pem {
328        path: PathBuf,
329        #[source]
330        source: pem::Error,
331    },
332
333    #[error(transparent)]
334    Rustls(#[from] rustls::Error),
335}
336
337pub type BuilderResult<T> = result::Result<T, BuilderError>;
338
339/// Builder for [`NetworkListener`].
340///
341/// This allows configuring the address, TLS certificates, client registry,
342/// message channels, and custom request handlers.
343///
344/// # Examples
345///
346/// ```
347/// use std::sync::Arc;
348/// use tokio::sync::{broadcast, RwLock};
349/// use tokio::sync::mpsc;
350/// use bloop_server_framework::network::NetworkListenerBuilder;
351///
352/// let (engine_tx, engine_rx) = mpsc::channel(512);
353/// let (event_tx, event_rx) = broadcast::channel(512);
354///
355/// # let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
356/// let builder = NetworkListenerBuilder::new()
357///     .address("127.0.0.1:12345")
358///     .clients(Arc::new(RwLock::new(Default::default())))
359///     .engine_tx(engine_tx)
360///     .event_tx(event_tx)
361///     .cert_path("examples/cert.pem")
362///     .key_path("examples/key.pem")
363///     .build()
364///     .unwrap();
365/// ```
366#[derive(Debug, Default)]
367pub struct NetworkListenerBuilder {
368    address: Option<String>,
369    cert_path: Option<PathBuf>,
370    key_path: Option<PathBuf>,
371    clients: Option<Arc<RwLock<ClientRegistry>>>,
372    engine_tx: Option<mpsc::Sender<(EngineRequest, oneshot::Sender<ServerMessage>)>>,
373    event_tx: Option<broadcast::Sender<Event>>,
374    custom_req_tx: Option<mpsc::Sender<CustomRequestMessage>>,
375}
376
377impl NetworkListenerBuilder {
378    pub fn new() -> Self {
379        Self {
380            address: None,
381            cert_path: None,
382            key_path: None,
383            clients: None,
384            engine_tx: None,
385            event_tx: None,
386            custom_req_tx: None,
387        }
388    }
389
390    pub fn custom_req_tx(
391        self,
392        custom_req_tx: mpsc::Sender<CustomRequestMessage>,
393    ) -> NetworkListenerBuilder {
394        NetworkListenerBuilder {
395            clients: self.clients,
396            address: self.address,
397            cert_path: self.cert_path,
398            key_path: self.key_path,
399            engine_tx: self.engine_tx,
400            event_tx: self.event_tx,
401            custom_req_tx: Some(custom_req_tx),
402        }
403    }
404
405    pub fn address(mut self, address: impl Into<String>) -> Self {
406        self.address = Some(address.into());
407        self
408    }
409
410    pub fn cert_path(mut self, path: impl Into<PathBuf>) -> Self {
411        self.cert_path = Some(path.into());
412        self
413    }
414
415    pub fn key_path(mut self, path: impl Into<PathBuf>) -> Self {
416        self.key_path = Some(path.into());
417        self
418    }
419
420    pub fn clients(mut self, clients: Arc<RwLock<ClientRegistry>>) -> Self {
421        self.clients = Some(clients);
422        self
423    }
424
425    pub fn engine_tx(
426        mut self,
427        tx: mpsc::Sender<(EngineRequest, oneshot::Sender<ServerMessage>)>,
428    ) -> Self {
429        self.engine_tx = Some(tx);
430        self
431    }
432
433    pub fn event_tx(mut self, tx: broadcast::Sender<Event>) -> Self {
434        self.event_tx = Some(tx);
435        self
436    }
437
438    /// Builds the [`NetworkListener`] from the provided configuration.
439    ///
440    /// Returns an error if required fields are missing, or TLS setup fails.
441    pub fn build(self) -> BuilderResult<NetworkListener> {
442        let addr: SocketAddr = self
443            .address
444            .ok_or_else(|| BuilderError::MissingField("address"))?
445            .parse()?;
446
447        let cert_path = self
448            .cert_path
449            .ok_or_else(|| BuilderError::MissingField("cert_path"))?;
450        let key_path = self
451            .key_path
452            .ok_or_else(|| BuilderError::MissingField("key_path"))?;
453
454        let certs = CertificateDer::pem_file_iter(&cert_path)
455            .map_err(|err| BuilderError::Pem {
456                path: cert_path.clone(),
457                source: err,
458            })?
459            .collect::<result::Result<Vec<_>, _>>()
460            .map_err(|err| BuilderError::Pem {
461                path: cert_path,
462                source: err,
463            })?;
464        let key = PrivateKeyDer::from_pem_file(&key_path).map_err(|err| BuilderError::Pem {
465            path: key_path,
466            source: err,
467        })?;
468
469        let config = ServerConfig::builder()
470            .with_no_client_auth()
471            .with_single_cert(certs, key)?;
472        let tls_acceptor = TlsAcceptor::from(Arc::new(config));
473
474        Ok(NetworkListener {
475            clients: self
476                .clients
477                .ok_or_else(|| BuilderError::MissingField("clients"))?,
478            addr,
479            tls_acceptor,
480            engine_tx: self
481                .engine_tx
482                .ok_or_else(|| BuilderError::MissingField("engine_tx"))?,
483            event_tx: self
484                .event_tx
485                .ok_or_else(|| BuilderError::MissingField("event_tx"))?,
486            custom_req_tx: self.custom_req_tx,
487        })
488    }
489}
490
491/// Handles an authenticated client connection.
492///
493/// Reads client messages from the stream, dispatches them to the appropriate
494/// handlers, and sends back server responses.
495async fn handle_connection<R, W>(
496    reader: &mut R,
497    writer: &mut W,
498    client_id: &str,
499    engine_tx: mpsc::Sender<(EngineRequest, oneshot::Sender<ServerMessage>)>,
500    custom_req_tx: Option<mpsc::Sender<CustomRequestMessage>>,
501) -> Result<()>
502where
503    R: AsyncRead + Unpin + Send,
504    W: AsyncWrite + Unpin + Send,
505{
506    loop {
507        let message = match timeout(Duration::from_secs(30), read_from_stream(reader)).await {
508            Ok(Ok(message)) => message,
509            Ok(Err(error)) => return Err(error),
510            Err(_) => return Ok(()),
511        };
512
513        let engine_request = match message {
514            ClientMessage::Bloop { nfc_uid } => EngineRequest::Bloop {
515                nfc_uid,
516                client_id: client_id.to_string(),
517            },
518            ClientMessage::RetrieveAudio { achievement_id } => {
519                EngineRequest::RetrieveAudio { id: achievement_id }
520            }
521            ClientMessage::PreloadCheck {
522                audio_manifest_hash,
523            } => EngineRequest::PreloadCheck {
524                manifest_hash: audio_manifest_hash,
525            },
526            ClientMessage::Ping => {
527                write_to_stream(writer, ServerMessage::Pong).await?;
528                continue;
529            }
530            ClientMessage::Quit => break,
531            ClientMessage::Unknown(message) => {
532                if let Some(sender) = custom_req_tx.as_ref() {
533                    let (resp_tx, resp_rx) = oneshot::channel();
534
535                    let _ = sender
536                        .send(CustomRequestMessage {
537                            client_id: client_id.to_string(),
538                            message,
539                            response: resp_tx,
540                        })
541                        .await;
542
543                    if let Some(message) = resp_rx.await? {
544                        write_to_stream(writer, message).await?;
545                    }
546                }
547
548                continue;
549            }
550            message => return Err(Error::UnexpectedMessage(message)),
551        };
552
553        let (resp_tx, resp_rx) = oneshot::channel();
554        let _ = engine_tx.send((engine_request, resp_tx)).await;
555        let response = resp_rx.await?;
556
557        write_to_stream(writer, response).await?;
558    }
559
560    Ok(())
561}
562
563/// Authenticates a client by performing handshake and credential verification.
564///
565/// Returns the client ID, its IP address, and the negotiated protocol version
566/// on success.
567async fn authenticate<R, W>(
568    reader: &mut R,
569    writer: &mut W,
570    clients: Arc<RwLock<ClientRegistry>>,
571) -> Result<(String, IpAddr, u8)>
572where
573    R: AsyncRead + Unpin + Send,
574    W: AsyncWrite + Unpin + Send,
575{
576    let (min_version, max_version) = match read_from_stream(reader).await? {
577        ClientMessage::ClientHandshake {
578            min_version,
579            max_version,
580        } => (min_version, max_version),
581        message => return Err(Error::UnexpectedMessage(message)),
582    };
583
584    if min_version > 3 || max_version < 3 {
585        return Err(Error::UnsupportedVersion(min_version, max_version));
586    }
587
588    write_to_stream(
589        writer,
590        ServerMessage::ServerHandshake {
591            accepted_version: 3,
592            capabilities: Capabilities::PreloadCheck,
593        },
594    )
595    .await?;
596
597    let (client_id, client_secret, ip_addr) = match read_from_stream(reader).await? {
598        ClientMessage::Authentication {
599            client_id,
600            client_secret,
601            ip_addr,
602        } => (client_id, client_secret, ip_addr),
603        message => return Err(Error::UnexpectedMessage(message)),
604    };
605
606    let clients = clients.read().await;
607    let Some(secret_hash) = clients.get(&client_id) else {
608        return Err(Error::InvalidCredentials);
609    };
610
611    if Argon2::default()
612        .verify_password(client_secret.as_bytes(), &secret_hash.password_hash())
613        .is_err()
614    {
615        return Err(Error::InvalidCredentials);
616    }
617
618    write_to_stream(writer, ServerMessage::AuthenticationAccepted).await?;
619
620    Ok((client_id.to_string(), ip_addr, 3))
621}
622
623#[cfg(test)]
624mod tests {
625    use super::*;
626    use std::fs;
627    use tempfile::tempdir;
628
629    #[tokio::test]
630    async fn builder_fails_with_missing_fields() {
631        let builder = NetworkListenerBuilder::new();
632        let result = builder.build();
633        assert!(matches!(result, Err(BuilderError::MissingField(_))));
634    }
635
636    #[tokio::test]
637    async fn builder_fails_with_invalid_address() {
638        let builder = NetworkListenerBuilder::new()
639            .address("invalid-addr")
640            .cert_path("cert.pem")
641            .key_path("key.pem")
642            .clients(Arc::new(RwLock::new(Default::default())))
643            .engine_tx(dummy_engine_tx())
644            .event_tx(dummy_event_tx());
645
646        let result = builder.build();
647        assert!(matches!(result, Err(BuilderError::AddrParse(_))));
648    }
649
650    #[tokio::test]
651    async fn builder_fails_on_invalid_pem_files() {
652        let dir = tempdir().unwrap();
653        let cert_path = dir.path().join("cert.pem");
654        let key_path = dir.path().join("key.pem");
655        fs::write(&cert_path, b"invalid-cert").unwrap();
656        fs::write(&key_path, b"invalid-key").unwrap();
657
658        let builder = NetworkListenerBuilder::new()
659            .address("127.0.0.1:12345")
660            .cert_path(&cert_path)
661            .key_path(&key_path)
662            .clients(Arc::new(RwLock::new(Default::default())))
663            .engine_tx(dummy_engine_tx())
664            .event_tx(dummy_event_tx());
665
666        let result = builder.build();
667        assert!(matches!(result, Err(BuilderError::Pem { .. })));
668    }
669
670    #[tokio::test]
671    async fn builder_succeeds_with_valid_dummy_pem() {
672        let dir = tempdir().unwrap();
673        let cert_path = dir.path().join("cert.pem");
674        let key_path = dir.path().join("key.pem");
675
676        let cert_data = include_bytes!("../examples/cert.pem");
677        let key_data = include_bytes!("../examples/key.pem");
678
679        fs::write(&cert_path, cert_data).unwrap();
680        fs::write(&key_path, key_data).unwrap();
681
682        let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
683        let builder = NetworkListenerBuilder::new()
684            .address("127.0.0.1:12345")
685            .cert_path(&cert_path)
686            .key_path(&key_path)
687            .clients(Arc::new(RwLock::new(Default::default())))
688            .engine_tx(dummy_engine_tx())
689            .event_tx(dummy_event_tx());
690
691        let result = builder.build();
692        assert!(result.is_ok());
693    }
694
695    #[tokio::test]
696    async fn authentication_fails_with_wrong_client_id() {
697        let clients = Arc::new(RwLock::new(Default::default()));
698
699        let client_handshake = build_handshake(3, 3);
700        let authentication = build_authentication("unknown-client", "password", "127.0.0.1");
701
702        let server_handshake: Message = ServerMessage::ServerHandshake {
703            accepted_version: 3,
704            capabilities: Capabilities::PreloadCheck,
705        }
706        .into();
707
708        let mut reader = tokio_test::io::Builder::new()
709            .read(&client_handshake)
710            .read(&authentication)
711            .build();
712        let mut writer = tokio_test::io::Builder::new()
713            .write(&server_handshake.into_bytes())
714            .build();
715
716        let result = authenticate(&mut reader, &mut writer, clients).await;
717
718        assert!(matches!(result, Err(Error::InvalidCredentials)));
719    }
720
721    #[tokio::test]
722    async fn authentication_succeeds_with_correct_credentials() {
723        let clients = Arc::new(RwLock::new(HashMap::default()));
724        clients.write().await.insert(
725            "client".into(),
726            PasswordHashString::new(
727                "$argon2id$v=19$m=10,t=1,p=1$THh0RHE5YWNkQUZNa2lqUA$dmB4X7J49jjCGA",
728            )
729            .unwrap(),
730        );
731
732        let client_handshake = build_handshake(3, 3);
733        let authentication = build_authentication("client", "secret", "127.0.0.1");
734
735        let server_handshake: Message = ServerMessage::ServerHandshake {
736            accepted_version: 3,
737            capabilities: Capabilities::PreloadCheck,
738        }
739        .into();
740
741        let authentication_accepted: Message = ServerMessage::AuthenticationAccepted.into();
742
743        let mut reader = tokio_test::io::Builder::new()
744            .read(&client_handshake)
745            .read(&authentication)
746            .build();
747        let mut writer = tokio_test::io::Builder::new()
748            .write(&server_handshake.into_bytes())
749            .write(&authentication_accepted.into_bytes())
750            .build();
751
752        let result = authenticate(&mut reader, &mut writer, clients).await;
753        println!("{:?}", result);
754
755        assert!(result.is_ok());
756    }
757
758    #[tokio::test]
759    async fn authentication_fails_with_wrong_password() {
760        let clients = Arc::new(RwLock::new(HashMap::default()));
761        clients.write().await.insert(
762            "client".into(),
763            PasswordHashString::new(
764                "$argon2id$v=19$m=10,t=1,p=1$THh0RHE5YWNkQUZNa2lqUA$dmB4X7J49jjCGA",
765            )
766            .unwrap(),
767        );
768
769        let client_handshake = build_handshake(3, 3);
770        let authentication = build_authentication("client1", "wrong-secret", "127.0.0.1");
771
772        let server_handshake: Message = ServerMessage::ServerHandshake {
773            accepted_version: 3,
774            capabilities: Capabilities::PreloadCheck,
775        }
776        .into();
777
778        let mut reader = tokio_test::io::Builder::new()
779            .read(&client_handshake)
780            .read(&authentication)
781            .build();
782        let mut writer = tokio_test::io::Builder::new()
783            .write(&server_handshake.into_bytes())
784            .build();
785
786        let result = authenticate(&mut reader, &mut writer, clients).await;
787
788        assert!(matches!(result, Err(Error::InvalidCredentials)));
789    }
790
791    fn dummy_engine_tx() -> mpsc::Sender<(EngineRequest, oneshot::Sender<ServerMessage>)> {
792        let (tx, _rx) = mpsc::channel(1);
793        tx
794    }
795
796    fn dummy_event_tx() -> broadcast::Sender<Event> {
797        let (tx, _rx) = broadcast::channel(1);
798        tx
799    }
800
801    fn build_handshake(min_version: u8, max_version: u8) -> Vec<u8> {
802        let mut buf = Vec::new();
803        let payload = [min_version, max_version];
804
805        buf.push(0x01);
806        buf.extend(&(payload.len() as u32).to_le_bytes());
807        buf.extend(&payload);
808
809        buf
810    }
811
812    fn build_authentication(client_id: &str, password: &str, ip_addr: &str) -> Vec<u8> {
813        use std::net::IpAddr;
814
815        let mut buf = Vec::new();
816
817        let client_id_bytes = client_id.as_bytes();
818        let password_bytes = password.as_bytes();
819
820        let mut payload = Vec::new();
821        payload.push(client_id_bytes.len() as u8);
822        payload.extend(client_id_bytes);
823
824        payload.push(password_bytes.len() as u8);
825        payload.extend(password_bytes);
826
827        let ip: IpAddr = ip_addr.parse().expect("Invalid IP address");
828        match ip {
829            IpAddr::V4(v4) => {
830                payload.push(4); // IPv4
831                payload.extend(&v4.octets());
832            }
833            IpAddr::V6(v6) => {
834                payload.push(6); // IPv6
835                payload.extend(&v6.octets());
836            }
837        }
838
839        buf.push(0x03);
840        buf.extend(&(payload.len() as u32).to_le_bytes());
841        buf.extend(payload);
842
843        buf
844    }
845}