Skip to main content

cloudillo_types/
extract.rs

1// SPDX-FileCopyrightText: Szilárd Hajba
2// SPDX-License-Identifier: LGPL-3.0-or-later
3
4//! Custom Axum extractors for Cloudillo-specific types.
5//!
6//! Provides `FromRequestParts` implementations for `TnId` and `IdTag`
7//! that work with any state implementing the required traits.
8
9use std::sync::Arc;
10
11use async_trait::async_trait;
12use axum::extract::FromRequestParts;
13use axum::http::request::Parts;
14
15use crate::error::Error;
16use crate::types::TnId;
17
18// IdTag //
19//*******//
20/// Identity tag extracted from request extensions (set by auth middleware).
21#[derive(Clone, Debug)]
22pub struct IdTag(pub Box<str>);
23
24impl IdTag {
25	pub fn new(id_tag: &str) -> IdTag {
26		IdTag(Box::from(id_tag))
27	}
28}
29
30impl<S> FromRequestParts<S> for IdTag
31where
32	S: Send + Sync,
33{
34	type Rejection = Error;
35
36	async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
37		if let Some(id_tag) = parts.extensions.get::<IdTag>().cloned() {
38			Ok(id_tag)
39		} else {
40			Err(Error::PermissionDenied)
41		}
42	}
43}
44
45// TnId //
46//******//
47/// Trait for resolving `TnId` from an identity tag string.
48///
49/// Implement this on your application state type to enable the
50/// `TnId` Axum extractor.
51#[async_trait]
52pub trait TnIdResolver: Send + Sync {
53	async fn resolve_tn_id(&self, id_tag: &str) -> Result<TnId, Error>;
54}
55
56/// Blanket impl for `Arc<T>` so that `App = Arc<AppState>` works
57/// when `AppState` implements `TnIdResolver`.
58#[async_trait]
59impl<T: TnIdResolver> TnIdResolver for Arc<T> {
60	async fn resolve_tn_id(&self, id_tag: &str) -> Result<TnId, Error> {
61		(**self).resolve_tn_id(id_tag).await
62	}
63}
64
65impl<S> FromRequestParts<S> for TnId
66where
67	S: TnIdResolver + Send + Sync,
68{
69	type Rejection = Error;
70
71	async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
72		if let Some(id_tag) = parts.extensions.get::<IdTag>().cloned() {
73			state.resolve_tn_id(&id_tag.0).await.map_err(|_| Error::PermissionDenied)
74		} else {
75			Err(Error::PermissionDenied)
76		}
77	}
78}
79
80// vim: ts=4