cipherstash_dynamodb/encrypted_table/
table_attribute.rs

1use super::{ReadConversionError, SealError};
2use aws_sdk_dynamodb::{primitives::Blob, types::AttributeValue};
3use cipherstash_client::zerokms::EncryptedRecord;
4use std::{
5    collections::{BTreeMap, HashMap},
6    str::FromStr,
7};
8
9/// Trait for converting `TableAttribute` to `Self`
10pub trait TryFromTableAttr: Sized {
11    /// Try to convert `value` to `Self`
12    fn try_from_table_attr(value: TableAttribute) -> Result<Self, ReadConversionError>;
13}
14
15#[derive(Clone, PartialEq, Debug)]
16pub enum TableAttribute {
17    String(String),
18    Number(String),
19    Bool(bool),
20    Bytes(Vec<u8>),
21
22    StringVec(Vec<String>),
23    ByteVec(Vec<Vec<u8>>),
24    NumberVec(Vec<String>),
25    Map(HashMap<String, TableAttribute>),
26    List(Vec<TableAttribute>),
27
28    Null,
29}
30
31impl TableAttribute {
32    // TODO: Unit test this
33    /// Try to convert the `TableAttribute` to an `EncryptedRecord` if it is a `Bytes` variant.
34    /// The descriptor of the record is checked against the `descriptor` argument
35    /// (which will be verified to be the correct descriptor for the record via AAD).
36    ///
37    /// If the descriptor does not match, an error is returned and this may indicate that the record
38    /// has been tampered with (e.g. via a confused deputy attack).
39    pub(crate) fn as_encrypted_record(
40        &self,
41        descriptor: &str,
42    ) -> Result<EncryptedRecord, SealError> {
43        if let TableAttribute::Bytes(s) = self {
44            EncryptedRecord::from_slice(&s[..])
45                .map_err(|_| SealError::AssertionFailed("Could not parse EncryptedRecord".to_string()))
46                .and_then(|record| {
47                    if record.descriptor == descriptor {
48                        Ok(record)
49                    } else {
50                        Err(SealError::AssertionFailed(format!(
51                            "Expected descriptor {}, got {} - WARNING: record may have been tampered with",
52                            descriptor,
53                            record.descriptor
54                        )))
55                    }
56                })
57        } else {
58            Err(SealError::AssertionFailed(format!(
59                "Expected TableAttribute::Bytes, got {}",
60                descriptor
61            )))
62        }
63    }
64
65    pub(crate) fn new_map() -> Self {
66        TableAttribute::Map(HashMap::new())
67    }
68
69    /// Try to insert a new key-value pair if this is a map.
70    /// Returns `Ok(())` if the key-value pair was inserted, otherwise [SealError::AssertionFailed].
71    pub(crate) fn try_insert_map(
72        &mut self,
73        key: impl Into<String>,
74        value: impl Into<TableAttribute>,
75    ) -> Result<(), SealError> {
76        if let Self::Map(map) = self {
77            map.insert(key.into(), value.into());
78            Ok(())
79        } else {
80            Err(SealError::AssertionFailed(
81                "Expected TableAttribute::Map".to_string(),
82            ))
83        }
84    }
85}
86
87macro_rules! impl_try_from_table_attr_helper {
88    (number_parse, $ty:ty, $value:ident) => {
89        $value
90            .parse()
91            .map_err(|_| ReadConversionError::ConversionFailed(stringify!($ty).to_string()))
92    };
93    (simple_parse, $_:ty, $value:ident) => {
94        Ok::<_, ReadConversionError>($value)
95    };
96    (number_from, $_:ident, $value:ident) => {
97        TableAttribute::Number($value.to_string())
98    };
99    (simple_from, $variant:ident, $value:ident) => {
100        TableAttribute::$variant($value)
101    };
102    (
103        body,
104        $ty:ty,
105        $variant:ident,
106        $from_impl:ident!($from_args:tt),
107        $try_from_impl:ident!($try_from_args:tt)
108    ) => {
109        impl From<$ty> for TableAttribute {
110            fn from(value: $ty) -> Self {
111                $from_impl!($from_args, $variant, value)
112            }
113        }
114
115        impl TryFromTableAttr for $ty {
116            fn try_from_table_attr(value: TableAttribute) -> Result<Self, ReadConversionError> {
117                let TableAttribute::$variant(value) = value else {
118                    return Err(ReadConversionError::ConversionFailed(
119                        stringify!($ty).to_string(),
120                    ));
121                };
122
123                $try_from_impl!($try_from_args, $ty, value)
124            }
125        }
126    };
127}
128
129macro_rules! impl_try_from_table_attr {
130    () => {};
131    (, $($tail:tt)*) => {
132        impl_try_from_table_attr!($($tail)*);
133    };
134    ($ty:ty => Number $($tail:tt)*) => {
135        impl_try_from_table_attr_helper!(
136            body,
137            $ty,
138            Number,
139            impl_try_from_table_attr_helper!(
140                number_from
141            ),
142            impl_try_from_table_attr_helper!(
143                number_parse
144            )
145        );
146
147        impl_try_from_table_attr!($($tail)*);
148    };
149    ($ty:ty => $variant:ident $($tail:tt)*) => {
150        impl_try_from_table_attr_helper!(
151            body,
152            $ty,
153            $variant,
154            impl_try_from_table_attr_helper!(
155                simple_from
156            ),
157            impl_try_from_table_attr_helper!(
158                simple_parse
159            )
160        );
161
162        impl_try_from_table_attr!($($tail)*);
163    };
164}
165
166// The following implementations are covered by the blanket implementation on Vec<T>
167// Vec<String> => StringVec,
168// Vec<some number type> => NumberVec,
169// Vec<Vec<u8>> => ByteVec,
170impl_try_from_table_attr!(
171    i16 => Number,
172    i32 => Number,
173    i64 => Number,
174    u16 => Number,
175    u32 => Number,
176    u64 => Number,
177    usize => Number,
178    f32 => Number,
179    f64  => Number,
180    String => String,
181    Vec<u8> => Bytes,
182    bool => Bool
183);
184
185impl From<&str> for TableAttribute {
186    fn from(value: &str) -> Self {
187        TableAttribute::String(value.to_string())
188    }
189}
190
191impl<T> TryFromTableAttr for Option<T>
192where
193    T: TryFromTableAttr,
194{
195    fn try_from_table_attr(value: TableAttribute) -> Result<Self, ReadConversionError> {
196        if matches!(value, TableAttribute::Null) {
197            Ok(None)
198        } else {
199            Ok(Some(T::try_from_table_attr(value)?))
200        }
201    }
202}
203
204impl<T> TryFromTableAttr for Vec<T>
205where
206    T: TryFromTableAttr,
207{
208    fn try_from_table_attr(value: TableAttribute) -> Result<Self, ReadConversionError> {
209        match value {
210            TableAttribute::StringVec(v) => v
211                .into_iter()
212                .map(TableAttribute::String)
213                .map(T::try_from_table_attr)
214                .collect(),
215            TableAttribute::ByteVec(v) => v
216                .into_iter()
217                .map(TableAttribute::Bytes)
218                .map(T::try_from_table_attr)
219                .collect(),
220            TableAttribute::NumberVec(v) => v
221                .into_iter()
222                .map(TableAttribute::Number)
223                .map(T::try_from_table_attr)
224                .collect(),
225            TableAttribute::List(v) => v.into_iter().map(T::try_from_table_attr).collect(),
226            _ => Err(ReadConversionError::ConversionFailed(
227                std::any::type_name::<Vec<T>>().to_string(),
228            )),
229        }
230    }
231}
232
233impl<T> From<Option<T>> for TableAttribute
234where
235    T: Into<TableAttribute>,
236{
237    fn from(value: Option<T>) -> Self {
238        match value {
239            Some(value) => value.into(),
240            None => TableAttribute::Null,
241        }
242    }
243}
244
245impl<T> From<Vec<T>> for TableAttribute
246where
247    T: Into<TableAttribute>,
248{
249    fn from(value: Vec<T>) -> Self {
250        // To determin whether we should produce a
251        // Ss, Ns, Bs or a regular list, we will iterate
252        // through the list and check if the all are the same
253        // variant.
254        #[derive(Clone, Copy, PartialEq, Eq)]
255        enum IsVariant {
256            // base case, we haven't looked at any elements yet.
257            Empty,
258            // Is String list
259            IsSs,
260            // Is Number list
261            IsNs,
262            // Is byte list
263            IsBs,
264            // Is mixed list
265            IsList,
266        }
267
268        let len = value.len();
269        let (table_attributes, is_variant) = value.into_iter().fold(
270            (Vec::with_capacity(len), IsVariant::Empty),
271            |(mut acc, mut is_variant), item| {
272                let table_attr = item.into();
273
274                // Don't check the variant if we already know it is a mixed list
275                if is_variant != IsVariant::IsList {
276                    match (&table_attr, is_variant) {
277                        (TableAttribute::Bytes(_), IsVariant::Empty)
278                        | (TableAttribute::Bytes(_), IsVariant::IsBs) => {
279                            is_variant = IsVariant::IsBs
280                        }
281                        (TableAttribute::Number(_), IsVariant::Empty)
282                        | (TableAttribute::Number(_), IsVariant::IsNs) => {
283                            is_variant = IsVariant::IsNs
284                        }
285                        (TableAttribute::String(_), IsVariant::Empty)
286                        | (TableAttribute::String(_), IsVariant::IsSs) => {
287                            is_variant = IsVariant::IsSs
288                        }
289                        _ => is_variant = IsVariant::IsList,
290                    }
291                }
292
293                acc.push(table_attr);
294                (acc, is_variant)
295            },
296        );
297
298        match is_variant {
299            IsVariant::IsList | IsVariant::Empty => TableAttribute::List(table_attributes),
300            IsVariant::IsSs => {
301                let strings = table_attributes
302                    .into_iter()
303                    .map(|string| {
304                        let TableAttribute::String(string) = string else {
305                            // We already checked that all the items are strings
306                            unreachable!()
307                        };
308
309                        string
310                    })
311                    .collect();
312
313                TableAttribute::StringVec(strings)
314            }
315            IsVariant::IsNs => {
316                let numbers = table_attributes
317                    .into_iter()
318                    .map(|number| {
319                        let TableAttribute::Number(number) = number else {
320                            // We already checked that all the items are numbers
321                            unreachable!()
322                        };
323
324                        number
325                    })
326                    .collect();
327
328                TableAttribute::NumberVec(numbers)
329            }
330            IsVariant::IsBs => {
331                let bytes = table_attributes
332                    .into_iter()
333                    .map(|bytes| {
334                        let TableAttribute::Bytes(bytes) = bytes else {
335                            // We already checked that all the items are bytes
336                            unreachable!()
337                        };
338
339                        bytes
340                    })
341                    .collect();
342
343                TableAttribute::ByteVec(bytes)
344            }
345        }
346    }
347}
348
349impl From<TableAttribute> for AttributeValue {
350    fn from(attribute: TableAttribute) -> Self {
351        match attribute {
352            TableAttribute::String(s) => AttributeValue::S(s),
353            TableAttribute::StringVec(s) => AttributeValue::Ss(s),
354
355            TableAttribute::Number(i) => AttributeValue::N(i),
356            TableAttribute::NumberVec(x) => AttributeValue::Ns(x),
357
358            TableAttribute::Bytes(x) => AttributeValue::B(Blob::new(x)),
359            TableAttribute::ByteVec(x) => {
360                AttributeValue::Bs(x.into_iter().map(Blob::new).collect())
361            }
362
363            TableAttribute::Bool(x) => AttributeValue::Bool(x),
364            TableAttribute::List(x) => AttributeValue::L(x.into_iter().map(|x| x.into()).collect()),
365            TableAttribute::Map(x) => {
366                AttributeValue::M(x.into_iter().map(|(k, v)| (k, v.into())).collect())
367            }
368            TableAttribute::Null => AttributeValue::Null(true),
369        }
370    }
371}
372
373impl From<AttributeValue> for TableAttribute {
374    fn from(attribute: AttributeValue) -> Self {
375        match attribute {
376            AttributeValue::S(s) => TableAttribute::String(s),
377            AttributeValue::N(n) => TableAttribute::Number(n),
378            AttributeValue::Bool(n) => TableAttribute::Bool(n),
379            AttributeValue::B(n) => TableAttribute::Bytes(n.into_inner()),
380            AttributeValue::L(l) => {
381                TableAttribute::List(l.into_iter().map(TableAttribute::from).collect())
382            }
383            AttributeValue::M(l) => TableAttribute::Map(
384                l.into_iter()
385                    .map(|(k, v)| (k, TableAttribute::from(v)))
386                    .collect(),
387            ),
388            AttributeValue::Bs(x) => {
389                TableAttribute::ByteVec(x.into_iter().map(|x| x.into_inner()).collect())
390            }
391            AttributeValue::Ss(x) => TableAttribute::StringVec(x),
392            AttributeValue::Ns(x) => TableAttribute::NumberVec(x),
393            AttributeValue::Null(_) => TableAttribute::Null,
394
395            x => panic!("Unsupported Dynamo attribute value: {x:?}"),
396        }
397    }
398}
399
400impl<K, V> From<HashMap<K, V>> for TableAttribute
401where
402    K: ToString,
403    V: Into<TableAttribute>,
404{
405    fn from(map: HashMap<K, V>) -> Self {
406        TableAttribute::Map(
407            map.into_iter()
408                .map(|(k, v)| (k.to_string(), v.into()))
409                .collect(),
410        )
411    }
412}
413
414impl<K, V> TryFromTableAttr for HashMap<K, V>
415where
416    K: FromStr + std::hash::Hash + std::cmp::Eq,
417    V: TryFromTableAttr,
418{
419    fn try_from_table_attr(value: TableAttribute) -> Result<Self, ReadConversionError> {
420        let TableAttribute::Map(map) = value else {
421            return Err(ReadConversionError::ConversionFailed(
422                std::any::type_name::<Self>().to_string(),
423            ));
424        };
425
426        map.into_iter()
427            .map(|(k, v)| {
428                let k = k.parse().map_err(|_| {
429                    ReadConversionError::ConversionFailed(std::any::type_name::<Self>().to_string())
430                })?;
431                let v = V::try_from_table_attr(v)?;
432
433                Ok((k, v))
434            })
435            .collect()
436    }
437}
438
439impl<K, V> From<BTreeMap<K, V>> for TableAttribute
440where
441    K: ToString,
442    V: Into<TableAttribute>,
443{
444    fn from(map: BTreeMap<K, V>) -> Self {
445        TableAttribute::Map(
446            map.into_iter()
447                .map(|(k, v)| (k.to_string(), v.into()))
448                .collect(),
449        )
450    }
451}
452
453impl<K, V> TryFromTableAttr for BTreeMap<K, V>
454where
455    K: FromStr + std::cmp::Ord,
456    V: TryFromTableAttr,
457{
458    fn try_from_table_attr(value: TableAttribute) -> Result<Self, ReadConversionError> {
459        let TableAttribute::Map(map) = value else {
460            return Err(ReadConversionError::ConversionFailed(
461                std::any::type_name::<Self>().to_string(),
462            ));
463        };
464
465        map.into_iter()
466            .map(|(k, v)| {
467                let k = k.parse().map_err(|_| {
468                    ReadConversionError::ConversionFailed(std::any::type_name::<Self>().to_string())
469                })?;
470                let v = V::try_from_table_attr(v)?;
471
472                Ok((k, v))
473            })
474            .collect()
475    }
476}
477
478#[cfg(test)]
479mod test {
480    use super::*;
481
482    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
483    enum TestType {
484        Number,
485        String,
486        Bytes,
487    }
488
489    impl From<TestType> for TableAttribute {
490        fn from(value: TestType) -> Self {
491            match value {
492                TestType::Number => TableAttribute::Number(42.to_string()),
493                TestType::String => TableAttribute::String("fourty two".to_string()),
494                TestType::Bytes => TableAttribute::Bytes(b"101010".to_vec()),
495            }
496        }
497    }
498
499    impl TryFromTableAttr for TestType {
500        fn try_from_table_attr(value: TableAttribute) -> Result<Self, ReadConversionError> {
501            match value {
502                TableAttribute::Number(n) if n == "42" => Ok(Self::Number),
503                TableAttribute::String(s) if s == "fourty two" => Ok(Self::String),
504                TableAttribute::Bytes(b) if b == b"101010" => Ok(Self::Bytes),
505                _ => Err(ReadConversionError::ConversionFailed("".to_string())),
506            }
507        }
508    }
509
510    #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
511    enum MapKeys {
512        A,
513        B,
514        C,
515    }
516
517    impl std::fmt::Display for MapKeys {
518        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
519            let c = match self {
520                MapKeys::A => "A",
521                MapKeys::B => "B",
522                MapKeys::C => "C",
523            };
524
525            write!(f, "{c}")
526        }
527    }
528
529    impl FromStr for MapKeys {
530        type Err = ();
531
532        fn from_str(s: &str) -> Result<Self, Self::Err> {
533            match s {
534                "A" => Ok(MapKeys::A),
535                "B" => Ok(MapKeys::B),
536                "C" => Ok(MapKeys::C),
537                _ => Err(()),
538            }
539        }
540    }
541
542    #[test]
543    fn test_to_and_from_list() {
544        let test_vec = vec![
545            TestType::Number,
546            TestType::Number,
547            TestType::String,
548            TestType::Bytes,
549        ];
550
551        let table_attribute = TableAttribute::from(test_vec.clone());
552
553        // Assert that we convert to the correct variant.
554        assert!(matches!(&table_attribute, TableAttribute::List(x) if x.len() == test_vec.len()));
555
556        let original = Vec::<TestType>::try_from_table_attr(table_attribute).unwrap();
557
558        assert_eq!(original, test_vec);
559    }
560
561    #[test]
562    fn test_string_vec() {
563        let test_vec = vec![
564            "String0".to_string(),
565            "String1".to_string(),
566            "String2".to_string(),
567        ];
568
569        let table_attribute = TableAttribute::from(test_vec.clone());
570
571        assert!(matches!(
572            &table_attribute,
573            TableAttribute::StringVec(x)
574            if x.len() == test_vec.len()
575        ));
576
577        let original = Vec::<String>::try_from_table_attr(table_attribute).unwrap();
578
579        assert_eq!(original, test_vec);
580    }
581
582    #[test]
583    fn test_number_vec() {
584        let test_vec = vec![2, 3, 5, 7, 13];
585
586        let table_attribute = TableAttribute::from(test_vec.clone());
587
588        assert!(matches!(
589            &table_attribute,
590            TableAttribute::NumberVec(x)
591            if x.len() == test_vec.len()
592        ));
593
594        let original = Vec::<i32>::try_from_table_attr(table_attribute).unwrap();
595
596        assert_eq!(original, test_vec);
597    }
598
599    #[test]
600    fn test_bytes_vec() {
601        let test_vec: Vec<Vec<u8>> = (0u8..5).map(|i| (i * 10..i * 10 + 10).collect()).collect();
602
603        let table_attribute = TableAttribute::from(test_vec.clone());
604
605        assert!(matches!(
606            &table_attribute,
607            TableAttribute::ByteVec(x)
608            if x.len() == test_vec.len()
609        ));
610
611        let original = Vec::<Vec<u8>>::try_from_table_attr(table_attribute).unwrap();
612
613        assert_eq!(original, test_vec);
614    }
615
616    #[test]
617    fn test_hashmap() {
618        let map = [
619            (MapKeys::A, "Something in A".to_string()),
620            (MapKeys::A, "Something in B".to_string()),
621            (MapKeys::A, "Something in C".to_string()),
622        ]
623        .into_iter()
624        .collect::<HashMap<_, _>>();
625
626        let table_attribute = TableAttribute::from(map.clone());
627
628        assert!(matches!(
629            &table_attribute,
630            TableAttribute::Map(x)
631            if x.len() == map.len()
632        ));
633
634        let original = HashMap::<MapKeys, String>::try_from_table_attr(table_attribute).unwrap();
635
636        assert_eq!(original, map);
637    }
638
639    #[test]
640    fn test_btreemap() {
641        let map = [
642            (MapKeys::A, "Something in A".to_string()),
643            (MapKeys::A, "Something in B".to_string()),
644            (MapKeys::A, "Something in C".to_string()),
645        ]
646        .into_iter()
647        .collect::<BTreeMap<_, _>>();
648
649        let table_attribute = TableAttribute::from(map.clone());
650
651        assert!(matches!(
652            &table_attribute,
653            TableAttribute::Map(x)
654            if x.len() == map.len()
655        ));
656
657        let original = BTreeMap::<MapKeys, String>::try_from_table_attr(table_attribute).unwrap();
658
659        assert_eq!(original, map);
660    }
661}