1use std::path::Path;
2
3use covert_types::state::StorageState;
4use futures::{future::BoxFuture, Stream};
5use sqlx::{
6 pool::PoolConnection,
7 sqlite::{SqliteQueryResult, SqliteRow},
8 Pool, Sqlite, Transaction,
9};
10
11use crate::{
12 states::{Sealed, Uninitialized, Unsealed},
13 storage::{create_ecrypted_pool, create_master_key, Storage},
14 utils::owned_rw_lock::{OwnedRwLock, TransitionResult},
15};
16
17#[derive(Debug)]
18pub struct EncryptedPool(OwnedRwLock<PoolState>);
19
20struct PoolClosedStream;
21
22impl Stream for PoolClosedStream {
23 type Item = Result<sqlx::Either<SqliteQueryResult, SqliteRow>, sqlx::Error>;
24
25 fn poll_next(
26 self: std::pin::Pin<&mut Self>,
27 _cx: &mut std::task::Context<'_>,
28 ) -> std::task::Poll<Option<Self::Item>> {
29 std::task::Poll::Ready(Some(Err(sqlx::Error::PoolClosed)))
30 }
31}
32
33impl<'c> sqlx::Executor<'c> for &EncryptedPool {
34 type Database = Sqlite;
35
36 fn fetch_many<'e, 'q, E>(
37 self,
38 query: E,
39 ) -> futures::stream::BoxStream<
40 'e,
41 Result<
42 sqlx::Either<
43 <Self::Database as sqlx::Database>::QueryResult,
44 <Self::Database as sqlx::Database>::Row,
45 >,
46 sqlx::Error,
47 >,
48 >
49 where
50 'c: 'e,
51 'q: 'e,
52 E: 'q + sqlx::Execute<'q, Self::Database>,
53 {
54 let Ok(pool) = self.pool() else {
55 return Box::pin(PoolClosedStream);
56 };
57 pool.fetch_many(query)
58 }
59
60 fn fetch_optional<'e, 'q, E>(
61 self,
62 query: E,
63 ) -> futures::future::BoxFuture<
64 'e,
65 Result<Option<<Self::Database as sqlx::Database>::Row>, sqlx::Error>,
66 >
67 where
68 'c: 'e,
69 'q: 'e,
70 E: 'q + sqlx::Execute<'q, Self::Database>,
71 {
72 let pool = match self.pool() {
73 Ok(p) => p,
74 Err(err) => return Box::pin(async { Err(err) }),
75 };
76 pool.fetch_optional(query)
77 }
78
79 fn prepare_with<'e, 'q: 'e>(
80 self,
81 sql: &'q str,
82 parameters: &'e [<Self::Database as sqlx::Database>::TypeInfo],
83 ) -> futures::future::BoxFuture<
84 'e,
85 Result<<Self::Database as sqlx::database::HasStatement<'q>>::Statement, sqlx::Error>,
86 >
87 where
88 'c: 'e,
89 {
90 let pool = match self.pool() {
91 Ok(p) => p,
92 Err(err) => return Box::pin(async { Err(err) }),
93 };
94 pool.prepare_with(sql, parameters)
95 }
96
97 fn describe<'e, 'q: 'e>(
98 self,
99 sql: &'q str,
100 ) -> futures::future::BoxFuture<'e, Result<sqlx::Describe<Self::Database>, sqlx::Error>>
101 where
102 'c: 'e,
103 {
104 let pool = match self.pool() {
105 Ok(p) => p,
106 Err(err) => return Box::pin(async { Err(err) }),
107 };
108 pool.describe(sql)
109 }
110}
111
112impl<'c> sqlx::Acquire<'c> for &EncryptedPool {
113 type Database = Sqlite;
114
115 type Connection = PoolConnection<Sqlite>;
116
117 fn acquire(self) -> BoxFuture<'c, Result<Self::Connection, sqlx::Error>> {
118 let pool = match self.pool() {
119 Ok(p) => p,
120 Err(err) => return Box::pin(async { Err(err) }),
121 };
122 Box::pin(pool.acquire())
123 }
124
125 fn begin(self) -> BoxFuture<'c, Result<Transaction<'c, Self::Database>, sqlx::Error>> {
126 let pool = match self.pool() {
127 Ok(p) => p,
128 Err(err) => return Box::pin(async { Err(err) }),
129 };
130 Box::pin(async move { pool.begin().await })
131 }
132}
133
134#[derive(Debug)]
135pub enum PoolState {
136 Uninitialized(Storage<Uninitialized>),
137 Sealed(Storage<Sealed>),
138 Unsealed(Storage<Unsealed>),
139}
140
141impl PoolState {
142 pub fn get_unsealed(&self) -> Result<&Storage<Unsealed>, EncryptedPoolError> {
148 match self {
149 PoolState::Uninitialized(_) => Err(EncryptedPoolError::InvalidState(
150 StorageState::Uninitialized,
151 )),
152 PoolState::Sealed(_) => Err(EncryptedPoolError::InvalidState(StorageState::Sealed)),
153 PoolState::Unsealed(b) => Ok(b),
154 }
155 }
156}
157
158#[derive(Debug, thiserror::Error)]
159pub enum EncryptedPoolError {
160 #[error("This operation is not allowed when the current state is `{0}`")]
161 InvalidState(StorageState),
162 #[error("Failed to transition the pool state from `{from}` to `{to}`")]
163 Transition {
164 from: StorageState,
165 to: StorageState,
166 },
167}
168
169impl EncryptedPool {
170 pub fn new(storage_path: &impl ToString) -> Self {
171 let storage_path = storage_path.to_string();
172
173 if Path::new(&storage_path).exists() {
174 Self(OwnedRwLock::new(PoolState::Sealed(Storage {
175 state: Sealed,
176 storage_path,
177 })))
178 } else {
179 Self(OwnedRwLock::new(PoolState::Uninitialized(Storage {
180 state: Uninitialized,
181 storage_path,
182 })))
183 }
184 }
185
186 #[must_use]
188 pub fn new_tmp() -> Self {
189 let storage_path = ":memory:".to_string();
190 let master_key = create_master_key();
191 let pool = create_ecrypted_pool(true, &storage_path, master_key)
192 .expect("to create encrypted pool and this should only be used for testing");
193
194 Self(OwnedRwLock::new(PoolState::Unsealed(Storage {
195 state: Unsealed { pool },
196 storage_path,
197 })))
198 }
199
200 pub fn state(&self) -> StorageState {
201 #[allow(clippy::redundant_closure_for_method_calls)]
202 self.0.map(|barrier| barrier.into())
203 }
204
205 pub fn initialize(&self) -> Result<Option<String>, EncryptedPoolError> {
211 self.0.write(|barrier| {
212 let barrier = match barrier {
213 PoolState::Uninitialized(barrier) => barrier,
214 PoolState::Sealed(barrier) => {
215 return TransitionResult {
216 state: PoolState::Sealed(barrier),
217 result: Err(EncryptedPoolError::InvalidState(StorageState::Sealed)),
218 }
219 }
220 PoolState::Unsealed(barrier) => {
221 return TransitionResult {
222 state: PoolState::Unsealed(barrier),
223 result: Err(EncryptedPoolError::InvalidState(StorageState::Unsealed)),
224 }
225 }
226 };
227
228 match barrier.initialize() {
229 Ok(res) => TransitionResult {
230 state: PoolState::Sealed(res.sealed_storage),
231 result: Ok(res.master_key),
232 },
233 Err(barrier) => TransitionResult {
234 state: PoolState::Uninitialized(barrier),
235 result: Err(EncryptedPoolError::Transition {
236 from: StorageState::Uninitialized,
237 to: StorageState::Sealed,
238 }),
239 },
240 }
241 })
242 }
243
244 pub fn unseal(&self, master_key: String) -> Result<(), EncryptedPoolError> {
250 self.0.write(|barrier| {
251 let barrier = match barrier {
252 PoolState::Uninitialized(barrier) => {
253 return TransitionResult {
254 state: PoolState::Uninitialized(barrier),
255 result: Err(EncryptedPoolError::InvalidState(
256 StorageState::Uninitialized,
257 )),
258 }
259 }
260 PoolState::Sealed(barrier) => barrier,
261 PoolState::Unsealed(barrier) => {
262 return TransitionResult {
263 state: PoolState::Unsealed(barrier),
264 result: Err(EncryptedPoolError::InvalidState(StorageState::Unsealed)),
265 }
266 }
267 };
268
269 match barrier.unseal(master_key) {
270 Ok(barrier) => TransitionResult {
271 state: PoolState::Unsealed(barrier),
272 result: Ok(()),
273 },
274 Err(barrier) => TransitionResult {
275 state: PoolState::Sealed(barrier),
276 result: Err(EncryptedPoolError::Transition {
277 from: StorageState::Sealed,
278 to: StorageState::Unsealed,
279 }),
280 },
281 }
282 })
283 }
284
285 pub fn seal(&self) -> Result<(), EncryptedPoolError> {
291 self.0.write(|barrier| {
292 let barrier = match barrier {
293 PoolState::Uninitialized(barrier) => {
294 return TransitionResult {
295 state: PoolState::Uninitialized(barrier),
296 result: Err(EncryptedPoolError::InvalidState(
297 StorageState::Uninitialized,
298 )),
299 }
300 }
301 PoolState::Sealed(barrier) => {
302 return TransitionResult {
303 state: PoolState::Sealed(barrier),
304 result: Err(EncryptedPoolError::InvalidState(StorageState::Sealed)),
305 }
306 }
307 PoolState::Unsealed(barrier) => barrier,
308 };
309
310 let barrier = barrier.seal();
311 TransitionResult {
312 state: PoolState::Sealed(barrier),
313 result: Ok(()),
314 }
315 })
316 }
317
318 fn pool(&self) -> Result<Pool<Sqlite>, sqlx::Error> {
319 self.0
320 .read()
321 .get_unsealed()
322 .map(|storage| storage.state.pool.clone())
323 .map_err(|_| sqlx::Error::PoolClosed)
324 }
325
326 pub async fn begin(&self) -> Result<Transaction<'static, Sqlite>, sqlx::Error> {
333 let pool = self.pool()?;
334 pool.begin().await
335 }
336}
337
338impl From<&PoolState> for StorageState {
339 fn from(barrier: &PoolState) -> Self {
340 match barrier {
341 PoolState::Uninitialized(_) => StorageState::Uninitialized,
342 PoolState::Sealed(_) => StorageState::Sealed,
343 PoolState::Unsealed(_) => StorageState::Unsealed,
344 }
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351
352 #[sqlx::test]
353 async fn unseal_and_query() {
354 let query = "SELECT count(*) FROM sqlite_master";
355
356 let pool = EncryptedPool::new(&":memory:".to_string());
357
358 let res = sqlx::query(query).execute(&pool).await;
359 assert!(matches!(res.unwrap_err(), sqlx::Error::PoolClosed));
360
361 let master_key = pool.initialize().unwrap().unwrap();
362 let res = sqlx::query(query).execute(&pool).await;
363 assert!(matches!(res.unwrap_err(), sqlx::Error::PoolClosed));
364
365 pool.unseal(master_key.clone()).unwrap();
367 let res = sqlx::query(query).execute(&pool).await;
368 assert!(res.is_ok());
369
370 pool.seal().unwrap();
372 let res = sqlx::query(query).execute(&pool).await;
373 assert!(matches!(res.unwrap_err(), sqlx::Error::PoolClosed));
374
375 pool.unseal(master_key).unwrap();
377 let res = sqlx::query(query).execute(&pool).await;
378 assert!(res.is_ok());
379 }
380}