1use std::fmt;
15
16use bsql_driver_postgres::arena::acquire_arena;
17use bsql_driver_postgres::codec::Encode;
18use tokio::sync::Mutex;
19
20use crate::error::{BsqlError, BsqlResult, QueryError};
21use crate::executor::OwnedResult;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum IsolationLevel {
26 ReadUncommitted,
27 ReadCommitted,
28 RepeatableRead,
29 Serializable,
30}
31
32impl IsolationLevel {
33 fn as_sql(&self) -> &'static str {
35 match self {
36 IsolationLevel::ReadUncommitted => "READ UNCOMMITTED",
37 IsolationLevel::ReadCommitted => "READ COMMITTED",
38 IsolationLevel::RepeatableRead => "REPEATABLE READ",
39 IsolationLevel::Serializable => "SERIALIZABLE",
40 }
41 }
42}
43
44impl fmt::Display for IsolationLevel {
45 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46 f.write_str(self.as_sql())
47 }
48}
49
50pub struct Transaction {
80 inner: Mutex<Option<bsql_driver_postgres::Transaction>>,
81 finished: bool,
83}
84
85impl Transaction {
86 pub(crate) fn from_driver(tx: bsql_driver_postgres::Transaction) -> Self {
88 Self {
89 inner: Mutex::new(Some(tx)),
90 finished: false,
91 }
92 }
93
94 fn consumed_error() -> BsqlError {
96 BsqlError::Query(QueryError {
97 message: "transaction already consumed".into(),
98 pg_code: None,
99 source: None,
100 })
101 }
102
103 pub async fn commit(mut self) -> BsqlResult<()> {
107 self.finished = true;
108 let tx = self
109 .inner
110 .lock()
111 .await
112 .take()
113 .ok_or_else(Self::consumed_error)?;
114 tx.commit().await.map_err(BsqlError::from)
115 }
116
117 pub async fn rollback(mut self) -> BsqlResult<()> {
121 self.finished = true;
122 let tx = self
123 .inner
124 .lock()
125 .await
126 .take()
127 .ok_or_else(Self::consumed_error)?;
128 tx.rollback().await.map_err(BsqlError::from)
129 }
130
131 pub async fn savepoint(&self, name: &str) -> BsqlResult<()> {
136 validate_savepoint_name(name)?;
137 let sql = format!("SAVEPOINT {name}");
138 let mut guard = self.inner.lock().await;
139 let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
140 tx.simple_query(&sql)
141 .await
142 .map_err(BsqlError::from_driver_query)
143 }
144
145 pub async fn release_savepoint(&self, name: &str) -> BsqlResult<()> {
149 validate_savepoint_name(name)?;
150 let sql = format!("RELEASE SAVEPOINT {name}");
151 let mut guard = self.inner.lock().await;
152 let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
153 tx.simple_query(&sql)
154 .await
155 .map_err(BsqlError::from_driver_query)
156 }
157
158 pub async fn rollback_to(&self, name: &str) -> BsqlResult<()> {
162 validate_savepoint_name(name)?;
163 let sql = format!("ROLLBACK TO SAVEPOINT {name}");
164 let mut guard = self.inner.lock().await;
165 let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
166 tx.simple_query(&sql)
167 .await
168 .map_err(BsqlError::from_driver_query)
169 }
170
171 pub async fn set_isolation(&self, level: IsolationLevel) -> BsqlResult<()> {
177 let sql = format!("SET TRANSACTION ISOLATION LEVEL {}", level.as_sql());
178 let mut guard = self.inner.lock().await;
179 let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
180 tx.simple_query(&sql)
181 .await
182 .map_err(BsqlError::from_driver_query)
183 }
184
185 pub(crate) async fn query_inner(
187 &self,
188 sql: &str,
189 sql_hash: u64,
190 params: &[&(dyn Encode + Sync)],
191 ) -> BsqlResult<OwnedResult> {
192 let mut guard = self.inner.lock().await;
193 let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
194 let mut arena = acquire_arena();
195 let result = tx
196 .query(sql, sql_hash, params, &mut arena)
197 .await
198 .map_err(BsqlError::from_driver_query)?;
199 Ok(OwnedResult::new(result, arena))
200 }
201
202 pub(crate) async fn execute_inner(
204 &self,
205 sql: &str,
206 sql_hash: u64,
207 params: &[&(dyn Encode + Sync)],
208 ) -> BsqlResult<u64> {
209 let mut guard = self.inner.lock().await;
210 let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
211 tx.execute(sql, sql_hash, params)
212 .await
213 .map_err(BsqlError::from_driver_query)
214 }
215
216 pub async fn execute_pipeline(
222 &self,
223 sql: &str,
224 sql_hash: u64,
225 param_sets: &[&[&(dyn Encode + Sync)]],
226 ) -> BsqlResult<Vec<u64>> {
227 let mut guard = self.inner.lock().await;
228 let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
229 tx.execute_pipeline(sql, sql_hash, param_sets)
230 .await
231 .map_err(BsqlError::from_driver_query)
232 }
233
234 #[doc(hidden)]
251 pub async fn defer_execute(
252 &self,
253 sql: &str,
254 sql_hash: u64,
255 params: &[&(dyn Encode + Sync)],
256 ) -> BsqlResult<()> {
257 let mut guard = self.inner.lock().await;
258 let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
259 tx.defer_execute(sql, sql_hash, params)
260 .await
261 .map_err(BsqlError::from_driver_query)
262 }
263
264 #[doc(hidden)]
269 pub async fn flush_deferred(&self) -> BsqlResult<Vec<u64>> {
270 let mut guard = self.inner.lock().await;
271 let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
272 tx.flush_deferred()
273 .await
274 .map_err(BsqlError::from_driver_query)
275 }
276
277 #[doc(hidden)]
283 pub async fn deferred_count(&self) -> usize {
284 let guard = self.inner.lock().await;
285 match guard.as_ref() {
286 Some(tx) => tx.deferred_count(),
287 None => 0,
288 }
289 }
290
291 pub async fn for_each_raw<F>(
296 &self,
297 sql: &str,
298 sql_hash: u64,
299 params: &[&(dyn Encode + Sync)],
300 mut f: F,
301 ) -> BsqlResult<()>
302 where
303 F: FnMut(bsql_driver_postgres::PgDataRow<'_>) -> BsqlResult<()>,
304 {
305 let mut guard = self.inner.lock().await;
306 let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
307 let mut user_err: Option<BsqlError> = None;
308 let driver_result = tx
309 .for_each(sql, sql_hash, params, |row| match f(row) {
310 Ok(()) => Ok(()),
311 Err(e) => {
312 user_err = Some(e);
313 Err(bsql_driver_postgres::DriverError::Protocol(
314 "for_each closure error".into(),
315 ))
316 }
317 })
318 .await;
319 if let Some(e) = user_err {
320 return Err(e);
321 }
322 driver_result.map_err(BsqlError::from_driver_query)
323 }
324
325 #[doc(hidden)]
330 pub async fn __for_each_raw_bytes<F>(
331 &self,
332 sql: &str,
333 sql_hash: u64,
334 params: &[&(dyn Encode + Sync)],
335 mut f: F,
336 ) -> BsqlResult<()>
337 where
338 F: FnMut(&[u8]) -> BsqlResult<()>,
339 {
340 let mut guard = self.inner.lock().await;
341 let tx = guard.as_mut().ok_or_else(Self::consumed_error)?;
342 let mut user_err: Option<BsqlError> = None;
343 let driver_result = tx
344 .for_each_raw(sql, sql_hash, params, |data| match f(data) {
345 Ok(()) => Ok(()),
346 Err(e) => {
347 user_err = Some(e);
348 Err(bsql_driver_postgres::DriverError::Protocol(
349 "for_each closure error".into(),
350 ))
351 }
352 })
353 .await;
354 if let Some(e) = user_err {
355 return Err(e);
356 }
357 driver_result.map_err(BsqlError::from_driver_query)
358 }
359}
360
361impl fmt::Debug for Transaction {
362 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
363 f.debug_struct("Transaction")
364 .field("finished", &self.finished)
365 .finish()
366 }
367}
368
369impl Drop for Transaction {
370 fn drop(&mut self) {
371 if !self.finished {
372 eprintln!(
377 "bsql: Transaction dropped without commit() or rollback() — \
378 connection discarded from pool. This is safe but wasteful."
379 );
380 }
381 }
382}
383
384fn validate_savepoint_name(name: &str) -> BsqlResult<()> {
386 crate::util::validate_savepoint_name(name)
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392
393 #[test]
394 fn validate_savepoint_name_valid() {
395 assert!(validate_savepoint_name("sp1").is_ok());
396 assert!(validate_savepoint_name("_sp").is_ok());
397 assert!(validate_savepoint_name("my_savepoint_123").is_ok());
398 }
399
400 #[test]
401 fn validate_savepoint_name_empty() {
402 assert!(validate_savepoint_name("").is_err());
403 }
404
405 #[test]
406 fn validate_savepoint_name_too_long() {
407 let long = "a".repeat(64);
408 assert!(validate_savepoint_name(&long).is_err());
409 }
410
411 #[test]
412 fn validate_savepoint_name_max_length() {
413 let max = "a".repeat(63);
414 assert!(validate_savepoint_name(&max).is_ok());
415 }
416
417 #[test]
418 fn validate_savepoint_name_starts_with_digit() {
419 assert!(validate_savepoint_name("1sp").is_err());
420 }
421
422 #[test]
423 fn validate_savepoint_name_starts_with_underscore() {
424 assert!(validate_savepoint_name("_sp").is_ok());
425 }
426
427 #[test]
428 fn validate_savepoint_name_special_chars() {
429 assert!(validate_savepoint_name("sp-1").is_err());
430 assert!(validate_savepoint_name("sp.1").is_err());
431 assert!(validate_savepoint_name("sp 1").is_err());
432 assert!(validate_savepoint_name("sp;1").is_err());
433 assert!(validate_savepoint_name("sp'1").is_err());
434 }
435
436 #[test]
437 fn isolation_level_display() {
438 assert_eq!(
439 IsolationLevel::ReadUncommitted.to_string(),
440 "READ UNCOMMITTED"
441 );
442 assert_eq!(IsolationLevel::ReadCommitted.to_string(), "READ COMMITTED");
443 assert_eq!(
444 IsolationLevel::RepeatableRead.to_string(),
445 "REPEATABLE READ"
446 );
447 assert_eq!(IsolationLevel::Serializable.to_string(), "SERIALIZABLE");
448 }
449
450 #[test]
453 fn isolation_level_clone() {
454 let level = IsolationLevel::Serializable;
455 let cloned = level;
456 assert_eq!(level, cloned);
457 }
458
459 #[test]
460 fn isolation_level_debug() {
461 let level = IsolationLevel::RepeatableRead;
462 let dbg = format!("{level:?}");
463 assert!(
464 dbg.contains("RepeatableRead"),
465 "Debug should show variant name: {dbg}"
466 );
467 }
468
469 #[test]
470 fn isolation_level_eq() {
471 assert_eq!(IsolationLevel::Serializable, IsolationLevel::Serializable);
472 assert_ne!(IsolationLevel::Serializable, IsolationLevel::ReadCommitted);
473 }
474
475 #[test]
478 fn transaction_debug_shows_finished_false() {
479 fn _assert_debug<T: std::fmt::Debug>() {}
482 _assert_debug::<Transaction>();
483 }
484
485 fn _assert_send<T: Send>() {}
488 fn _assert_sync<T: Sync>() {}
489
490 #[test]
491 fn transaction_is_send() {
492 _assert_send::<Transaction>();
493 }
494
495 #[test]
496 fn transaction_is_sync() {
497 _assert_sync::<Transaction>();
498 }
499
500 #[test]
501 fn isolation_level_is_send_and_sync() {
502 _assert_send::<IsolationLevel>();
503 _assert_sync::<IsolationLevel>();
504 }
505}