use crate::{AzothDb, AzothError, BackupManifest, ProjectionStore, Result};
use std::fs::File;
use std::io::{Read, Write};
use std::path::Path;
use std::str::FromStr;
use std::sync::Arc;
#[derive(Clone)]
pub struct EncryptionKey {
identity: age::x25519::Identity,
}
impl EncryptionKey {
pub fn generate() -> Self {
let identity = age::x25519::Identity::generate();
Self { identity }
}
pub fn to_identity_string(&self) -> String {
use age::secrecy::ExposeSecret;
self.identity.to_string().expose_secret().to_string()
}
pub fn to_recipient(&self) -> age::x25519::Recipient {
self.identity.to_public()
}
pub fn to_recipient_string(&self) -> String {
self.identity.to_public().to_string()
}
pub fn from_recipient_str(s: &str) -> Result<Self> {
let _recipient = s
.parse::<age::x25519::Recipient>()
.map_err(|e| AzothError::Config(format!("Invalid age recipient: {}", e)))?;
Err(AzothError::Config(
"Use from_str() with full identity for encryption/decryption".into(),
))
}
}
impl FromStr for EncryptionKey {
type Err = AzothError;
fn from_str(s: &str) -> Result<Self> {
let identity = s
.parse::<age::x25519::Identity>()
.map_err(|e| AzothError::Config(format!("Invalid age identity: {}", e)))?;
Ok(Self { identity })
}
}
#[derive(Clone, Default)]
pub struct BackupOptions {
pub encryption: Option<EncryptionKey>,
pub compression: bool,
pub compression_level: u32,
}
impl BackupOptions {
pub fn new() -> Self {
Self {
encryption: None,
compression: false,
compression_level: 6,
}
}
pub fn with_encryption(mut self, key: EncryptionKey) -> Self {
self.encryption = Some(key);
self
}
pub fn with_compression(mut self, enabled: bool) -> Self {
self.compression = enabled;
self
}
pub fn with_compression_level(mut self, level: u32) -> Self {
self.compression_level = level.min(9);
self
}
pub fn is_encrypted(&self) -> bool {
self.encryption.is_some()
}
}
impl AzothDb {
pub fn backup_with_options<P: AsRef<Path>>(
&self,
dir: P,
options: &BackupOptions,
) -> Result<()> {
use crate::CanonicalStore;
let backup_dir = dir.as_ref();
std::fs::create_dir_all(backup_dir)?;
let canonical = self.canonical().clone();
struct IngestionGuard {
canonical: Arc<crate::LmdbCanonicalStore>,
}
impl Drop for IngestionGuard {
fn drop(&mut self) {
if let Err(e) = self.canonical.clear_seal() {
tracing::error!("Failed to clear seal after backup: {}", e);
}
if let Err(e) = self.canonical.resume_ingestion() {
tracing::error!("Failed to resume ingestion after backup: {}", e);
}
}
}
canonical.pause_ingestion()?;
let _guard = IngestionGuard { canonical };
let sealed_id = self.canonical().seal()?;
tracing::info!("Sealed canonical at event {}", sealed_id);
while self.projector().get_lag()? > 0 {
self.projector().run_once()?;
}
tracing::info!("Projector caught up");
let canonical_dir = backup_dir.join("canonical");
self.canonical().backup_to(&canonical_dir)?;
let projection_path = backup_dir.join("projection.db");
self.projection().backup_to(&projection_path)?;
if options.compression {
tracing::info!("Compressing backup...");
compress_directory(&canonical_dir, options.compression_level)?;
compress_file(&projection_path, options.compression_level)?;
}
if let Some(ref key) = options.encryption {
tracing::info!("Encrypting backup...");
if options.compression {
let canonical_archive = canonical_dir.with_extension("tar.zst");
let projection_archive = projection_path.with_extension("db.zst");
encrypt_file(&canonical_archive, key)?;
encrypt_file(&projection_archive, key)?;
} else {
encrypt_backup(&canonical_dir, key)?;
encrypt_file(&projection_path, key)?;
}
}
let cursor = self.projection().get_cursor()?;
let manifest = BackupManifest::new(
sealed_id,
"lmdb".to_string(),
"sqlite".to_string(),
cursor,
1,
self.projection().schema_version()?,
);
let manifest_path = backup_dir.join("manifest.json");
let manifest_json = serde_json::to_string_pretty(&manifest)
.map_err(|e| AzothError::Serialization(e.to_string()))?;
std::fs::write(&manifest_path, manifest_json)?;
tracing::info!("Backup complete at {}", backup_dir.display());
Ok(())
}
pub fn restore_with_options<P: AsRef<Path>, Q: AsRef<Path>>(
backup_dir: P,
target_path: Q,
options: &BackupOptions,
) -> Result<Self> {
let backup_dir = backup_dir.as_ref();
let target_path = target_path.as_ref();
let manifest_path = backup_dir.join("manifest.json");
let manifest_json = std::fs::read_to_string(&manifest_path)?;
let _manifest: BackupManifest = serde_json::from_str(&manifest_json)
.map_err(|e| AzothError::Serialization(e.to_string()))?;
let canonical_dir = backup_dir.join("canonical");
let projection_file = backup_dir.join("projection.db");
if options.compression {
if let Some(ref key) = options.encryption {
tracing::info!("Decrypting backup artifacts...");
let canonical_archive_age = canonical_dir.with_extension("tar.zst.age");
let projection_file_age = projection_file.with_extension("db.zst.age");
decrypt_file(&canonical_archive_age, key)?;
decrypt_file(&projection_file_age, key)?;
}
tracing::info!("Decompressing backup artifacts...");
decompress_directory(&canonical_dir)?;
decompress_file(&projection_file)?;
} else if let Some(ref key) = options.encryption {
tracing::info!("Decrypting backup...");
decrypt_backup(&canonical_dir, key)?;
let projection_file_age = projection_file.with_extension("db.age");
decrypt_file(&projection_file_age, key)?;
}
Self::restore_from(backup_dir, target_path)
}
}
fn encrypt_backup(dir: &Path, key: &EncryptionKey) -> Result<()> {
let entries = std::fs::read_dir(dir)?;
for entry in entries {
let entry = entry?;
let path = entry.path();
if path.is_file() {
encrypt_file(&path, key)?;
}
}
Ok(())
}
fn encrypt_file(path: &Path, key: &EncryptionKey) -> Result<()> {
let plaintext = std::fs::read(path)?;
let recipient = key.to_recipient();
let encryptor = age::Encryptor::with_recipients(vec![Box::new(recipient)])
.expect("We provided a recipient");
let mut encrypted = vec![];
let mut writer = encryptor
.wrap_output(&mut encrypted)
.map_err(|e| AzothError::Encryption(format!("Failed to wrap output: {}", e)))?;
writer
.write_all(&plaintext)
.map_err(|e| AzothError::Encryption(format!("Failed to write: {}", e)))?;
writer
.finish()
.map_err(|e| AzothError::Encryption(format!("Failed to finish: {}", e)))?;
let encrypted_path = path.with_extension(format!(
"{}.age",
path.extension().and_then(|s| s.to_str()).unwrap_or("dat")
));
std::fs::write(&encrypted_path, encrypted)?;
std::fs::remove_file(path)?;
Ok(())
}
fn decrypt_backup(dir: &Path, key: &EncryptionKey) -> Result<()> {
let entries = std::fs::read_dir(dir)?;
for entry in entries {
let entry = entry?;
let path = entry.path();
if path.is_file() && path.extension().and_then(|s| s.to_str()) == Some("age") {
decrypt_file(&path, key)?;
}
}
Ok(())
}
fn decrypt_file(path: &Path, key: &EncryptionKey) -> Result<()> {
let encrypted = std::fs::read(path)?;
let decryptor = match age::Decryptor::new(&encrypted[..])
.map_err(|e| AzothError::Encryption(format!("Failed to create decryptor: {}", e)))?
{
age::Decryptor::Recipients(d) => d,
_ => {
return Err(AzothError::Encryption(
"Unexpected decryptor type".to_string(),
))
}
};
let mut decrypted = vec![];
let mut reader = decryptor
.decrypt(std::iter::once(&key.identity as &dyn age::Identity))
.map_err(|e| AzothError::Encryption(format!("Failed to decrypt: {}", e)))?;
reader
.read_to_end(&mut decrypted)
.map_err(|e| AzothError::Encryption(format!("Failed to read: {}", e)))?;
let original_path = if let Some(stem) = path.file_stem() {
path.with_file_name(stem)
} else {
path.with_extension("")
};
std::fs::write(&original_path, decrypted)?;
std::fs::remove_file(path)?;
Ok(())
}
fn compress_directory(dir: &Path, level: u32) -> Result<()> {
let archive_path = dir.with_extension("tar.zst");
let archive_file = File::create(&archive_path)?;
let encoder = zstd::Encoder::new(archive_file, level as i32)?;
let mut tar_builder = tar::Builder::new(encoder);
let dir_name = dir
.file_name()
.ok_or_else(|| AzothError::Config("Cannot get directory name".to_string()))?;
tar_builder
.append_dir_all(dir_name, dir)
.map_err(AzothError::Io)?;
let encoder = tar_builder.into_inner().map_err(AzothError::Io)?;
encoder.finish()?;
std::fs::remove_dir_all(dir)?;
Ok(())
}
fn compress_file(path: &Path, level: u32) -> Result<()> {
let data = std::fs::read(path)?;
let compressed = zstd::encode_all(&data[..], level as i32)?;
let compressed_path = path.with_extension(format!(
"{}.zst",
path.extension().and_then(|s| s.to_str()).unwrap_or("dat")
));
std::fs::write(&compressed_path, compressed)?;
std::fs::remove_file(path)?;
Ok(())
}
fn decompress_directory(dir_path: &Path) -> Result<()> {
let archive_path = dir_path.with_extension("tar.zst");
if !archive_path.exists() {
return Err(AzothError::Io(std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("Archive not found: {}", archive_path.display()),
)));
}
let archive_file = File::open(&archive_path)?;
let decoder = zstd::Decoder::new(archive_file)?;
let mut tar_archive = tar::Archive::new(decoder);
let parent = dir_path
.parent()
.ok_or_else(|| AzothError::Config("Cannot get parent directory".to_string()))?;
tar_archive.unpack(parent).map_err(AzothError::Io)?;
std::fs::remove_file(&archive_path)?;
Ok(())
}
fn decompress_file(path: &Path) -> Result<()> {
let compressed_path = if path.extension().and_then(|s| s.to_str()) == Some("zst") {
path.to_path_buf()
} else {
path.with_extension(format!(
"{}.zst",
path.extension().and_then(|s| s.to_str()).unwrap_or("dat")
))
};
if !compressed_path.exists() {
return Err(AzothError::Io(std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("Compressed file not found: {}", compressed_path.display()),
)));
}
let compressed = std::fs::read(&compressed_path)?;
let decompressed = zstd::decode_all(&compressed[..])?;
let original_path = if let Some(stem) = compressed_path.file_stem() {
compressed_path.with_file_name(stem)
} else {
path.to_path_buf()
};
std::fs::write(&original_path, decompressed)?;
std::fs::remove_file(&compressed_path)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encryption_key_generation() {
let key = EncryptionKey::generate();
let key_str = key.to_identity_string();
assert!(!key_str.is_empty());
let parsed = EncryptionKey::from_str(&key_str).unwrap();
assert_eq!(parsed.to_identity_string(), key_str);
}
#[test]
fn test_encryption_key_recipient() {
let key = EncryptionKey::generate();
let recipient_str = key.to_recipient_string();
assert!(recipient_str.starts_with("age1"));
}
#[test]
fn test_backup_options() {
let options = BackupOptions::new()
.with_encryption(EncryptionKey::generate())
.with_compression(true)
.with_compression_level(9);
assert!(options.is_encrypted());
assert!(options.compression);
assert_eq!(options.compression_level, 9);
}
}