cipherstash_dynamodb/encrypted_table/
query.rs1use aws_sdk_dynamodb::{primitives::Blob, types::AttributeValue};
2use cipherstash_client::{
3 credentials::{service_credentials::ServiceToken, Credentials},
4 encryption::{
5 compound_indexer::{ComposableIndex, ComposablePlaintext},
6 Encryption, Plaintext,
7 },
8};
9use itertools::Itertools;
10use std::{borrow::Cow, collections::HashMap, marker::PhantomData};
11
12use crate::{
13 traits::{Decryptable, Searchable},
14 Identifiable, IndexType, SingleIndex,
15};
16use cipherstash_client::encryption::{compound_indexer::CompoundIndex, IndexTerm};
17
18use super::{Dynamo, EncryptedTable, QueryError};
19
20pub struct QueryBuilder<S, B = ()> {
21 parts: Vec<(String, SingleIndex, Plaintext)>,
22 backend: B,
23 __searchable: PhantomData<S>,
24}
25
26pub struct PreparedQuery {
27 index_name: String,
28 type_name: String,
29 composed_index: Box<dyn ComposableIndex + Send>,
30 plaintext: ComposablePlaintext,
31}
32
33impl PreparedQuery {
34 pub async fn encrypt(
35 self,
36 cipher: &Encryption<impl Credentials<Token = ServiceToken>>,
37 ) -> Result<AttributeValue, QueryError> {
38 let PreparedQuery {
39 index_name,
40 composed_index,
41 plaintext,
42 type_name,
43 } = self;
44
45 let index_term = cipher.compound_query(
46 &CompoundIndex::new(composed_index),
47 plaintext,
48 Some(format!("{}#{}", type_name, index_name)),
49 12,
50 )?;
51
52 let term = if let IndexTerm::Binary(x) = index_term {
54 AttributeValue::B(Blob::new(x))
55 } else {
56 Err(QueryError::Other(format!(
57 "Returned IndexTerm had invalid type: {index_term:?}"
58 )))?
59 };
60
61 Ok(term)
62 }
63
64 pub async fn send(
65 self,
66 table: &EncryptedTable<Dynamo>,
67 ) -> Result<Vec<HashMap<String, AttributeValue>>, QueryError> {
68 let term = self.encrypt(&table.cipher).await?;
69
70 let query = table
71 .db
72 .query()
73 .table_name(&table.db.table_name)
74 .index_name("TermIndex")
75 .key_condition_expression("term = :term")
76 .expression_attribute_values(":term", term);
77
78 query
79 .send()
80 .await?
81 .items
82 .ok_or_else(|| QueryError::Other("Expected items entry on aws response".into()))
83 }
84}
85
86impl<S> Default for QueryBuilder<S> {
87 fn default() -> Self {
88 Self::new()
89 }
90}
91
92impl<S> QueryBuilder<S> {
93 pub fn new() -> Self {
94 Self {
95 parts: vec![],
96 backend: Default::default(),
97 __searchable: Default::default(),
98 }
99 }
100}
101
102impl<S, B> QueryBuilder<S, B> {
103 pub fn with_backend(backend: B) -> Self {
104 Self {
105 parts: vec![],
106 backend,
107 __searchable: Default::default(),
108 }
109 }
110
111 pub fn eq(mut self, name: impl Into<String>, plaintext: impl Into<Plaintext>) -> Self {
112 self.parts
113 .push((name.into(), SingleIndex::Exact, plaintext.into()));
114 self
115 }
116
117 pub fn starts_with(mut self, name: impl Into<String>, plaintext: impl Into<Plaintext>) -> Self {
118 self.parts
119 .push((name.into(), SingleIndex::Prefix, plaintext.into()));
120 self
121 }
122}
123
124impl<S, B> QueryBuilder<S, B>
125where
126 S: Searchable,
127{
128 pub fn build(self) -> Result<PreparedQuery, QueryError> {
129 PreparedQueryBuilder::new::<S>().build(self.parts)
130 }
131}
132
133impl<S> QueryBuilder<S, &EncryptedTable<Dynamo>>
134where
135 S: Searchable + Identifiable,
136{
137 pub async fn load<T>(self) -> Result<Vec<T>, QueryError>
138 where
139 T: Decryptable + Identifiable,
140 {
141 let table = self.backend;
142 let query = self.build()?;
143
144 let items = query.send(table).await?;
145 let results = table.decrypt_all(items).await?;
146
147 Ok(results)
148 }
149}
150
151impl<S> QueryBuilder<S, &EncryptedTable<Dynamo>>
152where
153 S: Searchable + Decryptable + Identifiable,
154{
155 pub async fn send(self) -> Result<Vec<S>, QueryError> {
156 self.load::<S>().await
157 }
158}
159
160pub struct PreparedQueryBuilder {
161 pub type_name: Cow<'static, str>,
162 pub index_by_name: fn(&str, IndexType) -> Option<Box<dyn ComposableIndex + Send>>,
163}
164
165impl PreparedQueryBuilder {
166 pub fn new<S: Searchable>() -> Self {
167 Self {
168 type_name: S::type_name(),
169 index_by_name: S::index_by_name,
170 }
171 }
172
173 pub fn build(
174 &self,
175 parts: Vec<(String, SingleIndex, Plaintext)>,
176 ) -> Result<PreparedQuery, QueryError> {
177 let items_len = parts.len();
178
179 for perm in parts.iter().permutations(items_len) {
182 let (indexes, plaintexts): (Vec<(&String, &SingleIndex)>, Vec<&Plaintext>) =
183 perm.into_iter().map(|x| ((&x.0, &x.1), &x.2)).unzip();
184
185 let index_name = indexes.iter().map(|(index_name, _)| index_name).join("#");
186
187 let mut indexes_iter = indexes.iter().map(|(_, index)| **index);
188
189 let index_type = match indexes.len() {
190 1 => IndexType::Single(indexes_iter.next().ok_or_else(|| {
191 QueryError::InvalidQuery(
192 "Expected indexes_iter to include have enough components".to_string(),
193 )
194 })?),
195
196 2 => IndexType::Compound2((
197 indexes_iter.next().ok_or_else(|| {
198 QueryError::InvalidQuery(
199 "Expected indexes_iter to include have enough components".to_string(),
200 )
201 })?,
202 indexes_iter.next().ok_or_else(|| {
203 QueryError::InvalidQuery(
204 "Expected indexes_iter to include have enough components".to_string(),
205 )
206 })?,
207 )),
208
209 x => {
210 return Err(QueryError::InvalidQuery(format!(
211 "Query included an invalid number of components: {x}"
212 )));
213 }
214 };
215
216 if let Some(composed_index) = (self.index_by_name)(index_name.as_str(), index_type) {
217 let mut plaintext = ComposablePlaintext::new(plaintexts[0].clone());
218
219 for p in plaintexts[1..].iter() {
220 plaintext = plaintext
221 .try_compose((*p).clone())
222 .expect("Failed to compose");
223 }
224
225 return Ok(PreparedQuery {
226 index_name,
227 type_name: self.type_name.to_string(),
228 plaintext,
229 composed_index,
230 });
231 }
232 }
233
234 let fields = parts.iter().map(|x| &x.0).join(",");
235
236 Err(QueryError::InvalidQuery(format!(
237 "Could not build query for fields: {fields}"
238 )))
239 }
240}