use amplify::s;
use chacha20poly1305::aead::{generic_array::GenericArray, stream};
use chacha20poly1305::{Key, KeyInit, XChaCha20Poly1305};
use rand::{distributions::Alphanumeric, Rng};
use scrypt::password_hash::{PasswordHasher, Salt, SaltString};
use scrypt::{Params, Scrypt};
use sea_orm::ActiveValue;
use serde::{Deserialize, Serialize};
use slog::Logger;
use tempfile::TempDir;
use typenum::consts::U32;
use walkdir::WalkDir;
use zip::write::FileOptions;
use std::fs::{create_dir_all, read_to_string, remove_file, write, File};
use std::io::{Read, Write};
use std::path::{Path, PathBuf};
use crate::database::entities::backup_info::{
ActiveModel as DbBackupInfoActMod, Model as DbBackupInfo,
};
use crate::utils::now;
use crate::wallet::{setup_logger, InternalError, LOG_FILE};
use crate::{Error, Wallet};
const BACKUP_BUFFER_LEN_ENCRYPT: usize = 239; const BACKUP_BUFFER_LEN_DECRYPT: usize = BACKUP_BUFFER_LEN_ENCRYPT + 16;
const BACKUP_KEY_LENGTH: usize = 32;
const BACKUP_NONCE_LENGTH: usize = 19;
const BACKUP_VERSION: u8 = 1;
struct BackupPaths {
encrypted: PathBuf,
backup_pub_data: PathBuf,
tempdir: TempDir,
zip: PathBuf,
}
#[derive(Clone, Deserialize, Serialize)]
pub(crate) struct ScryptParams {
log_n: u8,
r: u32,
p: u32,
len: usize,
version: Option<u32>,
algorithm: Option<String>,
}
impl ScryptParams {
pub(crate) fn new(log_n: Option<u8>, r: Option<u32>, p: Option<u32>) -> ScryptParams {
ScryptParams {
log_n: log_n.unwrap_or(Params::RECOMMENDED_LOG_N),
r: r.unwrap_or(Params::RECOMMENDED_R),
p: p.unwrap_or(Params::RECOMMENDED_P),
len: BACKUP_KEY_LENGTH,
version: None,
algorithm: None,
}
}
}
impl Default for ScryptParams {
fn default() -> ScryptParams {
ScryptParams::new(None, None, None)
}
}
impl TryInto<Params> for ScryptParams {
type Error = Error;
fn try_into(self: ScryptParams) -> Result<Params, Error> {
Params::new(self.log_n, self.r, self.p, self.len).map_err(|e| Error::Internal {
details: format!("invalid params {}", e),
})
}
}
#[derive(Deserialize, Serialize)]
struct BackupPubData {
scrypt_params: ScryptParams,
salt: String,
nonce: String,
version: u8,
}
impl BackupPubData {
fn nonce(&self) -> Result<[u8; BACKUP_NONCE_LENGTH], InternalError> {
let nonce_bytes = self.nonce.as_bytes();
nonce_bytes[0..BACKUP_NONCE_LENGTH]
.try_into()
.map_err(|_| InternalError::Unexpected)
}
}
impl Wallet {
pub fn backup(&self, backup_path: &str, password: &str) -> Result<(), Error> {
self.backup_customize(backup_path, password, None)
}
pub(crate) fn backup_customize(
&self,
backup_path: &str,
password: &str,
scrypt_params: Option<ScryptParams>,
) -> Result<(), Error> {
let prev_backup_info = self.update_backup_info(true)?;
match self._backup(backup_path, password, scrypt_params) {
Ok(()) => Ok(()),
Err(e) => {
if let Some(prev_backup_info) = prev_backup_info {
let mut prev_backup_info: DbBackupInfoActMod = prev_backup_info.into();
self.database.update_backup_info(&mut prev_backup_info)?;
} else {
self.database.del_backup_info()?;
}
Err(e)
}
}
}
fn _backup(
&self,
backup_path: &str,
password: &str,
scrypt_params: Option<ScryptParams>,
) -> Result<(), Error> {
info!(self.logger, "starting backup...");
let backup_file = PathBuf::from(&backup_path);
if backup_file.exists() {
return Err(Error::FileAlreadyExists {
path: backup_path.to_string(),
})?;
}
let tmp_base_path = _get_parent_path(&backup_file)?;
let files = _get_backup_paths(&tmp_base_path)?;
let scrypt_params = scrypt_params.unwrap_or_default();
let mut rng = rand::thread_rng();
let salt = SaltString::generate(&mut rng);
let str_params = serde_json::to_string(&scrypt_params).map_err(InternalError::from)?;
debug!(self.logger, "using generated scrypt params: {}", str_params);
let nonce: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(BACKUP_NONCE_LENGTH)
.map(char::from)
.collect();
debug!(self.logger, "using generated nonce: {}", &nonce);
let backup_pub_data = BackupPubData {
scrypt_params,
salt: salt.to_string(),
nonce,
version: BACKUP_VERSION,
};
debug!(
self.logger,
"\nzipping {:?} to {:?}", &self.wallet_dir, &files.zip
);
_zip_dir(&self.wallet_dir, &files.zip, true, &self.logger)?;
debug!(
self.logger,
"\nencrypting {:?} to {:?}", &files.zip, &files.encrypted
);
_encrypt_file(&files.zip, &files.encrypted, password, &backup_pub_data)?;
write(
files.backup_pub_data,
serde_json::to_string(&backup_pub_data).unwrap(),
)?;
debug!(
self.logger,
"\nzipping {:?} to {:?}", &files.tempdir, &backup_file
);
_zip_dir(
&PathBuf::from(files.tempdir.path()),
&backup_file,
false,
&self.logger,
)?;
info!(self.logger, "backup completed");
Ok(())
}
pub fn backup_info(&self) -> Result<bool, Error> {
let backup_required = if let Some(backup_info) = self.database.get_backup_info()? {
backup_info
.last_operation_timestamp
.parse::<i128>()
.unwrap()
> backup_info.last_backup_timestamp.parse::<i128>().unwrap()
} else {
false
};
Ok(backup_required)
}
pub(crate) fn update_backup_info(
&self,
doing_backup: bool,
) -> Result<Option<DbBackupInfo>, Error> {
let now = ActiveValue::Set(now().unix_timestamp_nanos().to_string());
if let Some(backup_info) = self.database.get_backup_info()? {
let prev_backup_info = backup_info.clone();
let mut backup_info: DbBackupInfoActMod = backup_info.into();
if doing_backup {
backup_info.last_backup_timestamp = now;
} else {
backup_info.last_operation_timestamp = now;
}
self.database.update_backup_info(&mut backup_info)?;
Ok(Some(prev_backup_info))
} else {
let (last_backup_timestamp, last_operation_timestamp) = if doing_backup {
(now, ActiveValue::Set(s!("0")))
} else {
(ActiveValue::Set(s!("0")), now)
};
let backup_info = DbBackupInfoActMod {
last_backup_timestamp,
last_operation_timestamp,
..Default::default()
};
self.database.set_backup_info(backup_info)?;
Ok(None)
}
}
}
pub fn restore_backup(backup_path: &str, password: &str, target_dir: &str) -> Result<(), Error> {
create_dir_all(target_dir)?;
let log_dir = Path::new(&target_dir);
let log_name = format!("restore_{}", now().unix_timestamp());
let logger = setup_logger(log_dir.to_path_buf(), Some(&log_name))?;
info!(logger, "starting restore...");
let backup_file = PathBuf::from(backup_path);
let tmp_base_path = _get_parent_path(&backup_file)?;
let files = _get_backup_paths(&tmp_base_path)?;
let target_dir_path = PathBuf::from(&target_dir);
info!(logger, "unzipping {:?}", backup_file);
_unzip(&backup_file, &PathBuf::from(files.tempdir.path()), &logger)?;
let json_pub_data = read_to_string(files.backup_pub_data)?;
debug!(logger, "using retrieved backup_pub_data: {}", json_pub_data);
let backup_pub_data: BackupPubData =
serde_json::from_str(json_pub_data.as_str()).map_err(InternalError::from)?;
let version = backup_pub_data.version;
debug!(logger, "retrieved version: {}", &version);
if version != BACKUP_VERSION {
return Err(Error::UnsupportedBackupVersion {
version: version.to_string(),
});
}
info!(
logger.clone(),
"decrypting {:?} to {:?}", files.encrypted, files.zip
);
_decrypt_file(&files.encrypted, &files.zip, password, &backup_pub_data)?;
info!(
logger.clone(),
"unzipping {:?} to {:?}", &files.zip, &target_dir_path
);
_unzip(&files.zip, &target_dir_path, &logger)?;
info!(logger, "restore completed");
Ok(())
}
fn _get_backup_paths(tmp_base_path: &Path) -> Result<BackupPaths, Error> {
create_dir_all(tmp_base_path)?;
let tempdir = tempfile::tempdir_in(tmp_base_path)?;
let encrypted = tempdir.path().join("backup.enc");
let backup_pub_data = tempdir.path().join("backup.pub_data");
let zip = tempdir.path().join("backup.zip");
Ok(BackupPaths {
encrypted,
backup_pub_data,
tempdir,
zip,
})
}
fn _get_parent_path(file: &Path) -> Result<PathBuf, Error> {
if let Some(parent) = file.parent() {
Ok(parent.to_path_buf())
} else {
Err(Error::IO {
details: "provided file path has no parent".to_string(),
})
}
}
fn _zip_dir(
path_in: &PathBuf,
path_out: &PathBuf,
keep_last_path_component: bool,
logger: &Logger,
) -> Result<(), Error> {
let writer = File::create(path_out)?;
let mut zip = zip::ZipWriter::new(writer);
let options = FileOptions::default().compression_method(zip::CompressionMethod::Zstd);
let mut buffer = [0u8; 4096];
let prefix = if keep_last_path_component {
if let Some(parent) = path_in.parent() {
parent
} else {
return Err(Error::Internal {
details: "no parent directory".to_string(),
});
}
} else {
path_in
};
let entry_iterator = WalkDir::new(path_in).into_iter().filter_map(|e| e.ok());
for entry in entry_iterator {
let path = entry.path();
let name = path.strip_prefix(prefix).map_err(InternalError::from)?;
let name_str = name.to_str().ok_or_else(|| InternalError::Unexpected)?;
if path.is_file() {
if path.ends_with(LOG_FILE) {
continue;
}; debug!(logger, "adding file {path:?} as {name:?}");
zip.start_file(name_str, options)
.map_err(InternalError::from)?;
let mut f = File::open(path)?;
loop {
let read_count = f.read(&mut buffer)?;
if read_count != 0 {
zip.write_all(&buffer[..read_count])?;
} else {
break;
}
}
} else if !name.as_os_str().is_empty() {
debug!(logger, "adding directory {path:?} as {name:?}");
zip.add_directory(name_str, options)
.map_err(InternalError::from)?;
}
}
let mut file = zip.finish().map_err(InternalError::from)?;
file.flush()?;
file.sync_all()?;
Ok(())
}
fn _unzip(zip_path: &PathBuf, path_out: &Path, logger: &Logger) -> Result<(), Error> {
let file = File::open(zip_path).map_err(InternalError::from)?;
let mut archive = zip::ZipArchive::new(file).map_err(InternalError::from)?;
for i in 0..archive.len() {
let mut file = archive.by_index(i).map_err(InternalError::from)?;
let outpath = match file.enclosed_name() {
Some(path) => path_out.join(path),
None => continue,
};
if file.name().ends_with('/') {
debug!(logger, "creating directory {i} as {}", outpath.display());
create_dir_all(&outpath)?;
} else {
debug!(
logger,
"extracting file {i} to {} ({} bytes)",
outpath.display(),
file.size()
);
if let Some(p) = outpath.parent() {
if !p.exists() {
debug!(logger, "creating parent dir {}", p.display());
create_dir_all(p)?;
}
}
let mut outfile = File::create(&outpath)?;
std::io::copy(&mut file, &mut outfile)?;
}
}
Ok(())
}
fn _get_cypher_secrets(
password: &str,
backup_pub_data: &BackupPubData,
) -> Result<GenericArray<u8, U32>, Error> {
let password_bytes = password.as_bytes();
let salt = Salt::from_b64(&backup_pub_data.salt).map_err(InternalError::from)?;
let password_hash = Scrypt
.hash_password_customized(
password_bytes,
None,
None,
backup_pub_data.scrypt_params.clone().try_into()?,
salt,
)
.map_err(InternalError::from)?;
let hash_output = password_hash
.hash
.ok_or_else(|| InternalError::NoPasswordHashError)?;
let hash = hash_output.as_bytes();
let key = Key::clone_from_slice(hash);
Ok(key)
}
fn _encrypt_file(
path_cleartext: &PathBuf,
path_encrypted: &PathBuf,
password: &str,
backup_pub_data: &BackupPubData,
) -> Result<(), Error> {
let key = _get_cypher_secrets(password, backup_pub_data)?;
let aead = XChaCha20Poly1305::new(&key);
let nonce = backup_pub_data.nonce()?;
let nonce = GenericArray::from_slice(&nonce);
let mut stream_encryptor = stream::EncryptorBE32::from_aead(aead, nonce);
let mut buffer = [0u8; BACKUP_BUFFER_LEN_ENCRYPT];
let mut source_file = File::open(path_cleartext)?;
let mut destination_file = File::create(path_encrypted)?;
loop {
let read_count = source_file.read(&mut buffer)?;
if read_count == BACKUP_BUFFER_LEN_ENCRYPT {
let ciphertext = stream_encryptor
.encrypt_next(buffer.as_slice())
.map_err(|e| InternalError::AeadError(e.to_string()))?;
destination_file.write_all(&ciphertext)?;
} else {
let ciphertext = stream_encryptor
.encrypt_last(&buffer[..read_count])
.map_err(|e| InternalError::AeadError(e.to_string()))?;
destination_file.write_all(&ciphertext)?;
break;
}
}
remove_file(path_cleartext)?;
Ok(())
}
fn _decrypt_file(
path_encrypted: &PathBuf,
path_cleartext: &PathBuf,
password: &str,
backup_pub_data: &BackupPubData,
) -> Result<(), Error> {
let key = _get_cypher_secrets(password, backup_pub_data)?;
let aead = XChaCha20Poly1305::new(&key);
let nonce = backup_pub_data.nonce()?;
let nonce = GenericArray::from_slice(&nonce);
let mut stream_decryptor = stream::DecryptorBE32::from_aead(aead, nonce);
let mut buffer = [0u8; BACKUP_BUFFER_LEN_DECRYPT];
let mut source_file = File::open(path_encrypted)?;
let mut destination_file = File::create(path_cleartext)?;
loop {
let read_count = source_file.read(&mut buffer)?;
if read_count == BACKUP_BUFFER_LEN_DECRYPT {
let cleartext = stream_decryptor
.decrypt_next(buffer.as_slice())
.map_err(|_| Error::WrongPassword)?;
destination_file.write_all(&cleartext)?;
} else if read_count == 0 {
break;
} else {
let cleartext = stream_decryptor
.decrypt_last(&buffer[..read_count])
.map_err(|_| Error::WrongPassword)?;
destination_file.write_all(&cleartext)?;
break;
}
}
Ok(())
}