Skip to main content

modo/tenant/
extractor.rs

1use std::ops::Deref;
2use std::sync::Arc;
3
4use axum::extract::{FromRequestParts, OptionalFromRequestParts};
5use http::request::Parts;
6
7use crate::Error;
8
9use super::traits::HasTenantId;
10
11/// Axum extractor that provides access to the resolved tenant.
12///
13/// Pulls the resolved tenant from request extensions (inserted by
14/// [`TenantMiddleware`](super::TenantMiddleware)). Returns HTTP 500 if the
15/// middleware has not been applied -- this indicates a developer misconfiguration.
16///
17/// Use `Option<Tenant<T>>` for routes that work with or without a tenant
18/// (the [`OptionalFromRequestParts`] impl returns `Ok(None)` instead of an error).
19pub struct Tenant<T>(pub(crate) Arc<T>);
20
21impl<T> Tenant<T> {
22    /// Returns a reference to the resolved tenant.
23    pub fn get(&self) -> &T {
24        &self.0
25    }
26}
27
28impl<T> Deref for Tenant<T> {
29    type Target = T;
30    fn deref(&self) -> &T {
31        &self.0
32    }
33}
34
35impl<T> Clone for Tenant<T> {
36    fn clone(&self) -> Self {
37        Self(self.0.clone())
38    }
39}
40
41impl<T: std::fmt::Debug> std::fmt::Debug for Tenant<T> {
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        f.debug_tuple("Tenant").field(&self.0).finish()
44    }
45}
46
47impl<T, S> FromRequestParts<S> for Tenant<T>
48where
49    T: HasTenantId + Send + Sync + Clone + 'static,
50    S: Send + Sync,
51{
52    type Rejection = Error;
53
54    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
55        parts
56            .extensions
57            .get::<Arc<T>>()
58            .cloned()
59            .map(Tenant)
60            .ok_or_else(|| Error::internal("Tenant middleware not applied"))
61    }
62}
63
64impl<T, S> OptionalFromRequestParts<S> for Tenant<T>
65where
66    T: HasTenantId + Send + Sync + Clone + 'static,
67    S: Send + Sync,
68{
69    type Rejection = Error;
70
71    async fn from_request_parts(
72        parts: &mut Parts,
73        _state: &S,
74    ) -> Result<Option<Self>, Self::Rejection> {
75        Ok(parts.extensions.get::<Arc<T>>().cloned().map(Tenant))
76    }
77}
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82    use std::sync::Arc;
83
84    #[derive(Clone, Debug)]
85    struct TestTenant {
86        id: String,
87        name: String,
88    }
89
90    impl HasTenantId for TestTenant {
91        fn tenant_id(&self) -> &str {
92            &self.id
93        }
94    }
95
96    #[test]
97    fn tenant_get() {
98        let t = Tenant(Arc::new(TestTenant {
99            id: "t1".into(),
100            name: "Test".into(),
101        }));
102        assert_eq!(t.get().id, "t1");
103        assert_eq!(t.get().name, "Test");
104    }
105
106    #[test]
107    fn tenant_deref() {
108        let t = Tenant(Arc::new(TestTenant {
109            id: "t1".into(),
110            name: "Test".into(),
111        }));
112        // Deref gives direct field access
113        assert_eq!(t.name, "Test");
114    }
115
116    #[tokio::test]
117    async fn extract_from_extensions() {
118        let tenant = TestTenant {
119            id: "t1".into(),
120            name: "Test".into(),
121        };
122        let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
123        parts.extensions.insert(Arc::new(tenant));
124
125        let result =
126            <Tenant<TestTenant> as FromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
127        assert!(result.is_ok());
128        assert_eq!(result.unwrap().get().id, "t1");
129    }
130
131    #[tokio::test]
132    async fn extract_missing_returns_500() {
133        let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
134
135        let result =
136            <Tenant<TestTenant> as FromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
137        assert!(result.is_err());
138        let err = result.unwrap_err();
139        assert_eq!(err.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
140    }
141
142    #[tokio::test]
143    async fn option_tenant_none_when_missing() {
144        let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
145
146        let result = <Tenant<TestTenant> as OptionalFromRequestParts<()>>::from_request_parts(
147            &mut parts,
148            &(),
149        )
150        .await;
151        assert!(result.is_ok());
152        assert!(result.unwrap().is_none());
153    }
154
155    #[tokio::test]
156    async fn option_tenant_some_when_present() {
157        let tenant = TestTenant {
158            id: "t1".into(),
159            name: "Test".into(),
160        };
161        let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
162        parts.extensions.insert(Arc::new(tenant));
163
164        let result = <Tenant<TestTenant> as OptionalFromRequestParts<()>>::from_request_parts(
165            &mut parts,
166            &(),
167        )
168        .await;
169        assert!(result.is_ok());
170        assert!(result.unwrap().is_some());
171    }
172}