cipherstash_dynamodb/crypto/
sealer.rs

1use super::{
2    attrs::FlattenedProtectedAttributes, b64_encode, format_term_key, hmac, SealError,
3    SealedTableEntry, Unsealed, MAX_TERMS_PER_INDEX,
4};
5use crate::{
6    encrypted_table::{AttributeName, TableAttribute, TableAttributes, TableEntry},
7    traits::PrimaryKeyParts,
8    IndexType,
9};
10use cipherstash_client::{
11    credentials::{service_credentials::ServiceToken, Credentials},
12    encryption::{
13        compound_indexer::{ComposableIndex, ComposablePlaintext, CompoundIndex},
14        Encryption, IndexTerm,
15    },
16};
17use itertools::Itertools;
18use std::{borrow::Cow, collections::HashMap};
19
20/// The combination of plaintext, index, name and index type for a particular field
21pub type UnsealedIndex = (
22    ComposablePlaintext,
23    Box<dyn ComposableIndex + Send>,
24    Cow<'static, str>,
25    IndexType,
26);
27
28/// Builder pattern for sealing a record of type, `T`.
29pub struct Sealer {
30    pub(crate) pk: String,
31    pub(crate) sk: String,
32
33    pub(crate) is_pk_encrypted: bool,
34    pub(crate) is_sk_encrypted: bool,
35
36    pub(crate) type_name: Cow<'static, str>,
37
38    pub(crate) unsealed_indexes: Vec<UnsealedIndex>,
39
40    pub(crate) unsealed: Unsealed,
41}
42
43struct RecordsWithTerms {
44    num_protected_attributes: usize,
45    records: Vec<RecordWithTerms>,
46}
47
48impl RecordsWithTerms {
49    fn new(records: Vec<RecordWithTerms>, num_protected_attributes: usize) -> Self {
50        Self {
51            num_protected_attributes,
52            records,
53        }
54    }
55
56    async fn encrypt(
57        self,
58        cipher: &Encryption<impl Credentials<Token = ServiceToken>>,
59    ) -> Result<Vec<Sealed>, SealError> {
60        let num_records = self.records.len();
61        let mut pksks = Vec::with_capacity(num_records);
62        let mut record_terms = Vec::with_capacity(num_records);
63        let mut unprotecteds = Vec::with_capacity(num_records);
64        let mut protected = FlattenedProtectedAttributes::new_with_capacity(
65            num_records * self.num_protected_attributes,
66        );
67
68        for sealer_with_terms in self.records {
69            let (pksk, terms, flattened_protected, unprotected) = sealer_with_terms.into_parts();
70
71            pksks.push(pksk);
72            record_terms.push(terms);
73            unprotecteds.push(unprotected);
74            protected.extend(flattened_protected.into_iter());
75        }
76
77        // TODO: Split this out into separate functions and/or implement From for the tuple into Sealed
78        if protected.is_empty() {
79            unprotecteds
80                .into_iter()
81                .zip_eq(record_terms.into_iter())
82                .zip_eq(pksks.into_iter())
83                .map(|record| {
84                    let (attributes, terms, pksk) = flatten_tuple_3(record);
85                    Ok(Sealed {
86                        pk: pksk.pk,
87                        sk: pksk.sk,
88                        attributes,
89                        terms,
90                    })
91                })
92                .collect()
93        } else {
94            let encrypted = protected.encrypt_all(cipher, num_records).await?;
95
96            encrypted
97                .into_iter()
98                .zip_eq(unprotecteds.into_iter())
99                .zip_eq(record_terms.into_iter())
100                .zip_eq(pksks.into_iter())
101                .map(|record| {
102                    let (enc_attrs, unprotecteds, terms, pksk) = flatten_tuple_4(record);
103                    enc_attrs.denormalize().map(|protected_attrs| Sealed {
104                        pk: pksk.pk,
105                        sk: pksk.sk,
106                        attributes: unprotecteds.merge(protected_attrs),
107                        terms,
108                    })
109                })
110                .collect()
111        }
112    }
113}
114
115struct RecordWithTerms {
116    pksk: PrimaryKeyParts,
117    unsealed: Unsealed,
118    terms: Vec<Term>,
119}
120
121impl RecordWithTerms {
122    fn into_parts(
123        self,
124    ) -> (
125        PrimaryKeyParts,
126        Vec<Term>,
127        FlattenedProtectedAttributes,
128        TableAttributes,
129    ) {
130        let (flattened_protected, unprotected) = self.unsealed.flatten_into_parts();
131        (self.pksk, self.terms, flattened_protected, unprotected)
132    }
133}
134
135impl Sealer {
136    fn index_all_terms<'a>(
137        records: impl IntoIterator<Item = Sealer>,
138        protected_attributes: impl AsRef<[Cow<'a, str>]>,
139        cipher: &Encryption<impl Credentials<Token = ServiceToken>>,
140        term_length: usize,
141    ) -> Result<RecordsWithTerms, SealError> {
142        let protected_attributes = protected_attributes.as_ref();
143        let num_protected_attributes = protected_attributes.len();
144
145        records
146            .into_iter()
147            .map(|sealer| {
148                let mut pk = sealer.pk;
149                let mut sk = sealer.sk;
150
151                if sealer.is_pk_encrypted {
152                    pk = b64_encode(hmac(&pk, None, cipher)?);
153                }
154
155                if sealer.is_sk_encrypted {
156                    sk = b64_encode(hmac(&sk, Some(pk.as_str()), cipher)?);
157                }
158
159                let type_name = &sealer.type_name;
160
161                // Index name, type and term
162                let terms: Vec<(Cow<'_, str>, IndexType, Vec<u8>)> = sealer
163                    .unsealed_indexes
164                    .into_iter()
165                    .map(|(attr, index, index_name, index_type)| {
166                        let term = cipher.compound_index(
167                            &CompoundIndex::new(index),
168                            attr,
169                            Some(format!("{}#{}", type_name, index_name)),
170                            term_length,
171                        )?;
172
173                        Ok::<_, SealError>((index_name, index_type, term))
174                    })
175                    .map(|index_term| match index_term {
176                        Ok((index_name, index_type, IndexTerm::Binary(x))) => {
177                            Ok(vec![(index_name, index_type, x)])
178                        }
179                        Ok((index_name, index_type, IndexTerm::BinaryVec(x))) => Ok(x
180                            .into_iter()
181                            .take(MAX_TERMS_PER_INDEX)
182                            .map(|x| (index_name.clone(), index_type, x))
183                            .collect()),
184                        _ => Err(SealError::InvalidCiphertext("Invalid index term".into())),
185                    })
186                    .flatten_ok()
187                    .try_collect()?;
188
189                let terms = terms
190                    .into_iter()
191                    .enumerate()
192                    .map(|(i, (index_name, index_type, value))| {
193                        let sk = b64_encode(hmac(
194                            &format_term_key(sk.as_str(), &index_name, index_type, i),
195                            Some(pk.as_str()),
196                            cipher,
197                        )?);
198
199                        Ok::<_, SealError>(Term { sk, value })
200                    })
201                    .collect::<Result<Vec<Term>, _>>()?;
202
203                Ok(RecordWithTerms {
204                    pksk: PrimaryKeyParts { pk, sk },
205                    unsealed: sealer.unsealed,
206                    terms,
207                })
208            })
209            .try_collect()
210            .map(|records| RecordsWithTerms::new(records, num_protected_attributes))
211    }
212
213    pub(crate) async fn seal_all<'a>(
214        records: impl IntoIterator<Item = Sealer>,
215        protected_attributes: impl AsRef<[Cow<'a, str>]>,
216        cipher: &Encryption<impl Credentials<Token = ServiceToken>>,
217        term_length: usize,
218    ) -> Result<Vec<Sealed>, SealError> {
219        Self::index_all_terms(records, protected_attributes, cipher, term_length)?
220            .encrypt(cipher)
221            .await
222    }
223
224    pub(crate) async fn seal<'a>(
225        self,
226        protected_attributes: impl AsRef<[Cow<'a, str>]>,
227        cipher: &Encryption<impl Credentials<Token = ServiceToken>>,
228        term_length: usize,
229    ) -> Result<Sealed, SealError> {
230        let mut vec = Self::seal_all([self], protected_attributes, cipher, term_length).await?;
231
232        if vec.len() != 1 {
233            let actual = vec.len();
234
235            return Err(SealError::AssertionFailed(format!(
236                "Expected seal_all to return 1 result but got {actual}"
237            )));
238        }
239
240        Ok(vec.remove(0))
241    }
242}
243
244// FIXME: Remove this (only used for debugging)
245#[derive(Debug)]
246struct Term {
247    sk: String,
248    value: Vec<u8>,
249}
250
251// FIXME: This struct is almost _identical_ to the one in encrypted_table/table_entry.rs
252pub struct Sealed {
253    pk: String,
254    sk: String,
255    attributes: TableAttributes,
256    terms: Vec<Term>,
257}
258
259impl Sealed {
260    pub fn len(&self) -> usize {
261        // the length of the terms plus the root entry
262        self.terms.len() + 1
263    }
264
265    pub fn primary_key(&self) -> PrimaryKeyParts {
266        PrimaryKeyParts {
267            pk: self.pk.clone(),
268            sk: self.sk.clone(),
269        }
270    }
271
272    /// Returns the root entry and the term entries for this record.
273    /// `index_predicate` is used to... TODO!!!.
274    pub fn into_table_entries(
275        self,
276        mut index_predicate: impl FnMut(&AttributeName, &TableAttribute) -> bool,
277    ) -> (SealedTableEntry, Vec<SealedTableEntry>) {
278        let root_attributes = self.attributes;
279
280        let index_attributes: TableAttributes = root_attributes
281            .clone()
282            .into_iter()
283            .filter(|(name, value)| index_predicate(name, value))
284            .map(|(name, value)| (name, value.clone()))
285            .collect::<HashMap<_, _>>()
286            .into();
287
288        let term_entries = self
289            .terms
290            .into_iter()
291            .map(|Term { sk, value }| {
292                SealedTableEntry(TableEntry::new_with_attributes(
293                    self.pk.clone(),
294                    sk,
295                    Some(value),
296                    index_attributes.clone(),
297                ))
298            })
299            .collect();
300
301        (
302            SealedTableEntry(TableEntry::new_with_attributes(
303                self.pk,
304                self.sk,
305                None,
306                root_attributes,
307            )),
308            term_entries,
309        )
310    }
311}
312
313// TODO: Move these somewhere else
314#[inline]
315fn flatten_tuple_4<A, B, C, D>((((a, b), c), d): (((A, B), C), D)) -> (A, B, C, D) {
316    (a, b, c, d)
317}
318
319#[inline]
320fn flatten_tuple_3<A, B, C>(((a, b), c): ((A, B), C)) -> (A, B, C) {
321    (a, b, c)
322}