1use connection_strategy::ConnectionStrategy;
6use group_state::SqLiteGroupStateStorage;
7use psk::SqLitePreSharedKeyStorage;
8use rusqlite::Connection;
9use storage::{SqLiteApplicationStorage, SqLiteKeyPackageStorage};
10use thiserror::Error;
11
12mod application;
13mod group_state;
14mod key_package;
15mod psk;
16
17#[cfg(any(feature = "sqlcipher", feature = "sqlcipher-bundled"))]
18mod cipher;
19
20#[cfg(test)]
21pub(crate) mod test_utils;
22
23pub mod connection_strategy;
25
26pub mod storage {
28 pub use {
29 crate::application::{Item, SqLiteApplicationStorage},
30 crate::group_state::SqLiteGroupStateStorage,
31 crate::key_package::SqLiteKeyPackageStorage,
32 crate::psk::SqLitePreSharedKeyStorage,
33 };
34}
35
36#[derive(Debug, Error)]
37pub enum SqLiteDataStorageError {
39 #[error(transparent)]
40 SqlEngineError(Box<dyn std::error::Error + Send + Sync + 'static>),
42 #[error(transparent)]
43 DataConversionError(Box<dyn std::error::Error + Send + Sync + 'static>),
45 #[error("epoch ID {0} exceeds maximum supported value (i64::MAX)")]
46 EpochIdOverflow(u64),
50 #[error("timestamp {0} exceeds maximum supported value (i64::MAX)")]
51 TimestampOverflow(u64),
55 #[cfg(any(feature = "sqlcipher", feature = "sqlcipher-bundled"))]
56 #[error("invalid key, must use SqlCipherKey::RawKeyWithSalt with plaintext_header_size > 0")]
57 SqlCipherKeyInvalidWithHeader,
59}
60
61impl mls_rs_core::error::IntoAnyError for SqLiteDataStorageError {
62 fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
63 Ok(self.into())
64 }
65}
66
67#[derive(Clone, Debug)]
68pub enum JournalMode {
69 Delete,
70 Truncate,
71 Persist,
72 Memory,
73 Wal,
74 Off,
75}
76
77impl JournalMode {
80 fn as_str(&self) -> &'static str {
81 match self {
82 JournalMode::Delete => "DELETE",
83 JournalMode::Truncate => "TRUNCATE",
84 JournalMode::Persist => "PERSIST",
85 JournalMode::Memory => "MEMORY",
86 JournalMode::Wal => "WAL",
87 JournalMode::Off => "OFF",
88 }
89 }
90}
91
92#[derive(Clone, Debug)]
93pub struct SqLiteDataStorageEngine<CS>
95where
96 CS: ConnectionStrategy,
97{
98 connection_strategy: CS,
99 journal_mode: Option<JournalMode>,
100}
101
102impl<CS> SqLiteDataStorageEngine<CS>
103where
104 CS: ConnectionStrategy,
105{
106 pub fn new(
107 connection_strategy: CS,
108 ) -> Result<SqLiteDataStorageEngine<CS>, SqLiteDataStorageError> {
109 Ok(SqLiteDataStorageEngine {
110 connection_strategy,
111 journal_mode: None,
112 })
113 }
114
115 pub fn with_journal_mode(self, journal_mode: Option<JournalMode>) -> Self {
117 Self {
118 journal_mode,
119 ..self
120 }
121 }
122
123 fn create_connection(&self) -> Result<Connection, SqLiteDataStorageError> {
124 let connection = self.connection_strategy.make_connection()?;
125
126 let current_schema = connection
128 .pragma_query_value(None, "user_version", |rows| rows.get::<_, u32>(0))
129 .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))?;
130
131 if let Some(journal_mode) = &self.journal_mode {
132 connection
133 .pragma_update(None, "journal_mode", journal_mode.as_str())
134 .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))?;
135 }
136
137 if current_schema < 1 {
138 create_tables_v1(&connection)?;
139 }
140
141 Ok(connection)
142 }
143
144 pub fn group_state_storage(&self) -> Result<SqLiteGroupStateStorage, SqLiteDataStorageError> {
146 Ok(SqLiteGroupStateStorage::new(self.create_connection()?))
147 }
148
149 pub fn key_package_storage(&self) -> Result<SqLiteKeyPackageStorage, SqLiteDataStorageError> {
151 Ok(SqLiteKeyPackageStorage::new(self.create_connection()?))
152 }
153
154 pub fn pre_shared_key_storage(
156 &self,
157 ) -> Result<SqLitePreSharedKeyStorage, SqLiteDataStorageError> {
158 Ok(SqLitePreSharedKeyStorage::new(self.create_connection()?))
159 }
160
161 pub fn application_data_storage(
163 &self,
164 ) -> Result<SqLiteApplicationStorage, SqLiteDataStorageError> {
165 Ok(SqLiteApplicationStorage::new(self.create_connection()?))
166 }
167}
168
169fn create_tables_v1(connection: &Connection) -> Result<(), SqLiteDataStorageError> {
170 connection
171 .execute_batch(
172 "BEGIN;
173 CREATE TABLE mls_group (
174 group_id BLOB PRIMARY KEY,
175 snapshot BLOB NOT NULL
176 ) WITHOUT ROWID;
177 CREATE TABLE epoch (
178 group_id BLOB,
179 epoch_id INTEGER,
180 epoch_data BLOB NOT NULL,
181 FOREIGN KEY (group_id) REFERENCES mls_group (group_id) ON DELETE CASCADE
182 PRIMARY KEY (group_id, epoch_id)
183 ) WITHOUT ROWID;
184 CREATE TABLE key_package (
185 id BLOB PRIMARY KEY,
186 expiration INTEGER,
187 data BLOB NOT NULL
188 ) WITHOUT ROWID;
189 CREATE INDEX key_package_exp ON key_package (expiration);
190 CREATE TABLE psk (
191 psk_id BLOB PRIMARY KEY,
192 data BLOB NOT NULL
193 ) WITHOUT ROWID;
194 CREATE TABLE kvs (
195 key TEXT PRIMARY KEY,
196 value BLOB NOT NULL
197 ) WITHOUT ROWID;
198 PRAGMA user_version = 1;
199 COMMIT;",
200 )
201 .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
202}
203
204#[cfg(test)]
205mod tests {
206 use tempfile::tempdir;
207
208 use crate::{
209 connection_strategy::{FileConnectionStrategy, MemoryStrategy},
210 SqLiteDataStorageEngine,
211 };
212
213 #[test]
214 pub fn user_version_test() {
215 let database = SqLiteDataStorageEngine::new(MemoryStrategy).unwrap();
216
217 let _connection = database.create_connection().unwrap();
218
219 let connection = database.create_connection().unwrap();
221
222 let current_schema = connection
224 .pragma_query_value(None, "user_version", |rows| rows.get::<_, u32>(0))
225 .unwrap();
226
227 assert_eq!(current_schema, 1);
228 }
229
230 #[test]
231 pub fn journal_mode_test() {
232 let temp = tempdir().unwrap();
233
234 let database = SqLiteDataStorageEngine::new(FileConnectionStrategy::new(
236 &temp.path().join("test_db.sqlite"),
237 ))
238 .unwrap();
239
240 let connection = database
241 .with_journal_mode(Some(crate::JournalMode::Truncate))
242 .create_connection()
243 .unwrap();
244
245 let journal_mode = connection
246 .pragma_query_value(None, "journal_mode", |rows| rows.get::<_, String>(0))
247 .unwrap();
248
249 assert_eq!(journal_mode, "truncate");
250 }
251
252 #[test]
253 pub fn extended_schema_version_test() {
254 let temp = tempdir().unwrap();
257 let database = SqLiteDataStorageEngine::new(FileConnectionStrategy::new(
258 &temp.path().join("extended_schema_test.sqlite"),
259 ))
260 .unwrap();
261
262 let connection = database.create_connection().unwrap();
264
265 connection
267 .execute_batch(
268 "BEGIN;
269 CREATE TABLE custom_table (
270 id INTEGER PRIMARY KEY,
271 data TEXT NOT NULL
272 );
273 PRAGMA user_version = 2;
274 COMMIT;",
275 )
276 .unwrap();
277
278 drop(connection);
279
280 let connection2 = database.create_connection().unwrap();
282
283 let current_schema = connection2
285 .pragma_query_value(None, "user_version", |rows| rows.get::<_, u32>(0))
286 .unwrap();
287
288 assert_eq!(current_schema, 2);
289
290 let mls_table_exists: bool = connection2
292 .query_row(
293 "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='mls_group'",
294 [],
295 |row| row.get(0),
296 )
297 .map(|count: i32| count > 0)
298 .unwrap();
299
300 let custom_table_exists: bool = connection2
301 .query_row(
302 "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='custom_table'",
303 [],
304 |row| row.get(0),
305 )
306 .map(|count: i32| count > 0)
307 .unwrap();
308
309 assert!(mls_table_exists);
310 assert!(custom_table_exists);
311 }
312}