mls_rs_provider_sqlite/
lib.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use connection_strategy::ConnectionStrategy;
6use group_state::SqLiteGroupStateStorage;
7use psk::SqLitePreSharedKeyStorage;
8use rusqlite::Connection;
9use storage::{SqLiteApplicationStorage, SqLiteKeyPackageStorage};
10use thiserror::Error;
11
12mod application;
13mod group_state;
14mod key_package;
15mod psk;
16
17#[cfg(any(feature = "sqlcipher", feature = "sqlcipher-bundled"))]
18mod cipher;
19
20#[cfg(test)]
21pub(crate) mod test_utils;
22
23/// Connection strategies.
24pub mod connection_strategy;
25
26/// SQLite storage components.
27pub mod storage {
28    pub use {
29        crate::application::{Item, SqLiteApplicationStorage},
30        crate::group_state::SqLiteGroupStateStorage,
31        crate::key_package::SqLiteKeyPackageStorage,
32        crate::psk::SqLitePreSharedKeyStorage,
33    };
34}
35
36#[derive(Debug, Error)]
37/// SQLite data storage error.
38pub enum SqLiteDataStorageError {
39    #[error(transparent)]
40    /// SQLite error.
41    SqlEngineError(Box<dyn std::error::Error + Send + Sync + 'static>),
42    #[error(transparent)]
43    /// Stored data is not compatible with the expected data type.
44    DataConversionError(Box<dyn std::error::Error + Send + Sync + 'static>),
45    #[cfg(any(feature = "sqlcipher", feature = "sqlcipher-bundled"))]
46    #[error("invalid key, must use SqlCipherKey::RawKeyWithSalt with plaintext_header_size > 0")]
47    /// Invalid SQLCipher key header.
48    SqlCipherKeyInvalidWithHeader,
49}
50
51impl mls_rs_core::error::IntoAnyError for SqLiteDataStorageError {
52    fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
53        Ok(self.into())
54    }
55}
56
57#[derive(Clone, Debug)]
58pub enum JournalMode {
59    Delete,
60    Truncate,
61    Persist,
62    Memory,
63    Wal,
64    Off,
65}
66
67/// Note: for in-memory dbs (such as what the tests use), the only available options are MEMORY or OFF
68/// Invalid modes do not error, only no-op
69impl JournalMode {
70    fn as_str(&self) -> &'static str {
71        match self {
72            JournalMode::Delete => "DELETE",
73            JournalMode::Truncate => "TRUNCATE",
74            JournalMode::Persist => "PERSIST",
75            JournalMode::Memory => "MEMORY",
76            JournalMode::Wal => "WAL",
77            JournalMode::Off => "OFF",
78        }
79    }
80}
81
82#[derive(Clone, Debug)]
83/// SQLite data storage engine.
84pub struct SqLiteDataStorageEngine<CS>
85where
86    CS: ConnectionStrategy,
87{
88    connection_strategy: CS,
89    journal_mode: Option<JournalMode>,
90}
91
92impl<CS> SqLiteDataStorageEngine<CS>
93where
94    CS: ConnectionStrategy,
95{
96    pub fn new(
97        connection_strategy: CS,
98    ) -> Result<SqLiteDataStorageEngine<CS>, SqLiteDataStorageError> {
99        Ok(SqLiteDataStorageEngine {
100            connection_strategy,
101            journal_mode: None,
102        })
103    }
104
105    /// A `journal_mode` of `None` means the SQLite default is used.
106    pub fn with_journal_mode(self, journal_mode: Option<JournalMode>) -> Self {
107        Self {
108            journal_mode,
109            ..self
110        }
111    }
112
113    fn create_connection(&self) -> Result<Connection, SqLiteDataStorageError> {
114        let connection = self.connection_strategy.make_connection()?;
115
116        // Run SQL to establish the schema
117        let current_schema = connection
118            .pragma_query_value(None, "user_version", |rows| rows.get::<_, u32>(0))
119            .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))?;
120
121        if let Some(journal_mode) = &self.journal_mode {
122            connection
123                .pragma_update(None, "journal_mode", journal_mode.as_str())
124                .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))?;
125        }
126
127        if current_schema != 1 {
128            create_tables_v1(&connection)?;
129        }
130
131        Ok(connection)
132    }
133
134    /// Returns a struct that implements the `GroupStateStorage` trait for use in MLS.
135    pub fn group_state_storage(&self) -> Result<SqLiteGroupStateStorage, SqLiteDataStorageError> {
136        Ok(SqLiteGroupStateStorage::new(self.create_connection()?))
137    }
138
139    /// Returns a struct that implements the `KeyPackageStorage` trait for use in MLS.
140    pub fn key_package_storage(&self) -> Result<SqLiteKeyPackageStorage, SqLiteDataStorageError> {
141        Ok(SqLiteKeyPackageStorage::new(self.create_connection()?))
142    }
143
144    /// Returns a struct that implements the `PreSharedKeyStorage` trait for use in MLS.
145    pub fn pre_shared_key_storage(
146        &self,
147    ) -> Result<SqLitePreSharedKeyStorage, SqLiteDataStorageError> {
148        Ok(SqLitePreSharedKeyStorage::new(self.create_connection()?))
149    }
150
151    /// Returns a key value store that can be used to store application specific data.
152    pub fn application_data_storage(
153        &self,
154    ) -> Result<SqLiteApplicationStorage, SqLiteDataStorageError> {
155        Ok(SqLiteApplicationStorage::new(self.create_connection()?))
156    }
157}
158
159fn create_tables_v1(connection: &Connection) -> Result<(), SqLiteDataStorageError> {
160    connection
161        .execute_batch(
162            "BEGIN;
163            CREATE TABLE mls_group (
164                group_id BLOB PRIMARY KEY,
165                snapshot BLOB NOT NULL
166            ) WITHOUT ROWID;
167            CREATE TABLE epoch (
168                group_id BLOB,
169                epoch_id INTEGER,
170                epoch_data BLOB NOT NULL,
171                FOREIGN KEY (group_id) REFERENCES mls_group (group_id) ON DELETE CASCADE
172                PRIMARY KEY (group_id, epoch_id)
173            ) WITHOUT ROWID;
174            CREATE TABLE key_package (
175                id BLOB PRIMARY KEY,
176                expiration INTEGER,
177                data BLOB NOT NULL
178            ) WITHOUT ROWID;
179            CREATE INDEX key_package_exp ON key_package (expiration);
180            CREATE TABLE psk (
181                psk_id BLOB PRIMARY KEY,
182                data BLOB NOT NULL
183            ) WITHOUT ROWID;
184            CREATE TABLE kvs (
185                key TEXT PRIMARY KEY,
186                value BLOB NOT NULL
187            ) WITHOUT ROWID;
188            PRAGMA user_version = 1;
189            COMMIT;",
190        )
191        .map_err(|e| SqLiteDataStorageError::SqlEngineError(e.into()))
192}
193
194#[cfg(test)]
195mod tests {
196    use tempfile::tempdir;
197
198    use crate::{
199        connection_strategy::{FileConnectionStrategy, MemoryStrategy},
200        SqLiteDataStorageEngine,
201    };
202
203    #[test]
204    pub fn user_version_test() {
205        let database = SqLiteDataStorageEngine::new(MemoryStrategy).unwrap();
206
207        let _connection = database.create_connection().unwrap();
208
209        // Create another connection to make sure the migration doesn't try to happen again.
210        let connection = database.create_connection().unwrap();
211
212        // Run SQL to establish the schema
213        let current_schema = connection
214            .pragma_query_value(None, "user_version", |rows| rows.get::<_, u32>(0))
215            .unwrap();
216
217        assert_eq!(current_schema, 1);
218    }
219
220    #[test]
221    pub fn journal_mode_test() {
222        let temp = tempdir().unwrap();
223
224        // Connect with journal_mode other than the default of MEMORY
225        let database = SqLiteDataStorageEngine::new(FileConnectionStrategy::new(
226            &temp.path().join("test_db.sqlite"),
227        ))
228        .unwrap();
229
230        let connection = database
231            .with_journal_mode(Some(crate::JournalMode::Truncate))
232            .create_connection()
233            .unwrap();
234
235        let journal_mode = connection
236            .pragma_query_value(None, "journal_mode", |rows| rows.get::<_, String>(0))
237            .unwrap();
238
239        assert_eq!(journal_mode, "truncate");
240    }
241}