use crate::config::SqliteConfig;
use crate::connection::Connection;
use crate::error::SqliteError;
use crate::migration::{MigrationBuilder, MigrationReport, Migrator};
use crate::query::{DeleteBuilder, InsertBuilder, SelectBuilder, UpdateBuilder};
use crate::transaction::{with_transaction, Transaction};
use crate::types::{Param, Row, Rows};
use std::path::PathBuf;
pub struct EmbeddedDb {
conn: Connection,
name: String,
}
impl EmbeddedDb {
pub fn open(name: &str) -> Result<Self, SqliteError> {
let path = Self::default_path(name);
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let config = SqliteConfig::default().with_path(path);
let conn = Connection::open(config)?;
Ok(Self {
conn,
name: name.to_string(),
})
}
pub fn open_path(path: impl Into<PathBuf>) -> Result<Self, SqliteError> {
let path = path.into();
let name = path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("db")
.to_string();
let config = SqliteConfig::default().with_path(path);
let conn = Connection::open(config)?;
Ok(Self { conn, name })
}
pub fn open_with_config(name: &str, config: SqliteConfig) -> Result<Self, SqliteError> {
let config = if config.path.is_none() {
config.with_path(Self::default_path(name))
} else {
config
};
let conn = Connection::open(config)?;
Ok(Self {
conn,
name: name.to_string(),
})
}
pub fn memory() -> Result<Self, SqliteError> {
let conn = Connection::open(SqliteConfig::memory())?;
Ok(Self {
conn,
name: ":memory:".to_string(),
})
}
fn default_path(name: &str) -> PathBuf {
let data_dir = dirs::data_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("unistore")
.join("db");
data_dir.join(format!("{}.db", name))
}
pub fn name(&self) -> &str {
&self.name
}
pub fn connection(&self) -> &Connection {
&self.conn
}
pub fn migrate<F>(&self, f: F) -> Result<MigrationReport, SqliteError>
where
F: FnOnce(&mut MigrationBuilder),
{
let migrator = Migrator::new(&self.conn);
migrator.migrate_with(f)
}
pub fn schema_version(&self) -> Result<u32, SqliteError> {
let migrator = Migrator::new(&self.conn);
migrator.current_version()
}
pub fn select(&self, table: &str) -> SelectBuilder<'_> {
SelectBuilder::new(&self.conn, table)
}
pub fn insert(&self, table: &str) -> InsertBuilder<'_> {
InsertBuilder::new(&self.conn, table)
}
pub fn update(&self, table: &str) -> UpdateBuilder<'_> {
UpdateBuilder::new(&self.conn, table)
}
pub fn delete(&self, table: &str) -> DeleteBuilder<'_> {
DeleteBuilder::new(&self.conn, table)
}
pub fn execute(&self, sql: &str, params: &[Param]) -> Result<usize, SqliteError> {
self.conn.execute(sql, params)
}
pub fn execute_batch(&self, sql: &str) -> Result<(), SqliteError> {
self.conn.execute_batch(sql)
}
pub fn query_row(&self, sql: &str, params: &[Param]) -> Result<Option<Row>, SqliteError> {
self.conn.query_row(sql, params)
}
pub fn query(&self, sql: &str, params: &[Param]) -> Result<Rows, SqliteError> {
self.conn.query(sql, params)
}
pub fn last_insert_id(&self) -> Result<i64, SqliteError> {
self.conn.last_insert_rowid()
}
pub fn begin_transaction(&self) -> Result<Transaction<'_>, SqliteError> {
Transaction::begin(&self.conn)
}
pub fn with_transaction<F, T>(&self, f: F) -> Result<T, SqliteError>
where
F: FnOnce(&Transaction<'_>) -> Result<T, SqliteError>,
{
with_transaction(&self.conn, f)
}
pub fn table_exists(&self, table: &str) -> Result<bool, SqliteError> {
let row = self.conn.query_row(
"SELECT 1 FROM sqlite_master WHERE type='table' AND name=?",
&[table.into()],
)?;
Ok(row.is_some())
}
pub fn table_columns(&self, table: &str) -> Result<Vec<ColumnInfo>, SqliteError> {
let rows = self.conn.query(&format!("PRAGMA table_info({})", table), &[])?;
Ok(rows
.into_iter()
.map(|row| ColumnInfo {
name: row.get_string("name").unwrap_or_default(),
type_name: row.get_string("type").unwrap_or_default(),
not_null: row.get_bool("notnull").unwrap_or(false),
default_value: row.get_string("dflt_value"),
is_primary_key: row.get_bool("pk").unwrap_or(false),
})
.collect())
}
pub fn database_size(&self) -> Result<i64, SqliteError> {
let row = self
.conn
.query_row("SELECT page_count * page_size as size FROM pragma_page_count(), pragma_page_size()", &[])?;
Ok(row.and_then(|r| r.get_i64("size")).unwrap_or(0))
}
pub fn vacuum(&self) -> Result<(), SqliteError> {
self.conn.execute_batch("VACUUM")
}
pub fn close(self) -> Result<(), SqliteError> {
self.conn.close()
}
}
#[derive(Debug, Clone)]
pub struct ColumnInfo {
pub name: String,
pub type_name: String,
pub not_null: bool,
pub default_value: Option<String>,
pub is_primary_key: bool,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedded_db_memory() {
let db = EmbeddedDb::memory().unwrap();
assert_eq!(db.name(), ":memory:");
}
#[test]
fn test_crud_operations() {
let db = EmbeddedDb::memory().unwrap();
db.migrate(|m| {
m.version(1, "创建用户表", |s| {
s.create_table("users", |t| t.id().text_not_null("name").integer("age").timestamps())
});
})
.unwrap();
let id = db.insert("users").set("name", "Alice").set("age", 30).execute().unwrap();
assert_eq!(id, 1);
let user = db.select("users").filter_eq("id", 1).fetch_one().unwrap().unwrap();
assert_eq!(user.get_str("name"), Some("Alice"));
assert_eq!(user.get_i64("age"), Some(30));
let affected = db
.update("users")
.set("age", 31)
.filter_eq("id", 1)
.execute()
.unwrap();
assert_eq!(affected, 1);
let user = db.select("users").filter_eq("id", 1).fetch_one().unwrap().unwrap();
assert_eq!(user.get_i64("age"), Some(31));
let affected = db.delete("users").filter_eq("id", 1).execute().unwrap();
assert_eq!(affected, 1);
let count = db.select("users").count().unwrap();
assert_eq!(count, 0);
}
#[test]
fn test_transaction() {
let db = EmbeddedDb::memory().unwrap();
db.execute_batch("CREATE TABLE test (id INTEGER PRIMARY KEY, value INTEGER)")
.unwrap();
db.with_transaction(|tx| {
tx.execute("INSERT INTO test (value) VALUES (?)", &[1i32.into()])?;
tx.execute("INSERT INTO test (value) VALUES (?)", &[2i32.into()])?;
Ok(())
})
.unwrap();
let count = db.select("test").count().unwrap();
assert_eq!(count, 2);
let result: Result<(), SqliteError> = db.with_transaction(|tx| {
tx.execute("INSERT INTO test (value) VALUES (?)", &[3i32.into()])?;
Err(SqliteError::Internal("test".into()))
});
assert!(result.is_err());
let count = db.select("test").count().unwrap();
assert_eq!(count, 2); }
#[test]
fn test_table_info() {
let db = EmbeddedDb::memory().unwrap();
db.migrate(|m| {
m.version(1, "创建表", |s| {
s.create_table("test", |t| t.id().text_not_null("name").integer("value"))
});
})
.unwrap();
assert!(db.table_exists("test").unwrap());
assert!(!db.table_exists("nonexistent").unwrap());
let columns = db.table_columns("test").unwrap();
assert_eq!(columns.len(), 3);
assert!(columns.iter().any(|c| c.name == "id" && c.is_primary_key));
assert!(columns.iter().any(|c| c.name == "name" && c.not_null));
}
}