1use super::{
2 attrs::FlattenedProtectedAttributes, b64_encode, format_term_key, hmac, SealError,
3 SealedTableEntry, Unsealed, MAX_TERMS_PER_INDEX,
4};
5use crate::{
6 encrypted_table::{AttributeName, TableAttribute, TableAttributes, TableEntry},
7 traits::PrimaryKeyParts,
8 IndexType,
9};
10use cipherstash_client::{
11 credentials::{service_credentials::ServiceToken, Credentials},
12 encryption::{
13 compound_indexer::{ComposableIndex, ComposablePlaintext, CompoundIndex},
14 Encryption, IndexTerm,
15 },
16};
17use itertools::Itertools;
18use std::{borrow::Cow, collections::HashMap};
19
20pub type UnsealedIndex = (
22 ComposablePlaintext,
23 Box<dyn ComposableIndex + Send>,
24 Cow<'static, str>,
25 IndexType,
26);
27
28pub struct Sealer {
30 pub(crate) pk: String,
31 pub(crate) sk: String,
32
33 pub(crate) is_pk_encrypted: bool,
34 pub(crate) is_sk_encrypted: bool,
35
36 pub(crate) type_name: Cow<'static, str>,
37
38 pub(crate) unsealed_indexes: Vec<UnsealedIndex>,
39
40 pub(crate) unsealed: Unsealed,
41}
42
43struct RecordsWithTerms {
44 num_protected_attributes: usize,
45 records: Vec<RecordWithTerms>,
46}
47
48impl RecordsWithTerms {
49 fn new(records: Vec<RecordWithTerms>, num_protected_attributes: usize) -> Self {
50 Self {
51 num_protected_attributes,
52 records,
53 }
54 }
55
56 async fn encrypt(
57 self,
58 cipher: &Encryption<impl Credentials<Token = ServiceToken>>,
59 ) -> Result<Vec<Sealed>, SealError> {
60 let num_records = self.records.len();
61 let mut pksks = Vec::with_capacity(num_records);
62 let mut record_terms = Vec::with_capacity(num_records);
63 let mut unprotecteds = Vec::with_capacity(num_records);
64 let mut protected = FlattenedProtectedAttributes::new_with_capacity(
65 num_records * self.num_protected_attributes,
66 );
67
68 for sealer_with_terms in self.records {
69 let (pksk, terms, flattened_protected, unprotected) = sealer_with_terms.into_parts();
70
71 pksks.push(pksk);
72 record_terms.push(terms);
73 unprotecteds.push(unprotected);
74 protected.extend(flattened_protected.into_iter());
75 }
76
77 if protected.is_empty() {
79 unprotecteds
80 .into_iter()
81 .zip_eq(record_terms.into_iter())
82 .zip_eq(pksks.into_iter())
83 .map(|record| {
84 let (attributes, terms, pksk) = flatten_tuple_3(record);
85 Ok(Sealed {
86 pk: pksk.pk,
87 sk: pksk.sk,
88 attributes,
89 terms,
90 })
91 })
92 .collect()
93 } else {
94 let encrypted = protected.encrypt_all(cipher, num_records).await?;
95
96 encrypted
97 .into_iter()
98 .zip_eq(unprotecteds.into_iter())
99 .zip_eq(record_terms.into_iter())
100 .zip_eq(pksks.into_iter())
101 .map(|record| {
102 let (enc_attrs, unprotecteds, terms, pksk) = flatten_tuple_4(record);
103 enc_attrs.denormalize().map(|protected_attrs| Sealed {
104 pk: pksk.pk,
105 sk: pksk.sk,
106 attributes: unprotecteds.merge(protected_attrs),
107 terms,
108 })
109 })
110 .collect()
111 }
112 }
113}
114
115struct RecordWithTerms {
116 pksk: PrimaryKeyParts,
117 unsealed: Unsealed,
118 terms: Vec<Term>,
119}
120
121impl RecordWithTerms {
122 fn into_parts(
123 self,
124 ) -> (
125 PrimaryKeyParts,
126 Vec<Term>,
127 FlattenedProtectedAttributes,
128 TableAttributes,
129 ) {
130 let (flattened_protected, unprotected) = self.unsealed.flatten_into_parts();
131 (self.pksk, self.terms, flattened_protected, unprotected)
132 }
133}
134
135impl Sealer {
136 fn index_all_terms<'a>(
137 records: impl IntoIterator<Item = Sealer>,
138 protected_attributes: impl AsRef<[Cow<'a, str>]>,
139 cipher: &Encryption<impl Credentials<Token = ServiceToken>>,
140 term_length: usize,
141 ) -> Result<RecordsWithTerms, SealError> {
142 let protected_attributes = protected_attributes.as_ref();
143 let num_protected_attributes = protected_attributes.len();
144
145 records
146 .into_iter()
147 .map(|sealer| {
148 let mut pk = sealer.pk;
149 let mut sk = sealer.sk;
150
151 if sealer.is_pk_encrypted {
152 pk = b64_encode(hmac(&pk, None, cipher)?);
153 }
154
155 if sealer.is_sk_encrypted {
156 sk = b64_encode(hmac(&sk, Some(pk.as_str()), cipher)?);
157 }
158
159 let type_name = &sealer.type_name;
160
161 let terms: Vec<(Cow<'_, str>, IndexType, Vec<u8>)> = sealer
163 .unsealed_indexes
164 .into_iter()
165 .map(|(attr, index, index_name, index_type)| {
166 let term = cipher.compound_index(
167 &CompoundIndex::new(index),
168 attr,
169 Some(format!("{}#{}", type_name, index_name)),
170 term_length,
171 )?;
172
173 Ok::<_, SealError>((index_name, index_type, term))
174 })
175 .map(|index_term| match index_term {
176 Ok((index_name, index_type, IndexTerm::Binary(x))) => {
177 Ok(vec![(index_name, index_type, x)])
178 }
179 Ok((index_name, index_type, IndexTerm::BinaryVec(x))) => Ok(x
180 .into_iter()
181 .take(MAX_TERMS_PER_INDEX)
182 .map(|x| (index_name.clone(), index_type, x))
183 .collect()),
184 _ => Err(SealError::InvalidCiphertext("Invalid index term".into())),
185 })
186 .flatten_ok()
187 .try_collect()?;
188
189 let terms = terms
190 .into_iter()
191 .enumerate()
192 .map(|(i, (index_name, index_type, value))| {
193 let sk = b64_encode(hmac(
194 &format_term_key(sk.as_str(), &index_name, index_type, i),
195 Some(pk.as_str()),
196 cipher,
197 )?);
198
199 Ok::<_, SealError>(Term { sk, value })
200 })
201 .collect::<Result<Vec<Term>, _>>()?;
202
203 Ok(RecordWithTerms {
204 pksk: PrimaryKeyParts { pk, sk },
205 unsealed: sealer.unsealed,
206 terms,
207 })
208 })
209 .try_collect()
210 .map(|records| RecordsWithTerms::new(records, num_protected_attributes))
211 }
212
213 pub(crate) async fn seal_all<'a>(
214 records: impl IntoIterator<Item = Sealer>,
215 protected_attributes: impl AsRef<[Cow<'a, str>]>,
216 cipher: &Encryption<impl Credentials<Token = ServiceToken>>,
217 term_length: usize,
218 ) -> Result<Vec<Sealed>, SealError> {
219 Self::index_all_terms(records, protected_attributes, cipher, term_length)?
220 .encrypt(cipher)
221 .await
222 }
223
224 pub(crate) async fn seal<'a>(
225 self,
226 protected_attributes: impl AsRef<[Cow<'a, str>]>,
227 cipher: &Encryption<impl Credentials<Token = ServiceToken>>,
228 term_length: usize,
229 ) -> Result<Sealed, SealError> {
230 let mut vec = Self::seal_all([self], protected_attributes, cipher, term_length).await?;
231
232 if vec.len() != 1 {
233 let actual = vec.len();
234
235 return Err(SealError::AssertionFailed(format!(
236 "Expected seal_all to return 1 result but got {actual}"
237 )));
238 }
239
240 Ok(vec.remove(0))
241 }
242}
243
244#[derive(Debug)]
246struct Term {
247 sk: String,
248 value: Vec<u8>,
249}
250
251pub struct Sealed {
253 pk: String,
254 sk: String,
255 attributes: TableAttributes,
256 terms: Vec<Term>,
257}
258
259impl Sealed {
260 pub fn len(&self) -> usize {
261 self.terms.len() + 1
263 }
264
265 pub fn primary_key(&self) -> PrimaryKeyParts {
266 PrimaryKeyParts {
267 pk: self.pk.clone(),
268 sk: self.sk.clone(),
269 }
270 }
271
272 pub fn into_table_entries(
275 self,
276 mut index_predicate: impl FnMut(&AttributeName, &TableAttribute) -> bool,
277 ) -> (SealedTableEntry, Vec<SealedTableEntry>) {
278 let root_attributes = self.attributes;
279
280 let index_attributes: TableAttributes = root_attributes
281 .clone()
282 .into_iter()
283 .filter(|(name, value)| index_predicate(name, value))
284 .map(|(name, value)| (name, value.clone()))
285 .collect::<HashMap<_, _>>()
286 .into();
287
288 let term_entries = self
289 .terms
290 .into_iter()
291 .map(|Term { sk, value }| {
292 SealedTableEntry(TableEntry::new_with_attributes(
293 self.pk.clone(),
294 sk,
295 Some(value),
296 index_attributes.clone(),
297 ))
298 })
299 .collect();
300
301 (
302 SealedTableEntry(TableEntry::new_with_attributes(
303 self.pk,
304 self.sk,
305 None,
306 root_attributes,
307 )),
308 term_entries,
309 )
310 }
311}
312
313#[inline]
315fn flatten_tuple_4<A, B, C, D>((((a, b), c), d): (((A, B), C), D)) -> (A, B, C, D) {
316 (a, b, c, d)
317}
318
319#[inline]
320fn flatten_tuple_3<A, B, C>(((a, b), c): ((A, B), C)) -> (A, B, C) {
321 (a, b, c)
322}