mls_rs_provider_sqlite/
psk.rs1use crate::SqLiteDataStorageError;
6use mls_rs_core::psk::{ExternalPskId, PreSharedKey, PreSharedKeyStorage};
7use rusqlite::{params, Connection, OptionalExtension};
8use std::{
9 ops::Deref,
10 sync::{Arc, Mutex},
11};
12
13#[derive(Debug, Clone)]
14pub struct SqLitePreSharedKeyStorage {
16 connection: Arc<Mutex<Connection>>,
17}
18
19impl SqLitePreSharedKeyStorage {
20 pub(crate) fn new(connection: Connection) -> SqLitePreSharedKeyStorage {
21 SqLitePreSharedKeyStorage {
22 connection: Arc::new(Mutex::new(connection)),
23 }
24 }
25
26 pub fn insert(&self, psk_id: &[u8], psk: &PreSharedKey) -> Result<(), SqLiteDataStorageError> {
28 let connection = self.connection.lock().unwrap();
29
30 connection
32 .execute(
33 "INSERT INTO psk (psk_id, data) VALUES (?,?) ON CONFLICT(psk_id) DO UPDATE SET data=excluded.data",
34 params![psk_id, psk.deref()],
35 )
36 .map(|_| ())
37 .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
38 }
39
40 pub fn get(&self, psk_id: &[u8]) -> Result<Option<PreSharedKey>, SqLiteDataStorageError> {
42 let connection = self.connection.lock().unwrap();
43
44 connection
45 .query_row(
46 "SELECT data FROM psk WHERE psk_id = ?",
47 params![psk_id],
48 |row| Ok(PreSharedKey::new(row.get(0)?)),
49 )
50 .optional()
51 .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
52 }
53
54 pub fn delete(&self, psk_id: &[u8]) -> Result<(), SqLiteDataStorageError> {
56 let connection = self.connection.lock().unwrap();
57
58 connection
59 .execute("DELETE FROM psk WHERE psk_id = ?", params![psk_id])
60 .map(|_| ())
61 .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
62 }
63}
64
65#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
66#[cfg_attr(mls_build_async, maybe_async::must_be_async)]
67impl PreSharedKeyStorage for SqLitePreSharedKeyStorage {
68 type Error = SqLiteDataStorageError;
69
70 async fn get(&self, id: &ExternalPskId) -> Result<Option<PreSharedKey>, Self::Error> {
71 self.get(id)
72 .map_err(|e| SqLiteDataStorageError::DataConversionError(e.into()))
73 }
74}
75
76#[cfg(test)]
77mod tests {
78 use mls_rs_core::psk::PreSharedKey;
79
80 use crate::{
81 SqLiteDataStorageEngine,
82 {connection_strategy::MemoryStrategy, test_utils::gen_rand_bytes},
83 };
84
85 use super::SqLitePreSharedKeyStorage;
86
87 fn test_psk() -> (Vec<u8>, PreSharedKey) {
88 let psk_id = gen_rand_bytes(32);
89 let stored_psk = PreSharedKey::new(gen_rand_bytes(64));
90
91 (psk_id, stored_psk)
92 }
93
94 fn test_storage() -> SqLitePreSharedKeyStorage {
95 SqLiteDataStorageEngine::new(MemoryStrategy)
96 .unwrap()
97 .pre_shared_key_storage()
98 .unwrap()
99 }
100
101 #[test]
102 fn test_insert() {
103 let (psk_id, psk) = test_psk();
104 let storage = test_storage();
105
106 storage.insert(&psk_id, &psk).unwrap();
107
108 let from_storage = storage.get(&psk_id).unwrap().unwrap();
109 assert_eq!(from_storage, psk);
110 }
111
112 #[test]
113 fn test_insert_existing_overwrite() {
114 let (psk_id, psk) = test_psk();
115 let (_, new_psk) = test_psk();
116
117 let storage = test_storage();
118
119 storage.insert(&psk_id, &psk).unwrap();
120 storage.insert(&psk_id, &new_psk).unwrap();
121
122 let from_storage = storage.get(&psk_id).unwrap().unwrap();
123 assert_eq!(from_storage, new_psk);
124 }
125
126 #[test]
127 fn test_delete() {
128 let (psk_id, psk) = test_psk();
129 let storage = test_storage();
130
131 storage.insert(&psk_id, &psk).unwrap();
132 storage.delete(&psk_id).unwrap();
133
134 assert!(storage.get(&psk_id).unwrap().is_none());
135 }
136}