Skip to main content

mls_rs_provider_sqlite/
key_package.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use mls_rs_core::{
6    key_package::{KeyPackageData, KeyPackageStorage},
7    mls_rs_codec::{MlsDecode, MlsEncode},
8    time::MlsTime,
9};
10use rusqlite::{params, Connection, OptionalExtension};
11use std::sync::{Arc, Mutex};
12
13use crate::SqLiteDataStorageError;
14
15#[derive(Debug, Clone)]
16/// SQLite storage for MLS Key Packages.
17///
18/// # Limitations
19///
20/// Expiration timestamps are stored as SQLite INTEGER (signed 64-bit), limiting
21/// the maximum timestamp to [`i64::MAX`] (9,223,372,036,854,775,807). Operations
22/// with timestamps exceeding this value will return [`SqLiteDataStorageError::TimestampOverflow`].
23pub struct SqLiteKeyPackageStorage {
24    connection: Arc<Mutex<Connection>>,
25}
26
27impl SqLiteKeyPackageStorage {
28    pub(crate) fn new(connection: Connection) -> SqLiteKeyPackageStorage {
29        SqLiteKeyPackageStorage {
30            connection: Arc::new(Mutex::new(connection)),
31        }
32    }
33
34    fn insert(
35        &mut self,
36        id: &[u8],
37        key_package: KeyPackageData,
38    ) -> Result<(), SqLiteDataStorageError> {
39        let connection = self.connection.lock().unwrap();
40
41        connection
42            .execute(
43                "INSERT INTO key_package (id, expiration, data) VALUES (?,?,?)",
44                params![
45                    id,
46                    i64::try_from(key_package.expiration).map_err(|_| {
47                        SqLiteDataStorageError::TimestampOverflow(key_package.expiration)
48                    })?,
49                    key_package
50                        .mls_encode_to_vec()
51                        .map_err(|e| SqLiteDataStorageError::DataConversionError(e.into()))?
52                ],
53            )
54            .map(|_| ())
55            .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
56    }
57
58    fn get(&self, id: &[u8]) -> Result<Option<KeyPackageData>, SqLiteDataStorageError> {
59        let connection = self.connection.lock().unwrap();
60
61        connection
62            .query_row(
63                "SELECT data FROM key_package WHERE id = ?",
64                params![id],
65                |row| {
66                    Ok(
67                        KeyPackageData::mls_decode(&mut row.get::<_, Vec<u8>>(0)?.as_slice())
68                            .unwrap(),
69                    )
70                },
71            )
72            .optional()
73            .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
74    }
75
76    /// Delete a specific key package from storage based on it's id.
77    pub fn delete(&self, id: &[u8]) -> Result<(), SqLiteDataStorageError> {
78        let connection = self.connection.lock().unwrap();
79
80        connection
81            .execute("DELETE FROM key_package where id = ?", params![id])
82            .map(|_| ())
83            .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
84    }
85
86    /// Delete key packages that are expired based on the current system clock time.
87    pub fn delete_expired(&self) -> Result<(), SqLiteDataStorageError> {
88        self.delete_expired_by_time(MlsTime::now().seconds_since_epoch())
89    }
90
91    /// Delete key packages that are expired based on an application provided time in seconds since
92    /// unix epoch.
93    pub fn delete_expired_by_time(&self, time: u64) -> Result<(), SqLiteDataStorageError> {
94        let connection = self.connection.lock().unwrap();
95
96        connection
97            .execute(
98                "DELETE FROM key_package where expiration < ?",
99                params![i64::try_from(time)
100                    .map_err(|_| SqLiteDataStorageError::TimestampOverflow(time))?],
101            )
102            .map(|_| ())
103            .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
104    }
105
106    /// Total number of key packages held in storage.
107    pub fn count(&self) -> Result<usize, SqLiteDataStorageError> {
108        let connection = self.connection.lock().unwrap();
109
110        connection
111            .query_row("SELECT count(*) FROM key_package", params![], |row| {
112                row.get::<_, i64>(0).and_then(|v| {
113                    usize::try_from(v).map_err(|_| rusqlite::Error::IntegralValueOutOfRange(0, v))
114                })
115            })
116            .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
117    }
118
119    /// Total number of key packages that will still remain in storage at a specific application provided
120    /// time in seconds since unix epoch. This assumes that the application would also be calling
121    /// [SqLiteKeyPackageStorage::delete_expired] at a reasonable cadence to be accurate.
122    pub fn count_at_time(&self, time: u64) -> Result<usize, SqLiteDataStorageError> {
123        let connection = self.connection.lock().unwrap();
124
125        connection
126            .query_row(
127                "SELECT count(*) FROM key_package where expiration >= ?",
128                params![i64::try_from(time)
129                    .map_err(|_| SqLiteDataStorageError::TimestampOverflow(time))?],
130                |row| {
131                    row.get::<_, i64>(0).and_then(|v| {
132                        usize::try_from(v)
133                            .map_err(|_| rusqlite::Error::IntegralValueOutOfRange(0, v))
134                    })
135                },
136            )
137            .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
138    }
139}
140
141#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
142#[cfg_attr(mls_build_async, maybe_async::must_be_async)]
143impl KeyPackageStorage for SqLiteKeyPackageStorage {
144    type Error = SqLiteDataStorageError;
145
146    async fn insert(&mut self, id: Vec<u8>, pkg: KeyPackageData) -> Result<(), Self::Error> {
147        self.insert(id.as_slice(), pkg)
148    }
149
150    async fn get(&self, id: &[u8]) -> Result<Option<KeyPackageData>, Self::Error> {
151        self.get(id)
152    }
153
154    async fn delete(&mut self, id: &[u8]) -> Result<(), Self::Error> {
155        (*self).delete(id)
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::SqLiteKeyPackageStorage;
162    use crate::{
163        SqLiteDataStorageEngine, SqLiteDataStorageError,
164        {connection_strategy::MemoryStrategy, test_utils::gen_rand_bytes},
165    };
166    use assert_matches::assert_matches;
167    use mls_rs_core::{crypto::HpkeSecretKey, key_package::KeyPackageData};
168
169    fn test_storage() -> SqLiteKeyPackageStorage {
170        SqLiteDataStorageEngine::new(MemoryStrategy)
171            .unwrap()
172            .key_package_storage()
173            .unwrap()
174    }
175
176    fn test_key_package() -> (Vec<u8>, KeyPackageData) {
177        let key_id = gen_rand_bytes(32);
178        let key_package = KeyPackageData::new(
179            gen_rand_bytes(256),
180            HpkeSecretKey::from(gen_rand_bytes(256)),
181            HpkeSecretKey::from(gen_rand_bytes(256)),
182            123,
183        );
184
185        (key_id, key_package)
186    }
187
188    #[test]
189    fn key_package_insert() {
190        let mut storage = test_storage();
191        let (key_package_id, key_package) = test_key_package();
192
193        storage
194            .insert(&key_package_id, key_package.clone())
195            .unwrap();
196
197        let from_storage = storage.get(&key_package_id).unwrap().unwrap();
198        assert_eq!(from_storage, key_package);
199    }
200
201    #[test]
202    fn duplicate_insert_should_fail() {
203        let mut storage = test_storage();
204        let (key_package_id, key_package) = test_key_package();
205
206        storage
207            .insert(&key_package_id, key_package.clone())
208            .unwrap();
209
210        let dupe_res = storage.insert(&key_package_id, key_package);
211
212        assert_matches!(dupe_res, Err(SqLiteDataStorageError::SqlEngineError(_)));
213    }
214
215    #[test]
216    fn key_package_not_found() {
217        let mut storage = test_storage();
218        let (key_package_id, key_package) = test_key_package();
219
220        storage.insert(&key_package_id, key_package).unwrap();
221
222        let (another_package_id, _) = test_key_package();
223
224        assert!(storage.get(&another_package_id).unwrap().is_none());
225    }
226
227    #[test]
228    fn key_package_delete() {
229        let mut storage = test_storage();
230        let (key_package_id, key_package) = test_key_package();
231
232        storage.insert(&key_package_id, key_package).unwrap();
233
234        storage.delete(&key_package_id).unwrap();
235        assert!(storage.get(&key_package_id).unwrap().is_none());
236    }
237
238    #[test]
239    fn expired_key_package_gelete() {
240        let mut storage = test_storage();
241
242        let data = [1, 15, 30, 1698652376].map(|exp| {
243            let mut kp = test_key_package();
244            kp.1.expiration = exp;
245            kp
246        });
247
248        for (id, data) in &data {
249            storage.insert(id, data.clone()).unwrap();
250        }
251
252        storage.delete_expired_by_time(30).unwrap();
253
254        assert!(storage.get(&data[0].0).unwrap().is_none());
255        assert!(storage.get(&data[1].0).unwrap().is_none());
256        storage.get(&data[2].0).unwrap().unwrap();
257        storage.get(&data[3].0).unwrap().unwrap();
258
259        storage.delete_expired().unwrap();
260
261        assert!(storage.get(&data[2].0).unwrap().is_none());
262        assert!(storage.get(&data[3].0).unwrap().is_none());
263    }
264
265    #[test]
266    fn key_count() {
267        let mut storage = test_storage();
268
269        let test_packages = (0..10).map(|_| test_key_package()).collect::<Vec<_>>();
270
271        test_packages
272            .into_iter()
273            .for_each(|(key_package_id, key_package)| {
274                storage.insert(&key_package_id, key_package).unwrap();
275            });
276
277        assert_eq!(storage.count().unwrap(), 10);
278    }
279
280    #[test]
281    fn key_count_at_time() {
282        let mut storage = test_storage();
283
284        let mut kp_1 = test_key_package();
285        kp_1.1.expiration = 1;
286        storage.insert(&kp_1.0, kp_1.1).unwrap();
287
288        let mut kp_2 = test_key_package();
289        kp_2.1.expiration = 2;
290        storage.insert(&kp_2.0, kp_2.1).unwrap();
291
292        assert_eq!(storage.count_at_time(3).unwrap(), 0);
293        assert_eq!(storage.count_at_time(2).unwrap(), 1);
294        assert_eq!(storage.count_at_time(1).unwrap(), 2);
295        assert_eq!(storage.count_at_time(0).unwrap(), 2);
296    }
297
298    #[test]
299    fn timestamp_overflow() {
300        let mut storage = test_storage();
301        let (id, mut kp) = test_key_package();
302        kp.expiration = u64::MAX;
303
304        let err = storage.insert(&id, kp).unwrap_err();
305        assert_matches!(err, SqLiteDataStorageError::TimestampOverflow(u64::MAX));
306
307        let err = storage.delete_expired_by_time(u64::MAX).unwrap_err();
308        assert_matches!(err, SqLiteDataStorageError::TimestampOverflow(u64::MAX));
309
310        let err = storage.count_at_time(u64::MAX).unwrap_err();
311        assert_matches!(err, SqLiteDataStorageError::TimestampOverflow(u64::MAX));
312    }
313}