Skip to main content

modo/tier/
middleware.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5
6use axum::body::Body;
7use axum::response::IntoResponse;
8use http::Request;
9use http::request::Parts;
10use tower::{Layer, Service};
11
12use super::types::{TierInfo, TierResolver};
13
14type OwnerExtractor = Arc<dyn Fn(&Parts) -> Option<String> + Send + Sync>;
15
16/// Tower middleware layer that resolves [`TierInfo`] and inserts it into
17/// request extensions.
18///
19/// Apply with `.layer()` on the router. Guards ([`super::require_feature`],
20/// [`super::require_limit`]) are applied separately with `.route_layer()`.
21///
22/// # Owner ID extraction
23///
24/// The extractor closure reads from `&Parts` (populated by upstream middleware)
25/// and returns `Some(owner_id)` or `None`.
26///
27/// # Default tier
28///
29/// When the extractor returns `None` and a default is set via
30/// [`with_default`](Self::with_default), the default `TierInfo` is inserted.
31/// Otherwise, no `TierInfo` is inserted and the inner service is called
32/// directly — downstream guards handle the absence.
33pub struct TierLayer {
34    resolver: TierResolver,
35    extractor: OwnerExtractor,
36    default: Option<Arc<TierInfo>>,
37}
38
39impl TierLayer {
40    /// Create a new tier layer.
41    ///
42    /// `extractor` is a sync closure that returns the owner ID from request
43    /// parts, or `None` if no owner context is available.
44    pub fn new<F>(resolver: TierResolver, extractor: F) -> Self
45    where
46        F: Fn(&Parts) -> Option<String> + Send + Sync + 'static,
47    {
48        Self {
49            resolver,
50            extractor: Arc::new(extractor),
51            default: None,
52        }
53    }
54
55    /// When the extractor returns `None`, inject this `TierInfo` instead of
56    /// skipping.
57    pub fn with_default(mut self, default: TierInfo) -> Self {
58        self.default = Some(Arc::new(default));
59        self
60    }
61}
62
63impl Clone for TierLayer {
64    fn clone(&self) -> Self {
65        Self {
66            resolver: self.resolver.clone(),
67            extractor: self.extractor.clone(),
68            default: self.default.clone(),
69        }
70    }
71}
72
73impl<S> Layer<S> for TierLayer {
74    type Service = TierMiddleware<S>;
75
76    fn layer(&self, inner: S) -> Self::Service {
77        TierMiddleware {
78            inner,
79            resolver: self.resolver.clone(),
80            extractor: self.extractor.clone(),
81            default: self.default.clone(),
82        }
83    }
84}
85
86/// Tower service produced by [`TierLayer`].
87pub struct TierMiddleware<S> {
88    inner: S,
89    resolver: TierResolver,
90    extractor: OwnerExtractor,
91    default: Option<Arc<TierInfo>>,
92}
93
94impl<S: Clone> Clone for TierMiddleware<S> {
95    fn clone(&self) -> Self {
96        Self {
97            inner: self.inner.clone(),
98            resolver: self.resolver.clone(),
99            extractor: self.extractor.clone(),
100            default: self.default.clone(),
101        }
102    }
103}
104
105impl<S> Service<Request<Body>> for TierMiddleware<S>
106where
107    S: Service<Request<Body>, Response = http::Response<Body>> + Clone + Send + 'static,
108    S::Future: Send + 'static,
109    S::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
110{
111    type Response = http::Response<Body>;
112    type Error = S::Error;
113    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
114
115    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
116        self.inner.poll_ready(cx)
117    }
118
119    fn call(&mut self, request: Request<Body>) -> Self::Future {
120        let resolver = self.resolver.clone();
121        let extractor = self.extractor.clone();
122        let default = self.default.clone();
123        let mut inner = self.inner.clone();
124        std::mem::swap(&mut self.inner, &mut inner);
125
126        Box::pin(async move {
127            let (mut parts, body) = request.into_parts();
128
129            let tier_info = match (extractor)(&parts) {
130                Some(owner_id) => match resolver.resolve(&owner_id).await {
131                    Ok(info) => Some(info),
132                    Err(e) => return Ok(e.into_response()),
133                },
134                None => default.map(|arc| (*arc).clone()),
135            };
136
137            if let Some(info) = tier_info {
138                parts.extensions.insert(info);
139            }
140
141            let request = Request::from_parts(parts, body);
142            inner.call(request).await
143        })
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150    use std::collections::HashMap;
151    use std::convert::Infallible;
152
153    use http::{Response, StatusCode};
154    use tower::ServiceExt;
155
156    use super::super::types::FeatureAccess;
157    use super::super::types::test_support::{FailingTierBackend, StaticTierBackend};
158
159    fn pro_tier() -> TierInfo {
160        TierInfo {
161            name: "pro".into(),
162            features: HashMap::from([("sso".into(), FeatureAccess::Toggle(true))]),
163        }
164    }
165
166    fn anon_tier() -> TierInfo {
167        TierInfo {
168            name: "anonymous".into(),
169            features: HashMap::from([("public_api".into(), FeatureAccess::Toggle(true))]),
170        }
171    }
172
173    fn resolver(tier: TierInfo) -> TierResolver {
174        TierResolver::from_backend(Arc::new(StaticTierBackend::new(tier)))
175    }
176
177    fn failing_resolver() -> TierResolver {
178        TierResolver::from_backend(Arc::new(FailingTierBackend))
179    }
180
181    async fn ok_handler(req: Request<Body>) -> Result<Response<Body>, Infallible> {
182        let has_tier = req.extensions().get::<TierInfo>().is_some();
183        let body = if has_tier { "tier-present" } else { "no-tier" };
184        Ok(Response::new(Body::from(body)))
185    }
186
187    #[tokio::test]
188    async fn extractor_some_resolves_tier() {
189        let layer = TierLayer::new(resolver(pro_tier()), |_| Some("tenant_1".into()));
190        let svc = layer.layer(tower::service_fn(ok_handler));
191
192        let req = Request::builder().body(Body::empty()).unwrap();
193        let resp = svc.oneshot(req).await.unwrap();
194        assert_eq!(resp.status(), StatusCode::OK);
195
196        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
197            .await
198            .unwrap();
199        assert_eq!(body, "tier-present");
200    }
201
202    #[tokio::test]
203    async fn extractor_none_no_default_skips() {
204        let layer = TierLayer::new(resolver(pro_tier()), |_| None);
205        let svc = layer.layer(tower::service_fn(ok_handler));
206
207        let req = Request::builder().body(Body::empty()).unwrap();
208        let resp = svc.oneshot(req).await.unwrap();
209
210        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
211            .await
212            .unwrap();
213        assert_eq!(body, "no-tier");
214    }
215
216    #[tokio::test]
217    async fn extractor_none_with_default_injects_default() {
218        let layer = TierLayer::new(resolver(pro_tier()), |_| None).with_default(anon_tier());
219        let svc = layer.layer(tower::service_fn(ok_handler));
220
221        let req = Request::builder().body(Body::empty()).unwrap();
222        let resp = svc.oneshot(req).await.unwrap();
223
224        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
225            .await
226            .unwrap();
227        assert_eq!(body, "tier-present");
228    }
229
230    #[tokio::test]
231    async fn backend_error_returns_error_response() {
232        let layer = TierLayer::new(failing_resolver(), |_| Some("tenant_1".into()));
233        let svc = layer.layer(tower::service_fn(ok_handler));
234
235        let req = Request::builder().body(Body::empty()).unwrap();
236        let resp = svc.oneshot(req).await.unwrap();
237        assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
238    }
239
240    #[tokio::test]
241    async fn backend_error_does_not_call_inner() {
242        use std::sync::atomic::{AtomicBool, Ordering};
243
244        let called = Arc::new(AtomicBool::new(false));
245        let called_clone = called.clone();
246
247        let layer = TierLayer::new(failing_resolver(), |_| Some("tenant_1".into()));
248        let svc = layer.layer(tower::service_fn(move |_req: Request<Body>| {
249            let called = called_clone.clone();
250            async move {
251                called.store(true, Ordering::SeqCst);
252                Ok::<_, Infallible>(Response::new(Body::from("should not reach")))
253            }
254        }));
255
256        let req = Request::builder().body(Body::empty()).unwrap();
257        let _resp = svc.oneshot(req).await.unwrap();
258        assert!(!called.load(Ordering::SeqCst));
259    }
260
261    #[tokio::test]
262    async fn tier_info_accessible_in_inner_service() {
263        let layer = TierLayer::new(resolver(pro_tier()), |_| Some("t".into()));
264
265        let inner = tower::service_fn(|req: Request<Body>| async move {
266            let tier = req.extensions().get::<TierInfo>().unwrap();
267            assert_eq!(tier.name, "pro");
268            assert!(tier.has_feature("sso"));
269            Ok::<_, Infallible>(Response::new(Body::empty()))
270        });
271
272        let svc = layer.layer(inner);
273        let req = Request::builder().body(Body::empty()).unwrap();
274        let resp = svc.oneshot(req).await.unwrap();
275        assert_eq!(resp.status(), StatusCode::OK);
276    }
277
278    #[tokio::test]
279    async fn extractor_reads_from_extensions() {
280        #[derive(Clone)]
281        struct OwnerId(String);
282
283        let layer = TierLayer::new(resolver(pro_tier()), |parts| {
284            parts.extensions.get::<OwnerId>().map(|id| id.0.clone())
285        });
286        let svc = layer.layer(tower::service_fn(ok_handler));
287
288        let mut req = Request::builder().body(Body::empty()).unwrap();
289        req.extensions_mut().insert(OwnerId("owner_42".into()));
290        let resp = svc.oneshot(req).await.unwrap();
291
292        let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
293            .await
294            .unwrap();
295        assert_eq!(body, "tier-present");
296    }
297}