Skip to main content

winterbaume_simpledbv2/
handlers.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::Arc;
5
6use http::header::HeaderName;
7use serde_json::json;
8use winterbaume_core::{
9    BackendState, MockRequest, MockResponse, MockService, StateChangeNotifier, StatefulService,
10    default_account_id,
11};
12
13use crate::state::{SdbError, SdbState};
14use crate::views::SdbStateView;
15use crate::wire;
16
17const X_AMZN_ERRORTYPE: HeaderName = HeaderName::from_static("x-amzn-errortype");
18
19pub struct SimpleDbV2Service {
20    pub(crate) state: Arc<BackendState<SdbState>>,
21    pub(crate) notifier: StateChangeNotifier<SdbStateView>,
22}
23
24impl SimpleDbV2Service {
25    pub fn new() -> Self {
26        Self {
27            state: Arc::new(BackendState::new()),
28            notifier: StateChangeNotifier::new(),
29        }
30    }
31
32    /// Pre-register a domain as existing in the mock state.
33    /// This is useful for tests that need domains to exist before starting exports.
34    pub async fn with_domain(self, region: &str, domain_name: &str) -> Self {
35        let state = self.state.get(default_account_id(), region);
36        state.write().await.add_domain(domain_name);
37        self
38    }
39}
40
41impl Default for SimpleDbV2Service {
42    fn default() -> Self {
43        Self::new()
44    }
45}
46
47impl MockService for SimpleDbV2Service {
48    fn service_name(&self) -> &str {
49        "sdb"
50    }
51
52    fn url_patterns(&self) -> Vec<&str> {
53        vec![
54            r"https?://sdb\.(.+)\.amazonaws\.com",
55            r"https?://sdb\.amazonaws\.com",
56        ]
57    }
58
59    fn handle(
60        &self,
61        request: MockRequest,
62    ) -> Pin<Box<dyn Future<Output = MockResponse> + Send + '_>> {
63        Box::pin(async move { self.dispatch(request).await })
64    }
65}
66
67impl SimpleDbV2Service {
68    async fn dispatch(&self, request: MockRequest) -> MockResponse {
69        let region = winterbaume_core::auth::extract_region_from_uri(&request.uri);
70        let account_id = default_account_id();
71        let state = self.state.get(account_id, &region);
72
73        // Parse URL path for restJson1 routing
74        let path = extract_path(&request.uri);
75        let raw_query = extract_query_string(&request.uri);
76        let query_map: HashMap<String, String> = winterbaume_core::parse_query_string(&raw_query);
77
78        let response = match path.as_str() {
79            "/v2/StartDomainExport" => {
80                self.handle_start_domain_export(
81                    &state,
82                    &request,
83                    &[],
84                    &query_map,
85                    account_id,
86                    &region,
87                )
88                .await
89            }
90            "/v2/GetExport" => {
91                self.handle_get_export(&state, &request, &[], &query_map)
92                    .await
93            }
94            "/v2/ListExports" => {
95                self.handle_list_exports(&state, &request, &[], &query_map)
96                    .await
97            }
98            _ => rest_json_error(404, "UnknownOperationException", "Not found"),
99        };
100        if response.status / 100 == 2 {
101            self.notify_state_changed(account_id, &region).await;
102        }
103        response
104    }
105
106    #[allow(clippy::too_many_arguments)]
107    async fn handle_start_domain_export(
108        &self,
109        state: &Arc<tokio::sync::RwLock<SdbState>>,
110        request: &MockRequest,
111        labels: &[(&str, &str)],
112        query: &HashMap<String, String>,
113        account_id: &str,
114        region: &str,
115    ) -> MockResponse {
116        let input = match wire::deserialize_start_domain_export_request(request, labels, query) {
117            Ok(v) => v,
118            Err(_) => return rest_json_error(400, "SerializationException", "Invalid JSON body"),
119        };
120        if input.domain_name.is_empty() {
121            return rest_json_error(
122                400,
123                "InvalidParameterValueException",
124                "Missing 'domainName'",
125            );
126        }
127        if input.s3_bucket.is_empty() {
128            return rest_json_error(400, "InvalidParameterValueException", "Missing 's3Bucket'");
129        }
130        let client_token = match input.client_token.as_deref() {
131            Some(t) if !t.is_empty() => t,
132            _ => {
133                return rest_json_error(
134                    400,
135                    "InvalidParameterValueException",
136                    "clientToken is required",
137                );
138            }
139        };
140
141        let mut state = state.write().await;
142        match state.start_domain_export(
143            &input.domain_name,
144            &input.s3_bucket,
145            input.s3_key_prefix.as_deref(),
146            input.s3_sse_algorithm.as_deref(),
147            input.s3_sse_kms_key_id.as_deref(),
148            input.s3_bucket_owner.as_deref(),
149            Some(client_token),
150            account_id,
151            region,
152        ) {
153            Ok(export) => {
154                wire::serialize_start_domain_export_response(&wire::StartDomainExportResponse {
155                    client_token: Some(export.client_token.clone()),
156                    export_arn: Some(export.export_arn.clone()),
157                    requested_at: Some(export.requested_at.timestamp() as f64),
158                })
159            }
160            Err(e) => sdb_error_response(&e),
161        }
162    }
163
164    async fn handle_get_export(
165        &self,
166        state: &Arc<tokio::sync::RwLock<SdbState>>,
167        request: &MockRequest,
168        labels: &[(&str, &str)],
169        query: &HashMap<String, String>,
170    ) -> MockResponse {
171        let input = match wire::deserialize_get_export_request(request, labels, query) {
172            Ok(v) => v,
173            Err(_) => return rest_json_error(400, "SerializationException", "Invalid JSON body"),
174        };
175        if input.export_arn.is_empty() {
176            return rest_json_error(400, "InvalidParameterValueException", "Missing 'exportArn'");
177        }
178
179        let state = state.read().await;
180        match state.get_export(&input.export_arn) {
181            Ok(export) => wire::serialize_get_export_response(&wire::GetExportResponse {
182                export_arn: Some(export.export_arn.clone()),
183                client_token: Some(export.client_token.clone()),
184                export_status: Some(export.export_status.clone()),
185                domain_name: Some(export.domain_name.clone()),
186                requested_at: Some(export.requested_at.timestamp() as f64),
187                s3_bucket: Some(export.s3_bucket.clone()),
188                s3_key_prefix: export.s3_key_prefix.clone(),
189                s3_sse_algorithm: export.s3_sse_algorithm.clone(),
190                s3_sse_kms_key_id: export.s3_sse_kms_key_id.clone(),
191                s3_bucket_owner: export.s3_bucket_owner.clone(),
192                failure_code: export.failure_code.clone(),
193                failure_message: export.failure_message.clone(),
194                export_manifest: export.export_manifest.clone(),
195                items_count: export.items_count,
196                export_data_cutoff_time: export
197                    .export_data_cutoff_time
198                    .map(|dt| dt.timestamp() as f64),
199            }),
200            Err(e) => sdb_error_response(&e),
201        }
202    }
203
204    async fn handle_list_exports(
205        &self,
206        state: &Arc<tokio::sync::RwLock<SdbState>>,
207        request: &MockRequest,
208        labels: &[(&str, &str)],
209        query: &HashMap<String, String>,
210    ) -> MockResponse {
211        let input = match wire::deserialize_list_exports_request(request, labels, query) {
212            Ok(v) => v,
213            Err(_) => return rest_json_error(400, "SerializationException", "Invalid JSON body"),
214        };
215
216        let state = state.read().await;
217        match state.list_exports(
218            input.domain_name.as_deref(),
219            input.max_results,
220            input.next_token.as_deref(),
221        ) {
222            Ok((summaries, next_token)) => {
223                let entries: Vec<wire::ExportSummary> = summaries
224                    .iter()
225                    .map(|s| wire::ExportSummary {
226                        export_arn: Some(s.export_arn.clone()),
227                        export_status: Some(s.export_status.clone()),
228                        requested_at: Some(s.requested_at.timestamp() as f64),
229                        domain_name: Some(s.domain_name.clone()),
230                    })
231                    .collect();
232
233                wire::serialize_list_exports_response(&wire::ListExportsResponse {
234                    export_summaries: Some(entries),
235                    next_token,
236                })
237            }
238            Err(e) => sdb_error_response(&e),
239        }
240    }
241}
242
243fn sdb_error_response(err: &SdbError) -> MockResponse {
244    let (status, error_type) = match err {
245        SdbError::NoSuchDomain { .. } => (400, "NoSuchDomainException"),
246        SdbError::NoSuchExport { .. } => (400, "NoSuchExportException"),
247        SdbError::Conflict => (400, "ConflictException"),
248    };
249    let body = json!({
250        "Type": "User",
251        "Message": err.to_string(),
252    });
253    let mut resp = MockResponse::rest_json(status, body.to_string());
254    resp.headers
255        .insert(X_AMZN_ERRORTYPE, error_type.parse().unwrap());
256    resp
257}
258
259fn rest_json_error(status: u16, code: &str, message: &str) -> MockResponse {
260    let body = json!({
261        "Type": "User",
262        "Message": message,
263    });
264    let mut resp = MockResponse::rest_json(status, body.to_string());
265    resp.headers.insert(X_AMZN_ERRORTYPE, code.parse().unwrap());
266    resp
267}
268
269fn extract_path(uri: &str) -> String {
270    // Parse the URI and extract the path component
271    if let Ok(parsed) = uri.parse::<http::Uri>() {
272        parsed.path().to_string()
273    } else {
274        // Fallback: find path after host
275        if let Some(pos) = uri.find("amazonaws.com") {
276            let rest = &uri[pos + "amazonaws.com".len()..];
277            rest.split('?').next().unwrap_or("/").to_string()
278        } else {
279            "/".to_string()
280        }
281    }
282}
283
284fn extract_query_string(uri: &str) -> String {
285    if let Ok(parsed) = uri.parse::<http::Uri>() {
286        parsed.query().unwrap_or("").to_string()
287    } else if let Some(idx) = uri.find('?') {
288        uri[idx + 1..].to_string()
289    } else {
290        String::new()
291    }
292}