rdbc_rs/future/
tx.rs

1use std::sync::{Arc, Mutex};
2
3use crate::driver;
4use anyhow::Result;
5
6use super::{driver::AsyncDriver, ConnectionPool, Preparable, Statement};
7
8struct Inner<DB>
9where
10    DB: ConnectionPool + Sync + Send,
11{
12    tx: Box<dyn driver::Transaction>,
13    pub(crate) conn: Option<Box<dyn driver::Connection>>,
14    db: DB,
15}
16
17impl<DB> Drop for Inner<DB>
18where
19    DB: ConnectionPool + Sync + Send,
20{
21    fn drop(&mut self) {
22        if let Some(conn) = self.conn.take() {
23            self.db.release_conn(conn);
24        }
25    }
26}
27
28/// Asynchronous wrapper type for [`crate::driver::Transaction`]
29#[derive(Clone)]
30pub struct Transaction<DB>
31where
32    DB: ConnectionPool + Sync + Send,
33{
34    inner: Arc<Mutex<Inner<DB>>>,
35    driver_name: String,
36    conn_url: String,
37}
38
39impl<DB> Transaction<DB>
40where
41    DB: ConnectionPool + Sync + Send,
42{
43    pub(crate) fn new(
44        driver_name: String,
45        conn_url: String,
46        tx: Box<dyn driver::Transaction>,
47        conn: Option<Box<dyn driver::Connection>>,
48        db: DB,
49    ) -> Self {
50        Self {
51            inner: Arc::new(Mutex::new(Inner { tx, conn, db })),
52            driver_name,
53            conn_url,
54        }
55    }
56
57    pub async fn commit(&mut self) -> Result<()> {
58        let async_driver = AsyncDriver::new();
59        self.inner
60            .lock()
61            .unwrap()
62            .tx
63            .commit(async_driver.callback());
64
65        async_driver.await
66    }
67
68    pub async fn rollback(&mut self) -> Result<()> {
69        let async_driver = AsyncDriver::new();
70        self.inner
71            .lock()
72            .unwrap()
73            .tx
74            .rollback(async_driver.callback());
75
76        async_driver.await
77    }
78}
79
80#[async_trait::async_trait]
81impl<DB> Preparable for Transaction<DB>
82where
83    DB: ConnectionPool + Sync + Send + Clone,
84{
85    type DB = DB;
86    async fn prepare<S>(&mut self, query: S) -> anyhow::Result<Statement<Self::DB>>
87    where
88        S: Into<String> + Send,
89    {
90        let async_driver = AsyncDriver::new();
91
92        self.inner
93            .lock()
94            .unwrap()
95            .tx
96            .prepare(query.into(), async_driver.callback());
97
98        let stmt = async_driver.await?;
99
100        Ok(Statement::new(None, None, stmt))
101    }
102
103    fn driver_name(&self) -> &str {
104        &self.driver_name
105    }
106
107    fn conn_str(&self) -> &str {
108        &self.conn_url
109    }
110}