bottle_orm/
transaction.rs1use heck::ToSnakeCase;
11use std::sync::Arc;
12use tokio::sync::Mutex;
13use futures::future::BoxFuture;
14use sqlx::any::AnyArguments;
15
16use crate::{
21 database::{Connection, Drivers, RawQuery},
22 Model, QueryBuilder,
23};
24
25#[derive(Debug, Clone)]
34pub struct Transaction<'a> {
35 pub(crate) tx: Arc<Mutex<Option<sqlx::Transaction<'a, sqlx::Any>>>>,
36 pub(crate) driver: Drivers,
37}
38
39impl Connection for Transaction<'_> {
47 fn driver(&self) -> Drivers { self.driver }
48 fn execute<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<sqlx::any::AnyQueryResult, sqlx::Error>> {
49 Box::pin(async move {
50 let mut guard = self.tx.lock().await;
51 if let Some(tx) = guard.as_mut() {
52 sqlx::query_with(sql, args).execute(&mut **tx).await
53 } else {
54 Err(sqlx::Error::WorkerCrashed)
55 }
56 })
57 }
58
59 fn fetch_all<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<Vec<sqlx::any::AnyRow>, sqlx::Error>> {
60 Box::pin(async move {
61 let mut guard = self.tx.lock().await;
62 if let Some(tx) = guard.as_mut() {
63 sqlx::query_with(sql, args).fetch_all(&mut **tx).await
64 } else {
65 Err(sqlx::Error::WorkerCrashed)
66 }
67 })
68 }
69
70 fn fetch_one<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<sqlx::any::AnyRow, sqlx::Error>> {
71 Box::pin(async move {
72 let mut guard = self.tx.lock().await;
73 if let Some(tx) = guard.as_mut() {
74 sqlx::query_with(sql, args).fetch_one(&mut **tx).await
75 } else {
76 Err(sqlx::Error::WorkerCrashed)
77 }
78 })
79 }
80
81 fn fetch_optional<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<Option<sqlx::any::AnyRow>, sqlx::Error>> {
82 Box::pin(async move {
83 let mut guard = self.tx.lock().await;
84 if let Some(tx) = guard.as_mut() {
85 sqlx::query_with(sql, args).fetch_optional(&mut **tx).await
86 } else {
87 Err(sqlx::Error::WorkerCrashed)
88 }
89 })
90 }
91}
92
93impl<'a> Transaction<'a> {
98 pub fn model<T: Model + Send + Sync + Unpin + crate::AnyImpl>(
100 &self,
101 ) -> QueryBuilder<T, Self> {
102 let active_columns = T::active_columns();
103 let mut columns: Vec<String> = Vec::with_capacity(active_columns.capacity());
104
105 for col in active_columns {
106 columns.push(col.strip_prefix("r#").unwrap_or(col).to_snake_case());
107 }
108
109 QueryBuilder::new(self.clone(), self.driver, T::table_name(), <T as Model>::columns(), columns)
110 }
111
112 pub fn raw<'b>(&self, sql: &'b str) -> RawQuery<'b, Self> {
114 RawQuery::new(self.clone(), sql)
115 }
116
117 pub async fn commit(self) -> Result<(), sqlx::Error> {
119 let mut guard = self.tx.lock().await;
120 if let Some(tx) = guard.take() {
121 tx.commit().await
122 } else {
123 Ok(())
124 }
125 }
126
127 pub async fn rollback(self) -> Result<(), sqlx::Error> {
129 let mut guard = self.tx.lock().await;
130 if let Some(tx) = guard.take() {
131 tx.rollback().await
132 } else {
133 Ok(())
134 }
135 }
136}