1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use serde_json::{Value, json};
6use winterbaume_core::{
7 BackendState, MockRequest, MockResponse, MockService, StateChangeNotifier, default_account_id,
8};
9
10use crate::backend::{InMemoryRedshiftQueryBackend, RedshiftQueryBackend};
11use crate::model::{
12 CancelStatementResponse, ColumnMetadata, DescribeStatementResponse, ExecuteStatementOutput,
13 Field, GetStatementResultResponse, ListStatementsResponse, SqlParameter, StatementData,
14};
15use crate::state::{RedshiftDataError, RedshiftDataState};
16use crate::types::{Statement, StatementParameter};
17use crate::views::RedshiftDataStateView;
18use crate::wire;
19
20pub struct RedshiftDataService {
21 pub(crate) query_backend: Arc<dyn RedshiftQueryBackend>,
22 pub(crate) state: Arc<BackendState<RedshiftDataState>>,
23 pub(crate) notifier: StateChangeNotifier<RedshiftDataStateView>,
24}
25
26impl RedshiftDataService {
27 pub fn new() -> Self {
28 Self {
29 query_backend: Arc::new(InMemoryRedshiftQueryBackend),
30 state: Arc::new(BackendState::new()),
31 notifier: StateChangeNotifier::new(),
32 }
33 }
34
35 pub fn with_query_backend(query_backend: Arc<dyn RedshiftQueryBackend>) -> Self {
37 Self {
38 query_backend,
39 state: Arc::new(BackendState::new()),
40 notifier: StateChangeNotifier::new(),
41 }
42 }
43}
44
45impl Default for RedshiftDataService {
46 fn default() -> Self {
47 Self::new()
48 }
49}
50
51impl MockService for RedshiftDataService {
52 fn service_name(&self) -> &str {
53 "redshift-data"
54 }
55
56 fn url_patterns(&self) -> Vec<&str> {
57 vec![
58 r"https?://redshift-data\.(.+)\.amazonaws\.com",
59 r"https?://redshift-data\.amazonaws\.com",
60 ]
61 }
62
63 fn handle(
64 &self,
65 request: MockRequest,
66 ) -> Pin<Box<dyn Future<Output = MockResponse> + Send + '_>> {
67 Box::pin(async move { self.dispatch(request).await })
68 }
69}
70
71impl RedshiftDataService {
72 async fn dispatch(&self, request: MockRequest) -> MockResponse {
73 let region = winterbaume_core::auth::extract_region_from_uri(&request.uri);
74 let account_id = default_account_id();
75
76 let action = request
79 .headers
80 .get("x-amz-target")
81 .and_then(|v| v.to_str().ok())
82 .and_then(|v| v.rsplit('.').next())
83 .map(|s| s.to_string());
84
85 let action = match action {
86 Some(a) => a,
87 None => {
88 return json_error_response(400, "MissingAction", "Missing X-Amz-Target header");
89 }
90 };
91
92 if serde_json::from_slice::<Value>(&request.body).is_err() {
95 return json_error_response(400, "SerializationException", "Invalid JSON body");
96 }
97 let body_bytes: &[u8] = &request.body;
98
99 let state = self.state.get(account_id, ®ion);
100
101 match action.as_str() {
102 "ExecuteStatement" => self.handle_execute_statement(&state, body_bytes).await,
103 "BatchExecuteStatement" => {
104 self.handle_batch_execute_statement(&state, body_bytes)
105 .await
106 }
107 "DescribeStatement" => self.handle_describe_statement(&state, body_bytes).await,
108 "DescribeTable" => self.handle_describe_table(&state, body_bytes).await,
109 "CancelStatement" => self.handle_cancel_statement(&state, body_bytes).await,
110 "ListStatements" => self.handle_list_statements(&state).await,
111 "GetStatementResult" => self.handle_get_statement_result(&state, body_bytes).await,
112 "GetStatementResultV2" => {
113 self.handle_get_statement_result_v2(&state, body_bytes)
114 .await
115 }
116 "ListDatabases" => self.handle_list_databases(&state, body_bytes).await,
117 "ListSchemas" => self.handle_list_schemas(&state, body_bytes).await,
118 "ListTables" => self.handle_list_tables(&state, body_bytes).await,
119 _ => json_error_response(
120 400,
121 "InvalidAction",
122 &format!("Could not find operation {action} for RedshiftData"),
123 ),
124 }
125 }
126
127 async fn handle_execute_statement(
128 &self,
129 state: &Arc<tokio::sync::RwLock<RedshiftDataState>>,
130 body: &[u8],
131 ) -> MockResponse {
132 let input = match wire::deserialize_execute_statement_request(body) {
133 Ok(v) => v,
134 Err(e) => return json_error_response(400, "ValidationException", &e),
135 };
136 if input.sql.is_empty() {
137 return json_error_response(400, "ValidationException", "Sql is required");
138 }
139 let database = match input.database.as_deref() {
140 Some(d) => d,
141 None => {
142 return json_error_response(400, "ValidationException", "Database is required");
143 }
144 };
145
146 let sql = input.sql.as_str();
147 let cluster_identifier = input.cluster_identifier.as_deref();
148 let workgroup_name = input.workgroup_name.as_deref();
149 let db_user = input.db_user.as_deref();
150 let secret_arn = input.secret_arn.as_deref();
151
152 let parameters: Vec<StatementParameter> = input
153 .parameters
154 .unwrap_or_default()
155 .into_iter()
156 .map(|p| StatementParameter {
157 name: p.name,
158 value: p.value,
159 })
160 .collect();
161
162 let result = self.query_backend.execute_statement(sql.to_string()).await;
163
164 let mut state = state.write().await;
165 match state.execute_statement(
166 sql,
167 database,
168 cluster_identifier,
169 workgroup_name,
170 db_user,
171 secret_arn,
172 parameters,
173 result,
174 ) {
175 Ok(id) => {
176 let output = ExecuteStatementOutput {
177 id: Some(id),
178 created_at: Some(chrono::Utc::now().timestamp() as f64),
179 database: Some(database.to_string()),
180 cluster_identifier: cluster_identifier.map(String::from),
181 workgroup_name: workgroup_name.map(String::from),
182 db_user: db_user.map(String::from),
183 secret_arn: secret_arn.map(String::from),
184 ..Default::default()
185 };
186 wire::serialize_execute_statement_response(&output)
187 }
188 Err(e) => redshiftdata_error_response(&e),
189 }
190 }
191
192 async fn handle_describe_statement(
193 &self,
194 state: &Arc<tokio::sync::RwLock<RedshiftDataState>>,
195 body: &[u8],
196 ) -> MockResponse {
197 let input = match wire::deserialize_describe_statement_request(body) {
198 Ok(v) => v,
199 Err(e) => return json_error_response(400, "ValidationException", &e),
200 };
201 if input.id.is_empty() {
202 return json_error_response(400, "ValidationException", "Id is required");
203 }
204 let id = input.id.as_str();
205
206 let state = state.read().await;
207 match state.describe_statement(id) {
208 Ok(stmt) => {
209 let resp = statement_to_describe_response(stmt);
210 wire::serialize_describe_statement_response(&resp)
211 }
212 Err(e) => redshiftdata_error_response(&e),
213 }
214 }
215
216 async fn handle_cancel_statement(
217 &self,
218 state: &Arc<tokio::sync::RwLock<RedshiftDataState>>,
219 body: &[u8],
220 ) -> MockResponse {
221 let input = match wire::deserialize_cancel_statement_request(body) {
222 Ok(v) => v,
223 Err(e) => return json_error_response(400, "ValidationException", &e),
224 };
225 if input.id.is_empty() {
226 return json_error_response(400, "ValidationException", "Id is required");
227 }
228 let id = input.id.as_str();
229
230 let mut state = state.write().await;
231 match state.cancel_statement(id) {
232 Ok(status) => {
233 let resp = CancelStatementResponse {
234 status: Some(status),
235 };
236 wire::serialize_cancel_statement_response(&resp)
237 }
238 Err(e) => redshiftdata_error_response(&e),
239 }
240 }
241
242 async fn handle_list_statements(
243 &self,
244 state: &Arc<tokio::sync::RwLock<RedshiftDataState>>,
245 ) -> MockResponse {
246 let state = state.read().await;
247 let stmts = state.list_statements();
248 let entries: Vec<StatementData> = stmts
249 .iter()
250 .map(|s| statement_to_statement_data(s))
251 .collect();
252
253 let resp = ListStatementsResponse {
254 statements: Some(entries),
255 next_token: None,
256 };
257 wire::serialize_list_statements_response(&resp)
258 }
259
260 async fn handle_get_statement_result(
261 &self,
262 state: &Arc<tokio::sync::RwLock<RedshiftDataState>>,
263 body: &[u8],
264 ) -> MockResponse {
265 let input = match wire::deserialize_get_statement_result_request(body) {
266 Ok(v) => v,
267 Err(e) => return json_error_response(400, "ValidationException", &e),
268 };
269 if input.id.is_empty() {
270 return json_error_response(400, "ValidationException", "Id is required");
271 }
272 let id = input.id.as_str();
273
274 let state = state.read().await;
275 match state.describe_statement(id) {
276 Ok(stmt) => {
277 let column_metadata: Vec<ColumnMetadata> = stmt
278 .result_columns
279 .iter()
280 .map(|(name, type_str)| ColumnMetadata {
281 name: Some(name.clone()),
282 type_name: Some(type_str.clone()),
283 ..Default::default()
284 })
285 .collect();
286
287 let records: Vec<Vec<Field>> = stmt
288 .result_data
289 .iter()
290 .map(|row| {
291 row.iter()
292 .zip(stmt.result_columns.iter())
293 .map(|(cell, (_, type_str))| string_to_field(cell, type_str))
294 .collect()
295 })
296 .collect();
297
298 let total = records.len() as i64;
299 let resp = GetStatementResultResponse {
300 records: Some(records),
301 column_metadata: Some(column_metadata),
302 total_num_rows: Some(total),
303 next_token: None,
304 };
305 wire::serialize_get_statement_result_response(&resp)
306 }
307 Err(e) => redshiftdata_error_response(&e),
308 }
309 }
310
311 async fn handle_batch_execute_statement(
312 &self,
313 state: &Arc<tokio::sync::RwLock<RedshiftDataState>>,
314 body: &[u8],
315 ) -> MockResponse {
316 let input = match wire::deserialize_batch_execute_statement_request(body) {
317 Ok(v) => v,
318 Err(e) => return json_error_response(400, "ValidationException", &e),
319 };
320 if input.sqls.is_empty() {
321 return json_error_response(400, "ValidationException", "Sqls is required");
322 }
323 let sqls: Vec<String> = input.sqls;
324
325 let database = match input.database.as_deref() {
326 Some(d) => d,
327 None => {
328 return json_error_response(400, "ValidationException", "Database is required");
329 }
330 };
331
332 let cluster_identifier = input.cluster_identifier.as_deref();
333 let workgroup_name = input.workgroup_name.as_deref();
334 let db_user = input.db_user.as_deref();
335 let secret_arn = input.secret_arn.as_deref();
336 let statement_name = input.statement_name.as_deref();
337
338 let result = self.query_backend.batch_execute(sqls.clone()).await;
339
340 let mut state = state.write().await;
341 match state.batch_execute_statement(
342 sqls.clone(),
343 database,
344 cluster_identifier,
345 workgroup_name,
346 db_user,
347 secret_arn,
348 statement_name,
349 result,
350 ) {
351 Ok(id) => {
352 let output = crate::model::BatchExecuteStatementOutput {
353 id: Some(id),
354 created_at: Some(chrono::Utc::now().timestamp() as f64),
355 database: Some(database.to_string()),
356 cluster_identifier: cluster_identifier.map(String::from),
357 workgroup_name: workgroup_name.map(String::from),
358 db_user: db_user.map(String::from),
359 secret_arn: secret_arn.map(String::from),
360 ..Default::default()
361 };
362 wire::serialize_batch_execute_statement_response(&output)
363 }
364 Err(e) => redshiftdata_error_response(&e),
365 }
366 }
367
368 async fn handle_describe_table(
369 &self,
370 state: &Arc<tokio::sync::RwLock<RedshiftDataState>>,
371 body: &[u8],
372 ) -> MockResponse {
373 let input = match wire::deserialize_describe_table_request(body) {
374 Ok(v) => v,
375 Err(e) => return json_error_response(400, "ValidationException", &e),
376 };
377 let table_name = input.table.as_deref();
378
379 let state = state.read().await;
380 let columns = state.describe_table(table_name);
381 let column_list: Vec<ColumnMetadata> = columns
382 .iter()
383 .map(|(name, type_str)| ColumnMetadata {
384 name: Some(name.clone()),
385 type_name: Some(type_str.clone()),
386 ..Default::default()
387 })
388 .collect();
389
390 let resp = crate::model::DescribeTableResponse {
391 column_list: Some(column_list),
392 next_token: None,
393 table_name: table_name.map(String::from),
394 };
395 wire::serialize_describe_table_response(&resp)
396 }
397
398 async fn handle_get_statement_result_v2(
399 &self,
400 state: &Arc<tokio::sync::RwLock<RedshiftDataState>>,
401 body: &[u8],
402 ) -> MockResponse {
403 let input = match wire::deserialize_get_statement_result_v2_request(body) {
404 Ok(v) => v,
405 Err(e) => return json_error_response(400, "ValidationException", &e),
406 };
407 if input.id.is_empty() {
408 return json_error_response(400, "ValidationException", "Id is required");
409 }
410 let id = input.id.as_str();
411
412 let state = state.read().await;
413 match state.describe_statement(id) {
414 Ok(stmt) => {
415 let column_metadata: Vec<ColumnMetadata> = stmt
416 .result_columns
417 .iter()
418 .map(|(name, type_str)| ColumnMetadata {
419 name: Some(name.clone()),
420 type_name: Some(type_str.clone()),
421 ..Default::default()
422 })
423 .collect();
424
425 let records: Vec<crate::model::QueryRecords> = stmt
427 .result_data
428 .iter()
429 .map(|row| {
430 let csv_line: String = row
431 .iter()
432 .map(|cell| cell.as_deref().unwrap_or(""))
433 .collect::<Vec<&str>>()
434 .join(",");
435 crate::model::QueryRecords {
436 c_s_v_records: Some(csv_line),
437 }
438 })
439 .collect();
440
441 let total = records.len() as i64;
442 let resp = crate::model::GetStatementResultV2Response {
443 column_metadata: Some(column_metadata),
444 next_token: None,
445 records: Some(records),
446 result_format: Some("CSV".to_string()),
447 total_num_rows: Some(total),
448 };
449 wire::serialize_get_statement_result_v2_response(&resp)
450 }
451 Err(e) => redshiftdata_error_response(&e),
452 }
453 }
454
455 async fn handle_list_databases(
456 &self,
457 state: &Arc<tokio::sync::RwLock<RedshiftDataState>>,
458 body: &[u8],
459 ) -> MockResponse {
460 if let Err(e) = wire::deserialize_list_databases_request(body) {
461 return json_error_response(400, "ValidationException", &e);
462 }
463 let state = state.read().await;
464 let databases = state.list_databases();
465 let resp = crate::model::ListDatabasesResponse {
466 databases: Some(databases),
467 next_token: None,
468 };
469 wire::serialize_list_databases_response(&resp)
470 }
471
472 async fn handle_list_schemas(
473 &self,
474 state: &Arc<tokio::sync::RwLock<RedshiftDataState>>,
475 body: &[u8],
476 ) -> MockResponse {
477 if let Err(e) = wire::deserialize_list_schemas_request(body) {
478 return json_error_response(400, "ValidationException", &e);
479 }
480 let state = state.read().await;
481 let schemas = state.list_schemas();
482 let resp = crate::model::ListSchemasResponse {
483 schemas: Some(schemas),
484 next_token: None,
485 };
486 wire::serialize_list_schemas_response(&resp)
487 }
488
489 async fn handle_list_tables(
490 &self,
491 state: &Arc<tokio::sync::RwLock<RedshiftDataState>>,
492 body: &[u8],
493 ) -> MockResponse {
494 if let Err(e) = wire::deserialize_list_tables_request(body) {
495 return json_error_response(400, "ValidationException", &e);
496 }
497 let state = state.read().await;
498 let table_names = state.list_tables();
499 let tables: Vec<crate::model::TableMember> = table_names
500 .iter()
501 .map(|name| crate::model::TableMember {
502 name: Some(name.clone()),
503 ..Default::default()
504 })
505 .collect();
506 let resp = crate::model::ListTablesResponse {
507 tables: Some(tables),
508 next_token: None,
509 };
510 wire::serialize_list_tables_response(&resp)
511 }
512}
513
514fn statement_to_describe_response(stmt: &Statement) -> DescribeStatementResponse {
516 let query_parameters: Option<Vec<SqlParameter>> = if stmt.parameters.is_empty() {
517 None
518 } else {
519 Some(
520 stmt.parameters
521 .iter()
522 .map(|p| SqlParameter {
523 name: p.name.clone(),
524 value: p.value.clone(),
525 })
526 .collect(),
527 )
528 };
529
530 DescribeStatementResponse {
531 id: Some(stmt.id.clone()),
532 status: Some(stmt.status.as_str().to_string()),
533 created_at: Some(stmt.created_at.timestamp() as f64),
534 updated_at: Some(stmt.updated_at.timestamp() as f64),
535 query_string: Some(stmt.query_string.clone()),
536 database: Some(stmt.database.clone()),
537 has_result_set: Some(stmt.has_result_set),
538 result_rows: Some(stmt.result_rows),
539 result_size: Some(stmt.result_size),
540 duration: Some(0),
541 cluster_identifier: stmt.cluster_identifier.clone(),
542 workgroup_name: stmt.workgroup_name.clone(),
543 db_user: stmt.db_user.clone(),
544 secret_arn: stmt.secret_arn.clone(),
545 query_parameters,
546 ..Default::default()
547 }
548}
549
550fn statement_to_statement_data(stmt: &Statement) -> StatementData {
552 StatementData {
553 id: Some(stmt.id.clone()),
554 status: Some(stmt.status.as_str().to_string()),
555 created_at: Some(stmt.created_at.timestamp() as f64),
556 updated_at: Some(stmt.updated_at.timestamp() as f64),
557 query_string: Some(stmt.query_string.clone()),
558 statement_name: Some(String::new()),
559 is_batch_statement: Some(false),
560 secret_arn: stmt.secret_arn.clone(),
561 ..Default::default()
562 }
563}
564
565fn string_to_field(value: &Option<String>, type_str: &str) -> Field {
568 match value {
569 None => Field {
570 is_null: Some(true),
571 ..Default::default()
572 },
573 Some(s) => {
574 let t = type_str.to_ascii_lowercase();
575 if t.contains("int") || t.contains("bigint") {
576 if let Ok(n) = s.parse::<i64>() {
577 return Field {
578 long_value: Some(n),
579 ..Default::default()
580 };
581 }
582 }
583 if t.contains("float") || t.contains("double") || t.contains("real") {
584 if let Ok(f) = s.parse::<f64>() {
585 return Field {
586 double_value: Some(f),
587 ..Default::default()
588 };
589 }
590 }
591 if t.contains("bool") {
592 let b = s == "true" || s == "1";
593 return Field {
594 boolean_value: Some(b),
595 ..Default::default()
596 };
597 }
598 Field {
599 string_value: Some(s.clone()),
600 ..Default::default()
601 }
602 }
603 }
604}
605
606fn redshiftdata_error_response(err: &RedshiftDataError) -> MockResponse {
607 let (status, error_type) = match err {
608 RedshiftDataError::SqlRequired
609 | RedshiftDataError::SqlsRequired
610 | RedshiftDataError::InvalidStatementId => (400, "ValidationException"),
611 RedshiftDataError::StatementNotFound => (404, "ResourceNotFoundException"),
612 };
613 let body = json!({
614 "__type": error_type,
615 "message": err.to_string(),
616 });
617 MockResponse::json(status, body.to_string())
618}
619
620fn json_error_response(status: u16, code: &str, message: &str) -> MockResponse {
621 let body = json!({
622 "__type": code,
623 "message": message,
624 });
625 MockResponse::json(status, body.to_string())
626}