mod common;
mod error;
mod secret;
mod storage;
use std::{
path::{Path, PathBuf},
sync::Arc,
time::Duration,
};
use derive_builder::Builder;
use iota_stronghold::{KeyProvider, SnapshotPath, Stronghold};
use log::{debug, error, warn};
use tokio::{sync::Mutex, task::JoinHandle};
use zeroize::Zeroizing;
use self::common::PRIVATE_DATA_CLIENT_PATH;
pub use self::error::Error;
use crate::client::storage::StorageProvider;
#[derive(Builder)]
#[builder(pattern = "owned", build_fn(skip))]
pub struct StrongholdAdapter {
#[builder(field(type = "Option<Stronghold>"))]
stronghold: Arc<Mutex<Stronghold>>,
#[builder(setter(custom))]
#[builder(field(type = "Option<KeyProvider>"))]
key_provider: Arc<Mutex<Option<KeyProvider>>>,
#[builder(setter(strip_option))]
timeout: Option<Duration>,
#[builder(setter(custom))]
timeout_task: Arc<Mutex<Option<JoinHandle<()>>>>,
#[builder(setter(skip))]
pub snapshot_path: PathBuf,
}
fn check_or_create_snapshot(
stronghold: &Stronghold,
key_provider: &KeyProvider,
snapshot_path: &SnapshotPath,
) -> Result<(), Error> {
let result = stronghold.load_client_from_snapshot(PRIVATE_DATA_CLIENT_PATH, key_provider, snapshot_path);
match result {
Err(iota_stronghold::ClientError::SnapshotFileMissing(_)) => {
stronghold.create_client(PRIVATE_DATA_CLIENT_PATH)?;
stronghold.commit_with_keyprovider(snapshot_path, key_provider)?;
}
Err(iota_stronghold::ClientError::ClientAlreadyLoaded(_)) => {
stronghold.get_client(PRIVATE_DATA_CLIENT_PATH)?;
}
Err(iota_stronghold::ClientError::Inner(ref err_msg)) => {
if err_msg.to_string().contains("XCHACHA20-POLY1305") {
return Err(Error::InvalidPassword);
}
}
_ => {}
}
Ok(())
}
impl StrongholdAdapterBuilder {
pub fn password(mut self, password: &str) -> Self {
self.key_provider = Some(self::common::key_provider_from_password(password));
self
}
pub fn build<P: AsRef<Path>>(mut self, snapshot_path: P) -> Result<StrongholdAdapter, Error> {
let stronghold = self.stronghold.unwrap_or_default();
if let Some(key_provider) = &self.key_provider {
check_or_create_snapshot(&stronghold, key_provider, &SnapshotPath::from_path(&snapshot_path))?;
}
let has_key_provider = self.key_provider.is_some();
let key_provider = Arc::new(Mutex::new(self.key_provider));
let stronghold = Arc::new(Mutex::new(stronghold));
if let (true, Some(Some(timeout))) = (has_key_provider, self.timeout) {
let timeout_task = Arc::new(Mutex::new(None));
let task_self = timeout_task.clone();
let key_provider = key_provider.clone();
let stronghold_clone = stronghold.clone();
tokio::spawn(async move {
*task_self.lock().await = Some(tokio::spawn(task_key_clear(
task_self.clone(), stronghold_clone,
key_provider,
timeout,
)));
});
self.timeout_task = Some(timeout_task);
}
Ok(StrongholdAdapter {
stronghold,
key_provider,
timeout: self.timeout.unwrap_or(None),
timeout_task: self.timeout_task.unwrap_or_else(|| Arc::new(Mutex::new(None))),
snapshot_path: snapshot_path.as_ref().to_path_buf(),
})
}
}
impl StrongholdAdapter {
pub fn builder() -> StrongholdAdapterBuilder {
StrongholdAdapterBuilder::default()
}
pub async fn is_key_available(&self) -> bool {
self.key_provider.lock().await.is_some()
}
pub async fn set_password(&mut self, password: &str) -> Result<(), Error> {
let mut key_provider_guard = self.key_provider.lock().await;
let key_provider = self::common::key_provider_from_password(password);
if let Some(old_key_provider) = &*key_provider_guard {
if old_key_provider.try_unlock()? != key_provider.try_unlock()? {
return Err(Error::InvalidPassword);
}
}
let snapshot_path = SnapshotPath::from_path(&self.snapshot_path);
let stronghold = self.stronghold.lock().await;
check_or_create_snapshot(&stronghold, &key_provider, &snapshot_path)?;
*key_provider_guard = Some(key_provider);
drop(key_provider_guard);
if let Some(timeout) = self.timeout {
if let Some(timeout_task) = self.timeout_task.lock().await.take() {
timeout_task.abort();
}
let task_self = self.timeout_task.clone();
let key_provider = self.key_provider.clone();
*self.timeout_task.lock().await = Some(tokio::spawn(task_key_clear(
task_self,
self.stronghold.clone(),
key_provider,
timeout,
)));
}
Ok(())
}
pub async fn change_password(&mut self, new_password: &str) -> Result<(), Error> {
if let Some(timeout_task) = self.timeout_task.lock().await.take() {
timeout_task.abort();
}
self.write_stronghold_snapshot(None).await?;
let mut values = Vec::new();
let keys_to_re_encrypt = self
.stronghold
.lock()
.await
.get_client(PRIVATE_DATA_CLIENT_PATH)?
.store()
.keys()?;
for key in keys_to_re_encrypt {
let value = match self.get(&key).await {
Err(err) => {
error!("an error occurred during the re-encryption of Stronghold Store: {err}");
if let Some(timeout) = self.timeout {
let task_self = self.timeout_task.clone();
let key_provider = self.key_provider.clone();
*self.timeout_task.lock().await = Some(tokio::spawn(task_key_clear(
task_self,
self.stronghold.clone(),
key_provider,
timeout,
)));
}
return Err(err);
}
Ok(None) => continue,
Ok(Some(value)) => Zeroizing::new(value),
};
values.push((key, value));
}
let old_key_provider = {
let mut lock = self.key_provider.lock().await;
let old_key_provider = lock.take();
*lock = Some(self::common::key_provider_from_password(new_password));
old_key_provider
};
for (key, value) in values {
if let Err(err) = self.insert(&key, &value).await {
error!("an error occurred during the re-encryption of Stronghold store: {err}");
*self.key_provider.lock().await = old_key_provider;
self.read_stronghold_snapshot().await?;
if let Some(timeout) = self.timeout {
let task_self = self.timeout_task.clone();
let key_provider = self.key_provider.clone();
*self.timeout_task.lock().await = Some(tokio::spawn(task_key_clear(
task_self,
self.stronghold.clone(),
key_provider,
timeout,
)));
}
return Err(err);
}
}
self.write_stronghold_snapshot(None).await?;
if let Some(timeout) = self.timeout {
let task_self = self.timeout_task.clone();
let key_provider = self.key_provider.clone();
*self.timeout_task.lock().await = Some(tokio::spawn(task_key_clear(
task_self,
self.stronghold.clone(),
key_provider,
timeout,
)));
}
Ok(())
}
pub async fn clear_key(&mut self) {
if let Some(timeout_task) = self.timeout_task.lock().await.take() {
timeout_task.abort();
}
if self.is_key_available().await {
if let Err(err) = self.unload_stronghold_snapshot().await {
warn!("failed to unload Stronghold while clearing the key: {err}");
}
}
self.key_provider.lock().await.take();
debug!("cleared stronghold key");
}
pub fn get_timeout(&self) -> Option<Duration> {
self.timeout
}
pub async fn set_timeout(&mut self, new_timeout: Option<Duration>) {
if let Some(timeout_task) = self.timeout_task.lock().await.take() {
timeout_task.abort();
}
self.timeout = new_timeout;
if let (Some(_), Some(timeout)) = (self.key_provider.lock().await.as_ref(), self.timeout) {
let task_self = self.timeout_task.clone();
let key_provider = self.key_provider.clone();
*self.timeout_task.lock().await = Some(tokio::spawn(task_key_clear(
task_self,
self.stronghold.clone(),
key_provider,
timeout,
)));
}
}
pub async fn restart_key_clearing_task(&mut self) {
self.set_timeout(self.get_timeout()).await;
}
#[allow(clippy::significant_drop_tightening)]
pub async fn read_stronghold_snapshot(&mut self) -> Result<(), Error> {
let locked_key_provider = self.key_provider.lock().await;
let key_provider = if let Some(key_provider) = &*locked_key_provider {
key_provider
} else {
return Err(Error::KeyCleared);
};
self.stronghold.lock().await.load_client_from_snapshot(
PRIVATE_DATA_CLIENT_PATH,
key_provider,
&SnapshotPath::from_path(&self.snapshot_path),
)?;
Ok(())
}
#[allow(clippy::significant_drop_tightening)]
pub async fn write_stronghold_snapshot(&self, snapshot_path: Option<&Path>) -> Result<(), Error> {
let locked_key_provider = self.key_provider.lock().await;
let key_provider = if let Some(key_provider) = &*locked_key_provider {
key_provider
} else {
return Err(Error::KeyCleared);
};
self.stronghold.lock().await.commit_with_keyprovider(
&SnapshotPath::from_path(snapshot_path.unwrap_or(&self.snapshot_path)),
key_provider,
)?;
Ok(())
}
pub async fn unload_stronghold_snapshot(&mut self) -> Result<(), Error> {
self.write_stronghold_snapshot(None).await?;
self.stronghold.lock().await.clear()?;
Ok(())
}
}
async fn task_key_clear(
task_self: Arc<Mutex<Option<JoinHandle<()>>>>,
stronghold: Arc<Mutex<Stronghold>>,
key_provider: Arc<Mutex<Option<KeyProvider>>>,
timeout: Duration,
) {
tokio::time::sleep(timeout).await;
debug!("StrongholdAdapter is purging the key");
key_provider.lock().await.take();
stronghold.lock().await.clear().unwrap();
task_self.lock().await.take();
}
#[cfg(test)]
mod tests {
use std::fs;
use super::*;
#[tokio::test]
async fn test_clear_key() {
let timeout = Duration::from_millis(100);
let stronghold_path = "test_clear_key.stronghold";
let mut adapter = StrongholdAdapter::builder()
.password("drowssap")
.timeout(timeout)
.build(stronghold_path)
.unwrap();
tokio::time::sleep(Duration::from_millis(10)).await;
assert!(matches!(*adapter.key_provider.lock().await, Some(_)));
assert_eq!(adapter.get_timeout(), Some(timeout));
assert!(matches!(*adapter.timeout_task.lock().await, Some(_)));
tokio::time::sleep(Duration::from_millis(200)).await;
assert!(matches!(*adapter.key_provider.lock().await, None));
assert_eq!(adapter.get_timeout(), Some(timeout));
assert!(matches!(*adapter.timeout_task.lock().await, None));
let timeout = None;
adapter.set_timeout(timeout).await;
assert!(adapter.set_password("password").await.is_err());
adapter.clear_key().await;
assert!(matches!(*adapter.key_provider.lock().await, None));
assert_eq!(adapter.get_timeout(), timeout);
assert!(matches!(*adapter.timeout_task.lock().await, None));
adapter.restart_key_clearing_task().await;
assert!(matches!(*adapter.key_provider.lock().await, None));
assert_eq!(adapter.get_timeout(), timeout);
assert!(matches!(*adapter.timeout_task.lock().await, None));
fs::remove_file(stronghold_path).unwrap();
}
#[tokio::test]
async fn stronghold_password_already_set() {
let stronghold_path = "stronghold_password_already_set.stronghold";
let mut adapter = StrongholdAdapter::builder()
.password("drowssap")
.build(stronghold_path)
.unwrap();
adapter.clear_key().await;
assert!(adapter.set_password("drowssap").await.is_ok());
assert!(adapter.set_password("drowssap").await.is_ok());
assert!(adapter.set_password("other_password").await.is_err());
fs::remove_file(stronghold_path).unwrap();
}
}