use async_trait::async_trait;
use bytes::Bytes;
use http::{HeaderMap, Method, StatusCode};
use md5::{Digest, Md5};
use parking_lot::Mutex;
use std::collections::{BTreeMap, HashMap};
use std::path::PathBuf;
use crate::auth::Principal;
pub type RequestBodyStream = axum::body::Body;
pub struct AwsRequest {
pub service: String,
pub action: String,
pub region: String,
pub account_id: String,
pub request_id: String,
pub headers: HeaderMap,
pub query_params: HashMap<String, String>,
pub body: Bytes,
pub body_stream: Mutex<Option<RequestBodyStream>>,
pub path_segments: Vec<String>,
pub raw_path: String,
pub raw_query: String,
pub method: Method,
pub is_query_protocol: bool,
pub access_key_id: Option<String>,
pub principal: Option<Principal>,
}
impl std::fmt::Debug for AwsRequest {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AwsRequest")
.field("service", &self.service)
.field("action", &self.action)
.field("region", &self.region)
.field("account_id", &self.account_id)
.field("request_id", &self.request_id)
.field("headers", &self.headers)
.field("query_params", &self.query_params)
.field("body_len", &self.body.len())
.field(
"body_stream",
&self.body_stream.lock().as_ref().map(|_| "<stream>"),
)
.field("path_segments", &self.path_segments)
.field("raw_path", &self.raw_path)
.field("raw_query", &self.raw_query)
.field("method", &self.method)
.field("is_query_protocol", &self.is_query_protocol)
.field("access_key_id", &self.access_key_id)
.field("principal", &self.principal)
.finish()
}
}
impl AwsRequest {
pub fn json_body(&self) -> serde_json::Value {
serde_json::from_slice(&self.body).unwrap_or(serde_json::Value::Null)
}
pub fn take_body_stream(&self) -> Option<RequestBodyStream> {
self.body_stream.lock().take()
}
}
pub async fn drain_request_stream(stream: RequestBodyStream) -> Result<Bytes, AwsServiceError> {
use http_body_util::BodyExt;
match stream.collect().await {
Ok(c) => Ok(c.to_bytes()),
Err(e) => Err(stream_error_to_aws(&e.to_string())),
}
}
fn stream_error_to_aws(msg: &str) -> AwsServiceError {
let too_large = msg.to_ascii_lowercase().contains("limit");
let (status, code, message) = if too_large {
(
StatusCode::PAYLOAD_TOO_LARGE,
"RequestEntityTooLarge",
"Streaming request body exceeded the configured limit",
)
} else {
(
StatusCode::BAD_REQUEST,
"MalformedRequestBody",
"Failed to read streaming request body",
)
};
AwsServiceError::aws_error(status, code, message)
}
#[derive(Debug)]
pub struct SpooledBody {
pub path: PathBuf,
pub size: u64,
pub md5_hex: String,
}
pub async fn spool_request_stream(
stream: RequestBodyStream,
dir: Option<&std::path::Path>,
) -> Result<SpooledBody, AwsServiceError> {
use http_body_util::BodyExt;
use tokio::io::AsyncWriteExt;
let dir = dir.map(|d| d.to_path_buf());
if let Some(d) = dir.as_ref() {
let _ = tokio::fs::create_dir_all(d).await;
}
let mut builder = tempfile::Builder::new();
builder.prefix("fc-spool-");
let named = match dir.as_ref() {
Some(d) => builder.tempfile_in(d),
None => builder.tempfile(),
}
.map_err(|e| {
AwsServiceError::aws_error(
StatusCode::INTERNAL_SERVER_ERROR,
"InternalError",
format!("failed to create spool tempfile: {e}"),
)
})?;
let (std_file, temp_path) = named.into_parts();
let path: PathBuf = temp_path.keep().map_err(|e| {
AwsServiceError::aws_error(
StatusCode::INTERNAL_SERVER_ERROR,
"InternalError",
format!("failed to persist spool tempfile: {e}"),
)
})?;
let mut file = tokio::fs::File::from_std(std_file);
let mut hasher = Md5::new();
let mut size: u64 = 0;
let mut body = stream;
async fn cleanup(file: tokio::fs::File, path: &std::path::Path) {
drop(file);
let _ = tokio::fs::remove_file(path).await;
}
loop {
match body.frame().await {
Some(Ok(frame)) => {
if let Ok(chunk) = frame.into_data() {
if !chunk.is_empty() {
hasher.update(&chunk);
size += chunk.len() as u64;
if let Err(e) = file.write_all(&chunk).await {
cleanup(file, &path).await;
return Err(AwsServiceError::aws_error(
StatusCode::INTERNAL_SERVER_ERROR,
"InternalError",
format!("failed to spool request body: {e}"),
));
}
}
}
}
Some(Err(e)) => {
cleanup(file, &path).await;
return Err(stream_error_to_aws(&e.to_string()));
}
None => break,
}
}
if let Err(e) = file.flush().await {
cleanup(file, &path).await;
return Err(AwsServiceError::aws_error(
StatusCode::INTERNAL_SERVER_ERROR,
"InternalError",
format!("failed to flush spool tempfile: {e}"),
));
}
drop(file);
let md5_hex = hex_lower(&hasher.finalize());
Ok(SpooledBody {
path,
size,
md5_hex,
})
}
fn hex_lower(bytes: &[u8]) -> String {
const HEX: &[u8] = b"0123456789abcdef";
let mut out = String::with_capacity(bytes.len() * 2);
for b in bytes {
out.push(HEX[(b >> 4) as usize] as char);
out.push(HEX[(b & 0x0f) as usize] as char);
}
out
}
#[derive(Debug)]
pub enum ResponseBody {
Bytes(Bytes),
File { file: tokio::fs::File, size: u64 },
}
impl ResponseBody {
pub fn len(&self) -> u64 {
match self {
ResponseBody::Bytes(b) => b.len() as u64,
ResponseBody::File { size, .. } => *size,
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn expect_bytes(&self) -> &[u8] {
match self {
ResponseBody::Bytes(b) => b,
ResponseBody::File { .. } => {
panic!("expect_bytes called on ResponseBody::File")
}
}
}
}
impl Default for ResponseBody {
fn default() -> Self {
ResponseBody::Bytes(Bytes::new())
}
}
impl From<Bytes> for ResponseBody {
fn from(b: Bytes) -> Self {
ResponseBody::Bytes(b)
}
}
impl From<Vec<u8>> for ResponseBody {
fn from(v: Vec<u8>) -> Self {
ResponseBody::Bytes(Bytes::from(v))
}
}
impl From<&'static [u8]> for ResponseBody {
fn from(s: &'static [u8]) -> Self {
ResponseBody::Bytes(Bytes::from_static(s))
}
}
impl From<String> for ResponseBody {
fn from(s: String) -> Self {
ResponseBody::Bytes(Bytes::from(s))
}
}
impl From<&'static str> for ResponseBody {
fn from(s: &'static str) -> Self {
ResponseBody::Bytes(Bytes::from_static(s.as_bytes()))
}
}
impl PartialEq<Bytes> for ResponseBody {
fn eq(&self, other: &Bytes) -> bool {
match self {
ResponseBody::Bytes(b) => b == other,
ResponseBody::File { .. } => false,
}
}
}
pub struct AwsResponse {
pub status: StatusCode,
pub content_type: String,
pub body: ResponseBody,
pub headers: HeaderMap,
}
impl AwsResponse {
pub fn xml(status: StatusCode, body: impl Into<Bytes>) -> Self {
Self {
status,
content_type: "text/xml".to_string(),
body: ResponseBody::Bytes(body.into()),
headers: HeaderMap::new(),
}
}
pub fn json(status: StatusCode, body: impl Into<Bytes>) -> Self {
Self {
status,
content_type: "application/x-amz-json-1.1".to_string(),
body: ResponseBody::Bytes(body.into()),
headers: HeaderMap::new(),
}
}
pub fn json_value(status: StatusCode, value: serde_json::Value) -> Self {
Self::json(
status,
serde_json::to_vec(&value).expect("serde_json::Value serialization is infallible"),
)
}
pub fn ok_json(value: serde_json::Value) -> Self {
Self::json_value(StatusCode::OK, value)
}
}
#[derive(Debug, thiserror::Error)]
pub enum AwsServiceError {
#[error("service not found: {service}")]
ServiceNotFound { service: String },
#[error("action {action} not implemented for service {service}")]
ActionNotImplemented { service: String, action: String },
#[error("{code}: {message}")]
AwsError {
status: StatusCode,
code: String,
message: String,
extra_fields: Vec<(String, String)>,
headers: Vec<(String, String)>,
},
}
impl AwsServiceError {
pub fn action_not_implemented(service: &str, action: &str) -> Self {
Self::ActionNotImplemented {
service: service.to_string(),
action: action.to_string(),
}
}
pub fn aws_error(
status: StatusCode,
code: impl Into<String>,
message: impl Into<String>,
) -> Self {
Self::AwsError {
status,
code: code.into(),
message: message.into(),
extra_fields: Vec::new(),
headers: Vec::new(),
}
}
pub fn aws_error_with_fields(
status: StatusCode,
code: impl Into<String>,
message: impl Into<String>,
extra_fields: Vec<(String, String)>,
) -> Self {
Self::AwsError {
status,
code: code.into(),
message: message.into(),
extra_fields,
headers: Vec::new(),
}
}
pub fn aws_error_with_headers(
status: StatusCode,
code: impl Into<String>,
message: impl Into<String>,
headers: Vec<(String, String)>,
) -> Self {
Self::AwsError {
status,
code: code.into(),
message: message.into(),
extra_fields: Vec::new(),
headers,
}
}
pub fn extra_fields(&self) -> &[(String, String)] {
match self {
Self::AwsError { extra_fields, .. } => extra_fields,
_ => &[],
}
}
pub fn status(&self) -> StatusCode {
match self {
Self::ServiceNotFound { .. } => StatusCode::BAD_REQUEST,
Self::ActionNotImplemented { .. } => StatusCode::NOT_IMPLEMENTED,
Self::AwsError { status, .. } => *status,
}
}
pub fn code(&self) -> &str {
match self {
Self::ServiceNotFound { .. } => "UnknownService",
Self::ActionNotImplemented { .. } => "InvalidAction",
Self::AwsError { code, .. } => code,
}
}
pub fn message(&self) -> String {
match self {
Self::ServiceNotFound { service } => format!("service not found: {service}"),
Self::ActionNotImplemented { service, action } => {
format!("action {action} not implemented for service {service}")
}
Self::AwsError { message, .. } => message.clone(),
}
}
pub fn response_headers(&self) -> &[(String, String)] {
match self {
Self::AwsError { headers, .. } => headers,
_ => &[],
}
}
}
#[async_trait]
pub trait AwsService: Send + Sync {
fn service_name(&self) -> &str;
async fn handle(&self, request: AwsRequest) -> Result<AwsResponse, AwsServiceError>;
fn supported_actions(&self) -> &[&str];
fn iam_enforceable(&self) -> bool {
false
}
fn iam_action_for(&self, _request: &AwsRequest) -> Option<crate::auth::IamAction> {
None
}
fn iam_condition_keys_for(
&self,
_request: &AwsRequest,
_action: &crate::auth::IamAction,
) -> BTreeMap<String, Vec<String>> {
BTreeMap::new()
}
fn resource_tags_for(
&self,
_resource_arn: &str,
) -> Option<std::collections::HashMap<String, String>> {
None
}
fn request_tags_from(
&self,
_request: &AwsRequest,
_action: &str,
) -> Option<std::collections::HashMap<String, String>> {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::IamAction;
use async_trait::async_trait;
struct DefaultService;
#[async_trait]
impl AwsService for DefaultService {
fn service_name(&self) -> &str {
"default"
}
async fn handle(&self, _request: AwsRequest) -> Result<AwsResponse, AwsServiceError> {
unreachable!()
}
fn supported_actions(&self) -> &[&str] {
&[]
}
}
struct PopulatedService;
#[async_trait]
impl AwsService for PopulatedService {
fn service_name(&self) -> &str {
"populated"
}
async fn handle(&self, _request: AwsRequest) -> Result<AwsResponse, AwsServiceError> {
unreachable!()
}
fn supported_actions(&self) -> &[&str] {
&[]
}
fn iam_condition_keys_for(
&self,
_request: &AwsRequest,
_action: &IamAction,
) -> BTreeMap<String, Vec<String>> {
let mut m = BTreeMap::new();
m.insert("s3:prefix".to_string(), vec!["logs/".to_string()]);
m
}
}
fn sample_request() -> AwsRequest {
AwsRequest {
service: "default".into(),
action: "Noop".into(),
region: "us-east-1".into(),
account_id: "123456789012".into(),
request_id: "req-1".into(),
headers: HeaderMap::new(),
query_params: HashMap::new(),
body: Bytes::new(),
body_stream: parking_lot::Mutex::new(None),
path_segments: vec![],
raw_path: "/".into(),
raw_query: String::new(),
method: Method::GET,
is_query_protocol: false,
access_key_id: None,
principal: None,
}
}
fn sample_action() -> IamAction {
IamAction {
service: "s3",
action: "ListBucket",
resource: "arn:aws:s3:::my-bucket".to_string(),
}
}
#[test]
fn iam_condition_keys_for_default_is_empty() {
let svc = DefaultService;
let keys = svc.iam_condition_keys_for(&sample_request(), &sample_action());
assert!(keys.is_empty());
}
#[test]
fn iam_condition_keys_for_override_returns_map() {
let svc = PopulatedService;
let keys = svc.iam_condition_keys_for(&sample_request(), &sample_action());
assert_eq!(keys.get("s3:prefix"), Some(&vec!["logs/".to_string()]));
}
#[test]
fn response_body_len_and_is_empty_for_bytes() {
let body: ResponseBody = Bytes::from_static(b"hello").into();
assert_eq!(body.len(), 5);
assert!(!body.is_empty());
let empty: ResponseBody = ResponseBody::default();
assert!(empty.is_empty());
}
#[test]
fn response_body_from_vec_and_string_and_str() {
let from_vec: ResponseBody = vec![1u8, 2, 3].into();
assert_eq!(from_vec.expect_bytes(), &[1, 2, 3][..]);
let from_string: ResponseBody = String::from("hi").into();
assert_eq!(from_string.expect_bytes(), b"hi");
let from_str: ResponseBody = "hey".into();
assert_eq!(from_str.expect_bytes(), b"hey");
let from_static: ResponseBody = (b"123" as &'static [u8]).into();
assert_eq!(from_static.expect_bytes(), b"123");
}
#[test]
fn response_body_partial_eq_bytes() {
let body: ResponseBody = Bytes::from_static(b"x").into();
assert!(body == Bytes::from_static(b"x"));
assert!(!(body == Bytes::from_static(b"y")));
}
#[test]
fn aws_request_json_body_empty_returns_null() {
let req = sample_request();
assert_eq!(req.json_body(), serde_json::Value::Null);
}
#[test]
fn aws_request_json_body_parses_valid() {
let mut req = sample_request();
req.body = Bytes::from_static(br#"{"a":1}"#);
assert_eq!(req.json_body(), serde_json::json!({"a": 1}));
}
#[test]
fn aws_response_xml_constructor() {
let resp = AwsResponse::xml(StatusCode::OK, Bytes::from_static(b"<ok/>"));
assert_eq!(resp.status, StatusCode::OK);
assert_eq!(resp.content_type, "text/xml");
}
#[test]
fn aws_response_json_constructor() {
let resp = AwsResponse::json(StatusCode::CREATED, "{}");
assert_eq!(resp.status, StatusCode::CREATED);
assert_eq!(resp.content_type, "application/x-amz-json-1.1");
}
#[test]
fn aws_response_ok_json_helper() {
let resp = AwsResponse::ok_json(serde_json::json!({"ok": true}));
assert_eq!(resp.status, StatusCode::OK);
assert!(resp.body.expect_bytes().starts_with(b"{"));
}
#[test]
fn aws_error_service_not_found_fields() {
let err = AwsServiceError::ServiceNotFound {
service: "sqs".to_string(),
};
assert_eq!(err.status(), StatusCode::BAD_REQUEST);
assert_eq!(err.code(), "UnknownService");
assert!(err.message().contains("sqs"));
assert!(err.extra_fields().is_empty());
assert!(err.response_headers().is_empty());
}
#[test]
fn aws_error_action_not_implemented_fields() {
let err = AwsServiceError::action_not_implemented("sns", "FutureAction");
assert_eq!(err.status(), StatusCode::NOT_IMPLEMENTED);
assert_eq!(err.code(), "InvalidAction");
assert!(err.message().contains("FutureAction"));
assert!(err.message().contains("sns"));
}
#[test]
fn aws_error_aws_error_helpers() {
let e = AwsServiceError::aws_error(StatusCode::FORBIDDEN, "Denied", "no");
assert_eq!(e.status(), StatusCode::FORBIDDEN);
assert_eq!(e.code(), "Denied");
assert_eq!(e.message(), "no");
let fields = vec![("Bucket".to_string(), "b".to_string())];
let ef = AwsServiceError::aws_error_with_fields(
StatusCode::NOT_FOUND,
"Missing",
"gone",
fields.clone(),
);
assert_eq!(ef.extra_fields(), fields.as_slice());
let hdrs = vec![("X-Retry".to_string(), "1".to_string())];
let eh = AwsServiceError::aws_error_with_headers(
StatusCode::TOO_MANY_REQUESTS,
"Throttled",
"slow",
hdrs.clone(),
);
assert_eq!(eh.response_headers(), hdrs.as_slice());
}
#[test]
#[should_panic(expected = "expect_bytes called on ResponseBody::File")]
fn response_body_expect_bytes_panics_on_file() {
let f = std::fs::File::create(std::env::temp_dir().join("fc-test-expect-file")).unwrap();
let async_f = tokio::fs::File::from_std(f);
let body = ResponseBody::File {
file: async_f,
size: 0,
};
let _ = body.expect_bytes();
}
}