1use std::future::Future;
29use std::net::IpAddr;
30
31use multistore::api::response::ErrorResponse;
32use multistore::error::ProxyError;
33use multistore::maybe_send::{MaybeSend, MaybeSync};
34use multistore::middleware::{CompletedRequest, DispatchContext, Middleware, Next};
35use multistore::route_handler::{HandlerAction, ProxyResponseBody, ProxyResult};
36use multistore::types::{ResolvedIdentity, S3Operation};
37
38use bytes::Bytes;
39use http::HeaderMap;
40
41pub struct UsageEvent<'a> {
43 pub request_id: &'a str,
45 pub identity: Option<&'a ResolvedIdentity>,
47 pub operation: Option<&'a S3Operation>,
49 pub bucket: Option<&'a str>,
51 pub status: u16,
53 pub bytes_transferred: u64,
57 pub was_forwarded: bool,
59 pub source_ip: Option<IpAddr>,
61}
62
63#[derive(Debug)]
67pub struct QuotaExceeded {
68 pub message: String,
70}
71
72pub trait UsageRecorder: MaybeSend + MaybeSync + 'static {
78 fn record_operation<'a>(
83 &'a self,
84 event: UsageEvent<'a>,
85 ) -> impl Future<Output = ()> + MaybeSend + 'a;
86}
87
88pub trait QuotaChecker: MaybeSend + MaybeSync + 'static {
94 fn check_quota<'a>(
99 &'a self,
100 identity: &'a ResolvedIdentity,
101 operation: &'a S3Operation,
102 bucket: Option<&'a str>,
103 estimated_bytes: u64,
104 source_ip: Option<IpAddr>,
105 ) -> impl Future<Output = Result<(), QuotaExceeded>> + MaybeSend + 'a;
106}
107
108pub struct MeteringMiddleware<Q, U> {
121 quota_checker: Q,
122 usage_recorder: U,
123}
124
125impl<Q, U> MeteringMiddleware<Q, U> {
126 pub fn new(quota_checker: Q, usage_recorder: U) -> Self {
129 Self {
130 quota_checker,
131 usage_recorder,
132 }
133 }
134}
135
136impl<Q: QuotaChecker, U: UsageRecorder> Middleware for MeteringMiddleware<Q, U> {
137 async fn handle<'a>(
138 &'a self,
139 ctx: DispatchContext<'a>,
140 next: Next<'a>,
141 ) -> Result<HandlerAction, ProxyError> {
142 let estimated_bytes = ctx
143 .headers
144 .get("content-length")
145 .and_then(|v| v.to_str().ok())
146 .and_then(|v| v.parse::<u64>().ok())
147 .unwrap_or(0);
148
149 let bucket_name = ctx.bucket_config.as_ref().map(|b| b.name.as_str());
150
151 if let Err(_exceeded) = self
152 .quota_checker
153 .check_quota(
154 ctx.identity,
155 ctx.operation,
156 bucket_name,
157 estimated_bytes,
158 ctx.source_ip,
159 )
160 .await
161 {
162 tracing::warn!(bucket = bucket_name, "quota exceeded, returning 429");
163 let xml = ErrorResponse::slow_down(ctx.request_id).to_xml();
164 let mut headers = HeaderMap::new();
165 headers.insert("content-type", "application/xml".parse().unwrap());
166 return Ok(HandlerAction::Response(ProxyResult {
167 status: 429,
168 headers,
169 body: ProxyResponseBody::Bytes(Bytes::from(xml)),
170 }));
171 }
172
173 next.run(ctx).await
174 }
175
176 fn after_dispatch(
177 &self,
178 completed: &CompletedRequest<'_>,
179 ) -> impl Future<Output = ()> + MaybeSend + '_ {
180 let request_id = completed.request_id.to_owned();
184 let identity = completed.identity.cloned();
185 let operation = completed.operation.cloned();
186 let bucket = completed.bucket.map(str::to_owned);
187 let status = completed.status;
188 let bytes_transferred = completed
189 .response_bytes
190 .or(completed.request_bytes)
191 .unwrap_or(0);
192 let was_forwarded = completed.was_forwarded;
193 let source_ip = completed.source_ip;
194
195 async move {
196 self.usage_recorder
197 .record_operation(UsageEvent {
198 request_id: &request_id,
199 identity: identity.as_ref(),
200 operation: operation.as_ref(),
201 bucket: bucket.as_deref(),
202 status,
203 bytes_transferred,
204 was_forwarded,
205 source_ip,
206 })
207 .await;
208 }
209 }
210}
211
212pub struct NoopRecorder;
219
220impl UsageRecorder for NoopRecorder {
221 async fn record_operation<'a>(&'a self, _event: UsageEvent<'a>) {}
222}
223
224pub struct NoopQuotaChecker;
227
228impl QuotaChecker for NoopQuotaChecker {
229 async fn check_quota<'a>(
230 &'a self,
231 _identity: &'a ResolvedIdentity,
232 _operation: &'a S3Operation,
233 _bucket: Option<&'a str>,
234 _estimated_bytes: u64,
235 _source_ip: Option<IpAddr>,
236 ) -> Result<(), QuotaExceeded> {
237 Ok(())
238 }
239}
240
241#[cfg(test)]
246mod tests {
247 use super::*;
248 use multistore::middleware::CompletedRequest;
249 use multistore::types::{ResolvedIdentity, S3Operation};
250 use std::sync::atomic::{AtomicU64, Ordering};
251 use std::sync::Arc;
252
253 struct RecordingRecorder {
256 last_bytes: Arc<AtomicU64>,
257 call_count: Arc<AtomicU64>,
258 }
259
260 impl RecordingRecorder {
261 fn new() -> (Self, Arc<AtomicU64>, Arc<AtomicU64>) {
262 let last_bytes = Arc::new(AtomicU64::new(0));
263 let call_count = Arc::new(AtomicU64::new(0));
264 (
265 Self {
266 last_bytes: Arc::clone(&last_bytes),
267 call_count: Arc::clone(&call_count),
268 },
269 last_bytes,
270 call_count,
271 )
272 }
273 }
274
275 impl UsageRecorder for RecordingRecorder {
276 async fn record_operation<'a>(&'a self, event: UsageEvent<'a>) {
277 self.last_bytes
278 .store(event.bytes_transferred, Ordering::SeqCst);
279 self.call_count.fetch_add(1, Ordering::SeqCst);
280 }
281 }
282
283 struct RejectingChecker {
284 message: String,
285 }
286
287 impl QuotaChecker for RejectingChecker {
288 async fn check_quota<'a>(
289 &'a self,
290 _identity: &'a ResolvedIdentity,
291 _operation: &'a S3Operation,
292 _bucket: Option<&'a str>,
293 _estimated_bytes: u64,
294 _source_ip: Option<IpAddr>,
295 ) -> Result<(), QuotaExceeded> {
296 Err(QuotaExceeded {
297 message: self.message.clone(),
298 })
299 }
300 }
301
302 struct CapturingChecker {
303 last_estimated_bytes: Arc<AtomicU64>,
304 }
305
306 impl CapturingChecker {
307 fn new() -> (Self, Arc<AtomicU64>) {
308 let last_estimated_bytes = Arc::new(AtomicU64::new(u64::MAX));
309 (
310 Self {
311 last_estimated_bytes: Arc::clone(&last_estimated_bytes),
312 },
313 last_estimated_bytes,
314 )
315 }
316 }
317
318 impl QuotaChecker for CapturingChecker {
319 async fn check_quota<'a>(
320 &'a self,
321 _identity: &'a ResolvedIdentity,
322 _operation: &'a S3Operation,
323 _bucket: Option<&'a str>,
324 estimated_bytes: u64,
325 _source_ip: Option<IpAddr>,
326 ) -> Result<(), QuotaExceeded> {
327 self.last_estimated_bytes
328 .store(estimated_bytes, Ordering::SeqCst);
329 Ok(())
330 }
331 }
332
333 #[test]
339 fn rejecting_checker_returns_error() {
340 let checker = RejectingChecker {
341 message: "over limit".into(),
342 };
343
344 let result = futures::executor::block_on(async {
345 checker
346 .check_quota(
347 &ResolvedIdentity::Anonymous,
348 &S3Operation::ListBuckets,
349 Some("test"),
350 0,
351 None,
352 )
353 .await
354 });
355
356 let err = result.unwrap_err();
357 assert_eq!(err.message, "over limit");
358 }
359
360 #[test]
361 fn noop_checker_allows_request() {
362 let result = futures::executor::block_on(async {
363 NoopQuotaChecker
364 .check_quota(
365 &ResolvedIdentity::Anonymous,
366 &S3Operation::ListBuckets,
367 None,
368 1_000_000,
369 None,
370 )
371 .await
372 });
373
374 assert!(result.is_ok());
375 }
376
377 #[test]
378 fn capturing_checker_receives_estimated_bytes() {
379 let (checker, captured_bytes) = CapturingChecker::new();
380
381 let _result = futures::executor::block_on(async {
382 checker
383 .check_quota(
384 &ResolvedIdentity::Anonymous,
385 &S3Operation::ListBuckets,
386 Some("test"),
387 42_000,
388 None,
389 )
390 .await
391 });
392
393 assert_eq!(captured_bytes.load(Ordering::SeqCst), 42_000);
394 }
395
396 #[test]
397 fn after_dispatch_records_usage() {
398 let (recorder, last_bytes, call_count) = RecordingRecorder::new();
399 let middleware = MeteringMiddleware::new(NoopQuotaChecker, recorder);
400
401 futures::executor::block_on(async {
402 let completed = CompletedRequest {
403 request_id: "req-1",
404 identity: None,
405 operation: None,
406 bucket: Some("my-bucket"),
407 status: 200,
408 response_bytes: Some(1024),
409 request_bytes: None,
410 was_forwarded: true,
411 source_ip: None,
412 };
413 Middleware::after_dispatch(&middleware, &completed).await;
414 });
415
416 assert_eq!(call_count.load(Ordering::SeqCst), 1);
417 assert_eq!(last_bytes.load(Ordering::SeqCst), 1024);
418 }
419
420 #[test]
421 fn after_dispatch_falls_back_to_request_bytes() {
422 let (recorder, last_bytes, _) = RecordingRecorder::new();
423 let middleware = MeteringMiddleware::new(NoopQuotaChecker, recorder);
424
425 futures::executor::block_on(async {
426 let completed = CompletedRequest {
427 request_id: "req-2",
428 identity: None,
429 operation: None,
430 bucket: None,
431 status: 200,
432 response_bytes: None,
433 request_bytes: Some(512),
434 was_forwarded: false,
435 source_ip: None,
436 };
437 Middleware::after_dispatch(&middleware, &completed).await;
438 });
439
440 assert_eq!(last_bytes.load(Ordering::SeqCst), 512);
441 }
442
443 #[test]
444 fn after_dispatch_defaults_to_zero_bytes() {
445 let (recorder, last_bytes, call_count) = RecordingRecorder::new();
446 let middleware = MeteringMiddleware::new(NoopQuotaChecker, recorder);
447
448 futures::executor::block_on(async {
449 let completed = CompletedRequest {
450 request_id: "req-3",
451 identity: None,
452 operation: None,
453 bucket: None,
454 status: 500,
455 response_bytes: None,
456 request_bytes: None,
457 was_forwarded: false,
458 source_ip: None,
459 };
460 Middleware::after_dispatch(&middleware, &completed).await;
461 });
462
463 assert_eq!(call_count.load(Ordering::SeqCst), 1);
464 assert_eq!(last_bytes.load(Ordering::SeqCst), 0);
465 }
466}