1use axum::extract::{FromRequestParts, OptionalFromRequestParts};
2use http::request::Parts;
3
4use crate::error::Error;
5
6pub use super::types::TierInfo;
7
8impl<S: Send + Sync> FromRequestParts<S> for TierInfo {
15 type Rejection = Error;
16
17 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
18 parts
19 .extensions
20 .get::<TierInfo>()
21 .cloned()
22 .ok_or_else(|| Error::internal("Tier middleware not applied"))
23 }
24}
25
26impl<S: Send + Sync> OptionalFromRequestParts<S> for TierInfo {
31 type Rejection = Error;
32
33 async fn from_request_parts(
34 parts: &mut Parts,
35 _state: &S,
36 ) -> Result<Option<Self>, Self::Rejection> {
37 Ok(parts.extensions.get::<TierInfo>().cloned())
38 }
39}
40
41#[cfg(test)]
42mod tests {
43 use super::*;
44 use std::collections::HashMap;
45
46 use super::super::types::FeatureAccess;
47
48 fn test_tier() -> TierInfo {
49 TierInfo {
50 name: "pro".into(),
51 features: HashMap::from([("sso".into(), FeatureAccess::Toggle(true))]),
52 }
53 }
54
55 #[tokio::test]
56 async fn extract_from_extensions() {
57 let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
58 parts.extensions.insert(test_tier());
59
60 let result = <TierInfo as FromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
61 assert!(result.is_ok());
62 assert_eq!(result.unwrap().name, "pro");
63 }
64
65 #[tokio::test]
66 async fn extract_missing_returns_500() {
67 let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
68
69 let result = <TierInfo as FromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
70 assert!(result.is_err());
71 let err = result.unwrap_err();
72 assert_eq!(err.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
73 }
74
75 #[tokio::test]
76 async fn optional_none_when_missing() {
77 let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
78
79 let result =
80 <TierInfo as OptionalFromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
81 assert!(result.is_ok());
82 assert!(result.unwrap().is_none());
83 }
84
85 #[tokio::test]
86 async fn optional_some_when_present() {
87 let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
88 parts.extensions.insert(test_tier());
89
90 let result =
91 <TierInfo as OptionalFromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
92 assert!(result.is_ok());
93 let tier = result.unwrap().unwrap();
94 assert_eq!(tier.name, "pro");
95 }
96}