use crate::actions::{self, Operation};
use crate::constants;
use crate::error::KinesisErrorResponse;
#[cfg(feature = "mirror")]
use crate::mirror::Mirror;
use crate::store::Store;
use crate::validation;
use axum::body::Bytes;
#[cfg(feature = "mirror")]
use axum::extract::Extension;
use axum::extract::{Request, State};
use axum::http::{HeaderMap, Method, StatusCode, Uri};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use base64::Engine;
use serde::Serialize;
use serde_json::{Value, json};
#[cfg(feature = "mirror")]
use std::sync::Arc;
use tracing::Instrument;
#[cfg(feature = "mirror")]
type MirrorExt = Option<Extension<Arc<Mirror>>>;
#[cfg(not(feature = "mirror"))]
type MirrorExt = ();
pub async fn handler(
method: Method,
uri: Uri,
headers: HeaderMap,
State(store): State<Store>,
mirror: MirrorExt,
body: Bytes,
) -> Response {
let request_id = uuid::Uuid::new_v4().to_string();
let mut response_headers = HeaderMap::new();
response_headers.insert("x-amzn-RequestId", request_id.parse().unwrap());
let has_origin = headers.get("origin").is_some();
if method != Method::OPTIONS || !has_origin {
let id2 = base64::Engine::encode(
&base64::engine::general_purpose::STANDARD,
rand::random::<[u8; 72]>(),
);
response_headers.insert("x-amz-id-2", id2.parse().unwrap());
}
if has_origin {
response_headers.insert("Access-Control-Allow-Origin", "*".parse().unwrap());
if method == Method::OPTIONS {
if let Some(req_headers) = headers.get("access-control-request-headers") {
response_headers.insert("Access-Control-Allow-Headers", req_headers.clone());
}
if let Some(req_method) = headers.get("access-control-request-method") {
response_headers.insert("Access-Control-Allow-Methods", req_method.clone());
}
response_headers.insert("Access-Control-Max-Age", "172800".parse().unwrap());
response_headers.insert("Content-Length", "0".parse().unwrap());
return (StatusCode::OK, response_headers, "").into_response();
}
response_headers.insert(
"Access-Control-Expose-Headers",
"x-amzn-RequestId,x-amzn-ErrorType,x-amz-request-id,x-amz-id-2,x-amzn-ErrorMessage,Date".parse().unwrap(),
);
}
if method != Method::POST {
let mut h = response_headers.clone();
h.insert(
"x-amzn-ErrorType",
constants::ACCESS_DENIED.parse().unwrap(),
);
return send_xml_error(
h,
constants::ACCESS_DENIED,
"Unable to determine service/operation name to be authorized",
403,
);
}
let content_type = headers
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.split(';')
.next()
.unwrap_or("")
.trim();
let content_valid = matches!(
content_type,
"application/x-amz-json-1.1" | "application/x-amz-cbor-1.1" | "application/json"
);
let target = headers
.get("x-amz-target")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let parts: Vec<&str> = target.splitn(2, '.').collect();
let service = parts.first().copied().unwrap_or("");
let operation_str = if parts.len() > 1 { parts[1] } else { "" };
let service_valid = service == constants::KINESIS_API;
let operation = operation_str.parse::<Operation>().ok();
let operation_valid = operation.is_some();
let response_content_type = if content_type == constants::CONTENT_TYPE_JSON {
constants::CONTENT_TYPE_JSON
} else {
constants::CONTENT_TYPE_CBOR
};
if body.is_empty() {
let error_type = if service_valid && operation_valid {
constants::SERIALIZATION_EXCEPTION
} else {
constants::UNKNOWN_OPERATION
};
let err = KinesisErrorResponse::client_error(error_type, None);
return send_kinesis_error(&response_headers, response_content_type, &err);
}
if !content_valid {
if service.is_empty() || operation_str.is_empty() {
let mut h = response_headers.clone();
h.insert(
"x-amzn-ErrorType",
constants::ACCESS_DENIED.parse().unwrap(),
);
return send_xml_error(
h,
constants::ACCESS_DENIED,
"Unable to determine service/operation name to be authorized",
403,
);
}
let mut h = response_headers.clone();
h.insert(
"x-amzn-ErrorType",
constants::UNKNOWN_OPERATION.parse().unwrap(),
);
return send_xml_error_code(h, constants::UNKNOWN_OPERATION, 404);
}
let data: Option<Value> = if content_type == constants::CONTENT_TYPE_CBOR {
ciborium::from_reader::<ciborium::Value, _>(&body[..])
.ok()
.map(|v| cbor_to_json(&v))
} else {
serde_json::from_slice(&body).ok()
};
let data = match data {
Some(Value::Object(map)) => Value::Object(map),
Some(_) | None => {
if content_type == "application/json" {
return send_json_response(
response_headers.clone(),
"application/json",
&json!({
"Output": {"__type": "com.amazon.coral.service#SerializationException"},
"Version": "1.0",
}),
400,
);
}
let err = KinesisErrorResponse::client_error(constants::SERIALIZATION_EXCEPTION, None);
return send_kinesis_error(&response_headers, response_content_type, &err);
}
};
if content_type == "application/json" {
return send_json_response(
response_headers.clone(),
"application/json",
&json!({
"Output": {"__type": "com.amazon.coral.service#UnknownOperationException"},
"Version": "1.0",
}),
404,
);
}
let Some(operation) = operation else {
let err = KinesisErrorResponse::client_error(constants::UNKNOWN_OPERATION, None);
return send_kinesis_error(&response_headers, response_content_type, &err);
};
if !service_valid {
let err = KinesisErrorResponse::client_error(constants::UNKNOWN_OPERATION, None);
return send_kinesis_error(&response_headers, response_content_type, &err);
}
let auth_header = headers.get("authorization").and_then(|v| v.to_str().ok());
let query_string = uri.query().unwrap_or("");
let auth_query = query_string.contains("X-Amz-Algorithm");
if auth_header.is_some() && auth_query {
return send_error_response(
&response_headers,
content_valid,
response_content_type,
constants::INVALID_SIGNATURE,
"Found both 'X-Amz-Algorithm' as a query-string param and 'Authorization' as HTTP header.",
400,
);
}
if auth_header.is_none() && !auth_query {
return send_error_response(
&response_headers,
content_valid,
response_content_type,
constants::MISSING_AUTH_TOKEN,
"Missing Authentication Token",
400,
);
}
if let Some(auth) = auth_header {
let mut msg = String::new();
let auth_params: std::collections::HashMap<String, String> = auth
.split([',', ' '])
.skip(1)
.filter(|s| !s.is_empty())
.filter_map(|s| {
let kv: Vec<&str> = s.trim().splitn(2, '=').collect();
if kv.len() == 2 {
Some((kv[0].to_string(), kv[1].to_string()))
} else {
None
}
})
.collect();
for param in ["Credential", "Signature", "SignedHeaders"] {
if !auth_params.contains_key(param) {
msg += &format!("Authorization header requires '{param}' parameter. ");
}
}
if !headers.contains_key("x-amz-date") && !headers.contains_key("date") {
msg += "Authorization header requires existence of either a 'X-Amz-Date' or a 'Date' header. ";
}
if !msg.is_empty() {
msg += &format!("Authorization={auth}");
return send_error_response(
&response_headers,
content_valid,
response_content_type,
constants::INCOMPLETE_SIGNATURE,
&msg,
403,
);
}
} else {
let query_params: std::collections::HashMap<String, String> = uri
.query()
.unwrap_or("")
.split('&')
.filter_map(|s| {
let kv: Vec<&str> = s.splitn(2, '=').collect();
if kv.len() == 2 {
Some((kv[0].to_string(), kv[1].to_string()))
} else if !kv[0].is_empty() {
Some((kv[0].to_string(), String::new()))
} else {
None
}
})
.collect();
let mut msg = String::new();
for param in [
"X-Amz-Algorithm",
"X-Amz-Credential",
"X-Amz-Signature",
"X-Amz-SignedHeaders",
"X-Amz-Date",
] {
if !query_params.contains_key(param) || query_params[param].is_empty() {
msg += &format!("AWS query-string parameters must include '{param}'. ");
}
}
if !msg.is_empty() {
msg += "Re-examine the query-string parameters.";
return send_error_response(
&response_headers,
content_valid,
response_content_type,
constants::INCOMPLETE_SIGNATURE,
&msg,
403,
);
}
}
let validation_rules = operation.validation_rules();
let field_refs: Vec<(&str, &validation::FieldDef)> =
validation_rules.iter().map(|(k, v)| (*k, v)).collect();
let data = match validation::check_types(&data, &field_refs) {
Ok(d) => d,
Err(err) => {
return send_kinesis_error(&response_headers, response_content_type, &err);
}
};
if let Err(err) = validation::check_validations(&data, &field_refs, None) {
return send_kinesis_error(&response_headers, response_content_type, &err);
}
let span = tracing::info_span!("kinesis", %operation, %request_id);
if operation == Operation::SubscribeToShard {
#[cfg(not(target_arch = "wasm32"))]
return match actions::subscribe_to_shard::execute_streaming(
&store,
data,
response_content_type,
)
.instrument(span.clone())
.await
{
Ok(body) => {
tracing::debug!(parent: &span, "ok");
response_headers.insert(
"Content-Type",
"application/vnd.amazon.eventstream".parse().unwrap(),
);
(StatusCode::OK, response_headers, body).into_response()
}
Err(ref err) => {
log_and_send_error(&span, &response_headers, response_content_type, err)
}
};
#[cfg(target_arch = "wasm32")]
{
let err = KinesisErrorResponse::client_error(
constants::INVALID_ARGUMENT,
Some("SubscribeToShard is not supported in this build."),
);
return log_and_send_error(&span, &response_headers, response_content_type, &err);
}
}
let dispatch_result = actions::dispatch(&store, operation, data)
.instrument(span.clone())
.await;
let (response, mirrorable_result) = match dispatch_result {
Ok(opt_result) => {
tracing::debug!(parent: &span, "ok");
let response = match &opt_result {
Some(result) => {
send_value_response(response_headers, response_content_type, result, 200)
}
None => {
response_headers.insert("Content-Type", response_content_type.parse().unwrap());
response_headers.insert("Content-Length", "0".parse().unwrap());
(StatusCode::OK, response_headers, "").into_response()
}
};
(response, Ok(opt_result))
}
Err(err) => {
let response =
log_and_send_error(&span, &response_headers, response_content_type, &err);
(response, Err(err))
}
};
#[cfg(feature = "mirror")]
if let Some(Extension(ref mirror)) = mirror
&& Mirror::should_mirror(&operation)
{
match mirrorable_result {
Ok(result) => {
mirror.spawn_forward(target.to_string(), content_type.to_string(), body, result);
}
Err(e) => {
tracing::debug!(
parent: &span,
error_type = %e.body.error_type,
"skipping mirror: local dispatch failed"
);
}
}
}
#[cfg(not(feature = "mirror"))]
{
let _ = (mirror, mirrorable_result);
}
response
}
fn send_kinesis_error(
extra_headers: &HeaderMap,
content_type: &str,
err: &KinesisErrorResponse,
) -> Response {
let mut headers = extra_headers.clone();
headers.insert(
"x-amzn-ErrorType",
err.body
.error_type
.parse()
.expect("error_type must be valid ASCII"),
);
send_json_response(headers, content_type, &err.body, err.status_code)
}
fn log_and_send_error(
span: &tracing::Span,
headers: &HeaderMap,
content_type: &str,
err: &KinesisErrorResponse,
) -> Response {
if err.status_code >= 500 {
tracing::error!(parent: span, error_type = %err.body.error_type, "server error");
} else {
tracing::info!(parent: span, error_type = %err.body.error_type, "client error");
}
send_kinesis_error(headers, content_type, err)
}
fn send_json_response(
mut headers: HeaderMap,
content_type: &str,
data: &impl Serialize,
status_code: u16,
) -> Response {
let body_bytes = if content_type == constants::CONTENT_TYPE_CBOR {
let mut buf = Vec::new();
let _ = ciborium::into_writer(data, &mut buf);
buf
} else {
serde_json::to_vec(data).unwrap_or_default()
};
headers.insert("Content-Type", content_type.parse().unwrap());
headers.insert(
"Content-Length",
body_bytes.len().to_string().parse().unwrap(),
);
(
StatusCode::from_u16(status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
headers,
body_bytes,
)
.into_response()
}
fn send_value_response(
mut headers: HeaderMap,
content_type: &str,
data: &Value,
status_code: u16,
) -> Response {
let body_bytes = if content_type == constants::CONTENT_TYPE_CBOR {
let mut buf = Vec::new();
let _ = ciborium::into_writer(&BlobAwareValue::new(data), &mut buf);
buf
} else {
serde_json::to_vec(data).unwrap_or_default()
};
headers.insert("Content-Type", content_type.parse().unwrap());
headers.insert(
"Content-Length",
body_bytes.len().to_string().parse().unwrap(),
);
(
StatusCode::from_u16(status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
headers,
body_bytes,
)
.into_response()
}
pub(crate) struct BlobAwareValue<'a> {
val: &'a Value,
is_blob: bool,
}
impl<'a> BlobAwareValue<'a> {
pub(crate) fn new(val: &'a Value) -> Self {
Self {
val,
is_blob: false,
}
}
}
impl Serialize for BlobAwareValue<'_> {
fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
match self.val {
Value::String(st) if self.is_blob => {
match base64::engine::general_purpose::STANDARD.decode(st) {
Ok(bytes) => s.serialize_bytes(&bytes),
Err(_) => s.serialize_str(st), }
}
Value::Object(map) => {
use serde::ser::SerializeMap;
let mut m = s.serialize_map(Some(map.len()))?;
for (k, v) in map {
m.serialize_entry(
k,
&BlobAwareValue {
val: v,
is_blob: k == constants::DATA,
},
)?;
}
m.end()
}
Value::Array(arr) => {
use serde::ser::SerializeSeq;
let mut seq = s.serialize_seq(Some(arr.len()))?;
for v in arr {
seq.serialize_element(&BlobAwareValue {
val: v,
is_blob: false,
})?;
}
seq.end()
}
other => other.serialize(s),
}
}
}
fn send_xml_error(
mut headers: HeaderMap,
error_type: &str,
message: &str,
status_code: u16,
) -> Response {
let body = format!("<{error_type}>\n <Message>{message}</Message>\n</{error_type}>\n");
headers.insert("Content-Length", body.len().to_string().parse().unwrap());
(
StatusCode::from_u16(status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
headers,
body,
)
.into_response()
}
fn send_xml_error_code(mut headers: HeaderMap, error_type: &str, status_code: u16) -> Response {
let body = format!("<{error_type}/>\n");
headers.insert("Content-Length", body.len().to_string().parse().unwrap());
(
StatusCode::from_u16(status_code).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR),
headers,
body,
)
.into_response()
}
pub async fn kinesis_413_middleware(request: Request, next: Next) -> Response {
let content_type = request
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.split(';')
.next()
.unwrap_or("")
.trim()
.to_owned();
let response = next.run(request).await;
if response.status() != StatusCode::PAYLOAD_TOO_LARGE {
return response;
}
let error = json!({
"__type": constants::SERIALIZATION_EXCEPTION,
"Message": "Request body is too large"
});
let response_content_type = if content_type == constants::CONTENT_TYPE_JSON {
constants::CONTENT_TYPE_JSON
} else {
constants::CONTENT_TYPE_CBOR
};
let body_bytes = if response_content_type == constants::CONTENT_TYPE_CBOR {
let mut buf = Vec::new();
let _ = ciborium::into_writer(&error, &mut buf);
buf
} else {
serde_json::to_vec(&error).unwrap_or_default()
};
(
StatusCode::PAYLOAD_TOO_LARGE,
[
("Content-Type", response_content_type.to_owned()),
("Content-Length", body_bytes.len().to_string()),
],
body_bytes,
)
.into_response()
}
fn send_error_response(
extra_headers: &HeaderMap,
content_valid: bool,
content_type: &str,
error_type: &str,
message: &str,
status_code: u16,
) -> Response {
if content_valid {
let err = KinesisErrorResponse::new(status_code, error_type, Some(message));
send_kinesis_error(extra_headers, content_type, &err)
} else {
let mut headers = extra_headers.clone();
headers.insert(
"x-amzn-ErrorType",
error_type.parse().expect("error_type must be valid ASCII"),
);
send_xml_error(headers, error_type, message, status_code)
}
}
#[doc(hidden)]
pub fn cbor_to_json(val: &ciborium::Value) -> Value {
match val {
ciborium::Value::Null => Value::Null,
ciborium::Value::Bool(b) => Value::Bool(*b),
ciborium::Value::Integer(n) => {
let n: i128 = (*n).into();
if let Ok(i) = i64::try_from(n) {
Value::Number(serde_json::Number::from(i))
} else {
#[allow(clippy::cast_precision_loss)]
let f = n as f64;
serde_json::Number::from_f64(f)
.map(Value::Number)
.unwrap_or(Value::Null)
}
}
ciborium::Value::Float(f) => serde_json::Number::from_f64(*f)
.map(Value::Number)
.unwrap_or(Value::Null),
ciborium::Value::Text(s) => Value::String(s.clone()),
ciborium::Value::Bytes(b) => {
Value::String(base64::engine::general_purpose::STANDARD.encode(b))
}
ciborium::Value::Array(arr) => Value::Array(arr.iter().map(cbor_to_json).collect()),
ciborium::Value::Map(map) => {
let mut obj = serde_json::Map::new();
for (k, v) in map {
let key = match k {
ciborium::Value::Text(s) => s.clone(),
_ => format!("{k:?}"),
};
obj.insert(key, cbor_to_json(v));
}
Value::Object(obj)
}
ciborium::Value::Tag(_, inner) => cbor_to_json(inner),
_ => Value::Null,
}
}
#[cfg(test)]
mod tests {
use super::*;
use base64::Engine;
use serde_json::json;
fn to_cbor_value(bav: &BlobAwareValue<'_>) -> ciborium::Value {
let mut buf = Vec::new();
ciborium::into_writer(bav, &mut buf).expect("CBOR serialization failed");
ciborium::from_reader(&buf[..]).expect("CBOR deserialization failed")
}
#[test]
fn blob_valid_base64_emits_bytes() {
let raw = b"hello world";
let b64 = base64::engine::general_purpose::STANDARD.encode(raw);
let val = Value::String(b64);
let bav = BlobAwareValue {
val: &val,
is_blob: true,
};
let cbor = to_cbor_value(&bav);
assert_eq!(cbor, ciborium::Value::Bytes(raw.to_vec()));
}
#[test]
fn blob_invalid_base64_falls_back_to_text() {
let val = Value::String("NOT!VALID!BASE64".to_string());
let bav = BlobAwareValue {
val: &val,
is_blob: true,
};
let cbor = to_cbor_value(&bav);
assert_eq!(cbor, ciborium::Value::Text("NOT!VALID!BASE64".to_string()));
}
#[test]
fn non_blob_string_emits_text() {
let b64 = base64::engine::general_purpose::STANDARD.encode(b"bytes");
let val = Value::String(b64.clone());
let bav = BlobAwareValue {
val: &val,
is_blob: false,
};
let cbor = to_cbor_value(&bav);
assert_eq!(cbor, ciborium::Value::Text(b64));
}
#[test]
fn object_with_data_key_decodes_blob() {
let raw = b"payload";
let b64 = base64::engine::general_purpose::STANDARD.encode(raw);
let val = json!({"Data": b64, "PartitionKey": "pk"});
let bav = BlobAwareValue::new(&val);
let cbor = to_cbor_value(&bav);
if let ciborium::Value::Map(entries) = cbor {
for (k, v) in &entries {
match k {
ciborium::Value::Text(key) if key == "Data" => {
assert_eq!(v, &ciborium::Value::Bytes(raw.to_vec()));
}
ciborium::Value::Text(key) if key == "PartitionKey" => {
assert_eq!(v, &ciborium::Value::Text("pk".to_string()));
}
_ => panic!("unexpected key: {k:?}"),
}
}
} else {
panic!("expected CBOR map, got {cbor:?}");
}
}
#[test]
fn array_does_not_propagate_is_blob() {
let b64 = base64::engine::general_purpose::STANDARD.encode(b"data");
let val = json!([b64]);
let bav = BlobAwareValue {
val: &val,
is_blob: true, };
let cbor = to_cbor_value(&bav);
if let ciborium::Value::Array(items) = cbor {
assert_eq!(items[0], ciborium::Value::Text(b64));
} else {
panic!("expected CBOR array");
}
}
#[test]
fn blob_empty_base64_emits_empty_bytes() {
let val = Value::String(String::new());
let bav = BlobAwareValue {
val: &val,
is_blob: true,
};
let cbor = to_cbor_value(&bav);
assert_eq!(cbor, ciborium::Value::Bytes(vec![]));
}
}