Skip to main content

nestforge_core/
auth.rs

1use std::{collections::BTreeMap, ops::Deref, sync::Arc};
2
3use axum::{
4    extract::FromRequestParts,
5    http::{header::AUTHORIZATION, request::Parts},
6};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9
10use crate::{request::request_id_from_extensions, HttpException};
11
12#[derive(Debug, Clone, Serialize, Deserialize, Default)]
13pub struct AuthIdentity {
14    pub subject: String,
15    pub roles: Vec<String>,
16    #[serde(default)]
17    pub claims: BTreeMap<String, Value>,
18}
19
20impl AuthIdentity {
21    pub fn new(subject: impl Into<String>) -> Self {
22        Self {
23            subject: subject.into(),
24            roles: Vec::new(),
25            claims: BTreeMap::new(),
26        }
27    }
28
29    pub fn with_roles<I, S>(mut self, roles: I) -> Self
30    where
31        I: IntoIterator<Item = S>,
32        S: Into<String>,
33    {
34        self.roles = roles.into_iter().map(Into::into).collect();
35        self
36    }
37
38    pub fn with_claim(mut self, key: impl Into<String>, value: Value) -> Self {
39        self.claims.insert(key.into(), value);
40        self
41    }
42
43    pub fn has_role(&self, role: &str) -> bool {
44        self.roles.iter().any(|candidate| candidate == role)
45    }
46
47    pub fn require_role(&self, role: &str) -> Result<(), HttpException> {
48        if self.has_role(role) {
49            Ok(())
50        } else {
51            Err(HttpException::forbidden(format!(
52                "Missing required role `{role}`"
53            )))
54        }
55    }
56}
57
58#[derive(Debug, Clone)]
59pub struct AuthUser(pub Arc<AuthIdentity>);
60
61impl AuthUser {
62    pub fn into_inner(self) -> Arc<AuthIdentity> {
63        self.0
64    }
65
66    pub fn value(&self) -> &AuthIdentity {
67        &self.0
68    }
69}
70
71impl Deref for AuthUser {
72    type Target = AuthIdentity;
73
74    fn deref(&self) -> &Self::Target {
75        &self.0
76    }
77}
78
79impl<S> FromRequestParts<S> for AuthUser
80where
81    S: Send + Sync,
82{
83    type Rejection = HttpException;
84
85    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
86        let request_id = request_id_from_extensions(&parts.extensions);
87        let identity = parts
88            .extensions
89            .get::<Arc<AuthIdentity>>()
90            .cloned()
91            .ok_or_else(|| {
92                HttpException::unauthorized("Authentication required")
93                    .with_optional_request_id(request_id)
94            })?;
95
96        Ok(Self(identity))
97    }
98}
99
100#[derive(Debug, Clone, Default)]
101pub struct OptionalAuthUser(pub Option<Arc<AuthIdentity>>);
102
103impl OptionalAuthUser {
104    pub fn into_inner(self) -> Option<Arc<AuthIdentity>> {
105        self.0
106    }
107
108    pub fn value(&self) -> Option<&AuthIdentity> {
109        self.0.as_deref()
110    }
111
112    pub fn is_authenticated(&self) -> bool {
113        self.0.is_some()
114    }
115}
116
117impl Deref for OptionalAuthUser {
118    type Target = Option<Arc<AuthIdentity>>;
119
120    fn deref(&self) -> &Self::Target {
121        &self.0
122    }
123}
124
125impl<S> FromRequestParts<S> for OptionalAuthUser
126where
127    S: Send + Sync,
128{
129    type Rejection = HttpException;
130
131    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
132        Ok(Self(parts.extensions.get::<Arc<AuthIdentity>>().cloned()))
133    }
134}
135
136#[derive(Debug, Clone)]
137pub struct BearerToken(pub Arc<str>);
138
139impl BearerToken {
140    pub fn value(&self) -> &str {
141        self.0.as_ref()
142    }
143
144    pub fn into_inner(self) -> String {
145        self.0.as_ref().to_string()
146    }
147}
148
149impl Deref for BearerToken {
150    type Target = str;
151
152    fn deref(&self) -> &Self::Target {
153        self.0.as_ref()
154    }
155}
156
157impl<S> FromRequestParts<S> for BearerToken
158where
159    S: Send + Sync,
160{
161    type Rejection = HttpException;
162
163    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
164        let request_id = request_id_from_extensions(&parts.extensions);
165        let header = parts
166            .headers
167            .get(AUTHORIZATION)
168            .and_then(|value| value.to_str().ok())
169            .ok_or_else(|| {
170                HttpException::unauthorized("Missing Authorization header")
171                    .with_optional_request_id(request_id.clone())
172            })?;
173
174        let token = header
175            .strip_prefix("Bearer ")
176            .map(str::trim)
177            .filter(|value| !value.is_empty())
178            .ok_or_else(|| {
179                HttpException::unauthorized("Invalid bearer token")
180                    .with_optional_request_id(request_id)
181            })?;
182
183        Ok(Self(Arc::<str>::from(token.to_string())))
184    }
185}