msg_socket/hooks/token.rs
1//! Token-based authentication connection hooks.
2//!
3//! This module provides ready-to-use connection hooks for simple token-based authentication:
4//!
5//! - [`ServerHook`] - Server-side connection hook that validates client tokens
6//! - [`ClientHook`] - Client-side connection hook that sends a token to the server
7//!
8//! # Example
9//!
10//! ```no_run
11//! use bytes::Bytes;
12//! use msg_socket::{
13//! RepSocket, ReqSocket,
14//! hooks::token::{ClientHook, ServerHook},
15//! };
16//! use msg_transport::tcp::Tcp;
17//!
18//! // Server side - validates incoming tokens
19//! let rep = RepSocket::new(Tcp::default()).with_connection_hook(ServerHook::new(|token| {
20//! // Custom validation logic
21//! **token == *b"secret"
22//! }));
23//!
24//! // Client side - sends token on connect
25//! let req =
26//! ReqSocket::new(Tcp::default()).with_connection_hook(ClientHook::new(Bytes::from("secret")));
27//! ```
28
29use std::io;
30
31use bytes::Bytes;
32use futures::SinkExt;
33use tokio::io::{AsyncRead, AsyncWrite};
34use tokio_stream::StreamExt;
35use tokio_util::codec::Framed;
36
37use crate::hooks::{ConnectionHook, Error, HookResult};
38use msg_wire::auth;
39
40/// Error type for server-side token authentication.
41#[derive(Debug, thiserror::Error)]
42pub enum ServerHookError {
43 /// The client's token was rejected by the validator.
44 #[error("authentication rejected")]
45 Rejected,
46 /// The connection was closed before authentication completed.
47 #[error("connection closed")]
48 ConnectionClosed,
49 /// Expected an auth message but received something else.
50 #[error("expected auth message")]
51 ExpectedAuthMessage,
52}
53
54/// Error type for client-side token authentication.
55#[derive(Debug, thiserror::Error)]
56pub enum ClientHookError {
57 /// The server denied the authentication.
58 #[error("authentication denied")]
59 Denied,
60 /// The connection was closed before authentication completed.
61 #[error("connection closed")]
62 ConnectionClosed,
63}
64
65/// Server-side authentication connection hook that validates incoming client tokens.
66///
67/// When a client connects, this connection hook:
68/// 1. Waits for the client to send an auth token
69/// 2. Validates the token using the provided validator function
70/// 3. Sends an ACK on success, or rejects the connection on failure
71///
72/// # Example
73///
74/// ```no_run
75/// use msg_socket::hooks::token::ServerHook;
76///
77/// // Accept all tokens
78/// let hook = ServerHook::accept_all();
79///
80/// // Custom validation
81/// let hook = ServerHook::new(|token| **token == *b"my_secret_token");
82/// ```
83pub struct ServerHook<F> {
84 validator: F,
85}
86
87impl ServerHook<fn(&Bytes) -> bool> {
88 /// Creates a server hook that accepts all tokens.
89 pub fn accept_all() -> Self {
90 Self { validator: |_| true }
91 }
92}
93
94impl<F> ServerHook<F>
95where
96 F: Fn(&Bytes) -> bool + Send + Sync + 'static,
97{
98 /// Creates a new server hook with the given validator function.
99 ///
100 /// The validator receives the client's token and returns `true` to accept
101 /// the connection or `false` to reject it.
102 pub fn new(validator: F) -> Self {
103 Self { validator }
104 }
105}
106
107impl<Io, F> ConnectionHook<Io> for ServerHook<F>
108where
109 Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
110 F: Fn(&Bytes) -> bool + Send + Sync + 'static,
111{
112 type Error = ServerHookError;
113
114 async fn on_connection(&self, io: Io) -> HookResult<Io, Self::Error> {
115 let mut conn = Framed::new(io, auth::Codec::new_server());
116
117 // Wait for client authentication message
118 let msg = conn
119 .next()
120 .await
121 .ok_or(Error::hook(ServerHookError::ConnectionClosed))?
122 .map_err(|e| io::Error::other(e.to_string()))?;
123
124 let auth::Message::Auth(token) = msg else {
125 return Err(Error::hook(ServerHookError::ExpectedAuthMessage));
126 };
127
128 // Validate the token
129 if !(self.validator)(&token) {
130 conn.send(auth::Message::Reject).await?;
131 return Err(Error::hook(ServerHookError::Rejected));
132 }
133
134 // Send acknowledgment
135 conn.send(auth::Message::Ack).await?;
136
137 Ok(conn.into_inner())
138 }
139}
140
141/// Client-side authentication connection hook that sends a token to the server.
142///
143/// When connecting to a server, this connection hook:
144/// 1. Sends the configured token to the server
145/// 2. Waits for the server's ACK response
146/// 3. Returns an error if the server rejects the token
147///
148/// # Example
149///
150/// ```no_run
151/// use bytes::Bytes;
152/// use msg_socket::hooks::token::ClientHook;
153///
154/// let hook = ClientHook::new(Bytes::from("my_secret_token"));
155/// ```
156pub struct ClientHook {
157 token: Bytes,
158}
159
160impl ClientHook {
161 /// Creates a new client hook with the given authentication token.
162 pub fn new(token: Bytes) -> Self {
163 Self { token }
164 }
165}
166
167impl<Io> ConnectionHook<Io> for ClientHook
168where
169 Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
170{
171 type Error = ClientHookError;
172
173 async fn on_connection(&self, io: Io) -> HookResult<Io, Self::Error> {
174 let mut conn = Framed::new(io, auth::Codec::new_client());
175
176 // Send authentication token
177 conn.send(auth::Message::Auth(self.token.clone())).await?;
178
179 conn.flush().await?;
180
181 // Wait for server acknowledgment
182 let ack = conn
183 .next()
184 .await
185 .ok_or(Error::hook(ClientHookError::ConnectionClosed))?
186 .map_err(|e| io::Error::other(e.to_string()))?;
187
188 if !matches!(ack, auth::Message::Ack) {
189 return Err(Error::hook(ClientHookError::Denied));
190 }
191
192 Ok(conn.into_inner())
193 }
194}