cipherstash_dynamodb/encrypted_table/
query.rs

1use aws_sdk_dynamodb::{primitives::Blob, types::AttributeValue};
2use cipherstash_client::{
3    credentials::{service_credentials::ServiceToken, Credentials},
4    encryption::{
5        compound_indexer::{ComposableIndex, ComposablePlaintext},
6        Encryption, Plaintext,
7    },
8};
9use itertools::Itertools;
10use std::{borrow::Cow, collections::HashMap, marker::PhantomData};
11
12use crate::{
13    traits::{Decryptable, Searchable},
14    Identifiable, IndexType, SingleIndex,
15};
16use cipherstash_client::encryption::{compound_indexer::CompoundIndex, IndexTerm};
17
18use super::{Dynamo, EncryptedTable, QueryError};
19
20pub struct QueryBuilder<S, B = ()> {
21    parts: Vec<(String, SingleIndex, Plaintext)>,
22    backend: B,
23    __searchable: PhantomData<S>,
24}
25
26pub struct PreparedQuery {
27    index_name: String,
28    type_name: String,
29    composed_index: Box<dyn ComposableIndex + Send>,
30    plaintext: ComposablePlaintext,
31}
32
33impl PreparedQuery {
34    pub async fn encrypt(
35        self,
36        cipher: &Encryption<impl Credentials<Token = ServiceToken>>,
37    ) -> Result<AttributeValue, QueryError> {
38        let PreparedQuery {
39            index_name,
40            composed_index,
41            plaintext,
42            type_name,
43        } = self;
44
45        let index_term = cipher.compound_query(
46            &CompoundIndex::new(composed_index),
47            plaintext,
48            Some(format!("{}#{}", type_name, index_name)),
49            12,
50        )?;
51
52        // With DynamoDB queries must always return a single term
53        let term = if let IndexTerm::Binary(x) = index_term {
54            AttributeValue::B(Blob::new(x))
55        } else {
56            Err(QueryError::Other(format!(
57                "Returned IndexTerm had invalid type: {index_term:?}"
58            )))?
59        };
60
61        Ok(term)
62    }
63
64    pub async fn send(
65        self,
66        table: &EncryptedTable<Dynamo>,
67    ) -> Result<Vec<HashMap<String, AttributeValue>>, QueryError> {
68        let term = self.encrypt(&table.cipher).await?;
69
70        let query = table
71            .db
72            .query()
73            .table_name(&table.db.table_name)
74            .index_name("TermIndex")
75            .key_condition_expression("term = :term")
76            .expression_attribute_values(":term", term);
77
78        query
79            .send()
80            .await?
81            .items
82            .ok_or_else(|| QueryError::Other("Expected items entry on aws response".into()))
83    }
84}
85
86impl<S> Default for QueryBuilder<S> {
87    fn default() -> Self {
88        Self::new()
89    }
90}
91
92impl<S> QueryBuilder<S> {
93    pub fn new() -> Self {
94        Self {
95            parts: vec![],
96            backend: Default::default(),
97            __searchable: Default::default(),
98        }
99    }
100}
101
102impl<S, B> QueryBuilder<S, B> {
103    pub fn with_backend(backend: B) -> Self {
104        Self {
105            parts: vec![],
106            backend,
107            __searchable: Default::default(),
108        }
109    }
110
111    pub fn eq(mut self, name: impl Into<String>, plaintext: impl Into<Plaintext>) -> Self {
112        self.parts
113            .push((name.into(), SingleIndex::Exact, plaintext.into()));
114        self
115    }
116
117    pub fn starts_with(mut self, name: impl Into<String>, plaintext: impl Into<Plaintext>) -> Self {
118        self.parts
119            .push((name.into(), SingleIndex::Prefix, plaintext.into()));
120        self
121    }
122}
123
124impl<S, B> QueryBuilder<S, B>
125where
126    S: Searchable,
127{
128    pub fn build(self) -> Result<PreparedQuery, QueryError> {
129        PreparedQueryBuilder::new::<S>().build(self.parts)
130    }
131}
132
133impl<S> QueryBuilder<S, &EncryptedTable<Dynamo>>
134where
135    S: Searchable + Identifiable,
136{
137    pub async fn load<T>(self) -> Result<Vec<T>, QueryError>
138    where
139        T: Decryptable + Identifiable,
140    {
141        let table = self.backend;
142        let query = self.build()?;
143
144        let items = query.send(table).await?;
145        let results = table.decrypt_all(items).await?;
146
147        Ok(results)
148    }
149}
150
151impl<S> QueryBuilder<S, &EncryptedTable<Dynamo>>
152where
153    S: Searchable + Decryptable + Identifiable,
154{
155    pub async fn send(self) -> Result<Vec<S>, QueryError> {
156        self.load::<S>().await
157    }
158}
159
160pub struct PreparedQueryBuilder {
161    pub type_name: Cow<'static, str>,
162    pub index_by_name: fn(&str, IndexType) -> Option<Box<dyn ComposableIndex + Send>>,
163}
164
165impl PreparedQueryBuilder {
166    pub fn new<S: Searchable>() -> Self {
167        Self {
168            type_name: S::type_name(),
169            index_by_name: S::index_by_name,
170        }
171    }
172
173    pub fn build(
174        &self,
175        parts: Vec<(String, SingleIndex, Plaintext)>,
176    ) -> Result<PreparedQuery, QueryError> {
177        let items_len = parts.len();
178
179        // this is the simplest way to brute force the index names but relies on some gross
180        // stringly typing which doesn't feel good
181        for perm in parts.iter().permutations(items_len) {
182            let (indexes, plaintexts): (Vec<(&String, &SingleIndex)>, Vec<&Plaintext>) =
183                perm.into_iter().map(|x| ((&x.0, &x.1), &x.2)).unzip();
184
185            let index_name = indexes.iter().map(|(index_name, _)| index_name).join("#");
186
187            let mut indexes_iter = indexes.iter().map(|(_, index)| **index);
188
189            let index_type = match indexes.len() {
190                1 => IndexType::Single(indexes_iter.next().ok_or_else(|| {
191                    QueryError::InvalidQuery(
192                        "Expected indexes_iter to include have enough components".to_string(),
193                    )
194                })?),
195
196                2 => IndexType::Compound2((
197                    indexes_iter.next().ok_or_else(|| {
198                        QueryError::InvalidQuery(
199                            "Expected indexes_iter to include have enough components".to_string(),
200                        )
201                    })?,
202                    indexes_iter.next().ok_or_else(|| {
203                        QueryError::InvalidQuery(
204                            "Expected indexes_iter to include have enough components".to_string(),
205                        )
206                    })?,
207                )),
208
209                x => {
210                    return Err(QueryError::InvalidQuery(format!(
211                        "Query included an invalid number of components: {x}"
212                    )));
213                }
214            };
215
216            if let Some(composed_index) = (self.index_by_name)(index_name.as_str(), index_type) {
217                let mut plaintext = ComposablePlaintext::new(plaintexts[0].clone());
218
219                for p in plaintexts[1..].iter() {
220                    plaintext = plaintext
221                        .try_compose((*p).clone())
222                        .expect("Failed to compose");
223                }
224
225                return Ok(PreparedQuery {
226                    index_name,
227                    type_name: self.type_name.to_string(),
228                    plaintext,
229                    composed_index,
230                });
231            }
232        }
233
234        let fields = parts.iter().map(|x| &x.0).join(",");
235
236        Err(QueryError::InvalidQuery(format!(
237            "Could not build query for fields: {fields}"
238        )))
239    }
240}