cdk_sql_common/
database.rs

1//! Database traits definition
2
3use std::fmt::Debug;
4use std::marker::PhantomData;
5use std::ops::{Deref, DerefMut};
6
7use cdk_common::database::Error;
8
9use crate::stmt::{query, Column, Statement};
10
11/// Database Executor
12///
13/// This trait defines the expectations of a database execution
14#[async_trait::async_trait]
15pub trait DatabaseExecutor: Debug + Sync + Send {
16    /// Database driver name
17    fn name() -> &'static str;
18
19    /// Executes a query and returns the affected rows
20    async fn execute(&self, statement: Statement) -> Result<usize, Error>;
21
22    /// Runs the query and returns the first row or None
23    async fn fetch_one(&self, statement: Statement) -> Result<Option<Vec<Column>>, Error>;
24
25    /// Runs the query and returns the first row or None
26    async fn fetch_all(&self, statement: Statement) -> Result<Vec<Vec<Column>>, Error>;
27
28    /// Fetches the first row and column from a query
29    async fn pluck(&self, statement: Statement) -> Result<Option<Column>, Error>;
30
31    /// Batch execution
32    async fn batch(&self, statement: Statement) -> Result<(), Error>;
33}
34
35/// Database transaction trait
36#[async_trait::async_trait]
37pub trait DatabaseTransaction<DB>
38where
39    DB: DatabaseExecutor,
40{
41    /// Consumes the current transaction committing the changes
42    async fn commit(conn: &mut DB) -> Result<(), Error>;
43
44    /// Begin a transaction
45    async fn begin(conn: &mut DB) -> Result<(), Error>;
46
47    /// Consumes the transaction rolling back all changes
48    async fn rollback(conn: &mut DB) -> Result<(), Error>;
49}
50
51/// Database connection with a transaction
52#[derive(Debug)]
53pub struct ConnectionWithTransaction<DB, W>
54where
55    DB: DatabaseConnector + 'static,
56    W: Debug + Deref<Target = DB> + DerefMut<Target = DB> + Send + Sync + 'static,
57{
58    inner: Option<W>,
59}
60
61impl<DB, W> ConnectionWithTransaction<DB, W>
62where
63    DB: DatabaseConnector,
64    W: Debug + Deref<Target = DB> + DerefMut<Target = DB> + Send + Sync + 'static,
65{
66    /// Creates a new transaction
67    pub async fn new(mut inner: W) -> Result<Self, Error> {
68        DB::Transaction::begin(inner.deref_mut()).await?;
69        Ok(Self { inner: Some(inner) })
70    }
71
72    /// Commits the transaction consuming it and releasing the connection back to the pool (or
73    /// disconnecting)
74    pub async fn commit(mut self) -> Result<(), Error> {
75        let mut conn = self
76            .inner
77            .take()
78            .ok_or(Error::Internal("Missing connection".to_owned()))?;
79
80        DB::Transaction::commit(&mut conn).await?;
81
82        Ok(())
83    }
84
85    /// Rollback the transaction consuming it and releasing the connection back to the pool (or
86    /// disconnecting)
87    pub async fn rollback(mut self) -> Result<(), Error> {
88        let mut conn = self
89            .inner
90            .take()
91            .ok_or(Error::Internal("Missing connection".to_owned()))?;
92
93        DB::Transaction::rollback(&mut conn).await?;
94
95        Ok(())
96    }
97}
98
99impl<DB, W> Drop for ConnectionWithTransaction<DB, W>
100where
101    DB: DatabaseConnector,
102    W: Debug + Deref<Target = DB> + DerefMut<Target = DB> + Send + Sync + 'static,
103{
104    fn drop(&mut self) {
105        if let Some(mut conn) = self.inner.take() {
106            tokio::spawn(async move {
107                let _ = DB::Transaction::rollback(conn.deref_mut()).await;
108            });
109        }
110    }
111}
112
113#[async_trait::async_trait]
114impl<DB, W> DatabaseExecutor for ConnectionWithTransaction<DB, W>
115where
116    DB: DatabaseConnector,
117    W: Debug + Deref<Target = DB> + DerefMut<Target = DB> + Send + Sync + 'static,
118{
119    fn name() -> &'static str {
120        "Transaction"
121    }
122
123    /// Executes a query and returns the affected rows
124    async fn execute(&self, statement: Statement) -> Result<usize, Error> {
125        self.inner
126            .as_ref()
127            .ok_or(Error::Internal("Missing internal connection".to_owned()))?
128            .execute(statement)
129            .await
130    }
131
132    /// Runs the query and returns the first row or None
133    async fn fetch_one(&self, statement: Statement) -> Result<Option<Vec<Column>>, Error> {
134        self.inner
135            .as_ref()
136            .ok_or(Error::Internal("Missing internal connection".to_owned()))?
137            .fetch_one(statement)
138            .await
139    }
140
141    /// Runs the query and returns the first row or None
142    async fn fetch_all(&self, statement: Statement) -> Result<Vec<Vec<Column>>, Error> {
143        self.inner
144            .as_ref()
145            .ok_or(Error::Internal("Missing internal connection".to_owned()))?
146            .fetch_all(statement)
147            .await
148    }
149
150    /// Fetches the first row and column from a query
151    async fn pluck(&self, statement: Statement) -> Result<Option<Column>, Error> {
152        self.inner
153            .as_ref()
154            .ok_or(Error::Internal("Missing internal connection".to_owned()))?
155            .pluck(statement)
156            .await
157    }
158
159    /// Batch execution
160    async fn batch(&self, statement: Statement) -> Result<(), Error> {
161        self.inner
162            .as_ref()
163            .ok_or(Error::Internal("Missing internal connection".to_owned()))?
164            .batch(statement)
165            .await
166    }
167}
168
169/// Generic transaction handler for SQLite
170pub struct GenericTransactionHandler<W>(PhantomData<W>);
171
172#[async_trait::async_trait]
173impl<W> DatabaseTransaction<W> for GenericTransactionHandler<W>
174where
175    W: DatabaseExecutor,
176{
177    /// Consumes the current transaction committing the changes
178    async fn commit(conn: &mut W) -> Result<(), Error> {
179        query("COMMIT")?.execute(conn).await?;
180        Ok(())
181    }
182
183    /// Begin a transaction
184    async fn begin(conn: &mut W) -> Result<(), Error> {
185        query("START TRANSACTION")?.execute(conn).await?;
186        Ok(())
187    }
188
189    /// Consumes the transaction rolling back all changes
190    async fn rollback(conn: &mut W) -> Result<(), Error> {
191        query("ROLLBACK")?.execute(conn).await?;
192        Ok(())
193    }
194}
195
196/// Database connector
197#[async_trait::async_trait]
198pub trait DatabaseConnector: Debug + DatabaseExecutor + Send + Sync {
199    /// Database static trait for the database
200    type Transaction: DatabaseTransaction<Self>
201    where
202        Self: Sized;
203}