predawn_sea_orm/
data_source.rs1use std::sync::Arc;
2
3use sea_orm::{DatabaseConnection, TransactionTrait};
4use snafu::ResultExt;
5use tokio::sync::Mutex;
6
7use crate::{
8 error::{
9 DbErrSnafu, InconsistentDataSourceAndTransactionSnafu,
10 NestedTransactionHaveMoreThanOneReferenceSnafu, TransactionHaveMoreThanOneReferenceSnafu,
11 },
12 Error, Transaction,
13};
14
15#[derive(Debug)]
16pub struct DataSource {
17 name: Arc<str>,
18 connection: DatabaseConnection,
19 transactions: Mutex<Vec<Transaction>>,
20}
21
22impl DataSource {
23 pub(crate) fn new(name: Arc<str>, conn: DatabaseConnection) -> Self {
24 Self {
25 name,
26 connection: conn,
27 transactions: Default::default(),
28 }
29 }
30
31 pub async fn current_txn(&self) -> Result<Transaction, Error> {
32 {
33 let transactions = self.transactions.lock().await;
34
35 if let Some(transaction) = transactions.last() {
36 return Ok(transaction.clone());
37 }
38 }
39
40 self.create_txn().await
41 }
42
43 pub async fn create_txn(&self) -> Result<Transaction, Error> {
44 let mut transactions = self.transactions.lock().await;
45
46 let result = match transactions.last() {
47 Some(txn) => txn.begin().await,
48 None => self.connection.begin().await,
49 };
50
51 let transaction = result.context(DbErrSnafu)?;
52
53 let transaction = Transaction {
54 name: self.name.clone(),
55 inner: Arc::new(transaction),
56 index: transactions.len(),
57 };
58
59 transactions.push(transaction.clone());
60
61 Ok(transaction)
62 }
63
64 pub fn standalone(&self) -> Self {
65 Self {
66 name: self.name.clone(),
67 connection: self.connection.clone(),
68 transactions: Default::default(),
69 }
70 }
71}
72
73macro_rules! single_operation {
74 ($ident:ident) => {
75 pub async fn $ident(&self, txn: Transaction) -> Result<(), Error> {
76 if self.name != txn.name {
77 return InconsistentDataSourceAndTransactionSnafu {
78 data_source_name: self.name.clone(),
79 transaction_name: txn.name.clone(),
80 txn,
81 }
82 .fail();
83 }
84
85 {
86 let mut transactions = self.transactions.lock().await;
87
88 debug_assert!(txn.index < transactions.len());
89
90 for _ in txn.index..transactions.len() {
91 let Transaction { name, inner, index } = transactions.pop().unwrap();
92
93 if index == txn.index {
94 drop(txn);
95
96 match Arc::try_unwrap(inner) {
97 Ok(txn) => {
98 txn.$ident().await.context(DbErrSnafu)?;
99 return Ok(());
100 }
101 Err(inner) => {
102 let last = Transaction { name, inner, index };
103 let txn = last.clone();
104
105 transactions.push(last);
106
107 return TransactionHaveMoreThanOneReferenceSnafu {
108 data_source_name: self.name.clone(),
109 transaction_hierarchy: index,
110 txn,
111 }
112 .fail();
113 }
114 }
115 } else {
116 match Arc::try_unwrap(inner) {
117 Ok(txn) => {
118 txn.$ident().await.context(DbErrSnafu)?;
119 }
120 Err(inner) => {
121 let last = Transaction { name, inner, index };
122
123 transactions.push(last);
124
125 return NestedTransactionHaveMoreThanOneReferenceSnafu {
126 data_source_name: self.name.clone(),
127 current_transaction_hierarchy: txn.index,
128 nested_transaction_hierarchy: index,
129 txn,
130 }
131 .fail();
132 }
133 }
134 }
135 }
136 }
137
138 Ok(())
139 }
140 };
141}
142
143macro_rules! multi_operation {
144 ($multi:ident, $single:ident) => {
145 pub async fn $multi(&self) -> Result<(), Error> {
146 let mut transactions = self.transactions.lock().await;
147
148 while let Some(Transaction { name, inner, index }) = transactions.pop() {
149 match Arc::try_unwrap(inner) {
150 Ok(txn) => {
151 txn.$single().await.context(DbErrSnafu)?;
152 }
153 Err(inner) => {
154 let last = Transaction { name, inner, index };
155
156 transactions.push(last.clone());
157
158 return TransactionHaveMoreThanOneReferenceSnafu {
159 data_source_name: self.name.clone(),
160 transaction_hierarchy: index,
161 txn: last,
162 }
163 .fail();
164 }
165 }
166 }
167
168 Ok(())
169 }
170 };
171}
172
173impl DataSource {
174 single_operation!(commit);
175
176 single_operation!(rollback);
177
178 multi_operation!(commit_all, commit);
179
180 multi_operation!(rollback_all, rollback);
181}