mls_rs_provider_sqlite/
lib.rs1use connection_strategy::ConnectionStrategy;
6use group_state::SqLiteGroupStateStorage;
7use psk::SqLitePreSharedKeyStorage;
8use rusqlite::Connection;
9use storage::{SqLiteApplicationStorage, SqLiteKeyPackageStorage};
10use thiserror::Error;
11
12mod application;
13mod group_state;
14mod key_package;
15mod psk;
16
17#[cfg(any(feature = "sqlcipher", feature = "sqlcipher-bundled"))]
18mod cipher;
19
20#[cfg(test)]
21pub(crate) mod test_utils;
22
23pub mod connection_strategy;
25
26pub mod storage {
28 pub use {
29 crate::application::{Item, SqLiteApplicationStorage},
30 crate::group_state::SqLiteGroupStateStorage,
31 crate::key_package::SqLiteKeyPackageStorage,
32 crate::psk::SqLitePreSharedKeyStorage,
33 };
34}
35
36#[derive(Debug, Error)]
37pub enum SqLiteDataStorageError {
39 #[error(transparent)]
40 SqlEngineError(Box<dyn std::error::Error + Send + Sync + 'static>),
42 #[error(transparent)]
43 DataConversionError(Box<dyn std::error::Error + Send + Sync + 'static>),
45 #[cfg(any(feature = "sqlcipher", feature = "sqlcipher-bundled"))]
46 #[error("invalid key, must use SqlCipherKey::RawKeyWithSalt with plaintext_header_size > 0")]
47 SqlCipherKeyInvalidWithHeader,
49}
50
51impl mls_rs_core::error::IntoAnyError for SqLiteDataStorageError {
52 fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
53 Ok(self.into())
54 }
55}
56
57#[derive(Clone, Debug)]
58pub enum JournalMode {
59 Delete,
60 Truncate,
61 Persist,
62 Memory,
63 Wal,
64 Off,
65}
66
67impl JournalMode {
70 fn as_str(&self) -> &'static str {
71 match self {
72 JournalMode::Delete => "DELETE",
73 JournalMode::Truncate => "TRUNCATE",
74 JournalMode::Persist => "PERSIST",
75 JournalMode::Memory => "MEMORY",
76 JournalMode::Wal => "WAL",
77 JournalMode::Off => "OFF",
78 }
79 }
80}
81
82#[derive(Clone, Debug)]
83pub struct SqLiteDataStorageEngine<CS>
85where
86 CS: ConnectionStrategy,
87{
88 connection_strategy: CS,
89 journal_mode: Option<JournalMode>,
90}
91
92impl<CS> SqLiteDataStorageEngine<CS>
93where
94 CS: ConnectionStrategy,
95{
96 pub fn new(
97 connection_strategy: CS,
98 ) -> Result<SqLiteDataStorageEngine<CS>, SqLiteDataStorageError> {
99 Ok(SqLiteDataStorageEngine {
100 connection_strategy,
101 journal_mode: None,
102 })
103 }
104
105 pub fn with_journal_mode(self, journal_mode: Option<JournalMode>) -> Self {
107 Self {
108 journal_mode,
109 ..self
110 }
111 }
112
113 fn create_connection(&self) -> Result<Connection, SqLiteDataStorageError> {
114 let connection = self.connection_strategy.make_connection()?;
115
116 let current_schema = connection
118 .pragma_query_value(None, "user_version", |rows| rows.get::<_, u32>(0))
119 .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))?;
120
121 if let Some(journal_mode) = &self.journal_mode {
122 connection
123 .pragma_update(None, "journal_mode", journal_mode.as_str())
124 .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))?;
125 }
126
127 if current_schema != 1 {
128 create_tables_v1(&connection)?;
129 }
130
131 Ok(connection)
132 }
133
134 pub fn group_state_storage(&self) -> Result<SqLiteGroupStateStorage, SqLiteDataStorageError> {
136 Ok(SqLiteGroupStateStorage::new(self.create_connection()?))
137 }
138
139 pub fn key_package_storage(&self) -> Result<SqLiteKeyPackageStorage, SqLiteDataStorageError> {
141 Ok(SqLiteKeyPackageStorage::new(self.create_connection()?))
142 }
143
144 pub fn pre_shared_key_storage(
146 &self,
147 ) -> Result<SqLitePreSharedKeyStorage, SqLiteDataStorageError> {
148 Ok(SqLitePreSharedKeyStorage::new(self.create_connection()?))
149 }
150
151 pub fn application_data_storage(
153 &self,
154 ) -> Result<SqLiteApplicationStorage, SqLiteDataStorageError> {
155 Ok(SqLiteApplicationStorage::new(self.create_connection()?))
156 }
157}
158
159fn create_tables_v1(connection: &Connection) -> Result<(), SqLiteDataStorageError> {
160 connection
161 .execute_batch(
162 "BEGIN;
163 CREATE TABLE mls_group (
164 group_id BLOB PRIMARY KEY,
165 snapshot BLOB NOT NULL
166 ) WITHOUT ROWID;
167 CREATE TABLE epoch (
168 group_id BLOB,
169 epoch_id INTEGER,
170 epoch_data BLOB NOT NULL,
171 FOREIGN KEY (group_id) REFERENCES mls_group (group_id) ON DELETE CASCADE
172 PRIMARY KEY (group_id, epoch_id)
173 ) WITHOUT ROWID;
174 CREATE TABLE key_package (
175 id BLOB PRIMARY KEY,
176 expiration INTEGER,
177 data BLOB NOT NULL
178 ) WITHOUT ROWID;
179 CREATE INDEX key_package_exp ON key_package (expiration);
180 CREATE TABLE psk (
181 psk_id BLOB PRIMARY KEY,
182 data BLOB NOT NULL
183 ) WITHOUT ROWID;
184 CREATE TABLE kvs (
185 key TEXT PRIMARY KEY,
186 value BLOB NOT NULL
187 ) WITHOUT ROWID;
188 PRAGMA user_version = 1;
189 COMMIT;",
190 )
191 .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
192}
193
194#[cfg(test)]
195mod tests {
196 use tempfile::tempdir;
197
198 use crate::{
199 connection_strategy::{FileConnectionStrategy, MemoryStrategy},
200 SqLiteDataStorageEngine,
201 };
202
203 #[test]
204 pub fn user_version_test() {
205 let database = SqLiteDataStorageEngine::new(MemoryStrategy).unwrap();
206
207 let _connection = database.create_connection().unwrap();
208
209 let connection = database.create_connection().unwrap();
211
212 let current_schema = connection
214 .pragma_query_value(None, "user_version", |rows| rows.get::<_, u32>(0))
215 .unwrap();
216
217 assert_eq!(current_schema, 1);
218 }
219
220 #[test]
221 pub fn journal_mode_test() {
222 let temp = tempdir().unwrap();
223
224 let database = SqLiteDataStorageEngine::new(FileConnectionStrategy::new(
226 &temp.path().join("test_db.sqlite"),
227 ))
228 .unwrap();
229
230 let connection = database
231 .with_journal_mode(Some(crate::JournalMode::Truncate))
232 .create_connection()
233 .unwrap();
234
235 let journal_mode = connection
236 .pragma_query_value(None, "journal_mode", |rows| rows.get::<_, String>(0))
237 .unwrap();
238
239 assert_eq!(journal_mode, "truncate");
240 }
241}