1use std::fmt;
15use std::sync::Mutex;
16
17use bsql_driver_postgres::codec::Encode;
18
19use crate::error::{BsqlError, BsqlResult, QueryError};
20use crate::executor::OwnedResult;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum IsolationLevel {
25 ReadUncommitted,
26 ReadCommitted,
27 RepeatableRead,
28 Serializable,
29}
30
31impl IsolationLevel {
32 fn as_sql(&self) -> &'static str {
34 match self {
35 IsolationLevel::ReadUncommitted => "READ UNCOMMITTED",
36 IsolationLevel::ReadCommitted => "READ COMMITTED",
37 IsolationLevel::RepeatableRead => "REPEATABLE READ",
38 IsolationLevel::Serializable => "SERIALIZABLE",
39 }
40 }
41}
42
43impl fmt::Display for IsolationLevel {
44 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45 f.write_str(self.as_sql())
46 }
47}
48
49pub struct Transaction {
79 inner: Mutex<Option<bsql_driver_postgres::Transaction>>,
80 finished: bool,
82}
83
84impl Transaction {
85 pub(crate) fn from_driver(tx: bsql_driver_postgres::Transaction) -> Self {
87 Self {
88 inner: Mutex::new(Some(tx)),
89 finished: false,
90 }
91 }
92
93 fn consumed_error() -> BsqlError {
95 BsqlError::Query(QueryError {
96 message: "transaction already consumed".into(),
97 pg_code: None,
98 source: None,
99 })
100 }
101
102 pub async fn commit(mut self) -> BsqlResult<()> {
106 self.finished = true;
107 let tx = self
108 .inner
109 .lock()
110 .unwrap_or_else(|e| e.into_inner())
111 .take()
112 .ok_or_else(Self::consumed_error)?;
113 tx.commit().map_err(BsqlError::from)
114 }
115
116 pub async fn rollback(mut self) -> BsqlResult<()> {
120 self.finished = true;
121 let tx = self
122 .inner
123 .lock()
124 .unwrap_or_else(|e| e.into_inner())
125 .take()
126 .ok_or_else(Self::consumed_error)?;
127 tx.rollback().map_err(BsqlError::from)
128 }
129
130 pub async fn savepoint(&self, name: &str) -> BsqlResult<()> {
135 validate_savepoint_name(name)?;
136 let sql = format!("SAVEPOINT {name}");
137 let mut guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
138 let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
139 tx.simple_query(&sql).map_err(BsqlError::from_driver_query)
140 }
141
142 pub async fn release_savepoint(&self, name: &str) -> BsqlResult<()> {
146 validate_savepoint_name(name)?;
147 let sql = format!("RELEASE SAVEPOINT {name}");
148 let mut guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
149 let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
150 tx.simple_query(&sql).map_err(BsqlError::from_driver_query)
151 }
152
153 pub async fn rollback_to(&self, name: &str) -> BsqlResult<()> {
157 validate_savepoint_name(name)?;
158 let sql = format!("ROLLBACK TO SAVEPOINT {name}");
159 let mut guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
160 let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
161 tx.simple_query(&sql).map_err(BsqlError::from_driver_query)
162 }
163
164 pub async fn set_isolation(&self, level: IsolationLevel) -> BsqlResult<()> {
170 let sql = format!("SET TRANSACTION ISOLATION LEVEL {}", level.as_sql());
171 let mut guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
172 let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
173 tx.simple_query(&sql).map_err(BsqlError::from_driver_query)
174 }
175
176 pub(crate) fn query_inner(
178 &self,
179 sql: &str,
180 sql_hash: u64,
181 params: &[&(dyn Encode + Sync)],
182 ) -> BsqlResult<OwnedResult> {
183 let mut guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
184 let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
185 let result = tx
186 .query(sql, sql_hash, params)
187 .map_err(BsqlError::from_driver_query)?;
188 Ok(OwnedResult::without_arena(result))
189 }
190
191 pub(crate) fn execute_inner(
193 &self,
194 sql: &str,
195 sql_hash: u64,
196 params: &[&(dyn Encode + Sync)],
197 ) -> BsqlResult<u64> {
198 let mut guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
199 let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
200 tx.execute(sql, sql_hash, params)
201 .map_err(BsqlError::from_driver_query)
202 }
203
204 pub async fn execute_pipeline(
210 &self,
211 sql: &str,
212 sql_hash: u64,
213 param_sets: &[&[&(dyn Encode + Sync)]],
214 ) -> BsqlResult<Vec<u64>> {
215 let mut guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
216 let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
217 tx.execute_pipeline(sql, sql_hash, param_sets)
218 .map_err(BsqlError::from_driver_query)
219 }
220
221 #[doc(hidden)]
238 pub async fn defer_execute(
239 &self,
240 sql: &str,
241 sql_hash: u64,
242 params: &[&(dyn Encode + Sync)],
243 ) -> BsqlResult<()> {
244 let mut guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
245 let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
246 tx.defer_execute(sql, sql_hash, params)
247 .map_err(BsqlError::from_driver_query)
248 }
249
250 #[doc(hidden)]
255 pub async fn flush_deferred(&self) -> BsqlResult<Vec<u64>> {
256 let mut guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
257 let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
258 tx.flush_deferred().map_err(BsqlError::from_driver_query)
259 }
260
261 #[doc(hidden)]
267 pub fn deferred_count(&self) -> usize {
268 let guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
269 match guard.as_ref() {
270 Some(tx) => tx.deferred_count(),
271 None => 0,
272 }
273 }
274
275 pub async fn for_each_raw<F>(
280 &self,
281 sql: &str,
282 sql_hash: u64,
283 params: &[&(dyn Encode + Sync)],
284 mut f: F,
285 ) -> BsqlResult<()>
286 where
287 F: FnMut(bsql_driver_postgres::PgDataRow<'_>) -> BsqlResult<()>,
288 {
289 let mut guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
290 let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
291 let mut user_err: Option<BsqlError> = None;
292 let driver_result = tx.for_each(sql, sql_hash, params, |row| match f(row) {
293 Ok(()) => Ok(()),
294 Err(e) => {
295 user_err = Some(e);
296 Err(bsql_driver_postgres::DriverError::Protocol(
297 "for_each closure error".into(),
298 ))
299 }
300 });
301 if let Some(e) = user_err {
302 return Err(e);
303 }
304 driver_result.map_err(BsqlError::from_driver_query)
305 }
306
307 #[doc(hidden)]
312 pub async fn __for_each_raw_bytes<F>(
313 &self,
314 sql: &str,
315 sql_hash: u64,
316 params: &[&(dyn Encode + Sync)],
317 mut f: F,
318 ) -> BsqlResult<()>
319 where
320 F: FnMut(&[u8]) -> BsqlResult<()>,
321 {
322 let mut guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
323 let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
324 let mut user_err: Option<BsqlError> = None;
325 let driver_result = tx.for_each_raw(sql, sql_hash, params, |data| match f(data) {
326 Ok(()) => Ok(()),
327 Err(e) => {
328 user_err = Some(e);
329 Err(bsql_driver_postgres::DriverError::Protocol(
330 "for_each closure error".into(),
331 ))
332 }
333 });
334 if let Some(e) = user_err {
335 return Err(e);
336 }
337 driver_result.map_err(BsqlError::from_driver_query)
338 }
339}
340
341impl fmt::Debug for Transaction {
342 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
343 f.debug_struct("Transaction")
344 .field("finished", &self.finished)
345 .finish()
346 }
347}
348
349impl Drop for Transaction {
350 fn drop(&mut self) {
351 if !self.finished {
352 eprintln!(
357 "bsql: Transaction dropped without commit() or rollback() — \
358 connection discarded from pool. This is safe but wasteful."
359 );
360 }
361 }
362}
363
364fn validate_savepoint_name(name: &str) -> BsqlResult<()> {
366 crate::util::validate_savepoint_name(name)
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372
373 #[test]
374 fn validate_savepoint_name_valid() {
375 assert!(validate_savepoint_name("sp1").is_ok());
376 assert!(validate_savepoint_name("_sp").is_ok());
377 assert!(validate_savepoint_name("my_savepoint_123").is_ok());
378 }
379
380 #[test]
381 fn validate_savepoint_name_empty() {
382 assert!(validate_savepoint_name("").is_err());
383 }
384
385 #[test]
386 fn validate_savepoint_name_too_long() {
387 let long = "a".repeat(64);
388 assert!(validate_savepoint_name(&long).is_err());
389 }
390
391 #[test]
392 fn validate_savepoint_name_max_length() {
393 let max = "a".repeat(63);
394 assert!(validate_savepoint_name(&max).is_ok());
395 }
396
397 #[test]
398 fn validate_savepoint_name_starts_with_digit() {
399 assert!(validate_savepoint_name("1sp").is_err());
400 }
401
402 #[test]
403 fn validate_savepoint_name_starts_with_underscore() {
404 assert!(validate_savepoint_name("_sp").is_ok());
405 }
406
407 #[test]
408 fn validate_savepoint_name_special_chars() {
409 assert!(validate_savepoint_name("sp-1").is_err());
410 assert!(validate_savepoint_name("sp.1").is_err());
411 assert!(validate_savepoint_name("sp 1").is_err());
412 assert!(validate_savepoint_name("sp;1").is_err());
413 assert!(validate_savepoint_name("sp'1").is_err());
414 }
415
416 #[test]
417 fn isolation_level_display() {
418 assert_eq!(
419 IsolationLevel::ReadUncommitted.to_string(),
420 "READ UNCOMMITTED"
421 );
422 assert_eq!(IsolationLevel::ReadCommitted.to_string(), "READ COMMITTED");
423 assert_eq!(
424 IsolationLevel::RepeatableRead.to_string(),
425 "REPEATABLE READ"
426 );
427 assert_eq!(IsolationLevel::Serializable.to_string(), "SERIALIZABLE");
428 }
429
430 #[test]
433 fn isolation_level_clone() {
434 let level = IsolationLevel::Serializable;
435 let cloned = level;
436 assert_eq!(level, cloned);
437 }
438
439 #[test]
440 fn isolation_level_debug() {
441 let level = IsolationLevel::RepeatableRead;
442 let dbg = format!("{level:?}");
443 assert!(
444 dbg.contains("RepeatableRead"),
445 "Debug should show variant name: {dbg}"
446 );
447 }
448
449 #[test]
450 fn isolation_level_eq() {
451 assert_eq!(IsolationLevel::Serializable, IsolationLevel::Serializable);
452 assert_ne!(IsolationLevel::Serializable, IsolationLevel::ReadCommitted);
453 }
454
455 #[test]
458 fn transaction_debug_shows_finished_false() {
459 fn _assert_debug<T: std::fmt::Debug>() {}
462 _assert_debug::<Transaction>();
463 }
464
465 fn _assert_send<T: Send>() {}
468 fn _assert_sync<T: Sync>() {}
469
470 #[test]
471 fn transaction_is_send() {
472 _assert_send::<Transaction>();
473 }
474
475 #[test]
476 fn transaction_is_sync() {
477 _assert_sync::<Transaction>();
478 }
479
480 #[test]
481 fn isolation_level_is_send_and_sync() {
482 _assert_send::<IsolationLevel>();
483 _assert_sync::<IsolationLevel>();
484 }
485
486 #[test]
489 fn isolation_level_as_sql_all_variants() {
490 assert_eq!(IsolationLevel::ReadUncommitted.as_sql(), "READ UNCOMMITTED");
491 assert_eq!(IsolationLevel::ReadCommitted.as_sql(), "READ COMMITTED");
492 assert_eq!(IsolationLevel::RepeatableRead.as_sql(), "REPEATABLE READ");
493 assert_eq!(IsolationLevel::Serializable.as_sql(), "SERIALIZABLE");
494 }
495
496 #[test]
499 fn validate_savepoint_name_single_char() {
500 assert!(validate_savepoint_name("a").is_ok());
501 assert!(validate_savepoint_name("_").is_ok());
502 }
503
504 #[test]
505 fn validate_savepoint_name_all_digits_after_letter() {
506 assert!(validate_savepoint_name("a123456789").is_ok());
507 }
508
509 #[test]
510 fn validate_savepoint_name_all_underscores() {
511 assert!(validate_savepoint_name("___").is_ok());
512 }
513
514 #[test]
515 fn validate_savepoint_name_unicode_rejected() {
516 assert!(
517 validate_savepoint_name("sp_\u{00e9}").is_err(),
518 "unicode chars should be rejected"
519 );
520 }
521
522 #[test]
523 fn validate_savepoint_name_sql_injection_rejected() {
524 assert!(validate_savepoint_name("sp; DROP TABLE").is_err());
525 assert!(validate_savepoint_name("sp'--").is_err());
526 assert!(validate_savepoint_name("sp\"test").is_err());
527 }
528}