axum_realtime_kit/auth.rs
1//! ## Example
2//!
3//! ```rust,no_run
4//! # use async_trait::async_trait;
5//! # use axum::{response::Response, routing::get, Router, response::IntoResponse};
6//! # use axum_realtime_kit::auth::{TokenValidator, WsAuth};
7//! # use std::{sync::Arc, net::SocketAddr};
8//! #
9//! // 1. Your custom user struct
10//! #[derive(Debug, Clone)]
11//! struct User {
12//! id: i64,
13//! username: String,
14//! }
15//!
16//! // Your application's shared state
17//! #[derive(Clone)]
18//! struct AppState {
19//! // ... your db pools, etc.
20//! }
21//!
22//! // 2. Implement the trait on your state
23//! #[async_trait]
24//! impl TokenValidator for AppState {
25//! type User = User;
26//! type Error = std::io::Error;
27//!
28//! async fn validate_token(&self, token: &str) -> Result<Self::User, Self::Error> {
29//! // ...
30//! # if token == "secret-token" {
31//! # Ok(User { id: 123, username: "Alice".to_string() })
32//! # } else {
33//! # Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "Invalid token"))
34//! # }
35//! }
36//! }
37//!
38//! // 3. Use the extractor in your handler
39//! async fn websocket_handler(
40//! auth: WsAuth<User>,
41//! ) -> Response {
42//! let user: User = auth.0;
43//! println!("Authenticated user: {:?}", user);
44//! "Hello!".into_response()
45//! }
46//!
47//! #[tokio::main]
48//! async fn main() {
49//! // The AppState itself is the state. It needs to be Clone.
50//! let app_state = AppState {};
51//! let app: Router<AppState> = Router::new()
52//! .route("/ws", get(websocket_handler))
53//! .with_state(app_state);
54//! // ...
55//! }
56//! ```
57
58use async_trait::async_trait;
59use axum::{
60 extract::{FromRequestParts, Query},
61 http::{HeaderMap, StatusCode, request::Parts},
62 response::{IntoResponse, Response},
63};
64use serde::Deserialize;
65
66/// The public trait the user's `AppState` must implement to enable `WsAuth`.
67///
68/// This trait decouples the authentication mechanism from the extractor, allowing
69/// any token validation strategy to be used.
70#[async_trait]
71pub trait TokenValidator {
72 /// The user type that is returned upon successful validation. This can be any
73 /// struct that is `Send + Sync + 'static`.
74 type User: Send + Sync + 'static;
75 /// The error type returned on validation failure.
76 type Error: std::error::Error + Send + Sync + 'static;
77
78 /// Validates a token string and returns a user on success.
79 ///
80 /// This is where you implement your specific authentication logic, such as
81 /// decoding a JWT or querying a database.
82 ///
83 /// # Arguments
84 /// * `token` - The token string extracted from the request.
85 async fn validate_token(&self, token: &str) -> Result<Self::User, Self::Error>;
86}
87
88/// The generic WebSocket authentication extractor.
89///
90/// This extractor requires the request to be authenticated. It holds the validated
91/// user object of type `U`. If authentication fails, the request is rejected with
92// a `401 Unauthorized` response.
93#[derive(Debug)]
94pub struct WsAuth<U>(pub U)
95where
96 U: Send + Sync + 'static;
97
98/// The query parameter struct used internally for token extraction.
99#[derive(Deserialize)]
100struct WebSocketAuthQuery {
101 token: String,
102}
103
104impl<S, U> FromRequestParts<S> for WsAuth<U>
105where
106 S: TokenValidator<User = U> + Send + Sync + 'static,
107 U: Send + Sync + 'static,
108{
109 type Rejection = Response;
110
111 fn from_request_parts(
112 parts: &mut Parts,
113 state: &S,
114 ) -> impl Future<Output = Result<Self, <Self as FromRequestParts<S>>::Rejection>> + Send {
115 Box::pin(async move {
116 // Extract token from header or query
117 let token = get_token_from_headers(&parts.headers);
118 let token = if let Some(t) = token {
119 Some(t)
120 } else {
121 match Query::<WebSocketAuthQuery>::from_request_parts(parts, state).await {
122 Ok(Query(q)) => Some(q.token),
123 Err(_) => None,
124 }
125 };
126
127 let token = match token {
128 Some(t) => t,
129 None => return Err(StatusCode::UNAUTHORIZED.into_response()),
130 };
131
132 match state.validate_token(&token).await {
133 Ok(user) => Ok(WsAuth(user)),
134 Err(_) => Err(StatusCode::UNAUTHORIZED.into_response()),
135 }
136 })
137 }
138}
139
140/// A private helper function to extract a bearer token from the Authorization header.
141fn get_token_from_headers(headers: &HeaderMap) -> Option<String> {
142 headers
143 .get("Authorization")
144 .and_then(|header| header.to_str().ok())
145 .and_then(|header_val| {
146 header_val
147 .strip_prefix("Bearer ")
148 .map(|token| token.to_owned())
149 })
150}