Skip to main content

bsql_core/
transaction.rs

1//! Database transactions with commit/rollback.
2//!
3//! Created via [`Pool::begin()`](crate::pool::Pool::begin). A transaction
4//! holds a single connection from the pool for its entire lifetime. Queries
5//! executed through the `Executor` trait run within the transaction.
6//!
7//! # Drop behavior
8//!
9//! If a `Transaction` is dropped without calling [`commit()`](Transaction::commit)
10//! or [`rollback()`](Transaction::rollback), the driver discards the connection
11//! from the pool. PostgreSQL auto-rollbacks when the connection closes. A warning
12//! is emitted via `eprintln!` to help detect forgotten commits during development.
13
14use std::fmt;
15use std::sync::Mutex;
16
17use bsql_driver_postgres::codec::Encode;
18
19use crate::error::{BsqlError, BsqlResult, QueryError};
20use crate::executor::OwnedResult;
21
22/// Transaction isolation levels supported by PostgreSQL.
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum IsolationLevel {
25    ReadUncommitted,
26    ReadCommitted,
27    RepeatableRead,
28    Serializable,
29}
30
31impl IsolationLevel {
32    /// SQL representation for `SET TRANSACTION ISOLATION LEVEL ...`.
33    fn as_sql(&self) -> &'static str {
34        match self {
35            IsolationLevel::ReadUncommitted => "READ UNCOMMITTED",
36            IsolationLevel::ReadCommitted => "READ COMMITTED",
37            IsolationLevel::RepeatableRead => "REPEATABLE READ",
38            IsolationLevel::Serializable => "SERIALIZABLE",
39        }
40    }
41}
42
43impl fmt::Display for IsolationLevel {
44    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45        f.write_str(self.as_sql())
46    }
47}
48
49/// A database transaction.
50///
51/// Created by [`Pool::begin()`](crate::pool::Pool::begin). Must be explicitly
52/// committed via [`commit()`](Transaction::commit). If dropped without
53/// `commit()`, the connection is discarded from the pool and a warning is logged.
54///
55/// Use `.defer(&tx)` on queries to buffer writes, then `tx.commit()` to flush
56/// them all in a single pipeline. Use `.run(&tx)` or `.fetch(&tx)` for immediate
57/// execution within the transaction.
58///
59/// # Example
60///
61/// ```rust,ignore
62/// use bsql::Pool;
63///
64/// let pool = Pool::connect("postgres://user:pass@localhost/mydb")?;
65/// let tx = pool.begin()?;
66///
67/// // Buffer writes with .defer() — nothing hits the network yet
68/// bsql::query!("INSERT INTO log (msg) VALUES ($msg: &str)")
69///     .defer(&tx)?;
70///
71/// // Or execute immediately within the transaction
72/// bsql::query!("UPDATE accounts SET balance = 0 WHERE id = $id: i32")
73///     .run(&tx)?;
74///
75/// // commit() flushes all deferred operations, then commits
76/// tx.commit()?;
77/// ```
78pub struct Transaction {
79    inner: Mutex<Option<bsql_driver_postgres::Transaction>>,
80    /// Set to true when commit() or rollback() is called.
81    finished: bool,
82}
83
84impl Transaction {
85    /// Wrap a driver-level transaction.
86    pub(crate) fn from_driver(tx: bsql_driver_postgres::Transaction) -> Self {
87        Self {
88            inner: Mutex::new(Some(tx)),
89            finished: false,
90        }
91    }
92
93    /// Return a "transaction already consumed" error.
94    fn consumed_error() -> BsqlError {
95        BsqlError::Query(QueryError {
96            message: "transaction already consumed".into(),
97            pg_code: None,
98            source: None,
99        })
100    }
101
102    /// Commit the transaction and return the connection to the pool.
103    ///
104    /// Consumes `self` — the transaction cannot be used after commit.
105    pub async fn commit(mut self) -> BsqlResult<()> {
106        self.finished = true;
107        let tx = self
108            .inner
109            .lock()
110            .unwrap_or_else(|e| e.into_inner())
111            .take()
112            .ok_or_else(Self::consumed_error)?;
113        tx.commit().map_err(BsqlError::from)
114    }
115
116    /// Explicitly roll back the transaction and return the connection to the pool.
117    ///
118    /// Consumes `self` — the transaction cannot be used after rollback.
119    pub async fn rollback(mut self) -> BsqlResult<()> {
120        self.finished = true;
121        let tx = self
122            .inner
123            .lock()
124            .unwrap_or_else(|e| e.into_inner())
125            .take()
126            .ok_or_else(Self::consumed_error)?;
127        tx.rollback().map_err(BsqlError::from)
128    }
129
130    /// Create a savepoint within the transaction.
131    ///
132    /// The `name` must be a valid SQL identifier: ASCII alphanumeric and
133    /// underscores only, starting with a letter or underscore. Maximum 63 characters.
134    pub async fn savepoint(&self, name: &str) -> BsqlResult<()> {
135        validate_savepoint_name(name)?;
136        let sql = format!("SAVEPOINT {name}");
137        let mut guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
138        let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
139        tx.simple_query(&sql).map_err(BsqlError::from_driver_query)
140    }
141
142    /// Release (destroy) a savepoint, keeping its effects.
143    ///
144    /// The `name` must match a previously created savepoint.
145    pub async fn release_savepoint(&self, name: &str) -> BsqlResult<()> {
146        validate_savepoint_name(name)?;
147        let sql = format!("RELEASE SAVEPOINT {name}");
148        let mut guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
149        let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
150        tx.simple_query(&sql).map_err(BsqlError::from_driver_query)
151    }
152
153    /// Roll back to a savepoint, undoing changes made after it was created.
154    ///
155    /// The savepoint remains valid after this call (can be rolled back to again).
156    pub async fn rollback_to(&self, name: &str) -> BsqlResult<()> {
157        validate_savepoint_name(name)?;
158        let sql = format!("ROLLBACK TO SAVEPOINT {name}");
159        let mut guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
160        let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
161        tx.simple_query(&sql).map_err(BsqlError::from_driver_query)
162    }
163
164    /// Set the isolation level for this transaction.
165    ///
166    /// Must be called before the first query in the transaction (immediately
167    /// after `begin()`). PostgreSQL rejects `SET TRANSACTION` after any
168    /// data-modifying statement.
169    pub async fn set_isolation(&self, level: IsolationLevel) -> BsqlResult<()> {
170        let sql = format!("SET TRANSACTION ISOLATION LEVEL {}", level.as_sql());
171        let mut guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
172        let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
173        tx.simple_query(&sql).map_err(BsqlError::from_driver_query)
174    }
175
176    /// Execute a query within the transaction (used by Executor impl).
177    pub(crate) fn query_inner(
178        &self,
179        sql: &str,
180        sql_hash: u64,
181        params: &[&(dyn Encode + Sync)],
182    ) -> BsqlResult<OwnedResult> {
183        let mut guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
184        let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
185        let result = tx
186            .query(sql, sql_hash, params)
187            .map_err(BsqlError::from_driver_query)?;
188        Ok(OwnedResult::without_arena(result))
189    }
190
191    /// Execute without result rows within the transaction (used by Executor impl).
192    pub(crate) fn execute_inner(
193        &self,
194        sql: &str,
195        sql_hash: u64,
196        params: &[&(dyn Encode + Sync)],
197    ) -> BsqlResult<u64> {
198        let mut guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
199        let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
200        tx.execute(sql, sql_hash, params)
201            .map_err(BsqlError::from_driver_query)
202    }
203
204    /// Execute the same statement N times with different params in one pipeline.
205    ///
206    /// Sends all N Bind+Execute messages + one Sync. One round-trip for
207    /// N operations within the transaction. Returns the affected row count
208    /// for each parameter set.
209    pub async fn execute_pipeline(
210        &self,
211        sql: &str,
212        sql_hash: u64,
213        param_sets: &[&[&(dyn Encode + Sync)]],
214    ) -> BsqlResult<Vec<u64>> {
215        let mut guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
216        let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
217        tx.execute_pipeline(sql, sql_hash, param_sets)
218            .map_err(BsqlError::from_driver_query)
219    }
220
221    // --- Deferred pipeline API ---
222
223    /// Buffer an execute for deferred pipeline flush.
224    ///
225    /// The operation is not sent to the server immediately. Instead, the
226    /// Bind+Execute message bytes are buffered internally. The buffered
227    /// operations are sent as a single pipeline on [`commit()`](Self::commit)
228    /// or [`flush_deferred()`](Self::flush_deferred).
229    ///
230    /// If the statement has not been prepared yet, a single round-trip is
231    /// made to prepare it. After that, the Bind+Execute bytes are buffered
232    /// with no I/O.
233    ///
234    /// Any read operation (`query_inner`, `for_each_raw`, `simple_query`, etc.)
235    /// automatically flushes deferred operations first to ensure
236    /// read-your-writes consistency.
237    #[doc(hidden)]
238    pub async fn defer_execute(
239        &self,
240        sql: &str,
241        sql_hash: u64,
242        params: &[&(dyn Encode + Sync)],
243    ) -> BsqlResult<()> {
244        let mut guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
245        let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
246        tx.defer_execute(sql, sql_hash, params)
247            .map_err(BsqlError::from_driver_query)
248    }
249
250    /// Flush all deferred operations as a single pipeline.
251    ///
252    /// Sends all buffered Bind+Execute messages + one Sync in a single TCP write.
253    /// Returns the affected row count for each deferred operation.
254    #[doc(hidden)]
255    pub async fn flush_deferred(&self) -> BsqlResult<Vec<u64>> {
256        let mut guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
257        let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
258        tx.flush_deferred().map_err(BsqlError::from_driver_query)
259    }
260
261    /// Number of operations currently buffered for deferred execution.
262    ///
263    /// This is a diagnostic method primarily for testing. Most users should
264    /// not need to call this -- deferred operations are flushed automatically
265    /// on commit or before any read.
266    #[doc(hidden)]
267    pub fn deferred_count(&self) -> usize {
268        let guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
269        match guard.as_ref() {
270            Some(tx) => tx.deferred_count(),
271            None => 0,
272        }
273    }
274
275    /// Process each row directly from the wire buffer within this transaction.
276    ///
277    /// Zero arena allocation — the closure receives a `PgDataRow` that reads
278    /// columns directly from the DataRow message bytes.
279    pub async fn for_each_raw<F>(
280        &self,
281        sql: &str,
282        sql_hash: u64,
283        params: &[&(dyn Encode + Sync)],
284        mut f: F,
285    ) -> BsqlResult<()>
286    where
287        F: FnMut(bsql_driver_postgres::PgDataRow<'_>) -> BsqlResult<()>,
288    {
289        let mut guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
290        let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
291        let mut user_err: Option<BsqlError> = None;
292        let driver_result = tx.for_each(sql, sql_hash, params, |row| match f(row) {
293            Ok(()) => Ok(()),
294            Err(e) => {
295                user_err = Some(e);
296                Err(bsql_driver_postgres::DriverError::Protocol(
297                    "for_each closure error".into(),
298                ))
299            }
300        });
301        if let Some(e) = user_err {
302            return Err(e);
303        }
304        driver_result.map_err(BsqlError::from_driver_query)
305    }
306
307    /// Process each DataRow as raw bytes within this transaction.
308    ///
309    /// Like `for_each_raw` but passes the raw `&[u8]` DataRow payload directly
310    /// to the closure — no `PgDataRow` construction, no SmallVec pre-scan.
311    #[doc(hidden)]
312    pub async fn __for_each_raw_bytes<F>(
313        &self,
314        sql: &str,
315        sql_hash: u64,
316        params: &[&(dyn Encode + Sync)],
317        mut f: F,
318    ) -> BsqlResult<()>
319    where
320        F: FnMut(&[u8]) -> BsqlResult<()>,
321    {
322        let mut guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
323        let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
324        let mut user_err: Option<BsqlError> = None;
325        let driver_result = tx.for_each_raw(sql, sql_hash, params, |data| match f(data) {
326            Ok(()) => Ok(()),
327            Err(e) => {
328                user_err = Some(e);
329                Err(bsql_driver_postgres::DriverError::Protocol(
330                    "for_each closure error".into(),
331                ))
332            }
333        });
334        if let Some(e) = user_err {
335            return Err(e);
336        }
337        driver_result.map_err(BsqlError::from_driver_query)
338    }
339}
340
341impl fmt::Debug for Transaction {
342    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
343        f.debug_struct("Transaction")
344            .field("finished", &self.finished)
345            .finish()
346    }
347}
348
349impl Drop for Transaction {
350    fn drop(&mut self) {
351        if !self.finished {
352            // The transaction was dropped without commit() or rollback().
353            // The driver-level Transaction::drop discards the connection from the
354            // pool — PG server auto-rollbacks when it sees the disconnect.
355            // Log a warning to help catch forgotten commits during development.
356            eprintln!(
357                "bsql: Transaction dropped without commit() or rollback() — \
358                 connection discarded from pool. This is safe but wasteful."
359            );
360        }
361    }
362}
363
364/// Delegate to shared savepoint name validator.
365fn validate_savepoint_name(name: &str) -> BsqlResult<()> {
366    crate::util::validate_savepoint_name(name)
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372
373    #[test]
374    fn validate_savepoint_name_valid() {
375        assert!(validate_savepoint_name("sp1").is_ok());
376        assert!(validate_savepoint_name("_sp").is_ok());
377        assert!(validate_savepoint_name("my_savepoint_123").is_ok());
378    }
379
380    #[test]
381    fn validate_savepoint_name_empty() {
382        assert!(validate_savepoint_name("").is_err());
383    }
384
385    #[test]
386    fn validate_savepoint_name_too_long() {
387        let long = "a".repeat(64);
388        assert!(validate_savepoint_name(&long).is_err());
389    }
390
391    #[test]
392    fn validate_savepoint_name_max_length() {
393        let max = "a".repeat(63);
394        assert!(validate_savepoint_name(&max).is_ok());
395    }
396
397    #[test]
398    fn validate_savepoint_name_starts_with_digit() {
399        assert!(validate_savepoint_name("1sp").is_err());
400    }
401
402    #[test]
403    fn validate_savepoint_name_starts_with_underscore() {
404        assert!(validate_savepoint_name("_sp").is_ok());
405    }
406
407    #[test]
408    fn validate_savepoint_name_special_chars() {
409        assert!(validate_savepoint_name("sp-1").is_err());
410        assert!(validate_savepoint_name("sp.1").is_err());
411        assert!(validate_savepoint_name("sp 1").is_err());
412        assert!(validate_savepoint_name("sp;1").is_err());
413        assert!(validate_savepoint_name("sp'1").is_err());
414    }
415
416    #[test]
417    fn isolation_level_display() {
418        assert_eq!(
419            IsolationLevel::ReadUncommitted.to_string(),
420            "READ UNCOMMITTED"
421        );
422        assert_eq!(IsolationLevel::ReadCommitted.to_string(), "READ COMMITTED");
423        assert_eq!(
424            IsolationLevel::RepeatableRead.to_string(),
425            "REPEATABLE READ"
426        );
427        assert_eq!(IsolationLevel::Serializable.to_string(), "SERIALIZABLE");
428    }
429
430    // --- IsolationLevel traits ---
431
432    #[test]
433    fn isolation_level_clone() {
434        let level = IsolationLevel::Serializable;
435        let cloned = level;
436        assert_eq!(level, cloned);
437    }
438
439    #[test]
440    fn isolation_level_debug() {
441        let level = IsolationLevel::RepeatableRead;
442        let dbg = format!("{level:?}");
443        assert!(
444            dbg.contains("RepeatableRead"),
445            "Debug should show variant name: {dbg}"
446        );
447    }
448
449    #[test]
450    fn isolation_level_eq() {
451        assert_eq!(IsolationLevel::Serializable, IsolationLevel::Serializable);
452        assert_ne!(IsolationLevel::Serializable, IsolationLevel::ReadCommitted);
453    }
454
455    // --- Transaction Debug ---
456
457    #[test]
458    fn transaction_debug_shows_finished_false() {
459        // Transaction cannot be constructed in tests without a driver,
460        // but we verify the Debug impl exists at compile time.
461        fn _assert_debug<T: std::fmt::Debug>() {}
462        _assert_debug::<Transaction>();
463    }
464
465    // --- Send + Sync assertions ---
466
467    fn _assert_send<T: Send>() {}
468    fn _assert_sync<T: Sync>() {}
469
470    #[test]
471    fn transaction_is_send() {
472        _assert_send::<Transaction>();
473    }
474
475    #[test]
476    fn transaction_is_sync() {
477        _assert_sync::<Transaction>();
478    }
479
480    #[test]
481    fn isolation_level_is_send_and_sync() {
482        _assert_send::<IsolationLevel>();
483        _assert_sync::<IsolationLevel>();
484    }
485
486    // --- IsolationLevel as_sql covers all variants ---
487
488    #[test]
489    fn isolation_level_as_sql_all_variants() {
490        assert_eq!(IsolationLevel::ReadUncommitted.as_sql(), "READ UNCOMMITTED");
491        assert_eq!(IsolationLevel::ReadCommitted.as_sql(), "READ COMMITTED");
492        assert_eq!(IsolationLevel::RepeatableRead.as_sql(), "REPEATABLE READ");
493        assert_eq!(IsolationLevel::Serializable.as_sql(), "SERIALIZABLE");
494    }
495
496    // --- Savepoint name validation edge cases ---
497
498    #[test]
499    fn validate_savepoint_name_single_char() {
500        assert!(validate_savepoint_name("a").is_ok());
501        assert!(validate_savepoint_name("_").is_ok());
502    }
503
504    #[test]
505    fn validate_savepoint_name_all_digits_after_letter() {
506        assert!(validate_savepoint_name("a123456789").is_ok());
507    }
508
509    #[test]
510    fn validate_savepoint_name_all_underscores() {
511        assert!(validate_savepoint_name("___").is_ok());
512    }
513
514    #[test]
515    fn validate_savepoint_name_unicode_rejected() {
516        assert!(
517            validate_savepoint_name("sp_\u{00e9}").is_err(),
518            "unicode chars should be rejected"
519        );
520    }
521
522    #[test]
523    fn validate_savepoint_name_sql_injection_rejected() {
524        assert!(validate_savepoint_name("sp; DROP TABLE").is_err());
525        assert!(validate_savepoint_name("sp'--").is_err());
526        assert!(validate_savepoint_name("sp\"test").is_err());
527    }
528}