1use std::sync::{Arc, Mutex};
2
3use crate::driver;
4use anyhow::Result;
5
6use super::{driver::AsyncDriver, ConnectionPool, Preparable, Statement};
7
8struct Inner<DB>
9where
10 DB: ConnectionPool + Sync + Send,
11{
12 tx: Box<dyn driver::Transaction>,
13 pub(crate) conn: Option<Box<dyn driver::Connection>>,
14 db: DB,
15}
16
17impl<DB> Drop for Inner<DB>
18where
19 DB: ConnectionPool + Sync + Send,
20{
21 fn drop(&mut self) {
22 if let Some(conn) = self.conn.take() {
23 self.db.release_conn(conn);
24 }
25 }
26}
27
28#[derive(Clone)]
30pub struct Transaction<DB>
31where
32 DB: ConnectionPool + Sync + Send,
33{
34 inner: Arc<Mutex<Inner<DB>>>,
35 driver_name: String,
36 conn_url: String,
37}
38
39impl<DB> Transaction<DB>
40where
41 DB: ConnectionPool + Sync + Send,
42{
43 pub(crate) fn new(
44 driver_name: String,
45 conn_url: String,
46 tx: Box<dyn driver::Transaction>,
47 conn: Option<Box<dyn driver::Connection>>,
48 db: DB,
49 ) -> Self {
50 Self {
51 inner: Arc::new(Mutex::new(Inner { tx, conn, db })),
52 driver_name,
53 conn_url,
54 }
55 }
56
57 pub async fn commit(&mut self) -> Result<()> {
58 let async_driver = AsyncDriver::new();
59 self.inner
60 .lock()
61 .unwrap()
62 .tx
63 .commit(async_driver.callback());
64
65 async_driver.await
66 }
67
68 pub async fn rollback(&mut self) -> Result<()> {
69 let async_driver = AsyncDriver::new();
70 self.inner
71 .lock()
72 .unwrap()
73 .tx
74 .rollback(async_driver.callback());
75
76 async_driver.await
77 }
78}
79
80#[async_trait::async_trait]
81impl<DB> Preparable for Transaction<DB>
82where
83 DB: ConnectionPool + Sync + Send + Clone,
84{
85 type DB = DB;
86 async fn prepare<S>(&mut self, query: S) -> anyhow::Result<Statement<Self::DB>>
87 where
88 S: Into<String> + Send,
89 {
90 let async_driver = AsyncDriver::new();
91
92 self.inner
93 .lock()
94 .unwrap()
95 .tx
96 .prepare(query.into(), async_driver.callback());
97
98 let stmt = async_driver.await?;
99
100 Ok(Statement::new(None, None, stmt))
101 }
102
103 fn driver_name(&self) -> &str {
104 &self.driver_name
105 }
106
107 fn conn_str(&self) -> &str {
108 &self.conn_url
109 }
110}