bottle_orm/
transaction.rs1use heck::ToSnakeCase;
11use std::sync::Arc;
12use tokio::sync::Mutex;
13use futures::future::BoxFuture;
14use sqlx::any::AnyArguments;
15
16use crate::{
21 database::{Connection, Drivers, RawQuery},
22 Model, QueryBuilder,
23};
24
25#[derive(Debug, Clone)]
34pub struct Transaction<'a> {
35 pub(crate) tx: Arc<Mutex<Option<sqlx::Transaction<'a, sqlx::Any>>>>,
36 pub(crate) driver: Drivers,
37}
38
39impl Connection for Transaction<'_> {
47 fn execute<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<sqlx::any::AnyQueryResult, sqlx::Error>> {
48 Box::pin(async move {
49 let mut guard = self.tx.lock().await;
50 if let Some(tx) = guard.as_mut() {
51 sqlx::query_with(sql, args).execute(&mut **tx).await
52 } else {
53 Err(sqlx::Error::WorkerCrashed)
54 }
55 })
56 }
57
58 fn fetch_all<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<Vec<sqlx::any::AnyRow>, sqlx::Error>> {
59 Box::pin(async move {
60 let mut guard = self.tx.lock().await;
61 if let Some(tx) = guard.as_mut() {
62 sqlx::query_with(sql, args).fetch_all(&mut **tx).await
63 } else {
64 Err(sqlx::Error::WorkerCrashed)
65 }
66 })
67 }
68
69 fn fetch_one<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<sqlx::any::AnyRow, sqlx::Error>> {
70 Box::pin(async move {
71 let mut guard = self.tx.lock().await;
72 if let Some(tx) = guard.as_mut() {
73 sqlx::query_with(sql, args).fetch_one(&mut **tx).await
74 } else {
75 Err(sqlx::Error::WorkerCrashed)
76 }
77 })
78 }
79
80 fn fetch_optional<'a, 'q: 'a>(&'a self, sql: &'q str, args: AnyArguments<'q>) -> BoxFuture<'a, Result<Option<sqlx::any::AnyRow>, sqlx::Error>> {
81 Box::pin(async move {
82 let mut guard = self.tx.lock().await;
83 if let Some(tx) = guard.as_mut() {
84 sqlx::query_with(sql, args).fetch_optional(&mut **tx).await
85 } else {
86 Err(sqlx::Error::WorkerCrashed)
87 }
88 })
89 }
90}
91
92impl<'a> Transaction<'a> {
97 pub fn model<T: Model + Send + Sync + Unpin + crate::AnyImpl>(
99 &self,
100 ) -> QueryBuilder<T, Self> {
101 let active_columns = T::active_columns();
102 let mut columns: Vec<String> = Vec::with_capacity(active_columns.capacity());
103
104 for col in active_columns {
105 columns.push(col.strip_prefix("r#").unwrap_or(col).to_snake_case());
106 }
107
108 QueryBuilder::new(self.clone(), self.driver, T::table_name(), <T as Model>::columns(), columns)
109 }
110
111 pub fn raw<'b>(&self, sql: &'b str) -> RawQuery<'b, Self> {
113 RawQuery::new(self.clone(), sql)
114 }
115
116 pub async fn commit(self) -> Result<(), sqlx::Error> {
118 let mut guard = self.tx.lock().await;
119 if let Some(tx) = guard.take() {
120 tx.commit().await
121 } else {
122 Ok(())
123 }
124 }
125
126 pub async fn rollback(self) -> Result<(), sqlx::Error> {
128 let mut guard = self.tx.lock().await;
129 if let Some(tx) = guard.take() {
130 tx.rollback().await
131 } else {
132 Ok(())
133 }
134 }
135}