Skip to main content

awsim_core/protocol/
rest.rs

1use axum::http::{HeaderMap, Method, StatusCode, Uri};
2use bytes::Bytes;
3use serde_json::Value;
4
5use crate::error::AwsError;
6
7use super::{ParsedRequest, RouteDefinition};
8
9/// Parse a restJson1 request.
10///
11/// Operation is determined by matching HTTP method + URI path against route definitions.
12pub fn parse_json_request(
13    method: &Method,
14    uri: &Uri,
15    body: &Bytes,
16    routes: &[RouteDefinition],
17) -> Result<ParsedRequest, AwsError> {
18    let path = uri.path();
19    let query_string = uri.query().unwrap_or("");
20
21    let (operation, path_params) = match_route(method.as_str(), path, query_string, routes)?;
22
23    let mut input = if body.is_empty() {
24        Value::Object(serde_json::Map::new())
25    } else {
26        serde_json::from_slice(body).map_err(|e| {
27            AwsError::bad_request("SerializationException", format!("Invalid JSON body: {e}"))
28        })?
29    };
30
31    // Merge path parameters into input
32    if let Value::Object(ref mut map) = input {
33        for (key, value) in path_params {
34            map.insert(key, Value::String(value));
35        }
36        // Merge query parameters
37        for (key, value) in parse_query_string(query_string) {
38            map.entry(key).or_insert(Value::String(value));
39        }
40    }
41
42    Ok(ParsedRequest {
43        operation: operation.to_string(),
44        input,
45    })
46}
47
48/// Parse a restXml request (used by S3).
49pub fn parse_xml_request(
50    method: &Method,
51    uri: &Uri,
52    headers: &HeaderMap,
53    body: &Bytes,
54    routes: &[RouteDefinition],
55) -> Result<ParsedRequest, AwsError> {
56    let path = uri.path();
57    let query_string = uri.query().unwrap_or("");
58
59    let (operation, path_params) = match_route(method.as_str(), path, query_string, routes)?;
60
61    let mut input = if body.is_empty() {
62        Value::Object(serde_json::Map::new())
63    } else {
64        // Only attempt XML parsing if the body actually looks like XML (starts
65        // with '<').  Otherwise treat it as raw binary data and store it as
66        // base64 in `__raw_body` so handlers like S3 PutObject can access it.
67        let looks_like_xml = body.first().is_some_and(|&b| b == b'<');
68        if looks_like_xml {
69            match parse_xml_body(body) {
70                Ok(v) => v,
71                Err(_) => {
72                    use base64::Engine;
73                    let encoded = base64::engine::general_purpose::STANDARD.encode(body);
74                    let mut map = serde_json::Map::new();
75                    map.insert("__raw_body".to_string(), Value::String(encoded));
76                    Value::Object(map)
77                }
78            }
79        } else {
80            // Non-XML body — always store as raw binary.
81            use base64::Engine;
82            let encoded = base64::engine::general_purpose::STANDARD.encode(body);
83            let mut map = serde_json::Map::new();
84            map.insert("__raw_body".to_string(), Value::String(encoded));
85            Value::Object(map)
86        }
87    };
88
89    // Merge path parameters
90    if let Value::Object(ref mut map) = input {
91        for (key, value) in path_params {
92            map.insert(key, Value::String(value));
93        }
94        for (key, value) in parse_query_string(query_string) {
95            map.entry(key).or_insert(Value::String(value));
96        }
97        // Extract relevant headers (x-amz-*)
98        for (name, value) in headers.iter() {
99            let name_str = name.as_str();
100            if name_str.starts_with("x-amz-") && name_str != "x-amz-target" {
101                if let Ok(v) = value.to_str() {
102                    let key = header_to_param_name(name_str);
103                    map.entry(key).or_insert(Value::String(v.to_string()));
104                }
105            }
106        }
107    }
108
109    Ok(ParsedRequest {
110        operation: operation.to_string(),
111        input,
112    })
113}
114
115/// Match an HTTP request against route definitions.
116/// Returns the operation name and extracted path parameters.
117fn match_route<'a>(
118    method: &str,
119    path: &str,
120    query_string: &str,
121    routes: &'a [RouteDefinition],
122) -> Result<(&'a str, Vec<(String, String)>), AwsError> {
123    // Strip a trailing slash ONLY for bucket-level operations (paths like `/bucket/`).
124    // Don't strip for object keys like `/bucket/folder/` — the trailing slash is
125    // significant (it marks S3 "folder" objects).
126    let segments: Vec<&str> = path.split('/').filter(|s| !s.is_empty()).collect();
127    let path = if segments.len() <= 1 {
128        // Bucket-level: `/bucket/` → `/bucket`
129        let stripped = path.strip_suffix('/').unwrap_or(path);
130        if stripped.is_empty() { "/" } else { stripped }
131    } else {
132        // Object-level: preserve trailing slash for folder markers
133        path
134    };
135
136    let query_params: Vec<(String, String)> = parse_query_string(query_string);
137
138    // Try routes with required_query_param first (more specific matches)
139    let mut best_match: Option<(&str, Vec<(String, String)>)> = None;
140    let mut best_specificity = 0;
141
142    for route in routes {
143        if !route.method.eq_ignore_ascii_case(method) {
144            continue;
145        }
146
147        if let Some(path_params) = match_path_pattern(route.path_pattern, path) {
148            let specificity = if route.required_query_param.is_some() {
149                2
150            } else {
151                1
152            };
153
154            if let Some(required_param) = route.required_query_param {
155                // This route requires a specific query parameter to be present
156                if query_params.iter().any(|(k, _)| k == required_param) && specificity > best_specificity {
157                    best_match = Some((route.operation, path_params));
158                    best_specificity = specificity;
159                }
160            } else if specificity > best_specificity
161                || (specificity == best_specificity && best_match.is_none())
162            {
163                best_match = Some((route.operation, path_params));
164                best_specificity = specificity;
165            }
166        }
167    }
168
169    best_match.ok_or_else(|| {
170        AwsError::unknown_operation(&format!("{method} {path}"))
171    })
172}
173
174/// Match a path pattern like "/2015-03-31/functions/{FunctionName}" against an actual path.
175/// Returns extracted path parameters if matched.
176fn match_path_pattern(pattern: &str, path: &str) -> Option<Vec<(String, String)>> {
177    let pattern_parts: Vec<&str> = pattern.split('/').collect();
178    let path_parts: Vec<&str> = path.split('/').collect();
179
180    // Handle greedy patterns (last segment is {Key+})
181    let has_greedy = pattern_parts
182        .last()
183        .is_some_and(|p| p.starts_with('{') && p.ends_with("+}"));
184
185    if has_greedy {
186        if path_parts.len() < pattern_parts.len() {
187            return None;
188        }
189    } else if pattern_parts.len() != path_parts.len() {
190        return None;
191    }
192
193    let mut params = Vec::new();
194
195    for (i, (pat, actual)) in pattern_parts.iter().zip(path_parts.iter()).enumerate() {
196        if pat.starts_with('{') && pat.ends_with("+}") {
197            // Greedy match - capture rest of path
198            let name = &pat[1..pat.len() - 2];
199            let rest = path_parts[i..].join("/");
200            params.push((name.to_string(), percent_decode(&rest)));
201            return Some(params);
202        } else if pat.starts_with('{') && pat.ends_with('}') {
203            let name = &pat[1..pat.len() - 1];
204            params.push((name.to_string(), percent_decode(actual)));
205        } else if pat != actual {
206            return None;
207        }
208    }
209
210    Some(params)
211}
212
213fn parse_query_string(qs: &str) -> Vec<(String, String)> {
214    if qs.is_empty() {
215        return Vec::new();
216    }
217    qs.split('&')
218        .filter_map(|pair| {
219            let mut parts = pair.splitn(2, '=');
220            let key = parts.next()?;
221            let value = parts.next().unwrap_or("");
222            Some((
223                percent_decode(key),
224                percent_decode(value),
225            ))
226        })
227        .collect()
228}
229
230fn percent_decode(s: &str) -> String {
231    // Simple percent-decoding
232    let mut result = String::with_capacity(s.len());
233    let mut chars = s.chars();
234    while let Some(c) = chars.next() {
235        if c == '%' {
236            let hex: String = chars.by_ref().take(2).collect();
237            if let Ok(byte) = u8::from_str_radix(&hex, 16) {
238                result.push(byte as char);
239            } else {
240                result.push('%');
241                result.push_str(&hex);
242            }
243        } else if c == '+' {
244            result.push(' ');
245        } else {
246            result.push(c);
247        }
248    }
249    result
250}
251
252/// Convert x-amz-* header names to PascalCase parameter names.
253/// e.g., "x-amz-copy-source" → "CopySource"
254fn header_to_param_name(header: &str) -> String {
255    header
256        .strip_prefix("x-amz-")
257        .unwrap_or(header)
258        .split('-')
259        .map(|part| {
260            let mut chars = part.chars();
261            match chars.next() {
262                None => String::new(),
263                Some(c) => {
264                    let upper: String = c.to_uppercase().collect();
265                    upper + &chars.as_str().to_lowercase()
266                }
267            }
268        })
269        .collect()
270}
271
272/// Parse XML body into a JSON-like Value.
273fn parse_xml_body(body: &Bytes) -> Result<Value, AwsError> {
274    let s = std::str::from_utf8(body)
275        .map_err(|_| AwsError::bad_request("InvalidRequest", "Body is not valid UTF-8"))?;
276
277    parse_xml_element(s)
278}
279
280/// Simple XML → JSON parser for AWS request bodies.
281fn parse_xml_element(xml: &str) -> Result<Value, AwsError> {
282    use quick_xml::events::Event;
283    use quick_xml::Reader;
284
285    let mut reader = Reader::from_str(xml);
286    let mut map = serde_json::Map::new();
287    let mut stack: Vec<(String, serde_json::Map<String, Value>)> = Vec::new();
288    let mut current_key = String::new();
289    let mut current_text = String::new();
290
291    loop {
292        match reader.read_event() {
293            Ok(Event::Start(e)) => {
294                let name = String::from_utf8_lossy(e.name().as_ref()).to_string();
295                if !current_key.is_empty() {
296                    stack.push((current_key.clone(), map.clone()));
297                    map = serde_json::Map::new();
298                }
299                current_key = name;
300                current_text.clear();
301            }
302            Ok(Event::Text(e)) => {
303                current_text = e.unescape().unwrap_or_default().to_string();
304            }
305            Ok(Event::End(_)) => {
306                if current_text.is_empty() && !map.is_empty() {
307                    let value = Value::Object(map.clone());
308                    if let Some((parent_key, mut parent_map)) = stack.pop() {
309                        // Check if this key already exists (array case)
310                        if let Some(existing) = parent_map.get_mut(&current_key) {
311                            match existing {
312                                Value::Array(arr) => arr.push(value),
313                                other => {
314                                    let prev = other.take();
315                                    *other = Value::Array(vec![prev, value]);
316                                }
317                            }
318                        } else {
319                            parent_map.insert(current_key.clone(), value);
320                        }
321                        map = parent_map;
322                        current_key = parent_key;
323                    } else {
324                        map.insert(current_key.clone(), Value::Object(map.clone()));
325                    }
326                } else if !current_key.is_empty() {
327                    let value = Value::String(current_text.clone());
328                    if let Some((_parent_key, _parent_map)) = stack.last_mut() {
329                    }
330                    map.insert(current_key.clone(), value);
331                    if let Some((parent_key, mut parent_map)) = stack.pop() {
332                        parent_map.insert(
333                            current_key.clone(),
334                            Value::String(current_text.clone()),
335                        );
336                        map = parent_map;
337                        current_key = parent_key;
338                    }
339                }
340                current_text.clear();
341            }
342            Ok(Event::Eof) => break,
343            Ok(_) => {}
344            Err(e) => {
345                return Err(AwsError::bad_request(
346                    "MalformedXML",
347                    format!("Invalid XML: {e}"),
348                ));
349            }
350        }
351    }
352
353    Ok(Value::Object(map))
354}
355
356/// Serialize a restXml success response.
357///
358/// Special convention: if `output` contains a `__raw_body` key (base64-encoded),
359/// the binary content is returned directly as the response body.  All other
360/// top-level keys are placed in response headers (e.g., `Content-Type`,
361/// `ETag`, `Last-Modified`).  This allows services such as S3 GetObject to
362/// return arbitrary binary data.
363pub fn serialize_xml_response(
364    output: &Value,
365    request_id: &str,
366) -> (StatusCode, HeaderMap, Bytes) {
367    let mut headers = HeaderMap::new();
368    headers.insert("x-amz-request-id", request_id.parse().unwrap());
369
370    // --- Raw binary response (e.g. S3 GetObject) ---
371    if let Some(raw_b64) = output.get("__raw_body").and_then(Value::as_str) {
372        use base64::Engine;
373        let data = base64::engine::general_purpose::STANDARD
374            .decode(raw_b64)
375            .unwrap_or_default();
376
377        // Promote scalar fields to response headers.
378        if let Some(map) = output.as_object() {
379            for (key, val) in map {
380                if key == "__raw_body" || key == "Body" {
381                    continue;
382                }
383                let header_name = pascal_to_header(key);
384                let header_value = match val {
385                    Value::String(s) => s.clone(),
386                    Value::Number(n) => n.to_string(),
387                    Value::Bool(b) => b.to_string(),
388                    _ => continue,
389                };
390                if let (Ok(k), Ok(v)) = (
391                    axum::http::header::HeaderName::from_bytes(header_name.as_bytes()),
392                    axum::http::HeaderValue::from_str(&header_value),
393                ) {
394                    headers.insert(k, v);
395                }
396            }
397        }
398
399        return (StatusCode::OK, headers, Bytes::from(data));
400    }
401
402    // --- Promote well-known fields to HTTP headers (S3 convention) ---
403    if let Some(map) = output.as_object() {
404        let header_fields = [
405            "ETag", "ContentType", "ContentLength", "LastModified",
406            "VersionId", "ServerSideEncryption", "StorageClass",
407        ];
408        for field in &header_fields {
409            if let Some(val) = map.get(*field) {
410                let header_name = pascal_to_header(field);
411                let header_value = match val {
412                    Value::String(s) => s.clone(),
413                    Value::Number(n) => n.to_string(),
414                    _ => continue,
415                };
416                if let (Ok(k), Ok(v)) = (
417                    axum::http::header::HeaderName::from_bytes(header_name.as_bytes()),
418                    axum::http::HeaderValue::from_str(&header_value),
419                ) {
420                    headers.insert(k, v);
421                }
422            }
423        }
424    }
425
426    // --- Normal XML response ---
427    // If `__xml_root` is present, wrap fields in that root element (with S3 namespace).
428    let xml_root = output
429        .get("__xml_root")
430        .and_then(Value::as_str)
431        .map(|s| s.to_string());
432
433    let output_for_xml = if xml_root.is_some() {
434        // Build a Value without the __xml_root sentinel key.
435        if let Some(map) = output.as_object() {
436            let filtered: serde_json::Map<String, Value> = map
437                .iter()
438                .filter(|(k, _)| k.as_str() != "__xml_root")
439                .map(|(k, v)| (k.clone(), v.clone()))
440                .collect();
441            Value::Object(filtered)
442        } else {
443            output.clone()
444        }
445    } else {
446        output.clone()
447    };
448
449    let body = if let Some(root) = xml_root {
450        // When an explicit XML root is present, always emit a root element
451        // (even if there are no child fields — e.g. empty BucketLoggingStatus).
452        let fields = super::query::json_to_xml_fields(&output_for_xml);
453        format!(
454            "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n\
455             <{root} xmlns=\"http://s3.amazonaws.com/doc/2006-03-01/\">\n\
456             {fields}</{root}>",
457        )
458    } else if output_for_xml.is_null()
459        || (output_for_xml.is_object() && output_for_xml.as_object().unwrap().is_empty())
460    {
461        String::new()
462    } else {
463        format!(
464            "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\n{}",
465            super::query::json_to_xml_fields(&output_for_xml)
466        )
467    };
468
469    if !body.is_empty() {
470        headers.insert("content-type", "application/xml".parse().unwrap());
471    }
472    (StatusCode::OK, headers, Bytes::from(body))
473}
474
475/// Convert a PascalCase field name to a lowercase HTTP header name.
476/// e.g., "ContentType" → "content-type", "ETag" → "etag"
477fn pascal_to_header(name: &str) -> String {
478    // Special cases where the generic PascalCase→kebab-case doesn't match HTTP conventions
479    match name {
480        "ETag" => return "etag".to_string(),
481        "ContentType" => return "content-type".to_string(),
482        "ContentLength" => return "content-length".to_string(),
483        "LastModified" => return "last-modified".to_string(),
484        "VersionId" => return "x-amz-version-id".to_string(),
485        "ServerSideEncryption" => return "x-amz-server-side-encryption".to_string(),
486        "StorageClass" => return "x-amz-storage-class".to_string(),
487        _ => {}
488    }
489    let mut out = String::new();
490    for (i, ch) in name.char_indices() {
491        if ch.is_uppercase() && i > 0 {
492            out.push('-');
493        }
494        out.extend(ch.to_lowercase());
495    }
496    out
497}
498
499#[cfg(test)]
500mod tests {
501    use super::*;
502
503    #[test]
504    fn test_match_simple_path() {
505        let result = match_path_pattern("/functions", "/functions");
506        assert!(result.is_some());
507        assert!(result.unwrap().is_empty());
508    }
509
510    #[test]
511    fn test_match_path_with_param() {
512        let result =
513            match_path_pattern("/2015-03-31/functions/{FunctionName}", "/2015-03-31/functions/my-func");
514        assert!(result.is_some());
515        let params = result.unwrap();
516        assert_eq!(params.len(), 1);
517        assert_eq!(params[0], ("FunctionName".to_string(), "my-func".to_string()));
518    }
519
520    #[test]
521    fn test_match_path_no_match() {
522        let result = match_path_pattern("/functions/{Name}", "/queues/my-queue");
523        assert!(result.is_none());
524    }
525
526    #[test]
527    fn test_match_greedy_path() {
528        let result = match_path_pattern("/{Bucket}/{Key+}", "/my-bucket/path/to/file.txt");
529        assert!(result.is_some());
530        let params = result.unwrap();
531        assert_eq!(params.len(), 2);
532        assert_eq!(params[0].1, "my-bucket");
533        assert_eq!(params[1].1, "path/to/file.txt");
534    }
535
536    #[test]
537    fn test_route_matching() {
538        let routes = vec![
539            RouteDefinition {
540                method: "GET",
541                path_pattern: "/2015-03-31/functions",
542                operation: "ListFunctions",
543                required_query_param: None,
544            },
545            RouteDefinition {
546                method: "POST",
547                path_pattern: "/2015-03-31/functions",
548                operation: "CreateFunction",
549                required_query_param: None,
550            },
551            RouteDefinition {
552                method: "GET",
553                path_pattern: "/2015-03-31/functions/{FunctionName}",
554                operation: "GetFunction",
555                required_query_param: None,
556            },
557        ];
558
559        let (op, _) = match_route("GET", "/2015-03-31/functions", "", &routes).unwrap();
560        assert_eq!(op, "ListFunctions");
561
562        let (op, params) =
563            match_route("GET", "/2015-03-31/functions/my-func", "", &routes).unwrap();
564        assert_eq!(op, "GetFunction");
565        assert_eq!(params[0].1, "my-func");
566    }
567
568    #[test]
569    fn test_query_param_disambiguation() {
570        let routes = vec![
571            RouteDefinition {
572                method: "PUT",
573                path_pattern: "/{Bucket}",
574                operation: "CreateBucket",
575                required_query_param: None,
576            },
577            RouteDefinition {
578                method: "PUT",
579                path_pattern: "/{Bucket}",
580                operation: "PutBucketVersioning",
581                required_query_param: Some("versioning"),
582            },
583        ];
584
585        let (op, _) = match_route("PUT", "/my-bucket", "", &routes).unwrap();
586        assert_eq!(op, "CreateBucket");
587
588        let (op, _) = match_route("PUT", "/my-bucket", "versioning", &routes).unwrap();
589        assert_eq!(op, "PutBucketVersioning");
590    }
591
592    #[test]
593    fn test_header_to_param_name() {
594        assert_eq!(header_to_param_name("x-amz-copy-source"), "CopySource");
595        assert_eq!(
596            header_to_param_name("x-amz-server-side-encryption"),
597            "ServerSideEncryption"
598        );
599    }
600}