kassandra_shared/
db.rs

1//! Shared types to be stored in the host databases
2
3use borsh::{BorshDeserialize, BorshSerialize};
4use chacha20poly1305::Key;
5use core::fmt::Formatter;
6use serde::de::{Error, Visitor};
7use serde::{Deserialize, Deserializer, Serialize, Serializer};
8
9/// A wrapper around a ChaCha key
10///
11/// Used to encrypted enclave responses for users
12#[derive(Debug, Clone)]
13pub struct EncKey(Key);
14
15impl EncKey {
16    /// Get the hash of this key
17    pub fn hash(&self) -> alloc::string::String {
18        use sha2::Digest;
19        let mut hasher = sha2::Sha256::new();
20        hasher.update(self.0.as_slice());
21        let hash: [u8; 32] = hasher.finalize().into();
22        hex::encode(hash)
23    }
24}
25
26impl From<Key> for EncKey {
27    fn from(key: Key) -> Self {
28        Self(key)
29    }
30}
31
32impl<'a> From<&'a EncKey> for &'a Key {
33    fn from(key: &'a EncKey) -> Self {
34        &key.0
35    }
36}
37
38impl Serialize for EncKey {
39    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
40    where
41        S: Serializer,
42    {
43        serializer.serialize_bytes(self.0.as_slice())
44    }
45}
46
47impl<'de> Deserialize<'de> for EncKey {
48    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
49    where
50        D: Deserializer<'de>,
51    {
52        struct EncKeyVisitor;
53        impl Visitor<'_> for EncKeyVisitor {
54            type Value = EncKey;
55
56            fn expecting(&self, formatter: &mut Formatter) -> core::fmt::Result {
57                formatter.write_str("32 bytes")
58            }
59
60            fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
61            where
62                E: Error,
63            {
64                let bytes: [u8; 32] = v
65                    .try_into()
66                    .map_err(|_| Error::custom("Unexpected length of encryption key"))?;
67                Ok(EncKey(*Key::from_slice(&bytes)))
68            }
69        }
70
71        deserializer.deserialize_bytes(EncKeyVisitor)
72    }
73}
74
75/// Simplified domain type for indexing a Tx on chain
76#[derive(
77    Debug,
78    Copy,
79    Clone,
80    Hash,
81    PartialEq,
82    Eq,
83    PartialOrd,
84    Ord,
85    Serialize,
86    Deserialize,
87    BorshSerialize,
88    BorshDeserialize,
89)]
90pub struct Index {
91    pub height: u64,
92    pub tx: u32,
93}
94
95impl Index {
96    pub fn as_bytes(&self) -> [u8; 12] {
97        let mut bytes = [0u8; 12];
98        let h_bytes = self.height.to_le_bytes();
99        let tx_bytes = self.tx.to_le_bytes();
100        for ix in 0..12 {
101            if ix < 8 {
102                bytes[ix] = h_bytes[ix];
103            } else {
104                bytes[ix] = tx_bytes[ix - 8];
105            }
106        }
107        bytes
108    }
109
110    pub fn try_from_bytes(bytes: &[u8]) -> Option<Self> {
111        if bytes.len() != 12 {
112            None
113        } else {
114            let mut h_bytes = [0u8; 8];
115            let mut tx_bytes = [0u8; 4];
116            for ix in 0..12 {
117                if ix < 8 {
118                    h_bytes[ix] = bytes[ix];
119                } else {
120                    tx_bytes[ix - 8] = bytes[ix];
121                }
122            }
123            Some(Self {
124                height: u64::from_le_bytes(h_bytes),
125                tx: u32::from_le_bytes(tx_bytes),
126            })
127        }
128    }
129}
130
131#[derive(
132    Debug,
133    Clone,
134    Hash,
135    PartialEq,
136    Eq,
137    Default,
138    Serialize,
139    Deserialize,
140    BorshSerialize,
141    BorshDeserialize,
142)]
143pub struct IndexList(alloc::vec::Vec<Index>);
144
145impl IndexList {
146    /// Try to parse bytes into a list of indices
147    pub fn try_from_bytes(bytes: &[u8]) -> Option<Self> {
148        if 12 * (bytes.len() / 12) != bytes.len() {
149            return None;
150        }
151        let len = bytes.len() / 12;
152        let indices: alloc::vec::Vec<_> =
153            bytes.chunks(12).filter_map(Index::try_from_bytes).collect();
154        if indices.len() != len {
155            None
156        } else {
157            Some(Self(indices))
158        }
159    }
160
161    /// Given two index sets, produce a new index set
162    /// modifying self in-place.
163    ///
164    /// We assume `self` is synced ahead of `other`. The intersection of
165    /// the indices up to the common block height is kept along with all
166    /// indices of `self` with block height greater than `other`'s maximum
167    /// block height.
168    pub fn combine(&mut self, mut other: Self) {
169        if self.0.is_empty() {
170            *self = other;
171            return;
172        }
173        if other.0.is_empty() {
174            return;
175        }
176        self.0.sort();
177        other.0.sort();
178        let a_height = self.0.last().map(|ix| ix.height).unwrap_or_default();
179        let b_height = other.0.last().map(|ix| ix.height).unwrap_or_default();
180        // from here on out, we assume that `self` is synced further than `other`
181        let height = if a_height < b_height {
182            core::mem::swap(self, &mut other);
183            a_height
184        } else {
185            b_height
186        };
187        self.0.retain(|ix| {
188            if ix.height > height {
189                true
190            } else {
191                other.contains(ix)
192            }
193        });
194    }
195
196    /// Create a union of two index sets
197    pub fn union(&mut self, other: &Self) {
198        self.0.extend_from_slice(&other.0[..]);
199        self.0.sort();
200        self.0.dedup();
201    }
202
203    /// Check if an index is contained in `self`
204    /// Assumes `self` is sorted.
205    pub fn contains(&self, index: &Index) -> bool {
206        self.0.binary_search(index).is_ok()
207    }
208
209    /// Check if the index set contains a given height.
210    /// Assumes `self` is sorted.
211    pub fn contains_height(&self, height: u64) -> bool {
212        self.0.binary_search_by_key(&height, |ix| ix.height).is_ok()
213    }
214
215    /// Return and iterator of references to the
216    /// contained indices
217    pub fn iter(&self) -> alloc::slice::Iter<Index> {
218        self.0.iter()
219    }
220
221    /// Function for filtering out elements in-place
222    pub fn retain<P>(&mut self, pred: P)
223    where
224        P: FnMut(&Index) -> bool,
225    {
226        self.0.retain(pred);
227    }
228}
229
230impl IntoIterator for IndexList {
231    type Item = Index;
232    type IntoIter = alloc::vec::IntoIter<Index>;
233
234    fn into_iter(self) -> Self::IntoIter {
235        self.0.into_iter()
236    }
237}
238
239impl FromIterator<Index> for IndexList {
240    fn from_iter<T: IntoIterator<Item = Index>>(iter: T) -> Self {
241        Self(iter.into_iter().collect())
242    }
243}
244
245/// The response from the enclave for performing
246/// FMD for a particular uses
247#[derive(Debug, Clone, Serialize, Deserialize)]
248pub struct EncryptedResponse {
249    /// Hash of user's encryption key used to identify
250    /// when database entries belong to them
251    pub owner: alloc::string::String,
252    /// Nonce needed to decrypt the indices
253    pub nonce: [u8; 12],
254    /// encrypted indices
255    pub indices: alloc::vec::Vec<u8>,
256    /// The last height FMD was performed at
257    pub height: u64,
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use alloc::vec::Vec;
264
265    #[test]
266    fn test_combine_indices() {
267        let a = IndexList(Vec::from([
268            Index { height: 0, tx: 0 },
269            Index { height: 0, tx: 1 },
270            Index { height: 1, tx: 0 },
271            Index { height: 3, tx: 0 },
272        ]));
273        let mut b = IndexList(Vec::from([
274            Index { height: 0, tx: 1 },
275            Index { height: 1, tx: 4 },
276        ]));
277        let expected = IndexList(Vec::from([
278            Index { height: 0, tx: 1 },
279            Index { height: 3, tx: 0 },
280        ]));
281
282        let mut first = a.clone();
283        first.combine(b.clone());
284        assert_eq!(first, expected);
285        b.combine(a.clone());
286        assert_eq!(b, expected);
287
288        let mut new = IndexList::default();
289        new.combine(a.clone());
290        assert_eq!(new, a);
291        let mut third = a.clone();
292        third.combine(IndexList::default());
293        assert_eq!(third, a);
294    }
295}