Skip to main content

awsim_core/protocol/
mod.rs

1pub mod eventstream;
2pub mod json;
3pub mod query;
4pub mod rest;
5
6use axum::http::{HeaderMap, Method, Uri};
7use bytes::Bytes;
8use serde_json::Value;
9
10use crate::error::AwsError;
11
12/// AWS API protocols.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum Protocol {
15    AwsJson1_0,
16    AwsJson1_1,
17    RestJson1,
18    RestXml,
19    AwsQuery,
20    Ec2Query,
21}
22
23impl Protocol {
24    pub fn response_content_type(&self) -> &'static str {
25        match self {
26            Self::AwsJson1_0 | Self::AwsJson1_1 | Self::RestJson1 => "application/x-amz-json-1.0",
27            Self::RestXml | Self::AwsQuery | Self::Ec2Query => "application/xml",
28        }
29    }
30
31    pub fn is_json(&self) -> bool {
32        matches!(self, Self::AwsJson1_0 | Self::AwsJson1_1 | Self::RestJson1)
33    }
34
35    pub fn is_xml(&self) -> bool {
36        matches!(self, Self::RestXml | Self::AwsQuery | Self::Ec2Query)
37    }
38}
39
40/// Parsed AWS request ready for dispatch to a service handler.
41#[derive(Debug)]
42pub struct ParsedRequest {
43    pub operation: String,
44    pub input: Value,
45}
46
47/// Route definition for REST-style services.
48#[derive(Debug, Clone)]
49pub struct RouteDefinition {
50    pub method: &'static str,
51    pub path_pattern: &'static str,
52    pub operation: &'static str,
53    /// For S3-style query parameter disambiguation.
54    /// e.g., PUT /{Bucket}?versioning → PutBucketVersioning
55    pub required_query_param: Option<&'static str>,
56}
57
58/// Detect which protocol an incoming request uses.
59pub fn detect_protocol(headers: &HeaderMap, body: &Bytes) -> Option<Protocol> {
60    // Check X-Amz-Target header → awsJson
61    if let Some(target) = headers.get("x-amz-target") {
62        let content_type = headers
63            .get("content-type")
64            .and_then(|v| v.to_str().ok())
65            .unwrap_or("");
66        if content_type.contains("x-amz-json-1.0") {
67            return Some(Protocol::AwsJson1_0);
68        }
69        // Default to 1.1 if X-Amz-Target present but content-type doesn't specify 1.0
70        let _ = target;
71        return Some(Protocol::AwsJson1_1);
72    }
73
74    let content_type = headers
75        .get("content-type")
76        .and_then(|v| v.to_str().ok())
77        .unwrap_or("");
78
79    // Check form-encoded → awsQuery or ec2Query
80    if content_type.contains("x-www-form-urlencoded") {
81        let body_str = std::str::from_utf8(body).unwrap_or("");
82        if body_str.contains("Action=") {
83            return Some(Protocol::AwsQuery);
84        }
85    }
86
87    // Check JSON content type → restJson1
88    if content_type.contains("json") {
89        return Some(Protocol::RestJson1);
90    }
91
92    // Check XML content type → restXml
93    if content_type.contains("xml") {
94        return Some(Protocol::RestXml);
95    }
96
97    // For REST protocols without explicit content-type (GET/HEAD/DELETE with no body),
98    // we determine protocol from the service's declared protocol
99    None
100}
101
102/// Parse a request based on the detected protocol.
103pub fn parse_request(
104    protocol: Protocol,
105    method: &Method,
106    uri: &Uri,
107    headers: &HeaderMap,
108    body: &Bytes,
109    routes: &[RouteDefinition],
110) -> Result<ParsedRequest, AwsError> {
111    match protocol {
112        Protocol::AwsJson1_0 | Protocol::AwsJson1_1 => json::parse_request(headers, body),
113        Protocol::AwsQuery | Protocol::Ec2Query => query::parse_request(body),
114        Protocol::RestJson1 => rest::parse_json_request(method, uri, body, routes),
115        Protocol::RestXml => rest::parse_xml_request(method, uri, headers, body, routes),
116    }
117}
118
119/// Serialize a successful response based on protocol.
120pub fn serialize_response(
121    protocol: Protocol,
122    operation: &str,
123    output: &Value,
124    request_id: &str,
125) -> (axum::http::StatusCode, HeaderMap, Bytes) {
126    // Streaming responses (Bedrock ConverseStream / InvokeModelWith
127    // ResponseStream, etc.) tag their output with an event-stream
128    // marker. Detect it before falling through to the per-protocol
129    // JSON/XML/Query encoders so the SDK gets the binary frames it
130    // expects under `application/vnd.amazon.eventstream`.
131    if let Some(body) = eventstream::try_encode(output) {
132        let mut headers = HeaderMap::new();
133        if let Ok(v) = "application/vnd.amazon.eventstream".parse() {
134            headers.insert(axum::http::header::CONTENT_TYPE, v);
135        }
136        if let Ok(v) = request_id.parse() {
137            headers.insert("x-amzn-requestid", v);
138        }
139        return (axum::http::StatusCode::OK, headers, Bytes::from(body));
140    }
141
142    match protocol {
143        Protocol::AwsJson1_0 | Protocol::AwsJson1_1 | Protocol::RestJson1 => {
144            json::serialize_response(output, request_id)
145        }
146        Protocol::AwsQuery | Protocol::Ec2Query => {
147            query::serialize_response(operation, output, request_id)
148        }
149        Protocol::RestXml => rest::serialize_xml_response(output, request_id),
150    }
151}
152
153/// Serialize an error response based on protocol.
154pub fn serialize_error(
155    protocol: Protocol,
156    error: &AwsError,
157    request_id: &str,
158) -> (axum::http::StatusCode, HeaderMap, Bytes) {
159    match protocol {
160        Protocol::AwsJson1_0 | Protocol::AwsJson1_1 | Protocol::RestJson1 => {
161            json::serialize_error(error, request_id)
162        }
163        Protocol::AwsQuery | Protocol::Ec2Query => query::serialize_error(error, request_id),
164        Protocol::RestXml => rest::serialize_error(error, request_id),
165    }
166}