Skip to main content

mls_rs_provider_sqlite/
lib.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use connection_strategy::ConnectionStrategy;
6use group_state::SqLiteGroupStateStorage;
7use psk::SqLitePreSharedKeyStorage;
8use rusqlite::Connection;
9use storage::{SqLiteApplicationStorage, SqLiteKeyPackageStorage};
10use thiserror::Error;
11
12mod application;
13mod group_state;
14mod key_package;
15mod psk;
16
17#[cfg(any(feature = "sqlcipher", feature = "sqlcipher-bundled"))]
18mod cipher;
19
20#[cfg(test)]
21pub(crate) mod test_utils;
22
23/// Connection strategies.
24pub mod connection_strategy;
25
26/// SQLite storage components.
27pub mod storage {
28    pub use {
29        crate::application::{Item, SqLiteApplicationStorage},
30        crate::group_state::SqLiteGroupStateStorage,
31        crate::key_package::SqLiteKeyPackageStorage,
32        crate::psk::SqLitePreSharedKeyStorage,
33    };
34}
35
36#[derive(Debug, Error)]
37/// SQLite data storage error.
38pub enum SqLiteDataStorageError {
39    #[error(transparent)]
40    /// SQLite error.
41    SqlEngineError(Box<dyn std::error::Error + Send + Sync + 'static>),
42    #[error(transparent)]
43    /// Stored data is not compatible with the expected data type.
44    DataConversionError(Box<dyn std::error::Error + Send + Sync + 'static>),
45    #[error("epoch ID {0} exceeds maximum supported value (i64::MAX)")]
46    /// Epoch ID is too large to store in SQLite.
47    ///
48    /// SQLite uses signed 64-bit integers, limiting epoch IDs to values up to 9,223,372,036,854,775,807.
49    EpochIdOverflow(u64),
50    #[error("timestamp {0} exceeds maximum supported value (i64::MAX)")]
51    /// Timestamp is too large to store in SQLite.
52    ///
53    /// SQLite uses signed 64-bit integers, limiting timestamps to values up to 9,223,372,036,854,775,807.
54    TimestampOverflow(u64),
55    #[cfg(any(feature = "sqlcipher", feature = "sqlcipher-bundled"))]
56    #[error("invalid key, must use SqlCipherKey::RawKeyWithSalt with plaintext_header_size > 0")]
57    /// Invalid SQLCipher key header.
58    SqlCipherKeyInvalidWithHeader,
59}
60
61impl mls_rs_core::error::IntoAnyError for SqLiteDataStorageError {
62    fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
63        Ok(self.into())
64    }
65}
66
67#[derive(Clone, Debug)]
68pub enum JournalMode {
69    Delete,
70    Truncate,
71    Persist,
72    Memory,
73    Wal,
74    Off,
75}
76
77/// Note: for in-memory dbs (such as what the tests use), the only available options are MEMORY or OFF
78/// Invalid modes do not error, only no-op
79impl JournalMode {
80    fn as_str(&self) -> &'static str {
81        match self {
82            JournalMode::Delete => "DELETE",
83            JournalMode::Truncate => "TRUNCATE",
84            JournalMode::Persist => "PERSIST",
85            JournalMode::Memory => "MEMORY",
86            JournalMode::Wal => "WAL",
87            JournalMode::Off => "OFF",
88        }
89    }
90}
91
92#[derive(Clone, Debug)]
93/// SQLite data storage engine.
94pub struct SqLiteDataStorageEngine<CS>
95where
96    CS: ConnectionStrategy,
97{
98    connection_strategy: CS,
99    journal_mode: Option<JournalMode>,
100}
101
102impl<CS> SqLiteDataStorageEngine<CS>
103where
104    CS: ConnectionStrategy,
105{
106    pub fn new(
107        connection_strategy: CS,
108    ) -> Result<SqLiteDataStorageEngine<CS>, SqLiteDataStorageError> {
109        Ok(SqLiteDataStorageEngine {
110            connection_strategy,
111            journal_mode: None,
112        })
113    }
114
115    /// A `journal_mode` of `None` means the SQLite default is used.
116    pub fn with_journal_mode(self, journal_mode: Option<JournalMode>) -> Self {
117        Self {
118            journal_mode,
119            ..self
120        }
121    }
122
123    fn create_connection(&self) -> Result<Connection, SqLiteDataStorageError> {
124        let connection = self.connection_strategy.make_connection()?;
125
126        // Run SQL to establish the schema
127        let current_schema = connection
128            .pragma_query_value(None, "user_version", |rows| rows.get::<_, u32>(0))
129            .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))?;
130
131        if let Some(journal_mode) = &self.journal_mode {
132            connection
133                .pragma_update(None, "journal_mode", journal_mode.as_str())
134                .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))?;
135        }
136
137        if current_schema < 1 {
138            create_tables_v1(&connection)?;
139        }
140
141        Ok(connection)
142    }
143
144    /// Returns a struct that implements the `GroupStateStorage` trait for use in MLS.
145    pub fn group_state_storage(&self) -> Result<SqLiteGroupStateStorage, SqLiteDataStorageError> {
146        Ok(SqLiteGroupStateStorage::new(self.create_connection()?))
147    }
148
149    /// Returns a struct that implements the `KeyPackageStorage` trait for use in MLS.
150    pub fn key_package_storage(&self) -> Result<SqLiteKeyPackageStorage, SqLiteDataStorageError> {
151        Ok(SqLiteKeyPackageStorage::new(self.create_connection()?))
152    }
153
154    /// Returns a struct that implements the `PreSharedKeyStorage` trait for use in MLS.
155    pub fn pre_shared_key_storage(
156        &self,
157    ) -> Result<SqLitePreSharedKeyStorage, SqLiteDataStorageError> {
158        Ok(SqLitePreSharedKeyStorage::new(self.create_connection()?))
159    }
160
161    /// Returns a key value store that can be used to store application specific data.
162    pub fn application_data_storage(
163        &self,
164    ) -> Result<SqLiteApplicationStorage, SqLiteDataStorageError> {
165        Ok(SqLiteApplicationStorage::new(self.create_connection()?))
166    }
167}
168
169fn create_tables_v1(connection: &Connection) -> Result<(), SqLiteDataStorageError> {
170    connection
171        .execute_batch(
172            "BEGIN;
173            CREATE TABLE mls_group (
174                group_id BLOB PRIMARY KEY,
175                snapshot BLOB NOT NULL
176            ) WITHOUT ROWID;
177            CREATE TABLE epoch (
178                group_id BLOB,
179                epoch_id INTEGER,
180                epoch_data BLOB NOT NULL,
181                FOREIGN KEY (group_id) REFERENCES mls_group (group_id) ON DELETE CASCADE
182                PRIMARY KEY (group_id, epoch_id)
183            ) WITHOUT ROWID;
184            CREATE TABLE key_package (
185                id BLOB PRIMARY KEY,
186                expiration INTEGER,
187                data BLOB NOT NULL
188            ) WITHOUT ROWID;
189            CREATE INDEX key_package_exp ON key_package (expiration);
190            CREATE TABLE psk (
191                psk_id BLOB PRIMARY KEY,
192                data BLOB NOT NULL
193            ) WITHOUT ROWID;
194            CREATE TABLE kvs (
195                key TEXT PRIMARY KEY,
196                value BLOB NOT NULL
197            ) WITHOUT ROWID;
198            PRAGMA user_version = 1;
199            COMMIT;",
200        )
201        .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
202}
203
204#[cfg(test)]
205mod tests {
206    use tempfile::tempdir;
207
208    use crate::{
209        connection_strategy::{FileConnectionStrategy, MemoryStrategy},
210        SqLiteDataStorageEngine,
211    };
212
213    #[test]
214    pub fn user_version_test() {
215        let database = SqLiteDataStorageEngine::new(MemoryStrategy).unwrap();
216
217        let _connection = database.create_connection().unwrap();
218
219        // Create another connection to make sure the migration doesn't try to happen again.
220        let connection = database.create_connection().unwrap();
221
222        // Run SQL to establish the schema
223        let current_schema = connection
224            .pragma_query_value(None, "user_version", |rows| rows.get::<_, u32>(0))
225            .unwrap();
226
227        assert_eq!(current_schema, 1);
228    }
229
230    #[test]
231    pub fn journal_mode_test() {
232        let temp = tempdir().unwrap();
233
234        // Connect with journal_mode other than the default of MEMORY
235        let database = SqLiteDataStorageEngine::new(FileConnectionStrategy::new(
236            &temp.path().join("test_db.sqlite"),
237        ))
238        .unwrap();
239
240        let connection = database
241            .with_journal_mode(Some(crate::JournalMode::Truncate))
242            .create_connection()
243            .unwrap();
244
245        let journal_mode = connection
246            .pragma_query_value(None, "journal_mode", |rows| rows.get::<_, String>(0))
247            .unwrap();
248
249        assert_eq!(journal_mode, "truncate");
250    }
251
252    #[test]
253    pub fn extended_schema_version_test() {
254        // Test that downstream applications can extend the schema beyond version 1
255        // without breaking mls-rs connection creation
256        let temp = tempdir().unwrap();
257        let database = SqLiteDataStorageEngine::new(FileConnectionStrategy::new(
258            &temp.path().join("extended_schema_test.sqlite"),
259        ))
260        .unwrap();
261
262        // Initialize database (creates v1 schema)
263        let connection = database.create_connection().unwrap();
264
265        // Simulate downstream application extending schema
266        connection
267            .execute_batch(
268                "BEGIN;
269                CREATE TABLE custom_table (
270                    id INTEGER PRIMARY KEY,
271                    data TEXT NOT NULL
272                );
273                PRAGMA user_version = 2;
274                COMMIT;",
275            )
276            .unwrap();
277
278        drop(connection);
279
280        // Create new connection - should not try to recreate tables
281        let connection2 = database.create_connection().unwrap();
282
283        // Verify user_version is still 2
284        let current_schema = connection2
285            .pragma_query_value(None, "user_version", |rows| rows.get::<_, u32>(0))
286            .unwrap();
287
288        assert_eq!(current_schema, 2);
289
290        // Verify both mls-rs tables and custom table exist
291        let mls_table_exists: bool = connection2
292            .query_row(
293                "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='mls_group'",
294                [],
295                |row| row.get(0),
296            )
297            .map(|count: i32| count > 0)
298            .unwrap();
299
300        let custom_table_exists: bool = connection2
301            .query_row(
302                "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='custom_table'",
303                [],
304                |row| row.get(0),
305            )
306            .map(|count: i32| count > 0)
307            .unwrap();
308
309        assert!(mls_table_exists);
310        assert!(custom_table_exists);
311    }
312}