use std::future::Future;
use std::net::IpAddr;
use multistore::api::response::ErrorResponse;
use multistore::error::ProxyError;
use multistore::maybe_send::{MaybeSend, MaybeSync};
use multistore::middleware::{CompletedRequest, DispatchContext, Middleware, Next};
use multistore::route_handler::{HandlerAction, ProxyResponseBody, ProxyResult};
use multistore::types::{ResolvedIdentity, S3Operation};
use bytes::Bytes;
use http::HeaderMap;
pub struct UsageEvent<'a> {
pub request_id: &'a str,
pub identity: Option<&'a ResolvedIdentity>,
pub operation: Option<&'a S3Operation>,
pub bucket: Option<&'a str>,
pub status: u16,
pub bytes_transferred: u64,
pub was_forwarded: bool,
pub source_ip: Option<IpAddr>,
}
#[derive(Debug)]
pub struct QuotaExceeded {
pub message: String,
}
pub trait UsageRecorder: MaybeSend + MaybeSync + 'static {
fn record_operation<'a>(
&'a self,
event: UsageEvent<'a>,
) -> impl Future<Output = ()> + MaybeSend + 'a;
}
pub trait QuotaChecker: MaybeSend + MaybeSync + 'static {
fn check_quota<'a>(
&'a self,
identity: &'a ResolvedIdentity,
operation: &'a S3Operation,
bucket: Option<&'a str>,
estimated_bytes: u64,
source_ip: Option<IpAddr>,
) -> impl Future<Output = Result<(), QuotaExceeded>> + MaybeSend + 'a;
}
pub struct MeteringMiddleware<Q, U> {
quota_checker: Q,
usage_recorder: U,
}
impl<Q, U> MeteringMiddleware<Q, U> {
pub fn new(quota_checker: Q, usage_recorder: U) -> Self {
Self {
quota_checker,
usage_recorder,
}
}
}
impl<Q: QuotaChecker, U: UsageRecorder> Middleware for MeteringMiddleware<Q, U> {
async fn handle<'a>(
&'a self,
ctx: DispatchContext<'a>,
next: Next<'a>,
) -> Result<HandlerAction, ProxyError> {
let estimated_bytes = ctx
.headers
.get("content-length")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(0);
let bucket_name = ctx.bucket_config.as_ref().map(|b| b.name.as_str());
if let Err(_exceeded) = self
.quota_checker
.check_quota(
ctx.identity,
ctx.operation,
bucket_name,
estimated_bytes,
ctx.source_ip,
)
.await
{
tracing::warn!(bucket = bucket_name, "quota exceeded, returning 429");
let xml = ErrorResponse::slow_down(ctx.request_id).to_xml();
let mut headers = HeaderMap::new();
headers.insert("content-type", "application/xml".parse().unwrap());
return Ok(HandlerAction::Response(ProxyResult {
status: 429,
headers,
body: ProxyResponseBody::Bytes(Bytes::from(xml)),
}));
}
next.run(ctx).await
}
fn after_dispatch(
&self,
completed: &CompletedRequest<'_>,
) -> impl Future<Output = ()> + MaybeSend + '_ {
let request_id = completed.request_id.to_owned();
let identity = completed.identity.cloned();
let operation = completed.operation.cloned();
let bucket = completed.bucket.map(str::to_owned);
let status = completed.status;
let bytes_transferred = completed
.response_bytes
.or(completed.request_bytes)
.unwrap_or(0);
let was_forwarded = completed.was_forwarded;
let source_ip = completed.source_ip;
async move {
self.usage_recorder
.record_operation(UsageEvent {
request_id: &request_id,
identity: identity.as_ref(),
operation: operation.as_ref(),
bucket: bucket.as_deref(),
status,
bytes_transferred,
was_forwarded,
source_ip,
})
.await;
}
}
}
pub struct NoopRecorder;
impl UsageRecorder for NoopRecorder {
async fn record_operation<'a>(&'a self, _event: UsageEvent<'a>) {}
}
pub struct NoopQuotaChecker;
impl QuotaChecker for NoopQuotaChecker {
async fn check_quota<'a>(
&'a self,
_identity: &'a ResolvedIdentity,
_operation: &'a S3Operation,
_bucket: Option<&'a str>,
_estimated_bytes: u64,
_source_ip: Option<IpAddr>,
) -> Result<(), QuotaExceeded> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use multistore::middleware::CompletedRequest;
use multistore::types::{ResolvedIdentity, S3Operation};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
struct RecordingRecorder {
last_bytes: Arc<AtomicU64>,
call_count: Arc<AtomicU64>,
}
impl RecordingRecorder {
fn new() -> (Self, Arc<AtomicU64>, Arc<AtomicU64>) {
let last_bytes = Arc::new(AtomicU64::new(0));
let call_count = Arc::new(AtomicU64::new(0));
(
Self {
last_bytes: Arc::clone(&last_bytes),
call_count: Arc::clone(&call_count),
},
last_bytes,
call_count,
)
}
}
impl UsageRecorder for RecordingRecorder {
async fn record_operation<'a>(&'a self, event: UsageEvent<'a>) {
self.last_bytes
.store(event.bytes_transferred, Ordering::SeqCst);
self.call_count.fetch_add(1, Ordering::SeqCst);
}
}
struct RejectingChecker {
message: String,
}
impl QuotaChecker for RejectingChecker {
async fn check_quota<'a>(
&'a self,
_identity: &'a ResolvedIdentity,
_operation: &'a S3Operation,
_bucket: Option<&'a str>,
_estimated_bytes: u64,
_source_ip: Option<IpAddr>,
) -> Result<(), QuotaExceeded> {
Err(QuotaExceeded {
message: self.message.clone(),
})
}
}
struct CapturingChecker {
last_estimated_bytes: Arc<AtomicU64>,
}
impl CapturingChecker {
fn new() -> (Self, Arc<AtomicU64>) {
let last_estimated_bytes = Arc::new(AtomicU64::new(u64::MAX));
(
Self {
last_estimated_bytes: Arc::clone(&last_estimated_bytes),
},
last_estimated_bytes,
)
}
}
impl QuotaChecker for CapturingChecker {
async fn check_quota<'a>(
&'a self,
_identity: &'a ResolvedIdentity,
_operation: &'a S3Operation,
_bucket: Option<&'a str>,
estimated_bytes: u64,
_source_ip: Option<IpAddr>,
) -> Result<(), QuotaExceeded> {
self.last_estimated_bytes
.store(estimated_bytes, Ordering::SeqCst);
Ok(())
}
}
#[test]
fn rejecting_checker_returns_error() {
let checker = RejectingChecker {
message: "over limit".into(),
};
let result = futures::executor::block_on(async {
checker
.check_quota(
&ResolvedIdentity::Anonymous,
&S3Operation::ListBuckets,
Some("test"),
0,
None,
)
.await
});
let err = result.unwrap_err();
assert_eq!(err.message, "over limit");
}
#[test]
fn noop_checker_allows_request() {
let result = futures::executor::block_on(async {
NoopQuotaChecker
.check_quota(
&ResolvedIdentity::Anonymous,
&S3Operation::ListBuckets,
None,
1_000_000,
None,
)
.await
});
assert!(result.is_ok());
}
#[test]
fn capturing_checker_receives_estimated_bytes() {
let (checker, captured_bytes) = CapturingChecker::new();
let _result = futures::executor::block_on(async {
checker
.check_quota(
&ResolvedIdentity::Anonymous,
&S3Operation::ListBuckets,
Some("test"),
42_000,
None,
)
.await
});
assert_eq!(captured_bytes.load(Ordering::SeqCst), 42_000);
}
#[test]
fn after_dispatch_records_usage() {
let (recorder, last_bytes, call_count) = RecordingRecorder::new();
let middleware = MeteringMiddleware::new(NoopQuotaChecker, recorder);
futures::executor::block_on(async {
let completed = CompletedRequest {
request_id: "req-1",
identity: None,
operation: None,
bucket: Some("my-bucket"),
status: 200,
response_bytes: Some(1024),
request_bytes: None,
was_forwarded: true,
source_ip: None,
};
Middleware::after_dispatch(&middleware, &completed).await;
});
assert_eq!(call_count.load(Ordering::SeqCst), 1);
assert_eq!(last_bytes.load(Ordering::SeqCst), 1024);
}
#[test]
fn after_dispatch_falls_back_to_request_bytes() {
let (recorder, last_bytes, _) = RecordingRecorder::new();
let middleware = MeteringMiddleware::new(NoopQuotaChecker, recorder);
futures::executor::block_on(async {
let completed = CompletedRequest {
request_id: "req-2",
identity: None,
operation: None,
bucket: None,
status: 200,
response_bytes: None,
request_bytes: Some(512),
was_forwarded: false,
source_ip: None,
};
Middleware::after_dispatch(&middleware, &completed).await;
});
assert_eq!(last_bytes.load(Ordering::SeqCst), 512);
}
#[test]
fn after_dispatch_defaults_to_zero_bytes() {
let (recorder, last_bytes, call_count) = RecordingRecorder::new();
let middleware = MeteringMiddleware::new(NoopQuotaChecker, recorder);
futures::executor::block_on(async {
let completed = CompletedRequest {
request_id: "req-3",
identity: None,
operation: None,
bucket: None,
status: 500,
response_bytes: None,
request_bytes: None,
was_forwarded: false,
source_ip: None,
};
Middleware::after_dispatch(&middleware, &completed).await;
});
assert_eq!(call_count.load(Ordering::SeqCst), 1);
assert_eq!(last_bytes.load(Ordering::SeqCst), 0);
}
}