1use mls_rs_core::{
6 key_package::{KeyPackageData, KeyPackageStorage},
7 mls_rs_codec::{MlsDecode, MlsEncode},
8 time::MlsTime,
9};
10use rusqlite::{params, Connection, OptionalExtension};
11use std::sync::{Arc, Mutex};
12
13use crate::SqLiteDataStorageError;
14
15#[derive(Debug, Clone)]
16pub struct SqLiteKeyPackageStorage {
24 connection: Arc<Mutex<Connection>>,
25}
26
27impl SqLiteKeyPackageStorage {
28 pub(crate) fn new(connection: Connection) -> SqLiteKeyPackageStorage {
29 SqLiteKeyPackageStorage {
30 connection: Arc::new(Mutex::new(connection)),
31 }
32 }
33
34 fn insert(
35 &mut self,
36 id: &[u8],
37 key_package: KeyPackageData,
38 ) -> Result<(), SqLiteDataStorageError> {
39 let connection = self.connection.lock().unwrap();
40
41 connection
42 .execute(
43 "INSERT INTO key_package (id, expiration, data) VALUES (?,?,?)",
44 params![
45 id,
46 i64::try_from(key_package.expiration).map_err(|_| {
47 SqLiteDataStorageError::TimestampOverflow(key_package.expiration)
48 })?,
49 key_package
50 .mls_encode_to_vec()
51 .map_err(|e| SqLiteDataStorageError::DataConversionError(e.into()))?
52 ],
53 )
54 .map(|_| ())
55 .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
56 }
57
58 fn get(&self, id: &[u8]) -> Result<Option<KeyPackageData>, SqLiteDataStorageError> {
59 let connection = self.connection.lock().unwrap();
60
61 connection
62 .query_row(
63 "SELECT data FROM key_package WHERE id = ?",
64 params![id],
65 |row| {
66 Ok(
67 KeyPackageData::mls_decode(&mut row.get::<_, Vec<u8>>(0)?.as_slice())
68 .unwrap(),
69 )
70 },
71 )
72 .optional()
73 .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
74 }
75
76 pub fn delete(&self, id: &[u8]) -> Result<(), SqLiteDataStorageError> {
78 let connection = self.connection.lock().unwrap();
79
80 connection
81 .execute("DELETE FROM key_package where id = ?", params![id])
82 .map(|_| ())
83 .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
84 }
85
86 pub fn delete_expired(&self) -> Result<(), SqLiteDataStorageError> {
88 self.delete_expired_by_time(MlsTime::now().seconds_since_epoch())
89 }
90
91 pub fn delete_expired_by_time(&self, time: u64) -> Result<(), SqLiteDataStorageError> {
94 let connection = self.connection.lock().unwrap();
95
96 connection
97 .execute(
98 "DELETE FROM key_package where expiration < ?",
99 params![i64::try_from(time)
100 .map_err(|_| SqLiteDataStorageError::TimestampOverflow(time))?],
101 )
102 .map(|_| ())
103 .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
104 }
105
106 pub fn count(&self) -> Result<usize, SqLiteDataStorageError> {
108 let connection = self.connection.lock().unwrap();
109
110 connection
111 .query_row("SELECT count(*) FROM key_package", params![], |row| {
112 row.get::<_, i64>(0).and_then(|v| {
113 usize::try_from(v).map_err(|_| rusqlite::Error::IntegralValueOutOfRange(0, v))
114 })
115 })
116 .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
117 }
118
119 pub fn count_at_time(&self, time: u64) -> Result<usize, SqLiteDataStorageError> {
123 let connection = self.connection.lock().unwrap();
124
125 connection
126 .query_row(
127 "SELECT count(*) FROM key_package where expiration >= ?",
128 params![i64::try_from(time)
129 .map_err(|_| SqLiteDataStorageError::TimestampOverflow(time))?],
130 |row| {
131 row.get::<_, i64>(0).and_then(|v| {
132 usize::try_from(v)
133 .map_err(|_| rusqlite::Error::IntegralValueOutOfRange(0, v))
134 })
135 },
136 )
137 .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
138 }
139}
140
141#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
142#[cfg_attr(mls_build_async, maybe_async::must_be_async)]
143impl KeyPackageStorage for SqLiteKeyPackageStorage {
144 type Error = SqLiteDataStorageError;
145
146 async fn insert(&mut self, id: Vec<u8>, pkg: KeyPackageData) -> Result<(), Self::Error> {
147 self.insert(id.as_slice(), pkg)
148 }
149
150 async fn get(&self, id: &[u8]) -> Result<Option<KeyPackageData>, Self::Error> {
151 self.get(id)
152 }
153
154 async fn delete(&mut self, id: &[u8]) -> Result<(), Self::Error> {
155 (*self).delete(id)
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::SqLiteKeyPackageStorage;
162 use crate::{
163 SqLiteDataStorageEngine, SqLiteDataStorageError,
164 {connection_strategy::MemoryStrategy, test_utils::gen_rand_bytes},
165 };
166 use assert_matches::assert_matches;
167 use mls_rs_core::{crypto::HpkeSecretKey, key_package::KeyPackageData};
168
169 fn test_storage() -> SqLiteKeyPackageStorage {
170 SqLiteDataStorageEngine::new(MemoryStrategy)
171 .unwrap()
172 .key_package_storage()
173 .unwrap()
174 }
175
176 fn test_key_package() -> (Vec<u8>, KeyPackageData) {
177 let key_id = gen_rand_bytes(32);
178 let key_package = KeyPackageData::new(
179 gen_rand_bytes(256),
180 HpkeSecretKey::from(gen_rand_bytes(256)),
181 HpkeSecretKey::from(gen_rand_bytes(256)),
182 123,
183 );
184
185 (key_id, key_package)
186 }
187
188 #[test]
189 fn key_package_insert() {
190 let mut storage = test_storage();
191 let (key_package_id, key_package) = test_key_package();
192
193 storage
194 .insert(&key_package_id, key_package.clone())
195 .unwrap();
196
197 let from_storage = storage.get(&key_package_id).unwrap().unwrap();
198 assert_eq!(from_storage, key_package);
199 }
200
201 #[test]
202 fn duplicate_insert_should_fail() {
203 let mut storage = test_storage();
204 let (key_package_id, key_package) = test_key_package();
205
206 storage
207 .insert(&key_package_id, key_package.clone())
208 .unwrap();
209
210 let dupe_res = storage.insert(&key_package_id, key_package);
211
212 assert_matches!(dupe_res, Err(SqLiteDataStorageError::SqlEngineError(_)));
213 }
214
215 #[test]
216 fn key_package_not_found() {
217 let mut storage = test_storage();
218 let (key_package_id, key_package) = test_key_package();
219
220 storage.insert(&key_package_id, key_package).unwrap();
221
222 let (another_package_id, _) = test_key_package();
223
224 assert!(storage.get(&another_package_id).unwrap().is_none());
225 }
226
227 #[test]
228 fn key_package_delete() {
229 let mut storage = test_storage();
230 let (key_package_id, key_package) = test_key_package();
231
232 storage.insert(&key_package_id, key_package).unwrap();
233
234 storage.delete(&key_package_id).unwrap();
235 assert!(storage.get(&key_package_id).unwrap().is_none());
236 }
237
238 #[test]
239 fn expired_key_package_gelete() {
240 let mut storage = test_storage();
241
242 let data = [1, 15, 30, 1698652376].map(|exp| {
243 let mut kp = test_key_package();
244 kp.1.expiration = exp;
245 kp
246 });
247
248 for (id, data) in &data {
249 storage.insert(id, data.clone()).unwrap();
250 }
251
252 storage.delete_expired_by_time(30).unwrap();
253
254 assert!(storage.get(&data[0].0).unwrap().is_none());
255 assert!(storage.get(&data[1].0).unwrap().is_none());
256 storage.get(&data[2].0).unwrap().unwrap();
257 storage.get(&data[3].0).unwrap().unwrap();
258
259 storage.delete_expired().unwrap();
260
261 assert!(storage.get(&data[2].0).unwrap().is_none());
262 assert!(storage.get(&data[3].0).unwrap().is_none());
263 }
264
265 #[test]
266 fn key_count() {
267 let mut storage = test_storage();
268
269 let test_packages = (0..10).map(|_| test_key_package()).collect::<Vec<_>>();
270
271 test_packages
272 .into_iter()
273 .for_each(|(key_package_id, key_package)| {
274 storage.insert(&key_package_id, key_package).unwrap();
275 });
276
277 assert_eq!(storage.count().unwrap(), 10);
278 }
279
280 #[test]
281 fn key_count_at_time() {
282 let mut storage = test_storage();
283
284 let mut kp_1 = test_key_package();
285 kp_1.1.expiration = 1;
286 storage.insert(&kp_1.0, kp_1.1).unwrap();
287
288 let mut kp_2 = test_key_package();
289 kp_2.1.expiration = 2;
290 storage.insert(&kp_2.0, kp_2.1).unwrap();
291
292 assert_eq!(storage.count_at_time(3).unwrap(), 0);
293 assert_eq!(storage.count_at_time(2).unwrap(), 1);
294 assert_eq!(storage.count_at_time(1).unwrap(), 2);
295 assert_eq!(storage.count_at_time(0).unwrap(), 2);
296 }
297
298 #[test]
299 fn timestamp_overflow() {
300 let mut storage = test_storage();
301 let (id, mut kp) = test_key_package();
302 kp.expiration = u64::MAX;
303
304 let err = storage.insert(&id, kp).unwrap_err();
305 assert_matches!(err, SqLiteDataStorageError::TimestampOverflow(u64::MAX));
306
307 let err = storage.delete_expired_by_time(u64::MAX).unwrap_err();
308 assert_matches!(err, SqLiteDataStorageError::TimestampOverflow(u64::MAX));
309
310 let err = storage.count_at_time(u64::MAX).unwrap_err();
311 assert_matches!(err, SqLiteDataStorageError::TimestampOverflow(u64::MAX));
312 }
313}