covert_storage/
encrypted_pool.rs

1use std::path::Path;
2
3use covert_types::state::StorageState;
4use futures::{future::BoxFuture, Stream};
5use sqlx::{
6    pool::PoolConnection,
7    sqlite::{SqliteQueryResult, SqliteRow},
8    Pool, Sqlite, Transaction,
9};
10
11use crate::{
12    states::{Sealed, Uninitialized, Unsealed},
13    storage::{create_ecrypted_pool, create_master_key, Storage},
14    utils::owned_rw_lock::{OwnedRwLock, TransitionResult},
15};
16
17#[derive(Debug)]
18pub struct EncryptedPool(OwnedRwLock<PoolState>);
19
20struct PoolClosedStream;
21
22impl Stream for PoolClosedStream {
23    type Item = Result<sqlx::Either<SqliteQueryResult, SqliteRow>, sqlx::Error>;
24
25    fn poll_next(
26        self: std::pin::Pin<&mut Self>,
27        _cx: &mut std::task::Context<'_>,
28    ) -> std::task::Poll<Option<Self::Item>> {
29        std::task::Poll::Ready(Some(Err(sqlx::Error::PoolClosed)))
30    }
31}
32
33impl<'c> sqlx::Executor<'c> for &EncryptedPool {
34    type Database = Sqlite;
35
36    fn fetch_many<'e, 'q, E>(
37        self,
38        query: E,
39    ) -> futures::stream::BoxStream<
40        'e,
41        Result<
42            sqlx::Either<
43                <Self::Database as sqlx::Database>::QueryResult,
44                <Self::Database as sqlx::Database>::Row,
45            >,
46            sqlx::Error,
47        >,
48    >
49    where
50        'c: 'e,
51        'q: 'e,
52        E: 'q + sqlx::Execute<'q, Self::Database>,
53    {
54        let Ok(pool) = self.pool() else {
55            return Box::pin(PoolClosedStream);
56        };
57        pool.fetch_many(query)
58    }
59
60    fn fetch_optional<'e, 'q, E>(
61        self,
62        query: E,
63    ) -> futures::future::BoxFuture<
64        'e,
65        Result<Option<<Self::Database as sqlx::Database>::Row>, sqlx::Error>,
66    >
67    where
68        'c: 'e,
69        'q: 'e,
70        E: 'q + sqlx::Execute<'q, Self::Database>,
71    {
72        let pool = match self.pool() {
73            Ok(p) => p,
74            Err(err) => return Box::pin(async { Err(err) }),
75        };
76        pool.fetch_optional(query)
77    }
78
79    fn prepare_with<'e, 'q: 'e>(
80        self,
81        sql: &'q str,
82        parameters: &'e [<Self::Database as sqlx::Database>::TypeInfo],
83    ) -> futures::future::BoxFuture<
84        'e,
85        Result<<Self::Database as sqlx::database::HasStatement<'q>>::Statement, sqlx::Error>,
86    >
87    where
88        'c: 'e,
89    {
90        let pool = match self.pool() {
91            Ok(p) => p,
92            Err(err) => return Box::pin(async { Err(err) }),
93        };
94        pool.prepare_with(sql, parameters)
95    }
96
97    fn describe<'e, 'q: 'e>(
98        self,
99        sql: &'q str,
100    ) -> futures::future::BoxFuture<'e, Result<sqlx::Describe<Self::Database>, sqlx::Error>>
101    where
102        'c: 'e,
103    {
104        let pool = match self.pool() {
105            Ok(p) => p,
106            Err(err) => return Box::pin(async { Err(err) }),
107        };
108        pool.describe(sql)
109    }
110}
111
112impl<'c> sqlx::Acquire<'c> for &EncryptedPool {
113    type Database = Sqlite;
114
115    type Connection = PoolConnection<Sqlite>;
116
117    fn acquire(self) -> BoxFuture<'c, Result<Self::Connection, sqlx::Error>> {
118        let pool = match self.pool() {
119            Ok(p) => p,
120            Err(err) => return Box::pin(async { Err(err) }),
121        };
122        Box::pin(pool.acquire())
123    }
124
125    fn begin(self) -> BoxFuture<'c, Result<Transaction<'c, Self::Database>, sqlx::Error>> {
126        let pool = match self.pool() {
127            Ok(p) => p,
128            Err(err) => return Box::pin(async { Err(err) }),
129        };
130        Box::pin(async move { pool.begin().await })
131    }
132}
133
134#[derive(Debug)]
135pub enum PoolState {
136    Uninitialized(Storage<Uninitialized>),
137    Sealed(Storage<Sealed>),
138    Unsealed(Storage<Unsealed>),
139}
140
141impl PoolState {
142    /// Try to get a unsealed storage.
143    ///
144    /// # Errors
145    ///
146    /// Returns error if the storage is not unsealed.
147    pub fn get_unsealed(&self) -> Result<&Storage<Unsealed>, EncryptedPoolError> {
148        match self {
149            PoolState::Uninitialized(_) => Err(EncryptedPoolError::InvalidState(
150                StorageState::Uninitialized,
151            )),
152            PoolState::Sealed(_) => Err(EncryptedPoolError::InvalidState(StorageState::Sealed)),
153            PoolState::Unsealed(b) => Ok(b),
154        }
155    }
156}
157
158#[derive(Debug, thiserror::Error)]
159pub enum EncryptedPoolError {
160    #[error("This operation is not allowed when the current state is `{0}`")]
161    InvalidState(StorageState),
162    #[error("Failed to transition the pool state from `{from}` to `{to}`")]
163    Transition {
164        from: StorageState,
165        to: StorageState,
166    },
167}
168
169impl EncryptedPool {
170    pub fn new(storage_path: &impl ToString) -> Self {
171        let storage_path = storage_path.to_string();
172
173        if Path::new(&storage_path).exists() {
174            Self(OwnedRwLock::new(PoolState::Sealed(Storage {
175                state: Sealed,
176                storage_path,
177            })))
178        } else {
179            Self(OwnedRwLock::new(PoolState::Uninitialized(Storage {
180                state: Uninitialized,
181                storage_path,
182            })))
183        }
184    }
185
186    /// Creates an unsealed temporary pool which is useful when writing tests.
187    #[must_use]
188    pub fn new_tmp() -> Self {
189        let storage_path = ":memory:".to_string();
190        let master_key = create_master_key();
191        let pool = create_ecrypted_pool(true, &storage_path, master_key)
192            .expect("to create encrypted pool and this should only be used for testing");
193
194        Self(OwnedRwLock::new(PoolState::Unsealed(Storage {
195            state: Unsealed { pool },
196            storage_path,
197        })))
198    }
199
200    pub fn state(&self) -> StorageState {
201        #[allow(clippy::redundant_closure_for_method_calls)]
202        self.0.map(|barrier| barrier.into())
203    }
204
205    /// Initialize the pool.
206    ///
207    /// # Errors
208    ///
209    /// Returns error if the pool is not uninitialized or the initialization fails.
210    pub fn initialize(&self) -> Result<Option<String>, EncryptedPoolError> {
211        self.0.write(|barrier| {
212            let barrier = match barrier {
213                PoolState::Uninitialized(barrier) => barrier,
214                PoolState::Sealed(barrier) => {
215                    return TransitionResult {
216                        state: PoolState::Sealed(barrier),
217                        result: Err(EncryptedPoolError::InvalidState(StorageState::Sealed)),
218                    }
219                }
220                PoolState::Unsealed(barrier) => {
221                    return TransitionResult {
222                        state: PoolState::Unsealed(barrier),
223                        result: Err(EncryptedPoolError::InvalidState(StorageState::Unsealed)),
224                    }
225                }
226            };
227
228            match barrier.initialize() {
229                Ok(res) => TransitionResult {
230                    state: PoolState::Sealed(res.sealed_storage),
231                    result: Ok(res.master_key),
232                },
233                Err(barrier) => TransitionResult {
234                    state: PoolState::Uninitialized(barrier),
235                    result: Err(EncryptedPoolError::Transition {
236                        from: StorageState::Uninitialized,
237                        to: StorageState::Sealed,
238                    }),
239                },
240            }
241        })
242    }
243
244    /// Unseal the pool.
245    ///
246    /// # Errors
247    ///
248    /// Returns error if the pool is not sealed or the unseal process fails.
249    pub fn unseal(&self, master_key: String) -> Result<(), EncryptedPoolError> {
250        self.0.write(|barrier| {
251            let barrier = match barrier {
252                PoolState::Uninitialized(barrier) => {
253                    return TransitionResult {
254                        state: PoolState::Uninitialized(barrier),
255                        result: Err(EncryptedPoolError::InvalidState(
256                            StorageState::Uninitialized,
257                        )),
258                    }
259                }
260                PoolState::Sealed(barrier) => barrier,
261                PoolState::Unsealed(barrier) => {
262                    return TransitionResult {
263                        state: PoolState::Unsealed(barrier),
264                        result: Err(EncryptedPoolError::InvalidState(StorageState::Unsealed)),
265                    }
266                }
267            };
268
269            match barrier.unseal(master_key) {
270                Ok(barrier) => TransitionResult {
271                    state: PoolState::Unsealed(barrier),
272                    result: Ok(()),
273                },
274                Err(barrier) => TransitionResult {
275                    state: PoolState::Sealed(barrier),
276                    result: Err(EncryptedPoolError::Transition {
277                        from: StorageState::Sealed,
278                        to: StorageState::Unsealed,
279                    }),
280                },
281            }
282        })
283    }
284
285    /// Seal the pool.
286    ///
287    /// # Errors
288    ///
289    /// Returns error if the pool is not unsealed.
290    pub fn seal(&self) -> Result<(), EncryptedPoolError> {
291        self.0.write(|barrier| {
292            let barrier = match barrier {
293                PoolState::Uninitialized(barrier) => {
294                    return TransitionResult {
295                        state: PoolState::Uninitialized(barrier),
296                        result: Err(EncryptedPoolError::InvalidState(
297                            StorageState::Uninitialized,
298                        )),
299                    }
300                }
301                PoolState::Sealed(barrier) => {
302                    return TransitionResult {
303                        state: PoolState::Sealed(barrier),
304                        result: Err(EncryptedPoolError::InvalidState(StorageState::Sealed)),
305                    }
306                }
307                PoolState::Unsealed(barrier) => barrier,
308            };
309
310            let barrier = barrier.seal();
311            TransitionResult {
312                state: PoolState::Sealed(barrier),
313                result: Ok(()),
314            }
315        })
316    }
317
318    fn pool(&self) -> Result<Pool<Sqlite>, sqlx::Error> {
319        self.0
320            .read()
321            .get_unsealed()
322            .map(|storage| storage.state.pool.clone())
323            .map_err(|_| sqlx::Error::PoolClosed)
324    }
325
326    /// Retrieves a connection and immediately begins a new transaction.
327    ///
328    /// # Errors
329    ///
330    /// Returns error if it is unable to retrieve the db pool or start the
331    /// transaction.
332    pub async fn begin(&self) -> Result<Transaction<'static, Sqlite>, sqlx::Error> {
333        let pool = self.pool()?;
334        pool.begin().await
335    }
336}
337
338impl From<&PoolState> for StorageState {
339    fn from(barrier: &PoolState) -> Self {
340        match barrier {
341            PoolState::Uninitialized(_) => StorageState::Uninitialized,
342            PoolState::Sealed(_) => StorageState::Sealed,
343            PoolState::Unsealed(_) => StorageState::Unsealed,
344        }
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    #[sqlx::test]
353    async fn unseal_and_query() {
354        let query = "SELECT count(*) FROM sqlite_master";
355
356        let pool = EncryptedPool::new(&":memory:".to_string());
357
358        let res = sqlx::query(query).execute(&pool).await;
359        assert!(matches!(res.unwrap_err(), sqlx::Error::PoolClosed));
360
361        let master_key = pool.initialize().unwrap().unwrap();
362        let res = sqlx::query(query).execute(&pool).await;
363        assert!(matches!(res.unwrap_err(), sqlx::Error::PoolClosed));
364
365        // Unseal and we should get a success response
366        pool.unseal(master_key.clone()).unwrap();
367        let res = sqlx::query(query).execute(&pool).await;
368        assert!(res.is_ok());
369
370        // Seal and we should not be able to query
371        pool.seal().unwrap();
372        let res = sqlx::query(query).execute(&pool).await;
373        assert!(matches!(res.unwrap_err(), sqlx::Error::PoolClosed));
374
375        // Unseal again and we should get a success response
376        pool.unseal(master_key).unwrap();
377        let res = sqlx::query(query).execute(&pool).await;
378        assert!(res.is_ok());
379    }
380}