axum_realtime_kit/
auth.rs

1//! ## Example
2//!
3//! ```rust,no_run
4//! # use async_trait::async_trait;
5//! # use axum::{response::Response, routing::get, Router, response::IntoResponse};
6//! # use axum_realtime_kit::auth::{TokenValidator, WsAuth};
7//! # use std::{sync::Arc, net::SocketAddr};
8//! #
9//! // 1. Your custom user struct
10//! #[derive(Debug, Clone)]
11//! struct User {
12//!     id: i64,
13//!     username: String,
14//! }
15//!
16//! // Your application's shared state
17//! #[derive(Clone)]
18//! struct AppState {
19//!     // ... your db pools, etc.
20//! }
21//!
22//! // 2. Implement the trait on your state
23//! #[async_trait]
24//! impl TokenValidator for AppState {
25//!     type User = User;
26//!     type Error = std::io::Error;
27//!
28//!     async fn validate_token(&self, token: &str) -> Result<Self::User, Self::Error> {
29//!         // ...
30//! #        if token == "secret-token" {
31//! #            Ok(User { id: 123, username: "Alice".to_string() })
32//! #        } else {
33//! #            Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "Invalid token"))
34//! #        }
35//!     }
36//! }
37//!
38//! // 3. Use the extractor in your handler
39//! async fn websocket_handler(
40//!     auth: WsAuth<User>,
41//! ) -> Response {
42//! let user: User = auth.0;
43//!     println!("Authenticated user: {:?}", user);
44//!     "Hello!".into_response()
45//! }
46//!
47//! #[tokio::main]
48//! async fn main() {
49//!     // The AppState itself is the state. It needs to be Clone.
50//!     let app_state = AppState {};
51//!     let app: Router<AppState> = Router::new()
52//!         .route("/ws", get(websocket_handler))
53//!         .with_state(app_state);
54//!     // ...
55//! }
56//! ```
57
58use async_trait::async_trait;
59use axum::{
60    extract::{FromRequestParts, Query},
61    http::{HeaderMap, StatusCode, request::Parts},
62    response::{IntoResponse, Response},
63};
64use serde::Deserialize;
65
66/// The public trait the user's `AppState` must implement to enable `WsAuth`.
67///
68/// This trait decouples the authentication mechanism from the extractor, allowing
69/// any token validation strategy to be used.
70#[async_trait]
71pub trait TokenValidator {
72    /// The user type that is returned upon successful validation. This can be any
73    /// struct that is `Send + Sync + 'static`.
74    type User: Send + Sync + 'static;
75    /// The error type returned on validation failure.
76    type Error: std::error::Error + Send + Sync + 'static;
77
78    /// Validates a token string and returns a user on success.
79    ///
80    /// This is where you implement your specific authentication logic, such as
81    /// decoding a JWT or querying a database.
82    ///
83    /// # Arguments
84    /// * `token` - The token string extracted from the request.
85    async fn validate_token(&self, token: &str) -> Result<Self::User, Self::Error>;
86}
87
88/// The generic WebSocket authentication extractor.
89///
90/// This extractor requires the request to be authenticated. It holds the validated
91/// user object of type `U`. If authentication fails, the request is rejected with
92//  a `401 Unauthorized` response.
93#[derive(Debug)]
94pub struct WsAuth<U>(pub U)
95where
96    U: Send + Sync + 'static;
97
98/// The query parameter struct used internally for token extraction.
99#[derive(Deserialize)]
100struct WebSocketAuthQuery {
101    token: String,
102}
103
104impl<S, U> FromRequestParts<S> for WsAuth<U>
105where
106    S: TokenValidator<User = U> + Send + Sync + 'static,
107    U: Send + Sync + 'static,
108{
109    type Rejection = Response;
110
111    fn from_request_parts(
112        parts: &mut Parts,
113        state: &S,
114    ) -> impl Future<Output = Result<Self, <Self as FromRequestParts<S>>::Rejection>> + Send {
115        Box::pin(async move {
116            // Extract token from header or query
117            let token = get_token_from_headers(&parts.headers);
118            let token = if let Some(t) = token {
119                Some(t)
120            } else {
121                match Query::<WebSocketAuthQuery>::from_request_parts(parts, state).await {
122                    Ok(Query(q)) => Some(q.token),
123                    Err(_) => None,
124                }
125            };
126
127            let token = match token {
128                Some(t) => t,
129                None => return Err(StatusCode::UNAUTHORIZED.into_response()),
130            };
131
132            match state.validate_token(&token).await {
133                Ok(user) => Ok(WsAuth(user)),
134                Err(_) => Err(StatusCode::UNAUTHORIZED.into_response()),
135            }
136        })
137    }
138}
139
140/// A private helper function to extract a bearer token from the Authorization header.
141fn get_token_from_headers(headers: &HeaderMap) -> Option<String> {
142    headers
143        .get("Authorization")
144        .and_then(|header| header.to_str().ok())
145        .and_then(|header_val| {
146            header_val
147                .strip_prefix("Bearer ")
148                .map(|token| token.to_owned())
149        })
150}