Skip to main content

bsql_core/
transaction.rs

1//! Database transactions with commit/rollback.
2//!
3//! Created via [`Pool::begin()`](crate::pool::Pool::begin). A transaction
4//! holds a single connection from the pool for its entire lifetime. Queries
5//! executed through the `Executor` trait run within the transaction.
6//!
7//! # Drop behavior
8//!
9//! If a `Transaction` is dropped without calling [`commit()`](Transaction::commit)
10//! or [`rollback()`](Transaction::rollback), the driver discards the connection
11//! from the pool. PostgreSQL auto-rollbacks when the connection closes. A warning
12//! is emitted via `log::warn!` to help detect forgotten commits during development.
13
14use std::fmt;
15
16use bsql_driver_postgres::codec::Encode;
17
18use crate::error::{BsqlError, BsqlResult, QueryError};
19use crate::executor::OwnedResult;
20
21/// Transaction isolation levels supported by PostgreSQL.
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum IsolationLevel {
24    ReadUncommitted,
25    ReadCommitted,
26    RepeatableRead,
27    Serializable,
28}
29
30impl IsolationLevel {
31    /// SQL representation for `SET TRANSACTION ISOLATION LEVEL ...`.
32    fn as_sql(&self) -> &'static str {
33        match self {
34            IsolationLevel::ReadUncommitted => "READ UNCOMMITTED",
35            IsolationLevel::ReadCommitted => "READ COMMITTED",
36            IsolationLevel::RepeatableRead => "REPEATABLE READ",
37            IsolationLevel::Serializable => "SERIALIZABLE",
38        }
39    }
40}
41
42impl fmt::Display for IsolationLevel {
43    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44        f.write_str(self.as_sql())
45    }
46}
47
48/// A database transaction.
49///
50/// Created by [`Pool::begin()`](crate::pool::Pool::begin). Must be explicitly
51/// committed via [`commit()`](Transaction::commit). If dropped without
52/// `commit()`, the connection is discarded from the pool and a warning is logged.
53///
54/// Use `.defer(&mut tx)` on queries to buffer writes, then `tx.commit()` to flush
55/// them all in a single pipeline. Use `.execute(&mut tx)` or `.fetch(&mut tx)` for immediate
56/// execution within the transaction.
57///
58/// # Example
59///
60/// ```rust,ignore
61/// use bsql::Pool;
62///
63/// let pool = Pool::connect("postgres://user:pass@localhost/mydb")?;
64/// let tx = pool.begin()?;
65///
66/// // Buffer writes with .defer() — nothing hits the network yet
67/// bsql::query!("INSERT INTO log (msg) VALUES ($msg: &str)")
68///     .defer(&mut tx)?;
69///
70/// // Or execute immediately within the transaction
71/// bsql::query!("UPDATE accounts SET balance = 0 WHERE id = $id: i32")
72///     .execute(&mut tx)?;
73///
74/// // commit() flushes all deferred operations, then commits
75/// tx.commit()?;
76/// ```
77pub struct Transaction {
78    inner: Option<bsql_driver_postgres::Transaction>,
79    /// Set to true when commit() or rollback() is called.
80    finished: bool,
81}
82
83impl Transaction {
84    /// Wrap a driver-level transaction.
85    pub(crate) fn from_driver(tx: bsql_driver_postgres::Transaction) -> Self {
86        Self {
87            inner: Some(tx),
88            finished: false,
89        }
90    }
91
92    /// Return a "transaction already consumed" error.
93    fn consumed_error() -> BsqlError {
94        BsqlError::Query(QueryError {
95            message: "transaction already consumed".into(),
96            pg_code: None,
97            source: None,
98        })
99    }
100
101    /// Commit the transaction and return the connection to the pool.
102    ///
103    /// Consumes `self` — the transaction cannot be used after commit.
104    pub async fn commit(mut self) -> BsqlResult<()> {
105        self.finished = true;
106        let tx = self.inner.take().ok_or_else(Self::consumed_error)?;
107        tx.commit().map_err(BsqlError::from)
108    }
109
110    /// Explicitly roll back the transaction and return the connection to the pool.
111    ///
112    /// Consumes `self` — the transaction cannot be used after rollback.
113    pub async fn rollback(mut self) -> BsqlResult<()> {
114        self.finished = true;
115        let tx = self.inner.take().ok_or_else(Self::consumed_error)?;
116        tx.rollback().map_err(BsqlError::from)
117    }
118
119    /// Create a savepoint within the transaction.
120    ///
121    /// The `name` must be a valid SQL identifier: ASCII alphanumeric and
122    /// underscores only, starting with a letter or underscore. Maximum 63 characters.
123    pub async fn savepoint(&mut self, name: &str) -> BsqlResult<()> {
124        validate_savepoint_name(name)?;
125        let sql = format!("SAVEPOINT {name}");
126        let tx = self.inner.as_mut().ok_or_else(Self::consumed_error)?;
127        tx.simple_query(&sql).map_err(BsqlError::from_driver_query)
128    }
129
130    /// Release (destroy) a savepoint, keeping its effects.
131    ///
132    /// The `name` must match a previously created savepoint.
133    pub async fn release_savepoint(&mut self, name: &str) -> BsqlResult<()> {
134        validate_savepoint_name(name)?;
135        let sql = format!("RELEASE SAVEPOINT {name}");
136        let tx = self.inner.as_mut().ok_or_else(Self::consumed_error)?;
137        tx.simple_query(&sql).map_err(BsqlError::from_driver_query)
138    }
139
140    /// Roll back to a savepoint, undoing changes made after it was created.
141    ///
142    /// The savepoint remains valid after this call (can be rolled back to again).
143    pub async fn rollback_to(&mut self, name: &str) -> BsqlResult<()> {
144        validate_savepoint_name(name)?;
145        let sql = format!("ROLLBACK TO SAVEPOINT {name}");
146        let tx = self.inner.as_mut().ok_or_else(Self::consumed_error)?;
147        tx.simple_query(&sql).map_err(BsqlError::from_driver_query)
148    }
149
150    /// Set the isolation level for this transaction.
151    ///
152    /// Must be called before the first query in the transaction (immediately
153    /// after `begin()`). PostgreSQL rejects `SET TRANSACTION` after any
154    /// data-modifying statement.
155    pub async fn set_isolation(&mut self, level: IsolationLevel) -> BsqlResult<()> {
156        let sql = format!("SET TRANSACTION ISOLATION LEVEL {}", level.as_sql());
157        let tx = self.inner.as_mut().ok_or_else(Self::consumed_error)?;
158        tx.simple_query(&sql).map_err(BsqlError::from_driver_query)
159    }
160
161    /// Execute a query within the transaction (used by QueryTarget dispatch).
162    pub(crate) fn query_inner(
163        &mut self,
164        sql: &str,
165        sql_hash: u64,
166        params: &[&(dyn Encode + Sync)],
167    ) -> BsqlResult<OwnedResult> {
168        let tx = self.inner.as_mut().ok_or_else(Self::consumed_error)?;
169        let result = tx
170            .query(sql, sql_hash, params)
171            .map_err(BsqlError::from_driver_query)?;
172        Ok(OwnedResult::without_arena(result))
173    }
174
175    /// Execute without result rows within the transaction (used by QueryTarget dispatch).
176    pub(crate) fn execute_inner(
177        &mut self,
178        sql: &str,
179        sql_hash: u64,
180        params: &[&(dyn Encode + Sync)],
181    ) -> BsqlResult<u64> {
182        let tx = self.inner.as_mut().ok_or_else(Self::consumed_error)?;
183        tx.execute(sql, sql_hash, params)
184            .map_err(BsqlError::from_driver_query)
185    }
186
187    /// Execute the same statement N times with different params in one pipeline.
188    ///
189    /// Sends all N Bind+Execute messages + one Sync. One round-trip for
190    /// N operations within the transaction. Returns the affected row count
191    /// for each parameter set.
192    pub async fn execute_pipeline(
193        &mut self,
194        sql: &str,
195        sql_hash: u64,
196        param_sets: &[&[&(dyn Encode + Sync)]],
197    ) -> BsqlResult<Vec<u64>> {
198        let tx = self.inner.as_mut().ok_or_else(Self::consumed_error)?;
199        tx.execute_pipeline(sql, sql_hash, param_sets)
200            .map_err(BsqlError::from_driver_query)
201    }
202
203    // --- Deferred pipeline API ---
204
205    /// Buffer an execute for deferred pipeline flush.
206    ///
207    /// The operation is not sent to the server immediately. Instead, the
208    /// Bind+Execute message bytes are buffered internally. The buffered
209    /// operations are sent as a single pipeline on [`commit()`](Self::commit)
210    /// or [`flush_deferred()`](Self::flush_deferred).
211    ///
212    /// If the statement has not been prepared yet, a single round-trip is
213    /// made to prepare it. After that, the Bind+Execute bytes are buffered
214    /// with no I/O.
215    ///
216    /// Any read operation (`query_inner`, `for_each_raw`, `simple_query`, etc.)
217    /// automatically flushes deferred operations first to ensure
218    /// read-your-writes consistency.
219    #[doc(hidden)]
220    pub async fn defer_execute(
221        &mut self,
222        sql: &str,
223        sql_hash: u64,
224        params: &[&(dyn Encode + Sync)],
225    ) -> BsqlResult<()> {
226        let tx = self.inner.as_mut().ok_or_else(Self::consumed_error)?;
227        tx.defer_execute(sql, sql_hash, params)
228            .map_err(BsqlError::from_driver_query)
229    }
230
231    /// Flush all deferred operations as a single pipeline.
232    ///
233    /// Sends all buffered Bind+Execute messages + one Sync in a single TCP write.
234    /// Returns the affected row count for each deferred operation.
235    #[doc(hidden)]
236    pub async fn flush_deferred(&mut self) -> BsqlResult<Vec<u64>> {
237        let tx = self.inner.as_mut().ok_or_else(Self::consumed_error)?;
238        tx.flush_deferred().map_err(BsqlError::from_driver_query)
239    }
240
241    /// Number of operations currently buffered for deferred execution.
242    ///
243    /// This is a diagnostic method primarily for testing. Most users should
244    /// not need to call this -- deferred operations are flushed automatically
245    /// on commit or before any read.
246    #[doc(hidden)]
247    pub fn deferred_count(&self) -> usize {
248        match self.inner.as_ref() {
249            Some(tx) => tx.deferred_count(),
250            None => 0,
251        }
252    }
253
254    /// Process each row directly from the wire buffer within this transaction.
255    ///
256    /// Zero arena allocation — the closure receives a `PgDataRow` that reads
257    /// columns directly from the DataRow message bytes.
258    pub async fn for_each_raw<F>(
259        &mut self,
260        sql: &str,
261        sql_hash: u64,
262        params: &[&(dyn Encode + Sync)],
263        mut f: F,
264    ) -> BsqlResult<()>
265    where
266        F: FnMut(bsql_driver_postgres::PgDataRow<'_>) -> BsqlResult<()>,
267    {
268        let tx = self.inner.as_mut().ok_or_else(Self::consumed_error)?;
269        let mut user_err: Option<BsqlError> = None;
270        let driver_result = tx.for_each(sql, sql_hash, params, |row| match f(row) {
271            Ok(()) => Ok(()),
272            Err(e) => {
273                user_err = Some(e);
274                Err(bsql_driver_postgres::DriverError::Protocol(
275                    "for_each closure error".into(),
276                ))
277            }
278        });
279        if let Some(e) = user_err {
280            return Err(e);
281        }
282        driver_result.map_err(BsqlError::from_driver_query)
283    }
284
285    /// Process each DataRow as raw bytes within this transaction.
286    ///
287    /// Like `for_each_raw` but passes the raw `&[u8]` DataRow payload directly
288    /// to the closure — no `PgDataRow` construction, no SmallVec pre-scan.
289    #[doc(hidden)]
290    pub async fn __for_each_raw_bytes<F>(
291        &mut self,
292        sql: &str,
293        sql_hash: u64,
294        params: &[&(dyn Encode + Sync)],
295        mut f: F,
296    ) -> BsqlResult<()>
297    where
298        F: FnMut(&[u8]) -> BsqlResult<()>,
299    {
300        let tx = self.inner.as_mut().ok_or_else(Self::consumed_error)?;
301        let mut user_err: Option<BsqlError> = None;
302        let driver_result = tx.for_each_raw(sql, sql_hash, params, |data| match f(data) {
303            Ok(()) => Ok(()),
304            Err(e) => {
305                user_err = Some(e);
306                Err(bsql_driver_postgres::DriverError::Protocol(
307                    "for_each closure error".into(),
308                ))
309            }
310        });
311        if let Some(e) = user_err {
312            return Err(e);
313        }
314        driver_result.map_err(BsqlError::from_driver_query)
315    }
316}
317
318impl fmt::Debug for Transaction {
319    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
320        f.debug_struct("Transaction")
321            .field("finished", &self.finished)
322            .finish()
323    }
324}
325
326impl Drop for Transaction {
327    fn drop(&mut self) {
328        if !self.finished {
329            // The transaction was dropped without commit() or rollback().
330            // The driver-level Transaction::drop discards the connection from the
331            // pool — PG server auto-rollbacks when it sees the disconnect.
332            // Log a warning to help catch forgotten commits during development.
333            log::warn!(
334                "bsql: Transaction dropped without commit() or rollback() — \
335                 connection discarded from pool. This is safe but wasteful."
336            );
337        }
338    }
339}
340
341/// Delegate to shared savepoint name validator.
342fn validate_savepoint_name(name: &str) -> BsqlResult<()> {
343    crate::util::validate_savepoint_name(name)
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349
350    #[test]
351    fn validate_savepoint_name_valid() {
352        assert!(validate_savepoint_name("sp1").is_ok());
353        assert!(validate_savepoint_name("_sp").is_ok());
354        assert!(validate_savepoint_name("my_savepoint_123").is_ok());
355    }
356
357    #[test]
358    fn validate_savepoint_name_empty() {
359        assert!(validate_savepoint_name("").is_err());
360    }
361
362    #[test]
363    fn validate_savepoint_name_too_long() {
364        let long = "a".repeat(64);
365        assert!(validate_savepoint_name(&long).is_err());
366    }
367
368    #[test]
369    fn validate_savepoint_name_max_length() {
370        let max = "a".repeat(63);
371        assert!(validate_savepoint_name(&max).is_ok());
372    }
373
374    #[test]
375    fn validate_savepoint_name_starts_with_digit() {
376        assert!(validate_savepoint_name("1sp").is_err());
377    }
378
379    #[test]
380    fn validate_savepoint_name_starts_with_underscore() {
381        assert!(validate_savepoint_name("_sp").is_ok());
382    }
383
384    #[test]
385    fn validate_savepoint_name_special_chars() {
386        assert!(validate_savepoint_name("sp-1").is_err());
387        assert!(validate_savepoint_name("sp.1").is_err());
388        assert!(validate_savepoint_name("sp 1").is_err());
389        assert!(validate_savepoint_name("sp;1").is_err());
390        assert!(validate_savepoint_name("sp'1").is_err());
391    }
392
393    #[test]
394    fn isolation_level_display() {
395        assert_eq!(
396            IsolationLevel::ReadUncommitted.to_string(),
397            "READ UNCOMMITTED"
398        );
399        assert_eq!(IsolationLevel::ReadCommitted.to_string(), "READ COMMITTED");
400        assert_eq!(
401            IsolationLevel::RepeatableRead.to_string(),
402            "REPEATABLE READ"
403        );
404        assert_eq!(IsolationLevel::Serializable.to_string(), "SERIALIZABLE");
405    }
406
407    // --- IsolationLevel traits ---
408
409    #[test]
410    fn isolation_level_clone() {
411        let level = IsolationLevel::Serializable;
412        let cloned = level;
413        assert_eq!(level, cloned);
414    }
415
416    #[test]
417    fn isolation_level_debug() {
418        let level = IsolationLevel::RepeatableRead;
419        let dbg = format!("{level:?}");
420        assert!(
421            dbg.contains("RepeatableRead"),
422            "Debug should show variant name: {dbg}"
423        );
424    }
425
426    #[test]
427    fn isolation_level_eq() {
428        assert_eq!(IsolationLevel::Serializable, IsolationLevel::Serializable);
429        assert_ne!(IsolationLevel::Serializable, IsolationLevel::ReadCommitted);
430    }
431
432    // --- Transaction Debug ---
433
434    #[test]
435    fn transaction_debug_shows_finished_false() {
436        // Transaction cannot be constructed in tests without a driver,
437        // but we verify the Debug impl exists at compile time.
438        fn _assert_debug<T: std::fmt::Debug>() {}
439        _assert_debug::<Transaction>();
440    }
441
442    // --- Send + Sync assertions ---
443
444    fn _assert_send<T: Send>() {}
445
446    #[test]
447    fn transaction_is_send() {
448        _assert_send::<Transaction>();
449    }
450
451    #[test]
452    fn isolation_level_is_send() {
453        _assert_send::<IsolationLevel>();
454    }
455
456    // --- IsolationLevel as_sql covers all variants ---
457
458    #[test]
459    fn isolation_level_as_sql_all_variants() {
460        assert_eq!(IsolationLevel::ReadUncommitted.as_sql(), "READ UNCOMMITTED");
461        assert_eq!(IsolationLevel::ReadCommitted.as_sql(), "READ COMMITTED");
462        assert_eq!(IsolationLevel::RepeatableRead.as_sql(), "REPEATABLE READ");
463        assert_eq!(IsolationLevel::Serializable.as_sql(), "SERIALIZABLE");
464    }
465
466    // --- Savepoint name validation edge cases ---
467
468    #[test]
469    fn validate_savepoint_name_single_char() {
470        assert!(validate_savepoint_name("a").is_ok());
471        assert!(validate_savepoint_name("_").is_ok());
472    }
473
474    #[test]
475    fn validate_savepoint_name_all_digits_after_letter() {
476        assert!(validate_savepoint_name("a123456789").is_ok());
477    }
478
479    #[test]
480    fn validate_savepoint_name_all_underscores() {
481        assert!(validate_savepoint_name("___").is_ok());
482    }
483
484    #[test]
485    fn validate_savepoint_name_unicode_rejected() {
486        assert!(
487            validate_savepoint_name("sp_\u{00e9}").is_err(),
488            "unicode chars should be rejected"
489        );
490    }
491
492    #[test]
493    fn validate_savepoint_name_sql_injection_rejected() {
494        assert!(validate_savepoint_name("sp; DROP TABLE").is_err());
495        assert!(validate_savepoint_name("sp'--").is_err());
496        assert!(validate_savepoint_name("sp\"test").is_err());
497    }
498
499    // --- Transaction consumed error message ---
500
501    #[test]
502    fn consumed_error_message_is_descriptive() {
503        let e = Transaction::consumed_error();
504        let display = e.to_string();
505        assert!(
506            display.contains("transaction already consumed"),
507            "consumed error should be descriptive: {display}"
508        );
509    }
510
511    // --- IsolationLevel as_sql is idempotent ---
512
513    #[test]
514    fn isolation_level_as_sql_is_idempotent() {
515        let level = IsolationLevel::Serializable;
516        assert_eq!(level.as_sql(), level.as_sql());
517        assert_eq!(level.as_sql(), "SERIALIZABLE");
518    }
519
520    // --- IsolationLevel Display matches as_sql ---
521
522    #[test]
523    fn isolation_level_display_matches_as_sql() {
524        for level in [
525            IsolationLevel::ReadUncommitted,
526            IsolationLevel::ReadCommitted,
527            IsolationLevel::RepeatableRead,
528            IsolationLevel::Serializable,
529        ] {
530            assert_eq!(level.to_string(), level.as_sql());
531        }
532    }
533
534    // --- Transaction from_driver sets finished to false ---
535
536    // Cannot construct without driver, but verify at compile-time level
537    #[test]
538    fn transaction_from_driver_compiles() {
539        fn _check(_tx: bsql_driver_postgres::Transaction) -> Transaction {
540            Transaction::from_driver(_tx)
541        }
542    }
543
544    // --- Savepoint name: null bytes rejected ---
545
546    #[test]
547    fn validate_savepoint_name_null_byte_rejected() {
548        assert!(
549            validate_savepoint_name("sp\0name").is_err(),
550            "null byte in savepoint name should be rejected"
551        );
552    }
553
554    // --- Savepoint name: exactly 63 chars OK, 64 chars error ---
555
556    #[test]
557    fn validate_savepoint_name_boundary_63_and_64() {
558        let ok_63 = format!("a{}", "b".repeat(62));
559        assert!(validate_savepoint_name(&ok_63).is_ok());
560        let err_64 = format!("a{}", "b".repeat(63));
561        assert!(validate_savepoint_name(&err_64).is_err());
562    }
563
564    // --- consumed_error produces Query variant ---
565
566    #[test]
567    fn consumed_error_is_query_variant() {
568        let e = Transaction::consumed_error();
569        assert!(
570            matches!(e, BsqlError::Query(_)),
571            "consumed_error should be Query variant"
572        );
573    }
574}