allframe_core/auth/
axum.rs1use std::{
32 future::Future,
33 marker::PhantomData,
34 pin::Pin,
35 sync::Arc,
36 task::{Context, Poll},
37};
38
39use super::{extract_bearer_token, AuthContext, AuthError, Authenticator};
40
41#[derive(Debug, Clone)]
69pub struct AuthenticatedUser<C>(pub C);
70
71impl<C> AuthenticatedUser<C> {
72 pub fn claims(&self) -> &C {
74 &self.0
75 }
76
77 pub fn into_inner(self) -> C {
79 self.0
80 }
81}
82
83impl<C> std::ops::Deref for AuthenticatedUser<C> {
84 type Target = C;
85
86 fn deref(&self) -> &Self::Target {
87 &self.0
88 }
89}
90
91#[derive(Clone)]
108pub struct AuthLayer<A> {
109 authenticator: Arc<A>,
110}
111
112impl<A> AuthLayer<A> {
113 pub fn new(authenticator: A) -> Self {
115 Self {
116 authenticator: Arc::new(authenticator),
117 }
118 }
119}
120
121impl<S, A> tower::Layer<S> for AuthLayer<A>
122where
123 A: Clone,
124{
125 type Service = AuthService<S, A>;
126
127 fn layer(&self, inner: S) -> Self::Service {
128 AuthService {
129 inner,
130 authenticator: self.authenticator.clone(),
131 }
132 }
133}
134
135#[derive(Clone)]
137pub struct AuthService<S, A> {
138 inner: S,
139 authenticator: Arc<A>,
140}
141
142impl<S, A, ReqBody> tower::Service<hyper::Request<ReqBody>> for AuthService<S, A>
143where
144 S: tower::Service<hyper::Request<ReqBody>> + Clone + Send + 'static,
145 S::Future: Send,
146 A: Authenticator + 'static,
147 ReqBody: Send + 'static,
148{
149 type Response = S::Response;
150 type Error = S::Error;
151 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
152
153 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
154 self.inner.poll_ready(cx)
155 }
156
157 fn call(&mut self, mut req: hyper::Request<ReqBody>) -> Self::Future {
158 let authenticator = self.authenticator.clone();
159 let mut inner = self.inner.clone();
160
161 Box::pin(async move {
162 if let Some(auth_header) = req.headers().get(hyper::header::AUTHORIZATION) {
164 if let Ok(header_str) = auth_header.to_str() {
165 if let Some(token) = extract_bearer_token(header_str) {
166 if let Ok(claims) = authenticator.authenticate(token).await {
168 let ctx = AuthContext::new(claims, token);
169 req.extensions_mut().insert(ctx);
170 }
171 }
172 }
173 }
174
175 inner.call(req).await
176 })
177 }
178}
179
180#[derive(Clone)]
185pub struct OptionalAuthLayer<A> {
186 authenticator: Arc<A>,
187}
188
189impl<A> OptionalAuthLayer<A> {
190 pub fn new(authenticator: A) -> Self {
192 Self {
193 authenticator: Arc::new(authenticator),
194 }
195 }
196}
197
198impl<S, A> tower::Layer<S> for OptionalAuthLayer<A>
199where
200 A: Clone,
201{
202 type Service = AuthService<S, A>;
203
204 fn layer(&self, inner: S) -> Self::Service {
205 AuthService {
206 inner,
207 authenticator: self.authenticator.clone(),
208 }
209 }
210}
211
212pub trait AuthExt {
214 fn auth_context<C: Clone + Send + Sync + 'static>(&self) -> Option<&AuthContext<C>>;
216
217 fn claims<C: Clone + Send + Sync + 'static>(&self) -> Option<&C> {
219 self.auth_context::<C>().map(|ctx| &ctx.claims)
220 }
221}
222
223impl<B> AuthExt for hyper::Request<B> {
224 fn auth_context<C: Clone + Send + Sync + 'static>(&self) -> Option<&AuthContext<C>> {
225 self.extensions().get::<AuthContext<C>>()
226 }
227}
228
229#[derive(Debug)]
231pub struct AuthRejection {
232 pub error: AuthError,
234}
235
236impl AuthRejection {
237 pub fn new(error: AuthError) -> Self {
239 Self { error }
240 }
241}
242
243impl std::fmt::Display for AuthRejection {
244 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
245 write!(f, "{}", self.error)
246 }
247}
248
249impl std::error::Error for AuthRejection {}
250
251#[derive(Debug, Clone, Copy)]
269pub struct Required;
270
271#[derive(Debug, Clone, Copy)]
275pub struct Optional;
276
277#[derive(Debug, Clone)]
279pub struct Auth<C, R = Required> {
280 pub context: Option<AuthContext<C>>,
282 _requirement: PhantomData<R>,
283}
284
285impl<C: Clone> Auth<C, Required> {
286 pub fn claims(&self) -> &C {
288 &self.context.as_ref().unwrap().claims
289 }
290
291 pub fn token(&self) -> &str {
293 self.context.as_ref().unwrap().token()
294 }
295}
296
297impl<C> Auth<C, Optional> {
298 pub fn claims(&self) -> Option<&C> {
300 self.context.as_ref().map(|ctx| &ctx.claims)
301 }
302
303 pub fn is_authenticated(&self) -> bool {
305 self.context.is_some()
306 }
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312
313 #[test]
314 fn test_authenticated_user() {
315 #[derive(Clone, Debug, PartialEq)]
316 struct Claims {
317 sub: String,
318 }
319
320 let user = AuthenticatedUser(Claims {
321 sub: "user123".to_string(),
322 });
323
324 assert_eq!(user.claims().sub, "user123");
325 assert_eq!(user.sub, "user123"); let claims = user.into_inner();
328 assert_eq!(claims.sub, "user123");
329 }
330
331 #[test]
332 fn test_auth_rejection() {
333 let rejection = AuthRejection::new(AuthError::MissingToken);
334 assert!(rejection.to_string().contains("missing"));
335 }
336
337 #[test]
338 fn test_auth_ext_trait() {
339 #[derive(Clone)]
340 struct Claims {
341 sub: String,
342 }
343
344 let mut req = hyper::Request::builder()
345 .body(())
346 .unwrap();
347
348 assert!(req.auth_context::<Claims>().is_none());
350 assert!(req.claims::<Claims>().is_none());
351
352 req.extensions_mut().insert(AuthContext::new(
354 Claims {
355 sub: "user123".to_string(),
356 },
357 "token",
358 ));
359
360 assert!(req.auth_context::<Claims>().is_some());
362 assert_eq!(req.claims::<Claims>().unwrap().sub, "user123");
363 }
364}