use async_trait::async_trait;
use bytes::Bytes;
use http::{HeaderMap, Method, StatusCode};
use std::collections::{BTreeMap, HashMap};
use crate::auth::Principal;
#[derive(Debug)]
pub struct AwsRequest {
pub service: String,
pub action: String,
pub region: String,
pub account_id: String,
pub request_id: String,
pub headers: HeaderMap,
pub query_params: HashMap<String, String>,
pub body: Bytes,
pub path_segments: Vec<String>,
pub raw_path: String,
pub raw_query: String,
pub method: Method,
pub is_query_protocol: bool,
pub access_key_id: Option<String>,
pub principal: Option<Principal>,
}
impl AwsRequest {
pub fn json_body(&self) -> serde_json::Value {
serde_json::from_slice(&self.body).unwrap_or(serde_json::Value::Null)
}
}
#[derive(Debug)]
pub enum ResponseBody {
Bytes(Bytes),
File { file: tokio::fs::File, size: u64 },
}
impl ResponseBody {
pub fn len(&self) -> u64 {
match self {
ResponseBody::Bytes(b) => b.len() as u64,
ResponseBody::File { size, .. } => *size,
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn expect_bytes(&self) -> &[u8] {
match self {
ResponseBody::Bytes(b) => b,
ResponseBody::File { .. } => {
panic!("expect_bytes called on ResponseBody::File")
}
}
}
}
impl Default for ResponseBody {
fn default() -> Self {
ResponseBody::Bytes(Bytes::new())
}
}
impl From<Bytes> for ResponseBody {
fn from(b: Bytes) -> Self {
ResponseBody::Bytes(b)
}
}
impl From<Vec<u8>> for ResponseBody {
fn from(v: Vec<u8>) -> Self {
ResponseBody::Bytes(Bytes::from(v))
}
}
impl From<&'static [u8]> for ResponseBody {
fn from(s: &'static [u8]) -> Self {
ResponseBody::Bytes(Bytes::from_static(s))
}
}
impl From<String> for ResponseBody {
fn from(s: String) -> Self {
ResponseBody::Bytes(Bytes::from(s))
}
}
impl From<&'static str> for ResponseBody {
fn from(s: &'static str) -> Self {
ResponseBody::Bytes(Bytes::from_static(s.as_bytes()))
}
}
impl PartialEq<Bytes> for ResponseBody {
fn eq(&self, other: &Bytes) -> bool {
match self {
ResponseBody::Bytes(b) => b == other,
ResponseBody::File { .. } => false,
}
}
}
pub struct AwsResponse {
pub status: StatusCode,
pub content_type: String,
pub body: ResponseBody,
pub headers: HeaderMap,
}
impl AwsResponse {
pub fn xml(status: StatusCode, body: impl Into<Bytes>) -> Self {
Self {
status,
content_type: "text/xml".to_string(),
body: ResponseBody::Bytes(body.into()),
headers: HeaderMap::new(),
}
}
pub fn json(status: StatusCode, body: impl Into<Bytes>) -> Self {
Self {
status,
content_type: "application/x-amz-json-1.1".to_string(),
body: ResponseBody::Bytes(body.into()),
headers: HeaderMap::new(),
}
}
pub fn ok_json(value: serde_json::Value) -> Self {
Self::json(StatusCode::OK, serde_json::to_vec(&value).unwrap())
}
}
#[derive(Debug, thiserror::Error)]
pub enum AwsServiceError {
#[error("service not found: {service}")]
ServiceNotFound { service: String },
#[error("action {action} not implemented for service {service}")]
ActionNotImplemented { service: String, action: String },
#[error("{code}: {message}")]
AwsError {
status: StatusCode,
code: String,
message: String,
extra_fields: Vec<(String, String)>,
headers: Vec<(String, String)>,
},
}
impl AwsServiceError {
pub fn action_not_implemented(service: &str, action: &str) -> Self {
Self::ActionNotImplemented {
service: service.to_string(),
action: action.to_string(),
}
}
pub fn aws_error(
status: StatusCode,
code: impl Into<String>,
message: impl Into<String>,
) -> Self {
Self::AwsError {
status,
code: code.into(),
message: message.into(),
extra_fields: Vec::new(),
headers: Vec::new(),
}
}
pub fn aws_error_with_fields(
status: StatusCode,
code: impl Into<String>,
message: impl Into<String>,
extra_fields: Vec<(String, String)>,
) -> Self {
Self::AwsError {
status,
code: code.into(),
message: message.into(),
extra_fields,
headers: Vec::new(),
}
}
pub fn aws_error_with_headers(
status: StatusCode,
code: impl Into<String>,
message: impl Into<String>,
headers: Vec<(String, String)>,
) -> Self {
Self::AwsError {
status,
code: code.into(),
message: message.into(),
extra_fields: Vec::new(),
headers,
}
}
pub fn extra_fields(&self) -> &[(String, String)] {
match self {
Self::AwsError { extra_fields, .. } => extra_fields,
_ => &[],
}
}
pub fn status(&self) -> StatusCode {
match self {
Self::ServiceNotFound { .. } => StatusCode::BAD_REQUEST,
Self::ActionNotImplemented { .. } => StatusCode::NOT_IMPLEMENTED,
Self::AwsError { status, .. } => *status,
}
}
pub fn code(&self) -> &str {
match self {
Self::ServiceNotFound { .. } => "UnknownService",
Self::ActionNotImplemented { .. } => "InvalidAction",
Self::AwsError { code, .. } => code,
}
}
pub fn message(&self) -> String {
match self {
Self::ServiceNotFound { service } => format!("service not found: {service}"),
Self::ActionNotImplemented { service, action } => {
format!("action {action} not implemented for service {service}")
}
Self::AwsError { message, .. } => message.clone(),
}
}
pub fn response_headers(&self) -> &[(String, String)] {
match self {
Self::AwsError { headers, .. } => headers,
_ => &[],
}
}
}
#[async_trait]
pub trait AwsService: Send + Sync {
fn service_name(&self) -> &str;
async fn handle(&self, request: AwsRequest) -> Result<AwsResponse, AwsServiceError>;
fn supported_actions(&self) -> &[&str];
fn iam_enforceable(&self) -> bool {
false
}
fn iam_action_for(&self, _request: &AwsRequest) -> Option<crate::auth::IamAction> {
None
}
fn iam_condition_keys_for(
&self,
_request: &AwsRequest,
_action: &crate::auth::IamAction,
) -> BTreeMap<String, Vec<String>> {
BTreeMap::new()
}
fn resource_tags_for(
&self,
_resource_arn: &str,
) -> Option<std::collections::HashMap<String, String>> {
None
}
fn request_tags_from(
&self,
_request: &AwsRequest,
_action: &str,
) -> Option<std::collections::HashMap<String, String>> {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::IamAction;
use async_trait::async_trait;
struct DefaultService;
#[async_trait]
impl AwsService for DefaultService {
fn service_name(&self) -> &str {
"default"
}
async fn handle(&self, _request: AwsRequest) -> Result<AwsResponse, AwsServiceError> {
unreachable!()
}
fn supported_actions(&self) -> &[&str] {
&[]
}
}
struct PopulatedService;
#[async_trait]
impl AwsService for PopulatedService {
fn service_name(&self) -> &str {
"populated"
}
async fn handle(&self, _request: AwsRequest) -> Result<AwsResponse, AwsServiceError> {
unreachable!()
}
fn supported_actions(&self) -> &[&str] {
&[]
}
fn iam_condition_keys_for(
&self,
_request: &AwsRequest,
_action: &IamAction,
) -> BTreeMap<String, Vec<String>> {
let mut m = BTreeMap::new();
m.insert("s3:prefix".to_string(), vec!["logs/".to_string()]);
m
}
}
fn sample_request() -> AwsRequest {
AwsRequest {
service: "default".into(),
action: "Noop".into(),
region: "us-east-1".into(),
account_id: "123456789012".into(),
request_id: "req-1".into(),
headers: HeaderMap::new(),
query_params: HashMap::new(),
body: Bytes::new(),
path_segments: vec![],
raw_path: "/".into(),
raw_query: String::new(),
method: Method::GET,
is_query_protocol: false,
access_key_id: None,
principal: None,
}
}
fn sample_action() -> IamAction {
IamAction {
service: "s3",
action: "ListBucket",
resource: "arn:aws:s3:::my-bucket".to_string(),
}
}
#[test]
fn iam_condition_keys_for_default_is_empty() {
let svc = DefaultService;
let keys = svc.iam_condition_keys_for(&sample_request(), &sample_action());
assert!(keys.is_empty());
}
#[test]
fn iam_condition_keys_for_override_returns_map() {
let svc = PopulatedService;
let keys = svc.iam_condition_keys_for(&sample_request(), &sample_action());
assert_eq!(keys.get("s3:prefix"), Some(&vec!["logs/".to_string()]));
}
#[test]
fn response_body_len_and_is_empty_for_bytes() {
let body: ResponseBody = Bytes::from_static(b"hello").into();
assert_eq!(body.len(), 5);
assert!(!body.is_empty());
let empty: ResponseBody = ResponseBody::default();
assert!(empty.is_empty());
}
#[test]
fn response_body_from_vec_and_string_and_str() {
let from_vec: ResponseBody = vec![1u8, 2, 3].into();
assert_eq!(from_vec.expect_bytes(), &[1, 2, 3][..]);
let from_string: ResponseBody = String::from("hi").into();
assert_eq!(from_string.expect_bytes(), b"hi");
let from_str: ResponseBody = "hey".into();
assert_eq!(from_str.expect_bytes(), b"hey");
let from_static: ResponseBody = (b"123" as &'static [u8]).into();
assert_eq!(from_static.expect_bytes(), b"123");
}
#[test]
fn response_body_partial_eq_bytes() {
let body: ResponseBody = Bytes::from_static(b"x").into();
assert!(body == Bytes::from_static(b"x"));
assert!(!(body == Bytes::from_static(b"y")));
}
#[test]
fn aws_request_json_body_empty_returns_null() {
let req = sample_request();
assert_eq!(req.json_body(), serde_json::Value::Null);
}
#[test]
fn aws_request_json_body_parses_valid() {
let mut req = sample_request();
req.body = Bytes::from_static(br#"{"a":1}"#);
assert_eq!(req.json_body(), serde_json::json!({"a": 1}));
}
#[test]
fn aws_response_xml_constructor() {
let resp = AwsResponse::xml(StatusCode::OK, Bytes::from_static(b"<ok/>"));
assert_eq!(resp.status, StatusCode::OK);
assert_eq!(resp.content_type, "text/xml");
}
#[test]
fn aws_response_json_constructor() {
let resp = AwsResponse::json(StatusCode::CREATED, "{}");
assert_eq!(resp.status, StatusCode::CREATED);
assert_eq!(resp.content_type, "application/x-amz-json-1.1");
}
#[test]
fn aws_response_ok_json_helper() {
let resp = AwsResponse::ok_json(serde_json::json!({"ok": true}));
assert_eq!(resp.status, StatusCode::OK);
assert!(resp.body.expect_bytes().starts_with(b"{"));
}
#[test]
fn aws_error_service_not_found_fields() {
let err = AwsServiceError::ServiceNotFound {
service: "sqs".to_string(),
};
assert_eq!(err.status(), StatusCode::BAD_REQUEST);
assert_eq!(err.code(), "UnknownService");
assert!(err.message().contains("sqs"));
assert!(err.extra_fields().is_empty());
assert!(err.response_headers().is_empty());
}
#[test]
fn aws_error_action_not_implemented_fields() {
let err = AwsServiceError::action_not_implemented("sns", "FutureAction");
assert_eq!(err.status(), StatusCode::NOT_IMPLEMENTED);
assert_eq!(err.code(), "InvalidAction");
assert!(err.message().contains("FutureAction"));
assert!(err.message().contains("sns"));
}
#[test]
fn aws_error_aws_error_helpers() {
let e = AwsServiceError::aws_error(StatusCode::FORBIDDEN, "Denied", "no");
assert_eq!(e.status(), StatusCode::FORBIDDEN);
assert_eq!(e.code(), "Denied");
assert_eq!(e.message(), "no");
let fields = vec![("Bucket".to_string(), "b".to_string())];
let ef = AwsServiceError::aws_error_with_fields(
StatusCode::NOT_FOUND,
"Missing",
"gone",
fields.clone(),
);
assert_eq!(ef.extra_fields(), fields.as_slice());
let hdrs = vec![("X-Retry".to_string(), "1".to_string())];
let eh = AwsServiceError::aws_error_with_headers(
StatusCode::TOO_MANY_REQUESTS,
"Throttled",
"slow",
hdrs.clone(),
);
assert_eq!(eh.response_headers(), hdrs.as_slice());
}
#[test]
#[should_panic(expected = "expect_bytes called on ResponseBody::File")]
fn response_body_expect_bytes_panics_on_file() {
let f = std::fs::File::create(std::env::temp_dir().join("fc-test-expect-file")).unwrap();
let async_f = tokio::fs::File::from_std(f);
let body = ResponseBody::File {
file: async_f,
size: 0,
};
let _ = body.expect_bytes();
}
}