use crate::actions;
use crate::constants;
use crate::store::Store;
use crate::validation;
use crate::validation::rules;
use axum::body::Bytes;
use axum::extract::State;
use axum::http::{HeaderMap, Method, StatusCode, Uri};
use axum::response::{IntoResponse, Response};
use serde_json::{Value, json};
const MAX_REQUEST_BYTES: usize = 7 * 1024 * 1024;
const AMZ_JSON: &str = "application/x-amz-json-1.1";
const AMZ_CBOR: &str = "application/x-amz-cbor-1.1";
const VALID_APIS: &[&str] = &["Kinesis_20131202"];
const VALID_OPERATIONS: &[&str] = &[
"AddTagsToStream",
"CreateStream",
"DecreaseStreamRetentionPeriod",
"DeleteResourcePolicy",
"DeleteStream",
"DeregisterStreamConsumer",
"DescribeAccountSettings",
"DescribeLimits",
"DescribeStream",
"DescribeStreamConsumer",
"DescribeStreamSummary",
"DisableEnhancedMonitoring",
"EnableEnhancedMonitoring",
"GetRecords",
"GetResourcePolicy",
"GetShardIterator",
"IncreaseStreamRetentionPeriod",
"ListShards",
"ListStreamConsumers",
"ListStreams",
"ListTagsForResource",
"ListTagsForStream",
"MergeShards",
"PutRecord",
"PutRecords",
"PutResourcePolicy",
"RegisterStreamConsumer",
"RemoveTagsFromStream",
"SplitShard",
"StartStreamEncryption",
"StopStreamEncryption",
"SubscribeToShard",
"TagResource",
"UntagResource",
"UpdateAccountSettings",
"UpdateMaxRecordSize",
"UpdateShardCount",
"UpdateStreamMode",
"UpdateStreamWarmThroughput",
];
pub async fn handler(
method: Method,
uri: Uri,
headers: HeaderMap,
State(store): State<Store>,
body: Bytes,
) -> Response {
let request_id = uuid::Uuid::new_v4().to_string();
let mut response_headers = HeaderMap::new();
response_headers.insert("x-amzn-RequestId", request_id.parse().unwrap());
let has_origin = headers.get("origin").is_some();
if method != Method::OPTIONS || !has_origin {
let id2 = base64::Engine::encode(
&base64::engine::general_purpose::STANDARD,
rand::random::<[u8; 72]>(),
);
response_headers.insert("x-amz-id-2", id2.parse().unwrap());
}
if has_origin {
response_headers.insert("Access-Control-Allow-Origin", "*".parse().unwrap());
if method == Method::OPTIONS {
if let Some(req_headers) = headers.get("access-control-request-headers") {
response_headers.insert("Access-Control-Allow-Headers", req_headers.clone());
}
if let Some(req_method) = headers.get("access-control-request-method") {
response_headers.insert("Access-Control-Allow-Methods", req_method.clone());
}
response_headers.insert("Access-Control-Max-Age", "172800".parse().unwrap());
response_headers.insert("Content-Length", "0".parse().unwrap());
return (StatusCode::OK, response_headers, "").into_response();
}
response_headers.insert(
"Access-Control-Expose-Headers",
"x-amzn-RequestId,x-amzn-ErrorType,x-amz-request-id,x-amz-id-2,x-amzn-ErrorMessage,Date".parse().unwrap(),
);
}
if method != Method::POST {
return send_xml_error(
&response_headers,
"AccessDeniedException",
"Unable to determine service/operation name to be authorized",
403,
);
}
let content_type = headers
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.split(';')
.next()
.unwrap_or("")
.trim();
let content_valid = matches!(
content_type,
"application/x-amz-json-1.1" | "application/x-amz-cbor-1.1" | "application/json"
);
let target = headers
.get("x-amz-target")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let parts: Vec<&str> = target.splitn(2, '.').collect();
let service = parts.first().copied().unwrap_or("");
let operation = if parts.len() > 1 { parts[1] } else { "" };
let service_valid = !service.is_empty() && VALID_APIS.contains(&service);
let operation_valid = !operation.is_empty() && VALID_OPERATIONS.contains(&operation);
let response_content_type = if content_type == AMZ_JSON {
AMZ_JSON
} else {
AMZ_CBOR
};
if body.is_empty() {
let error_type = if service_valid && operation_valid {
constants::SERIALIZATION_EXCEPTION
} else {
constants::UNKNOWN_OPERATION
};
return send_json_response(
&response_headers,
response_content_type,
&json!({"__type": error_type}),
400,
);
}
if body.len() > MAX_REQUEST_BYTES {
response_headers.insert("Transfer-Encoding", "chunked".parse().unwrap());
return (StatusCode::from_u16(413).unwrap(), response_headers, "").into_response();
}
if !content_valid {
if service.is_empty() || operation.is_empty() {
return send_xml_error(
&response_headers,
"AccessDeniedException",
"Unable to determine service/operation name to be authorized",
403,
);
}
return send_xml_error_code(&response_headers, constants::UNKNOWN_OPERATION, 404);
}
let data: Option<Value> = if content_type == AMZ_CBOR {
ciborium::from_reader::<Value, _>(&body[..]).ok()
} else {
serde_json::from_slice(&body).ok()
};
let data = match data {
Some(Value::Object(map)) => Value::Object(map),
Some(_) | None => {
if content_type == "application/json" {
return send_json_response(
&response_headers,
"application/json",
&json!({
"Output": {"__type": "com.amazon.coral.service#SerializationException"},
"Version": "1.0",
}),
400,
);
}
return send_json_response(
&response_headers,
response_content_type,
&json!({"__type": constants::SERIALIZATION_EXCEPTION}),
400,
);
}
};
if content_type == "application/json" {
return send_json_response(
&response_headers,
"application/json",
&json!({
"Output": {"__type": "com.amazon.coral.service#UnknownOperationException"},
"Version": "1.0",
}),
404,
);
}
if !service_valid || !operation_valid {
return send_json_response(
&response_headers,
response_content_type,
&json!({"__type": constants::UNKNOWN_OPERATION}),
400,
);
}
let auth_header = headers.get("authorization").and_then(|v| v.to_str().ok());
let query_string = uri.query().unwrap_or("");
let auth_query = query_string.contains("X-Amz-Algorithm");
if auth_header.is_some() && auth_query {
return send_error_response(
&response_headers,
content_valid,
response_content_type,
"InvalidSignatureException",
"Found both 'X-Amz-Algorithm' as a query-string param and 'Authorization' as HTTP header.",
400,
);
}
if auth_header.is_none() && !auth_query {
return send_error_response(
&response_headers,
content_valid,
response_content_type,
"MissingAuthenticationTokenException",
"Missing Authentication Token",
400,
);
}
if let Some(auth) = auth_header {
let mut msg = String::new();
let auth_params: std::collections::HashMap<String, String> = auth
.split([',', ' '])
.skip(1)
.filter(|s| !s.is_empty())
.filter_map(|s| {
let kv: Vec<&str> = s.trim().splitn(2, '=').collect();
if kv.len() == 2 {
Some((kv[0].to_string(), kv[1].to_string()))
} else {
None
}
})
.collect();
for param in ["Credential", "Signature", "SignedHeaders"] {
if !auth_params.contains_key(param) {
msg += &format!("Authorization header requires '{param}' parameter. ");
}
}
if !headers.contains_key("x-amz-date") && !headers.contains_key("date") {
msg += "Authorization header requires existence of either a 'X-Amz-Date' or a 'Date' header. ";
}
if !msg.is_empty() {
msg += &format!("Authorization={auth}");
return send_error_response(
&response_headers,
content_valid,
response_content_type,
"IncompleteSignatureException",
&msg,
400,
);
}
} else {
let query_params: std::collections::HashMap<String, String> = uri
.query()
.unwrap_or("")
.split('&')
.filter_map(|s| {
let kv: Vec<&str> = s.splitn(2, '=').collect();
if kv.len() == 2 {
Some((kv[0].to_string(), kv[1].to_string()))
} else if !kv[0].is_empty() {
Some((kv[0].to_string(), String::new()))
} else {
None
}
})
.collect();
let mut msg = String::new();
for param in [
"X-Amz-Algorithm",
"X-Amz-Credential",
"X-Amz-Signature",
"X-Amz-SignedHeaders",
"X-Amz-Date",
] {
if !query_params.contains_key(param) || query_params[param].is_empty() {
msg += &format!("AWS query-string parameters must include '{param}'. ");
}
}
if !msg.is_empty() {
msg += "Re-examine the query-string parameters.";
return send_error_response(
&response_headers,
content_valid,
response_content_type,
"IncompleteSignatureException",
&msg,
400,
);
}
}
let validation_rules = get_validation_rules(operation);
let field_refs: Vec<(&str, &validation::FieldDef)> =
validation_rules.iter().map(|(k, v)| (*k, v)).collect();
let data = match validation::check_types(&data, &field_refs) {
Ok(d) => d,
Err(err) => {
return send_json_response(
&response_headers,
response_content_type,
&json!({"__type": err.body.__type, "Message": err.body.message_upper}),
err.status_code,
);
}
};
if let Err(err) = validation::check_validations(&data, &field_refs, None) {
return send_json_response(
&response_headers,
response_content_type,
&json!({"__type": err.body.__type, "message": err.body.message}),
err.status_code,
);
}
if operation == "SubscribeToShard" {
return match actions::subscribe_to_shard::execute_streaming(&store, data).await {
Ok(body) => {
response_headers.insert(
"Content-Type",
"application/vnd.amazon.eventstream".parse().unwrap(),
);
(StatusCode::OK, response_headers, body).into_response()
}
Err(err) => send_json_response(
&response_headers,
response_content_type,
&json!({"__type": err.body.__type, "message": err.body.message}),
err.status_code,
),
};
}
match actions::dispatch(&store, operation, data).await {
Ok(Some(result)) => {
send_json_response(&response_headers, response_content_type, &result, 200)
}
Ok(None) => {
response_headers.insert("Content-Type", response_content_type.parse().unwrap());
response_headers.insert("Content-Length", "0".parse().unwrap());
(StatusCode::OK, response_headers, "").into_response()
}
Err(err) => send_json_response(
&response_headers,
response_content_type,
&json!({"__type": err.body.__type, "message": err.body.message}),
err.status_code,
),
}
}
fn get_validation_rules(operation: &str) -> Vec<(&'static str, validation::FieldDef)> {
match operation {
"AddTagsToStream" => rules::add_tags_to_stream(),
"CreateStream" => rules::create_stream(),
"DecreaseStreamRetentionPeriod" => rules::decrease_stream_retention_period(),
"DeleteResourcePolicy" => rules::delete_resource_policy(),
"DeleteStream" => rules::delete_stream(),
"DeregisterStreamConsumer" => rules::deregister_stream_consumer(),
"DescribeAccountSettings" => vec![],
"DescribeLimits" => vec![],
"DescribeStream" => rules::describe_stream(),
"DescribeStreamConsumer" => rules::describe_stream_consumer(),
"DescribeStreamSummary" => rules::describe_stream_summary(),
"DisableEnhancedMonitoring" => rules::disable_enhanced_monitoring(),
"EnableEnhancedMonitoring" => rules::enable_enhanced_monitoring(),
"GetRecords" => rules::get_records(),
"GetResourcePolicy" => rules::get_resource_policy(),
"GetShardIterator" => rules::get_shard_iterator(),
"IncreaseStreamRetentionPeriod" => rules::increase_stream_retention_period(),
"ListShards" => rules::list_shards(),
"ListStreamConsumers" => rules::list_stream_consumers(),
"ListStreams" => rules::list_streams(),
"ListTagsForResource" => rules::list_tags_for_resource(),
"ListTagsForStream" => rules::list_tags_for_stream(),
"MergeShards" => rules::merge_shards(),
"PutRecord" => rules::put_record(),
"PutRecords" => rules::put_records(),
"PutResourcePolicy" => rules::put_resource_policy(),
"RegisterStreamConsumer" => rules::register_stream_consumer(),
"RemoveTagsFromStream" => rules::remove_tags_from_stream(),
"SplitShard" => rules::split_shard(),
"StartStreamEncryption" => rules::start_stream_encryption(),
"StopStreamEncryption" => rules::stop_stream_encryption(),
"SubscribeToShard" => rules::subscribe_to_shard(),
"TagResource" => rules::tag_resource(),
"UntagResource" => rules::untag_resource(),
"UpdateAccountSettings" => vec![],
"UpdateMaxRecordSize" => rules::update_max_record_size(),
"UpdateShardCount" => rules::update_shard_count(),
"UpdateStreamMode" => rules::update_stream_mode(),
"UpdateStreamWarmThroughput" => rules::update_stream_warm_throughput(),
_ => vec![],
}
}
fn send_json_response(
extra_headers: &HeaderMap,
content_type: &str,
data: &Value,
status_code: u16,
) -> Response {
let body_bytes = if content_type == AMZ_CBOR {
let mut buf = Vec::new();
let _ = ciborium::into_writer(data, &mut buf);
buf
} else {
serde_json::to_vec(data).unwrap_or_default()
};
let mut headers = extra_headers.clone();
headers.insert("Content-Type", content_type.parse().unwrap());
headers.insert(
"Content-Length",
body_bytes.len().to_string().parse().unwrap(),
);
(
StatusCode::from_u16(status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
headers,
body_bytes,
)
.into_response()
}
fn send_xml_error(
extra_headers: &HeaderMap,
error_type: &str,
message: &str,
status_code: u16,
) -> Response {
let body = format!("<{error_type}>\n <Message>{message}</Message>\n</{error_type}>\n");
let mut headers = extra_headers.clone();
headers.insert("Content-Length", body.len().to_string().parse().unwrap());
(
StatusCode::from_u16(status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
headers,
body,
)
.into_response()
}
fn send_xml_error_code(extra_headers: &HeaderMap, error_type: &str, status_code: u16) -> Response {
let body = format!("<{error_type}/>\n");
let mut headers = extra_headers.clone();
headers.insert("Content-Length", body.len().to_string().parse().unwrap());
(
StatusCode::from_u16(status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
headers,
body,
)
.into_response()
}
fn send_error_response(
extra_headers: &HeaderMap,
content_valid: bool,
content_type: &str,
error_type: &str,
message: &str,
status_code: u16,
) -> Response {
if content_valid {
send_json_response(
extra_headers,
content_type,
&json!({"__type": error_type, "message": message}),
status_code,
)
} else {
send_xml_error(extra_headers, error_type, message, 403)
}
}