use std::path::Path;
use secrecy::{ExposeSecret, SecretBox};
use zeroize::Zeroizing;
use super::connection::Connection;
use super::error::{DbResult, Error};
pub fn open_encrypted(
path: &Path,
k_intermediate: &SecretBox<[u8; 32]>,
read_only: bool,
) -> DbResult<Connection> {
let conn = Connection::open(path, read_only)?;
apply_key(&conn, k_intermediate)?;
configure_connection(&conn)?;
Ok(conn)
}
fn apply_key(conn: &Connection, k_intermediate: &SecretBox<[u8; 32]>) -> DbResult<()> {
let key_hex = Zeroizing::new(hex::encode(k_intermediate.expose_secret()));
let pragma = Zeroizing::new(format!("PRAGMA key = \"x'{}'\";", key_hex.as_str()));
conn.execute_batch_zeroized(&pragma)?;
conn.execute_batch("SELECT count(*) FROM sqlite_master;")
.map_err(|e| {
Error::new(
e.code.0,
format!(
"encryption key verification failed (is the key correct?): {}",
e.message
),
)
})?;
Ok(())
}
fn configure_connection(conn: &Connection) -> DbResult<()> {
conn.execute_batch(
"PRAGMA foreign_keys = ON;
PRAGMA journal_mode = WAL;
PRAGMA synchronous = FULL;
PRAGMA secure_delete = ON;",
)
}
pub fn export_plaintext_copy(
conn: &Connection,
dest_path: &Path,
tables: &[&str],
) -> DbResult<()> {
let dest_str = dest_path.to_string_lossy();
let attach_sql = format!(
"ATTACH DATABASE '{}' AS backup KEY '';",
dest_str.replace('\'', "''")
);
conn.execute_batch(&attach_sql)?;
let result = (|| {
let tx = conn.transaction()?;
for table in tables {
tx.execute_batch(&format!(
"CREATE TABLE backup.{table} AS SELECT * FROM {table};"
))?;
}
tx.commit()
})();
let detach_result = conn.execute_batch("DETACH DATABASE backup;");
result?;
detach_result?;
Ok(())
}
pub fn import_plaintext_copy(
conn: &Connection,
source_path: &Path,
tables: &[&str],
) -> DbResult<()> {
if !source_path.exists() {
return Err(Error::new(
-1,
format!("backup file does not exist: {}", source_path.display()),
));
}
let source_str = source_path.to_string_lossy();
let attach_sql = format!(
"ATTACH DATABASE '{}' AS backup KEY '';",
source_str.replace('\'', "''")
);
conn.execute_batch(&attach_sql)?;
let result = (|| {
for table in tables {
let count: i64 =
conn.query_row(&format!("SELECT COUNT(*) FROM {table}"), &[], |row| {
Ok(row.column_i64(0))
})?;
if count > 0 {
return Err(Error::new(
-1,
format!("cannot import into non-empty table: {table}"),
));
}
}
let tx = conn.transaction()?;
for table in tables {
tx.execute_batch(&format!(
"INSERT INTO {table} SELECT * FROM backup.{table};"
))?;
}
tx.commit()
})();
let detach_result = conn.execute_batch("DETACH DATABASE backup;");
result?;
detach_result?;
Ok(())
}
pub fn integrity_check(conn: &Connection) -> DbResult<bool> {
let result = conn.query_row("PRAGMA integrity_check;", &[], |stmt| {
Ok(stmt.column_text(0))
})?;
Ok(result.trim() == "ok")
}
#[cfg(test)]
mod tests {
use super::{
export_plaintext_copy, import_plaintext_copy, integrity_check, open_encrypted,
};
use crate::params;
use crate::sqlite::Connection;
use crate::test_utils::init_sqlite;
use secrecy::SecretBox;
#[test]
fn test_cipher_encrypted_round_trip() {
init_sqlite();
let dir = tempfile::tempdir().expect("create temp dir");
let path = dir.path().join("cipher-test.sqlite");
let key = SecretBox::init_with(|| [0xABu8; 32]);
{
let conn = open_encrypted(&path, &key, false).expect("open encrypted");
conn.execute_batch(
"CREATE TABLE secret (id INTEGER PRIMARY KEY, val TEXT);",
)
.expect("create table");
conn.execute("INSERT INTO secret (id, val) VALUES (1, 'top-secret')", &[])
.expect("insert");
}
{
let conn = open_encrypted(&path, &key, false).expect("reopen encrypted");
let val = conn
.query_row("SELECT val FROM secret WHERE id = 1", &[], |stmt| {
Ok(stmt.column_text(0))
})
.expect("query");
assert_eq!(val, "top-secret");
}
{
let wrong_key = SecretBox::init_with(|| [0xCDu8; 32]);
let result = open_encrypted(&path, &wrong_key, false);
assert!(result.is_err(), "wrong key should fail");
}
}
#[test]
fn test_integrity_check() {
init_sqlite();
let conn = Connection::open_in_memory().expect("open in-memory db");
let ok = integrity_check(&conn).expect("check");
assert!(ok);
}
#[test]
fn test_cipher_plaintext_export_import_roundtrip() {
init_sqlite();
let dir = tempfile::tempdir().expect("create temp dir");
let src_path = dir.path().join("source.sqlite");
let dest_path = dir.path().join("backup.plain.sqlite");
let restore_path = dir.path().join("restore.sqlite");
let key = SecretBox::init_with(|| [0x11u8; 32]);
{
let conn = open_encrypted(&src_path, &key, false).expect("open src");
conn.execute_batch(
"CREATE TABLE widgets (id INTEGER PRIMARY KEY, val TEXT NOT NULL);",
)
.expect("create table");
conn.execute(
"INSERT INTO widgets (id, val) VALUES (?1, ?2)",
params![1_i64, "alpha"],
)
.expect("insert");
conn.execute(
"INSERT INTO widgets (id, val) VALUES (?1, ?2)",
params![2_i64, "beta"],
)
.expect("insert");
export_plaintext_copy(&conn, &dest_path, &["widgets"]).expect("export");
}
{
let conn =
open_encrypted(&restore_path, &key, false).expect("open restore");
conn.execute_batch(
"CREATE TABLE widgets (id INTEGER PRIMARY KEY, val TEXT NOT NULL);",
)
.expect("create table");
import_plaintext_copy(&conn, &dest_path, &["widgets"]).expect("import");
let count: i64 = conn
.query_row("SELECT COUNT(*) FROM widgets", &[], |row| {
Ok(row.column_i64(0))
})
.expect("count");
assert_eq!(count, 2);
let val = conn
.query_row("SELECT val FROM widgets WHERE id = 2", &[], |row| {
Ok(row.column_text(0))
})
.expect("query");
assert_eq!(val, "beta");
}
}
#[test]
fn test_cipher_import_rejects_non_empty_destination() {
init_sqlite();
let dir = tempfile::tempdir().expect("create temp dir");
let src_path = dir.path().join("source.sqlite");
let dest_path = dir.path().join("backup.plain.sqlite");
let restore_path = dir.path().join("restore.sqlite");
let key = SecretBox::init_with(|| [0x22u8; 32]);
{
let conn = open_encrypted(&src_path, &key, false).expect("open src");
conn.execute_batch(
"CREATE TABLE widgets (id INTEGER PRIMARY KEY, val TEXT NOT NULL);",
)
.expect("create table");
conn.execute(
"INSERT INTO widgets (id, val) VALUES (?1, ?2)",
params![1_i64, "alpha"],
)
.expect("insert");
export_plaintext_copy(&conn, &dest_path, &["widgets"]).expect("export");
}
let conn = open_encrypted(&restore_path, &key, false).expect("open restore");
conn.execute_batch(
"CREATE TABLE widgets (id INTEGER PRIMARY KEY, val TEXT NOT NULL);",
)
.expect("create table");
conn.execute(
"INSERT INTO widgets (id, val) VALUES (?1, ?2)",
params![99_i64, "preexisting"],
)
.expect("insert");
let err = import_plaintext_copy(&conn, &dest_path, &["widgets"])
.expect_err("import should refuse non-empty destination");
assert!(
err.to_string().contains("non-empty table"),
"expected non-empty-table error, got: {err}"
);
}
}