1use super::{ReadConversionError, SealError};
2use aws_sdk_dynamodb::{primitives::Blob, types::AttributeValue};
3use cipherstash_client::zerokms::EncryptedRecord;
4use std::{
5 collections::{BTreeMap, HashMap},
6 str::FromStr,
7};
8
9pub trait TryFromTableAttr: Sized {
11 fn try_from_table_attr(value: TableAttribute) -> Result<Self, ReadConversionError>;
13}
14
15#[derive(Clone, PartialEq, Debug)]
16pub enum TableAttribute {
17 String(String),
18 Number(String),
19 Bool(bool),
20 Bytes(Vec<u8>),
21
22 StringVec(Vec<String>),
23 ByteVec(Vec<Vec<u8>>),
24 NumberVec(Vec<String>),
25 Map(HashMap<String, TableAttribute>),
26 List(Vec<TableAttribute>),
27
28 Null,
29}
30
31impl TableAttribute {
32 pub(crate) fn as_encrypted_record(
40 &self,
41 descriptor: &str,
42 ) -> Result<EncryptedRecord, SealError> {
43 if let TableAttribute::Bytes(s) = self {
44 EncryptedRecord::from_slice(&s[..])
45 .map_err(|_| SealError::AssertionFailed("Could not parse EncryptedRecord".to_string()))
46 .and_then(|record| {
47 if record.descriptor == descriptor {
48 Ok(record)
49 } else {
50 Err(SealError::AssertionFailed(format!(
51 "Expected descriptor {}, got {} - WARNING: record may have been tampered with",
52 descriptor,
53 record.descriptor
54 )))
55 }
56 })
57 } else {
58 Err(SealError::AssertionFailed(format!(
59 "Expected TableAttribute::Bytes, got {}",
60 descriptor
61 )))
62 }
63 }
64
65 pub(crate) fn new_map() -> Self {
66 TableAttribute::Map(HashMap::new())
67 }
68
69 pub(crate) fn try_insert_map(
72 &mut self,
73 key: impl Into<String>,
74 value: impl Into<TableAttribute>,
75 ) -> Result<(), SealError> {
76 if let Self::Map(map) = self {
77 map.insert(key.into(), value.into());
78 Ok(())
79 } else {
80 Err(SealError::AssertionFailed(
81 "Expected TableAttribute::Map".to_string(),
82 ))
83 }
84 }
85}
86
87macro_rules! impl_try_from_table_attr_helper {
88 (number_parse, $ty:ty, $value:ident) => {
89 $value
90 .parse()
91 .map_err(|_| ReadConversionError::ConversionFailed(stringify!($ty).to_string()))
92 };
93 (simple_parse, $_:ty, $value:ident) => {
94 Ok::<_, ReadConversionError>($value)
95 };
96 (number_from, $_:ident, $value:ident) => {
97 TableAttribute::Number($value.to_string())
98 };
99 (simple_from, $variant:ident, $value:ident) => {
100 TableAttribute::$variant($value)
101 };
102 (
103 body,
104 $ty:ty,
105 $variant:ident,
106 $from_impl:ident!($from_args:tt),
107 $try_from_impl:ident!($try_from_args:tt)
108 ) => {
109 impl From<$ty> for TableAttribute {
110 fn from(value: $ty) -> Self {
111 $from_impl!($from_args, $variant, value)
112 }
113 }
114
115 impl TryFromTableAttr for $ty {
116 fn try_from_table_attr(value: TableAttribute) -> Result<Self, ReadConversionError> {
117 let TableAttribute::$variant(value) = value else {
118 return Err(ReadConversionError::ConversionFailed(
119 stringify!($ty).to_string(),
120 ));
121 };
122
123 $try_from_impl!($try_from_args, $ty, value)
124 }
125 }
126 };
127}
128
129macro_rules! impl_try_from_table_attr {
130 () => {};
131 (, $($tail:tt)*) => {
132 impl_try_from_table_attr!($($tail)*);
133 };
134 ($ty:ty => Number $($tail:tt)*) => {
135 impl_try_from_table_attr_helper!(
136 body,
137 $ty,
138 Number,
139 impl_try_from_table_attr_helper!(
140 number_from
141 ),
142 impl_try_from_table_attr_helper!(
143 number_parse
144 )
145 );
146
147 impl_try_from_table_attr!($($tail)*);
148 };
149 ($ty:ty => $variant:ident $($tail:tt)*) => {
150 impl_try_from_table_attr_helper!(
151 body,
152 $ty,
153 $variant,
154 impl_try_from_table_attr_helper!(
155 simple_from
156 ),
157 impl_try_from_table_attr_helper!(
158 simple_parse
159 )
160 );
161
162 impl_try_from_table_attr!($($tail)*);
163 };
164}
165
166impl_try_from_table_attr!(
171 i16 => Number,
172 i32 => Number,
173 i64 => Number,
174 u16 => Number,
175 u32 => Number,
176 u64 => Number,
177 usize => Number,
178 f32 => Number,
179 f64 => Number,
180 String => String,
181 Vec<u8> => Bytes,
182 bool => Bool
183);
184
185impl From<&str> for TableAttribute {
186 fn from(value: &str) -> Self {
187 TableAttribute::String(value.to_string())
188 }
189}
190
191impl<T> TryFromTableAttr for Option<T>
192where
193 T: TryFromTableAttr,
194{
195 fn try_from_table_attr(value: TableAttribute) -> Result<Self, ReadConversionError> {
196 if matches!(value, TableAttribute::Null) {
197 Ok(None)
198 } else {
199 Ok(Some(T::try_from_table_attr(value)?))
200 }
201 }
202}
203
204impl<T> TryFromTableAttr for Vec<T>
205where
206 T: TryFromTableAttr,
207{
208 fn try_from_table_attr(value: TableAttribute) -> Result<Self, ReadConversionError> {
209 match value {
210 TableAttribute::StringVec(v) => v
211 .into_iter()
212 .map(TableAttribute::String)
213 .map(T::try_from_table_attr)
214 .collect(),
215 TableAttribute::ByteVec(v) => v
216 .into_iter()
217 .map(TableAttribute::Bytes)
218 .map(T::try_from_table_attr)
219 .collect(),
220 TableAttribute::NumberVec(v) => v
221 .into_iter()
222 .map(TableAttribute::Number)
223 .map(T::try_from_table_attr)
224 .collect(),
225 TableAttribute::List(v) => v.into_iter().map(T::try_from_table_attr).collect(),
226 _ => Err(ReadConversionError::ConversionFailed(
227 std::any::type_name::<Vec<T>>().to_string(),
228 )),
229 }
230 }
231}
232
233impl<T> From<Option<T>> for TableAttribute
234where
235 T: Into<TableAttribute>,
236{
237 fn from(value: Option<T>) -> Self {
238 match value {
239 Some(value) => value.into(),
240 None => TableAttribute::Null,
241 }
242 }
243}
244
245impl<T> From<Vec<T>> for TableAttribute
246where
247 T: Into<TableAttribute>,
248{
249 fn from(value: Vec<T>) -> Self {
250 #[derive(Clone, Copy, PartialEq, Eq)]
255 enum IsVariant {
256 Empty,
258 IsSs,
260 IsNs,
262 IsBs,
264 IsList,
266 }
267
268 let len = value.len();
269 let (table_attributes, is_variant) = value.into_iter().fold(
270 (Vec::with_capacity(len), IsVariant::Empty),
271 |(mut acc, mut is_variant), item| {
272 let table_attr = item.into();
273
274 if is_variant != IsVariant::IsList {
276 match (&table_attr, is_variant) {
277 (TableAttribute::Bytes(_), IsVariant::Empty)
278 | (TableAttribute::Bytes(_), IsVariant::IsBs) => {
279 is_variant = IsVariant::IsBs
280 }
281 (TableAttribute::Number(_), IsVariant::Empty)
282 | (TableAttribute::Number(_), IsVariant::IsNs) => {
283 is_variant = IsVariant::IsNs
284 }
285 (TableAttribute::String(_), IsVariant::Empty)
286 | (TableAttribute::String(_), IsVariant::IsSs) => {
287 is_variant = IsVariant::IsSs
288 }
289 _ => is_variant = IsVariant::IsList,
290 }
291 }
292
293 acc.push(table_attr);
294 (acc, is_variant)
295 },
296 );
297
298 match is_variant {
299 IsVariant::IsList | IsVariant::Empty => TableAttribute::List(table_attributes),
300 IsVariant::IsSs => {
301 let strings = table_attributes
302 .into_iter()
303 .map(|string| {
304 let TableAttribute::String(string) = string else {
305 unreachable!()
307 };
308
309 string
310 })
311 .collect();
312
313 TableAttribute::StringVec(strings)
314 }
315 IsVariant::IsNs => {
316 let numbers = table_attributes
317 .into_iter()
318 .map(|number| {
319 let TableAttribute::Number(number) = number else {
320 unreachable!()
322 };
323
324 number
325 })
326 .collect();
327
328 TableAttribute::NumberVec(numbers)
329 }
330 IsVariant::IsBs => {
331 let bytes = table_attributes
332 .into_iter()
333 .map(|bytes| {
334 let TableAttribute::Bytes(bytes) = bytes else {
335 unreachable!()
337 };
338
339 bytes
340 })
341 .collect();
342
343 TableAttribute::ByteVec(bytes)
344 }
345 }
346 }
347}
348
349impl From<TableAttribute> for AttributeValue {
350 fn from(attribute: TableAttribute) -> Self {
351 match attribute {
352 TableAttribute::String(s) => AttributeValue::S(s),
353 TableAttribute::StringVec(s) => AttributeValue::Ss(s),
354
355 TableAttribute::Number(i) => AttributeValue::N(i),
356 TableAttribute::NumberVec(x) => AttributeValue::Ns(x),
357
358 TableAttribute::Bytes(x) => AttributeValue::B(Blob::new(x)),
359 TableAttribute::ByteVec(x) => {
360 AttributeValue::Bs(x.into_iter().map(Blob::new).collect())
361 }
362
363 TableAttribute::Bool(x) => AttributeValue::Bool(x),
364 TableAttribute::List(x) => AttributeValue::L(x.into_iter().map(|x| x.into()).collect()),
365 TableAttribute::Map(x) => {
366 AttributeValue::M(x.into_iter().map(|(k, v)| (k, v.into())).collect())
367 }
368 TableAttribute::Null => AttributeValue::Null(true),
369 }
370 }
371}
372
373impl From<AttributeValue> for TableAttribute {
374 fn from(attribute: AttributeValue) -> Self {
375 match attribute {
376 AttributeValue::S(s) => TableAttribute::String(s),
377 AttributeValue::N(n) => TableAttribute::Number(n),
378 AttributeValue::Bool(n) => TableAttribute::Bool(n),
379 AttributeValue::B(n) => TableAttribute::Bytes(n.into_inner()),
380 AttributeValue::L(l) => {
381 TableAttribute::List(l.into_iter().map(TableAttribute::from).collect())
382 }
383 AttributeValue::M(l) => TableAttribute::Map(
384 l.into_iter()
385 .map(|(k, v)| (k, TableAttribute::from(v)))
386 .collect(),
387 ),
388 AttributeValue::Bs(x) => {
389 TableAttribute::ByteVec(x.into_iter().map(|x| x.into_inner()).collect())
390 }
391 AttributeValue::Ss(x) => TableAttribute::StringVec(x),
392 AttributeValue::Ns(x) => TableAttribute::NumberVec(x),
393 AttributeValue::Null(_) => TableAttribute::Null,
394
395 x => panic!("Unsupported Dynamo attribute value: {x:?}"),
396 }
397 }
398}
399
400impl<K, V> From<HashMap<K, V>> for TableAttribute
401where
402 K: ToString,
403 V: Into<TableAttribute>,
404{
405 fn from(map: HashMap<K, V>) -> Self {
406 TableAttribute::Map(
407 map.into_iter()
408 .map(|(k, v)| (k.to_string(), v.into()))
409 .collect(),
410 )
411 }
412}
413
414impl<K, V> TryFromTableAttr for HashMap<K, V>
415where
416 K: FromStr + std::hash::Hash + std::cmp::Eq,
417 V: TryFromTableAttr,
418{
419 fn try_from_table_attr(value: TableAttribute) -> Result<Self, ReadConversionError> {
420 let TableAttribute::Map(map) = value else {
421 return Err(ReadConversionError::ConversionFailed(
422 std::any::type_name::<Self>().to_string(),
423 ));
424 };
425
426 map.into_iter()
427 .map(|(k, v)| {
428 let k = k.parse().map_err(|_| {
429 ReadConversionError::ConversionFailed(std::any::type_name::<Self>().to_string())
430 })?;
431 let v = V::try_from_table_attr(v)?;
432
433 Ok((k, v))
434 })
435 .collect()
436 }
437}
438
439impl<K, V> From<BTreeMap<K, V>> for TableAttribute
440where
441 K: ToString,
442 V: Into<TableAttribute>,
443{
444 fn from(map: BTreeMap<K, V>) -> Self {
445 TableAttribute::Map(
446 map.into_iter()
447 .map(|(k, v)| (k.to_string(), v.into()))
448 .collect(),
449 )
450 }
451}
452
453impl<K, V> TryFromTableAttr for BTreeMap<K, V>
454where
455 K: FromStr + std::cmp::Ord,
456 V: TryFromTableAttr,
457{
458 fn try_from_table_attr(value: TableAttribute) -> Result<Self, ReadConversionError> {
459 let TableAttribute::Map(map) = value else {
460 return Err(ReadConversionError::ConversionFailed(
461 std::any::type_name::<Self>().to_string(),
462 ));
463 };
464
465 map.into_iter()
466 .map(|(k, v)| {
467 let k = k.parse().map_err(|_| {
468 ReadConversionError::ConversionFailed(std::any::type_name::<Self>().to_string())
469 })?;
470 let v = V::try_from_table_attr(v)?;
471
472 Ok((k, v))
473 })
474 .collect()
475 }
476}
477
478#[cfg(test)]
479mod test {
480 use super::*;
481
482 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
483 enum TestType {
484 Number,
485 String,
486 Bytes,
487 }
488
489 impl From<TestType> for TableAttribute {
490 fn from(value: TestType) -> Self {
491 match value {
492 TestType::Number => TableAttribute::Number(42.to_string()),
493 TestType::String => TableAttribute::String("fourty two".to_string()),
494 TestType::Bytes => TableAttribute::Bytes(b"101010".to_vec()),
495 }
496 }
497 }
498
499 impl TryFromTableAttr for TestType {
500 fn try_from_table_attr(value: TableAttribute) -> Result<Self, ReadConversionError> {
501 match value {
502 TableAttribute::Number(n) if n == "42" => Ok(Self::Number),
503 TableAttribute::String(s) if s == "fourty two" => Ok(Self::String),
504 TableAttribute::Bytes(b) if b == b"101010" => Ok(Self::Bytes),
505 _ => Err(ReadConversionError::ConversionFailed("".to_string())),
506 }
507 }
508 }
509
510 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
511 enum MapKeys {
512 A,
513 B,
514 C,
515 }
516
517 impl std::fmt::Display for MapKeys {
518 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
519 let c = match self {
520 MapKeys::A => "A",
521 MapKeys::B => "B",
522 MapKeys::C => "C",
523 };
524
525 write!(f, "{c}")
526 }
527 }
528
529 impl FromStr for MapKeys {
530 type Err = ();
531
532 fn from_str(s: &str) -> Result<Self, Self::Err> {
533 match s {
534 "A" => Ok(MapKeys::A),
535 "B" => Ok(MapKeys::B),
536 "C" => Ok(MapKeys::C),
537 _ => Err(()),
538 }
539 }
540 }
541
542 #[test]
543 fn test_to_and_from_list() {
544 let test_vec = vec![
545 TestType::Number,
546 TestType::Number,
547 TestType::String,
548 TestType::Bytes,
549 ];
550
551 let table_attribute = TableAttribute::from(test_vec.clone());
552
553 assert!(matches!(&table_attribute, TableAttribute::List(x) if x.len() == test_vec.len()));
555
556 let original = Vec::<TestType>::try_from_table_attr(table_attribute).unwrap();
557
558 assert_eq!(original, test_vec);
559 }
560
561 #[test]
562 fn test_string_vec() {
563 let test_vec = vec![
564 "String0".to_string(),
565 "String1".to_string(),
566 "String2".to_string(),
567 ];
568
569 let table_attribute = TableAttribute::from(test_vec.clone());
570
571 assert!(matches!(
572 &table_attribute,
573 TableAttribute::StringVec(x)
574 if x.len() == test_vec.len()
575 ));
576
577 let original = Vec::<String>::try_from_table_attr(table_attribute).unwrap();
578
579 assert_eq!(original, test_vec);
580 }
581
582 #[test]
583 fn test_number_vec() {
584 let test_vec = vec![2, 3, 5, 7, 13];
585
586 let table_attribute = TableAttribute::from(test_vec.clone());
587
588 assert!(matches!(
589 &table_attribute,
590 TableAttribute::NumberVec(x)
591 if x.len() == test_vec.len()
592 ));
593
594 let original = Vec::<i32>::try_from_table_attr(table_attribute).unwrap();
595
596 assert_eq!(original, test_vec);
597 }
598
599 #[test]
600 fn test_bytes_vec() {
601 let test_vec: Vec<Vec<u8>> = (0u8..5).map(|i| (i * 10..i * 10 + 10).collect()).collect();
602
603 let table_attribute = TableAttribute::from(test_vec.clone());
604
605 assert!(matches!(
606 &table_attribute,
607 TableAttribute::ByteVec(x)
608 if x.len() == test_vec.len()
609 ));
610
611 let original = Vec::<Vec<u8>>::try_from_table_attr(table_attribute).unwrap();
612
613 assert_eq!(original, test_vec);
614 }
615
616 #[test]
617 fn test_hashmap() {
618 let map = [
619 (MapKeys::A, "Something in A".to_string()),
620 (MapKeys::A, "Something in B".to_string()),
621 (MapKeys::A, "Something in C".to_string()),
622 ]
623 .into_iter()
624 .collect::<HashMap<_, _>>();
625
626 let table_attribute = TableAttribute::from(map.clone());
627
628 assert!(matches!(
629 &table_attribute,
630 TableAttribute::Map(x)
631 if x.len() == map.len()
632 ));
633
634 let original = HashMap::<MapKeys, String>::try_from_table_attr(table_attribute).unwrap();
635
636 assert_eq!(original, map);
637 }
638
639 #[test]
640 fn test_btreemap() {
641 let map = [
642 (MapKeys::A, "Something in A".to_string()),
643 (MapKeys::A, "Something in B".to_string()),
644 (MapKeys::A, "Something in C".to_string()),
645 ]
646 .into_iter()
647 .collect::<BTreeMap<_, _>>();
648
649 let table_attribute = TableAttribute::from(map.clone());
650
651 assert!(matches!(
652 &table_attribute,
653 TableAttribute::Map(x)
654 if x.len() == map.len()
655 ));
656
657 let original = BTreeMap::<MapKeys, String>::try_from_table_attr(table_attribute).unwrap();
658
659 assert_eq!(original, map);
660 }
661}