use crate::error::SqlxResult;
use sqlx::Row;
use sqlx::sqlite::SqlitePool;
pub struct SqliteHelpers;
impl SqliteHelpers {
pub fn upsert_sql(table: &str, columns: &[&str]) -> String {
let cols = columns.join(", ");
let placeholders: Vec<String> = columns.iter().map(|_| "?".to_string()).collect();
let vals = placeholders.join(", ");
format!(
"INSERT OR REPLACE INTO {} ({}) VALUES ({})",
table, cols, vals
)
}
pub fn on_conflict_sql(
table: &str,
columns: &[&str],
conflict_columns: &[&str],
update_columns: &[&str],
) -> String {
let cols = columns.join(", ");
let placeholders: Vec<String> = columns.iter().map(|_| "?".to_string()).collect();
let vals = placeholders.join(", ");
let conflict = conflict_columns.join(", ");
let updates: Vec<String> = update_columns
.iter()
.map(|c| format!("{} = excluded.{}", c, c))
.collect();
let update_clause = updates.join(", ");
format!(
"INSERT INTO {} ({}) VALUES ({}) ON CONFLICT({}) DO UPDATE SET {}",
table, cols, vals, conflict, update_clause
)
}
pub fn json_extract(column: &str, path: &str) -> String {
format!("json_extract({}, '$.{}')", column, path)
}
pub async fn last_insert_rowid(pool: &SqlitePool) -> SqlxResult<i64> {
let row = sqlx::query("SELECT last_insert_rowid()")
.fetch_one(pool)
.await?;
let id: i64 = row.try_get(0)?;
Ok(id)
}
pub async fn version(pool: &SqlitePool) -> SqlxResult<String> {
let row = sqlx::query("SELECT sqlite_version()")
.fetch_one(pool)
.await?;
let version: String = row.try_get(0)?;
Ok(version)
}
pub async fn table_exists(pool: &SqlitePool, table: &str) -> SqlxResult<bool> {
let sql = "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?";
let row = sqlx::query(sql).bind(table).fetch_one(pool).await?;
let count: i32 = row.try_get(0)?;
Ok(count > 0)
}
pub async fn get_columns(pool: &SqlitePool, table: &str) -> SqlxResult<Vec<String>> {
let sql = format!("PRAGMA table_info({})", table);
let rows = sqlx::query(&sql).fetch_all(pool).await?;
let columns: Vec<String> = rows
.iter()
.map(|r| r.try_get::<String, _>("name").unwrap_or_default())
.collect();
Ok(columns)
}
pub async fn enable_foreign_keys(pool: &SqlitePool) -> SqlxResult<()> {
sqlx::query("PRAGMA foreign_keys = ON")
.execute(pool)
.await?;
Ok(())
}
pub async fn set_journal_mode(pool: &SqlitePool, mode: JournalMode) -> SqlxResult<()> {
let sql = format!("PRAGMA journal_mode = {}", mode.as_str());
sqlx::query(&sql).execute(pool).await?;
Ok(())
}
pub async fn set_synchronous(pool: &SqlitePool, mode: SynchronousMode) -> SqlxResult<()> {
let sql = format!("PRAGMA synchronous = {}", mode.as_str());
sqlx::query(&sql).execute(pool).await?;
Ok(())
}
pub async fn vacuum(pool: &SqlitePool) -> SqlxResult<()> {
sqlx::query("VACUUM").execute(pool).await?;
Ok(())
}
pub async fn analyze(pool: &SqlitePool) -> SqlxResult<()> {
sqlx::query("ANALYZE").execute(pool).await?;
Ok(())
}
pub async fn integrity_check(pool: &SqlitePool) -> SqlxResult<Vec<String>> {
let rows = sqlx::query("PRAGMA integrity_check")
.fetch_all(pool)
.await?;
let results: Vec<String> = rows
.iter()
.map(|r| r.try_get::<String, _>(0).unwrap_or_default())
.collect();
Ok(results)
}
pub fn fts5_match(table: &str, _query: &str) -> String {
format!("{} MATCH ?", table)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum JournalMode {
Delete,
Truncate,
Persist,
Memory,
Wal,
Off,
}
impl JournalMode {
pub fn as_str(&self) -> &'static str {
match self {
Self::Delete => "DELETE",
Self::Truncate => "TRUNCATE",
Self::Persist => "PERSIST",
Self::Memory => "MEMORY",
Self::Wal => "WAL",
Self::Off => "OFF",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SynchronousMode {
Full,
Normal,
Off,
Extra,
}
impl SynchronousMode {
pub fn as_str(&self) -> &'static str {
match self {
Self::Full => "FULL",
Self::Normal => "NORMAL",
Self::Off => "OFF",
Self::Extra => "EXTRA",
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_upsert_sql() {
let sql = SqliteHelpers::upsert_sql("users", &["id", "name", "email"]);
assert!(sql.contains("INSERT OR REPLACE"));
assert!(sql.contains("users"));
}
#[test]
fn test_on_conflict_sql() {
let sql = SqliteHelpers::on_conflict_sql(
"users",
&["id", "name", "email"],
&["id"],
&["name", "email"],
);
assert!(sql.contains("ON CONFLICT(id)"));
assert!(sql.contains("DO UPDATE SET"));
}
#[test]
fn test_json_extract() {
assert_eq!(
SqliteHelpers::json_extract("data", "name"),
"json_extract(data, '$.name')"
);
}
#[test]
fn test_journal_mode() {
assert_eq!(JournalMode::Wal.as_str(), "WAL");
assert_eq!(JournalMode::Delete.as_str(), "DELETE");
}
}