use anyhow::{Context, Result};
use indicatif::{ProgressBar, ProgressStyle};
use log::{error, info};
use serde::{Deserialize, Serialize};
use std::path::Path;
use std::str;
use tokio::process::Command;
use crate::config::{get_backup_dir, get_mongodb_bin_path, MongoConfig};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BackupVerification {
pub database: String,
pub backup_path: String,
pub database_path: String,
pub exists: bool,
pub file_count: usize,
pub byte_count: u64,
pub verified: bool,
}
pub fn validate_db_name(name: &str) -> Result<()> {
if name.is_empty() {
anyhow::bail!("Database name cannot be empty");
}
if name.len() > 64 {
anyhow::bail!("Database name too long (max 64 characters)");
}
let invalid_chars = ['/', '\\', '.', '"', '*', '<', '>', ':', '|', '?', '\0', ' '];
if let Some(c) = name.chars().find(|c| invalid_chars.contains(c)) {
anyhow::bail!("Database name contains invalid character: '{}'", c);
}
Ok(())
}
pub fn mask_connection_string(uri: &str) -> String {
if let Some((left, right)) = uri.rsplit_once('@') {
let scheme = left
.split_once("://")
.map(|(scheme, _)| scheme)
.unwrap_or("mongodb");
if let Some((host, _query)) = right.split_once('?') {
return format!("{scheme}://*****@{host}?<params>");
}
return format!("{scheme}://*****@{right}");
}
if let Some((base, _query)) = uri.split_once('?') {
return format!("{base}?<params>");
}
uri.to_string()
}
pub async fn list_databases(config: &MongoConfig) -> Result<Vec<String>> {
let client_options = config.get_client_options().await?;
let client = mongodb::Client::with_options(client_options)?;
let db_names = client.list_database_names().await?;
Ok(db_names)
}
pub async fn export_database(
config: &MongoConfig,
database: &str,
output_dir: &Path,
) -> Result<()> {
validate_db_name(database)?;
info!(
"Exporting database {} from {}",
database, config.environment
);
let mut progress = create_progress_bar("Exporting");
let bin_path = get_mongodb_bin_path().map_err(|e| {
error!("Failed to find MongoDB tools: {}", e);
anyhow::anyhow!("Failed to find mongodump")
})?;
let mongodump_path = bin_path.join("mongodump");
info!("Using mongodump from: {}", mongodump_path.display());
info!(
"MongoDB connection string: {}",
mask_connection_string(&config.connection_string)
);
let output = Command::new(mongodump_path)
.arg("--uri")
.arg(&config.connection_string)
.arg("--db")
.arg(database)
.arg("--out")
.arg(output_dir)
.output()
.await
.context("Failed to execute mongodump")?;
progress.finish_with_message("Export completed");
if !output.status.success() {
let stderr = str::from_utf8(&output.stderr)?;
error!("Export failed: {}", stderr);
anyhow::bail!("Export failed: {}", stderr);
} else {
let stdout = str::from_utf8(&output.stdout)?;
info!("Export output: {}", stdout);
}
let db_path = output_dir.join(database);
if !db_path.exists() {
info!(
"Database '{}' appears to be empty, creating placeholder directory",
database
);
std::fs::create_dir_all(&db_path)
.context("Failed to create placeholder for empty database")?;
}
Ok(())
}
pub async fn import_database(
config: &MongoConfig,
database: &str,
input_dir: &Path,
drop: bool,
clear: bool,
) -> Result<()> {
validate_db_name(database)?;
info!("Importing database {} to {}", database, config.environment);
if clear && !drop {
clear_collections(config, database).await?;
}
let mut progress = create_progress_bar("Importing");
let bin_path = get_mongodb_bin_path().map_err(|e| {
error!("Failed to find MongoDB tools: {}", e);
anyhow::anyhow!("Failed to find mongorestore")
})?;
let mongorestore_path = bin_path.join("mongorestore");
info!("Using mongorestore from: {}", mongorestore_path.display());
let db_path = input_dir.join(database);
if !db_path.exists() {
error!("Database directory not found: {}", db_path.display());
anyhow::bail!("Database directory not found: {}", db_path.display());
}
let mut command = Command::new(&mongorestore_path);
command
.arg("--uri")
.arg(&config.connection_string)
.arg("--nsInclude")
.arg(format!("{}.*", database));
if drop {
command.arg("--drop");
}
if clear && !drop {
command.arg("--noIndexRestore");
}
command.arg(input_dir);
info!("Running restore with directory: {}", input_dir.display());
let output = command
.output()
.await
.context("Failed to execute mongorestore")?;
progress.finish_with_message("Import completed");
if !output.status.success() {
let stderr = str::from_utf8(&output.stderr)?;
error!("Import failed: {}", stderr);
anyhow::bail!("Import failed: {}", stderr);
} else {
let stdout = str::from_utf8(&output.stdout)?;
info!("Import output: {}", stdout);
}
Ok(())
}
pub async fn create_backup(config: &MongoConfig, database: &str) -> Result<std::path::PathBuf> {
info!(
"Creating backup of {} from {}",
database, config.environment
);
let backup_dir = get_backup_dir();
let timestamp = chrono::Utc::now().format("%Y%m%d%H%M%S");
let backup_path = backup_dir.join(format!("backup_{}_{}", database, timestamp));
std::fs::create_dir_all(&backup_path)?;
export_database(config, database, &backup_path).await?;
Ok(backup_path)
}
pub fn verify_backup(backup_path: &Path, database: &str) -> Result<BackupVerification> {
validate_db_name(database)?;
let database_path = backup_path.join(database);
let exists = database_path.exists();
let mut file_count = 0usize;
let mut byte_count = 0u64;
if exists {
accumulate_backup_stats(&database_path, &mut file_count, &mut byte_count)?;
}
Ok(BackupVerification {
database: database.to_string(),
backup_path: backup_path.display().to_string(),
database_path: database_path.display().to_string(),
exists,
file_count,
byte_count,
verified: exists,
})
}
fn accumulate_backup_stats(
path: &Path,
file_count: &mut usize,
byte_count: &mut u64,
) -> Result<()> {
for entry in std::fs::read_dir(path)
.with_context(|| format!("Failed to read backup directory {}", path.display()))?
{
let entry = entry?;
let metadata = entry.metadata()?;
if metadata.is_dir() {
accumulate_backup_stats(&entry.path(), file_count, byte_count)?;
} else if metadata.is_file() {
*file_count += 1;
*byte_count += metadata.len();
}
}
Ok(())
}
pub async fn restore_backup(
config: &MongoConfig,
database: &str,
backup_path: &Path,
) -> Result<()> {
info!("Restoring backup of {} to {}", database, config.environment);
import_database(config, database, backup_path, true, false).await?;
Ok(())
}
pub async fn clear_collections(config: &MongoConfig, database: &str) -> Result<()> {
info!(
"Clearing all collections in database {} on {}",
database, config.environment
);
let mut progress = create_progress_bar("Clearing collections");
let client_options = config.get_client_options().await?;
let client = mongodb::Client::with_options(client_options)?;
let db = client.database(database);
let mut collections = db.list_collection_names().await?;
collections.retain(|name| !name.starts_with("system."));
for collection_name in collections {
let collection = db.collection::<mongodb::bson::Document>(&collection_name);
collection.delete_many(mongodb::bson::doc! {}).await?;
}
progress.finish_with_message("Collections cleared");
Ok(())
}
struct ProgressGuard {
pb: ProgressBar,
finished: bool,
}
impl ProgressGuard {
fn new(message: &str) -> Self {
let pb = if crate::output::is_json() {
ProgressBar::hidden()
} else {
let pb = ProgressBar::new_spinner();
pb.set_style(
ProgressStyle::default_spinner()
.template("{spinner:.green} {msg}")
.expect("Invalid progress template - this is a bug"),
);
pb.set_message(format!("{} in progress...", message));
pb.enable_steady_tick(std::time::Duration::from_millis(100));
pb
};
Self {
pb,
finished: false,
}
}
fn finish_with_message(&mut self, msg: &str) {
self.pb.finish_with_message(msg.to_string());
self.finished = true;
}
}
impl Drop for ProgressGuard {
fn drop(&mut self) {
if !self.finished {
self.pb.finish_and_clear();
}
}
}
fn create_progress_bar(message: &str) -> ProgressGuard {
ProgressGuard::new(message)
}