Skip to main content

cloudillo_types/
extract.rs

1// SPDX-FileCopyrightText: Szilárd Hajba
2// SPDX-License-Identifier: LGPL-3.0-or-later
3
4//! Custom Axum extractors for Cloudillo-specific types.
5//!
6//! Provides `FromRequestParts` implementations for `TnId` and `IdTag`
7//! that work with any state implementing the required traits.
8
9use std::sync::Arc;
10
11use async_trait::async_trait;
12use axum::extract::{FromRequestParts, OptionalFromRequestParts};
13use axum::http::request::Parts;
14
15use crate::error::Error;
16use crate::types::TnId;
17
18// IdTag //
19//*******//
20/// Identity tag extracted from request extensions (set by auth middleware).
21#[derive(Clone, Debug)]
22pub struct IdTag(pub Box<str>);
23
24impl IdTag {
25	pub fn new(id_tag: &str) -> IdTag {
26		IdTag(Box::from(id_tag))
27	}
28}
29
30impl<S> FromRequestParts<S> for IdTag
31where
32	S: Send + Sync,
33{
34	type Rejection = Error;
35
36	async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
37		if let Some(id_tag) = parts.extensions.get::<IdTag>().cloned() {
38			Ok(id_tag)
39		} else {
40			Err(Error::PermissionDenied)
41		}
42	}
43}
44
45impl<S> OptionalFromRequestParts<S> for IdTag
46where
47	S: Send + Sync,
48{
49	type Rejection = Error;
50
51	async fn from_request_parts(
52		parts: &mut Parts,
53		_state: &S,
54	) -> Result<Option<Self>, Self::Rejection> {
55		Ok(parts.extensions.get::<IdTag>().cloned())
56	}
57}
58
59// TnId //
60//******//
61/// Trait for resolving `TnId` from an identity tag string.
62///
63/// Implement this on your application state type to enable the
64/// `TnId` Axum extractor.
65#[async_trait]
66pub trait TnIdResolver: Send + Sync {
67	async fn resolve_tn_id(&self, id_tag: &str) -> Result<TnId, Error>;
68}
69
70/// Blanket impl for `Arc<T>` so that `App = Arc<AppState>` works
71/// when `AppState` implements `TnIdResolver`.
72#[async_trait]
73impl<T: TnIdResolver> TnIdResolver for Arc<T> {
74	async fn resolve_tn_id(&self, id_tag: &str) -> Result<TnId, Error> {
75		(**self).resolve_tn_id(id_tag).await
76	}
77}
78
79impl<S> FromRequestParts<S> for TnId
80where
81	S: TnIdResolver + Send + Sync,
82{
83	type Rejection = Error;
84
85	async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
86		if let Some(id_tag) = parts.extensions.get::<IdTag>().cloned() {
87			state.resolve_tn_id(&id_tag.0).await.map_err(|_| Error::PermissionDenied)
88		} else {
89			Err(Error::PermissionDenied)
90		}
91	}
92}
93
94// vim: ts=4