Skip to main content

multistore_metering/
lib.rs

1//! Usage metering and quota enforcement middleware.
2//!
3//! This crate provides trait abstractions for tracking API usage and enforcing
4//! quotas, along with a [`MeteringMiddleware`] that wires them into the proxy's
5//! middleware chain. Integrators bring their own storage backends by implementing
6//! [`UsageRecorder`] and [`QuotaChecker`].
7//!
8//! ## Quick start
9//!
10//! ```rust,ignore
11//! use multistore_metering::{MeteringMiddleware, UsageRecorder, QuotaChecker};
12//!
13//! // Implement UsageRecorder and QuotaChecker for your storage backend,
14//! // then register the middleware on the ProxyGateway builder:
15//! let metering = MeteringMiddleware::new(my_quota_checker, my_usage_recorder);
16//! gateway_builder.add_middleware(metering);
17//! ```
18//!
19//! ## Architecture
20//!
21//! - **Pre-dispatch:** [`QuotaChecker::check_quota`] runs before the request
22//!   proceeds, using `Content-Length` as a byte estimate. Return
23//!   [`Err(QuotaExceeded)`](QuotaExceeded) to reject with HTTP 429.
24//! - **Post-dispatch:** [`UsageRecorder::record_operation`] runs after the
25//!   response is available, recording actual status and byte counts from
26//!   the backend response.
27
28use std::future::Future;
29use std::net::IpAddr;
30
31use multistore::api::response::ErrorResponse;
32use multistore::error::ProxyError;
33use multistore::maybe_send::{MaybeSend, MaybeSync};
34use multistore::middleware::{CompletedRequest, DispatchContext, Middleware, Next};
35use multistore::route_handler::{HandlerAction, ProxyResponseBody, ProxyResult};
36use multistore::types::{ResolvedIdentity, S3Operation};
37
38use bytes::Bytes;
39use http::HeaderMap;
40
41/// A completed operation's metadata, passed to [`UsageRecorder::record_operation`].
42pub struct UsageEvent<'a> {
43    /// The unique request identifier.
44    pub request_id: &'a str,
45    /// The resolved caller identity, if any.
46    pub identity: Option<&'a ResolvedIdentity>,
47    /// The parsed S3 operation, if determined.
48    pub operation: Option<&'a S3Operation>,
49    /// The target bucket name, if applicable.
50    pub bucket: Option<&'a str>,
51    /// The HTTP status code of the response.
52    pub status: u16,
53    /// Best-available byte count: `content_length` from backend response
54    /// for forwarded requests, response body length for direct responses,
55    /// or `Content-Length` header estimate as fallback.
56    pub bytes_transferred: u64,
57    /// Whether the request was forwarded to a backend via presigned URL.
58    pub was_forwarded: bool,
59    /// The client's IP address, if known.
60    pub source_ip: Option<IpAddr>,
61}
62
63/// Quota violation error returned by [`QuotaChecker::check_quota`].
64///
65/// The `message` is included in the HTTP 429 response body.
66#[derive(Debug)]
67pub struct QuotaExceeded {
68    /// Human-readable explanation of the quota violation.
69    pub message: String,
70}
71
72/// Records completed operations for usage tracking.
73///
74/// Integrators implement this trait with their storage backend (Redis,
75/// DynamoDB, in-memory, etc.). The recorder is called after every
76/// dispatched request, including failed ones.
77pub trait UsageRecorder: MaybeSend + MaybeSync + 'static {
78    /// Record a completed operation.
79    ///
80    /// This runs in the post-dispatch phase. Implementations should be
81    /// fire-and-forget — recording failures must not affect the response.
82    fn record_operation<'a>(
83        &'a self,
84        event: UsageEvent<'a>,
85    ) -> impl Future<Output = ()> + MaybeSend + 'a;
86}
87
88/// Pre-dispatch quota enforcement.
89///
90/// Integrators implement this trait to enforce usage limits before a
91/// request proceeds. The `estimated_bytes` value comes from the request's
92/// `Content-Length` header (for uploads) or is 0 when unknown.
93pub trait QuotaChecker: MaybeSend + MaybeSync + 'static {
94    /// Check whether the caller is within their quota.
95    ///
96    /// Return `Ok(())` to allow the request, or `Err(QuotaExceeded)` to
97    /// reject it with HTTP 429.
98    fn check_quota<'a>(
99        &'a self,
100        identity: &'a ResolvedIdentity,
101        operation: &'a S3Operation,
102        bucket: Option<&'a str>,
103        estimated_bytes: u64,
104        source_ip: Option<IpAddr>,
105    ) -> impl Future<Output = Result<(), QuotaExceeded>> + MaybeSend + 'a;
106}
107
108/// Middleware that enforces quotas pre-dispatch and records usage post-dispatch.
109///
110/// Generic over the quota checker `Q` and usage recorder `U`, allowing
111/// integrators to bring their own storage backends.
112///
113/// ## Request flow
114///
115/// 1. Extract `Content-Length` from request headers as a byte estimate.
116/// 2. Call [`QuotaChecker::check_quota`] — reject with 429 if over limit.
117/// 3. Delegate to the next middleware via [`Next::run`].
118/// 4. In [`after_dispatch`](Middleware::after_dispatch), call
119///    [`UsageRecorder::record_operation`] with the actual response metadata.
120pub struct MeteringMiddleware<Q, U> {
121    quota_checker: Q,
122    usage_recorder: U,
123}
124
125impl<Q, U> MeteringMiddleware<Q, U> {
126    /// Create a new metering middleware with the given quota checker and
127    /// usage recorder.
128    pub fn new(quota_checker: Q, usage_recorder: U) -> Self {
129        Self {
130            quota_checker,
131            usage_recorder,
132        }
133    }
134}
135
136impl<Q: QuotaChecker, U: UsageRecorder> Middleware for MeteringMiddleware<Q, U> {
137    async fn handle<'a>(
138        &'a self,
139        ctx: DispatchContext<'a>,
140        next: Next<'a>,
141    ) -> Result<HandlerAction, ProxyError> {
142        let estimated_bytes = ctx
143            .headers
144            .get("content-length")
145            .and_then(|v| v.to_str().ok())
146            .and_then(|v| v.parse::<u64>().ok())
147            .unwrap_or(0);
148
149        let bucket_name = ctx.bucket_config.as_ref().map(|b| b.name.as_str());
150
151        if let Err(_exceeded) = self
152            .quota_checker
153            .check_quota(
154                ctx.identity,
155                ctx.operation,
156                bucket_name,
157                estimated_bytes,
158                ctx.source_ip,
159            )
160            .await
161        {
162            tracing::warn!(bucket = bucket_name, "quota exceeded, returning 429");
163            let xml = ErrorResponse::slow_down(ctx.request_id).to_xml();
164            let mut headers = HeaderMap::new();
165            headers.insert("content-type", "application/xml".parse().unwrap());
166            return Ok(HandlerAction::Response(ProxyResult {
167                status: 429,
168                headers,
169                body: ProxyResponseBody::Bytes(Bytes::from(xml)),
170            }));
171        }
172
173        next.run(ctx).await
174    }
175
176    fn after_dispatch(
177        &self,
178        completed: &CompletedRequest<'_>,
179    ) -> impl Future<Output = ()> + MaybeSend + '_ {
180        // Extract all fields synchronously to avoid capturing `completed`
181        // in the returned future (the future's lifetime is tied to `&self`,
182        // not `completed`).
183        let request_id = completed.request_id.to_owned();
184        let identity = completed.identity.cloned();
185        let operation = completed.operation.cloned();
186        let bucket = completed.bucket.map(str::to_owned);
187        let status = completed.status;
188        let bytes_transferred = completed
189            .response_bytes
190            .or(completed.request_bytes)
191            .unwrap_or(0);
192        let was_forwarded = completed.was_forwarded;
193        let source_ip = completed.source_ip;
194
195        async move {
196            self.usage_recorder
197                .record_operation(UsageEvent {
198                    request_id: &request_id,
199                    identity: identity.as_ref(),
200                    operation: operation.as_ref(),
201                    bucket: bucket.as_deref(),
202                    status,
203                    bytes_transferred,
204                    was_forwarded,
205                    source_ip,
206                })
207                .await;
208        }
209    }
210}
211
212// ===========================================================================
213// No-op implementations
214// ===========================================================================
215
216/// A [`UsageRecorder`] that does nothing. Useful when only quota checking
217/// is needed, or for testing.
218pub struct NoopRecorder;
219
220impl UsageRecorder for NoopRecorder {
221    async fn record_operation<'a>(&'a self, _event: UsageEvent<'a>) {}
222}
223
224/// A [`QuotaChecker`] that always allows requests. Useful when only usage
225/// recording is needed, or for testing.
226pub struct NoopQuotaChecker;
227
228impl QuotaChecker for NoopQuotaChecker {
229    async fn check_quota<'a>(
230        &'a self,
231        _identity: &'a ResolvedIdentity,
232        _operation: &'a S3Operation,
233        _bucket: Option<&'a str>,
234        _estimated_bytes: u64,
235        _source_ip: Option<IpAddr>,
236    ) -> Result<(), QuotaExceeded> {
237        Ok(())
238    }
239}
240
241// ===========================================================================
242// Tests
243// ===========================================================================
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248    use multistore::middleware::CompletedRequest;
249    use multistore::types::{ResolvedIdentity, S3Operation};
250    use std::sync::atomic::{AtomicU64, Ordering};
251    use std::sync::Arc;
252
253    // -- Test helpers ---------------------------------------------------------
254
255    struct RecordingRecorder {
256        last_bytes: Arc<AtomicU64>,
257        call_count: Arc<AtomicU64>,
258    }
259
260    impl RecordingRecorder {
261        fn new() -> (Self, Arc<AtomicU64>, Arc<AtomicU64>) {
262            let last_bytes = Arc::new(AtomicU64::new(0));
263            let call_count = Arc::new(AtomicU64::new(0));
264            (
265                Self {
266                    last_bytes: Arc::clone(&last_bytes),
267                    call_count: Arc::clone(&call_count),
268                },
269                last_bytes,
270                call_count,
271            )
272        }
273    }
274
275    impl UsageRecorder for RecordingRecorder {
276        async fn record_operation<'a>(&'a self, event: UsageEvent<'a>) {
277            self.last_bytes
278                .store(event.bytes_transferred, Ordering::SeqCst);
279            self.call_count.fetch_add(1, Ordering::SeqCst);
280        }
281    }
282
283    struct RejectingChecker {
284        message: String,
285    }
286
287    impl QuotaChecker for RejectingChecker {
288        async fn check_quota<'a>(
289            &'a self,
290            _identity: &'a ResolvedIdentity,
291            _operation: &'a S3Operation,
292            _bucket: Option<&'a str>,
293            _estimated_bytes: u64,
294            _source_ip: Option<IpAddr>,
295        ) -> Result<(), QuotaExceeded> {
296            Err(QuotaExceeded {
297                message: self.message.clone(),
298            })
299        }
300    }
301
302    struct CapturingChecker {
303        last_estimated_bytes: Arc<AtomicU64>,
304    }
305
306    impl CapturingChecker {
307        fn new() -> (Self, Arc<AtomicU64>) {
308            let last_estimated_bytes = Arc::new(AtomicU64::new(u64::MAX));
309            (
310                Self {
311                    last_estimated_bytes: Arc::clone(&last_estimated_bytes),
312                },
313                last_estimated_bytes,
314            )
315        }
316    }
317
318    impl QuotaChecker for CapturingChecker {
319        async fn check_quota<'a>(
320            &'a self,
321            _identity: &'a ResolvedIdentity,
322            _operation: &'a S3Operation,
323            _bucket: Option<&'a str>,
324            estimated_bytes: u64,
325            _source_ip: Option<IpAddr>,
326        ) -> Result<(), QuotaExceeded> {
327            self.last_estimated_bytes
328                .store(estimated_bytes, Ordering::SeqCst);
329            Ok(())
330        }
331    }
332
333    // -- Tests ----------------------------------------------------------------
334
335    // Tests for `handle` use the ProxyGateway integration tests in core.
336    // Here we test the quota checking and after_dispatch logic directly.
337
338    #[test]
339    fn rejecting_checker_returns_error() {
340        let checker = RejectingChecker {
341            message: "over limit".into(),
342        };
343
344        let result = futures::executor::block_on(async {
345            checker
346                .check_quota(
347                    &ResolvedIdentity::Anonymous,
348                    &S3Operation::ListBuckets,
349                    Some("test"),
350                    0,
351                    None,
352                )
353                .await
354        });
355
356        let err = result.unwrap_err();
357        assert_eq!(err.message, "over limit");
358    }
359
360    #[test]
361    fn noop_checker_allows_request() {
362        let result = futures::executor::block_on(async {
363            NoopQuotaChecker
364                .check_quota(
365                    &ResolvedIdentity::Anonymous,
366                    &S3Operation::ListBuckets,
367                    None,
368                    1_000_000,
369                    None,
370                )
371                .await
372        });
373
374        assert!(result.is_ok());
375    }
376
377    #[test]
378    fn capturing_checker_receives_estimated_bytes() {
379        let (checker, captured_bytes) = CapturingChecker::new();
380
381        let _result = futures::executor::block_on(async {
382            checker
383                .check_quota(
384                    &ResolvedIdentity::Anonymous,
385                    &S3Operation::ListBuckets,
386                    Some("test"),
387                    42_000,
388                    None,
389                )
390                .await
391        });
392
393        assert_eq!(captured_bytes.load(Ordering::SeqCst), 42_000);
394    }
395
396    #[test]
397    fn after_dispatch_records_usage() {
398        let (recorder, last_bytes, call_count) = RecordingRecorder::new();
399        let middleware = MeteringMiddleware::new(NoopQuotaChecker, recorder);
400
401        futures::executor::block_on(async {
402            let completed = CompletedRequest {
403                request_id: "req-1",
404                identity: None,
405                operation: None,
406                bucket: Some("my-bucket"),
407                status: 200,
408                response_bytes: Some(1024),
409                request_bytes: None,
410                was_forwarded: true,
411                source_ip: None,
412            };
413            Middleware::after_dispatch(&middleware, &completed).await;
414        });
415
416        assert_eq!(call_count.load(Ordering::SeqCst), 1);
417        assert_eq!(last_bytes.load(Ordering::SeqCst), 1024);
418    }
419
420    #[test]
421    fn after_dispatch_falls_back_to_request_bytes() {
422        let (recorder, last_bytes, _) = RecordingRecorder::new();
423        let middleware = MeteringMiddleware::new(NoopQuotaChecker, recorder);
424
425        futures::executor::block_on(async {
426            let completed = CompletedRequest {
427                request_id: "req-2",
428                identity: None,
429                operation: None,
430                bucket: None,
431                status: 200,
432                response_bytes: None,
433                request_bytes: Some(512),
434                was_forwarded: false,
435                source_ip: None,
436            };
437            Middleware::after_dispatch(&middleware, &completed).await;
438        });
439
440        assert_eq!(last_bytes.load(Ordering::SeqCst), 512);
441    }
442
443    #[test]
444    fn after_dispatch_defaults_to_zero_bytes() {
445        let (recorder, last_bytes, call_count) = RecordingRecorder::new();
446        let middleware = MeteringMiddleware::new(NoopQuotaChecker, recorder);
447
448        futures::executor::block_on(async {
449            let completed = CompletedRequest {
450                request_id: "req-3",
451                identity: None,
452                operation: None,
453                bucket: None,
454                status: 500,
455                response_bytes: None,
456                request_bytes: None,
457                was_forwarded: false,
458                source_ip: None,
459            };
460            Middleware::after_dispatch(&middleware, &completed).await;
461        });
462
463        assert_eq!(call_count.load(Ordering::SeqCst), 1);
464        assert_eq!(last_bytes.load(Ordering::SeqCst), 0);
465    }
466}