Skip to main content

bsql_core/
transaction.rs

1//! Database transactions with commit/rollback.
2//!
3//! Created via [`Pool::begin()`](crate::pool::Pool::begin). A transaction
4//! holds a single connection from the pool for its entire lifetime. Queries
5//! executed through the `Executor` trait run within the transaction.
6//!
7//! # Drop behavior
8//!
9//! If a `Transaction` is dropped without calling [`commit()`](Transaction::commit)
10//! or [`rollback()`](Transaction::rollback), the driver discards the connection
11//! from the pool. PostgreSQL auto-rollbacks when the connection closes. A warning
12//! is emitted via `eprintln!` to help detect forgotten commits during development.
13
14use std::fmt;
15
16use bsql_driver_postgres::arena::acquire_arena;
17use bsql_driver_postgres::codec::Encode;
18use tokio::sync::Mutex;
19
20use crate::error::{BsqlError, BsqlResult, QueryError};
21use crate::executor::OwnedResult;
22
23/// Transaction isolation levels supported by PostgreSQL.
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum IsolationLevel {
26    ReadUncommitted,
27    ReadCommitted,
28    RepeatableRead,
29    Serializable,
30}
31
32impl IsolationLevel {
33    /// SQL representation for `SET TRANSACTION ISOLATION LEVEL ...`.
34    fn as_sql(&self) -> &'static str {
35        match self {
36            IsolationLevel::ReadUncommitted => "READ UNCOMMITTED",
37            IsolationLevel::ReadCommitted => "READ COMMITTED",
38            IsolationLevel::RepeatableRead => "REPEATABLE READ",
39            IsolationLevel::Serializable => "SERIALIZABLE",
40        }
41    }
42}
43
44impl fmt::Display for IsolationLevel {
45    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46        f.write_str(self.as_sql())
47    }
48}
49
50/// A database transaction.
51///
52/// Created by [`Pool::begin()`](crate::pool::Pool::begin). Must be explicitly
53/// committed via [`commit()`](Transaction::commit). If dropped without
54/// `commit()`, the connection is discarded from the pool and a warning is logged.
55///
56/// Use `.defer(&tx)` on queries to buffer writes, then `tx.commit()` to flush
57/// them all in a single pipeline. Use `.run(&tx)` or `.fetch(&tx)` for immediate
58/// execution within the transaction.
59///
60/// # Example
61///
62/// ```rust,ignore
63/// use bsql::Pool;
64///
65/// let pool = Pool::connect("postgres://user:pass@localhost/mydb").await?;
66/// let tx = pool.begin().await?;
67///
68/// // Buffer writes with .defer() — nothing hits the network yet
69/// bsql::query!("INSERT INTO log (msg) VALUES ($msg: &str)")
70///     .defer(&tx).await?;
71///
72/// // Or execute immediately within the transaction
73/// bsql::query!("UPDATE accounts SET balance = 0 WHERE id = $id: i32")
74///     .run(&tx).await?;
75///
76/// // commit() flushes all deferred operations, then commits
77/// tx.commit().await?;
78/// ```
79pub struct Transaction {
80    inner: Mutex<Option<bsql_driver_postgres::Transaction>>,
81    /// Set to true when commit() or rollback() is called.
82    finished: bool,
83}
84
85impl Transaction {
86    /// Wrap a driver-level transaction.
87    pub(crate) fn from_driver(tx: bsql_driver_postgres::Transaction) -> Self {
88        Self {
89            inner: Mutex::new(Some(tx)),
90            finished: false,
91        }
92    }
93
94    /// Return a "transaction already consumed" error.
95    fn consumed_error() -> BsqlError {
96        BsqlError::Query(QueryError {
97            message: "transaction already consumed".into(),
98            pg_code: None,
99            source: None,
100        })
101    }
102
103    /// Commit the transaction and return the connection to the pool.
104    ///
105    /// Consumes `self` — the transaction cannot be used after commit.
106    pub async fn commit(mut self) -> BsqlResult<()> {
107        self.finished = true;
108        let tx = self
109            .inner
110            .lock()
111            .await
112            .take()
113            .ok_or_else(Self::consumed_error)?;
114        tx.commit().await.map_err(BsqlError::from)
115    }
116
117    /// Explicitly roll back the transaction and return the connection to the pool.
118    ///
119    /// Consumes `self` — the transaction cannot be used after rollback.
120    pub async fn rollback(mut self) -> BsqlResult<()> {
121        self.finished = true;
122        let tx = self
123            .inner
124            .lock()
125            .await
126            .take()
127            .ok_or_else(Self::consumed_error)?;
128        tx.rollback().await.map_err(BsqlError::from)
129    }
130
131    /// Create a savepoint within the transaction.
132    ///
133    /// The `name` must be a valid SQL identifier: ASCII alphanumeric and
134    /// underscores only, starting with a letter or underscore. Maximum 63 characters.
135    pub async fn savepoint(&self, name: &str) -> BsqlResult<()> {
136        validate_savepoint_name(name)?;
137        let sql = format!("SAVEPOINT {name}");
138        let mut guard = self.inner.lock().await;
139        let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
140        tx.simple_query(&sql)
141            .await
142            .map_err(BsqlError::from_driver_query)
143    }
144
145    /// Release (destroy) a savepoint, keeping its effects.
146    ///
147    /// The `name` must match a previously created savepoint.
148    pub async fn release_savepoint(&self, name: &str) -> BsqlResult<()> {
149        validate_savepoint_name(name)?;
150        let sql = format!("RELEASE SAVEPOINT {name}");
151        let mut guard = self.inner.lock().await;
152        let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
153        tx.simple_query(&sql)
154            .await
155            .map_err(BsqlError::from_driver_query)
156    }
157
158    /// Roll back to a savepoint, undoing changes made after it was created.
159    ///
160    /// The savepoint remains valid after this call (can be rolled back to again).
161    pub async fn rollback_to(&self, name: &str) -> BsqlResult<()> {
162        validate_savepoint_name(name)?;
163        let sql = format!("ROLLBACK TO SAVEPOINT {name}");
164        let mut guard = self.inner.lock().await;
165        let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
166        tx.simple_query(&sql)
167            .await
168            .map_err(BsqlError::from_driver_query)
169    }
170
171    /// Set the isolation level for this transaction.
172    ///
173    /// Must be called before the first query in the transaction (immediately
174    /// after `begin()`). PostgreSQL rejects `SET TRANSACTION` after any
175    /// data-modifying statement.
176    pub async fn set_isolation(&self, level: IsolationLevel) -> BsqlResult<()> {
177        let sql = format!("SET TRANSACTION ISOLATION LEVEL {}", level.as_sql());
178        let mut guard = self.inner.lock().await;
179        let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
180        tx.simple_query(&sql)
181            .await
182            .map_err(BsqlError::from_driver_query)
183    }
184
185    /// Execute a query within the transaction (used by Executor impl).
186    pub(crate) async fn query_inner(
187        &self,
188        sql: &str,
189        sql_hash: u64,
190        params: &[&(dyn Encode + Sync)],
191    ) -> BsqlResult<OwnedResult> {
192        let mut guard = self.inner.lock().await;
193        let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
194        let mut arena = acquire_arena();
195        let result = tx
196            .query(sql, sql_hash, params, &mut arena)
197            .await
198            .map_err(BsqlError::from_driver_query)?;
199        Ok(OwnedResult::new(result, arena))
200    }
201
202    /// Execute without result rows within the transaction (used by Executor impl).
203    pub(crate) async fn execute_inner(
204        &self,
205        sql: &str,
206        sql_hash: u64,
207        params: &[&(dyn Encode + Sync)],
208    ) -> BsqlResult<u64> {
209        let mut guard = self.inner.lock().await;
210        let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
211        tx.execute(sql, sql_hash, params)
212            .await
213            .map_err(BsqlError::from_driver_query)
214    }
215
216    /// Execute the same statement N times with different params in one pipeline.
217    ///
218    /// Sends all N Bind+Execute messages + one Sync. One round-trip for
219    /// N operations within the transaction. Returns the affected row count
220    /// for each parameter set.
221    pub async fn execute_pipeline(
222        &self,
223        sql: &str,
224        sql_hash: u64,
225        param_sets: &[&[&(dyn Encode + Sync)]],
226    ) -> BsqlResult<Vec<u64>> {
227        let mut guard = self.inner.lock().await;
228        let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
229        tx.execute_pipeline(sql, sql_hash, param_sets)
230            .await
231            .map_err(BsqlError::from_driver_query)
232    }
233
234    // --- Deferred pipeline API ---
235
236    /// Buffer an execute for deferred pipeline flush.
237    ///
238    /// The operation is not sent to the server immediately. Instead, the
239    /// Bind+Execute message bytes are buffered internally. The buffered
240    /// operations are sent as a single pipeline on [`commit()`](Self::commit)
241    /// or [`flush_deferred()`](Self::flush_deferred).
242    ///
243    /// If the statement has not been prepared yet, a single round-trip is
244    /// made to prepare it. After that, the Bind+Execute bytes are buffered
245    /// with no I/O.
246    ///
247    /// Any read operation (`query_inner`, `for_each_raw`, `simple_query`, etc.)
248    /// automatically flushes deferred operations first to ensure
249    /// read-your-writes consistency.
250    #[doc(hidden)]
251    pub async fn defer_execute(
252        &self,
253        sql: &str,
254        sql_hash: u64,
255        params: &[&(dyn Encode + Sync)],
256    ) -> BsqlResult<()> {
257        let mut guard = self.inner.lock().await;
258        let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
259        tx.defer_execute(sql, sql_hash, params)
260            .await
261            .map_err(BsqlError::from_driver_query)
262    }
263
264    /// Flush all deferred operations as a single pipeline.
265    ///
266    /// Sends all buffered Bind+Execute messages + one Sync in a single TCP write.
267    /// Returns the affected row count for each deferred operation.
268    #[doc(hidden)]
269    pub async fn flush_deferred(&self) -> BsqlResult<Vec<u64>> {
270        let mut guard = self.inner.lock().await;
271        let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
272        tx.flush_deferred()
273            .await
274            .map_err(BsqlError::from_driver_query)
275    }
276
277    /// Number of operations currently buffered for deferred execution.
278    ///
279    /// This is a diagnostic method primarily for testing. Most users should
280    /// not need to call this -- deferred operations are flushed automatically
281    /// on commit or before any read.
282    #[doc(hidden)]
283    pub async fn deferred_count(&self) -> usize {
284        let guard = self.inner.lock().await;
285        match guard.as_ref() {
286            Some(tx) => tx.deferred_count(),
287            None => 0,
288        }
289    }
290
291    /// Process each row directly from the wire buffer within this transaction.
292    ///
293    /// Zero arena allocation — the closure receives a `PgDataRow` that reads
294    /// columns directly from the DataRow message bytes.
295    pub async fn for_each_raw<F>(
296        &self,
297        sql: &str,
298        sql_hash: u64,
299        params: &[&(dyn Encode + Sync)],
300        mut f: F,
301    ) -> BsqlResult<()>
302    where
303        F: FnMut(bsql_driver_postgres::PgDataRow<'_>) -> BsqlResult<()>,
304    {
305        let mut guard = self.inner.lock().await;
306        let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
307        let mut user_err: Option<BsqlError> = None;
308        let driver_result = tx
309            .for_each(sql, sql_hash, params, |row| match f(row) {
310                Ok(()) => Ok(()),
311                Err(e) => {
312                    user_err = Some(e);
313                    Err(bsql_driver_postgres::DriverError::Protocol(
314                        "for_each closure error".into(),
315                    ))
316                }
317            })
318            .await;
319        if let Some(e) = user_err {
320            return Err(e);
321        }
322        driver_result.map_err(BsqlError::from_driver_query)
323    }
324
325    /// Process each DataRow as raw bytes within this transaction.
326    ///
327    /// Like `for_each_raw` but passes the raw `&[u8]` DataRow payload directly
328    /// to the closure — no `PgDataRow` construction, no SmallVec pre-scan.
329    #[doc(hidden)]
330    pub async fn __for_each_raw_bytes<F>(
331        &self,
332        sql: &str,
333        sql_hash: u64,
334        params: &[&(dyn Encode + Sync)],
335        mut f: F,
336    ) -> BsqlResult<()>
337    where
338        F: FnMut(&[u8]) -> BsqlResult<()>,
339    {
340        let mut guard = self.inner.lock().await;
341        let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
342        let mut user_err: Option<BsqlError> = None;
343        let driver_result = tx
344            .for_each_raw(sql, sql_hash, params, |data| match f(data) {
345                Ok(()) => Ok(()),
346                Err(e) => {
347                    user_err = Some(e);
348                    Err(bsql_driver_postgres::DriverError::Protocol(
349                        "for_each closure error".into(),
350                    ))
351                }
352            })
353            .await;
354        if let Some(e) = user_err {
355            return Err(e);
356        }
357        driver_result.map_err(BsqlError::from_driver_query)
358    }
359}
360
361impl fmt::Debug for Transaction {
362    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
363        f.debug_struct("Transaction")
364            .field("finished", &self.finished)
365            .finish()
366    }
367}
368
369impl Drop for Transaction {
370    fn drop(&mut self) {
371        if !self.finished {
372            // The transaction was dropped without commit() or rollback().
373            // The driver-level Transaction::drop discards the connection from the
374            // pool — PG server auto-rollbacks when it sees the disconnect.
375            // Log a warning to help catch forgotten commits during development.
376            eprintln!(
377                "bsql: Transaction dropped without commit() or rollback() — \
378                 connection discarded from pool. This is safe but wasteful."
379            );
380        }
381    }
382}
383
384/// Delegate to shared savepoint name validator.
385fn validate_savepoint_name(name: &str) -> BsqlResult<()> {
386    crate::util::validate_savepoint_name(name)
387}
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392
393    #[test]
394    fn validate_savepoint_name_valid() {
395        assert!(validate_savepoint_name("sp1").is_ok());
396        assert!(validate_savepoint_name("_sp").is_ok());
397        assert!(validate_savepoint_name("my_savepoint_123").is_ok());
398    }
399
400    #[test]
401    fn validate_savepoint_name_empty() {
402        assert!(validate_savepoint_name("").is_err());
403    }
404
405    #[test]
406    fn validate_savepoint_name_too_long() {
407        let long = "a".repeat(64);
408        assert!(validate_savepoint_name(&long).is_err());
409    }
410
411    #[test]
412    fn validate_savepoint_name_max_length() {
413        let max = "a".repeat(63);
414        assert!(validate_savepoint_name(&max).is_ok());
415    }
416
417    #[test]
418    fn validate_savepoint_name_starts_with_digit() {
419        assert!(validate_savepoint_name("1sp").is_err());
420    }
421
422    #[test]
423    fn validate_savepoint_name_starts_with_underscore() {
424        assert!(validate_savepoint_name("_sp").is_ok());
425    }
426
427    #[test]
428    fn validate_savepoint_name_special_chars() {
429        assert!(validate_savepoint_name("sp-1").is_err());
430        assert!(validate_savepoint_name("sp.1").is_err());
431        assert!(validate_savepoint_name("sp 1").is_err());
432        assert!(validate_savepoint_name("sp;1").is_err());
433        assert!(validate_savepoint_name("sp'1").is_err());
434    }
435
436    #[test]
437    fn isolation_level_display() {
438        assert_eq!(
439            IsolationLevel::ReadUncommitted.to_string(),
440            "READ UNCOMMITTED"
441        );
442        assert_eq!(IsolationLevel::ReadCommitted.to_string(), "READ COMMITTED");
443        assert_eq!(
444            IsolationLevel::RepeatableRead.to_string(),
445            "REPEATABLE READ"
446        );
447        assert_eq!(IsolationLevel::Serializable.to_string(), "SERIALIZABLE");
448    }
449
450    // --- IsolationLevel traits ---
451
452    #[test]
453    fn isolation_level_clone() {
454        let level = IsolationLevel::Serializable;
455        let cloned = level;
456        assert_eq!(level, cloned);
457    }
458
459    #[test]
460    fn isolation_level_debug() {
461        let level = IsolationLevel::RepeatableRead;
462        let dbg = format!("{level:?}");
463        assert!(
464            dbg.contains("RepeatableRead"),
465            "Debug should show variant name: {dbg}"
466        );
467    }
468
469    #[test]
470    fn isolation_level_eq() {
471        assert_eq!(IsolationLevel::Serializable, IsolationLevel::Serializable);
472        assert_ne!(IsolationLevel::Serializable, IsolationLevel::ReadCommitted);
473    }
474
475    // --- Transaction Debug ---
476
477    #[test]
478    fn transaction_debug_shows_finished_false() {
479        // Transaction cannot be constructed in tests without a driver,
480        // but we verify the Debug impl exists at compile time.
481        fn _assert_debug<T: std::fmt::Debug>() {}
482        _assert_debug::<Transaction>();
483    }
484
485    // --- Send + Sync assertions ---
486
487    fn _assert_send<T: Send>() {}
488    fn _assert_sync<T: Sync>() {}
489
490    #[test]
491    fn transaction_is_send() {
492        _assert_send::<Transaction>();
493    }
494
495    #[test]
496    fn transaction_is_sync() {
497        _assert_sync::<Transaction>();
498    }
499
500    #[test]
501    fn isolation_level_is_send_and_sync() {
502        _assert_send::<IsolationLevel>();
503        _assert_sync::<IsolationLevel>();
504    }
505}