use crate::{
client::{Client, ClientBorrowMut},
error::Error,
execute::Execute,
};
use super::Transaction;
#[derive(Debug, Copy, Clone)]
#[non_exhaustive]
pub enum IsolationLevel {
ReadUncommitted,
ReadCommitted,
RepeatableRead,
Serializable,
}
impl IsolationLevel {
const PREFIX: &str = " ISOLATION LEVEL ";
const READ_UNCOMMITTED: &str = "READ UNCOMMITTED,";
const READ_COMMITTED: &str = "READ COMMITTED,";
const REPEATABLE_READ: &str = "REPEATABLE READ,";
const SERIALIZABLE: &str = "SERIALIZABLE,";
fn write(self, str: &mut StackStr) {
str.push_str(Self::PREFIX);
str.push_str(match self {
IsolationLevel::ReadUncommitted => Self::READ_UNCOMMITTED,
IsolationLevel::ReadCommitted => Self::READ_COMMITTED,
IsolationLevel::RepeatableRead => Self::REPEATABLE_READ,
IsolationLevel::Serializable => Self::SERIALIZABLE,
});
}
}
pub struct TransactionBuilder {
isolation_level: Option<IsolationLevel>,
read_only: Option<bool>,
deferrable: Option<bool>,
}
impl TransactionBuilder {
pub const fn new() -> Self {
Self {
isolation_level: None,
read_only: None,
deferrable: None,
}
}
pub fn isolation_level(mut self, isolation_level: IsolationLevel) -> Self {
self.isolation_level = Some(isolation_level);
self
}
pub fn read_only(mut self, read_only: bool) -> Self {
self.read_only = Some(read_only);
self
}
pub fn deferrable(mut self, deferrable: bool) -> Self {
self.deferrable = Some(deferrable);
self
}
pub async fn begin<C>(self, cli: &mut C) -> Result<Transaction<'_, C>, Error>
where
C: ClientBorrowMut,
{
self._begin(cli.borrow_cli_mut()).await.map(|_| Transaction::new(cli))
}
pub async fn begin_owned<'a, C>(self, mut cli: C) -> Result<Transaction<'a, C>, Error>
where
C: ClientBorrowMut + 'a,
{
self._begin(cli.borrow_cli_mut())
.await
.map(|_| Transaction::new_owned(cli))
}
async fn _begin(self, cli: &Client) -> Result<u64, Error> {
let mut query = const { StackStr::new("START TRANSACTION") };
let Self {
isolation_level,
read_only,
deferrable,
} = self;
if let Some(isolation_level) = isolation_level {
isolation_level.write(&mut query);
}
if let Some(read_only) = read_only {
let s = if read_only { " READ ONLY," } else { " READ WRITE," };
query.push_str(s);
}
if let Some(deferrable) = deferrable {
let s = if deferrable { " DEFERRABLE" } else { " NOT DEFERRABLE" };
query.push_str(s);
}
query.pop_if_ends_with(",");
query.as_str().execute(cli).await
}
}
struct StackStr {
buf: [u8; 120],
cursor: usize,
}
impl StackStr {
const fn new(str: &str) -> Self {
let mut buf = [0; 120];
let mut cursor = 0;
let str = str.as_bytes();
while cursor < str.len() {
buf[cursor] = str[cursor];
cursor += 1;
}
Self { buf, cursor }
}
fn push_str(&mut self, str: &str) {
let start = self.cursor;
self.cursor += str.len();
self.buf[start..self.cursor].copy_from_slice(str.as_bytes());
}
fn as_str(&self) -> &str {
core::str::from_utf8(&self.buf[..self.cursor]).unwrap()
}
fn pop_if_ends_with(&mut self, needle: &str) {
let needle = needle.as_bytes();
if self.buf[..self.cursor].ends_with(needle) {
self.cursor -= needle.len();
}
}
}
#[cfg(test)]
mod test {
use std::sync::Arc;
use crate::{
Client, Postgres,
client::{ClientBorrow, ClientBorrowMut},
};
use super::{IsolationLevel, TransactionBuilder};
#[tokio::test]
async fn client_borrow_mut() {
#[derive(Clone)]
struct PanicCli(Arc<Client>);
impl PanicCli {
fn new(cli: Client) -> Self {
Self(Arc::new(cli))
}
}
impl ClientBorrow for PanicCli {
fn borrow_cli_ref(&self) -> &Client {
&self.0
}
}
impl ClientBorrowMut for PanicCli {
fn borrow_cli_mut(&mut self) -> &mut Client {
Arc::get_mut(&mut self.0).unwrap()
}
}
let (cli, drv) = Postgres::new("postgres://postgres:postgres@localhost:5432")
.connect()
.await
.unwrap();
tokio::spawn(drv.into_future());
let mut cli = PanicCli::new(cli);
{
let _tx = TransactionBuilder::new().begin(&mut cli).await.unwrap();
}
let res = tokio::spawn(async move {
let _cli2 = cli.clone();
let _tx = TransactionBuilder::new().begin_owned(cli).await.unwrap();
})
.await
.err()
.unwrap();
assert!(res.is_panic());
}
#[tokio::test]
async fn transaction_builder() {
let (mut cli, drv) = Postgres::new("postgres://postgres:postgres@localhost:5432")
.connect()
.await
.unwrap();
tokio::spawn(drv.into_future());
let _ = TransactionBuilder::new()
.isolation_level(IsolationLevel::ReadUncommitted)
.read_only(false)
.deferrable(false)
.begin(&mut cli)
.await
.unwrap();
}
}