1use async_trait::async_trait;
2use bytes::Bytes;
3use http::{HeaderMap, Method, StatusCode};
4use md5::{Digest, Md5};
5use parking_lot::Mutex;
6use std::collections::{BTreeMap, HashMap};
7use std::path::PathBuf;
8
9use crate::auth::Principal;
10
11pub type RequestBodyStream = axum::body::Body;
18
19pub struct AwsRequest {
21 pub service: String,
22 pub action: String,
23 pub region: String,
24 pub account_id: String,
25 pub request_id: String,
26 pub headers: HeaderMap,
27 pub query_params: HashMap<String, String>,
28 pub body: Bytes,
31 pub body_stream: Mutex<Option<RequestBodyStream>>,
36 pub path_segments: Vec<String>,
37 pub raw_path: String,
39 pub raw_query: String,
41 pub method: Method,
42 pub is_query_protocol: bool,
44 pub access_key_id: Option<String>,
46 pub principal: Option<Principal>,
53}
54
55impl std::fmt::Debug for AwsRequest {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 f.debug_struct("AwsRequest")
58 .field("service", &self.service)
59 .field("action", &self.action)
60 .field("region", &self.region)
61 .field("account_id", &self.account_id)
62 .field("request_id", &self.request_id)
63 .field("headers", &self.headers)
64 .field("query_params", &self.query_params)
65 .field("body_len", &self.body.len())
66 .field(
67 "body_stream",
68 &self.body_stream.lock().as_ref().map(|_| "<stream>"),
69 )
70 .field("path_segments", &self.path_segments)
71 .field("raw_path", &self.raw_path)
72 .field("raw_query", &self.raw_query)
73 .field("method", &self.method)
74 .field("is_query_protocol", &self.is_query_protocol)
75 .field("access_key_id", &self.access_key_id)
76 .field("principal", &self.principal)
77 .finish()
78 }
79}
80
81impl AwsRequest {
82 pub fn json_body(&self) -> serde_json::Value {
84 serde_json::from_slice(&self.body).unwrap_or(serde_json::Value::Null)
85 }
86
87 pub fn take_body_stream(&self) -> Option<RequestBodyStream> {
92 self.body_stream.lock().take()
93 }
94}
95
96pub async fn drain_request_stream(stream: RequestBodyStream) -> Result<Bytes, AwsServiceError> {
106 use http_body_util::BodyExt;
107 match stream.collect().await {
108 Ok(c) => Ok(c.to_bytes()),
109 Err(e) => Err(stream_error_to_aws(&e.to_string())),
110 }
111}
112
113fn stream_error_to_aws(msg: &str) -> AwsServiceError {
114 let too_large = msg.to_ascii_lowercase().contains("limit");
119 let (status, code, message) = if too_large {
120 (
121 StatusCode::PAYLOAD_TOO_LARGE,
122 "RequestEntityTooLarge",
123 "Streaming request body exceeded the configured limit",
124 )
125 } else {
126 (
127 StatusCode::BAD_REQUEST,
128 "MalformedRequestBody",
129 "Failed to read streaming request body",
130 )
131 };
132 AwsServiceError::aws_error(status, code, message)
133}
134
135#[derive(Debug)]
146pub struct SpooledBody {
147 pub path: PathBuf,
148 pub size: u64,
149 pub md5_hex: String,
150}
151
152pub async fn spool_request_stream(
164 stream: RequestBodyStream,
165 dir: Option<&std::path::Path>,
166) -> Result<SpooledBody, AwsServiceError> {
167 use http_body_util::BodyExt;
168 use tokio::io::AsyncWriteExt;
169
170 let dir = dir.map(|d| d.to_path_buf());
171 if let Some(d) = dir.as_ref() {
172 let _ = tokio::fs::create_dir_all(d).await;
174 }
175
176 let mut builder = tempfile::Builder::new();
177 builder.prefix("fc-spool-");
178 let named = match dir.as_ref() {
179 Some(d) => builder.tempfile_in(d),
180 None => builder.tempfile(),
181 }
182 .map_err(|e| {
183 AwsServiceError::aws_error(
184 StatusCode::INTERNAL_SERVER_ERROR,
185 "InternalError",
186 format!("failed to create spool tempfile: {e}"),
187 )
188 })?;
189
190 let (std_file, temp_path) = named.into_parts();
193 let path: PathBuf = temp_path.keep().map_err(|e| {
196 AwsServiceError::aws_error(
197 StatusCode::INTERNAL_SERVER_ERROR,
198 "InternalError",
199 format!("failed to persist spool tempfile: {e}"),
200 )
201 })?;
202
203 let mut file = tokio::fs::File::from_std(std_file);
204 let mut hasher = Md5::new();
205 let mut size: u64 = 0;
206 let mut body = stream;
207
208 async fn cleanup(file: tokio::fs::File, path: &std::path::Path) {
213 drop(file);
214 let _ = tokio::fs::remove_file(path).await;
215 }
216
217 loop {
218 match body.frame().await {
219 Some(Ok(frame)) => {
220 if let Ok(chunk) = frame.into_data() {
221 if !chunk.is_empty() {
222 hasher.update(&chunk);
223 size += chunk.len() as u64;
224 if let Err(e) = file.write_all(&chunk).await {
225 cleanup(file, &path).await;
226 return Err(AwsServiceError::aws_error(
227 StatusCode::INTERNAL_SERVER_ERROR,
228 "InternalError",
229 format!("failed to spool request body: {e}"),
230 ));
231 }
232 }
233 }
234 }
236 Some(Err(e)) => {
237 cleanup(file, &path).await;
238 return Err(stream_error_to_aws(&e.to_string()));
239 }
240 None => break,
241 }
242 }
243
244 if let Err(e) = file.flush().await {
245 cleanup(file, &path).await;
246 return Err(AwsServiceError::aws_error(
247 StatusCode::INTERNAL_SERVER_ERROR,
248 "InternalError",
249 format!("failed to flush spool tempfile: {e}"),
250 ));
251 }
252 drop(file);
253
254 let md5_hex = hex_lower(&hasher.finalize());
255 Ok(SpooledBody {
256 path,
257 size,
258 md5_hex,
259 })
260}
261
262fn hex_lower(bytes: &[u8]) -> String {
263 const HEX: &[u8] = b"0123456789abcdef";
264 let mut out = String::with_capacity(bytes.len() * 2);
265 for b in bytes {
266 out.push(HEX[(b >> 4) as usize] as char);
267 out.push(HEX[(b & 0x0f) as usize] as char);
268 }
269 out
270}
271
272#[derive(Debug)]
281pub enum ResponseBody {
282 Bytes(Bytes),
283 File { file: tokio::fs::File, size: u64 },
284}
285
286impl ResponseBody {
287 pub fn len(&self) -> u64 {
288 match self {
289 ResponseBody::Bytes(b) => b.len() as u64,
290 ResponseBody::File { size, .. } => *size,
291 }
292 }
293
294 pub fn is_empty(&self) -> bool {
295 self.len() == 0
296 }
297
298 pub fn expect_bytes(&self) -> &[u8] {
302 match self {
303 ResponseBody::Bytes(b) => b,
304 ResponseBody::File { .. } => {
305 panic!("expect_bytes called on ResponseBody::File")
306 }
307 }
308 }
309}
310
311impl Default for ResponseBody {
312 fn default() -> Self {
313 ResponseBody::Bytes(Bytes::new())
314 }
315}
316
317impl From<Bytes> for ResponseBody {
318 fn from(b: Bytes) -> Self {
319 ResponseBody::Bytes(b)
320 }
321}
322
323impl From<Vec<u8>> for ResponseBody {
324 fn from(v: Vec<u8>) -> Self {
325 ResponseBody::Bytes(Bytes::from(v))
326 }
327}
328
329impl From<&'static [u8]> for ResponseBody {
330 fn from(s: &'static [u8]) -> Self {
331 ResponseBody::Bytes(Bytes::from_static(s))
332 }
333}
334
335impl From<String> for ResponseBody {
336 fn from(s: String) -> Self {
337 ResponseBody::Bytes(Bytes::from(s))
338 }
339}
340
341impl From<&'static str> for ResponseBody {
342 fn from(s: &'static str) -> Self {
343 ResponseBody::Bytes(Bytes::from_static(s.as_bytes()))
344 }
345}
346
347impl PartialEq<Bytes> for ResponseBody {
348 fn eq(&self, other: &Bytes) -> bool {
349 match self {
350 ResponseBody::Bytes(b) => b == other,
351 ResponseBody::File { .. } => false,
352 }
353 }
354}
355
356pub struct AwsResponse {
358 pub status: StatusCode,
359 pub content_type: String,
360 pub body: ResponseBody,
361 pub headers: HeaderMap,
362}
363
364impl AwsResponse {
365 pub fn xml(status: StatusCode, body: impl Into<Bytes>) -> Self {
366 Self {
367 status,
368 content_type: "text/xml".to_string(),
369 body: ResponseBody::Bytes(body.into()),
370 headers: HeaderMap::new(),
371 }
372 }
373
374 pub fn json(status: StatusCode, body: impl Into<Bytes>) -> Self {
375 Self {
376 status,
377 content_type: "application/x-amz-json-1.1".to_string(),
378 body: ResponseBody::Bytes(body.into()),
379 headers: HeaderMap::new(),
380 }
381 }
382
383 pub fn ok_json(value: serde_json::Value) -> Self {
385 Self::json(StatusCode::OK, serde_json::to_vec(&value).unwrap())
386 }
387}
388
389#[derive(Debug, thiserror::Error)]
391pub enum AwsServiceError {
392 #[error("service not found: {service}")]
393 ServiceNotFound { service: String },
394
395 #[error("action {action} not implemented for service {service}")]
396 ActionNotImplemented { service: String, action: String },
397
398 #[error("{code}: {message}")]
399 AwsError {
400 status: StatusCode,
401 code: String,
402 message: String,
403 extra_fields: Vec<(String, String)>,
405 headers: Vec<(String, String)>,
407 },
408}
409
410impl AwsServiceError {
411 pub fn action_not_implemented(service: &str, action: &str) -> Self {
412 Self::ActionNotImplemented {
413 service: service.to_string(),
414 action: action.to_string(),
415 }
416 }
417
418 pub fn aws_error(
419 status: StatusCode,
420 code: impl Into<String>,
421 message: impl Into<String>,
422 ) -> Self {
423 Self::AwsError {
424 status,
425 code: code.into(),
426 message: message.into(),
427 extra_fields: Vec::new(),
428 headers: Vec::new(),
429 }
430 }
431
432 pub fn aws_error_with_fields(
433 status: StatusCode,
434 code: impl Into<String>,
435 message: impl Into<String>,
436 extra_fields: Vec<(String, String)>,
437 ) -> Self {
438 Self::AwsError {
439 status,
440 code: code.into(),
441 message: message.into(),
442 extra_fields,
443 headers: Vec::new(),
444 }
445 }
446
447 pub fn aws_error_with_headers(
448 status: StatusCode,
449 code: impl Into<String>,
450 message: impl Into<String>,
451 headers: Vec<(String, String)>,
452 ) -> Self {
453 Self::AwsError {
454 status,
455 code: code.into(),
456 message: message.into(),
457 extra_fields: Vec::new(),
458 headers,
459 }
460 }
461
462 pub fn extra_fields(&self) -> &[(String, String)] {
463 match self {
464 Self::AwsError { extra_fields, .. } => extra_fields,
465 _ => &[],
466 }
467 }
468
469 pub fn status(&self) -> StatusCode {
470 match self {
471 Self::ServiceNotFound { .. } => StatusCode::BAD_REQUEST,
472 Self::ActionNotImplemented { .. } => StatusCode::NOT_IMPLEMENTED,
473 Self::AwsError { status, .. } => *status,
474 }
475 }
476
477 pub fn code(&self) -> &str {
478 match self {
479 Self::ServiceNotFound { .. } => "UnknownService",
480 Self::ActionNotImplemented { .. } => "InvalidAction",
481 Self::AwsError { code, .. } => code,
482 }
483 }
484
485 pub fn message(&self) -> String {
486 match self {
487 Self::ServiceNotFound { service } => format!("service not found: {service}"),
488 Self::ActionNotImplemented { service, action } => {
489 format!("action {action} not implemented for service {service}")
490 }
491 Self::AwsError { message, .. } => message.clone(),
492 }
493 }
494
495 pub fn response_headers(&self) -> &[(String, String)] {
496 match self {
497 Self::AwsError { headers, .. } => headers,
498 _ => &[],
499 }
500 }
501}
502
503#[async_trait]
505pub trait AwsService: Send + Sync {
506 fn service_name(&self) -> &str;
508
509 async fn handle(&self, request: AwsRequest) -> Result<AwsResponse, AwsServiceError>;
511
512 fn supported_actions(&self) -> &[&str];
514
515 fn iam_enforceable(&self) -> bool {
530 false
531 }
532
533 fn iam_action_for(&self, _request: &AwsRequest) -> Option<crate::auth::IamAction> {
546 None
547 }
548
549 fn iam_condition_keys_for(
569 &self,
570 _request: &AwsRequest,
571 _action: &crate::auth::IamAction,
572 ) -> BTreeMap<String, Vec<String>> {
573 BTreeMap::new()
574 }
575
576 fn resource_tags_for(
588 &self,
589 _resource_arn: &str,
590 ) -> Option<std::collections::HashMap<String, String>> {
591 None
592 }
593
594 fn request_tags_from(
604 &self,
605 _request: &AwsRequest,
606 _action: &str,
607 ) -> Option<std::collections::HashMap<String, String>> {
608 None
609 }
610}
611
612#[cfg(test)]
613mod tests {
614 use super::*;
615 use crate::auth::IamAction;
616 use async_trait::async_trait;
617
618 struct DefaultService;
619
620 #[async_trait]
621 impl AwsService for DefaultService {
622 fn service_name(&self) -> &str {
623 "default"
624 }
625 async fn handle(&self, _request: AwsRequest) -> Result<AwsResponse, AwsServiceError> {
626 unreachable!()
627 }
628 fn supported_actions(&self) -> &[&str] {
629 &[]
630 }
631 }
632
633 struct PopulatedService;
634
635 #[async_trait]
636 impl AwsService for PopulatedService {
637 fn service_name(&self) -> &str {
638 "populated"
639 }
640 async fn handle(&self, _request: AwsRequest) -> Result<AwsResponse, AwsServiceError> {
641 unreachable!()
642 }
643 fn supported_actions(&self) -> &[&str] {
644 &[]
645 }
646 fn iam_condition_keys_for(
647 &self,
648 _request: &AwsRequest,
649 _action: &IamAction,
650 ) -> BTreeMap<String, Vec<String>> {
651 let mut m = BTreeMap::new();
652 m.insert("s3:prefix".to_string(), vec!["logs/".to_string()]);
653 m
654 }
655 }
656
657 fn sample_request() -> AwsRequest {
658 AwsRequest {
659 service: "default".into(),
660 action: "Noop".into(),
661 region: "us-east-1".into(),
662 account_id: "123456789012".into(),
663 request_id: "req-1".into(),
664 headers: HeaderMap::new(),
665 query_params: HashMap::new(),
666 body: Bytes::new(),
667 body_stream: parking_lot::Mutex::new(None),
668 path_segments: vec![],
669 raw_path: "/".into(),
670 raw_query: String::new(),
671 method: Method::GET,
672 is_query_protocol: false,
673 access_key_id: None,
674 principal: None,
675 }
676 }
677
678 fn sample_action() -> IamAction {
679 IamAction {
680 service: "s3",
681 action: "ListBucket",
682 resource: "arn:aws:s3:::my-bucket".to_string(),
683 }
684 }
685
686 #[test]
687 fn iam_condition_keys_for_default_is_empty() {
688 let svc = DefaultService;
689 let keys = svc.iam_condition_keys_for(&sample_request(), &sample_action());
690 assert!(keys.is_empty());
691 }
692
693 #[test]
694 fn iam_condition_keys_for_override_returns_map() {
695 let svc = PopulatedService;
696 let keys = svc.iam_condition_keys_for(&sample_request(), &sample_action());
697 assert_eq!(keys.get("s3:prefix"), Some(&vec!["logs/".to_string()]));
698 }
699
700 #[test]
701 fn response_body_len_and_is_empty_for_bytes() {
702 let body: ResponseBody = Bytes::from_static(b"hello").into();
703 assert_eq!(body.len(), 5);
704 assert!(!body.is_empty());
705 let empty: ResponseBody = ResponseBody::default();
706 assert!(empty.is_empty());
707 }
708
709 #[test]
710 fn response_body_from_vec_and_string_and_str() {
711 let from_vec: ResponseBody = vec![1u8, 2, 3].into();
712 assert_eq!(from_vec.expect_bytes(), &[1, 2, 3][..]);
713 let from_string: ResponseBody = String::from("hi").into();
714 assert_eq!(from_string.expect_bytes(), b"hi");
715 let from_str: ResponseBody = "hey".into();
716 assert_eq!(from_str.expect_bytes(), b"hey");
717 let from_static: ResponseBody = (b"123" as &'static [u8]).into();
718 assert_eq!(from_static.expect_bytes(), b"123");
719 }
720
721 #[test]
722 fn response_body_partial_eq_bytes() {
723 let body: ResponseBody = Bytes::from_static(b"x").into();
724 assert!(body == Bytes::from_static(b"x"));
725 assert!(!(body == Bytes::from_static(b"y")));
726 }
727
728 #[test]
729 fn aws_request_json_body_empty_returns_null() {
730 let req = sample_request();
731 assert_eq!(req.json_body(), serde_json::Value::Null);
732 }
733
734 #[test]
735 fn aws_request_json_body_parses_valid() {
736 let mut req = sample_request();
737 req.body = Bytes::from_static(br#"{"a":1}"#);
738 assert_eq!(req.json_body(), serde_json::json!({"a": 1}));
739 }
740
741 #[test]
742 fn aws_response_xml_constructor() {
743 let resp = AwsResponse::xml(StatusCode::OK, Bytes::from_static(b"<ok/>"));
744 assert_eq!(resp.status, StatusCode::OK);
745 assert_eq!(resp.content_type, "text/xml");
746 }
747
748 #[test]
749 fn aws_response_json_constructor() {
750 let resp = AwsResponse::json(StatusCode::CREATED, "{}");
751 assert_eq!(resp.status, StatusCode::CREATED);
752 assert_eq!(resp.content_type, "application/x-amz-json-1.1");
753 }
754
755 #[test]
756 fn aws_response_ok_json_helper() {
757 let resp = AwsResponse::ok_json(serde_json::json!({"ok": true}));
758 assert_eq!(resp.status, StatusCode::OK);
759 assert!(resp.body.expect_bytes().starts_with(b"{"));
760 }
761
762 #[test]
763 fn aws_error_service_not_found_fields() {
764 let err = AwsServiceError::ServiceNotFound {
765 service: "sqs".to_string(),
766 };
767 assert_eq!(err.status(), StatusCode::BAD_REQUEST);
768 assert_eq!(err.code(), "UnknownService");
769 assert!(err.message().contains("sqs"));
770 assert!(err.extra_fields().is_empty());
771 assert!(err.response_headers().is_empty());
772 }
773
774 #[test]
775 fn aws_error_action_not_implemented_fields() {
776 let err = AwsServiceError::action_not_implemented("sns", "FutureAction");
777 assert_eq!(err.status(), StatusCode::NOT_IMPLEMENTED);
778 assert_eq!(err.code(), "InvalidAction");
779 assert!(err.message().contains("FutureAction"));
780 assert!(err.message().contains("sns"));
781 }
782
783 #[test]
784 fn aws_error_aws_error_helpers() {
785 let e = AwsServiceError::aws_error(StatusCode::FORBIDDEN, "Denied", "no");
786 assert_eq!(e.status(), StatusCode::FORBIDDEN);
787 assert_eq!(e.code(), "Denied");
788 assert_eq!(e.message(), "no");
789
790 let fields = vec![("Bucket".to_string(), "b".to_string())];
791 let ef = AwsServiceError::aws_error_with_fields(
792 StatusCode::NOT_FOUND,
793 "Missing",
794 "gone",
795 fields.clone(),
796 );
797 assert_eq!(ef.extra_fields(), fields.as_slice());
798
799 let hdrs = vec![("X-Retry".to_string(), "1".to_string())];
800 let eh = AwsServiceError::aws_error_with_headers(
801 StatusCode::TOO_MANY_REQUESTS,
802 "Throttled",
803 "slow",
804 hdrs.clone(),
805 );
806 assert_eq!(eh.response_headers(), hdrs.as_slice());
807 }
808
809 #[test]
810 #[should_panic(expected = "expect_bytes called on ResponseBody::File")]
811 fn response_body_expect_bytes_panics_on_file() {
812 let f = std::fs::File::create(std::env::temp_dir().join("fc-test-expect-file")).unwrap();
813 let async_f = tokio::fs::File::from_std(f);
814 let body = ResponseBody::File {
815 file: async_f,
816 size: 0,
817 };
818 let _ = body.expect_bytes();
819 }
820}