predawn_sea_orm/
data_source.rs

1use std::sync::Arc;
2
3use sea_orm::{DatabaseConnection, TransactionTrait};
4use snafu::ResultExt;
5use tokio::sync::Mutex;
6
7use crate::{
8    error::{
9        DbErrSnafu, InconsistentDataSourceAndTransactionSnafu,
10        NestedTransactionHaveMoreThanOneReferenceSnafu, TransactionHaveMoreThanOneReferenceSnafu,
11    },
12    Error, Transaction,
13};
14
15#[derive(Debug)]
16pub struct DataSource {
17    name: Arc<str>,
18    connection: DatabaseConnection,
19    transactions: Mutex<Vec<Transaction>>,
20}
21
22impl DataSource {
23    pub(crate) fn new(name: Arc<str>, conn: DatabaseConnection) -> Self {
24        Self {
25            name,
26            connection: conn,
27            transactions: Default::default(),
28        }
29    }
30
31    pub async fn current_txn(&self) -> Result<Transaction, Error> {
32        {
33            let transactions = self.transactions.lock().await;
34
35            if let Some(transaction) = transactions.last() {
36                return Ok(transaction.clone());
37            }
38        }
39
40        self.create_txn().await
41    }
42
43    pub async fn create_txn(&self) -> Result<Transaction, Error> {
44        let mut transactions = self.transactions.lock().await;
45
46        let result = match transactions.last() {
47            Some(txn) => txn.begin().await,
48            None => self.connection.begin().await,
49        };
50
51        let transaction = result.context(DbErrSnafu)?;
52
53        let transaction = Transaction {
54            name: self.name.clone(),
55            inner: Arc::new(transaction),
56            index: transactions.len(),
57        };
58
59        transactions.push(transaction.clone());
60
61        Ok(transaction)
62    }
63
64    pub fn standalone(&self) -> Self {
65        Self {
66            name: self.name.clone(),
67            connection: self.connection.clone(),
68            transactions: Default::default(),
69        }
70    }
71}
72
73macro_rules! single_operation {
74    ($ident:ident) => {
75        pub async fn $ident(&self, txn: Transaction) -> Result<(), Error> {
76            if self.name != txn.name {
77                return InconsistentDataSourceAndTransactionSnafu {
78                    data_source_name: self.name.clone(),
79                    transaction_name: txn.name.clone(),
80                    txn,
81                }
82                .fail();
83            }
84
85            {
86                let mut transactions = self.transactions.lock().await;
87
88                debug_assert!(txn.index < transactions.len());
89
90                for _ in txn.index..transactions.len() {
91                    let Transaction { name, inner, index } = transactions.pop().unwrap();
92
93                    if index == txn.index {
94                        drop(txn);
95
96                        match Arc::try_unwrap(inner) {
97                            Ok(txn) => {
98                                txn.$ident().await.context(DbErrSnafu)?;
99                                return Ok(());
100                            }
101                            Err(inner) => {
102                                let last = Transaction { name, inner, index };
103                                let txn = last.clone();
104
105                                transactions.push(last);
106
107                                return TransactionHaveMoreThanOneReferenceSnafu {
108                                    data_source_name: self.name.clone(),
109                                    transaction_hierarchy: index,
110                                    txn,
111                                }
112                                .fail();
113                            }
114                        }
115                    } else {
116                        match Arc::try_unwrap(inner) {
117                            Ok(txn) => {
118                                txn.$ident().await.context(DbErrSnafu)?;
119                            }
120                            Err(inner) => {
121                                let last = Transaction { name, inner, index };
122
123                                transactions.push(last);
124
125                                return NestedTransactionHaveMoreThanOneReferenceSnafu {
126                                    data_source_name: self.name.clone(),
127                                    current_transaction_hierarchy: txn.index,
128                                    nested_transaction_hierarchy: index,
129                                    txn,
130                                }
131                                .fail();
132                            }
133                        }
134                    }
135                }
136            }
137
138            Ok(())
139        }
140    };
141}
142
143macro_rules! multi_operation {
144    ($multi:ident, $single:ident) => {
145        pub async fn $multi(&self) -> Result<(), Error> {
146            let mut transactions = self.transactions.lock().await;
147
148            while let Some(Transaction { name, inner, index }) = transactions.pop() {
149                match Arc::try_unwrap(inner) {
150                    Ok(txn) => {
151                        txn.$single().await.context(DbErrSnafu)?;
152                    }
153                    Err(inner) => {
154                        let last = Transaction { name, inner, index };
155
156                        transactions.push(last.clone());
157
158                        return TransactionHaveMoreThanOneReferenceSnafu {
159                            data_source_name: self.name.clone(),
160                            transaction_hierarchy: index,
161                            txn: last,
162                        }
163                        .fail();
164                    }
165                }
166            }
167
168            Ok(())
169        }
170    };
171}
172
173impl DataSource {
174    single_operation!(commit);
175
176    single_operation!(rollback);
177
178    multi_operation!(commit_all, commit);
179
180    multi_operation!(rollback_all, rollback);
181}