1use async_trait::async_trait;
2use bytes::Bytes;
3use http::{HeaderMap, Method, StatusCode};
4use std::collections::{BTreeMap, HashMap};
5
6use crate::auth::Principal;
7
8#[derive(Debug)]
10pub struct AwsRequest {
11 pub service: String,
12 pub action: String,
13 pub region: String,
14 pub account_id: String,
15 pub request_id: String,
16 pub headers: HeaderMap,
17 pub query_params: HashMap<String, String>,
18 pub body: Bytes,
19 pub path_segments: Vec<String>,
20 pub raw_path: String,
22 pub raw_query: String,
24 pub method: Method,
25 pub is_query_protocol: bool,
27 pub access_key_id: Option<String>,
29 pub principal: Option<Principal>,
36}
37
38impl AwsRequest {
39 pub fn json_body(&self) -> serde_json::Value {
41 serde_json::from_slice(&self.body).unwrap_or(serde_json::Value::Null)
42 }
43}
44
45#[derive(Debug)]
54pub enum ResponseBody {
55 Bytes(Bytes),
56 File { file: tokio::fs::File, size: u64 },
57}
58
59impl ResponseBody {
60 pub fn len(&self) -> u64 {
61 match self {
62 ResponseBody::Bytes(b) => b.len() as u64,
63 ResponseBody::File { size, .. } => *size,
64 }
65 }
66
67 pub fn is_empty(&self) -> bool {
68 self.len() == 0
69 }
70
71 pub fn expect_bytes(&self) -> &[u8] {
75 match self {
76 ResponseBody::Bytes(b) => b,
77 ResponseBody::File { .. } => {
78 panic!("expect_bytes called on ResponseBody::File")
79 }
80 }
81 }
82}
83
84impl Default for ResponseBody {
85 fn default() -> Self {
86 ResponseBody::Bytes(Bytes::new())
87 }
88}
89
90impl From<Bytes> for ResponseBody {
91 fn from(b: Bytes) -> Self {
92 ResponseBody::Bytes(b)
93 }
94}
95
96impl From<Vec<u8>> for ResponseBody {
97 fn from(v: Vec<u8>) -> Self {
98 ResponseBody::Bytes(Bytes::from(v))
99 }
100}
101
102impl From<&'static [u8]> for ResponseBody {
103 fn from(s: &'static [u8]) -> Self {
104 ResponseBody::Bytes(Bytes::from_static(s))
105 }
106}
107
108impl From<String> for ResponseBody {
109 fn from(s: String) -> Self {
110 ResponseBody::Bytes(Bytes::from(s))
111 }
112}
113
114impl From<&'static str> for ResponseBody {
115 fn from(s: &'static str) -> Self {
116 ResponseBody::Bytes(Bytes::from_static(s.as_bytes()))
117 }
118}
119
120impl PartialEq<Bytes> for ResponseBody {
121 fn eq(&self, other: &Bytes) -> bool {
122 match self {
123 ResponseBody::Bytes(b) => b == other,
124 ResponseBody::File { .. } => false,
125 }
126 }
127}
128
129pub struct AwsResponse {
131 pub status: StatusCode,
132 pub content_type: String,
133 pub body: ResponseBody,
134 pub headers: HeaderMap,
135}
136
137impl AwsResponse {
138 pub fn xml(status: StatusCode, body: impl Into<Bytes>) -> Self {
139 Self {
140 status,
141 content_type: "text/xml".to_string(),
142 body: ResponseBody::Bytes(body.into()),
143 headers: HeaderMap::new(),
144 }
145 }
146
147 pub fn json(status: StatusCode, body: impl Into<Bytes>) -> Self {
148 Self {
149 status,
150 content_type: "application/x-amz-json-1.1".to_string(),
151 body: ResponseBody::Bytes(body.into()),
152 headers: HeaderMap::new(),
153 }
154 }
155
156 pub fn ok_json(value: serde_json::Value) -> Self {
158 Self::json(StatusCode::OK, serde_json::to_vec(&value).unwrap())
159 }
160}
161
162#[derive(Debug, thiserror::Error)]
164pub enum AwsServiceError {
165 #[error("service not found: {service}")]
166 ServiceNotFound { service: String },
167
168 #[error("action {action} not implemented for service {service}")]
169 ActionNotImplemented { service: String, action: String },
170
171 #[error("{code}: {message}")]
172 AwsError {
173 status: StatusCode,
174 code: String,
175 message: String,
176 extra_fields: Vec<(String, String)>,
178 headers: Vec<(String, String)>,
180 },
181}
182
183impl AwsServiceError {
184 pub fn action_not_implemented(service: &str, action: &str) -> Self {
185 Self::ActionNotImplemented {
186 service: service.to_string(),
187 action: action.to_string(),
188 }
189 }
190
191 pub fn aws_error(
192 status: StatusCode,
193 code: impl Into<String>,
194 message: impl Into<String>,
195 ) -> Self {
196 Self::AwsError {
197 status,
198 code: code.into(),
199 message: message.into(),
200 extra_fields: Vec::new(),
201 headers: Vec::new(),
202 }
203 }
204
205 pub fn aws_error_with_fields(
206 status: StatusCode,
207 code: impl Into<String>,
208 message: impl Into<String>,
209 extra_fields: Vec<(String, String)>,
210 ) -> Self {
211 Self::AwsError {
212 status,
213 code: code.into(),
214 message: message.into(),
215 extra_fields,
216 headers: Vec::new(),
217 }
218 }
219
220 pub fn aws_error_with_headers(
221 status: StatusCode,
222 code: impl Into<String>,
223 message: impl Into<String>,
224 headers: Vec<(String, String)>,
225 ) -> Self {
226 Self::AwsError {
227 status,
228 code: code.into(),
229 message: message.into(),
230 extra_fields: Vec::new(),
231 headers,
232 }
233 }
234
235 pub fn extra_fields(&self) -> &[(String, String)] {
236 match self {
237 Self::AwsError { extra_fields, .. } => extra_fields,
238 _ => &[],
239 }
240 }
241
242 pub fn status(&self) -> StatusCode {
243 match self {
244 Self::ServiceNotFound { .. } => StatusCode::BAD_REQUEST,
245 Self::ActionNotImplemented { .. } => StatusCode::NOT_IMPLEMENTED,
246 Self::AwsError { status, .. } => *status,
247 }
248 }
249
250 pub fn code(&self) -> &str {
251 match self {
252 Self::ServiceNotFound { .. } => "UnknownService",
253 Self::ActionNotImplemented { .. } => "InvalidAction",
254 Self::AwsError { code, .. } => code,
255 }
256 }
257
258 pub fn message(&self) -> String {
259 match self {
260 Self::ServiceNotFound { service } => format!("service not found: {service}"),
261 Self::ActionNotImplemented { service, action } => {
262 format!("action {action} not implemented for service {service}")
263 }
264 Self::AwsError { message, .. } => message.clone(),
265 }
266 }
267
268 pub fn response_headers(&self) -> &[(String, String)] {
269 match self {
270 Self::AwsError { headers, .. } => headers,
271 _ => &[],
272 }
273 }
274}
275
276#[async_trait]
278pub trait AwsService: Send + Sync {
279 fn service_name(&self) -> &str;
281
282 async fn handle(&self, request: AwsRequest) -> Result<AwsResponse, AwsServiceError>;
284
285 fn supported_actions(&self) -> &[&str];
287
288 fn iam_enforceable(&self) -> bool {
303 false
304 }
305
306 fn iam_action_for(&self, _request: &AwsRequest) -> Option<crate::auth::IamAction> {
319 None
320 }
321
322 fn iam_condition_keys_for(
342 &self,
343 _request: &AwsRequest,
344 _action: &crate::auth::IamAction,
345 ) -> BTreeMap<String, Vec<String>> {
346 BTreeMap::new()
347 }
348
349 fn resource_tags_for(
361 &self,
362 _resource_arn: &str,
363 ) -> Option<std::collections::HashMap<String, String>> {
364 None
365 }
366
367 fn request_tags_from(
377 &self,
378 _request: &AwsRequest,
379 _action: &str,
380 ) -> Option<std::collections::HashMap<String, String>> {
381 None
382 }
383}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388 use crate::auth::IamAction;
389 use async_trait::async_trait;
390
391 struct DefaultService;
392
393 #[async_trait]
394 impl AwsService for DefaultService {
395 fn service_name(&self) -> &str {
396 "default"
397 }
398 async fn handle(&self, _request: AwsRequest) -> Result<AwsResponse, AwsServiceError> {
399 unreachable!()
400 }
401 fn supported_actions(&self) -> &[&str] {
402 &[]
403 }
404 }
405
406 struct PopulatedService;
407
408 #[async_trait]
409 impl AwsService for PopulatedService {
410 fn service_name(&self) -> &str {
411 "populated"
412 }
413 async fn handle(&self, _request: AwsRequest) -> Result<AwsResponse, AwsServiceError> {
414 unreachable!()
415 }
416 fn supported_actions(&self) -> &[&str] {
417 &[]
418 }
419 fn iam_condition_keys_for(
420 &self,
421 _request: &AwsRequest,
422 _action: &IamAction,
423 ) -> BTreeMap<String, Vec<String>> {
424 let mut m = BTreeMap::new();
425 m.insert("s3:prefix".to_string(), vec!["logs/".to_string()]);
426 m
427 }
428 }
429
430 fn sample_request() -> AwsRequest {
431 AwsRequest {
432 service: "default".into(),
433 action: "Noop".into(),
434 region: "us-east-1".into(),
435 account_id: "123456789012".into(),
436 request_id: "req-1".into(),
437 headers: HeaderMap::new(),
438 query_params: HashMap::new(),
439 body: Bytes::new(),
440 path_segments: vec![],
441 raw_path: "/".into(),
442 raw_query: String::new(),
443 method: Method::GET,
444 is_query_protocol: false,
445 access_key_id: None,
446 principal: None,
447 }
448 }
449
450 fn sample_action() -> IamAction {
451 IamAction {
452 service: "s3",
453 action: "ListBucket",
454 resource: "arn:aws:s3:::my-bucket".to_string(),
455 }
456 }
457
458 #[test]
459 fn iam_condition_keys_for_default_is_empty() {
460 let svc = DefaultService;
461 let keys = svc.iam_condition_keys_for(&sample_request(), &sample_action());
462 assert!(keys.is_empty());
463 }
464
465 #[test]
466 fn iam_condition_keys_for_override_returns_map() {
467 let svc = PopulatedService;
468 let keys = svc.iam_condition_keys_for(&sample_request(), &sample_action());
469 assert_eq!(keys.get("s3:prefix"), Some(&vec!["logs/".to_string()]));
470 }
471
472 #[test]
473 fn response_body_len_and_is_empty_for_bytes() {
474 let body: ResponseBody = Bytes::from_static(b"hello").into();
475 assert_eq!(body.len(), 5);
476 assert!(!body.is_empty());
477 let empty: ResponseBody = ResponseBody::default();
478 assert!(empty.is_empty());
479 }
480
481 #[test]
482 fn response_body_from_vec_and_string_and_str() {
483 let from_vec: ResponseBody = vec![1u8, 2, 3].into();
484 assert_eq!(from_vec.expect_bytes(), &[1, 2, 3][..]);
485 let from_string: ResponseBody = String::from("hi").into();
486 assert_eq!(from_string.expect_bytes(), b"hi");
487 let from_str: ResponseBody = "hey".into();
488 assert_eq!(from_str.expect_bytes(), b"hey");
489 let from_static: ResponseBody = (b"123" as &'static [u8]).into();
490 assert_eq!(from_static.expect_bytes(), b"123");
491 }
492
493 #[test]
494 fn response_body_partial_eq_bytes() {
495 let body: ResponseBody = Bytes::from_static(b"x").into();
496 assert!(body == Bytes::from_static(b"x"));
497 assert!(!(body == Bytes::from_static(b"y")));
498 }
499
500 #[test]
501 fn aws_request_json_body_empty_returns_null() {
502 let req = sample_request();
503 assert_eq!(req.json_body(), serde_json::Value::Null);
504 }
505
506 #[test]
507 fn aws_request_json_body_parses_valid() {
508 let mut req = sample_request();
509 req.body = Bytes::from_static(br#"{"a":1}"#);
510 assert_eq!(req.json_body(), serde_json::json!({"a": 1}));
511 }
512
513 #[test]
514 fn aws_response_xml_constructor() {
515 let resp = AwsResponse::xml(StatusCode::OK, Bytes::from_static(b"<ok/>"));
516 assert_eq!(resp.status, StatusCode::OK);
517 assert_eq!(resp.content_type, "text/xml");
518 }
519
520 #[test]
521 fn aws_response_json_constructor() {
522 let resp = AwsResponse::json(StatusCode::CREATED, "{}");
523 assert_eq!(resp.status, StatusCode::CREATED);
524 assert_eq!(resp.content_type, "application/x-amz-json-1.1");
525 }
526
527 #[test]
528 fn aws_response_ok_json_helper() {
529 let resp = AwsResponse::ok_json(serde_json::json!({"ok": true}));
530 assert_eq!(resp.status, StatusCode::OK);
531 assert!(resp.body.expect_bytes().starts_with(b"{"));
532 }
533
534 #[test]
535 fn aws_error_service_not_found_fields() {
536 let err = AwsServiceError::ServiceNotFound {
537 service: "sqs".to_string(),
538 };
539 assert_eq!(err.status(), StatusCode::BAD_REQUEST);
540 assert_eq!(err.code(), "UnknownService");
541 assert!(err.message().contains("sqs"));
542 assert!(err.extra_fields().is_empty());
543 assert!(err.response_headers().is_empty());
544 }
545
546 #[test]
547 fn aws_error_action_not_implemented_fields() {
548 let err = AwsServiceError::action_not_implemented("sns", "FutureAction");
549 assert_eq!(err.status(), StatusCode::NOT_IMPLEMENTED);
550 assert_eq!(err.code(), "InvalidAction");
551 assert!(err.message().contains("FutureAction"));
552 assert!(err.message().contains("sns"));
553 }
554
555 #[test]
556 fn aws_error_aws_error_helpers() {
557 let e = AwsServiceError::aws_error(StatusCode::FORBIDDEN, "Denied", "no");
558 assert_eq!(e.status(), StatusCode::FORBIDDEN);
559 assert_eq!(e.code(), "Denied");
560 assert_eq!(e.message(), "no");
561
562 let fields = vec![("Bucket".to_string(), "b".to_string())];
563 let ef = AwsServiceError::aws_error_with_fields(
564 StatusCode::NOT_FOUND,
565 "Missing",
566 "gone",
567 fields.clone(),
568 );
569 assert_eq!(ef.extra_fields(), fields.as_slice());
570
571 let hdrs = vec![("X-Retry".to_string(), "1".to_string())];
572 let eh = AwsServiceError::aws_error_with_headers(
573 StatusCode::TOO_MANY_REQUESTS,
574 "Throttled",
575 "slow",
576 hdrs.clone(),
577 );
578 assert_eq!(eh.response_headers(), hdrs.as_slice());
579 }
580
581 #[test]
582 #[should_panic(expected = "expect_bytes called on ResponseBody::File")]
583 fn response_body_expect_bytes_panics_on_file() {
584 let f = std::fs::File::create(std::env::temp_dir().join("fc-test-expect-file")).unwrap();
585 let async_f = tokio::fs::File::from_std(f);
586 let body = ResponseBody::File {
587 file: async_f,
588 size: 0,
589 };
590 let _ = body.expect_bytes();
591 }
592}