1use anyhow::Result;
2use aws_sdk_dynamodb::{
3 operation::create_table::CreateTableInput,
4 types::{
5 AttributeDefinition, BillingMode, GlobalSecondaryIndex, KeySchemaElement, KeyType,
6 LocalSecondaryIndex, Projection, ProjectionType, ProvisionedThroughput,
7 ScalarAttributeType,
8 },
9};
10use serde::{Deserialize, Serialize};
11use std::{fs::File, io::BufReader, path::Path};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct TableConfig {
15 pub table_name: String,
16 #[serde(default)]
18 pub(crate) local_endpoint: Option<String>,
19 #[serde(default)]
21 pub(crate) delete_on_exit: bool,
22 pub(crate) info: Option<TableInfo>,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct TableInfo {
28 pub table_name: String,
29 pub pk: TableAttr,
30 #[serde(default)]
31 pub sk: Option<TableAttr>,
32 #[serde(default)]
33 pub attrs: Vec<TableAttr>,
34 #[serde(default)]
35 pub gsis: Vec<TableGsi>,
36 #[serde(default)]
37 pub lsis: Vec<TableLsi>,
38 #[serde(default)]
39 pub throughput: Option<Throughput>,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct Throughput {
44 pub read: i64,
45 pub write: i64,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct TableAttr {
50 pub name: String,
51 #[serde(rename = "type")]
52 pub attr_type: AttrType,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
56pub enum AttrType {
57 S,
58 N,
59 B,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct TableGsi {
64 pub name: String,
65 pub pk: TableAttr,
66 #[serde(default)]
67 pub sk: Option<TableAttr>,
68 #[serde(default)]
69 pub attrs: Vec<String>,
70 #[serde(default)]
71 pub throughput: Option<Throughput>,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct TableLsi {
76 pub name: String,
77 pub pk: TableAttr,
79 pub sk: TableAttr,
80 #[serde(default)]
81 pub attrs: Vec<String>,
82}
83
84impl From<AttrType> for ScalarAttributeType {
85 fn from(attr_type: AttrType) -> Self {
86 match attr_type {
87 AttrType::S => ScalarAttributeType::S,
88 AttrType::N => ScalarAttributeType::N,
89 AttrType::B => ScalarAttributeType::B,
90 }
91 }
92}
93
94impl From<TableAttr> for AttributeDefinition {
95 fn from(attr: TableAttr) -> Self {
96 let attr_type = attr.attr_type.into();
97 AttributeDefinition::builder()
98 .attribute_name(attr.name)
99 .attribute_type(attr_type)
100 .build()
101 .expect("attr should be valid")
102 }
103}
104
105impl TableAttr {
106 fn to_pk(&self) -> KeySchemaElement {
107 KeySchemaElement::builder()
108 .attribute_name(self.name.clone())
109 .key_type(KeyType::Hash)
110 .build()
111 .expect("pk should be valid")
112 }
113
114 fn to_sk(&self) -> KeySchemaElement {
115 KeySchemaElement::builder()
116 .attribute_name(self.name.clone())
117 .key_type(KeyType::Range)
118 .build()
119 .expect("sk should be valid")
120 }
121}
122
123impl From<TableGsi> for GlobalSecondaryIndex {
124 fn from(gsi: TableGsi) -> Self {
125 let pk = gsi.pk.to_pk();
126 let sk = gsi.sk.map(|sk| sk.to_sk());
127
128 let key_schema = if let Some(sk) = sk {
129 vec![pk, sk]
130 } else {
131 vec![pk]
132 };
133
134 let mut builder = GlobalSecondaryIndex::builder()
135 .set_key_schema(Some(key_schema))
136 .projection(
137 Projection::builder()
138 .projection_type(ProjectionType::Include)
139 .set_non_key_attributes(Some(gsi.attrs))
140 .build(),
141 )
142 .index_name(gsi.name);
144
145 if let Some(throughput) = gsi.throughput {
146 let pt = ProvisionedThroughput::builder()
147 .read_capacity_units(throughput.read)
148 .write_capacity_units(throughput.write)
149 .build()
150 .expect("throughput should be valid");
151 builder = builder.provisioned_throughput(pt);
152 }
153 builder.build().expect("gsi should be valid")
154 }
155}
156
157impl From<TableLsi> for LocalSecondaryIndex {
158 fn from(lsi: TableLsi) -> Self {
159 let pk = lsi.pk.to_pk();
160 let sk = lsi.sk.to_sk();
161 let key_schema = vec![pk, sk];
162 let projection = if lsi.attrs.is_empty() {
163 Projection::builder()
164 .projection_type(ProjectionType::All)
165 .build()
166 } else {
167 Projection::builder()
168 .projection_type(ProjectionType::Include)
169 .set_non_key_attributes(Some(lsi.attrs))
170 .build()
171 };
172 LocalSecondaryIndex::builder()
173 .set_key_schema(Some(key_schema))
174 .projection(projection)
175 .index_name(lsi.name)
176 .build()
177 .expect("lsi should be valid")
178 }
179}
180
181impl From<TableInfo> for CreateTableInput {
182 fn from(config: TableInfo) -> Self {
183 let pk = config.pk.to_pk();
184 let sk = config.sk.as_ref().map(|sk| sk.to_sk());
185
186 let key_schema = if let Some(sk) = sk {
187 vec![pk, sk]
188 } else {
189 vec![pk]
190 };
191
192 let mut attrs = config.attrs.clone();
194 attrs.push(config.pk);
195 if let Some(sk) = config.sk {
196 attrs.push(sk);
197 }
198 let attrs = attrs.into_iter().map(AttributeDefinition::from).collect();
199
200 let gsis = config
201 .gsis
202 .into_iter()
203 .map(GlobalSecondaryIndex::from)
204 .collect();
205
206 let lsis = config
207 .lsis
208 .into_iter()
209 .map(LocalSecondaryIndex::from)
210 .collect();
211
212 let mut builder = CreateTableInput::builder()
213 .table_name(config.table_name)
214 .set_key_schema(Some(key_schema))
215 .set_attribute_definitions(Some(attrs))
216 .set_global_secondary_indexes(Some(gsis))
217 .set_local_secondary_indexes(Some(lsis));
218
219 match config.throughput {
220 Some(throughput) => {
221 let pt = ProvisionedThroughput::builder()
222 .read_capacity_units(throughput.read)
223 .write_capacity_units(throughput.write)
224 .build()
225 .expect("throughput should be valid");
226 builder = builder.provisioned_throughput(pt);
227 }
228 None => {
229 builder = builder.billing_mode(BillingMode::PayPerRequest);
230 }
231 }
232
233 builder.build().expect("table info should be valid")
234 }
235}
236
237impl TableConfig {
238 pub fn load_from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
239 let file = File::open(path)?;
240 let reader = BufReader::new(file);
241 let config = serde_yaml::from_reader(reader)?;
242 Ok(config)
243 }
244
245 pub fn new(
246 table_name: String,
247 local_endpoint: Option<String>,
248 delete_on_exit: bool,
249 info: Option<TableInfo>,
250 ) -> Self {
251 let delete_on_exit = if local_endpoint.is_some() {
252 delete_on_exit
253 } else {
254 false
255 };
256
257 Self {
258 table_name,
259 local_endpoint,
260 delete_on_exit,
261 info,
262 }
263 }
264}
265
266impl TableInfo {
267 pub fn load_from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
268 let file = File::open(path)?;
269 let reader = BufReader::new(file);
270 let config = serde_yaml::from_reader(reader)?;
271 Ok(config)
272 }
273
274 pub fn load(s: &str) -> Result<Self> {
275 let config = serde_yaml::from_str(s)?;
276 Ok(config)
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 #[test]
285 fn config_could_be_loaded() {
286 let config = TableConfig::load_from_file("fixtures/dev.yml").unwrap();
287
288 let info = config.info.expect("info should be present");
289
290 assert_eq!(config.table_name, "users");
291 assert_eq!(info.pk.name, "pk");
292 assert_eq!(info.pk.attr_type, AttrType::S);
293
294 let input = CreateTableInput::try_from(info).unwrap();
295 assert_eq!(input.attribute_definitions().len(), 5);
296 assert_eq!(input.global_secondary_indexes().len(), 1);
297 assert_eq!(input.local_secondary_indexes().len(), 1);
298 }
299
300 #[test]
301 fn table_info_could_be_loaded() {
302 let info = TableInfo::load_from_file("fixtures/info.yml").unwrap();
303
304 assert_eq!(info.pk.name, "pk");
305 assert_eq!(info.pk.attr_type, AttrType::S);
306 }
307}