Skip to main content

awsim_core/protocol/
rest.rs

1use axum::http::{HeaderMap, Method, StatusCode, Uri};
2use bytes::Bytes;
3use serde_json::Value;
4
5use crate::error::AwsError;
6
7use super::{ParsedRequest, RouteDefinition};
8
9/// Parse a restJson1 request.
10///
11/// Operation is determined by matching HTTP method + URI path against route definitions.
12pub fn parse_json_request(
13    method: &Method,
14    uri: &Uri,
15    body: &Bytes,
16    routes: &[RouteDefinition],
17) -> Result<ParsedRequest, AwsError> {
18    let path = uri.path();
19    let query_string = uri.query().unwrap_or("");
20
21    let (operation, path_params) = match_route(method.as_str(), path, query_string, routes)?;
22
23    let mut input = if body.is_empty() {
24        Value::Object(serde_json::Map::new())
25    } else {
26        serde_json::from_slice(body).map_err(|e| {
27            AwsError::bad_request("SerializationException", format!("Invalid JSON body: {e}"))
28        })?
29    };
30
31    // Merge path parameters into input
32    if let Value::Object(ref mut map) = input {
33        for (key, value) in path_params {
34            map.insert(key, Value::String(value));
35        }
36        // Merge query parameters
37        for (key, value) in parse_query_string(query_string) {
38            map.entry(key).or_insert(Value::String(value));
39        }
40    }
41
42    Ok(ParsedRequest {
43        operation: operation.to_string(),
44        input,
45    })
46}
47
48/// Parse a restXml request (used by S3).
49pub fn parse_xml_request(
50    method: &Method,
51    uri: &Uri,
52    headers: &HeaderMap,
53    body: &Bytes,
54    routes: &[RouteDefinition],
55) -> Result<ParsedRequest, AwsError> {
56    let path = uri.path();
57    let query_string = uri.query().unwrap_or("");
58
59    let (operation, path_params) = match_route(method.as_str(), path, query_string, routes)?;
60
61    let mut input = if body.is_empty() {
62        Value::Object(serde_json::Map::new())
63    } else {
64        // Only attempt XML parsing if the body actually looks like XML (starts
65        // with '<').  Otherwise treat it as raw binary data and store it as
66        // base64 in `__raw_body` so handlers like S3 PutObject can access it.
67        let looks_like_xml = body.first().is_some_and(|&b| b == b'<');
68        if looks_like_xml {
69            match parse_xml_body(body) {
70                Ok(v) => v,
71                Err(_) => {
72                    use base64::Engine;
73                    let encoded = base64::engine::general_purpose::STANDARD.encode(body);
74                    let mut map = serde_json::Map::new();
75                    map.insert("__raw_body".to_string(), Value::String(encoded));
76                    Value::Object(map)
77                }
78            }
79        } else {
80            // Non-XML body — always store as raw binary.
81            use base64::Engine;
82            let encoded = base64::engine::general_purpose::STANDARD.encode(body);
83            let mut map = serde_json::Map::new();
84            map.insert("__raw_body".to_string(), Value::String(encoded));
85            Value::Object(map)
86        }
87    };
88
89    // Merge path parameters
90    if let Value::Object(ref mut map) = input {
91        for (key, value) in path_params {
92            map.insert(key, Value::String(value));
93        }
94        for (key, value) in parse_query_string(query_string) {
95            map.entry(key).or_insert(Value::String(value));
96        }
97        // Extract relevant headers (x-amz-*)
98        for (name, value) in headers.iter() {
99            let name_str = name.as_str();
100            if name_str.starts_with("x-amz-")
101                && name_str != "x-amz-target"
102                && let Ok(v) = value.to_str()
103            {
104                let key = header_to_param_name(name_str);
105                map.entry(key).or_insert(Value::String(v.to_string()));
106            }
107        }
108    }
109
110    Ok(ParsedRequest {
111        operation: operation.to_string(),
112        input,
113    })
114}
115
116/// Result of matching an HTTP request against routes: the operation name and path parameters.
117pub type RouteMatch<'a> = (&'a str, Vec<(String, String)>);
118
119/// Match an HTTP request against route definitions.
120/// Returns the operation name and extracted path parameters.
121fn match_route<'a>(
122    method: &str,
123    path: &str,
124    query_string: &str,
125    routes: &'a [RouteDefinition],
126) -> Result<RouteMatch<'a>, AwsError> {
127    // Strip a trailing slash ONLY for bucket-level operations (paths like `/bucket/`).
128    // Don't strip for object keys like `/bucket/folder/` — the trailing slash is
129    // significant (it marks S3 "folder" objects).
130    let segments: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
131    let path = if segments.len() <= 1 {
132        // Bucket-level: `/bucket/` → `/bucket`
133        let stripped = path.strip_suffix('/').unwrap_or(path);
134        if stripped.is_empty() { "/" } else { stripped }
135    } else {
136        // Object-level: preserve trailing slash for folder markers
137        path
138    };
139
140    let query_params: Vec<(String, String)> = parse_query_string(query_string);
141
142    // Try routes with required_query_param first (more specific matches)
143    let mut best_match: Option<(&str, Vec<(String, String)>)> = None;
144    let mut best_specificity = 0;
145
146    for route in routes {
147        if !route.method.eq_ignore_ascii_case(method) {
148            continue;
149        }
150
151        if let Some(path_params) = match_path_pattern(route.path_pattern, path) {
152            let specificity = if route.required_query_param.is_some() {
153                2
154            } else {
155                1
156            };
157
158            if let Some(required_param) = route.required_query_param {
159                // This route requires a specific query parameter to be present
160                if query_params.iter().any(|(k, _)| k == required_param)
161                    && specificity > best_specificity
162                {
163                    best_match = Some((route.operation, path_params));
164                    best_specificity = specificity;
165                }
166            } else if specificity > best_specificity
167                || (specificity == best_specificity && best_match.is_none())
168            {
169                best_match = Some((route.operation, path_params));
170                best_specificity = specificity;
171            }
172        }
173    }
174
175    best_match.ok_or_else(|| AwsError::unknown_operation(&format!("{method} {path}")))
176}
177
178/// Match a path pattern like "/2015-03-31/functions/{FunctionName}" against an actual path.
179/// Returns extracted path parameters if matched.
180fn match_path_pattern(pattern: &str, path: &str) -> Option<Vec<(String, String)>> {
181    let pattern_parts: Vec<&str> = pattern.split('/').collect();
182    let path_parts: Vec<&str> = path.split('/').collect();
183
184    // Handle greedy patterns (last segment is {Key+})
185    let has_greedy = pattern_parts
186        .last()
187        .is_some_and(|p| p.starts_with('{') && p.ends_with("+}"));
188
189    if has_greedy {
190        if path_parts.len() < pattern_parts.len() {
191            return None;
192        }
193    } else if pattern_parts.len() != path_parts.len() {
194        return None;
195    }
196
197    let mut params = Vec::new();
198
199    for (i, (pat, actual)) in pattern_parts.iter().zip(path_parts.iter()).enumerate() {
200        if pat.starts_with('{') && pat.ends_with("+}") {
201            // Greedy match - capture rest of path
202            let name = &pat[1..pat.len() - 2];
203            let rest = path_parts[i..].join("/");
204            params.push((name.to_string(), percent_decode(&rest)));
205            return Some(params);
206        } else if pat.starts_with('{') && pat.ends_with('}') {
207            let name = &pat[1..pat.len() - 1];
208            params.push((name.to_string(), percent_decode(actual)));
209        } else if pat != actual {
210            return None;
211        }
212    }
213
214    Some(params)
215}
216
217fn parse_query_string(qs: &str) -> Vec<(String, String)> {
218    if qs.is_empty() {
219        return Vec::new();
220    }
221    qs.split('&')
222        .filter_map(|pair| {
223            let mut parts = pair.splitn(2, '=');
224            let key = parts.next()?;
225            let value = parts.next().unwrap_or("");
226            Some((percent_decode(key), percent_decode(value)))
227        })
228        .collect()
229}
230
231fn percent_decode(s: &str) -> String {
232    // Simple percent-decoding
233    let mut result = String::with_capacity(s.len());
234    let mut chars = s.chars();
235    while let Some(c) = chars.next() {
236        if c == '%' {
237            let hex: String = chars.by_ref().take(2).collect();
238            if let Ok(byte) = u8::from_str_radix(&hex, 16) {
239                result.push(byte as char);
240            } else {
241                result.push('%');
242                result.push_str(&hex);
243            }
244        } else if c == '+' {
245            result.push(' ');
246        } else {
247            result.push(c);
248        }
249    }
250    result
251}
252
253/// Convert x-amz-* header names to PascalCase parameter names.
254/// e.g., "x-amz-copy-source" → "CopySource"
255fn header_to_param_name(header: &str) -> String {
256    header
257        .strip_prefix("x-amz-")
258        .unwrap_or(header)
259        .split('-')
260        .map(|part| {
261            let mut chars = part.chars();
262            match chars.next() {
263                None => String::new(),
264                Some(c) => {
265                    let upper: String = c.to_uppercase().collect();
266                    upper + &chars.as_str().to_lowercase()
267                }
268            }
269        })
270        .collect()
271}
272
273/// Parse XML body into a JSON-like Value.
274fn parse_xml_body(body: &Bytes) -> Result<Value, AwsError> {
275    let s = std::str::from_utf8(body)
276        .map_err(|_| AwsError::bad_request("InvalidRequest", "Body is not valid UTF-8"))?;
277
278    parse_xml_element(s)
279}
280
281/// Simple XML → JSON parser for AWS request bodies.
282fn parse_xml_element(xml: &str) -> Result<Value, AwsError> {
283    use quick_xml::Reader;
284    use quick_xml::events::Event;
285
286    let mut reader = Reader::from_str(xml);
287    let mut map = serde_json::Map::new();
288    let mut stack: Vec<(String, serde_json::Map<String, Value>)> = Vec::new();
289    let mut current_key = String::new();
290    let mut current_text = String::new();
291
292    loop {
293        match reader.read_event() {
294            Ok(Event::Start(e)) => {
295                let name = String::from_utf8_lossy(e.name().as_ref()).to_string();
296                if !current_key.is_empty() {
297                    stack.push((current_key.clone(), map.clone()));
298                    map = serde_json::Map::new();
299                }
300                current_key = name;
301                current_text.clear();
302            }
303            Ok(Event::Text(e)) => {
304                current_text = e.unescape().unwrap_or_default().to_string();
305            }
306            Ok(Event::End(_)) => {
307                if current_text.is_empty() && !map.is_empty() {
308                    let value = Value::Object(map.clone());
309                    if let Some((parent_key, mut parent_map)) = stack.pop() {
310                        // Check if this key already exists (array case)
311                        if let Some(existing) = parent_map.get_mut(&current_key) {
312                            match existing {
313                                Value::Array(arr) => arr.push(value),
314                                other => {
315                                    let prev = other.take();
316                                    *other = Value::Array(vec![prev, value]);
317                                }
318                            }
319                        } else {
320                            parent_map.insert(current_key.clone(), value);
321                        }
322                        map = parent_map;
323                        current_key = parent_key;
324                    } else {
325                        map.insert(current_key.clone(), Value::Object(map.clone()));
326                    }
327                } else if !current_key.is_empty() {
328                    let value = Value::String(current_text.clone());
329                    if let Some((_parent_key, _parent_map)) = stack.last_mut() {}
330                    map.insert(current_key.clone(), value);
331                    if let Some((parent_key, mut parent_map)) = stack.pop() {
332                        parent_map.insert(current_key.clone(), Value::String(current_text.clone()));
333                        map = parent_map;
334                        current_key = parent_key;
335                    }
336                }
337                current_text.clear();
338            }
339            Ok(Event::Eof) => break,
340            Ok(_) => {}
341            Err(e) => {
342                return Err(AwsError::bad_request(
343                    "MalformedXML",
344                    format!("Invalid XML: {e}"),
345                ));
346            }
347        }
348    }
349
350    Ok(Value::Object(map))
351}
352
353/// Serialize a restXml success response.
354///
355/// Special convention: if `output` contains a `__raw_body` key (base64-encoded),
356/// the binary content is returned directly as the response body.  All other
357/// top-level keys are placed in response headers (e.g., `Content-Type`,
358/// `ETag`, `Last-Modified`).  This allows services such as S3 GetObject to
359/// return arbitrary binary data.
360pub fn serialize_xml_response(output: &Value, request_id: &str) -> (StatusCode, HeaderMap, Bytes) {
361    let mut headers = HeaderMap::new();
362    headers.insert("x-amz-request-id", request_id.parse().unwrap());
363
364    // --- Raw binary response (e.g. S3 GetObject) ---
365    if let Some(raw_b64) = output.get("__raw_body").and_then(Value::as_str) {
366        use base64::Engine;
367        let data = base64::engine::general_purpose::STANDARD
368            .decode(raw_b64)
369            .unwrap_or_default();
370
371        // Promote scalar fields to response headers.
372        if let Some(map) = output.as_object() {
373            for (key, val) in map {
374                if key == "__raw_body" || key == "Body" {
375                    continue;
376                }
377                let header_name = pascal_to_header(key);
378                let header_value = match val {
379                    Value::String(s) => s.clone(),
380                    Value::Number(n) => n.to_string(),
381                    Value::Bool(b) => b.to_string(),
382                    _ => continue,
383                };
384                if let (Ok(k), Ok(v)) = (
385                    axum::http::header::HeaderName::from_bytes(header_name.as_bytes()),
386                    axum::http::HeaderValue::from_str(&header_value),
387                ) {
388                    headers.insert(k, v);
389                }
390            }
391        }
392
393        return (StatusCode::OK, headers, Bytes::from(data));
394    }
395
396    // --- Promote well-known fields to HTTP headers (S3 convention) ---
397    if let Some(map) = output.as_object() {
398        let header_fields = [
399            "ETag",
400            "ContentType",
401            "ContentLength",
402            "LastModified",
403            "VersionId",
404            "ServerSideEncryption",
405            "StorageClass",
406        ];
407        for field in &header_fields {
408            if let Some(val) = map.get(*field) {
409                let header_name = pascal_to_header(field);
410                let header_value = match val {
411                    Value::String(s) => s.clone(),
412                    Value::Number(n) => n.to_string(),
413                    _ => continue,
414                };
415                if let (Ok(k), Ok(v)) = (
416                    axum::http::header::HeaderName::from_bytes(header_name.as_bytes()),
417                    axum::http::HeaderValue::from_str(&header_value),
418                ) {
419                    headers.insert(k, v);
420                }
421            }
422        }
423    }
424
425    // --- Normal XML response ---
426    // If `__xml_root` is present, wrap fields in that root element (with S3 namespace).
427    let xml_root = output
428        .get("__xml_root")
429        .and_then(Value::as_str)
430        .map(|s| s.to_string());
431
432    let output_for_xml = if xml_root.is_some() {
433        // Build a Value without the __xml_root sentinel key.
434        if let Some(map) = output.as_object() {
435            let filtered: serde_json::Map<String, Value> = map
436                .iter()
437                .filter(|(k, _)| k.as_str() != "__xml_root")
438                .map(|(k, v)| (k.clone(), v.clone()))
439                .collect();
440            Value::Object(filtered)
441        } else {
442            output.clone()
443        }
444    } else {
445        output.clone()
446    };
447
448    let body = if let Some(root) = xml_root {
449        // When an explicit XML root is present, always emit a root element
450        // (even if there are no child fields — e.g. empty BucketLoggingStatus).
451        let fields = super::query::json_to_xml_fields(&output_for_xml);
452        format!(
453            "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n\
454             <{root} xmlns=\"http://s3.amazonaws.com/doc/2006-03-01/\">\n\
455             {fields}</{root}>",
456        )
457    } else if output_for_xml.is_null()
458        || (output_for_xml.is_object() && output_for_xml.as_object().unwrap().is_empty())
459    {
460        String::new()
461    } else {
462        format!(
463            "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n{}",
464            super::query::json_to_xml_fields(&output_for_xml)
465        )
466    };
467
468    if !body.is_empty() {
469        headers.insert("content-type", "application/xml".parse().unwrap());
470    }
471    (StatusCode::OK, headers, Bytes::from(body))
472}
473
474/// Convert a PascalCase field name to a lowercase HTTP header name.
475/// e.g., "ContentType" → "content-type", "ETag" → "etag"
476fn pascal_to_header(name: &str) -> String {
477    // Special cases where the generic PascalCase→kebab-case doesn't match HTTP conventions
478    match name {
479        "ETag" => return "etag".to_string(),
480        "ContentType" => return "content-type".to_string(),
481        "ContentLength" => return "content-length".to_string(),
482        "LastModified" => return "last-modified".to_string(),
483        "VersionId" => return "x-amz-version-id".to_string(),
484        "ServerSideEncryption" => return "x-amz-server-side-encryption".to_string(),
485        "StorageClass" => return "x-amz-storage-class".to_string(),
486        _ => {}
487    }
488    let mut out = String::new();
489    for (i, ch) in name.char_indices() {
490        if ch.is_uppercase() && i > 0 {
491            out.push('-');
492        }
493        out.extend(ch.to_lowercase());
494    }
495    out
496}
497
498#[cfg(test)]
499mod tests {
500    use super::*;
501
502    #[test]
503    fn test_match_simple_path() {
504        let result = match_path_pattern("/functions", "/functions");
505        assert!(result.is_some());
506        assert!(result.unwrap().is_empty());
507    }
508
509    #[test]
510    fn test_match_path_with_param() {
511        let result = match_path_pattern(
512            "/2015-03-31/functions/{FunctionName}",
513            "/2015-03-31/functions/my-func",
514        );
515        assert!(result.is_some());
516        let params = result.unwrap();
517        assert_eq!(params.len(), 1);
518        assert_eq!(
519            params[0],
520            ("FunctionName".to_string(), "my-func".to_string())
521        );
522    }
523
524    #[test]
525    fn test_match_path_no_match() {
526        let result = match_path_pattern("/functions/{Name}", "/queues/my-queue");
527        assert!(result.is_none());
528    }
529
530    #[test]
531    fn test_match_greedy_path() {
532        let result = match_path_pattern("/{Bucket}/{Key+}", "/my-bucket/path/to/file.txt");
533        assert!(result.is_some());
534        let params = result.unwrap();
535        assert_eq!(params.len(), 2);
536        assert_eq!(params[0].1, "my-bucket");
537        assert_eq!(params[1].1, "path/to/file.txt");
538    }
539
540    #[test]
541    fn test_route_matching() {
542        let routes = vec![
543            RouteDefinition {
544                method: "GET",
545                path_pattern: "/2015-03-31/functions",
546                operation: "ListFunctions",
547                required_query_param: None,
548            },
549            RouteDefinition {
550                method: "POST",
551                path_pattern: "/2015-03-31/functions",
552                operation: "CreateFunction",
553                required_query_param: None,
554            },
555            RouteDefinition {
556                method: "GET",
557                path_pattern: "/2015-03-31/functions/{FunctionName}",
558                operation: "GetFunction",
559                required_query_param: None,
560            },
561        ];
562
563        let (op, _) = match_route("GET", "/2015-03-31/functions", "", &routes).unwrap();
564        assert_eq!(op, "ListFunctions");
565
566        let (op, params) =
567            match_route("GET", "/2015-03-31/functions/my-func", "", &routes).unwrap();
568        assert_eq!(op, "GetFunction");
569        assert_eq!(params[0].1, "my-func");
570    }
571
572    #[test]
573    fn test_query_param_disambiguation() {
574        let routes = vec![
575            RouteDefinition {
576                method: "PUT",
577                path_pattern: "/{Bucket}",
578                operation: "CreateBucket",
579                required_query_param: None,
580            },
581            RouteDefinition {
582                method: "PUT",
583                path_pattern: "/{Bucket}",
584                operation: "PutBucketVersioning",
585                required_query_param: Some("versioning"),
586            },
587        ];
588
589        let (op, _) = match_route("PUT", "/my-bucket", "", &routes).unwrap();
590        assert_eq!(op, "CreateBucket");
591
592        let (op, _) = match_route("PUT", "/my-bucket", "versioning", &routes).unwrap();
593        assert_eq!(op, "PutBucketVersioning");
594    }
595
596    #[test]
597    fn test_header_to_param_name() {
598        assert_eq!(header_to_param_name("x-amz-copy-source"), "CopySource");
599        assert_eq!(
600            header_to_param_name("x-amz-server-side-encryption"),
601            "ServerSideEncryption"
602        );
603    }
604}