dynamodb_tools/
connector.rs

1use crate::TableConfig;
2use crate::error::{DynamoToolsError, Result};
3use aws_config::meta::region::RegionProviderChain;
4use aws_config::{BehaviorVersion, Region};
5use aws_sdk_dynamodb::config::Credentials;
6use aws_sdk_dynamodb::types::{AttributeValue, PutRequest, WriteRequest};
7use aws_sdk_dynamodb::{Client, operation::create_table::CreateTableInput};
8use serde_json::Value;
9use std::{collections::HashMap, fs, path::Path};
10#[cfg(feature = "test_utils")]
11use tokio::runtime::Runtime;
12
13/// Provides a connection to DynamoDB, potentially managing test table lifecycles.
14///
15/// This struct encapsulates an AWS DynamoDB client (`aws_sdk_dynamodb::Client`).
16/// If configured with table definitions and a local `endpoint` in [`TableConfig`],
17/// it will create uniquely named tables upon construction.
18///
19/// If the `test_utils` feature is enabled and `delete_on_exit` is true in the
20/// configuration, the created tables will be automatically deleted when this
21/// connector is dropped.
22#[derive(Debug)]
23pub struct DynamodbConnector {
24    client: Option<Client>,
25    // Map base table name to actual unique table name created
26    created_tables: HashMap<String, String>,
27    // Keep track of the original config for Drop
28    #[cfg(feature = "test_utils")]
29    config: TableConfig,
30}
31
32impl DynamodbConnector {
33    /// Creates a new connector by loading configuration from a YAML file.
34    ///
35    /// See [`TableConfig::load_from_file`] and [`DynamodbConnector::try_new`].
36    ///
37    /// # Errors
38    ///
39    /// Returns `Err` if loading the config file fails or if creating the connector fails
40    /// (e.g., table creation fails, AWS configuration error).
41    pub async fn load(config_path: impl AsRef<Path>) -> Result<Self> {
42        let config = TableConfig::load_from_file(config_path)?;
43        DynamodbConnector::try_new(config).await
44    }
45
46    /// Returns a reference to the underlying `aws_sdk_dynamodb::Client`.
47    ///
48    /// # Errors
49    ///
50    /// Returns `Err` ([`DynamoToolsError::Internal`]) if the client has already been
51    /// taken (e.g., after `Drop` has started).
52    pub fn client(&self) -> Result<&Client> {
53        self.client
54            .as_ref()
55            .ok_or_else(|| DynamoToolsError::Internal("Client instance is missing".to_string()))
56    }
57
58    /// Returns the unique name of a table created by this connector, given its base name.
59    ///
60    /// The `base_name` corresponds to the `table_name` field within [`TableInfo`]
61    /// in the configuration.
62    pub fn get_created_table_name(&self, base_name: &str) -> Option<&str> {
63        self.created_tables.get(base_name).map(|s| s.as_str())
64    }
65
66    /// Returns a map of all tables created by this connector.
67    /// Keys are the base names from the config, values are the unique created names.
68    pub fn get_all_created_table_names(&self) -> &HashMap<String, String> {
69        &self.created_tables
70    }
71
72    /// Creates a new connector based on the provided [`TableConfig`].
73    ///
74    /// - Sets up AWS SDK configuration.
75    /// - Creates a `aws_sdk_dynamodb::Client`.
76    /// - Iterates through `config.tables`. For each `TableInfo`:
77    ///   - Attempts to create a DynamoDB table with a unique name derived from `TableInfo.table_name`.
78    ///   - Stores the mapping from the base name to the unique name.
79    ///
80    /// # Errors
81    ///
82    /// Returns `Err` if AWS config fails, client creation fails, or any table creation fails.
83    pub async fn try_new(config: TableConfig) -> Result<Self> {
84        let endpoint = config.endpoint.clone();
85        // Store config for Drop
86        #[cfg(feature = "test_utils")]
87        let connector_config = config.clone();
88
89        let base_sdk_config_builder = aws_config::defaults(BehaviorVersion::latest()).region(
90            RegionProviderChain::first_try(Region::new(config.region.clone()))
91                .or_default_provider(),
92        );
93        let loaded_sdk_config = base_sdk_config_builder.load().await;
94        let builder = aws_sdk_dynamodb::config::Builder::from(&loaded_sdk_config);
95        let dynamodb_config = if let Some(url) = endpoint.as_ref() {
96            builder
97                .endpoint_url(url)
98                .credentials_provider(Credentials::for_tests())
99                .build()
100        } else {
101            builder.build()
102        };
103        let client = Client::from_conf(dynamodb_config);
104
105        let mut created_tables = HashMap::new();
106
107        for table_info in config.tables {
108            let base_table_name = table_info.table_name.clone();
109            let seed_file = table_info.seed_data_file.clone(); // Clone seed file path
110            let mut input = CreateTableInput::try_from(table_info)?;
111
112            let unique_table_name = format!("{}-{}", base_table_name, xid::new());
113            input.table_name = Some(unique_table_name.clone());
114
115            // Build the CreateTable request (logic adapted from previous version)
116            let create_table_builder = client
117                .create_table()
118                .table_name(&unique_table_name)
119                .set_key_schema(input.key_schema)
120                .set_attribute_definitions(input.attribute_definitions)
121                .set_global_secondary_indexes(input.global_secondary_indexes)
122                .set_local_secondary_indexes(input.local_secondary_indexes);
123
124            let create_table_builder = match input.provisioned_throughput {
125                Some(pt) => create_table_builder.provisioned_throughput(pt),
126                None => create_table_builder.billing_mode(input.billing_mode.ok_or_else(|| {
127                    DynamoToolsError::MissingField(format!(
128                        "Billing mode missing for table '{}' with no throughput",
129                        base_table_name
130                    ))
131                })?),
132            };
133
134            // Send the request
135            create_table_builder
136                .send()
137                .await
138                .map_err(DynamoToolsError::TableCreation)?; // Propagate SDK errors, wrapped in our type
139
140            created_tables.insert(base_table_name.clone(), unique_table_name.clone());
141
142            // --- Seed Data ---
143            if let Some(file_path) = seed_file {
144                println!(
145                    "[INFO] Seeding data for table '{}' from file '{}'",
146                    unique_table_name, file_path
147                );
148                // Read file content
149                let content = fs::read_to_string(&file_path)
150                    .map_err(|e| DynamoToolsError::SeedFileRead(file_path.clone(), e))?;
151
152                // Parse JSON array
153                let items_json: Vec<Value> = serde_json::from_str(&content)
154                    .map_err(|e| DynamoToolsError::SeedJsonParse(file_path.clone(), e))?;
155
156                // Convert to WriteRequests
157                let mut write_requests = Vec::new();
158                for item_value in items_json {
159                    let item_map: HashMap<String, AttributeValue> =
160                        serde_dynamo::to_item(item_value)?;
161                    let put_request = PutRequest::builder()
162                        .set_item(Some(item_map))
163                        .build()
164                        .map_err(|e| {
165                            DynamoToolsError::Internal(format!("Failed to build PutRequest: {}", e))
166                        })?;
167                    write_requests.push(WriteRequest::builder().put_request(put_request).build());
168                }
169
170                // Batch write items (chunking by 25)
171                for chunk in write_requests.chunks(25) {
172                    let request_items =
173                        HashMap::from([(unique_table_name.clone(), chunk.to_vec())]);
174                    client
175                        .batch_write_item()
176                        .set_request_items(Some(request_items))
177                        .send()
178                        .await
179                        .map_err(|e| {
180                            DynamoToolsError::SeedBatchWrite(unique_table_name.clone(), e)
181                        })?;
182                    println!(
183                        "[INFO] Wrote batch of {} items to table '{}'",
184                        chunk.len(),
185                        unique_table_name
186                    );
187                }
188            }
189            // --- End Seed Data ---
190        }
191
192        Ok(Self {
193            client: Some(client),
194            created_tables,
195            #[cfg(feature = "test_utils")]
196            config: connector_config,
197        })
198    }
199}
200
201/// Best-effort table cleanup on drop (requires `test_utils` feature).
202///
203/// If `delete_on_exit` was true and an endpoint was configured,
204/// attempts to delete all tables created by this connector in background threads.
205#[cfg(feature = "test_utils")]
206impl Drop for DynamodbConnector {
207    fn drop(&mut self) {
208        // Check config before taking client
209        if !self.config.delete_on_exit || self.config.endpoint.is_none() {
210            println!(
211                "[INFO] Skipping delete on drop (delete_on_exit: {}, endpoint: {:?})",
212                self.config.delete_on_exit, self.config.endpoint
213            );
214            return;
215        }
216
217        if let Some(client) = self.client.take() {
218            // Clone map and config needed for threads
219            let tables_to_delete = self.created_tables.clone();
220            println!(
221                "[INFO] Drop: Attempting to delete tables: {:?}",
222                tables_to_delete.values()
223            );
224
225            for (_base_name, unique_name) in tables_to_delete {
226                let client_clone = client.clone(); // Clone client for each thread
227                std::thread::spawn(move || {
228                    let rt = match Runtime::new() {
229                        Ok(rt) => rt,
230                        Err(e) => {
231                            eprintln!(
232                                "[ERROR] Failed to create Tokio runtime for table deletion: {}",
233                                e
234                            );
235                            return;
236                        }
237                    };
238
239                    rt.block_on(async move {
240                        match client_clone
241                            .delete_table()
242                            .table_name(&unique_name)
243                            .send()
244                            .await
245                        {
246                            Ok(_) => println!("[INFO] Deleted table: {}", unique_name),
247                            Err(e) => {
248                                eprintln!("[ERROR] Failed to delete table '{}': {}", unique_name, e)
249                            }
250                        }
251                    });
252                });
253            }
254        }
255    }
256}