Skip to main content

msg_socket/hooks/
token.rs

1//! Token-based authentication connection hooks.
2//!
3//! This module provides ready-to-use connection hooks for simple token-based authentication:
4//!
5//! - [`ServerHook`] - Server-side connection hook that validates client tokens
6//! - [`ClientHook`] - Client-side connection hook that sends a token to the server
7//!
8//! # Example
9//!
10//! ```no_run
11//! use bytes::Bytes;
12//! use msg_socket::{
13//!     RepSocket, ReqSocket,
14//!     hooks::token::{ClientHook, ServerHook},
15//! };
16//! use msg_transport::tcp::Tcp;
17//!
18//! // Server side - validates incoming tokens
19//! let rep = RepSocket::new(Tcp::default()).with_connection_hook(ServerHook::new(|token| {
20//!     // Custom validation logic
21//!     **token == *b"secret"
22//! }));
23//!
24//! // Client side - sends token on connect
25//! let req =
26//!     ReqSocket::new(Tcp::default()).with_connection_hook(ClientHook::new(Bytes::from("secret")));
27//! ```
28
29use std::io;
30
31use bytes::Bytes;
32use futures::SinkExt;
33use tokio::io::{AsyncRead, AsyncWrite};
34use tokio_stream::StreamExt;
35use tokio_util::codec::Framed;
36
37use crate::hooks::{ConnectionHook, Error, HookResult};
38use msg_wire::auth;
39
40/// Error type for server-side token authentication.
41#[derive(Debug, thiserror::Error)]
42pub enum ServerHookError {
43    /// The client's token was rejected by the validator.
44    #[error("authentication rejected")]
45    Rejected,
46    /// The connection was closed before authentication completed.
47    #[error("connection closed")]
48    ConnectionClosed,
49    /// Expected an auth message but received something else.
50    #[error("expected auth message")]
51    ExpectedAuthMessage,
52}
53
54/// Error type for client-side token authentication.
55#[derive(Debug, thiserror::Error)]
56pub enum ClientHookError {
57    /// The server denied the authentication.
58    #[error("authentication denied")]
59    Denied,
60    /// The connection was closed before authentication completed.
61    #[error("connection closed")]
62    ConnectionClosed,
63}
64
65/// Server-side authentication connection hook that validates incoming client tokens.
66///
67/// When a client connects, this connection hook:
68/// 1. Waits for the client to send an auth token
69/// 2. Validates the token using the provided validator function
70/// 3. Sends an ACK on success, or rejects the connection on failure
71///
72/// # Example
73///
74/// ```no_run
75/// use msg_socket::hooks::token::ServerHook;
76///
77/// // Accept all tokens
78/// let hook = ServerHook::accept_all();
79///
80/// // Custom validation
81/// let hook = ServerHook::new(|token| **token == *b"my_secret_token");
82/// ```
83pub struct ServerHook<F> {
84    validator: F,
85}
86
87impl ServerHook<fn(&Bytes) -> bool> {
88    /// Creates a server hook that accepts all tokens.
89    pub fn accept_all() -> Self {
90        Self { validator: |_| true }
91    }
92}
93
94impl<F> ServerHook<F>
95where
96    F: Fn(&Bytes) -> bool + Send + Sync + 'static,
97{
98    /// Creates a new server hook with the given validator function.
99    ///
100    /// The validator receives the client's token and returns `true` to accept
101    /// the connection or `false` to reject it.
102    pub fn new(validator: F) -> Self {
103        Self { validator }
104    }
105}
106
107impl<Io, F> ConnectionHook<Io> for ServerHook<F>
108where
109    Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
110    F: Fn(&Bytes) -> bool + Send + Sync + 'static,
111{
112    type Error = ServerHookError;
113
114    async fn on_connection(&self, io: Io) -> HookResult<Io, Self::Error> {
115        let mut conn = Framed::new(io, auth::Codec::new_server());
116
117        // Wait for client authentication message
118        let msg = conn
119            .next()
120            .await
121            .ok_or(Error::hook(ServerHookError::ConnectionClosed))?
122            .map_err(|e| io::Error::other(e.to_string()))?;
123
124        let auth::Message::Auth(token) = msg else {
125            return Err(Error::hook(ServerHookError::ExpectedAuthMessage));
126        };
127
128        // Validate the token
129        if !(self.validator)(&token) {
130            conn.send(auth::Message::Reject).await?;
131            return Err(Error::hook(ServerHookError::Rejected));
132        }
133
134        // Send acknowledgment
135        conn.send(auth::Message::Ack).await?;
136
137        Ok(conn.into_inner())
138    }
139}
140
141/// Client-side authentication connection hook that sends a token to the server.
142///
143/// When connecting to a server, this connection hook:
144/// 1. Sends the configured token to the server
145/// 2. Waits for the server's ACK response
146/// 3. Returns an error if the server rejects the token
147///
148/// # Example
149///
150/// ```no_run
151/// use bytes::Bytes;
152/// use msg_socket::hooks::token::ClientHook;
153///
154/// let hook = ClientHook::new(Bytes::from("my_secret_token"));
155/// ```
156pub struct ClientHook {
157    token: Bytes,
158}
159
160impl ClientHook {
161    /// Creates a new client hook with the given authentication token.
162    pub fn new(token: Bytes) -> Self {
163        Self { token }
164    }
165}
166
167impl<Io> ConnectionHook<Io> for ClientHook
168where
169    Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
170{
171    type Error = ClientHookError;
172
173    async fn on_connection(&self, io: Io) -> HookResult<Io, Self::Error> {
174        let mut conn = Framed::new(io, auth::Codec::new_client());
175
176        // Send authentication token
177        conn.send(auth::Message::Auth(self.token.clone())).await?;
178
179        conn.flush().await?;
180
181        // Wait for server acknowledgment
182        let ack = conn
183            .next()
184            .await
185            .ok_or(Error::hook(ClientHookError::ConnectionClosed))?
186            .map_err(|e| io::Error::other(e.to_string()))?;
187
188        if !matches!(ack, auth::Message::Ack) {
189            return Err(Error::hook(ClientHookError::Denied));
190        }
191
192        Ok(conn.into_inner())
193    }
194}