1use std::ops::Deref;
2use std::sync::Arc;
3
4use axum::extract::{FromRequestParts, OptionalFromRequestParts};
5use http::request::Parts;
6
7use crate::Error;
8
9use super::traits::HasTenantId;
10
11pub struct Tenant<T>(pub(crate) Arc<T>);
20
21impl<T> Tenant<T> {
22 pub fn get(&self) -> &T {
24 &self.0
25 }
26}
27
28impl<T> Deref for Tenant<T> {
29 type Target = T;
30 fn deref(&self) -> &T {
31 &self.0
32 }
33}
34
35impl<T> Clone for Tenant<T> {
36 fn clone(&self) -> Self {
37 Self(self.0.clone())
38 }
39}
40
41impl<T: std::fmt::Debug> std::fmt::Debug for Tenant<T> {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 f.debug_tuple("Tenant").field(&self.0).finish()
44 }
45}
46
47impl<T, S> FromRequestParts<S> for Tenant<T>
48where
49 T: HasTenantId + Send + Sync + Clone + 'static,
50 S: Send + Sync,
51{
52 type Rejection = Error;
53
54 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
55 parts
56 .extensions
57 .get::<Arc<T>>()
58 .cloned()
59 .map(Tenant)
60 .ok_or_else(|| Error::internal("Tenant middleware not applied"))
61 }
62}
63
64impl<T, S> OptionalFromRequestParts<S> for Tenant<T>
65where
66 T: HasTenantId + Send + Sync + Clone + 'static,
67 S: Send + Sync,
68{
69 type Rejection = Error;
70
71 async fn from_request_parts(
72 parts: &mut Parts,
73 _state: &S,
74 ) -> Result<Option<Self>, Self::Rejection> {
75 Ok(parts.extensions.get::<Arc<T>>().cloned().map(Tenant))
76 }
77}
78
79#[cfg(test)]
80mod tests {
81 use super::*;
82 use std::sync::Arc;
83
84 #[derive(Clone, Debug)]
85 struct TestTenant {
86 id: String,
87 name: String,
88 }
89
90 impl HasTenantId for TestTenant {
91 fn tenant_id(&self) -> &str {
92 &self.id
93 }
94 }
95
96 #[test]
97 fn tenant_get() {
98 let t = Tenant(Arc::new(TestTenant {
99 id: "t1".into(),
100 name: "Test".into(),
101 }));
102 assert_eq!(t.get().id, "t1");
103 assert_eq!(t.get().name, "Test");
104 }
105
106 #[test]
107 fn tenant_deref() {
108 let t = Tenant(Arc::new(TestTenant {
109 id: "t1".into(),
110 name: "Test".into(),
111 }));
112 assert_eq!(t.name, "Test");
114 }
115
116 #[tokio::test]
117 async fn extract_from_extensions() {
118 let tenant = TestTenant {
119 id: "t1".into(),
120 name: "Test".into(),
121 };
122 let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
123 parts.extensions.insert(Arc::new(tenant));
124
125 let result =
126 <Tenant<TestTenant> as FromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
127 assert!(result.is_ok());
128 assert_eq!(result.unwrap().get().id, "t1");
129 }
130
131 #[tokio::test]
132 async fn extract_missing_returns_500() {
133 let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
134
135 let result =
136 <Tenant<TestTenant> as FromRequestParts<()>>::from_request_parts(&mut parts, &()).await;
137 assert!(result.is_err());
138 let err = result.unwrap_err();
139 assert_eq!(err.status(), http::StatusCode::INTERNAL_SERVER_ERROR);
140 }
141
142 #[tokio::test]
143 async fn option_tenant_none_when_missing() {
144 let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
145
146 let result = <Tenant<TestTenant> as OptionalFromRequestParts<()>>::from_request_parts(
147 &mut parts,
148 &(),
149 )
150 .await;
151 assert!(result.is_ok());
152 assert!(result.unwrap().is_none());
153 }
154
155 #[tokio::test]
156 async fn option_tenant_some_when_present() {
157 let tenant = TestTenant {
158 id: "t1".into(),
159 name: "Test".into(),
160 };
161 let (mut parts, _) = http::Request::builder().body(()).unwrap().into_parts();
162 parts.extensions.insert(Arc::new(tenant));
163
164 let result = <Tenant<TestTenant> as OptionalFromRequestParts<()>>::from_request_parts(
165 &mut parts,
166 &(),
167 )
168 .await;
169 assert!(result.is_ok());
170 assert!(result.unwrap().is_some());
171 }
172}