1use std::fmt;
15
16use bsql_driver_postgres::codec::Encode;
17
18use crate::error::{BsqlError, BsqlResult, QueryError};
19use crate::executor::OwnedResult;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum IsolationLevel {
24 ReadUncommitted,
25 ReadCommitted,
26 RepeatableRead,
27 Serializable,
28}
29
30impl IsolationLevel {
31 fn as_sql(&self) -> &'static str {
33 match self {
34 IsolationLevel::ReadUncommitted => "READ UNCOMMITTED",
35 IsolationLevel::ReadCommitted => "READ COMMITTED",
36 IsolationLevel::RepeatableRead => "REPEATABLE READ",
37 IsolationLevel::Serializable => "SERIALIZABLE",
38 }
39 }
40}
41
42impl fmt::Display for IsolationLevel {
43 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 f.write_str(self.as_sql())
45 }
46}
47
48pub struct Transaction {
78 inner: Option<bsql_driver_postgres::Transaction>,
79 finished: bool,
81}
82
83impl Transaction {
84 pub(crate) fn from_driver(tx: bsql_driver_postgres::Transaction) -> Self {
86 Self {
87 inner: Some(tx),
88 finished: false,
89 }
90 }
91
92 fn consumed_error() -> BsqlError {
94 BsqlError::Query(QueryError {
95 message: "transaction already consumed".into(),
96 pg_code: None,
97 source: None,
98 })
99 }
100
101 pub async fn commit(mut self) -> BsqlResult<()> {
105 self.finished = true;
106 let tx = self.inner.take().ok_or_else(Self::consumed_error)?;
107 tx.commit().map_err(BsqlError::from)
108 }
109
110 pub async fn rollback(mut self) -> BsqlResult<()> {
114 self.finished = true;
115 let tx = self.inner.take().ok_or_else(Self::consumed_error)?;
116 tx.rollback().map_err(BsqlError::from)
117 }
118
119 pub async fn savepoint(&mut self, name: &str) -> BsqlResult<()> {
124 validate_savepoint_name(name)?;
125 let sql = format!("SAVEPOINT {name}");
126 let tx = self.inner.as_mut().ok_or_else(Self::consumed_error)?;
127 tx.simple_query(&sql).map_err(BsqlError::from_driver_query)
128 }
129
130 pub async fn release_savepoint(&mut self, name: &str) -> BsqlResult<()> {
134 validate_savepoint_name(name)?;
135 let sql = format!("RELEASE SAVEPOINT {name}");
136 let tx = self.inner.as_mut().ok_or_else(Self::consumed_error)?;
137 tx.simple_query(&sql).map_err(BsqlError::from_driver_query)
138 }
139
140 pub async fn rollback_to(&mut self, name: &str) -> BsqlResult<()> {
144 validate_savepoint_name(name)?;
145 let sql = format!("ROLLBACK TO SAVEPOINT {name}");
146 let tx = self.inner.as_mut().ok_or_else(Self::consumed_error)?;
147 tx.simple_query(&sql).map_err(BsqlError::from_driver_query)
148 }
149
150 pub async fn set_isolation(&mut self, level: IsolationLevel) -> BsqlResult<()> {
156 let sql = format!("SET TRANSACTION ISOLATION LEVEL {}", level.as_sql());
157 let tx = self.inner.as_mut().ok_or_else(Self::consumed_error)?;
158 tx.simple_query(&sql).map_err(BsqlError::from_driver_query)
159 }
160
161 pub(crate) fn query_inner(
163 &mut self,
164 sql: &str,
165 sql_hash: u64,
166 params: &[&(dyn Encode + Sync)],
167 ) -> BsqlResult<OwnedResult> {
168 let tx = self.inner.as_mut().ok_or_else(Self::consumed_error)?;
169 let result = tx
170 .query(sql, sql_hash, params)
171 .map_err(BsqlError::from_driver_query)?;
172 Ok(OwnedResult::without_arena(result))
173 }
174
175 pub(crate) fn execute_inner(
177 &mut self,
178 sql: &str,
179 sql_hash: u64,
180 params: &[&(dyn Encode + Sync)],
181 ) -> BsqlResult<u64> {
182 let tx = self.inner.as_mut().ok_or_else(Self::consumed_error)?;
183 tx.execute(sql, sql_hash, params)
184 .map_err(BsqlError::from_driver_query)
185 }
186
187 pub async fn execute_pipeline(
193 &mut self,
194 sql: &str,
195 sql_hash: u64,
196 param_sets: &[&[&(dyn Encode + Sync)]],
197 ) -> BsqlResult<Vec<u64>> {
198 let tx = self.inner.as_mut().ok_or_else(Self::consumed_error)?;
199 tx.execute_pipeline(sql, sql_hash, param_sets)
200 .map_err(BsqlError::from_driver_query)
201 }
202
203 #[doc(hidden)]
220 pub async fn defer_execute(
221 &mut self,
222 sql: &str,
223 sql_hash: u64,
224 params: &[&(dyn Encode + Sync)],
225 ) -> BsqlResult<()> {
226 let tx = self.inner.as_mut().ok_or_else(Self::consumed_error)?;
227 tx.defer_execute(sql, sql_hash, params)
228 .map_err(BsqlError::from_driver_query)
229 }
230
231 #[doc(hidden)]
236 pub async fn flush_deferred(&mut self) -> BsqlResult<Vec<u64>> {
237 let tx = self.inner.as_mut().ok_or_else(Self::consumed_error)?;
238 tx.flush_deferred().map_err(BsqlError::from_driver_query)
239 }
240
241 #[doc(hidden)]
247 pub fn deferred_count(&self) -> usize {
248 match self.inner.as_ref() {
249 Some(tx) => tx.deferred_count(),
250 None => 0,
251 }
252 }
253
254 pub async fn for_each_raw<F>(
259 &mut self,
260 sql: &str,
261 sql_hash: u64,
262 params: &[&(dyn Encode + Sync)],
263 mut f: F,
264 ) -> BsqlResult<()>
265 where
266 F: FnMut(bsql_driver_postgres::PgDataRow<'_>) -> BsqlResult<()>,
267 {
268 let tx = self.inner.as_mut().ok_or_else(Self::consumed_error)?;
269 let mut user_err: Option<BsqlError> = None;
270 let driver_result = tx.for_each(sql, sql_hash, params, |row| match f(row) {
271 Ok(()) => Ok(()),
272 Err(e) => {
273 user_err = Some(e);
274 Err(bsql_driver_postgres::DriverError::Protocol(
275 "for_each closure error".into(),
276 ))
277 }
278 });
279 if let Some(e) = user_err {
280 return Err(e);
281 }
282 driver_result.map_err(BsqlError::from_driver_query)
283 }
284
285 #[doc(hidden)]
290 pub async fn __for_each_raw_bytes<F>(
291 &mut self,
292 sql: &str,
293 sql_hash: u64,
294 params: &[&(dyn Encode + Sync)],
295 mut f: F,
296 ) -> BsqlResult<()>
297 where
298 F: FnMut(&[u8]) -> BsqlResult<()>,
299 {
300 let tx = self.inner.as_mut().ok_or_else(Self::consumed_error)?;
301 let mut user_err: Option<BsqlError> = None;
302 let driver_result = tx.for_each_raw(sql, sql_hash, params, |data| match f(data) {
303 Ok(()) => Ok(()),
304 Err(e) => {
305 user_err = Some(e);
306 Err(bsql_driver_postgres::DriverError::Protocol(
307 "for_each closure error".into(),
308 ))
309 }
310 });
311 if let Some(e) = user_err {
312 return Err(e);
313 }
314 driver_result.map_err(BsqlError::from_driver_query)
315 }
316}
317
318impl fmt::Debug for Transaction {
319 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
320 f.debug_struct("Transaction")
321 .field("finished", &self.finished)
322 .finish()
323 }
324}
325
326impl Drop for Transaction {
327 fn drop(&mut self) {
328 if !self.finished {
329 log::warn!(
334 "bsql: Transaction dropped without commit() or rollback() — \
335 connection discarded from pool. This is safe but wasteful."
336 );
337 }
338 }
339}
340
341fn validate_savepoint_name(name: &str) -> BsqlResult<()> {
343 crate::util::validate_savepoint_name(name)
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349
350 #[test]
351 fn validate_savepoint_name_valid() {
352 assert!(validate_savepoint_name("sp1").is_ok());
353 assert!(validate_savepoint_name("_sp").is_ok());
354 assert!(validate_savepoint_name("my_savepoint_123").is_ok());
355 }
356
357 #[test]
358 fn validate_savepoint_name_empty() {
359 assert!(validate_savepoint_name("").is_err());
360 }
361
362 #[test]
363 fn validate_savepoint_name_too_long() {
364 let long = "a".repeat(64);
365 assert!(validate_savepoint_name(&long).is_err());
366 }
367
368 #[test]
369 fn validate_savepoint_name_max_length() {
370 let max = "a".repeat(63);
371 assert!(validate_savepoint_name(&max).is_ok());
372 }
373
374 #[test]
375 fn validate_savepoint_name_starts_with_digit() {
376 assert!(validate_savepoint_name("1sp").is_err());
377 }
378
379 #[test]
380 fn validate_savepoint_name_starts_with_underscore() {
381 assert!(validate_savepoint_name("_sp").is_ok());
382 }
383
384 #[test]
385 fn validate_savepoint_name_special_chars() {
386 assert!(validate_savepoint_name("sp-1").is_err());
387 assert!(validate_savepoint_name("sp.1").is_err());
388 assert!(validate_savepoint_name("sp 1").is_err());
389 assert!(validate_savepoint_name("sp;1").is_err());
390 assert!(validate_savepoint_name("sp'1").is_err());
391 }
392
393 #[test]
394 fn isolation_level_display() {
395 assert_eq!(
396 IsolationLevel::ReadUncommitted.to_string(),
397 "READ UNCOMMITTED"
398 );
399 assert_eq!(IsolationLevel::ReadCommitted.to_string(), "READ COMMITTED");
400 assert_eq!(
401 IsolationLevel::RepeatableRead.to_string(),
402 "REPEATABLE READ"
403 );
404 assert_eq!(IsolationLevel::Serializable.to_string(), "SERIALIZABLE");
405 }
406
407 #[test]
410 fn isolation_level_clone() {
411 let level = IsolationLevel::Serializable;
412 let cloned = level;
413 assert_eq!(level, cloned);
414 }
415
416 #[test]
417 fn isolation_level_debug() {
418 let level = IsolationLevel::RepeatableRead;
419 let dbg = format!("{level:?}");
420 assert!(
421 dbg.contains("RepeatableRead"),
422 "Debug should show variant name: {dbg}"
423 );
424 }
425
426 #[test]
427 fn isolation_level_eq() {
428 assert_eq!(IsolationLevel::Serializable, IsolationLevel::Serializable);
429 assert_ne!(IsolationLevel::Serializable, IsolationLevel::ReadCommitted);
430 }
431
432 #[test]
435 fn transaction_debug_shows_finished_false() {
436 fn _assert_debug<T: std::fmt::Debug>() {}
439 _assert_debug::<Transaction>();
440 }
441
442 fn _assert_send<T: Send>() {}
445
446 #[test]
447 fn transaction_is_send() {
448 _assert_send::<Transaction>();
449 }
450
451 #[test]
452 fn isolation_level_is_send() {
453 _assert_send::<IsolationLevel>();
454 }
455
456 #[test]
459 fn isolation_level_as_sql_all_variants() {
460 assert_eq!(IsolationLevel::ReadUncommitted.as_sql(), "READ UNCOMMITTED");
461 assert_eq!(IsolationLevel::ReadCommitted.as_sql(), "READ COMMITTED");
462 assert_eq!(IsolationLevel::RepeatableRead.as_sql(), "REPEATABLE READ");
463 assert_eq!(IsolationLevel::Serializable.as_sql(), "SERIALIZABLE");
464 }
465
466 #[test]
469 fn validate_savepoint_name_single_char() {
470 assert!(validate_savepoint_name("a").is_ok());
471 assert!(validate_savepoint_name("_").is_ok());
472 }
473
474 #[test]
475 fn validate_savepoint_name_all_digits_after_letter() {
476 assert!(validate_savepoint_name("a123456789").is_ok());
477 }
478
479 #[test]
480 fn validate_savepoint_name_all_underscores() {
481 assert!(validate_savepoint_name("___").is_ok());
482 }
483
484 #[test]
485 fn validate_savepoint_name_unicode_rejected() {
486 assert!(
487 validate_savepoint_name("sp_\u{00e9}").is_err(),
488 "unicode chars should be rejected"
489 );
490 }
491
492 #[test]
493 fn validate_savepoint_name_sql_injection_rejected() {
494 assert!(validate_savepoint_name("sp; DROP TABLE").is_err());
495 assert!(validate_savepoint_name("sp'--").is_err());
496 assert!(validate_savepoint_name("sp\"test").is_err());
497 }
498
499 #[test]
502 fn consumed_error_message_is_descriptive() {
503 let e = Transaction::consumed_error();
504 let display = e.to_string();
505 assert!(
506 display.contains("transaction already consumed"),
507 "consumed error should be descriptive: {display}"
508 );
509 }
510
511 #[test]
514 fn isolation_level_as_sql_is_idempotent() {
515 let level = IsolationLevel::Serializable;
516 assert_eq!(level.as_sql(), level.as_sql());
517 assert_eq!(level.as_sql(), "SERIALIZABLE");
518 }
519
520 #[test]
523 fn isolation_level_display_matches_as_sql() {
524 for level in [
525 IsolationLevel::ReadUncommitted,
526 IsolationLevel::ReadCommitted,
527 IsolationLevel::RepeatableRead,
528 IsolationLevel::Serializable,
529 ] {
530 assert_eq!(level.to_string(), level.as_sql());
531 }
532 }
533
534 #[test]
538 fn transaction_from_driver_compiles() {
539 fn _check(_tx: bsql_driver_postgres::Transaction) -> Transaction {
540 Transaction::from_driver(_tx)
541 }
542 }
543
544 #[test]
547 fn validate_savepoint_name_null_byte_rejected() {
548 assert!(
549 validate_savepoint_name("sp\0name").is_err(),
550 "null byte in savepoint name should be rejected"
551 );
552 }
553
554 #[test]
557 fn validate_savepoint_name_boundary_63_and_64() {
558 let ok_63 = format!("a{}", "b".repeat(62));
559 assert!(validate_savepoint_name(&ok_63).is_ok());
560 let err_64 = format!("a{}", "b".repeat(63));
561 assert!(validate_savepoint_name(&err_64).is_err());
562 }
563
564 #[test]
567 fn consumed_error_is_query_variant() {
568 let e = Transaction::consumed_error();
569 assert!(
570 matches!(e, BsqlError::Query(_)),
571 "consumed_error should be Query variant"
572 );
573 }
574}