bottle_orm/
transaction.rs1use heck::ToSnakeCase;
11use std::sync::Arc;
12use tokio::sync::Mutex;
13use futures::future::BoxFuture;
14use sqlx::any::AnyArguments;
15
16use crate::{
21 database::{Connection, Drivers, RawQuery},
22 Model, QueryBuilder,
23};
24
25#[derive(Debug, Clone)]
34pub struct Transaction<'a> {
35 pub(crate) tx: Arc<Mutex<Option<sqlx::Transaction<'a, sqlx::Any>>>>,
36 pub(crate) pool: sqlx::AnyPool,
37 pub(crate) driver: Drivers,
38}
39
40impl Connection for Transaction<'_> {
48 fn driver(&self) -> Drivers { self.driver }
49 fn execute<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<sqlx::any::AnyQueryResult, sqlx::Error>> {
50 Box::pin(async move {
51 let mut guard = self.tx.lock().await;
52 if let Some(tx) = guard.as_mut() {
53 sqlx::query_with(sql, args).execute(&mut **tx).await
54 } else {
55 Err(sqlx::Error::WorkerCrashed)
56 }
57 })
58 }
59
60 fn fetch_all<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<Vec<sqlx::any::AnyRow>, sqlx::Error>> {
61 Box::pin(async move {
62 let mut guard = self.tx.lock().await;
63 if let Some(tx) = guard.as_mut() {
64 sqlx::query_with(sql, args).fetch_all(&mut **tx).await
65 } else {
66 Err(sqlx::Error::WorkerCrashed)
67 }
68 })
69 }
70
71 fn fetch_one<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<sqlx::any::AnyRow, sqlx::Error>> {
72 Box::pin(async move {
73 let mut guard = self.tx.lock().await;
74 if let Some(tx) = guard.as_mut() {
75 sqlx::query_with(sql, args).fetch_one(&mut **tx).await
76 } else {
77 Err(sqlx::Error::WorkerCrashed)
78 }
79 })
80 }
81
82 fn fetch_optional<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<Option<sqlx::any::AnyRow>, sqlx::Error>> {
83 Box::pin(async move {
84 let mut guard = self.tx.lock().await;
85 if let Some(tx) = guard.as_mut() {
86 sqlx::query_with(sql, args).fetch_optional(&mut **tx).await
87 } else {
88 Err(sqlx::Error::WorkerCrashed)
89 }
90 })
91 }
92
93 fn clone_db(&self) -> crate::Database {
94 crate::Database {
95 pool: self.pool.clone(),
96 driver: self.driver,
97 }
98 }
99}
100
101impl<'a> Transaction<'a> {
106 pub fn model<T: Model + Send + Sync + Unpin + crate::AnyImpl>(
108 &self,
109 ) -> QueryBuilder<T, Self> {
110 let active_columns = T::active_columns();
111 let mut columns: Vec<String> = Vec::with_capacity(active_columns.capacity());
112
113 for col in active_columns {
114 columns.push(col.strip_prefix("r#").unwrap_or(col).to_snake_case());
115 }
116
117 QueryBuilder::new(self.clone(), self.driver, T::table_name(), <T as Model>::columns(), columns)
118 }
119
120 pub fn raw<'b>(&self, sql: &'b str) -> RawQuery<'b, Self> {
122 RawQuery::new(self.clone(), sql)
123 }
124
125 pub async fn commit(self) -> Result<(), sqlx::Error> {
127 let mut guard = self.tx.lock().await;
128 if let Some(tx) = guard.take() {
129 tx.commit().await
130 } else {
131 Ok(())
132 }
133 }
134
135 pub async fn rollback(self) -> Result<(), sqlx::Error> {
137 let mut guard = self.tx.lock().await;
138 if let Some(tx) = guard.take() {
139 tx.rollback().await
140 } else {
141 Ok(())
142 }
143 }
144}