1use async_trait::async_trait;
2use bytes::Bytes;
3use http::{HeaderMap, Method, StatusCode};
4use md5::{Digest, Md5};
5use parking_lot::Mutex;
6use std::collections::{BTreeMap, HashMap};
7use std::path::PathBuf;
8
9use crate::auth::Principal;
10
11pub type RequestBodyStream = axum::body::Body;
18
19pub struct AwsRequest {
21 pub service: String,
22 pub action: String,
23 pub region: String,
24 pub account_id: String,
25 pub request_id: String,
26 pub headers: HeaderMap,
27 pub query_params: HashMap<String, String>,
28 pub body: Bytes,
31 pub body_stream: Mutex<Option<RequestBodyStream>>,
36 pub path_segments: Vec<String>,
37 pub raw_path: String,
39 pub raw_query: String,
41 pub method: Method,
42 pub is_query_protocol: bool,
44 pub access_key_id: Option<String>,
46 pub principal: Option<Principal>,
53}
54
55impl std::fmt::Debug for AwsRequest {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 f.debug_struct("AwsRequest")
58 .field("service", &self.service)
59 .field("action", &self.action)
60 .field("region", &self.region)
61 .field("account_id", &self.account_id)
62 .field("request_id", &self.request_id)
63 .field("headers", &self.headers)
64 .field("query_params", &self.query_params)
65 .field("body_len", &self.body.len())
66 .field(
67 "body_stream",
68 &self.body_stream.lock().as_ref().map(|_| "<stream>"),
69 )
70 .field("path_segments", &self.path_segments)
71 .field("raw_path", &self.raw_path)
72 .field("raw_query", &self.raw_query)
73 .field("method", &self.method)
74 .field("is_query_protocol", &self.is_query_protocol)
75 .field("access_key_id", &self.access_key_id)
76 .field("principal", &self.principal)
77 .finish()
78 }
79}
80
81impl AwsRequest {
82 pub fn json_body(&self) -> serde_json::Value {
84 serde_json::from_slice(&self.body).unwrap_or(serde_json::Value::Null)
85 }
86
87 pub fn take_body_stream(&self) -> Option<RequestBodyStream> {
92 self.body_stream.lock().take()
93 }
94}
95
96pub async fn drain_request_stream(stream: RequestBodyStream) -> Result<Bytes, AwsServiceError> {
106 use http_body_util::BodyExt;
107 match stream.collect().await {
108 Ok(c) => Ok(c.to_bytes()),
109 Err(e) => Err(stream_error_to_aws(&e.to_string())),
110 }
111}
112
113fn stream_error_to_aws(msg: &str) -> AwsServiceError {
114 let too_large = msg.to_ascii_lowercase().contains("limit");
119 let (status, code, message) = if too_large {
120 (
121 StatusCode::PAYLOAD_TOO_LARGE,
122 "RequestEntityTooLarge",
123 "Streaming request body exceeded the configured limit",
124 )
125 } else {
126 (
127 StatusCode::BAD_REQUEST,
128 "MalformedRequestBody",
129 "Failed to read streaming request body",
130 )
131 };
132 AwsServiceError::aws_error(status, code, message)
133}
134
135#[derive(Debug)]
146pub struct SpooledBody {
147 pub path: PathBuf,
148 pub size: u64,
149 pub md5_hex: String,
150}
151
152pub async fn spool_request_stream(
164 stream: RequestBodyStream,
165 dir: Option<&std::path::Path>,
166) -> Result<SpooledBody, AwsServiceError> {
167 use http_body_util::BodyExt;
168 use tokio::io::AsyncWriteExt;
169
170 let dir = dir.map(|d| d.to_path_buf());
171 if let Some(d) = dir.as_ref() {
172 let _ = tokio::fs::create_dir_all(d).await;
174 }
175
176 let mut builder = tempfile::Builder::new();
177 builder.prefix("fc-spool-");
178 let named = match dir.as_ref() {
179 Some(d) => builder.tempfile_in(d),
180 None => builder.tempfile(),
181 }
182 .map_err(|e| {
183 AwsServiceError::aws_error(
184 StatusCode::INTERNAL_SERVER_ERROR,
185 "InternalError",
186 format!("failed to create spool tempfile: {e}"),
187 )
188 })?;
189
190 let (std_file, temp_path) = named.into_parts();
193 let path: PathBuf = temp_path.keep().map_err(|e| {
196 AwsServiceError::aws_error(
197 StatusCode::INTERNAL_SERVER_ERROR,
198 "InternalError",
199 format!("failed to persist spool tempfile: {e}"),
200 )
201 })?;
202
203 let mut file = tokio::fs::File::from_std(std_file);
204 let mut hasher = Md5::new();
205 let mut size: u64 = 0;
206 let mut body = stream;
207
208 async fn cleanup(file: tokio::fs::File, path: &std::path::Path) {
213 drop(file);
214 let _ = tokio::fs::remove_file(path).await;
215 }
216
217 loop {
218 match body.frame().await {
219 Some(Ok(frame)) => {
220 if let Ok(chunk) = frame.into_data() {
221 if !chunk.is_empty() {
222 hasher.update(&chunk);
223 size += chunk.len() as u64;
224 if let Err(e) = file.write_all(&chunk).await {
225 cleanup(file, &path).await;
226 return Err(AwsServiceError::aws_error(
227 StatusCode::INTERNAL_SERVER_ERROR,
228 "InternalError",
229 format!("failed to spool request body: {e}"),
230 ));
231 }
232 }
233 }
234 }
236 Some(Err(e)) => {
237 cleanup(file, &path).await;
238 return Err(stream_error_to_aws(&e.to_string()));
239 }
240 None => break,
241 }
242 }
243
244 if let Err(e) = file.flush().await {
245 cleanup(file, &path).await;
246 return Err(AwsServiceError::aws_error(
247 StatusCode::INTERNAL_SERVER_ERROR,
248 "InternalError",
249 format!("failed to flush spool tempfile: {e}"),
250 ));
251 }
252 drop(file);
253
254 let md5_hex = hex_lower(&hasher.finalize());
255 Ok(SpooledBody {
256 path,
257 size,
258 md5_hex,
259 })
260}
261
262fn hex_lower(bytes: &[u8]) -> String {
263 const HEX: &[u8] = b"0123456789abcdef";
264 let mut out = String::with_capacity(bytes.len() * 2);
265 for b in bytes {
266 out.push(HEX[(b >> 4) as usize] as char);
267 out.push(HEX[(b & 0x0f) as usize] as char);
268 }
269 out
270}
271
272#[derive(Debug)]
281pub enum ResponseBody {
282 Bytes(Bytes),
283 File { file: tokio::fs::File, size: u64 },
284}
285
286impl ResponseBody {
287 pub fn len(&self) -> u64 {
288 match self {
289 ResponseBody::Bytes(b) => b.len() as u64,
290 ResponseBody::File { size, .. } => *size,
291 }
292 }
293
294 pub fn is_empty(&self) -> bool {
295 self.len() == 0
296 }
297
298 pub fn expect_bytes(&self) -> &[u8] {
302 match self {
303 ResponseBody::Bytes(b) => b,
304 ResponseBody::File { .. } => {
305 panic!("expect_bytes called on ResponseBody::File")
306 }
307 }
308 }
309}
310
311impl Default for ResponseBody {
312 fn default() -> Self {
313 ResponseBody::Bytes(Bytes::new())
314 }
315}
316
317impl From<Bytes> for ResponseBody {
318 fn from(b: Bytes) -> Self {
319 ResponseBody::Bytes(b)
320 }
321}
322
323impl From<Vec<u8>> for ResponseBody {
324 fn from(v: Vec<u8>) -> Self {
325 ResponseBody::Bytes(Bytes::from(v))
326 }
327}
328
329impl From<&'static [u8]> for ResponseBody {
330 fn from(s: &'static [u8]) -> Self {
331 ResponseBody::Bytes(Bytes::from_static(s))
332 }
333}
334
335impl From<String> for ResponseBody {
336 fn from(s: String) -> Self {
337 ResponseBody::Bytes(Bytes::from(s))
338 }
339}
340
341impl From<&'static str> for ResponseBody {
342 fn from(s: &'static str) -> Self {
343 ResponseBody::Bytes(Bytes::from_static(s.as_bytes()))
344 }
345}
346
347impl PartialEq<Bytes> for ResponseBody {
348 fn eq(&self, other: &Bytes) -> bool {
349 match self {
350 ResponseBody::Bytes(b) => b == other,
351 ResponseBody::File { .. } => false,
352 }
353 }
354}
355
356pub struct AwsResponse {
358 pub status: StatusCode,
359 pub content_type: String,
360 pub body: ResponseBody,
361 pub headers: HeaderMap,
362}
363
364impl AwsResponse {
365 pub fn xml(status: StatusCode, body: impl Into<Bytes>) -> Self {
366 Self {
367 status,
368 content_type: "text/xml".to_string(),
369 body: ResponseBody::Bytes(body.into()),
370 headers: HeaderMap::new(),
371 }
372 }
373
374 pub fn json(status: StatusCode, body: impl Into<Bytes>) -> Self {
375 Self {
376 status,
377 content_type: "application/x-amz-json-1.1".to_string(),
378 body: ResponseBody::Bytes(body.into()),
379 headers: HeaderMap::new(),
380 }
381 }
382
383 pub fn json_value(status: StatusCode, value: serde_json::Value) -> Self {
389 Self::json(
390 status,
391 serde_json::to_vec(&value).expect("serde_json::Value serialization is infallible"),
392 )
393 }
394
395 pub fn ok_json(value: serde_json::Value) -> Self {
397 Self::json_value(StatusCode::OK, value)
398 }
399}
400
401#[derive(Debug, thiserror::Error)]
403pub enum AwsServiceError {
404 #[error("service not found: {service}")]
405 ServiceNotFound { service: String },
406
407 #[error("action {action} not implemented for service {service}")]
408 ActionNotImplemented { service: String, action: String },
409
410 #[error("{code}: {message}")]
411 AwsError {
412 status: StatusCode,
413 code: String,
414 message: String,
415 extra_fields: Vec<(String, String)>,
417 headers: Vec<(String, String)>,
419 },
420}
421
422impl AwsServiceError {
423 pub fn action_not_implemented(service: &str, action: &str) -> Self {
424 Self::ActionNotImplemented {
425 service: service.to_string(),
426 action: action.to_string(),
427 }
428 }
429
430 pub fn aws_error(
431 status: StatusCode,
432 code: impl Into<String>,
433 message: impl Into<String>,
434 ) -> Self {
435 Self::AwsError {
436 status,
437 code: code.into(),
438 message: message.into(),
439 extra_fields: Vec::new(),
440 headers: Vec::new(),
441 }
442 }
443
444 pub fn aws_error_with_fields(
445 status: StatusCode,
446 code: impl Into<String>,
447 message: impl Into<String>,
448 extra_fields: Vec<(String, String)>,
449 ) -> Self {
450 Self::AwsError {
451 status,
452 code: code.into(),
453 message: message.into(),
454 extra_fields,
455 headers: Vec::new(),
456 }
457 }
458
459 pub fn aws_error_with_headers(
460 status: StatusCode,
461 code: impl Into<String>,
462 message: impl Into<String>,
463 headers: Vec<(String, String)>,
464 ) -> Self {
465 Self::AwsError {
466 status,
467 code: code.into(),
468 message: message.into(),
469 extra_fields: Vec::new(),
470 headers,
471 }
472 }
473
474 pub fn extra_fields(&self) -> &[(String, String)] {
475 match self {
476 Self::AwsError { extra_fields, .. } => extra_fields,
477 _ => &[],
478 }
479 }
480
481 pub fn status(&self) -> StatusCode {
482 match self {
483 Self::ServiceNotFound { .. } => StatusCode::BAD_REQUEST,
484 Self::ActionNotImplemented { .. } => StatusCode::NOT_IMPLEMENTED,
485 Self::AwsError { status, .. } => *status,
486 }
487 }
488
489 pub fn code(&self) -> &str {
490 match self {
491 Self::ServiceNotFound { .. } => "UnknownService",
492 Self::ActionNotImplemented { .. } => "InvalidAction",
493 Self::AwsError { code, .. } => code,
494 }
495 }
496
497 pub fn message(&self) -> String {
498 match self {
499 Self::ServiceNotFound { service } => format!("service not found: {service}"),
500 Self::ActionNotImplemented { service, action } => {
501 format!("action {action} not implemented for service {service}")
502 }
503 Self::AwsError { message, .. } => message.clone(),
504 }
505 }
506
507 pub fn response_headers(&self) -> &[(String, String)] {
508 match self {
509 Self::AwsError { headers, .. } => headers,
510 _ => &[],
511 }
512 }
513}
514
515#[async_trait]
517pub trait AwsService: Send + Sync {
518 fn service_name(&self) -> &str;
520
521 async fn handle(&self, request: AwsRequest) -> Result<AwsResponse, AwsServiceError>;
523
524 fn supported_actions(&self) -> &[&str];
526
527 fn iam_enforceable(&self) -> bool {
542 false
543 }
544
545 fn iam_action_for(&self, _request: &AwsRequest) -> Option<crate::auth::IamAction> {
558 None
559 }
560
561 fn iam_condition_keys_for(
581 &self,
582 _request: &AwsRequest,
583 _action: &crate::auth::IamAction,
584 ) -> BTreeMap<String, Vec<String>> {
585 BTreeMap::new()
586 }
587
588 fn resource_tags_for(
600 &self,
601 _resource_arn: &str,
602 ) -> Option<std::collections::HashMap<String, String>> {
603 None
604 }
605
606 fn request_tags_from(
616 &self,
617 _request: &AwsRequest,
618 _action: &str,
619 ) -> Option<std::collections::HashMap<String, String>> {
620 None
621 }
622}
623
624#[cfg(test)]
625mod tests {
626 use super::*;
627 use crate::auth::IamAction;
628 use async_trait::async_trait;
629
630 struct DefaultService;
631
632 #[async_trait]
633 impl AwsService for DefaultService {
634 fn service_name(&self) -> &str {
635 "default"
636 }
637 async fn handle(&self, _request: AwsRequest) -> Result<AwsResponse, AwsServiceError> {
638 unreachable!()
639 }
640 fn supported_actions(&self) -> &[&str] {
641 &[]
642 }
643 }
644
645 struct PopulatedService;
646
647 #[async_trait]
648 impl AwsService for PopulatedService {
649 fn service_name(&self) -> &str {
650 "populated"
651 }
652 async fn handle(&self, _request: AwsRequest) -> Result<AwsResponse, AwsServiceError> {
653 unreachable!()
654 }
655 fn supported_actions(&self) -> &[&str] {
656 &[]
657 }
658 fn iam_condition_keys_for(
659 &self,
660 _request: &AwsRequest,
661 _action: &IamAction,
662 ) -> BTreeMap<String, Vec<String>> {
663 let mut m = BTreeMap::new();
664 m.insert("s3:prefix".to_string(), vec!["logs/".to_string()]);
665 m
666 }
667 }
668
669 fn sample_request() -> AwsRequest {
670 AwsRequest {
671 service: "default".into(),
672 action: "Noop".into(),
673 region: "us-east-1".into(),
674 account_id: "123456789012".into(),
675 request_id: "req-1".into(),
676 headers: HeaderMap::new(),
677 query_params: HashMap::new(),
678 body: Bytes::new(),
679 body_stream: parking_lot::Mutex::new(None),
680 path_segments: vec![],
681 raw_path: "/".into(),
682 raw_query: String::new(),
683 method: Method::GET,
684 is_query_protocol: false,
685 access_key_id: None,
686 principal: None,
687 }
688 }
689
690 fn sample_action() -> IamAction {
691 IamAction {
692 service: "s3",
693 action: "ListBucket",
694 resource: "arn:aws:s3:::my-bucket".to_string(),
695 }
696 }
697
698 #[test]
699 fn iam_condition_keys_for_default_is_empty() {
700 let svc = DefaultService;
701 let keys = svc.iam_condition_keys_for(&sample_request(), &sample_action());
702 assert!(keys.is_empty());
703 }
704
705 #[test]
706 fn iam_condition_keys_for_override_returns_map() {
707 let svc = PopulatedService;
708 let keys = svc.iam_condition_keys_for(&sample_request(), &sample_action());
709 assert_eq!(keys.get("s3:prefix"), Some(&vec!["logs/".to_string()]));
710 }
711
712 #[test]
713 fn response_body_len_and_is_empty_for_bytes() {
714 let body: ResponseBody = Bytes::from_static(b"hello").into();
715 assert_eq!(body.len(), 5);
716 assert!(!body.is_empty());
717 let empty: ResponseBody = ResponseBody::default();
718 assert!(empty.is_empty());
719 }
720
721 #[test]
722 fn response_body_from_vec_and_string_and_str() {
723 let from_vec: ResponseBody = vec![1u8, 2, 3].into();
724 assert_eq!(from_vec.expect_bytes(), &[1, 2, 3][..]);
725 let from_string: ResponseBody = String::from("hi").into();
726 assert_eq!(from_string.expect_bytes(), b"hi");
727 let from_str: ResponseBody = "hey".into();
728 assert_eq!(from_str.expect_bytes(), b"hey");
729 let from_static: ResponseBody = (b"123" as &'static [u8]).into();
730 assert_eq!(from_static.expect_bytes(), b"123");
731 }
732
733 #[test]
734 fn response_body_partial_eq_bytes() {
735 let body: ResponseBody = Bytes::from_static(b"x").into();
736 assert!(body == Bytes::from_static(b"x"));
737 assert!(!(body == Bytes::from_static(b"y")));
738 }
739
740 #[test]
741 fn aws_request_json_body_empty_returns_null() {
742 let req = sample_request();
743 assert_eq!(req.json_body(), serde_json::Value::Null);
744 }
745
746 #[test]
747 fn aws_request_json_body_parses_valid() {
748 let mut req = sample_request();
749 req.body = Bytes::from_static(br#"{"a":1}"#);
750 assert_eq!(req.json_body(), serde_json::json!({"a": 1}));
751 }
752
753 #[test]
754 fn aws_response_xml_constructor() {
755 let resp = AwsResponse::xml(StatusCode::OK, Bytes::from_static(b"<ok/>"));
756 assert_eq!(resp.status, StatusCode::OK);
757 assert_eq!(resp.content_type, "text/xml");
758 }
759
760 #[test]
761 fn aws_response_json_constructor() {
762 let resp = AwsResponse::json(StatusCode::CREATED, "{}");
763 assert_eq!(resp.status, StatusCode::CREATED);
764 assert_eq!(resp.content_type, "application/x-amz-json-1.1");
765 }
766
767 #[test]
768 fn aws_response_ok_json_helper() {
769 let resp = AwsResponse::ok_json(serde_json::json!({"ok": true}));
770 assert_eq!(resp.status, StatusCode::OK);
771 assert!(resp.body.expect_bytes().starts_with(b"{"));
772 }
773
774 #[test]
775 fn aws_error_service_not_found_fields() {
776 let err = AwsServiceError::ServiceNotFound {
777 service: "sqs".to_string(),
778 };
779 assert_eq!(err.status(), StatusCode::BAD_REQUEST);
780 assert_eq!(err.code(), "UnknownService");
781 assert!(err.message().contains("sqs"));
782 assert!(err.extra_fields().is_empty());
783 assert!(err.response_headers().is_empty());
784 }
785
786 #[test]
787 fn aws_error_action_not_implemented_fields() {
788 let err = AwsServiceError::action_not_implemented("sns", "FutureAction");
789 assert_eq!(err.status(), StatusCode::NOT_IMPLEMENTED);
790 assert_eq!(err.code(), "InvalidAction");
791 assert!(err.message().contains("FutureAction"));
792 assert!(err.message().contains("sns"));
793 }
794
795 #[test]
796 fn aws_error_aws_error_helpers() {
797 let e = AwsServiceError::aws_error(StatusCode::FORBIDDEN, "Denied", "no");
798 assert_eq!(e.status(), StatusCode::FORBIDDEN);
799 assert_eq!(e.code(), "Denied");
800 assert_eq!(e.message(), "no");
801
802 let fields = vec![("Bucket".to_string(), "b".to_string())];
803 let ef = AwsServiceError::aws_error_with_fields(
804 StatusCode::NOT_FOUND,
805 "Missing",
806 "gone",
807 fields.clone(),
808 );
809 assert_eq!(ef.extra_fields(), fields.as_slice());
810
811 let hdrs = vec![("X-Retry".to_string(), "1".to_string())];
812 let eh = AwsServiceError::aws_error_with_headers(
813 StatusCode::TOO_MANY_REQUESTS,
814 "Throttled",
815 "slow",
816 hdrs.clone(),
817 );
818 assert_eq!(eh.response_headers(), hdrs.as_slice());
819 }
820
821 #[test]
822 #[should_panic(expected = "expect_bytes called on ResponseBody::File")]
823 fn response_body_expect_bytes_panics_on_file() {
824 let f = std::fs::File::create(std::env::temp_dir().join("fc-test-expect-file")).unwrap();
825 let async_f = tokio::fs::File::from_std(f);
826 let body = ResponseBody::File {
827 file: async_f,
828 size: 0,
829 };
830 let _ = body.expect_bytes();
831 }
832}