1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5
6use axum::body::Body;
7use axum::response::IntoResponse;
8use http::Request;
9use http::request::Parts;
10use tower::{Layer, Service};
11
12use super::types::{TierInfo, TierResolver};
13
14type OwnerExtractor = Arc<dyn Fn(&Parts) -> Option<String> + Send + Sync>;
15
16pub struct TierLayer {
34 resolver: TierResolver,
35 extractor: OwnerExtractor,
36 default: Option<Arc<TierInfo>>,
37}
38
39impl TierLayer {
40 pub fn new<F>(resolver: TierResolver, extractor: F) -> Self
45 where
46 F: Fn(&Parts) -> Option<String> + Send + Sync + 'static,
47 {
48 Self {
49 resolver,
50 extractor: Arc::new(extractor),
51 default: None,
52 }
53 }
54
55 pub fn with_default(mut self, default: TierInfo) -> Self {
58 self.default = Some(Arc::new(default));
59 self
60 }
61}
62
63impl Clone for TierLayer {
64 fn clone(&self) -> Self {
65 Self {
66 resolver: self.resolver.clone(),
67 extractor: self.extractor.clone(),
68 default: self.default.clone(),
69 }
70 }
71}
72
73impl<S> Layer<S> for TierLayer {
74 type Service = TierMiddleware<S>;
75
76 fn layer(&self, inner: S) -> Self::Service {
77 TierMiddleware {
78 inner,
79 resolver: self.resolver.clone(),
80 extractor: self.extractor.clone(),
81 default: self.default.clone(),
82 }
83 }
84}
85
86pub struct TierMiddleware<S> {
88 inner: S,
89 resolver: TierResolver,
90 extractor: OwnerExtractor,
91 default: Option<Arc<TierInfo>>,
92}
93
94impl<S: Clone> Clone for TierMiddleware<S> {
95 fn clone(&self) -> Self {
96 Self {
97 inner: self.inner.clone(),
98 resolver: self.resolver.clone(),
99 extractor: self.extractor.clone(),
100 default: self.default.clone(),
101 }
102 }
103}
104
105impl<S> Service<Request<Body>> for TierMiddleware<S>
106where
107 S: Service<Request<Body>, Response = http::Response<Body>> + Clone + Send + 'static,
108 S::Future: Send + 'static,
109 S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
110{
111 type Response = http::Response<Body>;
112 type Error = S::Error;
113 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
114
115 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
116 self.inner.poll_ready(cx)
117 }
118
119 fn call(&mut self, request: Request<Body>) -> Self::Future {
120 let resolver = self.resolver.clone();
121 let extractor = self.extractor.clone();
122 let default = self.default.clone();
123 let mut inner = self.inner.clone();
124 std::mem::swap(&mut self.inner, &mut inner);
125
126 Box::pin(async move {
127 let (mut parts, body) = request.into_parts();
128
129 let tier_info = match (extractor)(&parts) {
130 Some(owner_id) => match resolver.resolve(&owner_id).await {
131 Ok(info) => Some(info),
132 Err(e) => return Ok(e.into_response()),
133 },
134 None => default.map(|arc| (*arc).clone()),
135 };
136
137 if let Some(info) = tier_info {
138 parts.extensions.insert(info);
139 }
140
141 let request = Request::from_parts(parts, body);
142 inner.call(request).await
143 })
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150 use std::collections::HashMap;
151 use std::convert::Infallible;
152
153 use http::{Response, StatusCode};
154 use tower::ServiceExt;
155
156 use super::super::types::FeatureAccess;
157 use super::super::types::test_support::{FailingTierBackend, StaticTierBackend};
158
159 fn pro_tier() -> TierInfo {
160 TierInfo {
161 name: "pro".into(),
162 features: HashMap::from([("sso".into(), FeatureAccess::Toggle(true))]),
163 }
164 }
165
166 fn anon_tier() -> TierInfo {
167 TierInfo {
168 name: "anonymous".into(),
169 features: HashMap::from([("public_api".into(), FeatureAccess::Toggle(true))]),
170 }
171 }
172
173 fn resolver(tier: TierInfo) -> TierResolver {
174 TierResolver::from_backend(Arc::new(StaticTierBackend::new(tier)))
175 }
176
177 fn failing_resolver() -> TierResolver {
178 TierResolver::from_backend(Arc::new(FailingTierBackend))
179 }
180
181 async fn ok_handler(req: Request<Body>) -> Result<Response<Body>, Infallible> {
182 let has_tier = req.extensions().get::<TierInfo>().is_some();
183 let body = if has_tier { "tier-present" } else { "no-tier" };
184 Ok(Response::new(Body::from(body)))
185 }
186
187 #[tokio::test]
188 async fn extractor_some_resolves_tier() {
189 let layer = TierLayer::new(resolver(pro_tier()), |_| Some("tenant_1".into()));
190 let svc = layer.layer(tower::service_fn(ok_handler));
191
192 let req = Request::builder().body(Body::empty()).unwrap();
193 let resp = svc.oneshot(req).await.unwrap();
194 assert_eq!(resp.status(), StatusCode::OK);
195
196 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
197 .await
198 .unwrap();
199 assert_eq!(body, "tier-present");
200 }
201
202 #[tokio::test]
203 async fn extractor_none_no_default_skips() {
204 let layer = TierLayer::new(resolver(pro_tier()), |_| None);
205 let svc = layer.layer(tower::service_fn(ok_handler));
206
207 let req = Request::builder().body(Body::empty()).unwrap();
208 let resp = svc.oneshot(req).await.unwrap();
209
210 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
211 .await
212 .unwrap();
213 assert_eq!(body, "no-tier");
214 }
215
216 #[tokio::test]
217 async fn extractor_none_with_default_injects_default() {
218 let layer = TierLayer::new(resolver(pro_tier()), |_| None).with_default(anon_tier());
219 let svc = layer.layer(tower::service_fn(ok_handler));
220
221 let req = Request::builder().body(Body::empty()).unwrap();
222 let resp = svc.oneshot(req).await.unwrap();
223
224 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
225 .await
226 .unwrap();
227 assert_eq!(body, "tier-present");
228 }
229
230 #[tokio::test]
231 async fn backend_error_returns_error_response() {
232 let layer = TierLayer::new(failing_resolver(), |_| Some("tenant_1".into()));
233 let svc = layer.layer(tower::service_fn(ok_handler));
234
235 let req = Request::builder().body(Body::empty()).unwrap();
236 let resp = svc.oneshot(req).await.unwrap();
237 assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
238 }
239
240 #[tokio::test]
241 async fn backend_error_does_not_call_inner() {
242 use std::sync::atomic::{AtomicBool, Ordering};
243
244 let called = Arc::new(AtomicBool::new(false));
245 let called_clone = called.clone();
246
247 let layer = TierLayer::new(failing_resolver(), |_| Some("tenant_1".into()));
248 let svc = layer.layer(tower::service_fn(move |_req: Request<Body>| {
249 let called = called_clone.clone();
250 async move {
251 called.store(true, Ordering::SeqCst);
252 Ok::<_, Infallible>(Response::new(Body::from("should not reach")))
253 }
254 }));
255
256 let req = Request::builder().body(Body::empty()).unwrap();
257 let _resp = svc.oneshot(req).await.unwrap();
258 assert!(!called.load(Ordering::SeqCst));
259 }
260
261 #[tokio::test]
262 async fn tier_info_accessible_in_inner_service() {
263 let layer = TierLayer::new(resolver(pro_tier()), |_| Some("t".into()));
264
265 let inner = tower::service_fn(|req: Request<Body>| async move {
266 let tier = req.extensions().get::<TierInfo>().unwrap();
267 assert_eq!(tier.name, "pro");
268 assert!(tier.has_feature("sso"));
269 Ok::<_, Infallible>(Response::new(Body::empty()))
270 });
271
272 let svc = layer.layer(inner);
273 let req = Request::builder().body(Body::empty()).unwrap();
274 let resp = svc.oneshot(req).await.unwrap();
275 assert_eq!(resp.status(), StatusCode::OK);
276 }
277
278 #[tokio::test]
279 async fn extractor_reads_from_extensions() {
280 #[derive(Clone)]
281 struct OwnerId(String);
282
283 let layer = TierLayer::new(resolver(pro_tier()), |parts| {
284 parts.extensions.get::<OwnerId>().map(|id| id.0.clone())
285 });
286 let svc = layer.layer(tower::service_fn(ok_handler));
287
288 let mut req = Request::builder().body(Body::empty()).unwrap();
289 req.extensions_mut().insert(OwnerId("owner_42".into()));
290 let resp = svc.oneshot(req).await.unwrap();
291
292 let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
293 .await
294 .unwrap();
295 assert_eq!(body, "tier-present");
296 }
297}