dynamodb_tools/
config.rs

1use anyhow::Result;
2use aws_sdk_dynamodb::{
3    operation::create_table::CreateTableInput,
4    types::{
5        AttributeDefinition, BillingMode, GlobalSecondaryIndex, KeySchemaElement, KeyType,
6        LocalSecondaryIndex, Projection, ProjectionType, ProvisionedThroughput,
7        ScalarAttributeType,
8    },
9};
10use serde::{Deserialize, Serialize};
11use std::{fs::File, io::BufReader, path::Path};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct TableConfig {
15    pub table_name: String,
16    /// local endpoints, if provided, dynamodb connector will connect dynamodb local
17    #[serde(default)]
18    pub(crate) local_endpoint: Option<String>,
19    /// drop table when connector is dropped. Would only work if local_endpoint is provided
20    #[serde(default)]
21    pub(crate) delete_on_exit: bool,
22    /// table info
23    pub(crate) info: Option<TableInfo>,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct TableInfo {
28    pub table_name: String,
29    pub pk: TableAttr,
30    #[serde(default)]
31    pub sk: Option<TableAttr>,
32    #[serde(default)]
33    pub attrs: Vec<TableAttr>,
34    #[serde(default)]
35    pub gsis: Vec<TableGsi>,
36    #[serde(default)]
37    pub lsis: Vec<TableLsi>,
38    #[serde(default)]
39    pub throughput: Option<Throughput>,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct Throughput {
44    pub read: i64,
45    pub write: i64,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct TableAttr {
50    pub name: String,
51    #[serde(rename = "type")]
52    pub attr_type: AttrType,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
56pub enum AttrType {
57    S,
58    N,
59    B,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct TableGsi {
64    pub name: String,
65    pub pk: TableAttr,
66    #[serde(default)]
67    pub sk: Option<TableAttr>,
68    #[serde(default)]
69    pub attrs: Vec<String>,
70    #[serde(default)]
71    pub throughput: Option<Throughput>,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct TableLsi {
76    pub name: String,
77    // must be the same as the pk of the table
78    pub pk: TableAttr,
79    pub sk: TableAttr,
80    #[serde(default)]
81    pub attrs: Vec<String>,
82}
83
84impl From<AttrType> for ScalarAttributeType {
85    fn from(attr_type: AttrType) -> Self {
86        match attr_type {
87            AttrType::S => ScalarAttributeType::S,
88            AttrType::N => ScalarAttributeType::N,
89            AttrType::B => ScalarAttributeType::B,
90        }
91    }
92}
93
94impl From<TableAttr> for AttributeDefinition {
95    fn from(attr: TableAttr) -> Self {
96        let attr_type = attr.attr_type.into();
97        AttributeDefinition::builder()
98            .attribute_name(attr.name)
99            .attribute_type(attr_type)
100            .build()
101            .expect("attr should be valid")
102    }
103}
104
105impl TableAttr {
106    fn to_pk(&self) -> KeySchemaElement {
107        KeySchemaElement::builder()
108            .attribute_name(self.name.clone())
109            .key_type(KeyType::Hash)
110            .build()
111            .expect("pk should be valid")
112    }
113
114    fn to_sk(&self) -> KeySchemaElement {
115        KeySchemaElement::builder()
116            .attribute_name(self.name.clone())
117            .key_type(KeyType::Range)
118            .build()
119            .expect("sk should be valid")
120    }
121}
122
123impl From<TableGsi> for GlobalSecondaryIndex {
124    fn from(gsi: TableGsi) -> Self {
125        let pk = gsi.pk.to_pk();
126        let sk = gsi.sk.map(|sk| sk.to_sk());
127
128        let key_schema = if let Some(sk) = sk {
129            vec![pk, sk]
130        } else {
131            vec![pk]
132        };
133
134        let mut builder = GlobalSecondaryIndex::builder()
135            .set_key_schema(Some(key_schema))
136            .projection(
137                Projection::builder()
138                    .projection_type(ProjectionType::Include)
139                    .set_non_key_attributes(Some(gsi.attrs))
140                    .build(),
141            )
142            // .provisioned_throughput(pt)
143            .index_name(gsi.name);
144
145        if let Some(throughput) = gsi.throughput {
146            let pt = ProvisionedThroughput::builder()
147                .read_capacity_units(throughput.read)
148                .write_capacity_units(throughput.write)
149                .build()
150                .expect("throughput should be valid");
151            builder = builder.provisioned_throughput(pt);
152        }
153        builder.build().expect("gsi should be valid")
154    }
155}
156
157impl From<TableLsi> for LocalSecondaryIndex {
158    fn from(lsi: TableLsi) -> Self {
159        let pk = lsi.pk.to_pk();
160        let sk = lsi.sk.to_sk();
161        let key_schema = vec![pk, sk];
162        let projection = if lsi.attrs.is_empty() {
163            Projection::builder()
164                .projection_type(ProjectionType::All)
165                .build()
166        } else {
167            Projection::builder()
168                .projection_type(ProjectionType::Include)
169                .set_non_key_attributes(Some(lsi.attrs))
170                .build()
171        };
172        LocalSecondaryIndex::builder()
173            .set_key_schema(Some(key_schema))
174            .projection(projection)
175            .index_name(lsi.name)
176            .build()
177            .expect("lsi should be valid")
178    }
179}
180
181impl From<TableInfo> for CreateTableInput {
182    fn from(config: TableInfo) -> Self {
183        let pk = config.pk.to_pk();
184        let sk = config.sk.as_ref().map(|sk| sk.to_sk());
185
186        let key_schema = if let Some(sk) = sk {
187            vec![pk, sk]
188        } else {
189            vec![pk]
190        };
191
192        // add pk and sk to attrs
193        let mut attrs = config.attrs.clone();
194        attrs.push(config.pk);
195        if let Some(sk) = config.sk {
196            attrs.push(sk);
197        }
198        let attrs = attrs.into_iter().map(AttributeDefinition::from).collect();
199
200        let gsis = config
201            .gsis
202            .into_iter()
203            .map(GlobalSecondaryIndex::from)
204            .collect();
205
206        let lsis = config
207            .lsis
208            .into_iter()
209            .map(LocalSecondaryIndex::from)
210            .collect();
211
212        let mut builder = CreateTableInput::builder()
213            .table_name(config.table_name)
214            .set_key_schema(Some(key_schema))
215            .set_attribute_definitions(Some(attrs))
216            .set_global_secondary_indexes(Some(gsis))
217            .set_local_secondary_indexes(Some(lsis));
218
219        match config.throughput {
220            Some(throughput) => {
221                let pt = ProvisionedThroughput::builder()
222                    .read_capacity_units(throughput.read)
223                    .write_capacity_units(throughput.write)
224                    .build()
225                    .expect("throughput should be valid");
226                builder = builder.provisioned_throughput(pt);
227            }
228            None => {
229                builder = builder.billing_mode(BillingMode::PayPerRequest);
230            }
231        }
232
233        builder.build().expect("table info should be valid")
234    }
235}
236
237impl TableConfig {
238    pub fn load_from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
239        let file = File::open(path)?;
240        let reader = BufReader::new(file);
241        let config = serde_yaml::from_reader(reader)?;
242        Ok(config)
243    }
244
245    pub fn new(
246        table_name: String,
247        local_endpoint: Option<String>,
248        delete_on_exit: bool,
249        info: Option<TableInfo>,
250    ) -> Self {
251        let delete_on_exit = if local_endpoint.is_some() {
252            delete_on_exit
253        } else {
254            false
255        };
256
257        Self {
258            table_name,
259            local_endpoint,
260            delete_on_exit,
261            info,
262        }
263    }
264}
265
266impl TableInfo {
267    pub fn load_from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
268        let file = File::open(path)?;
269        let reader = BufReader::new(file);
270        let config = serde_yaml::from_reader(reader)?;
271        Ok(config)
272    }
273
274    pub fn load(s: &str) -> Result<Self> {
275        let config = serde_yaml::from_str(s)?;
276        Ok(config)
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    #[test]
285    fn config_could_be_loaded() {
286        let config = TableConfig::load_from_file("fixtures/dev.yml").unwrap();
287
288        let info = config.info.expect("info should be present");
289
290        assert_eq!(config.table_name, "users");
291        assert_eq!(info.pk.name, "pk");
292        assert_eq!(info.pk.attr_type, AttrType::S);
293
294        let input = CreateTableInput::try_from(info).unwrap();
295        assert_eq!(input.attribute_definitions().len(), 5);
296        assert_eq!(input.global_secondary_indexes().len(), 1);
297        assert_eq!(input.local_secondary_indexes().len(), 1);
298    }
299
300    #[test]
301    fn table_info_could_be_loaded() {
302        let info = TableInfo::load_from_file("fixtures/info.yml").unwrap();
303
304        assert_eq!(info.pk.name, "pk");
305        assert_eq!(info.pk.attr_type, AttrType::S);
306    }
307}