Skip to main content

modo/tier/
extractor.rs

1use axum::extract::{FromRequestParts, OptionalFromRequestParts};
2use http::request::Parts;
3
4use crate::error::Error;
5
6pub use super::types::TierInfo;
7
8/// Extracts [`TierInfo`] from request extensions.
9///
10/// # Errors
11///
12/// Returns [`Error::internal`](crate::Error::internal) if [`TierLayer`](super::TierLayer)
13/// has not been applied (i.e., `TierInfo` is missing from extensions).
14impl<S: Send + Sync> FromRequestParts<S> for TierInfo {
15    type Rejection = Error;
16
17    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
18        parts
19            .extensions
20            .get::<TierInfo>()
21            .cloned()
22            .ok_or_else(|| Error::internal("Tier middleware not applied"))
23    }
24}
25
26/// Optionally extracts [`TierInfo`] from request extensions.
27///
28/// Returns `Ok(None)` when [`TierLayer`](super::TierLayer) has not been applied
29/// or the owner extractor returned `None` without a default tier configured.
30impl<S: Send + Sync> OptionalFromRequestParts<S> for TierInfo {
31    type Rejection = Error;
32
33    async fn from_request_parts(
34        parts: &mut Parts,
35        _state: &S,
36    ) -> Result<Option<Self>, Self::Rejection> {
37        Ok(parts.extensions.get::<TierInfo>().cloned())
38    }
39}
40
41#[cfg(test)]
42mod tests {
43    use super::*;
44    use std::collections::HashMap;
45
46    use super::super::types::FeatureAccess;
47
48    fn test_tier() -> TierInfo {
49        TierInfo {
50            name: "pro".into(),
51            features: HashMap::from([("sso".into(), FeatureAccess::Toggle(true))]),
52        }
53    }
54
55    #[tokio::test]
56    async fn extract_from_extensions() {
57        let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
58        parts.extensions.insert(test_tier());
59
60        let result = <TierInfo as FromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
61        assert!(result.is_ok());
62        assert_eq!(result.unwrap().name, "pro");
63    }
64
65    #[tokio::test]
66    async fn extract_missing_returns_500() {
67        let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
68
69        let result = <TierInfo as FromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
70        assert!(result.is_err());
71        let err = result.unwrap_err();
72        assert_eq!(err.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
73    }
74
75    #[tokio::test]
76    async fn optional_none_when_missing() {
77        let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
78
79        let result =
80            <TierInfo as OptionalFromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
81        assert!(result.is_ok());
82        assert!(result.unwrap().is_none());
83    }
84
85    #[tokio::test]
86    async fn optional_some_when_present() {
87        let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
88        parts.extensions.insert(test_tier());
89
90        let result =
91            <TierInfo as OptionalFromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
92        assert!(result.is_ok());
93        let tier = result.unwrap().unwrap();
94        assert_eq!(tier.name, "pro");
95    }
96}