use std::path::Path;
use covert_types::state::StorageState;
use futures::{future::BoxFuture, Stream};
use sqlx::{
pool::PoolConnection,
sqlite::{SqliteQueryResult, SqliteRow},
Pool, Sqlite, Transaction,
};
use crate::{
states::{Sealed, Uninitialized, Unsealed},
storage::{create_ecrypted_pool, create_master_key, Storage},
utils::owned_rw_lock::{OwnedRwLock, TransitionResult},
};
#[derive(Debug)]
pub struct EncryptedPool(OwnedRwLock<PoolState>);
struct PoolClosedStream;
impl Stream for PoolClosedStream {
type Item = Result<sqlx::Either<SqliteQueryResult, SqliteRow>, sqlx::Error>;
fn poll_next(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
std::task::Poll::Ready(Some(Err(sqlx::Error::PoolClosed)))
}
}
impl<'c> sqlx::Executor<'c> for &EncryptedPool {
type Database = Sqlite;
fn fetch_many<'e, 'q, E>(
self,
query: E,
) -> futures::stream::BoxStream<
'e,
Result<
sqlx::Either<
<Self::Database as sqlx::Database>::QueryResult,
<Self::Database as sqlx::Database>::Row,
>,
sqlx::Error,
>,
>
where
'c: 'e,
'q: 'e,
E: 'q + sqlx::Execute<'q, Self::Database>,
{
let Ok(pool) = self.pool() else {
return Box::pin(PoolClosedStream);
};
pool.fetch_many(query)
}
fn fetch_optional<'e, 'q, E>(
self,
query: E,
) -> futures::future::BoxFuture<
'e,
Result<Option<<Self::Database as sqlx::Database>::Row>, sqlx::Error>,
>
where
'c: 'e,
'q: 'e,
E: 'q + sqlx::Execute<'q, Self::Database>,
{
let pool = match self.pool() {
Ok(p) => p,
Err(err) => return Box::pin(async { Err(err) }),
};
pool.fetch_optional(query)
}
fn prepare_with<'e, 'q: 'e>(
self,
sql: &'q str,
parameters: &'e [<Self::Database as sqlx::Database>::TypeInfo],
) -> futures::future::BoxFuture<
'e,
Result<<Self::Database as sqlx::database::HasStatement<'q>>::Statement, sqlx::Error>,
>
where
'c: 'e,
{
let pool = match self.pool() {
Ok(p) => p,
Err(err) => return Box::pin(async { Err(err) }),
};
pool.prepare_with(sql, parameters)
}
fn describe<'e, 'q: 'e>(
self,
sql: &'q str,
) -> futures::future::BoxFuture<'e, Result<sqlx::Describe<Self::Database>, sqlx::Error>>
where
'c: 'e,
{
let pool = match self.pool() {
Ok(p) => p,
Err(err) => return Box::pin(async { Err(err) }),
};
pool.describe(sql)
}
}
impl<'c> sqlx::Acquire<'c> for &EncryptedPool {
type Database = Sqlite;
type Connection = PoolConnection<Sqlite>;
fn acquire(self) -> BoxFuture<'c, Result<Self::Connection, sqlx::Error>> {
let pool = match self.pool() {
Ok(p) => p,
Err(err) => return Box::pin(async { Err(err) }),
};
Box::pin(pool.acquire())
}
fn begin(self) -> BoxFuture<'c, Result<Transaction<'c, Self::Database>, sqlx::Error>> {
let pool = match self.pool() {
Ok(p) => p,
Err(err) => return Box::pin(async { Err(err) }),
};
Box::pin(async move { pool.begin().await })
}
}
#[derive(Debug)]
pub enum PoolState {
Uninitialized(Storage<Uninitialized>),
Sealed(Storage<Sealed>),
Unsealed(Storage<Unsealed>),
}
impl PoolState {
pub fn get_unsealed(&self) -> Result<&Storage<Unsealed>, EncryptedPoolError> {
match self {
PoolState::Uninitialized(_) => Err(EncryptedPoolError::InvalidState(
StorageState::Uninitialized,
)),
PoolState::Sealed(_) => Err(EncryptedPoolError::InvalidState(StorageState::Sealed)),
PoolState::Unsealed(b) => Ok(b),
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum EncryptedPoolError {
#[error("This operation is not allowed when the current state is `{0}`")]
InvalidState(StorageState),
#[error("Failed to transition the pool state from `{from}` to `{to}`")]
Transition {
from: StorageState,
to: StorageState,
},
}
impl EncryptedPool {
pub fn new(storage_path: &impl ToString) -> Self {
let storage_path = storage_path.to_string();
if Path::new(&storage_path).exists() {
Self(OwnedRwLock::new(PoolState::Sealed(Storage {
state: Sealed,
storage_path,
})))
} else {
Self(OwnedRwLock::new(PoolState::Uninitialized(Storage {
state: Uninitialized,
storage_path,
})))
}
}
#[must_use]
pub fn new_tmp() -> Self {
let storage_path = ":memory:".to_string();
let master_key = create_master_key();
let pool = create_ecrypted_pool(true, &storage_path, master_key)
.expect("to create encrypted pool and this should only be used for testing");
Self(OwnedRwLock::new(PoolState::Unsealed(Storage {
state: Unsealed { pool },
storage_path,
})))
}
pub fn state(&self) -> StorageState {
#[allow(clippy::redundant_closure_for_method_calls)]
self.0.map(|barrier| barrier.into())
}
pub fn initialize(&self) -> Result<Option<String>, EncryptedPoolError> {
self.0.write(|barrier| {
let barrier = match barrier {
PoolState::Uninitialized(barrier) => barrier,
PoolState::Sealed(barrier) => {
return TransitionResult {
state: PoolState::Sealed(barrier),
result: Err(EncryptedPoolError::InvalidState(StorageState::Sealed)),
}
}
PoolState::Unsealed(barrier) => {
return TransitionResult {
state: PoolState::Unsealed(barrier),
result: Err(EncryptedPoolError::InvalidState(StorageState::Unsealed)),
}
}
};
match barrier.initialize() {
Ok(res) => TransitionResult {
state: PoolState::Sealed(res.sealed_storage),
result: Ok(res.master_key),
},
Err(barrier) => TransitionResult {
state: PoolState::Uninitialized(barrier),
result: Err(EncryptedPoolError::Transition {
from: StorageState::Uninitialized,
to: StorageState::Sealed,
}),
},
}
})
}
pub fn unseal(&self, master_key: String) -> Result<(), EncryptedPoolError> {
self.0.write(|barrier| {
let barrier = match barrier {
PoolState::Uninitialized(barrier) => {
return TransitionResult {
state: PoolState::Uninitialized(barrier),
result: Err(EncryptedPoolError::InvalidState(
StorageState::Uninitialized,
)),
}
}
PoolState::Sealed(barrier) => barrier,
PoolState::Unsealed(barrier) => {
return TransitionResult {
state: PoolState::Unsealed(barrier),
result: Err(EncryptedPoolError::InvalidState(StorageState::Unsealed)),
}
}
};
match barrier.unseal(master_key) {
Ok(barrier) => TransitionResult {
state: PoolState::Unsealed(barrier),
result: Ok(()),
},
Err(barrier) => TransitionResult {
state: PoolState::Sealed(barrier),
result: Err(EncryptedPoolError::Transition {
from: StorageState::Sealed,
to: StorageState::Unsealed,
}),
},
}
})
}
pub fn seal(&self) -> Result<(), EncryptedPoolError> {
self.0.write(|barrier| {
let barrier = match barrier {
PoolState::Uninitialized(barrier) => {
return TransitionResult {
state: PoolState::Uninitialized(barrier),
result: Err(EncryptedPoolError::InvalidState(
StorageState::Uninitialized,
)),
}
}
PoolState::Sealed(barrier) => {
return TransitionResult {
state: PoolState::Sealed(barrier),
result: Err(EncryptedPoolError::InvalidState(StorageState::Sealed)),
}
}
PoolState::Unsealed(barrier) => barrier,
};
let barrier = barrier.seal();
TransitionResult {
state: PoolState::Sealed(barrier),
result: Ok(()),
}
})
}
fn pool(&self) -> Result<Pool<Sqlite>, sqlx::Error> {
self.0
.read()
.get_unsealed()
.map(|storage| storage.state.pool.clone())
.map_err(|_| sqlx::Error::PoolClosed)
}
pub async fn begin(&self) -> Result<Transaction<'static, Sqlite>, sqlx::Error> {
let pool = self.pool()?;
pool.begin().await
}
}
impl From<&PoolState> for StorageState {
fn from(barrier: &PoolState) -> Self {
match barrier {
PoolState::Uninitialized(_) => StorageState::Uninitialized,
PoolState::Sealed(_) => StorageState::Sealed,
PoolState::Unsealed(_) => StorageState::Unsealed,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[sqlx::test]
async fn unseal_and_query() {
let query = "SELECT count(*) FROM sqlite_master";
let pool = EncryptedPool::new(&":memory:".to_string());
let res = sqlx::query(query).execute(&pool).await;
assert!(matches!(res.unwrap_err(), sqlx::Error::PoolClosed));
let master_key = pool.initialize().unwrap().unwrap();
let res = sqlx::query(query).execute(&pool).await;
assert!(matches!(res.unwrap_err(), sqlx::Error::PoolClosed));
pool.unseal(master_key.clone()).unwrap();
let res = sqlx::query(query).execute(&pool).await;
assert!(res.is_ok());
pool.seal().unwrap();
let res = sqlx::query(query).execute(&pool).await;
assert!(matches!(res.unwrap_err(), sqlx::Error::PoolClosed));
pool.unseal(master_key).unwrap();
let res = sqlx::query(query).execute(&pool).await;
assert!(res.is_ok());
}
}