use crate::sqlite::database::{OpenDatabaseInput, OpenDatabaseTool};
use crate::sqlite::manager::DATABASE_MANAGER;
use mixtape_core::tool::{Tool, ToolResult};
use std::path::PathBuf;
use tempfile::TempDir;
pub fn unwrap_json(result: ToolResult) -> serde_json::Value {
match result {
ToolResult::Json(v) => v,
other => panic!("Expected JSON result, got {:?}", other),
}
}
pub struct TestDatabase {
#[allow(dead_code)]
temp_dir: TempDir,
key: String,
}
impl TestDatabase {
pub async fn new() -> Self {
Self::with_name("test.db").await
}
pub async fn with_name(name: &str) -> Self {
let temp_dir = TempDir::new().expect("Failed to create temp directory");
let db_path = temp_dir.path().join(name);
let tool = OpenDatabaseTool;
let input = OpenDatabaseInput {
db_path,
create: true,
};
let result = tool
.execute(input)
.await
.expect("Failed to open test database");
let key = match result {
ToolResult::Json(json) => json["database"]
.as_str()
.expect("OpenDatabaseTool should return database key")
.to_string(),
other => panic!(
"Expected JSON result from OpenDatabaseTool, got {:?}",
other
),
};
Self { temp_dir, key }
}
pub fn key(&self) -> String {
self.key.clone()
}
pub fn path(&self) -> PathBuf {
PathBuf::from(&self.key)
}
pub async fn with_schema(schema: &str) -> Self {
let db = Self::new().await;
db.execute(schema);
db
}
pub fn execute(&self, sql: &str) {
let conn = DATABASE_MANAGER
.get(Some(&self.key))
.expect("Failed to get test database connection");
let conn = conn.lock().unwrap();
conn.execute_batch(sql)
.expect("Failed to execute SQL in test database");
}
pub fn query(&self, sql: &str) -> Vec<Vec<serde_json::Value>> {
let conn = DATABASE_MANAGER
.get(Some(&self.key))
.expect("Failed to get test database connection");
let conn = conn.lock().unwrap();
let mut stmt = conn.prepare(sql).expect("Failed to prepare query");
let column_count = stmt.column_count();
let rows: Vec<Vec<serde_json::Value>> = stmt
.query_map([], |row| {
let mut values = Vec::with_capacity(column_count);
for i in 0..column_count {
let value = match row.get_ref(i)? {
rusqlite::types::ValueRef::Null => serde_json::Value::Null,
rusqlite::types::ValueRef::Integer(n) => serde_json::json!(n),
rusqlite::types::ValueRef::Real(f) => serde_json::json!(f),
rusqlite::types::ValueRef::Text(s) => {
serde_json::json!(String::from_utf8_lossy(s))
}
rusqlite::types::ValueRef::Blob(b) => {
serde_json::json!(base64::Engine::encode(
&base64::engine::general_purpose::STANDARD,
b
))
}
};
values.push(value);
}
Ok(values)
})
.expect("Failed to execute query")
.filter_map(|r| r.ok())
.collect();
rows
}
pub fn count(&self, table: &str) -> i64 {
let conn = DATABASE_MANAGER
.get(Some(&self.key))
.expect("Failed to get test database connection");
let conn = conn.lock().unwrap();
conn.query_row(&format!("SELECT COUNT(*) FROM \"{}\"", table), [], |row| {
row.get(0)
})
.expect("Failed to count rows")
}
}
impl Drop for TestDatabase {
fn drop(&mut self) {
let _ = DATABASE_MANAGER.close(&self.key);
}
}