covert_storage/
backend_pool.rs

1use std::{borrow::Cow, sync::Arc};
2
3use sqlx::{
4    error::DatabaseError,
5    sqlite::{SqliteArguments, SqliteQueryResult, SqliteRow},
6    Arguments, Encode, Sqlite, Type,
7};
8
9use crate::{scoped_queries::ScopedQuery, EncryptedPool};
10
11#[derive(Debug, thiserror::Error)]
12pub enum CovertDatabaseError {
13    #[error("Unable to prefix query: `{query}` with prefix: `{prefix}`")]
14    BadPrefixQuery {
15        prefix: String,
16        query: String,
17        message: String,
18    },
19}
20
21impl DatabaseError for CovertDatabaseError {
22    #[inline]
23    fn code(&self) -> Option<Cow<'_, str>> {
24        // The SQLITE_ERROR result code is a generic error code that is used when no other more specific error code is available.
25        Some("1".into())
26    }
27
28    fn message(&self) -> &str {
29        match self {
30            CovertDatabaseError::BadPrefixQuery { message, .. } => message,
31        }
32    }
33
34    fn as_error(&self) -> &(dyn std::error::Error + Send + Sync + 'static) {
35        self
36    }
37
38    fn as_error_mut(&mut self) -> &mut (dyn std::error::Error + Send + Sync + 'static) {
39        self
40    }
41
42    fn into_error(self: Box<Self>) -> Box<dyn std::error::Error + Send + Sync + 'static> {
43        self
44    }
45}
46
47#[derive(Debug, Clone)]
48pub struct BackendStoragePool {
49    prefix: String,
50    pool: Arc<EncryptedPool>,
51}
52
53impl BackendStoragePool {
54    pub fn new(prefix: &str, pool: Arc<EncryptedPool>) -> Self {
55        Self {
56            prefix: prefix.to_string(),
57            pool,
58        }
59    }
60
61    /// Construct a prefixed query.
62    ///
63    /// # Errors
64    ///
65    /// Returns error if the sql query cannot be prefixed.
66    pub fn query(&self, sql: &impl ToString) -> Result<Query, sqlx::Error> {
67        ScopedQuery::new(&self.prefix, &sql.to_string())
68            .map(|query| Query {
69                query,
70                pool: Arc::clone(&self.pool),
71                arguments: SqliteArguments::default(),
72            })
73            .map_err(|err| {
74                let prefix = self.prefix.clone();
75                let query = sql.to_string();
76                let message = format!(
77                    "Unable to prefix query: `{query}` with prefix: `{prefix}`. Error: {err:?}"
78                );
79                sqlx::Error::Database(Box::new(CovertDatabaseError::BadPrefixQuery {
80                    prefix,
81                    query,
82                    message,
83                }))
84            })
85    }
86
87    #[must_use]
88    pub fn prefix(&self) -> &str {
89        &self.prefix
90    }
91}
92
93pub struct Query<'a> {
94    query: ScopedQuery,
95    pool: Arc<EncryptedPool>,
96    arguments: SqliteArguments<'a>,
97}
98
99impl<'a> Query<'a> {
100    pub fn bind<T: 'a + Send + Encode<'a, Sqlite> + Type<Sqlite>>(mut self, value: T) -> Self {
101        self.arguments.add(value);
102
103        self
104    }
105
106    pub async fn execute(self) -> Result<SqliteQueryResult, sqlx::Error> {
107        sqlx::query_with(self.query.sql(), self.arguments)
108            .execute(self.pool.as_ref())
109            .await
110    }
111
112    pub async fn fetch_one<T>(self) -> Result<T, sqlx::Error>
113    where
114        T: Send + for<'r> sqlx::FromRow<'r, SqliteRow> + Unpin,
115    {
116        sqlx::query_as_with(self.query.sql(), self.arguments)
117            .fetch_one(self.pool.as_ref())
118            .await
119    }
120
121    pub async fn fetch_all<T>(self) -> Result<Vec<T>, sqlx::Error>
122    where
123        T: Send + for<'r> sqlx::FromRow<'r, SqliteRow> + Unpin,
124    {
125        sqlx::query_as_with(self.query.sql(), self.arguments)
126            .fetch_all(self.pool.as_ref())
127            .await
128    }
129
130    pub async fn fetch_optional<T>(self) -> Result<Option<T>, sqlx::Error>
131    where
132        T: Send + for<'r> sqlx::FromRow<'r, SqliteRow> + Unpin,
133    {
134        sqlx::query_as_with(self.query.sql(), self.arguments)
135            .fetch_optional(self.pool.as_ref())
136            .await
137    }
138}