Skip to main content

nestforge_core/
auth.rs

1use std::{collections::BTreeMap, ops::Deref, sync::Arc};
2
3use axum::{
4    extract::FromRequestParts,
5    http::{header::AUTHORIZATION, request::Parts},
6};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9
10use crate::{request::request_id_from_extensions, HttpException};
11
12/**
13 * Authentication Identity
14 *
15 * Represents an authenticated user's identity within the NestForge framework.
16 * Contains the subject (typically user ID), roles, and additional custom claims.
17 *
18 * # Fields
19 * - `subject`: The unique identifier for the user (e.g., user ID, email)
20 * - `roles`: List of role names the user possesses
21 * - `claims`: Additional key-value pairs for custom authentication data
22 *
23 * # Usage
24 * This type is automatically populated by the framework's authentication
25 * middleware and can be accessed in handlers via the `AuthUser` extractor.
26 */
27#[derive(Debug, Clone, Serialize, Deserialize, Default)]
28pub struct AuthIdentity {
29    pub subject: String,
30    pub roles: Vec<String>,
31    #[serde(default)]
32    pub claims: BTreeMap<String, Value>,
33}
34
35impl AuthIdentity {
36    /**
37     * Creates a new authentication identity.
38     *
39     * # Arguments
40     * - `subject`: The unique identifier for the user
41     *
42     * # Example
43     * ```rust
44     * let identity = AuthIdentity::new("user-123");
45     * ```
46     */
47    pub fn new(subject: impl Into<String>) -> Self {
48        Self {
49            subject: subject.into(),
50            roles: Vec::new(),
51            claims: BTreeMap::new(),
52        }
53    }
54
55    /**
56     * Adds roles to the authentication identity.
57     *
58     * # Arguments
59     * - `roles`: An iterator of role names to assign
60     */
61    pub fn with_roles<I, S>(mut self, roles: I) -> Self
62    where
63        I: IntoIterator<Item = S>,
64        S: Into<String>,
65    {
66        self.roles = roles.into_iter().map(Into::into).collect();
67        self
68    }
69
70    /**
71     * Adds a custom claim to the authentication identity.
72     *
73     * # Arguments
74     * - `key`: The claim key
75     * - `value`: The claim value
76     */
77    pub fn with_claim(mut self, key: impl Into<String>, value: Value) -> Self {
78        self.claims.insert(key.into(), value);
79        self
80    }
81
82    /**
83     * Checks if the identity has a specific role.
84     *
85     * # Arguments
86     * - `role`: The role name to check
87     *
88     * Returns true if the role is present in the identity's roles.
89     */
90    pub fn has_role(&self, role: &str) -> bool {
91        self.roles.iter().any(|candidate| candidate == role)
92    }
93
94    /**
95     * Requires a specific role, returning an error if not present.
96     *
97     * # Arguments
98     * - `role`: The required role name
99     *
100     * Returns Ok if the role is present, or a forbidden HttpException if not.
101     */
102    pub fn require_role(&self, role: &str) -> Result<(), HttpException> {
103        if self.has_role(role) {
104            Ok(())
105        } else {
106            Err(HttpException::forbidden(format!(
107                "Missing required role `{role}`"
108            )))
109        }
110    }
111}
112
113/**
114 * AuthUser Extractor
115 *
116 * A request extractor that provides mandatory authentication.
117 * Fails with 401 Unauthorized if no authenticated identity is present.
118 *
119 * # Usage
120 * ```rust
121 * async fn handler(user: AuthUser) -> impl IntoResponse {
122 *     format!("Hello, {}", user.subject)
123 * }
124 * ```
125 */
126#[derive(Debug, Clone)]
127pub struct AuthUser(pub Arc<AuthIdentity>);
128
129impl AuthUser {
130    pub fn into_inner(self) -> Arc<AuthIdentity> {
131        self.0
132    }
133
134    pub fn value(&self) -> &AuthIdentity {
135        &self.0
136    }
137}
138
139impl Deref for AuthUser {
140    type Target = AuthIdentity;
141
142    fn deref(&self) -> &Self::Target {
143        &self.0
144    }
145}
146
147impl<S> FromRequestParts<S> for AuthUser
148where
149    S: Send + Sync,
150{
151    type Rejection = HttpException;
152
153    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
154        let request_id = request_id_from_extensions(&parts.extensions);
155        let identity = parts
156            .extensions
157            .get::<Arc<AuthIdentity>>()
158            .cloned()
159            .ok_or_else(|| {
160                HttpException::unauthorized("Authentication required")
161                    .with_optional_request_id(request_id)
162            })?;
163
164        Ok(Self(identity))
165    }
166}
167
168/**
169 * OptionalAuthUser Extractor
170 *
171 * A request extractor that provides optional authentication.
172 * Unlike `AuthUser`, this succeeds even when no identity is present,
173 * returning None in that case.
174 *
175 * # Usage
176 * ```rust
177 * async fn handler(user: OptionalAuthUser) -> impl IntoResponse {
178 *     match user.value() {
179 *         Some(identity) => format!("Hello, {}", identity.subject),
180 *         None => "Hello, guest".to_string(),
181 *     }
182 * }
183 * ```
184 */
185#[derive(Debug, Clone, Default)]
186pub struct OptionalAuthUser(pub Option<Arc<AuthIdentity>>);
187
188/**
189 * BearerToken Extractor
190 *
191 * A request extractor that extracts the bearer token from the
192 * Authorization header. Useful for custom authentication schemes.
193 *
194 * # Response
195 * Returns the token string without the "Bearer " prefix.
196 *
197 * # Errors
198 * - 401 if Authorization header is missing
199 * - 401 if header doesn't start with "Bearer "
200 * - 401 if token is empty after trimming
201 */
202#[derive(Debug, Clone)]
203pub struct BearerToken(pub Arc<str>);
204
205impl BearerToken {
206    pub fn value(&self) -> &str {
207        self.0.as_ref()
208    }
209
210    pub fn into_inner(self) -> String {
211        self.0.as_ref().to_string()
212    }
213}
214
215impl Deref for BearerToken {
216    type Target = str;
217
218    fn deref(&self) -> &Self::Target {
219        self.0.as_ref()
220    }
221}
222
223impl<S> FromRequestParts<S> for BearerToken
224where
225    S: Send + Sync,
226{
227    type Rejection = HttpException;
228
229    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
230        let request_id = request_id_from_extensions(&parts.extensions);
231        let header = parts
232            .headers
233            .get(AUTHORIZATION)
234            .and_then(|value| value.to_str().ok())
235            .ok_or_else(|| {
236                HttpException::unauthorized("Missing Authorization header")
237                    .with_optional_request_id(request_id.clone())
238            })?;
239
240        let token = header
241            .strip_prefix("Bearer ")
242            .map(str::trim)
243            .filter(|value| !value.is_empty())
244            .ok_or_else(|| {
245                HttpException::unauthorized("Invalid bearer token")
246                    .with_optional_request_id(request_id)
247            })?;
248
249        Ok(Self(Arc::<str>::from(token.to_string())))
250    }
251}