use std::path::{Path, PathBuf};
use rusqlite::Connection;
use crate::error::{WalletError, Result};
use super::schema;
pub struct Database {
path: PathBuf,
conn: Option<Connection>,
}
impl Database {
pub fn open(path: &Path) -> Result<Self> {
let conn = Connection::open(path)?;
Ok(Self {
path: path.to_path_buf(),
conn: Some(conn),
})
}
pub fn create(path: &Path) -> Result<Self> {
let conn = Connection::open(path)?;
for sql in schema::CREATE_ALL_TABLES {
conn.execute(sql, [])?;
}
Ok(Self {
path: path.to_path_buf(),
conn: Some(conn),
})
}
pub fn connection(&self) -> Result<&Connection> {
self.conn.as_ref().ok_or_else(|| {
WalletError::DatabaseError("Database not open".to_string())
})
}
pub fn connection_mut(&mut self) -> Result<&mut Connection> {
self.conn.as_mut().ok_or_else(|| {
WalletError::DatabaseError("Database not open".to_string())
})
}
pub fn path(&self) -> &Path {
&self.path
}
pub fn close(&mut self) {
self.conn = None;
}
pub fn is_open(&self) -> bool {
self.conn.is_some()
}
pub fn begin_transaction(&mut self) -> Result<()> {
self.connection()?.execute("BEGIN TRANSACTION", [])?;
Ok(())
}
pub fn commit_transaction(&mut self) -> Result<()> {
self.connection()?.execute("COMMIT", [])?;
Ok(())
}
pub fn rollback_transaction(&mut self) -> Result<()> {
self.connection()?.execute("ROLLBACK", [])?;
Ok(())
}
pub fn checkpoint(&self) -> Result<()> {
self.connection()?.execute_batch("PRAGMA wal_checkpoint(TRUNCATE)")?;
Ok(())
}
}
impl Drop for Database {
fn drop(&mut self) {
self.close();
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_checkpoint_no_error() {
let temp_dir = TempDir::new().unwrap();
let db_path = temp_dir.path().join("test.db");
let db = Database::create(&db_path).unwrap();
db.checkpoint().unwrap();
}
#[test]
fn test_checkpoint_after_write() {
let temp_dir = TempDir::new().unwrap();
let db_path = temp_dir.path().join("test.db");
let db = Database::create(&db_path).unwrap();
db.connection().unwrap().execute(
"INSERT INTO nswallet_properties (database_id, lang, version, email, sync_timestamp, update_timestamp) VALUES (?, ?, ?, ?, ?, ?)",
rusqlite::params!["test-id", "en", "4", "0", "2024-01-01 00:00:00", "2024-01-01 00:00:00"]
).unwrap();
db.checkpoint().unwrap();
}
#[test]
fn test_checkpoint_clears_wal() {
use std::fs;
let temp_dir = TempDir::new().unwrap();
let db_path = temp_dir.path().join("test.db");
let wal_path = temp_dir.path().join("test.db-wal");
let db = Database::create(&db_path).unwrap();
db.connection().unwrap().execute_batch("PRAGMA journal_mode=WAL").unwrap();
db.connection().unwrap().execute(
"INSERT INTO nswallet_properties (database_id, lang, version, email, sync_timestamp, update_timestamp) VALUES (?, ?, ?, ?, ?, ?)",
rusqlite::params!["test-id", "en", "4", "0", "2024-01-01 00:00:00", "2024-01-01 00:00:00"]
).unwrap();
let wal_size_before = fs::metadata(&wal_path).map(|m| m.len()).unwrap_or(0);
db.checkpoint().unwrap();
let wal_size_after = fs::metadata(&wal_path).map(|m| m.len()).unwrap_or(0);
assert!(wal_size_after == 0 || !wal_path.exists() || wal_size_after < wal_size_before,
"WAL should be truncated after checkpoint");
}
}