use crate::assert_send_sync;
use crate::transaction::DropBehavior;
use crate::transaction::TransactionBehavior;
use crate::Error;
use crate::IntoParams;
use crate::Row;
use crate::Rows;
use crate::Statement;
use std::fmt::Debug;
use std::sync::atomic::AtomicU8;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::sync::Mutex;
use std::task::Waker;
pub type Result<T> = std::result::Result<T, Error>;
pub(crate) struct AtomicDropBehavior {
inner: AtomicU8,
}
impl AtomicDropBehavior {
fn new(behavior: DropBehavior) -> Self {
Self {
inner: AtomicU8::new(behavior.into()),
}
}
fn load(&self, ordering: Ordering) -> DropBehavior {
self.inner.load(ordering).into()
}
pub(crate) fn store(&self, behavior: DropBehavior, ordering: Ordering) {
self.inner.store(behavior.into(), ordering);
}
}
pub struct Connection {
inner: Option<Arc<turso_sdk_kit::rsapi::TursoConnection>>,
pub(crate) transaction_behavior: TransactionBehavior,
pub(crate) dangling_tx: AtomicDropBehavior,
pub(crate) extra_io: Option<Arc<dyn Fn(Waker) -> Result<()> + Send + Sync>>,
}
assert_send_sync!(Connection);
impl Clone for Connection {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
transaction_behavior: self.transaction_behavior,
dangling_tx: AtomicDropBehavior::new(self.dangling_tx.load(Ordering::SeqCst)),
extra_io: self.extra_io.clone(),
}
}
}
impl Connection {
pub fn create(
conn: Arc<turso_sdk_kit::rsapi::TursoConnection>,
extra_io: Option<Arc<dyn Fn(Waker) -> Result<()> + Send + Sync>>,
) -> Self {
#[allow(clippy::arc_with_non_send_sync)]
let connection = Connection {
inner: Some(conn),
transaction_behavior: TransactionBehavior::Deferred,
dangling_tx: AtomicDropBehavior::new(DropBehavior::Ignore),
extra_io,
};
connection
}
pub(crate) async fn maybe_handle_dangling_tx(&self) -> Result<()> {
match self.dangling_tx.load(Ordering::SeqCst) {
DropBehavior::Rollback => {
let mut stmt = self.prepare("ROLLBACK").await?;
stmt.execute(()).await?;
self.dangling_tx
.store(DropBehavior::Ignore, Ordering::SeqCst);
}
DropBehavior::Commit => {
let mut stmt = self.prepare("COMMIT").await?;
stmt.execute(()).await?;
self.dangling_tx
.store(DropBehavior::Ignore, Ordering::SeqCst);
}
DropBehavior::Ignore => {}
DropBehavior::Panic => {
panic!("Transaction dropped unexpectedly.");
}
}
Ok(())
}
pub async fn query(&self, sql: impl AsRef<str>, params: impl IntoParams) -> Result<Rows> {
self.maybe_handle_dangling_tx().await?;
let mut stmt = self.prepare(sql).await?;
stmt.query(params).await
}
pub async fn execute(&self, sql: impl AsRef<str>, params: impl IntoParams) -> Result<u64> {
self.maybe_handle_dangling_tx().await?;
let mut stmt = self.prepare(sql).await?;
stmt.execute(params).await
}
fn get_inner_connection(&self) -> Result<Arc<turso_sdk_kit::rsapi::TursoConnection>> {
match &self.inner {
Some(inner) => Ok(inner.clone()),
None => Err(Error::Misuse("inner connection must be set".to_string())),
}
}
pub async fn execute_batch(&self, sql: impl AsRef<str>) -> Result<()> {
self.maybe_handle_dangling_tx().await?;
self.prepare_execute_batch(sql).await?;
Ok(())
}
pub async fn prepare(&self, sql: impl AsRef<str>) -> Result<Statement> {
let conn = self.get_inner_connection()?;
let stmt = conn.prepare_single(sql)?;
#[allow(clippy::arc_with_non_send_sync)]
let statement = Statement {
conn: self.clone(),
inner: Arc::new(Mutex::new(stmt)),
};
Ok(statement)
}
pub async fn prepare_cached(&self, sql: impl AsRef<str>) -> Result<Statement> {
let conn = self.get_inner_connection()?;
let stmt = conn.prepare_cached(sql)?;
#[allow(clippy::arc_with_non_send_sync)]
let statement = Statement {
conn: self.clone(),
inner: Arc::new(Mutex::new(stmt)),
};
Ok(statement)
}
async fn prepare_execute_batch(&self, sql: impl AsRef<str>) -> Result<()> {
self.maybe_handle_dangling_tx().await?;
let conn = self.get_inner_connection()?;
let mut sql = sql.as_ref();
while let Some((stmt, offset)) = conn.prepare_first(sql)? {
let mut stmt = Statement {
conn: self.clone(),
inner: Arc::new(Mutex::new(stmt)),
};
let _ = stmt.execute(()).await?;
sql = &sql[offset..];
}
Ok(())
}
pub async fn pragma_query<F>(&self, pragma_name: &str, mut f: F) -> Result<()>
where
F: FnMut(&Row) -> std::result::Result<(), turso_sdk_kit::rsapi::TursoError>,
{
let sql = format!("PRAGMA {pragma_name}");
let mut stmt = self.prepare(&sql).await?;
let mut rows = stmt.query(()).await?;
while let Some(row) = rows.next().await? {
f(&row)?;
}
Ok(())
}
pub async fn pragma_update<V: std::fmt::Display>(
&self,
pragma_name: &str,
pragma_value: V,
) -> Result<Vec<Row>> {
let sql = format!("PRAGMA {pragma_name} = {pragma_value}");
let mut stmt = self.prepare(&sql).await?;
let mut rows = stmt.query(()).await?;
let mut collected = Vec::new();
while let Some(row) = rows.next().await? {
collected.push(row);
}
Ok(collected)
}
pub fn last_insert_rowid(&self) -> i64 {
let conn = self.get_inner_connection().unwrap();
conn.last_insert_rowid()
}
pub fn cacheflush(&self) -> Result<()> {
let conn = self.get_inner_connection()?;
conn.cacheflush()?;
Ok(())
}
pub fn is_autocommit(&self) -> Result<bool> {
let conn = self.get_inner_connection()?;
Ok(conn.get_auto_commit())
}
pub fn busy_timeout(&self, duration: std::time::Duration) -> Result<()> {
let conn = self.get_inner_connection()?;
conn.set_busy_timeout(duration);
Ok(())
}
}
impl Debug for Connection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Connection").finish()
}
}