use crate::api::list::{
build_list_prefix, build_list_xml, build_list_xml_v1, parse_list_query_params, ListXmlParams,
ListXmlParamsV1,
};
use crate::api::list_rewrite::ListRewrite;
use crate::api::request::{self, HostStyle};
use crate::api::response::{BucketList, ErrorResponse, ListAllMyBucketsResult};
use crate::auth;
use crate::auth::TemporaryCredentialResolver;
use crate::backend::multipart::{build_backend_url, sign_s3_request};
use crate::backend::request_signer::{hash_payload, UNSIGNED_PAYLOAD};
use crate::backend::ForwardResponse;
use crate::backend::ProxyBackend;
use crate::error::ProxyError;
use crate::middleware::{
CompletedRequest, Dispatch, DispatchContext, DispatchFuture, ErasedMiddleware, Middleware, Next,
};
use crate::registry::{BucketRegistry, CredentialRegistry};
use crate::route_handler::{ProxyResponseBody, RequestInfo};
use crate::router::Router;
use crate::types::{Action, BucketConfig, ResolvedIdentity, S3Operation};
use bytes::Bytes;
use http::{HeaderMap, Method};
use object_store::list::PaginatedListOptions;
use std::borrow::Cow;
use std::net::IpAddr;
use std::time::Duration;
use uuid::Uuid;
const PRESIGNED_URL_TTL: Duration = Duration::from_secs(300);
const SIGNED_AWS_CHUNKED_UNSUPPORTED: &str =
"aws-chunked uploads with signed chunks (x-amz-content-sha256: \
STREAMING-AWS4-HMAC-SHA256-PAYLOAD) are not supported; configure the client \
to use a trailing checksum (the default) or multipart";
pub const DEFAULT_USER_AGENT: &str = concat!("multistore/", env!("CARGO_PKG_VERSION"));
pub use crate::route_handler::{
filter_response_headers, ForwardRequest, HandlerAction, PendingRequest, ProxyResult,
RESPONSE_HEADER_DENYLIST,
};
pub enum GatewayResponse<S> {
Response(ProxyResult),
Forward(ForwardResponse<S>),
}
pub struct RequestMetadata {
pub request_id: String,
pub identity: Option<ResolvedIdentity>,
pub operation: Option<S3Operation>,
pub bucket: Option<String>,
pub source_ip: Option<IpAddr>,
}
pub struct ProxyGateway<B, R, C> {
backend: B,
bucket_registry: R,
credential_registry: C,
middleware: Vec<Box<dyn ErasedMiddleware>>,
virtual_host_domain: Option<String>,
credential_resolver: Option<Box<dyn TemporaryCredentialResolver>>,
router: Router,
debug_errors: bool,
user_agent: String,
server_timing: bool,
max_request_body_size: Option<u64>,
}
impl<B, R, C> ProxyGateway<B, R, C>
where
B: ProxyBackend,
R: BucketRegistry,
C: CredentialRegistry,
{
pub fn new(
backend: B,
bucket_registry: R,
credential_registry: C,
virtual_host_domain: Option<String>,
) -> Self {
Self {
backend,
bucket_registry,
credential_registry,
middleware: Vec::new(),
virtual_host_domain,
credential_resolver: None,
router: Router::new(),
debug_errors: false,
user_agent: DEFAULT_USER_AGENT.to_string(),
server_timing: true,
max_request_body_size: None,
}
}
pub fn with_middleware(mut self, middleware: impl Middleware) -> Self {
self.middleware.push(Box::new(middleware));
self
}
pub fn with_credential_resolver(
mut self,
resolver: impl TemporaryCredentialResolver + 'static,
) -> Self {
self.credential_resolver = Some(Box::new(resolver));
self
}
pub fn with_router(mut self, router: Router) -> Self {
self.router = router;
self
}
pub fn with_debug_errors(mut self, enabled: bool) -> Self {
self.debug_errors = enabled;
self
}
pub fn with_user_agent(mut self, user_agent: impl Into<String>) -> Self {
self.user_agent = user_agent.into();
self
}
pub fn with_server_timing(mut self, enabled: bool) -> Self {
self.server_timing = enabled;
self
}
pub fn with_max_request_body_size(mut self, max_bytes: u64) -> Self {
self.max_request_body_size = Some(max_bytes);
self
}
fn check_upload_size(&self, headers: &HeaderMap) -> Result<(), ProxyError> {
if let Some(max) = self.max_request_body_size {
if let Some(len) = content_length(headers) {
if len > max {
tracing::warn!(
content_length = len,
max = max,
"rejecting upload exceeding configured max body size"
);
return Err(ProxyError::EntityTooLarge);
}
}
}
Ok(())
}
fn maybe_inject_server_timing(
&self,
headers: &mut HeaderMap,
total_start: chrono::DateTime<chrono::Utc>,
dispatch_start: Option<chrono::DateTime<chrono::Utc>>,
backend_start: Option<chrono::DateTime<chrono::Utc>>,
) {
if !self.server_timing {
return;
}
let now = chrono::Utc::now();
let total_ms = (now - total_start).num_milliseconds().max(0);
let mut value = format!("total;dur={total_ms}");
if let Some(ds) = dispatch_start {
let dispatch_ms = (now - ds).num_milliseconds().max(0);
value.push_str(&format!(", dispatch;dur={dispatch_ms}"));
}
if let Some(bs) = backend_start {
let backend_ms = (now - bs).num_milliseconds().max(0);
value.push_str(&format!(", backend;dur={backend_ms}"));
}
if let Ok(hv) = value.parse() {
headers.insert("server-timing", hv);
}
}
pub async fn handle_request<CF, Fut, E>(
&self,
req: &RequestInfo<'_>,
body: B::Body,
collect_body: CF,
) -> GatewayResponse<B::ResponseBody>
where
CF: FnOnce(B::Body) -> Fut,
Fut: std::future::Future<Output = Result<Bytes, E>>,
E: std::fmt::Display,
{
let total_start = chrono::Utc::now();
if let Some(action) = self.router.dispatch(req).await {
return match action {
HandlerAction::Response(mut r) => {
self.maybe_inject_server_timing(&mut r.headers, total_start, None, None);
GatewayResponse::Response(r)
}
HandlerAction::Forward(fwd) => {
let backend_start = chrono::Utc::now();
match self.backend.forward(fwd, body).await {
Ok(mut resp) => {
resp.headers = filter_response_headers(&resp.headers);
self.maybe_inject_server_timing(
&mut resp.headers,
total_start,
None,
Some(backend_start),
);
GatewayResponse::Forward(resp)
}
Err(e) => {
let mut r = error_response(&e, req.path, "", self.debug_errors);
self.maybe_inject_server_timing(
&mut r.headers,
total_start,
None,
Some(backend_start),
);
GatewayResponse::Response(r)
}
}
}
HandlerAction::NeedsBody(_) => {
let mut r = error_response(
&ProxyError::Internal("unexpected NeedsBody from route handler".into()),
req.path,
"",
self.debug_errors,
);
self.maybe_inject_server_timing(&mut r.headers, total_start, None, None);
GatewayResponse::Response(r)
}
};
}
let dispatch_start = chrono::Utc::now();
let (action, metadata) = self.resolve_request_with_metadata(req).await;
fn response_body_bytes(body: &ProxyResponseBody) -> Option<u64> {
match body {
ProxyResponseBody::Bytes(b) => Some(b.len() as u64),
ProxyResponseBody::Empty => Some(0),
}
}
let request_bytes = content_length(req.headers);
let (mut response, status, resp_bytes, was_forwarded, backend_start) = match action {
HandlerAction::Response(r) => {
let s = r.status;
let rb = response_body_bytes(&r.body);
(GatewayResponse::Response(r), s, rb, false, None)
}
HandlerAction::Forward(fwd) => {
let backend_start = chrono::Utc::now();
match self.backend.forward(fwd, body).await {
Ok(mut resp) => {
resp.headers = filter_response_headers(&resp.headers);
let s = resp.status;
let cl = resp.content_length;
(
GatewayResponse::Forward(resp),
s,
cl,
true,
Some(backend_start),
)
}
Err(e) => {
let err_resp =
error_response(&e, req.path, &metadata.request_id, self.debug_errors);
let s = err_resp.status;
(
GatewayResponse::Response(err_resp),
s,
None,
true,
Some(backend_start),
)
}
}
}
HandlerAction::NeedsBody(pending) => {
let backend_start = chrono::Utc::now();
match collect_body(body).await {
Ok(bytes) => {
let result = self.handle_with_body(pending, bytes).await;
let s = result.status;
let rb = response_body_bytes(&result.body);
(
GatewayResponse::Response(result),
s,
rb,
false,
Some(backend_start),
)
}
Err(e) => {
tracing::error!(error = %e, "failed to read request body");
let err_resp = error_response(
&ProxyError::Internal("failed to read request body".into()),
"",
&metadata.request_id,
self.debug_errors,
);
let s = err_resp.status;
(
GatewayResponse::Response(err_resp),
s,
None,
false,
Some(backend_start),
)
}
}
}
};
let completed = CompletedRequest {
request_id: &metadata.request_id,
identity: metadata.identity.as_ref(),
operation: metadata.operation.as_ref(),
bucket: metadata.bucket.as_deref(),
status,
response_bytes: resp_bytes,
request_bytes,
was_forwarded,
source_ip: metadata.source_ip,
};
for m in &self.middleware {
m.after_dispatch(&completed).await;
}
match &mut response {
GatewayResponse::Response(ref mut r) => {
self.maybe_inject_server_timing(
&mut r.headers,
total_start,
Some(dispatch_start),
backend_start,
);
}
GatewayResponse::Forward(ref mut fwd) => {
self.maybe_inject_server_timing(
&mut fwd.headers,
total_start,
Some(dispatch_start),
backend_start,
);
}
}
response
}
pub async fn resolve_request(
&self,
method: Method,
path: &str,
query: Option<&str>,
headers: &HeaderMap,
source_ip: Option<IpAddr>,
) -> HandlerAction {
let req = RequestInfo::new(&method, path, query, headers, source_ip);
let (action, _metadata) = self.resolve_request_with_metadata(&req).await;
action
}
pub(crate) async fn resolve_request_with_metadata(
&self,
req: &RequestInfo<'_>,
) -> (HandlerAction, RequestMetadata) {
let request_id = Uuid::new_v4().to_string();
tracing::info!(
request_id = %request_id,
method = %req.method,
path = %req.path,
query = ?req.query,
"incoming request"
);
let host_style = determine_host_style(req.headers, self.virtual_host_domain.as_deref());
let operation = match request::parse_s3_request(
req.method,
req.path,
req.query,
req.headers,
host_style,
) {
Ok(op) => op,
Err(err) => return self.error_result(err, req.path, &request_id, req.source_ip),
};
tracing::debug!(operation = ?operation, "parsed S3 operation");
let identity = match auth::resolve_identity(
req.method,
req.signing_path.unwrap_or(req.path),
req.signing_query.or(req.query).unwrap_or(""),
req.headers,
&self.credential_registry,
self.credential_resolver.as_deref(),
)
.await
{
Ok(id) => id,
Err(err) => return self.error_result(err, req.path, &request_id, req.source_ip),
};
tracing::debug!(identity = ?identity, "resolved identity");
let resolved = if let Some(bucket_name) = operation.bucket() {
match self
.bucket_registry
.get_bucket(bucket_name, &identity, &operation)
.await
{
Ok(resolved) => {
tracing::debug!(
bucket = %bucket_name,
backend_type = %resolved.config.backend_type,
"resolved bucket config"
);
tracing::trace!("authorization passed");
Some(resolved)
}
Err(err) => return self.error_result(err, req.path, &request_id, req.source_ip),
}
} else {
None
};
let ctx = DispatchContext {
identity: &identity,
operation: &operation,
bucket_config: resolved.as_ref().map(|r| Cow::Borrowed(&r.config)),
headers: req.headers,
source_ip: req.source_ip,
request_id: &request_id,
list_rewrite: resolved.as_ref().and_then(|r| r.list_rewrite.as_ref()),
display_name: resolved.as_ref().and_then(|r| r.display_name.as_deref()),
extensions: http::Extensions::new(),
};
let next = Next::new(&self.middleware, self);
let metadata = RequestMetadata {
request_id: request_id.clone(),
identity: Some(identity.clone()),
operation: Some(operation.clone()),
bucket: operation.bucket().map(str::to_string),
source_ip: req.source_ip,
};
match next.run(ctx).await {
Ok(action) => {
match &action {
HandlerAction::Response(resp) => {
tracing::info!(
request_id = %request_id,
status = resp.status,
"request completed"
);
}
HandlerAction::Forward(fwd) => {
tracing::info!(
request_id = %request_id,
method = %fwd.method,
"forwarding via presigned URL"
);
}
HandlerAction::NeedsBody(_) => {
tracing::debug!(
request_id = %request_id,
"request needs body (multipart)"
);
}
}
(action, metadata)
}
Err(err) => self.error_result(err, req.path, &request_id, req.source_ip),
}
}
fn error_result(
&self,
err: ProxyError,
path: &str,
request_id: &str,
source_ip: Option<IpAddr>,
) -> (HandlerAction, RequestMetadata) {
tracing::warn!(
request_id = %request_id,
error = %err,
status = err.status_code(),
s3_code = %err.s3_error_code(),
"request failed"
);
let metadata = RequestMetadata {
request_id: request_id.to_string(),
identity: None,
operation: None,
bucket: None,
source_ip,
};
(
HandlerAction::Response(error_response(&err, path, request_id, self.debug_errors)),
metadata,
)
}
pub async fn handle_with_body(&self, pending: PendingRequest, body: Bytes) -> ProxyResult {
let result = match &pending.operation {
S3Operation::DeleteObjects { .. } => self.execute_delete_objects(&pending, body).await,
_ => self.execute_multipart(&pending, body).await,
};
match result {
Ok(result) => {
tracing::info!(
request_id = %pending.request_id,
status = result.status,
"body request completed"
);
result
}
Err(err) => {
tracing::warn!(
request_id = %pending.request_id,
error = %err,
status = err.status_code(),
s3_code = %err.s3_error_code(),
"body request failed"
);
error_response(
&err,
pending.operation.key(),
&pending.request_id,
self.debug_errors,
)
}
}
}
async fn dispatch_operation(
&self,
ctx: &DispatchContext<'_>,
) -> Result<HandlerAction, ProxyError> {
let original_headers = ctx.headers;
let list_rewrite = ctx.list_rewrite;
let request_id = ctx.request_id;
let operation = ctx.operation;
if matches!(operation, S3Operation::ListBuckets) {
let buckets = self.bucket_registry.list_buckets(ctx.identity).await?;
tracing::info!(count = buckets.len(), "listing virtual buckets");
let xml = ListAllMyBucketsResult {
owner: self.bucket_registry.bucket_owner(),
buckets: BucketList { buckets },
}
.to_xml();
let mut resp_headers = HeaderMap::new();
resp_headers.insert("content-type", "application/xml".parse().unwrap());
return Ok(HandlerAction::Response(ProxyResult {
status: 200,
headers: resp_headers,
body: ProxyResponseBody::from_bytes(Bytes::from(xml)),
}));
}
let bucket_config = ctx
.bucket_config
.as_deref()
.expect("bucket_config must be set for bucket-targeted operations");
let pending = || PendingRequest {
operation: operation.clone(),
bucket_config: bucket_config.clone(),
original_headers: original_headers.clone(),
request_id: request_id.to_string(),
identity: ctx.identity.clone(),
};
match operation {
S3Operation::GetObject { key, .. } => {
let fwd = self
.build_forward(
Method::GET,
bucket_config,
key,
original_headers,
&[
"range",
"if-match",
"if-none-match",
"if-modified-since",
"if-unmodified-since",
],
request_id,
)
.await?;
tracing::debug!(path = fwd.url.path(), "GET via presigned URL");
Ok(HandlerAction::Forward(fwd))
}
S3Operation::HeadObject { key, .. } => {
let fwd = self
.build_forward(
Method::HEAD,
bucket_config,
key,
original_headers,
&[
"range",
"if-match",
"if-none-match",
"if-modified-since",
"if-unmodified-since",
],
request_id,
)
.await?;
tracing::debug!(path = fwd.url.path(), "HEAD via presigned URL");
Ok(HandlerAction::Forward(fwd))
}
S3Operation::PutObject { key, .. } => {
self.check_upload_size(original_headers)?;
if let Some(fwd) = self
.try_streaming_forward(bucket_config, operation, original_headers, request_id)
.await?
{
return Ok(HandlerAction::Forward(fwd));
}
let fwd = self
.build_forward(
Method::PUT,
bucket_config,
key,
original_headers,
&[
"content-type",
"content-length",
"content-md5",
"content-disposition",
"content-encoding",
"content-language",
"cache-control",
"expires",
],
request_id,
)
.await?;
tracing::debug!(path = fwd.url.path(), "PUT via presigned URL");
Ok(HandlerAction::Forward(fwd))
}
S3Operation::DeleteObject { key, .. } => {
let fwd = self
.build_forward(
Method::DELETE,
bucket_config,
key,
original_headers,
&[],
request_id,
)
.await?;
tracing::debug!(path = fwd.url.path(), "DELETE via presigned URL");
Ok(HandlerAction::Forward(fwd))
}
S3Operation::ListBucket { raw_query, .. } => {
let result = self
.handle_list(
bucket_config,
raw_query.as_deref(),
list_rewrite,
ctx.display_name,
)
.await?;
Ok(HandlerAction::Response(result))
}
S3Operation::UploadPart { .. } => {
Self::require_s3_backend(bucket_config)?;
self.check_upload_size(original_headers)?;
if let Some(fwd) = self
.try_streaming_forward(bucket_config, operation, original_headers, request_id)
.await?
{
return Ok(HandlerAction::Forward(fwd));
}
Ok(HandlerAction::NeedsBody(pending()))
}
S3Operation::CreateMultipartUpload { .. }
| S3Operation::CompleteMultipartUpload { .. }
| S3Operation::AbortMultipartUpload { .. } => {
Self::require_s3_backend(bucket_config)?;
Ok(HandlerAction::NeedsBody(pending()))
}
S3Operation::DeleteObjects { .. } => {
if !bucket_config.is_s3_backend() {
return Err(ProxyError::NotImplemented(format!(
"batch delete not supported for '{}' backends",
bucket_config.backend_type
)));
}
self.check_upload_size(original_headers)?;
Ok(HandlerAction::NeedsBody(pending()))
}
_ => Err(ProxyError::Internal("unexpected operation".into())),
}
}
async fn build_forward(
&self,
method: Method,
config: &BucketConfig,
key: &str,
original_headers: &HeaderMap,
forward_header_names: &[&'static str],
request_id: &str,
) -> Result<ForwardRequest, ProxyError> {
let signer = self.backend.create_signer(config)?;
let path = build_object_path(config, key);
let url = signer
.signed_url(method.clone(), &path, PRESIGNED_URL_TTL)
.await
.map_err(ProxyError::from_object_store_error)?;
let mut fwd_headers = HeaderMap::new();
for name in forward_header_names {
if let Some(v) = original_headers.get(*name) {
fwd_headers.insert(*name, v.clone());
}
}
fwd_headers.insert(http::header::USER_AGENT, self.user_agent.parse().unwrap());
Ok(ForwardRequest {
method,
url,
headers: fwd_headers,
request_id: request_id.to_string(),
})
}
async fn try_streaming_forward(
&self,
config: &BucketConfig,
operation: &S3Operation,
original_headers: &HeaderMap,
request_id: &str,
) -> Result<Option<ForwardRequest>, ProxyError> {
match crate::aws_chunked::streaming_upload(original_headers) {
Some((crate::aws_chunked::StreamingUpload::Unsigned, sentinel)) => {
if !config.is_s3_backend() {
return Err(ProxyError::InvalidRequest(format!(
"aws-chunked streaming uploads are not supported for '{}' backends",
config.backend_type
)));
}
Ok(Some(
self.build_streaming_forward(
config,
operation,
sentinel,
original_headers,
request_id,
)
.await?,
))
}
Some((crate::aws_chunked::StreamingUpload::Signed, _)) => Err(
ProxyError::NotImplemented(SIGNED_AWS_CHUNKED_UNSUPPORTED.to_string()),
),
None => Ok(None),
}
}
fn require_s3_backend(config: &BucketConfig) -> Result<(), ProxyError> {
if config.is_s3_backend() {
Ok(())
} else {
Err(ProxyError::InvalidRequest(format!(
"multipart operations not supported for '{}' backends",
config.backend_type
)))
}
}
async fn build_streaming_forward(
&self,
config: &BucketConfig,
operation: &S3Operation,
payload_hash: &str,
original_headers: &HeaderMap,
request_id: &str,
) -> Result<ForwardRequest, ProxyError> {
let url = url::Url::parse(&build_backend_url(config, operation)?)
.map_err(|e| ProxyError::Internal(format!("invalid backend URL: {e}")))?;
let mut headers = HeaderMap::new();
for name in &[
"content-type",
"content-encoding",
"x-amz-decoded-content-length",
"x-amz-trailer",
] {
if let Some(v) = original_headers.get(*name) {
headers.insert(*name, v.clone());
}
}
sign_s3_request(
&Method::PUT,
url.as_str(),
&mut headers,
config,
payload_hash,
)?;
if let Some(cl) = original_headers.get(http::header::CONTENT_LENGTH) {
headers.insert(http::header::CONTENT_LENGTH, cl.clone());
}
headers.insert(http::header::USER_AGENT, self.user_agent.parse().unwrap());
tracing::debug!(path = url.path(), "aws-chunked write via streaming re-sign");
Ok(ForwardRequest {
method: Method::PUT,
url,
headers,
request_id: request_id.to_string(),
})
}
async fn handle_list(
&self,
config: &BucketConfig,
raw_query: Option<&str>,
list_rewrite: Option<&ListRewrite>,
display_name: Option<&str>,
) -> Result<ProxyResult, ProxyError> {
let store = self.backend.create_paginated_store(config)?;
let list_params = parse_list_query_params(raw_query);
let client_prefix = &list_params.prefix;
let delimiter = &list_params.delimiter;
let full_prefix = build_list_prefix(config, client_prefix);
let offset = if list_params.is_v2 {
list_params
.start_after
.as_ref()
.map(|sa| build_list_prefix(config, sa))
} else {
list_params
.marker
.as_ref()
.map(|m| build_list_prefix(config, m))
};
tracing::debug!(
full_prefix = %full_prefix,
delimiter = %delimiter,
max_keys = list_params.max_keys,
has_page_token = list_params.continuation_token.is_some(),
"LIST via PaginatedListStore"
);
let prefix = if full_prefix.is_empty() {
None
} else {
Some(full_prefix.as_str())
};
let opts = PaginatedListOptions {
offset,
delimiter: if delimiter.is_empty() {
None
} else {
Some(Cow::Owned(delimiter.clone()))
},
max_keys: Some(list_params.max_keys),
page_token: list_params.continuation_token.clone(),
..Default::default()
};
let paginated = store
.list_paginated(prefix, opts)
.await
.map_err(ProxyError::from_object_store_error)?;
let bucket_name = display_name.unwrap_or(&config.name);
let is_truncated = paginated.page_token.is_some();
let xml = if list_params.is_v2 {
let key_count = paginated.result.objects.len() + paginated.result.common_prefixes.len();
build_list_xml(
&ListXmlParams {
bucket_name,
client_prefix,
delimiter,
max_keys: list_params.max_keys,
is_truncated,
key_count,
start_after: &list_params.start_after,
continuation_token: &list_params.continuation_token,
next_continuation_token: paginated.page_token,
encoding_type: &list_params.encoding_type,
},
&paginated.result,
config,
list_rewrite,
)?
} else {
let next_marker = if is_truncated {
paginated
.result
.objects
.last()
.map(|obj| obj.location.to_string())
} else {
None
};
build_list_xml_v1(
&ListXmlParamsV1 {
bucket_name,
client_prefix,
delimiter,
max_keys: list_params.max_keys,
is_truncated,
marker: list_params.marker.as_deref().unwrap_or(""),
next_marker,
encoding_type: &list_params.encoding_type,
},
&paginated.result,
config,
list_rewrite,
)?
};
let mut resp_headers = HeaderMap::new();
resp_headers.insert("content-type", "application/xml".parse().unwrap());
Ok(ProxyResult {
status: 200,
headers: resp_headers,
body: ProxyResponseBody::Bytes(Bytes::from(xml)),
})
}
async fn execute_multipart(
&self,
pending: &PendingRequest,
body: Bytes,
) -> Result<ProxyResult, ProxyError> {
let backend_url = build_backend_url(&pending.bucket_config, &pending.operation)?;
tracing::debug!(backend_url = %backend_url, "multipart via raw HTTP");
let mut headers = HeaderMap::new();
for (name, val) in pending.original_headers.iter() {
let n = name.as_str();
if matches!(n, "content-type" | "content-length" | "content-md5")
|| n.starts_with("x-amz-checksum")
|| n == "x-amz-sdk-checksum-algorithm"
{
headers.insert(name.clone(), val.clone());
}
}
headers.insert(http::header::USER_AGENT, self.user_agent.parse().unwrap());
let payload_hash = if body.is_empty() {
UNSIGNED_PAYLOAD.to_string()
} else {
hash_payload(&body)
};
let method = pending.operation.method();
sign_s3_request(
&method,
&backend_url,
&mut headers,
&pending.bucket_config,
&payload_hash,
)?;
let raw_resp = self
.backend
.send_raw(method, backend_url, headers, body)
.await?;
tracing::debug!(status = raw_resp.status, "multipart backend response");
Ok(ProxyResult {
status: raw_resp.status,
headers: filter_response_headers(&raw_resp.headers),
body: ProxyResponseBody::from_bytes(raw_resp.body),
})
}
async fn execute_delete_objects(
&self,
pending: &PendingRequest,
body: Bytes,
) -> Result<ProxyResult, ProxyError> {
use crate::api::delete;
let config = &pending.bucket_config;
let bucket = pending.operation.bucket().unwrap_or_default();
let request = delete::DeleteRequest::parse(&body)?;
let quiet = request.quiet;
let mut allowed_backend: Vec<String> = Vec::new();
let mut errors: Vec<delete::DeleteError> = Vec::new();
for key in request.keys() {
if self
.bucket_registry
.authorize_key(bucket, &pending.identity, Action::DeleteObject, key)
.await
{
allowed_backend.push(apply_backend_prefix(config, key));
} else {
errors.push(delete::DeleteError {
key: key.to_string(),
code: "AccessDenied".into(),
message: "Access Denied".into(),
});
}
}
let mut deleted_client: Vec<String> = Vec::new();
if !allowed_backend.is_empty() {
let backend_body = Bytes::from(delete::build_backend_delete_body(&allowed_backend));
let backend_url = build_backend_url(config, &pending.operation)?;
let mut headers = HeaderMap::new();
headers.insert("content-type", "application/xml".parse().unwrap());
headers.insert(
"content-md5",
content_md5(&backend_body)
.parse()
.map_err(|_| ProxyError::Internal("invalid content-md5 header".into()))?,
);
headers.insert(http::header::USER_AGENT, self.user_agent.parse().unwrap());
let payload_hash = hash_payload(&backend_body);
sign_s3_request(
&Method::POST,
&backend_url,
&mut headers,
config,
&payload_hash,
)?;
let raw_resp = self
.backend
.send_raw(Method::POST, backend_url, headers, backend_body)
.await?;
tracing::debug!(status = raw_resp.status, "batch delete backend response");
if raw_resp.status >= 300 {
return Err(ProxyError::BackendError(format!(
"backend rejected batch delete with status {}",
raw_resp.status
)));
}
match delete::parse_backend_result(&raw_resp.body) {
Ok(outcome) => {
for k in outcome.deleted {
deleted_client.push(strip_backend_prefix(config, &k));
}
for mut e in outcome.errors {
e.key = strip_backend_prefix(config, &e.key);
errors.push(e);
}
}
Err(e) => {
tracing::error!(error = %e, "backend returned an unparseable delete result");
return Err(ProxyError::BackendError(
"backend returned an unparseable delete result".into(),
));
}
}
}
let xml = delete::build_delete_result(&deleted_client, &errors, quiet);
let mut resp_headers = HeaderMap::new();
resp_headers.insert("content-type", "application/xml".parse().unwrap());
Ok(ProxyResult {
status: 200,
headers: resp_headers,
body: ProxyResponseBody::from_bytes(Bytes::from(xml)),
})
}
}
impl<B, R, C> Dispatch for ProxyGateway<B, R, C>
where
B: ProxyBackend,
R: BucketRegistry,
C: CredentialRegistry,
{
fn dispatch<'a>(&'a self, ctx: DispatchContext<'a>) -> DispatchFuture<'a> {
Box::pin(async move { self.dispatch_operation(&ctx).await })
}
}
fn determine_host_style(headers: &HeaderMap, virtual_host_domain: Option<&str>) -> HostStyle {
if let Some(domain) = virtual_host_domain {
if let Some(host) = headers.get("host").and_then(|v| v.to_str().ok()) {
let host = host.split(':').next().unwrap_or(host);
if let Some(bucket) = host.strip_suffix(&format!(".{}", domain)) {
return HostStyle::VirtualHosted {
bucket: bucket.to_string(),
};
}
}
}
HostStyle::Path
}
fn error_response(err: &ProxyError, resource: &str, request_id: &str, debug: bool) -> ProxyResult {
let xml = ErrorResponse::from_proxy_error(err, resource, request_id, debug).to_xml();
let body = ProxyResponseBody::from_bytes(Bytes::from(xml));
let mut headers = HeaderMap::new();
headers.insert("content-type", "application/xml".parse().unwrap());
ProxyResult {
status: err.status_code(),
headers,
body,
}
}
fn build_object_path(config: &BucketConfig, key: &str) -> object_store::path::Path {
object_store::path::Path::from(apply_backend_prefix(config, key))
}
fn content_length(headers: &HeaderMap) -> Option<u64> {
headers
.get(http::header::CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
}
fn apply_backend_prefix(config: &BucketConfig, key: &str) -> String {
match &config.backend_prefix {
Some(prefix) => {
let p = prefix.trim_end_matches('/');
if p.is_empty() {
key.to_string()
} else {
format!("{p}/{key}")
}
}
None => key.to_string(),
}
}
fn strip_backend_prefix(config: &BucketConfig, key: &str) -> String {
match &config.backend_prefix {
Some(prefix) => {
let p = prefix.trim_end_matches('/');
if p.is_empty() {
return key.to_string();
}
key.strip_prefix(p)
.and_then(|rest| rest.strip_prefix('/'))
.unwrap_or(key)
.to_string()
}
None => key.to_string(),
}
}
fn content_md5(body: &[u8]) -> String {
use base64::Engine;
use md5::{Digest, Md5};
base64::engine::general_purpose::STANDARD.encode(Md5::digest(body))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::api::response::BucketEntry;
use crate::backend::RawResponse;
use crate::registry::{BucketRegistry, CredentialRegistry, ResolvedBucket};
use crate::types::{ResolvedIdentity, RoleConfig, StoredCredential};
use object_store::list::PaginatedListStore;
use object_store::signer::Signer;
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Clone)]
struct MockBackend;
impl ProxyBackend for MockBackend {
type ResponseBody = ();
type Body = ();
async fn forward(
&self,
_request: ForwardRequest,
_body: (),
) -> Result<ForwardResponse<()>, ProxyError> {
unimplemented!("not needed for resolve_request tests")
}
fn create_paginated_store(
&self,
_config: &BucketConfig,
) -> Result<Box<dyn PaginatedListStore>, ProxyError> {
unimplemented!("not needed for forward tests")
}
fn create_signer(&self, config: &BucketConfig) -> Result<Arc<dyn Signer>, ProxyError> {
crate::backend::build_signer(config)
}
async fn send_raw(
&self,
_method: http::Method,
_url: String,
_headers: HeaderMap,
_body: Bytes,
) -> Result<RawResponse, ProxyError> {
unimplemented!("not needed for forward tests")
}
}
#[derive(Clone)]
struct MockRegistry;
impl BucketRegistry for MockRegistry {
async fn get_bucket(
&self,
name: &str,
_identity: &ResolvedIdentity,
_operation: &S3Operation,
) -> Result<ResolvedBucket, ProxyError> {
Ok(ResolvedBucket {
config: test_bucket_config(name),
list_rewrite: None,
display_name: None,
})
}
async fn list_buckets(
&self,
_identity: &ResolvedIdentity,
) -> Result<Vec<BucketEntry>, ProxyError> {
Ok(vec![])
}
}
#[derive(Clone)]
struct MockCreds;
impl CredentialRegistry for MockCreds {
async fn get_credential(
&self,
_access_key_id: &str,
) -> Result<Option<StoredCredential>, ProxyError> {
Ok(None)
}
async fn get_role(&self, _role_id: &str) -> Result<Option<RoleConfig>, ProxyError> {
Ok(None)
}
}
fn test_bucket_config(name: &str) -> BucketConfig {
let mut backend_options = HashMap::new();
backend_options.insert(
"endpoint".into(),
"https://s3.us-east-1.amazonaws.com".into(),
);
backend_options.insert("bucket_name".into(), "backend-bucket".into());
backend_options.insert("region".into(), "us-east-1".into());
backend_options.insert("access_key_id".into(), "AKIAIOSFODNN7EXAMPLE".into());
backend_options.insert(
"secret_access_key".into(),
"wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".into(),
);
let backend_type = if name.starts_with("azure") {
"azure"
} else {
"s3"
};
BucketConfig {
name: name.to_string(),
backend_type: backend_type.into(),
backend_prefix: None,
anonymous_access: true,
allowed_roles: vec![],
backend_options,
}
}
fn run<F: std::future::Future>(f: F) -> F::Output {
futures::executor::block_on(f)
}
fn gateway() -> ProxyGateway<MockBackend, MockRegistry, MockCreds> {
ProxyGateway::new(MockBackend, MockRegistry, MockCreds, None)
}
#[test]
fn get_forward_preserves_range_header() {
run(async {
let gw = gateway();
let mut headers = HeaderMap::new();
headers.insert("range", "bytes=0-99".parse().unwrap());
let action = gw
.resolve_request(Method::GET, "/test-bucket/key.txt", None, &headers, None)
.await;
match action {
HandlerAction::Forward(fwd) => {
assert_eq!(fwd.method, Method::GET);
assert_eq!(
fwd.headers.get("range").map(|v| v.to_str().unwrap()),
Some("bytes=0-99"),
"GET forward should pass through the Range header"
);
}
other => panic!("expected Forward, got {:?}", std::mem::discriminant(&other)),
}
});
}
#[test]
fn head_forward_preserves_range_header() {
run(async {
let gw = gateway();
let mut headers = HeaderMap::new();
headers.insert("range", "bytes=0-1023".parse().unwrap());
let action = gw
.resolve_request(Method::HEAD, "/test-bucket/key.txt", None, &headers, None)
.await;
match action {
HandlerAction::Forward(fwd) => {
assert_eq!(fwd.method, Method::HEAD);
assert_eq!(
fwd.headers.get("range").map(|v| v.to_str().unwrap()),
Some("bytes=0-1023"),
"HEAD forward should pass through the Range header"
);
}
other => panic!("expected Forward, got {:?}", std::mem::discriminant(&other)),
}
});
}
#[test]
fn forward_includes_user_agent_header() {
run(async {
let gw = gateway();
let headers = HeaderMap::new();
let action = gw
.resolve_request(Method::GET, "/test-bucket/key.txt", None, &headers, None)
.await;
match action {
HandlerAction::Forward(fwd) => {
let ua = fwd
.headers
.get(http::header::USER_AGENT)
.expect("forward should include User-Agent header");
assert!(
ua.to_str().unwrap().starts_with("multistore/"),
"User-Agent should start with 'multistore/', got: {}",
ua.to_str().unwrap()
);
}
other => panic!("expected Forward, got {:?}", std::mem::discriminant(&other)),
}
});
}
#[test]
fn put_forward_includes_user_agent_header() {
run(async {
let gw = gateway();
let mut headers = HeaderMap::new();
headers.insert("content-type", "application/octet-stream".parse().unwrap());
let action = gw
.resolve_request(Method::PUT, "/test-bucket/key.txt", None, &headers, None)
.await;
match action {
HandlerAction::Forward(fwd) => {
let ua = fwd
.headers
.get(http::header::USER_AGENT)
.expect("PUT forward should include User-Agent header");
assert!(
ua.to_str().unwrap().starts_with("multistore/"),
"User-Agent should start with 'multistore/', got: {}",
ua.to_str().unwrap()
);
}
other => panic!("expected Forward, got {:?}", std::mem::discriminant(&other)),
}
});
}
#[test]
fn delete_forward_includes_user_agent_header() {
run(async {
let gw = gateway();
let headers = HeaderMap::new();
let action = gw
.resolve_request(Method::DELETE, "/test-bucket/key.txt", None, &headers, None)
.await;
match action {
HandlerAction::Forward(fwd) => {
let ua = fwd
.headers
.get(http::header::USER_AGENT)
.expect("DELETE forward should include User-Agent header");
assert_eq!(ua.to_str().unwrap(), DEFAULT_USER_AGENT);
}
other => panic!("expected Forward, got {:?}", std::mem::discriminant(&other)),
}
});
}
#[test]
fn custom_user_agent_is_used_in_forward() {
run(async {
let gw = gateway().with_user_agent("myapp/1.0 multistore/0.2.0");
let headers = HeaderMap::new();
let action = gw
.resolve_request(Method::GET, "/test-bucket/key.txt", None, &headers, None)
.await;
match action {
HandlerAction::Forward(fwd) => {
let ua = fwd
.headers
.get(http::header::USER_AGENT)
.expect("forward should include User-Agent header");
assert_eq!(ua.to_str().unwrap(), "myapp/1.0 multistore/0.2.0");
}
other => panic!("expected Forward, got {:?}", std::mem::discriminant(&other)),
}
});
}
#[test]
fn multipart_needs_body_then_includes_user_agent() {
run(async {
let gw = gateway();
let headers = HeaderMap::new();
let action = gw
.resolve_request(
Method::POST,
"/test-bucket/key.txt",
Some("uploads"),
&headers,
None,
)
.await;
assert!(
matches!(action, HandlerAction::NeedsBody(_)),
"CreateMultipartUpload should return NeedsBody"
);
});
}
#[test]
fn put_over_max_body_size_is_rejected() {
run(async {
let gw = gateway().with_max_request_body_size(1024);
let mut headers = HeaderMap::new();
headers.insert("content-length", "2048".parse().unwrap());
let action = gw
.resolve_request(Method::PUT, "/test-bucket/big.bin", None, &headers, None)
.await;
match action {
HandlerAction::Response(r) => assert_eq!(
r.status, 400,
"oversized PUT should be rejected with EntityTooLarge (400)"
),
other => panic!(
"expected Response, got {:?}",
std::mem::discriminant(&other)
),
}
});
}
#[test]
fn put_under_max_body_size_forwards() {
run(async {
let gw = gateway().with_max_request_body_size(1_000_000);
let mut headers = HeaderMap::new();
headers.insert("content-length", "1024".parse().unwrap());
let action = gw
.resolve_request(Method::PUT, "/test-bucket/ok.bin", None, &headers, None)
.await;
assert!(
matches!(action, HandlerAction::Forward(_)),
"PUT within the limit should forward"
);
});
}
#[test]
fn put_with_no_limit_forwards_large_body() {
run(async {
let gw = gateway(); let mut headers = HeaderMap::new();
headers.insert("content-length", "999999999".parse().unwrap());
let action = gw
.resolve_request(Method::PUT, "/test-bucket/huge.bin", None, &headers, None)
.await;
assert!(
matches!(action, HandlerAction::Forward(_)),
"with no limit configured, large PUT should still forward"
);
});
}
fn unsigned_aws_chunked_headers() -> HeaderMap {
let mut headers = HeaderMap::new();
headers.insert("content-encoding", "aws-chunked".parse().unwrap());
headers.insert(
"x-amz-content-sha256",
"STREAMING-UNSIGNED-PAYLOAD-TRAILER".parse().unwrap(),
);
headers.insert("content-length", "52".parse().unwrap());
headers.insert("x-amz-decoded-content-length", "7".parse().unwrap());
headers.insert("x-amz-trailer", "x-amz-checksum-crc64nvme".parse().unwrap());
headers
}
#[test]
fn put_unsigned_aws_chunked_streams_via_resign() {
run(async {
let gw = gateway();
let headers = unsigned_aws_chunked_headers();
let action = gw
.resolve_request(Method::PUT, "/test-bucket/test.md", None, &headers, None)
.await;
match action {
HandlerAction::Forward(fwd) => {
assert_eq!(fwd.method, Method::PUT);
assert_eq!(
fwd.headers.get("x-amz-content-sha256").unwrap(),
"STREAMING-UNSIGNED-PAYLOAD-TRAILER"
);
assert_eq!(fwd.headers.get("content-encoding").unwrap(), "aws-chunked");
assert!(fwd.headers.contains_key("x-amz-decoded-content-length"));
assert!(fwd.headers.contains_key("authorization"));
}
other => panic!("expected Forward, got {:?}", std::mem::discriminant(&other)),
}
});
}
#[test]
fn put_signed_aws_chunked_is_rejected() {
run(async {
let gw = gateway();
let mut headers = HeaderMap::new();
headers.insert("content-encoding", "aws-chunked".parse().unwrap());
headers.insert(
"x-amz-content-sha256",
"STREAMING-AWS4-HMAC-SHA256-PAYLOAD".parse().unwrap(),
);
let action = gw
.resolve_request(Method::PUT, "/test-bucket/test.md", None, &headers, None)
.await;
match action {
HandlerAction::Response(r) => assert_eq!(
r.status, 501,
"signed aws-chunked uploads should be rejected with NotImplemented"
),
other => panic!(
"expected Response(501), got {:?}",
std::mem::discriminant(&other)
),
}
});
}
#[test]
fn upload_part_unsigned_aws_chunked_streams_via_resign() {
run(async {
let gw = gateway();
let headers = unsigned_aws_chunked_headers();
let action = gw
.resolve_request(
Method::PUT,
"/test-bucket/key.bin",
Some("partNumber=1&uploadId=abc"),
&headers,
None,
)
.await;
match action {
HandlerAction::Forward(fwd) => {
let q = fwd.url.query().unwrap_or("");
assert!(
q.contains("partNumber=1") && q.contains("uploadId=abc"),
"UploadPart forward must carry partNumber/uploadId, got query {q:?}"
);
assert_eq!(
fwd.headers.get("x-amz-content-sha256").unwrap(),
"STREAMING-UNSIGNED-PAYLOAD-TRAILER"
);
}
other => panic!(
"expected Forward (stream via re-sign, not buffer), got {:?}",
std::mem::discriminant(&other)
),
}
});
}
#[test]
fn streaming_put_on_non_s3_backend_is_rejected() {
run(async {
let gw = gateway();
let headers = unsigned_aws_chunked_headers();
let action = gw
.resolve_request(Method::PUT, "/azure-bucket/test.md", None, &headers, None)
.await;
match action {
HandlerAction::Response(r) => assert_eq!(
r.status, 400,
"aws-chunked PUT to a non-S3 backend should be rejected, not mis-signed"
),
other => panic!(
"expected Response(400), got {:?}",
std::mem::discriminant(&other)
),
}
});
}
#[test]
fn upload_part_over_max_body_size_is_rejected() {
run(async {
let gw = gateway().with_max_request_body_size(1024);
let mut headers = HeaderMap::new();
headers.insert("content-length", "5000".parse().unwrap());
let action = gw
.resolve_request(
Method::PUT,
"/test-bucket/key.bin",
Some("partNumber=1&uploadId=abc"),
&headers,
None,
)
.await;
match action {
HandlerAction::Response(r) => assert_eq!(
r.status, 400,
"oversized UploadPart should be rejected with EntityTooLarge (400)"
),
other => panic!(
"expected Response, got {:?}",
std::mem::discriminant(&other)
),
}
});
}
struct BlockMiddleware;
impl crate::middleware::Middleware for BlockMiddleware {
async fn handle<'a>(
&'a self,
_ctx: crate::middleware::DispatchContext<'a>,
_next: crate::middleware::Next<'a>,
) -> Result<HandlerAction, ProxyError> {
Ok(HandlerAction::Response(ProxyResult {
status: 429,
headers: HeaderMap::new(),
body: ProxyResponseBody::Empty,
}))
}
}
struct PassMiddleware;
impl crate::middleware::Middleware for PassMiddleware {
async fn handle<'a>(
&'a self,
ctx: crate::middleware::DispatchContext<'a>,
next: crate::middleware::Next<'a>,
) -> Result<HandlerAction, ProxyError> {
next.run(ctx).await
}
}
#[test]
fn middleware_short_circuits_request() {
run(async {
let gw = gateway().with_middleware(BlockMiddleware);
let headers = HeaderMap::new();
let action = gw
.resolve_request(Method::GET, "/test-bucket/key.txt", None, &headers, None)
.await;
match action {
HandlerAction::Response(resp) => {
assert_eq!(resp.status, 429, "blocking middleware should return 429");
}
other => panic!(
"expected Response, got {:?}",
std::mem::discriminant(&other)
),
}
});
}
#[test]
fn middleware_passthrough_allows_request() {
run(async {
let gw = gateway().with_middleware(PassMiddleware);
let headers = HeaderMap::new();
let action = gw
.resolve_request(Method::GET, "/test-bucket/key.txt", None, &headers, None)
.await;
match action {
HandlerAction::Forward(fwd) => {
assert_eq!(
fwd.method,
Method::GET,
"passthrough middleware should allow normal forwarding"
);
}
other => panic!("expected Forward, got {:?}", std::mem::discriminant(&other)),
}
});
}
#[derive(Clone)]
struct ForwardMockBackend;
impl ProxyBackend for ForwardMockBackend {
type ResponseBody = ();
type Body = ();
async fn forward(
&self,
_request: ForwardRequest,
_body: (),
) -> Result<ForwardResponse<()>, ProxyError> {
Ok(ForwardResponse {
status: 200,
headers: HeaderMap::new(),
body: (),
content_length: Some(0),
})
}
fn create_paginated_store(
&self,
_config: &BucketConfig,
) -> Result<Box<dyn PaginatedListStore>, ProxyError> {
unimplemented!()
}
fn create_signer(&self, config: &BucketConfig) -> Result<Arc<dyn Signer>, ProxyError> {
crate::backend::build_signer(config)
}
async fn send_raw(
&self,
_method: http::Method,
_url: String,
_headers: HeaderMap,
_body: Bytes,
) -> Result<RawResponse, ProxyError> {
unimplemented!()
}
}
fn forward_gateway() -> ProxyGateway<ForwardMockBackend, MockRegistry, MockCreds> {
ProxyGateway::new(ForwardMockBackend, MockRegistry, MockCreds, None)
}
fn extract_server_timing(response: &GatewayResponse<()>) -> Option<String> {
let headers = match response {
GatewayResponse::Response(r) => &r.headers,
GatewayResponse::Forward(f) => &f.headers,
};
headers
.get("server-timing")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
}
#[test]
fn server_timing_present_on_forward_response() {
run(async {
let gw = forward_gateway();
let headers = HeaderMap::new();
let req = RequestInfo::new(&Method::GET, "/test-bucket/key.txt", None, &headers, None);
let response = gw
.handle_request(&req, (), |_| async { Ok::<_, String>(Bytes::new()) })
.await;
let timing = extract_server_timing(&response)
.expect("forwarded response should have Server-Timing header");
assert!(
timing.contains("total;dur="),
"should contain total: {timing}"
);
assert!(
timing.contains("dispatch;dur="),
"should contain dispatch: {timing}"
);
assert!(
timing.contains("backend;dur="),
"should contain backend: {timing}"
);
});
}
#[test]
fn server_timing_present_on_error_response() {
run(async {
let gw = forward_gateway();
let headers = HeaderMap::new();
let req = RequestInfo::new(&Method::GET, "/", None, &headers, None);
let response = gw
.handle_request(&req, (), |_| async { Ok::<_, String>(Bytes::new()) })
.await;
let timing = extract_server_timing(&response)
.expect("error response should have Server-Timing header");
assert!(
timing.contains("total;dur="),
"should contain total: {timing}"
);
});
}
#[test]
fn server_timing_disabled_when_configured() {
run(async {
let gw = forward_gateway().with_server_timing(false);
let headers = HeaderMap::new();
let req = RequestInfo::new(&Method::GET, "/test-bucket/key.txt", None, &headers, None);
let response = gw
.handle_request(&req, (), |_| async { Ok::<_, String>(Bytes::new()) })
.await;
assert!(
extract_server_timing(&response).is_none(),
"Server-Timing should not be present when disabled"
);
});
}
#[derive(Clone)]
struct DeleteMockBackend {
captured: Arc<std::sync::Mutex<Option<Bytes>>>,
}
impl ProxyBackend for DeleteMockBackend {
type ResponseBody = ();
type Body = ();
async fn forward(
&self,
_request: ForwardRequest,
_body: (),
) -> Result<ForwardResponse<()>, ProxyError> {
unimplemented!()
}
fn create_paginated_store(
&self,
_config: &BucketConfig,
) -> Result<Box<dyn PaginatedListStore>, ProxyError> {
unimplemented!()
}
fn create_signer(&self, config: &BucketConfig) -> Result<Arc<dyn Signer>, ProxyError> {
crate::backend::build_signer(config)
}
async fn send_raw(
&self,
_method: http::Method,
_url: String,
_headers: HeaderMap,
body: Bytes,
) -> Result<RawResponse, ProxyError> {
*self.captured.lock().unwrap() = Some(body);
Ok(RawResponse {
status: 200,
headers: HeaderMap::new(),
body: Bytes::from_static(
b"<?xml version=\"1.0\"?><DeleteResult><Deleted><Key>allowed/a.txt</Key></Deleted></DeleteResult>",
),
})
}
}
#[test]
fn batch_delete_filters_unauthorized_keys_per_key() {
use crate::types::{AccessScope, AuthenticatedIdentity};
run(async {
let captured = Arc::new(std::sync::Mutex::new(None));
let backend = DeleteMockBackend {
captured: captured.clone(),
};
let gw = ProxyGateway::new(backend, MockRegistry, MockCreds, None);
let identity = ResolvedIdentity::Authenticated(AuthenticatedIdentity {
principal_name: "tester".into(),
allowed_scopes: vec![AccessScope {
bucket: "test-bucket".into(),
prefixes: vec!["allowed/".into()],
actions: vec![Action::DeleteObject],
}],
});
let pending = PendingRequest {
operation: S3Operation::DeleteObjects {
bucket: "test-bucket".into(),
},
bucket_config: test_bucket_config("test-bucket"),
original_headers: HeaderMap::new(),
request_id: "rid".into(),
identity,
};
let body = Bytes::from_static(
br#"<Delete><Object><Key>allowed/a.txt</Key></Object><Object><Key>denied/b.txt</Key></Object></Delete>"#,
);
let result = gw.handle_with_body(pending, body).await;
assert_eq!(result.status, 200);
let xml = match result.body {
ProxyResponseBody::Bytes(b) => String::from_utf8(b.to_vec()).unwrap(),
ProxyResponseBody::Empty => panic!("expected a body"),
};
assert!(
xml.contains("<Deleted><Key>allowed/a.txt</Key></Deleted>"),
"{xml}"
);
assert!(xml.contains("<Key>denied/b.txt</Key>"), "{xml}");
assert!(xml.contains("<Code>AccessDenied</Code>"), "{xml}");
let sent = captured
.lock()
.unwrap()
.clone()
.expect("backend was called");
let sent = String::from_utf8(sent.to_vec()).unwrap();
assert!(sent.contains("allowed/a.txt"), "forwarded body: {sent}");
assert!(
!sent.contains("denied/b.txt"),
"denied key leaked to backend: {sent}"
);
});
}
#[test]
fn batch_delete_all_denied_skips_backend() {
use crate::types::{AccessScope, AuthenticatedIdentity};
run(async {
let captured = Arc::new(std::sync::Mutex::new(None));
let backend = DeleteMockBackend {
captured: captured.clone(),
};
let gw = ProxyGateway::new(backend, MockRegistry, MockCreds, None);
let identity = ResolvedIdentity::Authenticated(AuthenticatedIdentity {
principal_name: "tester".into(),
allowed_scopes: vec![AccessScope {
bucket: "test-bucket".into(),
prefixes: vec!["other/".into()],
actions: vec![Action::DeleteObject],
}],
});
let pending = PendingRequest {
operation: S3Operation::DeleteObjects {
bucket: "test-bucket".into(),
},
bucket_config: test_bucket_config("test-bucket"),
original_headers: HeaderMap::new(),
request_id: "rid".into(),
identity,
};
let body =
Bytes::from_static(br#"<Delete><Object><Key>secret/a.txt</Key></Object></Delete>"#);
let result = gw.handle_with_body(pending, body).await;
assert_eq!(result.status, 200);
assert!(
captured.lock().unwrap().is_none(),
"backend should be skipped"
);
});
}
#[derive(Clone)]
struct CaptureHeadersBackend {
captured: Arc<std::sync::Mutex<Option<HeaderMap>>>,
}
impl ProxyBackend for CaptureHeadersBackend {
type ResponseBody = ();
type Body = ();
async fn forward(
&self,
_request: ForwardRequest,
_body: (),
) -> Result<ForwardResponse<()>, ProxyError> {
unimplemented!()
}
fn create_paginated_store(
&self,
_config: &BucketConfig,
) -> Result<Box<dyn PaginatedListStore>, ProxyError> {
unimplemented!()
}
fn create_signer(&self, config: &BucketConfig) -> Result<Arc<dyn Signer>, ProxyError> {
crate::backend::build_signer(config)
}
async fn send_raw(
&self,
_method: http::Method,
_url: String,
headers: HeaderMap,
_body: Bytes,
) -> Result<RawResponse, ProxyError> {
*self.captured.lock().unwrap() = Some(headers);
Ok(RawResponse {
status: 200,
headers: HeaderMap::new(),
body: Bytes::new(),
})
}
}
#[test]
fn complete_multipart_forwards_and_signs_checksum_headers() {
use crate::types::AuthenticatedIdentity;
run(async {
let captured = Arc::new(std::sync::Mutex::new(None));
let backend = CaptureHeadersBackend {
captured: captured.clone(),
};
let gw = ProxyGateway::new(backend, MockRegistry, MockCreds, None);
let mut original_headers = HeaderMap::new();
original_headers.insert("content-type", "application/xml".parse().unwrap());
original_headers.insert("x-amz-checksum-crc32", "AAAAAA==".parse().unwrap());
original_headers.insert("x-amz-checksum-type", "FULL_OBJECT".parse().unwrap());
original_headers.insert("x-amz-sdk-checksum-algorithm", "CRC32".parse().unwrap());
original_headers.insert(
"authorization",
"AWS4-HMAC-SHA256 client-bogus".parse().unwrap(),
);
let pending = PendingRequest {
operation: S3Operation::CompleteMultipartUpload {
bucket: "test-bucket".into(),
key: "big.dmg".into(),
upload_id: "upload-1".into(),
},
bucket_config: test_bucket_config("test-bucket"),
original_headers,
request_id: "rid".into(),
identity: ResolvedIdentity::Authenticated(AuthenticatedIdentity {
principal_name: "tester".into(),
allowed_scopes: vec![],
}),
};
let body = Bytes::from_static(
br#"<CompleteMultipartUpload><Part><PartNumber>1</PartNumber><ETag>"abc"</ETag><ChecksumCRC32>AAAAAA==</ChecksumCRC32></Part></CompleteMultipartUpload>"#,
);
let result = gw.handle_with_body(pending, body).await;
assert_eq!(result.status, 200);
let sent = captured
.lock()
.unwrap()
.clone()
.expect("backend was called");
assert_eq!(sent.get("x-amz-checksum-crc32").unwrap(), "AAAAAA==");
assert_eq!(sent.get("x-amz-checksum-type").unwrap(), "FULL_OBJECT");
assert_eq!(sent.get("x-amz-sdk-checksum-algorithm").unwrap(), "CRC32");
let auth = sent.get("authorization").unwrap().to_str().unwrap();
assert!(
auth.starts_with("AWS4-HMAC-SHA256 Credential="),
"expected re-signed Authorization, got: {auth}"
);
assert!(
auth.contains("x-amz-checksum-crc32")
&& auth.contains("x-amz-checksum-type")
&& auth.contains("x-amz-sdk-checksum-algorithm"),
"checksum headers missing from SignedHeaders: {auth}"
);
});
}
}