use std::num::NonZeroUsize;
use std::sync::{Arc, Mutex as StdMutex};
use async_trait::async_trait;
use limbo::params::Params as LimboParams;
use limbo::Builder;
use tokio::sync::Mutex as TokioMutex;
const STMT_CACHE_CAPACITY: usize = 128;
use oxisql_core::{
ColumnInfo, Connection, ForeignKeyInfo, IndexInfo, OxiSqlError, PreparedStatement, Row,
TableInfo, TableType, ToSqlValue, Transaction, Value,
};
use crate::error::SqliteCompatError;
use crate::types::{limbo_to_core_typed, rewrite_params, split_statements};
type StmtCache = Arc<StdMutex<lru::LruCache<String, limbo::Statement>>>;
fn new_stmt_cache() -> StmtCache {
let cap = NonZeroUsize::new(STMT_CACHE_CAPACITY).unwrap_or(NonZeroUsize::MIN);
Arc::new(StdMutex::new(lru::LruCache::new(cap)))
}
async fn exec_rewritten(
conn: &limbo::Connection,
sql: &str,
limbo_params: Vec<limbo::Value>,
cache: Option<&StmtCache>,
) -> Result<u64, SqliteCompatError> {
let lp = if limbo_params.is_empty() {
LimboParams::None
} else {
LimboParams::Positional(limbo_params)
};
match cache {
Some(c) => {
let cached_stmt: Option<limbo::Statement> = c
.lock()
.map_err(|e| SqliteCompatError::Other(format!("stmt_cache lock poisoned: {e}")))?
.pop(sql);
let mut stmt = match cached_stmt {
Some(s) => s,
None => {
conn.prepare(sql).await.map_err(SqliteCompatError::from)?
}
};
stmt.execute(lp).await.map_err(SqliteCompatError::from)?;
c.lock()
.map_err(|e| SqliteCompatError::Other(format!("stmt_cache lock poisoned: {e}")))?
.put(sql.to_owned(), stmt);
let n = conn
.changes()
.map_err(|e| SqliteCompatError::Other(format!("changes() failed: {e}")))?;
Ok(n.max(0) as u64)
}
None => {
conn.execute(sql, lp)
.await
.map_err(SqliteCompatError::from)?;
let n = conn
.changes()
.map_err(|e| SqliteCompatError::Other(format!("changes() failed: {e}")))?;
Ok(n.max(0) as u64)
}
}
}
async fn query_rewritten(
conn: &limbo::Connection,
sql: &str,
limbo_params: Vec<limbo::Value>,
) -> Result<Vec<Row>, SqliteCompatError> {
let lp = if limbo_params.is_empty() {
LimboParams::None
} else {
LimboParams::Positional(limbo_params)
};
let mut stmt = conn.prepare(sql).await.map_err(SqliteCompatError::from)?;
let col_info: Vec<(String, Option<String>)> = stmt
.columns()
.iter()
.map(|c| (c.name().to_owned(), c.decl_type().map(str::to_owned)))
.collect();
let col_names: Vec<String> = col_info.iter().map(|(name, _)| name.clone()).collect();
let mut rows_iter = stmt.query(lp).await.map_err(SqliteCompatError::from)?;
let mut rows: Vec<Row> = Vec::new();
while let Some(limbo_row) = rows_iter.next().await.map_err(SqliteCompatError::from)? {
let mut values: Vec<Value> = Vec::with_capacity(col_info.len());
for idx in 0..limbo_row.column_count() {
let raw = limbo_row.get_value(idx).map_err(SqliteCompatError::from)?;
let decl = col_info.get(idx).and_then(|(_, dt)| dt.as_deref());
values.push(limbo_to_core_typed(raw, decl)?);
}
rows.push(Row::new(col_names.clone(), values));
}
Ok(rows)
}
#[derive(Clone)]
pub struct SqliteConnection {
conn: limbo::Connection,
txn_lock: Arc<TokioMutex<()>>,
stmt_cache: StmtCache,
path: String,
}
impl std::fmt::Debug for SqliteConnection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let cache_len = self.stmt_cache.lock().map(|g| g.len()).unwrap_or(0);
f.debug_struct("SqliteConnection")
.field("path", &self.path)
.field("stmt_cache_len", &cache_len)
.finish_non_exhaustive()
}
}
impl SqliteConnection {
pub async fn open(path: &str) -> Result<Self, OxiSqlError> {
let db = Builder::new_local(path)
.build()
.await
.map_err(|e| OxiSqlError::Other(format!("limbo open error: {e}")))?;
let conn = db
.connect()
.map_err(|e| OxiSqlError::Other(format!("limbo connect error: {e}")))?;
Ok(Self {
conn,
txn_lock: Arc::new(TokioMutex::new(())),
stmt_cache: new_stmt_cache(),
path: path.to_owned(),
})
}
pub async fn open_memory() -> Result<Self, OxiSqlError> {
Self::open(":memory:").await
}
pub fn path(&self) -> &str {
&self.path
}
}
#[async_trait]
impl Connection for SqliteConnection {
async fn execute(&self, sql: &str, params: &[&dyn ToSqlValue]) -> Result<u64, OxiSqlError> {
let (rewritten, limbo_params) = rewrite_params(sql, params).map_err(OxiSqlError::from)?;
exec_rewritten(&self.conn, &rewritten, limbo_params, Some(&self.stmt_cache))
.await
.map_err(OxiSqlError::from)
}
async fn query(&self, sql: &str, params: &[&dyn ToSqlValue]) -> Result<Vec<Row>, OxiSqlError> {
let (rewritten, limbo_params) = rewrite_params(sql, params).map_err(OxiSqlError::from)?;
query_rewritten(&self.conn, &rewritten, limbo_params)
.await
.map_err(OxiSqlError::from)
}
async fn transaction(&self) -> Result<Box<dyn Transaction + '_>, OxiSqlError> {
let guard = self.txn_lock.lock().await;
self.conn
.execute("BEGIN", LimboParams::None)
.await
.map_err(|e| OxiSqlError::Other(format!("BEGIN failed: {e}")))?;
Ok(Box::new(SqliteTransaction {
conn: self.conn.clone(),
stmt_cache: Arc::clone(&self.stmt_cache),
_guard: guard,
done: false,
}))
}
async fn execute_batch(&self, sql: &str) -> Result<u64, OxiSqlError> {
let stmts = split_statements(sql);
let mut total = 0u64;
for stmt in stmts {
total += self.execute(stmt, &[]).await?;
}
Ok(total)
}
async fn ping(&self) -> Result<(), OxiSqlError> {
self.query("SELECT 1", &[]).await?;
Ok(())
}
async fn prepare(&self, sql: &str) -> Result<Box<dyn PreparedStatement + '_>, OxiSqlError> {
Ok(Box::new(SqlitePrepared {
conn: &self.conn,
stmt_cache: Arc::clone(&self.stmt_cache),
sql: sql.to_owned(),
}))
}
async fn tables(&self) -> Result<Vec<TableInfo>, OxiSqlError> {
let rows = self
.query(
"SELECT name, type FROM sqlite_master \
WHERE type IN ('table','view') AND name NOT LIKE 'sqlite_%' \
ORDER BY name",
&[],
)
.await?;
let infos = rows
.into_iter()
.map(|row| {
let name = row
.get_by_index(0)
.and_then(|v| {
if let Value::Text(s) = v {
Some(s.clone())
} else {
None
}
})
.unwrap_or_default();
let ttype_str = row
.get_by_index(1)
.and_then(|v| {
if let Value::Text(s) = v {
Some(s.as_str())
} else {
None
}
})
.unwrap_or("table");
let table_type = match ttype_str {
"view" => TableType::View,
_ => TableType::Base,
};
TableInfo {
name,
schema: None,
table_type,
}
})
.collect();
Ok(infos)
}
async fn columns(&self, table: &str) -> Result<Vec<ColumnInfo>, OxiSqlError> {
let sql = format!("PRAGMA table_info(\"{table}\")");
let rows = self.query(&sql, &[]).await?;
let infos = rows
.into_iter()
.map(|row| {
let text_at = |r: &Row, idx: usize| -> String {
r.get_by_index(idx)
.and_then(|v| match v {
Value::Text(s) => Some(s.clone()),
Value::I64(n) => Some(n.to_string()),
Value::Null => Some(String::new()),
_ => None,
})
.unwrap_or_default()
};
let i64_at = |r: &Row, idx: usize| -> i64 {
r.get_by_index(idx)
.and_then(|v| {
if let Value::I64(n) = v {
Some(*n)
} else {
None
}
})
.unwrap_or(0)
};
let ordinal = i64_at(&row, 0) as u32 + 1; let name = text_at(&row, 1);
let data_type = text_at(&row, 2);
let notnull = i64_at(&row, 3) != 0;
let default_val = row.get_by_index(4).and_then(|v| match v {
Value::Text(s) => Some(s.clone()),
Value::Null => None,
other => Some(format!("{other:?}")),
});
ColumnInfo {
name,
ordinal_position: ordinal,
data_type,
nullable: !notnull,
default: default_val,
max_length: None,
numeric_precision: None,
numeric_scale: None,
}
})
.collect();
Ok(infos)
}
async fn indexes(&self, table: &str) -> Result<Vec<IndexInfo>, OxiSqlError> {
let sql = "SELECT name, sql FROM sqlite_master \
WHERE type='index' AND tbl_name=$1 AND name NOT LIKE 'sqlite_%'";
let rows = self.query(sql, &[&table]).await?;
let mut infos: Vec<IndexInfo> = Vec::new();
for row in rows {
let name = row
.get_by_index(0)
.and_then(|v| {
if let Value::Text(s) = v {
Some(s.clone())
} else {
None
}
})
.unwrap_or_default();
let idx_sql = row
.get_by_index(1)
.and_then(|v| {
if let Value::Text(s) = v {
Some(s.clone())
} else {
None
}
})
.unwrap_or_default();
let upper = idx_sql.to_ascii_uppercase();
let unique = upper.contains("UNIQUE");
let columns: Vec<String> =
if let (Some(open), Some(close)) = (idx_sql.rfind('('), idx_sql.rfind(')')) {
idx_sql[open + 1..close]
.split(',')
.map(|c| c.trim().to_string())
.filter(|c| !c.is_empty())
.collect()
} else {
vec![]
};
infos.push(IndexInfo {
name,
columns,
unique,
primary: false,
});
}
Ok(infos)
}
async fn foreign_keys(&self, table: &str) -> Result<Vec<ForeignKeyInfo>, OxiSqlError> {
let escaped = table.replace('"', "\"\"");
let sql = format!("PRAGMA foreign_key_list(\"{}\")", escaped);
let rows = query_rewritten(&self.conn, &sql, vec![])
.await
.map_err(OxiSqlError::from)?;
let mut infos: Vec<ForeignKeyInfo> = Vec::with_capacity(rows.len());
for row in &rows {
let id = match row.get_by_index(0) {
Some(Value::I64(v)) => *v,
_ => 0,
};
let from_col = match row.get_by_index(3) {
Some(Value::Text(s)) => s.clone(),
_ => continue,
};
let foreign_table = match row.get_by_index(2) {
Some(Value::Text(s)) => s.clone(),
_ => continue,
};
let foreign_column = match row.get_by_index(4) {
Some(Value::Text(s)) => s.clone(),
_ => String::new(),
};
let on_update = match row.get_by_index(5) {
Some(Value::Text(s)) => Some(s.clone()),
_ => None,
};
let on_delete = match row.get_by_index(6) {
Some(Value::Text(s)) => Some(s.clone()),
_ => None,
};
let constraint_name = format!("fk_{table}_{id}");
infos.push(ForeignKeyInfo {
constraint_name,
column: from_col,
foreign_table,
foreign_column,
on_update,
on_delete,
});
}
Ok(infos)
}
}
pub struct SqliteTransaction<'a> {
conn: limbo::Connection,
stmt_cache: StmtCache,
_guard: tokio::sync::MutexGuard<'a, ()>,
done: bool,
}
impl<'a> Drop for SqliteTransaction<'a> {
fn drop(&mut self) {
if !self.done {
let conn = self.conn.clone();
tokio::spawn(async move {
if let Err(e) = conn.execute("ROLLBACK", LimboParams::None).await {
log::warn!("SqliteTransaction drop: ROLLBACK failed: {e}");
}
});
}
}
}
#[async_trait]
impl<'a> Transaction for SqliteTransaction<'a> {
async fn execute(&mut self, sql: &str, params: &[&dyn ToSqlValue]) -> Result<u64, OxiSqlError> {
let (rewritten, limbo_params) = rewrite_params(sql, params).map_err(OxiSqlError::from)?;
exec_rewritten(&self.conn, &rewritten, limbo_params, Some(&self.stmt_cache))
.await
.map_err(OxiSqlError::from)
}
async fn query(
&mut self,
sql: &str,
params: &[&dyn ToSqlValue],
) -> Result<Vec<Row>, OxiSqlError> {
let (rewritten, limbo_params) = rewrite_params(sql, params).map_err(OxiSqlError::from)?;
query_rewritten(&self.conn, &rewritten, limbo_params)
.await
.map_err(OxiSqlError::from)
}
async fn commit(mut self: Box<Self>) -> Result<(), OxiSqlError> {
self.done = true;
self.conn
.execute("COMMIT", LimboParams::None)
.await
.map_err(|e| OxiSqlError::Other(format!("COMMIT failed: {e}")))?;
Ok(())
}
async fn rollback(mut self: Box<Self>) -> Result<(), OxiSqlError> {
self.done = true;
self.conn
.execute("ROLLBACK", LimboParams::None)
.await
.map_err(|e| OxiSqlError::Other(format!("ROLLBACK failed: {e}")))?;
Ok(())
}
}
pub struct SqlitePrepared<'a> {
conn: &'a limbo::Connection,
stmt_cache: StmtCache,
sql: String,
}
#[async_trait]
impl<'a> PreparedStatement for SqlitePrepared<'a> {
async fn execute(&mut self, params: &[&dyn ToSqlValue]) -> Result<u64, OxiSqlError> {
let (rewritten, limbo_params) =
rewrite_params(&self.sql, params).map_err(OxiSqlError::from)?;
exec_rewritten(self.conn, &rewritten, limbo_params, Some(&self.stmt_cache))
.await
.map_err(OxiSqlError::from)
}
async fn query(&mut self, params: &[&dyn ToSqlValue]) -> Result<Vec<Row>, OxiSqlError> {
let (rewritten, limbo_params) =
rewrite_params(&self.sql, params).map_err(OxiSqlError::from)?;
query_rewritten(self.conn, &rewritten, limbo_params)
.await
.map_err(OxiSqlError::from)
}
fn sql(&self) -> &str {
&self.sql
}
}