cdk_sql_common/
database.rs1use std::fmt::Debug;
4use std::marker::PhantomData;
5use std::ops::{Deref, DerefMut};
6
7use cdk_common::database::Error;
8
9use crate::stmt::{query, Column, Statement};
10
11#[async_trait::async_trait]
15pub trait DatabaseExecutor: Debug + Sync + Send {
16 fn name() -> &'static str;
18
19 async fn execute(&self, statement: Statement) -> Result<usize, Error>;
21
22 async fn fetch_one(&self, statement: Statement) -> Result<Option<Vec<Column>>, Error>;
24
25 async fn fetch_all(&self, statement: Statement) -> Result<Vec<Vec<Column>>, Error>;
27
28 async fn pluck(&self, statement: Statement) -> Result<Option<Column>, Error>;
30
31 async fn batch(&self, statement: Statement) -> Result<(), Error>;
33}
34
35#[async_trait::async_trait]
37pub trait DatabaseTransaction<DB>
38where
39 DB: DatabaseExecutor,
40{
41 async fn commit(conn: &mut DB) -> Result<(), Error>;
43
44 async fn begin(conn: &mut DB) -> Result<(), Error>;
46
47 async fn rollback(conn: &mut DB) -> Result<(), Error>;
49}
50
51#[derive(Debug)]
53pub struct ConnectionWithTransaction<DB, W>
54where
55 DB: DatabaseConnector + 'static,
56 W: Debug + Deref<Target = DB> + DerefMut<Target = DB> + Send + Sync + 'static,
57{
58 inner: Option<W>,
59}
60
61impl<DB, W> ConnectionWithTransaction<DB, W>
62where
63 DB: DatabaseConnector,
64 W: Debug + Deref<Target = DB> + DerefMut<Target = DB> + Send + Sync + 'static,
65{
66 pub async fn new(mut inner: W) -> Result<Self, Error> {
68 DB::Transaction::begin(inner.deref_mut()).await?;
69 Ok(Self { inner: Some(inner) })
70 }
71
72 pub async fn commit(mut self) -> Result<(), Error> {
75 let mut conn = self
76 .inner
77 .take()
78 .ok_or(Error::Internal("Missing connection".to_owned()))?;
79
80 DB::Transaction::commit(&mut conn).await?;
81
82 Ok(())
83 }
84
85 pub async fn rollback(mut self) -> Result<(), Error> {
88 let mut conn = self
89 .inner
90 .take()
91 .ok_or(Error::Internal("Missing connection".to_owned()))?;
92
93 DB::Transaction::rollback(&mut conn).await?;
94
95 Ok(())
96 }
97}
98
99impl<DB, W> Drop for ConnectionWithTransaction<DB, W>
100where
101 DB: DatabaseConnector,
102 W: Debug + Deref<Target = DB> + DerefMut<Target = DB> + Send + Sync + 'static,
103{
104 fn drop(&mut self) {
105 if let Some(mut conn) = self.inner.take() {
106 tokio::spawn(async move {
107 let _ = DB::Transaction::rollback(conn.deref_mut()).await;
108 });
109 }
110 }
111}
112
113#[async_trait::async_trait]
114impl<DB, W> DatabaseExecutor for ConnectionWithTransaction<DB, W>
115where
116 DB: DatabaseConnector,
117 W: Debug + Deref<Target = DB> + DerefMut<Target = DB> + Send + Sync + 'static,
118{
119 fn name() -> &'static str {
120 "Transaction"
121 }
122
123 async fn execute(&self, statement: Statement) -> Result<usize, Error> {
125 self.inner
126 .as_ref()
127 .ok_or(Error::Internal("Missing internal connection".to_owned()))?
128 .execute(statement)
129 .await
130 }
131
132 async fn fetch_one(&self, statement: Statement) -> Result<Option<Vec<Column>>, Error> {
134 self.inner
135 .as_ref()
136 .ok_or(Error::Internal("Missing internal connection".to_owned()))?
137 .fetch_one(statement)
138 .await
139 }
140
141 async fn fetch_all(&self, statement: Statement) -> Result<Vec<Vec<Column>>, Error> {
143 self.inner
144 .as_ref()
145 .ok_or(Error::Internal("Missing internal connection".to_owned()))?
146 .fetch_all(statement)
147 .await
148 }
149
150 async fn pluck(&self, statement: Statement) -> Result<Option<Column>, Error> {
152 self.inner
153 .as_ref()
154 .ok_or(Error::Internal("Missing internal connection".to_owned()))?
155 .pluck(statement)
156 .await
157 }
158
159 async fn batch(&self, statement: Statement) -> Result<(), Error> {
161 self.inner
162 .as_ref()
163 .ok_or(Error::Internal("Missing internal connection".to_owned()))?
164 .batch(statement)
165 .await
166 }
167}
168
169pub struct GenericTransactionHandler<W>(PhantomData<W>);
171
172#[async_trait::async_trait]
173impl<W> DatabaseTransaction<W> for GenericTransactionHandler<W>
174where
175 W: DatabaseExecutor,
176{
177 async fn commit(conn: &mut W) -> Result<(), Error> {
179 query("COMMIT")?.execute(conn).await?;
180 Ok(())
181 }
182
183 async fn begin(conn: &mut W) -> Result<(), Error> {
185 query("START TRANSACTION")?.execute(conn).await?;
186 Ok(())
187 }
188
189 async fn rollback(conn: &mut W) -> Result<(), Error> {
191 query("ROLLBACK")?.execute(conn).await?;
192 Ok(())
193 }
194}
195
196#[async_trait::async_trait]
198pub trait DatabaseConnector: Debug + DatabaseExecutor + Send + Sync {
199 type Transaction: DatabaseTransaction<Self>
201 where
202 Self: Sized;
203}