1use std::{collections::BTreeMap, ops::Deref, sync::Arc};
2
3use axum::{
4 extract::FromRequestParts,
5 http::{header::AUTHORIZATION, request::Parts},
6};
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9
10use crate::{request::request_id_from_extensions, HttpException};
11
12#[derive(Debug, Clone, Serialize, Deserialize, Default)]
13pub struct AuthIdentity {
14 pub subject: String,
15 pub roles: Vec<String>,
16 #[serde(default)]
17 pub claims: BTreeMap<String, Value>,
18}
19
20impl AuthIdentity {
21 pub fn new(subject: impl Into<String>) -> Self {
22 Self {
23 subject: subject.into(),
24 roles: Vec::new(),
25 claims: BTreeMap::new(),
26 }
27 }
28
29 pub fn with_roles<I, S>(mut self, roles: I) -> Self
30 where
31 I: IntoIterator<Item = S>,
32 S: Into<String>,
33 {
34 self.roles = roles.into_iter().map(Into::into).collect();
35 self
36 }
37
38 pub fn with_claim(mut self, key: impl Into<String>, value: Value) -> Self {
39 self.claims.insert(key.into(), value);
40 self
41 }
42
43 pub fn has_role(&self, role: &str) -> bool {
44 self.roles.iter().any(|candidate| candidate == role)
45 }
46
47 pub fn require_role(&self, role: &str) -> Result<(), HttpException> {
48 if self.has_role(role) {
49 Ok(())
50 } else {
51 Err(HttpException::forbidden(format!(
52 "Missing required role `{role}`"
53 )))
54 }
55 }
56}
57
58#[derive(Debug, Clone)]
59pub struct AuthUser(pub Arc<AuthIdentity>);
60
61impl AuthUser {
62 pub fn into_inner(self) -> Arc<AuthIdentity> {
63 self.0
64 }
65
66 pub fn value(&self) -> &AuthIdentity {
67 &self.0
68 }
69}
70
71impl Deref for AuthUser {
72 type Target = AuthIdentity;
73
74 fn deref(&self) -> &Self::Target {
75 &self.0
76 }
77}
78
79impl<S> FromRequestParts<S> for AuthUser
80where
81 S: Send + Sync,
82{
83 type Rejection = HttpException;
84
85 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
86 let request_id = request_id_from_extensions(&parts.extensions);
87 let identity = parts
88 .extensions
89 .get::<Arc<AuthIdentity>>()
90 .cloned()
91 .ok_or_else(|| {
92 HttpException::unauthorized("Authentication required")
93 .with_optional_request_id(request_id)
94 })?;
95
96 Ok(Self(identity))
97 }
98}
99
100#[derive(Debug, Clone, Default)]
101pub struct OptionalAuthUser(pub Option<Arc<AuthIdentity>>);
102
103impl OptionalAuthUser {
104 pub fn into_inner(self) -> Option<Arc<AuthIdentity>> {
105 self.0
106 }
107
108 pub fn value(&self) -> Option<&AuthIdentity> {
109 self.0.as_deref()
110 }
111
112 pub fn is_authenticated(&self) -> bool {
113 self.0.is_some()
114 }
115}
116
117impl Deref for OptionalAuthUser {
118 type Target = Option<Arc<AuthIdentity>>;
119
120 fn deref(&self) -> &Self::Target {
121 &self.0
122 }
123}
124
125impl<S> FromRequestParts<S> for OptionalAuthUser
126where
127 S: Send + Sync,
128{
129 type Rejection = HttpException;
130
131 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
132 Ok(Self(parts.extensions.get::<Arc<AuthIdentity>>().cloned()))
133 }
134}
135
136#[derive(Debug, Clone)]
137pub struct BearerToken(pub Arc<str>);
138
139impl BearerToken {
140 pub fn value(&self) -> &str {
141 self.0.as_ref()
142 }
143
144 pub fn into_inner(self) -> String {
145 self.0.as_ref().to_string()
146 }
147}
148
149impl Deref for BearerToken {
150 type Target = str;
151
152 fn deref(&self) -> &Self::Target {
153 self.0.as_ref()
154 }
155}
156
157impl<S> FromRequestParts<S> for BearerToken
158where
159 S: Send + Sync,
160{
161 type Rejection = HttpException;
162
163 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
164 let request_id = request_id_from_extensions(&parts.extensions);
165 let header = parts
166 .headers
167 .get(AUTHORIZATION)
168 .and_then(|value| value.to_str().ok())
169 .ok_or_else(|| {
170 HttpException::unauthorized("Missing Authorization header")
171 .with_optional_request_id(request_id.clone())
172 })?;
173
174 let token = header
175 .strip_prefix("Bearer ")
176 .map(str::trim)
177 .filter(|value| !value.is_empty())
178 .ok_or_else(|| {
179 HttpException::unauthorized("Invalid bearer token")
180 .with_optional_request_id(request_id)
181 })?;
182
183 Ok(Self(Arc::<str>::from(token.to_string())))
184 }
185}