msg_socket/hooks/mod.rs
1//! Connection hooks for customizing connection establishment.
2//!
3//! Connection hooks are attached when establishing connections and allow custom
4//! authentication, handshakes, or protocol negotiations. The [`ConnectionHook`] trait
5//! is called during connection setup, before the connection is used for messaging.
6//!
7//! # Built-in Hooks
8//!
9//! The [`token`] module provides ready-to-use token-based authentication hooks:
10//! - [`token::ServerHook`] - Server-side hook that validates client tokens
11//! - [`token::ClientHook`] - Client-side hook that sends a token to the server
12//!
13//! # Custom Hooks
14//!
15//! Implement [`ConnectionHook`] for custom authentication or protocol negotiation:
16//!
17//! ```no_run
18//! use msg_socket::hooks::{ConnectionHook, Error, HookResult};
19//! use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
20//!
21//! struct MyAuth;
22//!
23//! #[derive(Debug, thiserror::Error)]
24//! enum MyAuthError {
25//! #[error("invalid token")]
26//! InvalidToken,
27//! }
28//!
29//! impl<Io> ConnectionHook<Io> for MyAuth
30//! where
31//! Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
32//! {
33//! type Error = MyAuthError;
34//!
35//! async fn on_connection(&self, mut io: Io) -> HookResult<Io, Self::Error> {
36//! let mut buf = [0u8; 32];
37//! io.read_exact(&mut buf).await?;
38//! if &buf == b"expected_token_value_32_bytes!!!" {
39//! io.write_all(b"OK").await?;
40//! Ok(io)
41//! } else {
42//! Err(Error::hook(MyAuthError::InvalidToken))
43//! }
44//! }
45//! }
46//! ```
47//!
48//! # Future Extensions
49//!
50//! TODO: Additional hooks may be added for different parts of the connection lifecycle
51//! (e.g., disconnection, reconnection, periodic health checks).
52
53use std::{error::Error as StdError, future::Future, io, pin::Pin, sync::Arc};
54
55use tokio::io::{AsyncRead, AsyncWrite};
56
57pub mod token;
58
59/// Error type for connection hooks.
60///
61/// Distinguishes between I/O errors and hook-specific errors.
62#[derive(Debug, thiserror::Error)]
63pub enum Error<E> {
64 /// An I/O error occurred.
65 #[error("IO error: {0}")]
66 Io(#[from] io::Error),
67 /// A hook-specific error.
68 #[error("Hook error: {0}")]
69 Hook(#[source] E),
70}
71
72impl<E> Error<E> {
73 /// Create a hook error from a hook-specific error.
74 pub fn hook(err: E) -> Self {
75 Error::Hook(err)
76 }
77}
78
79/// Result type for connection hooks.
80///
81/// This is intentionally named `HookResult` (not `Result`) to make it clear this is not
82/// `std::result::Result`. A `HookResult` can be:
83/// - `Ok(io)` - success, returns the IO stream
84/// - `Err(Error::Io(..))` - an I/O error occurred
85/// - `Err(Error::Hook(..))` - a hook-specific error occurred
86pub type HookResult<T, E> = std::result::Result<T, Error<E>>;
87
88/// Type-erased hook result used internally by drivers.
89pub(crate) type ErasedHookResult<T> = HookResult<T, Box<dyn StdError + Send + Sync>>;
90
91/// Connection hook executed during connection establishment.
92///
93/// For server sockets: called when a connection is accepted.
94/// For client sockets: called after connecting.
95///
96/// The connection hook receives the raw IO stream and has full control over the handshake protocol.
97pub trait ConnectionHook<Io>: Send + Sync + 'static
98where
99 Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
100{
101 /// The hook-specific error type.
102 type Error: StdError + Send + Sync + 'static;
103
104 /// Called when a connection is established.
105 ///
106 /// # Arguments
107 /// * `io` - The raw IO stream for this connection
108 ///
109 /// # Returns
110 /// - `Ok(io)` - The IO stream on success (potentially wrapped/transformed)
111 /// - `Err(Error::Io(..))` - An I/O error occurred
112 /// - `Err(Error::Hook(Self::Error))` - A hook-specific error to reject the connection
113 fn on_connection(&self, io: Io) -> impl Future<Output = HookResult<Io, Self::Error>> + Send;
114}
115
116// ============================================================================
117// Type-erased connection hook for internal use
118// ============================================================================
119
120/// Type-erased connection hook for internal use.
121///
122/// This trait allows storing connection hooks with different concrete types behind a single
123/// `Arc<dyn ConnectionHookErased<Io>>`. The hook error type is erased to `Box<dyn Error>`.
124pub(crate) trait ConnectionHookErased<Io>: Send + Sync + 'static
125where
126 Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
127{
128 fn on_connection(
129 self: Arc<Self>,
130 io: Io,
131 ) -> Pin<Box<dyn Future<Output = ErasedHookResult<Io>> + Send + 'static>>;
132}
133
134impl<T, Io> ConnectionHookErased<Io> for T
135where
136 T: ConnectionHook<Io>,
137 Io: AsyncRead + AsyncWrite + Send + Unpin + 'static,
138{
139 fn on_connection(
140 self: Arc<Self>,
141 io: Io,
142 ) -> Pin<Box<dyn Future<Output = ErasedHookResult<Io>> + Send + 'static>> {
143 Box::pin(async move {
144 ConnectionHook::on_connection(&*self, io).await.map_err(|e| match e {
145 Error::Io(io_err) => Error::Io(io_err),
146 Error::Hook(hook_err) => {
147 Error::Hook(Box::new(hook_err) as Box<dyn StdError + Send + Sync>)
148 }
149 })
150 })
151 }
152}