use std::path::{Path, PathBuf};
use directories::ProjectDirs;
use thiserror::Error;
use turso::{Builder, Connection, Row, Rows, Value};
#[derive(Error, Debug)]
pub enum StorageError {
#[error("Database error: {0}")]
Database(#[from] turso::Error),
#[error("Failed to access config directory: {0}")]
ConfigDir(#[from] std::io::Error),
#[error("Could not determine config directory")]
NoConfigDir,
#[error("Data conversion error: {0}")]
Conversion(String),
#[error("Record not found: {0}")]
NotFound(String),
#[error("Migration error: {0}")]
Migration(String),
}
pub type Result<T> = std::result::Result<T, StorageError>;
#[derive(Clone)]
pub struct Database {
conn: Connection,
}
impl Database {
pub async fn open(path: &Path) -> Result<Self> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let path_str = path.to_string_lossy();
let db = Builder::new_local(&path_str).build().await?;
let conn = db.connect()?;
conn.execute("PRAGMA foreign_keys = ON", ()).await?;
Ok(Self { conn })
}
pub async fn open_in_memory() -> Result<Self> {
let db = Builder::new_local(":memory:").build().await?;
let conn = db.connect()?;
conn.execute("PRAGMA foreign_keys = ON", ()).await?;
Ok(Self { conn })
}
pub async fn open_default() -> Result<Self> {
let path = Self::default_path()?;
Self::open(&path).await
}
pub fn default_path() -> Result<PathBuf> {
let proj_dirs =
ProjectDirs::from("", "", "ratado").ok_or(StorageError::NoConfigDir)?;
let config_dir = proj_dirs.config_dir();
std::fs::create_dir_all(config_dir)?;
Ok(config_dir.join("ratado.db"))
}
pub async fn execute(
&self,
sql: impl AsRef<str>,
params: impl turso::IntoParams,
) -> Result<u64> {
Ok(self.conn.execute(sql, params).await?)
}
pub async fn execute_batch(&self, sql: impl AsRef<str>) -> Result<()> {
Ok(self.conn.execute_batch(sql).await?)
}
pub async fn query(
&self,
sql: impl AsRef<str>,
params: impl turso::IntoParams,
) -> Result<Rows> {
Ok(self.conn.query(sql, params).await?)
}
pub async fn query_one(
&self,
sql: impl AsRef<str>,
params: impl turso::IntoParams,
) -> Result<Option<Row>> {
let mut rows = self.query(sql, params).await?;
Ok(rows.next().await?)
}
pub async fn query_scalar(
&self,
sql: impl AsRef<str>,
params: impl turso::IntoParams,
) -> Result<Option<Value>> {
if let Some(row) = self.query_one(sql, params).await? {
Ok(Some(row.get_value(0)?))
} else {
Ok(None)
}
}
pub fn connection(&self) -> &Connection {
&self.conn
}
}
impl std::fmt::Debug for Database {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Database").finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_open_in_memory() {
let db = Database::open_in_memory().await.unwrap();
db.execute("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)", ())
.await
.unwrap();
db.execute("INSERT INTO test (name) VALUES (?1)", ["Alice"])
.await
.unwrap();
let mut rows = db.query("SELECT name FROM test", ()).await.unwrap();
let row = rows.next().await.unwrap().unwrap();
let name = row.get_value(0).unwrap();
assert_eq!(name, Value::Text("Alice".to_string()));
}
#[tokio::test]
async fn test_execute_batch() {
let db = Database::open_in_memory().await.unwrap();
db.execute_batch(
"
CREATE TABLE test1 (id INTEGER PRIMARY KEY);
CREATE TABLE test2 (id INTEGER PRIMARY KEY);
INSERT INTO test1 (id) VALUES (1);
INSERT INTO test2 (id) VALUES (2);
",
)
.await
.unwrap();
let value = db
.query_scalar("SELECT COUNT(*) FROM test1", ())
.await
.unwrap();
assert_eq!(value, Some(Value::Integer(1)));
}
#[tokio::test]
async fn test_query_one() {
let db = Database::open_in_memory().await.unwrap();
db.execute("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)", ())
.await
.unwrap();
db.execute("INSERT INTO test (name) VALUES (?1)", ["Bob"])
.await
.unwrap();
let row = db
.query_one("SELECT name FROM test WHERE id = 1", ())
.await
.unwrap();
assert!(row.is_some());
let row = db
.query_one("SELECT name FROM test WHERE id = 999", ())
.await
.unwrap();
assert!(row.is_none());
}
#[test]
fn test_default_path() {
let path = Database::default_path().unwrap();
assert!(path.ends_with("ratado.db"));
assert!(path.to_string_lossy().contains("ratado"));
}
}