use crate::maybe_send::{MaybeSend, MaybeSync};
use bytes::Bytes;
use http::{HeaderMap, Method};
use std::future::Future;
use std::net::IpAddr;
use std::pin::Pin;
use url::Url;
pub enum ProxyResponseBody {
Bytes(Bytes),
Empty,
}
impl ProxyResponseBody {
pub fn from_bytes(bytes: Bytes) -> Self {
if bytes.is_empty() {
Self::Empty
} else {
Self::Bytes(bytes)
}
}
pub fn empty() -> Self {
Self::Empty
}
}
pub enum HandlerAction {
Response(ProxyResult),
Forward(ForwardRequest),
NeedsBody(PendingRequest),
}
impl HandlerAction {
pub fn response_headers_mut(&mut self) -> Option<&mut HeaderMap> {
match self {
HandlerAction::Response(result) => Some(&mut result.headers),
HandlerAction::Forward(fwd) => Some(&mut fwd.headers),
HandlerAction::NeedsBody(_) => None,
}
}
}
pub struct ForwardRequest {
pub method: Method,
pub url: Url,
pub headers: HeaderMap,
pub request_id: String,
}
pub struct ProxyResult {
pub status: u16,
pub headers: HeaderMap,
pub body: ProxyResponseBody,
}
impl ProxyResult {
pub fn json(status: u16, body: impl Into<String>) -> Self {
let mut headers = HeaderMap::new();
headers.insert("content-type", "application/json".parse().unwrap());
Self {
status,
headers,
body: ProxyResponseBody::from_bytes(Bytes::from(body.into())),
}
}
pub fn xml(status: u16, body: impl Into<String>) -> Self {
let mut headers = HeaderMap::new();
headers.insert("content-type", "application/xml".parse().unwrap());
Self {
status,
headers,
body: ProxyResponseBody::from_bytes(Bytes::from(body.into())),
}
}
}
pub struct PendingRequest {
pub(crate) operation: crate::types::S3Operation,
pub(crate) bucket_config: crate::types::BucketConfig,
pub(crate) original_headers: HeaderMap,
pub(crate) request_id: String,
}
pub const RESPONSE_HEADER_DENYLIST: &[&str] = &[
"transfer-encoding",
"connection",
"keep-alive",
"proxy-connection",
"te",
"trailer",
"upgrade",
"proxy-authenticate",
"proxy-authorization",
"www-authenticate",
"set-cookie",
"forwarded",
"x-forwarded-for",
"x-forwarded-proto",
"x-forwarded-host",
"x-forwarded-port",
"via",
"x-amz-server-side-encryption-customer-key-md5",
"x-amz-server-side-encryption-aws-kms-key-id",
"x-ms-encryption-key-sha256",
"x-goog-encryption-key-sha256",
];
pub fn filter_response_headers(source: &http::HeaderMap) -> http::HeaderMap {
let mut out = http::HeaderMap::new();
for (name, value) in source.iter() {
if !RESPONSE_HEADER_DENYLIST.contains(&name.as_str()) {
out.insert(name.clone(), value.clone());
}
}
out
}
#[cfg(not(target_arch = "wasm32"))]
pub type RouteHandlerFuture<'a> = Pin<Box<dyn Future<Output = Option<ProxyResult>> + Send + 'a>>;
#[cfg(target_arch = "wasm32")]
pub type RouteHandlerFuture<'a> = Pin<Box<dyn Future<Output = Option<ProxyResult>> + 'a>>;
#[derive(Debug, Clone, Default)]
pub struct Params(Vec<(String, String)>);
impl Params {
pub fn get(&self, key: &str) -> Option<&str> {
self.0
.iter()
.find(|(k, _)| k == key)
.map(|(_, v)| v.as_str())
}
pub(crate) fn from_matchit(params: &matchit::Params<'_, '_>) -> Self {
Self(
params
.iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect(),
)
}
}
pub struct RequestInfo<'a> {
pub method: &'a Method,
pub path: &'a str,
pub query: Option<&'a str>,
pub headers: &'a HeaderMap,
pub source_ip: Option<IpAddr>,
pub params: Params,
pub signing_path: Option<&'a str>,
pub signing_query: Option<&'a str>,
}
impl<'a> RequestInfo<'a> {
pub fn new(
method: &'a Method,
path: &'a str,
query: Option<&'a str>,
headers: &'a HeaderMap,
source_ip: Option<IpAddr>,
) -> Self {
Self {
method,
path,
query,
headers,
source_ip,
params: Params::default(),
signing_path: None,
signing_query: None,
}
}
pub fn with_signing_path(mut self, signing_path: &'a str) -> Self {
self.signing_path = Some(signing_path);
self
}
pub fn with_signing_query(mut self, signing_query: Option<&'a str>) -> Self {
self.signing_query = signing_query;
self
}
}
pub trait RouteHandler: MaybeSend + MaybeSync {
fn handle<'a>(&'a self, req: &'a RequestInfo<'a>) -> RouteHandlerFuture<'a>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_response_headers_mut_on_response() {
let mut action = HandlerAction::Response(ProxyResult {
status: 200,
headers: HeaderMap::new(),
body: ProxyResponseBody::Empty,
});
let headers = action.response_headers_mut().unwrap();
headers.insert("x-custom", "value".parse().unwrap());
if let HandlerAction::Response(result) = &action {
assert_eq!(result.headers.get("x-custom").unwrap(), "value");
}
}
#[test]
fn test_response_headers_mut_on_forward() {
let mut action = HandlerAction::Forward(ForwardRequest {
method: Method::GET,
url: "https://example.com".parse().unwrap(),
headers: HeaderMap::new(),
request_id: String::new(),
});
assert!(action.response_headers_mut().is_some());
}
#[test]
fn test_response_headers_mut_on_needs_body() {
use crate::types::{BucketConfig, S3Operation};
let mut action = HandlerAction::NeedsBody(PendingRequest {
operation: S3Operation::CreateMultipartUpload {
bucket: "b".into(),
key: "k".into(),
},
bucket_config: BucketConfig {
name: String::new(),
backend_type: "s3".into(),
backend_prefix: None,
anonymous_access: false,
backend_options: Default::default(),
allowed_roles: Vec::new(),
},
original_headers: HeaderMap::new(),
request_id: String::new(),
});
assert!(action.response_headers_mut().is_none());
}
#[test]
fn test_blocks_hop_by_hop_headers() {
let mut headers = http::HeaderMap::new();
headers.insert("transfer-encoding", "chunked".parse().unwrap());
headers.insert("connection", "keep-alive".parse().unwrap());
headers.insert("content-type", "text/plain".parse().unwrap());
let filtered = filter_response_headers(&headers);
assert!(filtered.get("transfer-encoding").is_none());
assert!(filtered.get("connection").is_none());
assert!(filtered.get("content-type").is_some());
}
#[test]
fn test_blocks_auth_and_cookie_headers() {
let mut headers = http::HeaderMap::new();
headers.insert("www-authenticate", "Basic".parse().unwrap());
headers.insert("set-cookie", "session=abc".parse().unwrap());
headers.insert("etag", "\"abc\"".parse().unwrap());
let filtered = filter_response_headers(&headers);
assert!(filtered.get("www-authenticate").is_none());
assert!(filtered.get("set-cookie").is_none());
assert!(filtered.get("etag").is_some());
}
#[test]
fn test_blocks_encryption_key_material() {
let mut headers = http::HeaderMap::new();
headers.insert(
"x-amz-server-side-encryption-aws-kms-key-id",
"arn:aws:kms:us-east-1:123456:key/abc".parse().unwrap(),
);
headers.insert(
"x-amz-server-side-encryption-customer-key-md5",
"abc123".parse().unwrap(),
);
headers.insert("x-amz-server-side-encryption", "aws:kms".parse().unwrap());
let filtered = filter_response_headers(&headers);
assert!(filtered
.get("x-amz-server-side-encryption-aws-kms-key-id")
.is_none());
assert!(filtered
.get("x-amz-server-side-encryption-customer-key-md5")
.is_none());
assert!(filtered.get("x-amz-server-side-encryption").is_some());
}
#[test]
fn test_passes_cloud_metadata_headers() {
let mut headers = http::HeaderMap::new();
headers.insert("x-amz-meta-author", "alice".parse().unwrap());
headers.insert("x-ms-meta-version", "2".parse().unwrap());
headers.insert("x-goog-meta-project", "test".parse().unwrap());
headers.insert("x-amz-storage-class", "STANDARD".parse().unwrap());
headers.insert("x-amz-version-id", "v1".parse().unwrap());
let filtered = filter_response_headers(&headers);
assert_eq!(filtered.len(), 5);
}
#[test]
fn test_passes_standard_content_headers() {
let mut headers = http::HeaderMap::new();
headers.insert("content-type", "application/json".parse().unwrap());
headers.insert("content-length", "1234".parse().unwrap());
headers.insert("content-range", "bytes 0-499/1000".parse().unwrap());
headers.insert("etag", "\"abc\"".parse().unwrap());
headers.insert(
"last-modified",
"Mon, 01 Jan 2024 00:00:00 GMT".parse().unwrap(),
);
headers.insert("accept-ranges", "bytes".parse().unwrap());
headers.insert("cache-control", "max-age=3600".parse().unwrap());
headers.insert("location", "/new".parse().unwrap());
let filtered = filter_response_headers(&headers);
assert_eq!(filtered.len(), 8);
}
#[test]
fn test_blocks_proxy_routing_headers() {
let mut headers = http::HeaderMap::new();
headers.insert("x-forwarded-for", "1.2.3.4".parse().unwrap());
headers.insert("via", "1.1 proxy".parse().unwrap());
headers.insert("forwarded", "for=1.2.3.4".parse().unwrap());
let filtered = filter_response_headers(&headers);
assert!(filtered.is_empty());
}
}