use crate::error::VecXError;
use r2d2::Pool;
use r2d2_sqlite::SqliteConnectionManager;
use rusqlite::Connection;
use std::path::Path;
use std::time::Duration;
const PAGES_PER_STEP: i32 = 100;
const STEP_DELAY_MS: u64 = 10;
pub fn backup_database(
pool: &Pool<SqliteConnectionManager>,
dest_path: &Path,
) -> Result<u64, VecXError> {
let source_conn = pool.get().map_err(|e| {
VecXError::Other(format!("Failed to get connection for backup: {}", e))
})?;
let mut dest_conn = Connection::open(dest_path).map_err(|e| {
VecXError::SqlError(format!("Failed to open destination database: {}", e))
})?;
backup_connection(&source_conn, &mut dest_conn)?;
let file_size = std::fs::metadata(dest_path)
.map(|m| m.len())
.map_err(|e| VecXError::IoError(format!("Failed to get backup file size: {}", e)))?;
Ok(file_size)
}
fn backup_connection(
source: &rusqlite::Connection,
dest: &mut rusqlite::Connection,
) -> Result<(), VecXError> {
let backup = rusqlite::backup::Backup::new(source, dest).map_err(|e| {
VecXError::SqlError(format!("Failed to initialize backup: {}", e))
})?;
loop {
let step_result = backup.step(PAGES_PER_STEP).map_err(|e| {
VecXError::SqlError(format!("Backup step failed: {}", e))
})?;
match step_result {
rusqlite::backup::StepResult::Done => break,
rusqlite::backup::StepResult::More => {
std::thread::sleep(Duration::from_millis(STEP_DELAY_MS));
}
rusqlite::backup::StepResult::Busy => {
std::thread::sleep(Duration::from_millis(STEP_DELAY_MS * 10));
}
rusqlite::backup::StepResult::Locked => {
std::thread::sleep(Duration::from_millis(STEP_DELAY_MS * 10));
}
_ => {
std::thread::sleep(Duration::from_millis(STEP_DELAY_MS));
}
}
}
Ok(())
}
pub fn backup_to_memory(pool: &Pool<SqliteConnectionManager>) -> Result<Vec<u8>, VecXError> {
let temp_path = std::env::temp_dir().join(format!(
"vxlite_backup_{}.db",
std::process::id()
));
let _size = backup_database(pool, &temp_path)?;
let data = std::fs::read(&temp_path).map_err(|e| {
VecXError::IoError(format!("Failed to read backup file: {}", e))
})?;
let _ = std::fs::remove_file(&temp_path);
Ok(data)
}
pub fn restore_database(
backup_path: &Path,
dest_pool: &Pool<SqliteConnectionManager>,
) -> Result<(), VecXError> {
let backup_conn = Connection::open(backup_path).map_err(|e| {
VecXError::SqlError(format!("Failed to open backup file: {}", e))
})?;
let mut dest_conn = dest_pool.get().map_err(|e| {
VecXError::Other(format!("Failed to get destination connection: {}", e))
})?;
backup_connection(&backup_conn, &mut *dest_conn)?;
Ok(())
}
pub fn restore_from_memory(
data: &[u8],
dest_pool: &Pool<SqliteConnectionManager>,
) -> Result<(), VecXError> {
let temp_path = std::env::temp_dir().join(format!(
"vxlite_restore_{}.db",
std::process::id()
));
std::fs::write(&temp_path, data).map_err(|e| {
VecXError::IoError(format!("Failed to write temp restore file: {}", e))
})?;
let result = restore_database(&temp_path, dest_pool);
let _ = std::fs::remove_file(&temp_path);
result
}
pub fn get_index_files(pool: &Pool<SqliteConnectionManager>) -> Result<Vec<String>, VecXError> {
let conn = pool.get().map_err(|e| {
VecXError::Other(format!("Failed to get connection: {}", e))
})?;
let mut stmt = conn
.prepare(
"SELECT sql FROM sqlite_master WHERE type='table' AND sql LIKE '%vectorlite%'",
)
.map_err(|e| VecXError::SqlError(format!("Failed to prepare query: {}", e)))?;
let sql_strings: Vec<String> = stmt
.query_map([], |row| row.get(0))
.map_err(|e| VecXError::SqlError(format!("Failed to query tables: {}", e)))?
.filter_map(|r| r.ok())
.collect();
let mut index_files = Vec::new();
for sql in sql_strings {
if let Some(path) = extract_index_path(&sql) {
if !path.is_empty() && path != ":memory:" {
index_files.push(path);
}
}
}
Ok(index_files)
}
fn extract_index_path(sql: &str) -> Option<String> {
let sql_lower = sql.to_lowercase();
let using_pos = sql_lower.find("using vectorlite(")?;
let start = using_pos + "using vectorlite(".len();
let end = sql[start..].find(')')? + start;
let args = &sql[start..end];
let parts: Vec<&str> = args.split(',').collect();
for part in parts.iter().rev() {
let trimmed = part.trim().trim_matches(|c| c == '\'' || c == '"');
if trimmed.contains('/') || trimmed.contains('\\') || trimmed.ends_with(".idx") {
return Some(trimmed.to_string());
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_index_path() {
let sql = "CREATE VIRTUAL TABLE vt_vector_test USING vectorlite(vector_embedding float32[128] cosine, hnsw(max_elements=100000), '/tmp/test.idx')";
assert_eq!(extract_index_path(sql), Some("/tmp/test.idx".to_string()));
let sql_no_path = "CREATE VIRTUAL TABLE vt_vector_test USING vectorlite(vector_embedding float32[128] cosine, hnsw(max_elements=100000))";
assert_eq!(extract_index_path(sql_no_path), None);
}
}