use async_trait::async_trait;
use bytes::Bytes;
use http::{HeaderMap, Method, StatusCode};
use md5::{Digest, Md5};
use parking_lot::Mutex;
use std::collections::{BTreeMap, HashMap};
use std::path::PathBuf;
use crate::auth::Principal;
pub type RequestBodyStream = axum::body::Body;
pub struct AwsRequest {
pub service: String,
pub action: String,
pub region: String,
pub account_id: String,
pub request_id: String,
pub headers: HeaderMap,
pub query_params: HashMap<String, String>,
pub body: Bytes,
pub body_stream: Mutex<Option<RequestBodyStream>>,
pub path_segments: Vec<String>,
pub raw_path: String,
pub raw_query: String,
pub method: Method,
pub is_query_protocol: bool,
pub access_key_id: Option<String>,
pub principal: Option<Principal>,
}
impl std::fmt::Debug for AwsRequest {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AwsRequest")
.field("service", &self.service)
.field("action", &self.action)
.field("region", &self.region)
.field("account_id", &self.account_id)
.field("request_id", &self.request_id)
.field("headers", &self.headers)
.field("query_params", &self.query_params)
.field("body_len", &self.body.len())
.field(
"body_stream",
&self.body_stream.lock().as_ref().map(|_| "<stream>"),
)
.field("path_segments", &self.path_segments)
.field("raw_path", &self.raw_path)
.field("raw_query", &self.raw_query)
.field("method", &self.method)
.field("is_query_protocol", &self.is_query_protocol)
.field("access_key_id", &self.access_key_id)
.field("principal", &self.principal)
.finish()
}
}
impl AwsRequest {
pub fn json_body(&self) -> serde_json::Value {
serde_json::from_slice(&self.body).unwrap_or(serde_json::Value::Null)
}
pub fn take_body_stream(&self) -> Option<RequestBodyStream> {
self.body_stream.lock().take()
}
}
pub async fn drain_request_stream(stream: RequestBodyStream) -> Result<Bytes, AwsServiceError> {
use http_body_util::BodyExt;
match stream.collect().await {
Ok(c) => Ok(c.to_bytes()),
Err(e) => Err(stream_error_to_aws(&e.to_string())),
}
}
fn stream_error_to_aws(msg: &str) -> AwsServiceError {
let too_large = msg.to_ascii_lowercase().contains("limit");
let (status, code, message) = if too_large {
(
StatusCode::PAYLOAD_TOO_LARGE,
"RequestEntityTooLarge",
"Streaming request body exceeded the configured limit",
)
} else {
(
StatusCode::BAD_REQUEST,
"MalformedRequestBody",
"Failed to read streaming request body",
)
};
AwsServiceError::aws_error(status, code, message)
}
#[derive(Debug)]
pub struct SpooledBody {
pub path: PathBuf,
pub size: u64,
pub md5_hex: String,
}
#[derive(Default)]
pub struct AwsChunkedDecoder {
state: ChunkState,
line: Vec<u8>,
remaining: usize,
done: bool,
}
#[derive(Default, PartialEq)]
enum ChunkState {
#[default]
Header,
Data,
AfterData,
Trailer,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct MalformedChunk;
impl AwsChunkedDecoder {
pub fn feed(&mut self, input: &[u8]) -> Result<Vec<u8>, MalformedChunk> {
let mut out = Vec::new();
let mut i = 0;
while i < input.len() && !self.done {
match self.state {
ChunkState::Data => {
let take = self.remaining.min(input.len() - i);
out.extend_from_slice(&input[i..i + take]);
i += take;
self.remaining -= take;
if self.remaining == 0 {
self.state = ChunkState::AfterData;
}
}
ChunkState::AfterData => {
while i < input.len() {
let b = input[i];
i += 1;
if b == b'\n' {
self.state = ChunkState::Header;
break;
}
}
}
ChunkState::Header | ChunkState::Trailer => {
let is_header = self.state == ChunkState::Header;
while i < input.len() {
let b = input[i];
i += 1;
if b == b'\n' {
let line = std::mem::take(&mut self.line);
if is_header {
let hex_part: &[u8] =
line.split(|&c| c == b';').next().unwrap_or(&[]);
let hex = std::str::from_utf8(hex_part)
.map_err(|_| MalformedChunk)?
.trim();
let size =
usize::from_str_radix(hex, 16).map_err(|_| MalformedChunk)?;
if size == 0 {
self.state = ChunkState::Trailer;
} else {
self.remaining = size;
self.state = ChunkState::Data;
}
} else if line.is_empty() {
self.done = true;
}
break;
} else if b != b'\r' {
self.line.push(b);
}
}
}
}
}
Ok(out)
}
}
pub fn is_aws_chunked(headers: &http::HeaderMap) -> bool {
headers
.get("content-encoding")
.and_then(|v| v.to_str().ok())
.is_some_and(|v| {
v.split(',')
.any(|t| t.trim().eq_ignore_ascii_case("aws-chunked"))
})
|| headers
.get("x-amz-content-sha256")
.and_then(|v| v.to_str().ok())
.is_some_and(|v| v.starts_with("STREAMING-"))
}
pub fn strip_aws_chunked_encoding(content_encoding: Option<&str>) -> Option<String> {
let ce = content_encoding?;
let kept: Vec<&str> = ce
.split(',')
.map(|t| t.trim())
.filter(|t| !t.is_empty() && !t.eq_ignore_ascii_case("aws-chunked"))
.collect();
if kept.is_empty() {
None
} else {
Some(kept.join(", "))
}
}
pub async fn spool_request_stream(
stream: RequestBodyStream,
dir: Option<&std::path::Path>,
aws_chunked: bool,
) -> Result<SpooledBody, AwsServiceError> {
use http_body_util::BodyExt;
use tokio::io::AsyncWriteExt;
let dir = dir.map(|d| d.to_path_buf());
if let Some(d) = dir.as_ref() {
let _ = tokio::fs::create_dir_all(d).await;
}
let mut builder = tempfile::Builder::new();
builder.prefix("fc-spool-");
let named = match dir.as_ref() {
Some(d) => builder.tempfile_in(d),
None => builder.tempfile(),
}
.map_err(|e| {
AwsServiceError::aws_error(
StatusCode::INTERNAL_SERVER_ERROR,
"InternalError",
format!("failed to create spool tempfile: {e}"),
)
})?;
let (std_file, temp_path) = named.into_parts();
let path: PathBuf = temp_path.keep().map_err(|e| {
AwsServiceError::aws_error(
StatusCode::INTERNAL_SERVER_ERROR,
"InternalError",
format!("failed to persist spool tempfile: {e}"),
)
})?;
let mut file = tokio::fs::File::from_std(std_file);
let mut hasher = Md5::new();
let mut size: u64 = 0;
let mut body = stream;
let mut decoder = aws_chunked.then(AwsChunkedDecoder::default);
async fn cleanup(file: tokio::fs::File, path: &std::path::Path) {
drop(file);
let _ = tokio::fs::remove_file(path).await;
}
loop {
match body.frame().await {
Some(Ok(frame)) => {
if let Ok(raw) = frame.into_data() {
if !raw.is_empty() {
let payload = match decoder.as_mut() {
Some(d) => match d.feed(&raw) {
Ok(decoded) => decoded,
Err(_) => {
cleanup(file, &path).await;
return Err(AwsServiceError::aws_error(
StatusCode::BAD_REQUEST,
"InvalidChunkSizeError",
"Malformed aws-chunked request body",
));
}
},
None => raw.to_vec(),
};
if !payload.is_empty() {
hasher.update(&payload);
size += payload.len() as u64;
if let Err(e) = file.write_all(&payload).await {
cleanup(file, &path).await;
return Err(AwsServiceError::aws_error(
StatusCode::INTERNAL_SERVER_ERROR,
"InternalError",
format!("failed to spool request body: {e}"),
));
}
}
}
}
}
Some(Err(e)) => {
cleanup(file, &path).await;
return Err(stream_error_to_aws(&e.to_string()));
}
None => break,
}
}
if let Err(e) = file.flush().await {
cleanup(file, &path).await;
return Err(AwsServiceError::aws_error(
StatusCode::INTERNAL_SERVER_ERROR,
"InternalError",
format!("failed to flush spool tempfile: {e}"),
));
}
drop(file);
let md5_hex = hex_lower(&hasher.finalize());
Ok(SpooledBody {
path,
size,
md5_hex,
})
}
fn hex_lower(bytes: &[u8]) -> String {
const HEX: &[u8] = b"0123456789abcdef";
let mut out = String::with_capacity(bytes.len() * 2);
for b in bytes {
out.push(HEX[(b >> 4) as usize] as char);
out.push(HEX[(b & 0x0f) as usize] as char);
}
out
}
#[derive(Debug)]
pub enum ResponseBody {
Bytes(Bytes),
File { file: tokio::fs::File, size: u64 },
}
impl ResponseBody {
pub fn len(&self) -> u64 {
match self {
ResponseBody::Bytes(b) => b.len() as u64,
ResponseBody::File { size, .. } => *size,
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn expect_bytes(&self) -> &[u8] {
match self {
ResponseBody::Bytes(b) => b,
ResponseBody::File { .. } => {
panic!("expect_bytes called on ResponseBody::File")
}
}
}
}
impl Default for ResponseBody {
fn default() -> Self {
ResponseBody::Bytes(Bytes::new())
}
}
impl From<Bytes> for ResponseBody {
fn from(b: Bytes) -> Self {
ResponseBody::Bytes(b)
}
}
impl From<Vec<u8>> for ResponseBody {
fn from(v: Vec<u8>) -> Self {
ResponseBody::Bytes(Bytes::from(v))
}
}
impl From<&'static [u8]> for ResponseBody {
fn from(s: &'static [u8]) -> Self {
ResponseBody::Bytes(Bytes::from_static(s))
}
}
impl From<String> for ResponseBody {
fn from(s: String) -> Self {
ResponseBody::Bytes(Bytes::from(s))
}
}
impl From<&'static str> for ResponseBody {
fn from(s: &'static str) -> Self {
ResponseBody::Bytes(Bytes::from_static(s.as_bytes()))
}
}
impl PartialEq<Bytes> for ResponseBody {
fn eq(&self, other: &Bytes) -> bool {
match self {
ResponseBody::Bytes(b) => b == other,
ResponseBody::File { .. } => false,
}
}
}
pub struct AwsResponse {
pub status: StatusCode,
pub content_type: String,
pub body: ResponseBody,
pub headers: HeaderMap,
}
impl AwsResponse {
pub fn xml(status: StatusCode, body: impl Into<Bytes>) -> Self {
Self {
status,
content_type: "text/xml".to_string(),
body: ResponseBody::Bytes(body.into()),
headers: HeaderMap::new(),
}
}
pub fn json(status: StatusCode, body: impl Into<Bytes>) -> Self {
Self {
status,
content_type: "application/x-amz-json-1.1".to_string(),
body: ResponseBody::Bytes(body.into()),
headers: HeaderMap::new(),
}
}
pub fn json_value(status: StatusCode, value: serde_json::Value) -> Self {
Self::json(
status,
serde_json::to_vec(&value).expect("serde_json::Value serialization is infallible"),
)
}
pub fn ok_json(value: serde_json::Value) -> Self {
Self::json_value(StatusCode::OK, value)
}
}
#[derive(Debug, thiserror::Error)]
pub enum AwsServiceError {
#[error("service not found: {service}")]
ServiceNotFound { service: String },
#[error("action {action} not implemented for service {service}")]
ActionNotImplemented { service: String, action: String },
#[error("{code}: {message}")]
AwsError {
status: StatusCode,
code: String,
message: String,
extra_fields: Vec<(String, String)>,
headers: Vec<(String, String)>,
},
}
impl AwsServiceError {
pub fn action_not_implemented(service: &str, action: &str) -> Self {
Self::ActionNotImplemented {
service: service.to_string(),
action: action.to_string(),
}
}
pub fn aws_error(
status: StatusCode,
code: impl Into<String>,
message: impl Into<String>,
) -> Self {
Self::AwsError {
status,
code: code.into(),
message: message.into(),
extra_fields: Vec::new(),
headers: Vec::new(),
}
}
pub fn aws_error_with_fields(
status: StatusCode,
code: impl Into<String>,
message: impl Into<String>,
extra_fields: Vec<(String, String)>,
) -> Self {
Self::AwsError {
status,
code: code.into(),
message: message.into(),
extra_fields,
headers: Vec::new(),
}
}
pub fn aws_error_with_headers(
status: StatusCode,
code: impl Into<String>,
message: impl Into<String>,
headers: Vec<(String, String)>,
) -> Self {
Self::AwsError {
status,
code: code.into(),
message: message.into(),
extra_fields: Vec::new(),
headers,
}
}
pub fn extra_fields(&self) -> &[(String, String)] {
match self {
Self::AwsError { extra_fields, .. } => extra_fields,
_ => &[],
}
}
pub fn status(&self) -> StatusCode {
match self {
Self::ServiceNotFound { .. } => StatusCode::BAD_REQUEST,
Self::ActionNotImplemented { .. } => StatusCode::NOT_IMPLEMENTED,
Self::AwsError { status, .. } => *status,
}
}
pub fn code(&self) -> &str {
match self {
Self::ServiceNotFound { .. } => "UnknownService",
Self::ActionNotImplemented { .. } => "InvalidAction",
Self::AwsError { code, .. } => code,
}
}
pub fn message(&self) -> String {
match self {
Self::ServiceNotFound { service } => format!("service not found: {service}"),
Self::ActionNotImplemented { service, action } => {
format!("action {action} not implemented for service {service}")
}
Self::AwsError { message, .. } => message.clone(),
}
}
pub fn response_headers(&self) -> &[(String, String)] {
match self {
Self::AwsError { headers, .. } => headers,
_ => &[],
}
}
}
#[async_trait]
pub trait AwsService: Send + Sync {
fn service_name(&self) -> &str;
async fn handle(&self, request: AwsRequest) -> Result<AwsResponse, AwsServiceError>;
fn supported_actions(&self) -> &[&str];
fn iam_enforceable(&self) -> bool {
false
}
fn iam_action_for(&self, _request: &AwsRequest) -> Option<crate::auth::IamAction> {
None
}
fn iam_condition_keys_for(
&self,
_request: &AwsRequest,
_action: &crate::auth::IamAction,
) -> BTreeMap<String, Vec<String>> {
BTreeMap::new()
}
fn resource_tags_for(
&self,
_resource_arn: &str,
) -> Option<std::collections::HashMap<String, String>> {
None
}
fn request_tags_from(
&self,
_request: &AwsRequest,
_action: &str,
) -> Option<std::collections::HashMap<String, String>> {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::IamAction;
use async_trait::async_trait;
fn aws_chunked_body(payload: &[u8], chunk_size: usize, with_trailer: bool) -> Vec<u8> {
let sig = "0".repeat(64);
let mut out = Vec::new();
for c in payload.chunks(chunk_size.max(1)) {
out.extend_from_slice(format!("{:x};chunk-signature={sig}\r\n", c.len()).as_bytes());
out.extend_from_slice(c);
out.extend_from_slice(b"\r\n");
}
out.extend_from_slice(format!("0;chunk-signature={sig}\r\n").as_bytes());
if with_trailer {
out.extend_from_slice(b"x-amz-checksum-crc32:AAAAAA==\r\n");
}
out.extend_from_slice(b"\r\n");
out
}
fn decode_all(body: &[u8], feed_size: usize) -> Vec<u8> {
let mut d = AwsChunkedDecoder::default();
let mut out = Vec::new();
for frame in body.chunks(feed_size.max(1)) {
out.extend(d.feed(frame).expect("valid chunked body"));
}
out
}
#[test]
fn aws_chunked_decoder_roundtrips_across_frame_boundaries() {
let payload: Vec<u8> = (0..5000u32).map(|i| (i % 251) as u8).collect();
for with_trailer in [false, true] {
let body = aws_chunked_body(&payload, 1024, with_trailer);
for feed in [1usize, 7, 64, 1000, body.len()] {
let decoded = decode_all(&body, feed);
assert_eq!(decoded, payload, "feed={feed} trailer={with_trailer}");
}
}
}
#[test]
fn aws_chunked_decoder_handles_empty_payload() {
let body = aws_chunked_body(b"", 1024, false);
assert_eq!(decode_all(&body, 3), Vec::<u8>::new());
}
#[test]
fn aws_chunked_decoder_rejects_bad_size_line() {
let mut d = AwsChunkedDecoder::default();
assert!(d.feed(b"zz;chunk-signature=x\r\n").is_err());
}
#[test]
fn is_aws_chunked_detects_streaming_markers() {
let mut h = http::HeaderMap::new();
assert!(!is_aws_chunked(&h));
h.insert("content-encoding", "aws-chunked".parse().unwrap());
assert!(is_aws_chunked(&h));
let mut h2 = http::HeaderMap::new();
h2.insert(
"x-amz-content-sha256",
"STREAMING-AWS4-HMAC-SHA256-PAYLOAD".parse().unwrap(),
);
assert!(is_aws_chunked(&h2));
let mut h3 = http::HeaderMap::new();
h3.insert("content-encoding", "gzip".parse().unwrap());
assert!(!is_aws_chunked(&h3));
}
#[test]
fn strip_aws_chunked_keeps_real_encoding() {
assert_eq!(strip_aws_chunked_encoding(Some("aws-chunked")), None);
assert_eq!(
strip_aws_chunked_encoding(Some("aws-chunked, gzip")).as_deref(),
Some("gzip")
);
assert_eq!(
strip_aws_chunked_encoding(Some("gzip")).as_deref(),
Some("gzip")
);
assert_eq!(strip_aws_chunked_encoding(None), None);
}
struct DefaultService;
#[async_trait]
impl AwsService for DefaultService {
fn service_name(&self) -> &str {
"default"
}
async fn handle(&self, _request: AwsRequest) -> Result<AwsResponse, AwsServiceError> {
unreachable!()
}
fn supported_actions(&self) -> &[&str] {
&[]
}
}
struct PopulatedService;
#[async_trait]
impl AwsService for PopulatedService {
fn service_name(&self) -> &str {
"populated"
}
async fn handle(&self, _request: AwsRequest) -> Result<AwsResponse, AwsServiceError> {
unreachable!()
}
fn supported_actions(&self) -> &[&str] {
&[]
}
fn iam_condition_keys_for(
&self,
_request: &AwsRequest,
_action: &IamAction,
) -> BTreeMap<String, Vec<String>> {
let mut m = BTreeMap::new();
m.insert("s3:prefix".to_string(), vec!["logs/".to_string()]);
m
}
}
fn sample_request() -> AwsRequest {
AwsRequest {
service: "default".into(),
action: "Noop".into(),
region: "us-east-1".into(),
account_id: "123456789012".into(),
request_id: "req-1".into(),
headers: HeaderMap::new(),
query_params: HashMap::new(),
body: Bytes::new(),
body_stream: parking_lot::Mutex::new(None),
path_segments: vec![],
raw_path: "/".into(),
raw_query: String::new(),
method: Method::GET,
is_query_protocol: false,
access_key_id: None,
principal: None,
}
}
fn sample_action() -> IamAction {
IamAction {
service: "s3",
action: "ListBucket",
resource: "arn:aws:s3:::my-bucket".to_string(),
}
}
#[test]
fn iam_condition_keys_for_default_is_empty() {
let svc = DefaultService;
let keys = svc.iam_condition_keys_for(&sample_request(), &sample_action());
assert!(keys.is_empty());
}
#[test]
fn iam_condition_keys_for_override_returns_map() {
let svc = PopulatedService;
let keys = svc.iam_condition_keys_for(&sample_request(), &sample_action());
assert_eq!(keys.get("s3:prefix"), Some(&vec!["logs/".to_string()]));
}
#[test]
fn response_body_len_and_is_empty_for_bytes() {
let body: ResponseBody = Bytes::from_static(b"hello").into();
assert_eq!(body.len(), 5);
assert!(!body.is_empty());
let empty: ResponseBody = ResponseBody::default();
assert!(empty.is_empty());
}
#[test]
fn response_body_from_vec_and_string_and_str() {
let from_vec: ResponseBody = vec![1u8, 2, 3].into();
assert_eq!(from_vec.expect_bytes(), &[1, 2, 3][..]);
let from_string: ResponseBody = String::from("hi").into();
assert_eq!(from_string.expect_bytes(), b"hi");
let from_str: ResponseBody = "hey".into();
assert_eq!(from_str.expect_bytes(), b"hey");
let from_static: ResponseBody = (b"123" as &'static [u8]).into();
assert_eq!(from_static.expect_bytes(), b"123");
}
#[test]
fn response_body_partial_eq_bytes() {
let body: ResponseBody = Bytes::from_static(b"x").into();
assert!(body == Bytes::from_static(b"x"));
assert!(!(body == Bytes::from_static(b"y")));
}
#[test]
fn aws_request_json_body_empty_returns_null() {
let req = sample_request();
assert_eq!(req.json_body(), serde_json::Value::Null);
}
#[test]
fn aws_request_json_body_parses_valid() {
let mut req = sample_request();
req.body = Bytes::from_static(br#"{"a":1}"#);
assert_eq!(req.json_body(), serde_json::json!({"a": 1}));
}
#[test]
fn aws_response_xml_constructor() {
let resp = AwsResponse::xml(StatusCode::OK, Bytes::from_static(b"<ok/>"));
assert_eq!(resp.status, StatusCode::OK);
assert_eq!(resp.content_type, "text/xml");
}
#[test]
fn aws_response_json_constructor() {
let resp = AwsResponse::json(StatusCode::CREATED, "{}");
assert_eq!(resp.status, StatusCode::CREATED);
assert_eq!(resp.content_type, "application/x-amz-json-1.1");
}
#[test]
fn aws_response_ok_json_helper() {
let resp = AwsResponse::ok_json(serde_json::json!({"ok": true}));
assert_eq!(resp.status, StatusCode::OK);
assert!(resp.body.expect_bytes().starts_with(b"{"));
}
#[test]
fn aws_error_service_not_found_fields() {
let err = AwsServiceError::ServiceNotFound {
service: "sqs".to_string(),
};
assert_eq!(err.status(), StatusCode::BAD_REQUEST);
assert_eq!(err.code(), "UnknownService");
assert!(err.message().contains("sqs"));
assert!(err.extra_fields().is_empty());
assert!(err.response_headers().is_empty());
}
#[test]
fn aws_error_action_not_implemented_fields() {
let err = AwsServiceError::action_not_implemented("sns", "FutureAction");
assert_eq!(err.status(), StatusCode::NOT_IMPLEMENTED);
assert_eq!(err.code(), "InvalidAction");
assert!(err.message().contains("FutureAction"));
assert!(err.message().contains("sns"));
}
#[test]
fn aws_error_aws_error_helpers() {
let e = AwsServiceError::aws_error(StatusCode::FORBIDDEN, "Denied", "no");
assert_eq!(e.status(), StatusCode::FORBIDDEN);
assert_eq!(e.code(), "Denied");
assert_eq!(e.message(), "no");
let fields = vec![("Bucket".to_string(), "b".to_string())];
let ef = AwsServiceError::aws_error_with_fields(
StatusCode::NOT_FOUND,
"Missing",
"gone",
fields.clone(),
);
assert_eq!(ef.extra_fields(), fields.as_slice());
let hdrs = vec![("X-Retry".to_string(), "1".to_string())];
let eh = AwsServiceError::aws_error_with_headers(
StatusCode::TOO_MANY_REQUESTS,
"Throttled",
"slow",
hdrs.clone(),
);
assert_eq!(eh.response_headers(), hdrs.as_slice());
}
#[test]
#[should_panic(expected = "expect_bytes called on ResponseBody::File")]
fn response_body_expect_bytes_panics_on_file() {
let f = std::fs::File::create(std::env::temp_dir().join("fc-test-expect-file")).unwrap();
let async_f = tokio::fs::File::from_std(f);
let body = ResponseBody::File {
file: async_f,
size: 0,
};
let _ = body.expect_bytes();
}
}