mod common;
mod error;
mod migration;
mod secret;
mod storage;
use alloc::sync::Weak;
use std::{
path::{Path, PathBuf},
sync::Arc,
time::Duration,
};
use iota_stronghold::{KeyProvider, SnapshotPath, Stronghold};
use log::{debug, error, warn};
use tokio::{
sync::{Mutex, MutexGuard},
task::JoinHandle,
};
use zeroize::Zeroizing;
pub(crate) use self::common::PRIVATE_DATA_CLIENT_PATH;
pub use self::error::Error;
use super::{storage::StorageAdapter, utils::Password};
#[derive(Debug)]
pub struct StrongholdAdapter {
stronghold: Arc<Mutex<Stronghold>>,
key_provider: Arc<Mutex<Option<KeyProvider>>>,
timeout: Option<Duration>,
timeout_task: Arc<Mutex<Option<TaskHandle>>>,
pub(crate) snapshot_path: PathBuf,
}
fn check_or_create_snapshot(
stronghold: &Stronghold,
key_provider: &KeyProvider,
snapshot_path: &SnapshotPath,
) -> Result<(), Error> {
let result = stronghold.load_client_from_snapshot(PRIVATE_DATA_CLIENT_PATH, key_provider, snapshot_path);
match result {
Err(iota_stronghold::ClientError::SnapshotFileMissing(_)) => {
stronghold.create_client(PRIVATE_DATA_CLIENT_PATH)?;
stronghold.commit_with_keyprovider(snapshot_path, key_provider)?;
}
Err(iota_stronghold::ClientError::ClientAlreadyLoaded(_)) => {
stronghold.get_client(PRIVATE_DATA_CLIENT_PATH)?;
}
Err(iota_stronghold::ClientError::Inner(ref err_msg)) => {
if err_msg.contains("XCHACHA20-POLY1305") || err_msg.contains("BadFileKey") {
return Err(Error::InvalidPassword);
} else if err_msg.contains("unsupported version") {
if err_msg.contains("expected [3, 0], found [2, 0]") {
return Err(Error::UnsupportedSnapshotVersion { found: 2, expected: 3 });
} else {
panic!("unsupported version mismatch");
}
}
}
_ => {}
}
Ok(())
}
#[derive(Default, Debug)]
pub struct StrongholdAdapterBuilder {
stronghold: Option<Stronghold>,
key_provider: Option<KeyProvider>,
timeout: Option<Duration>,
}
impl StrongholdAdapterBuilder {
pub fn stronghold(mut self, stronghold: impl Into<Option<Stronghold>>) -> Self {
self.stronghold = stronghold.into();
self
}
pub fn key_provider(mut self, key_provider: impl Into<Option<KeyProvider>>) -> Self {
self.key_provider = key_provider.into();
self
}
pub fn timeout(mut self, timeout: impl Into<Option<Duration>>) -> Self {
self.timeout = timeout.into();
self
}
pub fn password(mut self, password: impl Into<Password>) -> Self {
self.key_provider
.replace(self::common::key_provider_from_password(password.into()));
self
}
pub fn build<P: AsRef<Path>>(self, snapshot_path: P) -> Result<StrongholdAdapter, Error> {
if snapshot_path.as_ref().is_dir() {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
format!("Path is not a file: {:?}", snapshot_path.as_ref().to_path_buf()),
)
.into());
}
let stronghold = self.stronghold.unwrap_or_default();
#[cfg(test)]
iota_stronghold::engine::snapshot::try_set_encrypt_work_factor(0).unwrap();
if let Some(key_provider) = &self.key_provider {
check_or_create_snapshot(&stronghold, key_provider, &SnapshotPath::from_path(&snapshot_path))?;
}
let has_key_provider = self.key_provider.is_some();
let key_provider = Arc::new(Mutex::new(self.key_provider));
let stronghold = Arc::new(Mutex::new(stronghold));
let timeout_task = Arc::new(Mutex::new(None));
if let (true, Some(timeout)) = (has_key_provider, self.timeout) {
let weak = Arc::downgrade(&timeout_task);
*timeout_task.try_lock().unwrap() = Some(tokio::spawn(task_key_clear(
weak,
stronghold.clone(),
key_provider.clone(),
timeout,
)));
}
Ok(StrongholdAdapter {
stronghold,
key_provider,
timeout: self.timeout,
timeout_task,
snapshot_path: snapshot_path.as_ref().to_path_buf(),
})
}
}
impl StrongholdAdapter {
pub fn snapshot_path(&self) -> &Path {
self.snapshot_path.as_path()
}
pub fn builder() -> StrongholdAdapterBuilder {
StrongholdAdapterBuilder::default()
}
pub async fn is_key_available(&self) -> bool {
self.key_provider.lock().await.is_some()
}
pub async fn set_password(&self, password: impl Into<Password> + Send) -> Result<(), Error> {
let password = password.into();
let mut key_provider_guard = self.key_provider.lock().await;
let key_provider = self::common::key_provider_from_password(password);
if let Some(old_key_provider) = &*key_provider_guard {
if old_key_provider.try_unlock()? != key_provider.try_unlock()? {
return Err(Error::InvalidPassword);
}
}
let snapshot_path = SnapshotPath::from_path(&self.snapshot_path);
let stronghold = self.stronghold.lock().await;
check_or_create_snapshot(&stronghold, &key_provider, &snapshot_path)?;
*key_provider_guard = Some(key_provider);
drop(key_provider_guard);
if let Some(timeout) = self.timeout {
if let Some(timeout_task) = self.timeout_task.lock().await.take() {
timeout_task.abort();
}
let key_provider = self.key_provider.clone();
*self.timeout_task.lock().await = Some(tokio::spawn(task_key_clear(
Arc::downgrade(&self.timeout_task),
self.stronghold.clone(),
key_provider,
timeout,
)));
}
Ok(())
}
pub async fn change_password(&self, new_password: impl Into<Password> + Send) -> Result<(), Error> {
let new_password = new_password.into();
if let Some(timeout_task) = self.timeout_task.lock().await.take() {
timeout_task.abort();
}
self.write_stronghold_snapshot(None).await?;
let mut values = Vec::new();
let keys_to_re_encrypt = self
.stronghold
.lock()
.await
.get_client(PRIVATE_DATA_CLIENT_PATH)?
.store()
.keys()?
.into_iter()
.map(|k| unsafe { String::from_utf8_unchecked(k) })
.collect::<Vec<_>>();
for key in keys_to_re_encrypt {
let value = match self.get_bytes(&key).await {
Err(err) => {
error!("an error occurred during the re-encryption of Stronghold Store: {err}");
if let Some(timeout) = self.timeout {
let key_provider = self.key_provider.clone();
*self.timeout_task.lock().await = Some(tokio::spawn(task_key_clear(
Arc::downgrade(&self.timeout_task),
self.stronghold.clone(),
key_provider,
timeout,
)));
}
return Err(err);
}
Ok(None) => continue,
Ok(Some(value)) => Zeroizing::new(value),
};
values.push((key, value));
}
let old_key_provider = {
let mut lock = self.key_provider.lock().await;
let old_key_provider = lock.take();
*lock = Some(self::common::key_provider_from_password(new_password));
old_key_provider
};
for (key, value) in values {
if let Err(err) = self.set_bytes(&key, &value).await {
error!("an error occurred during the re-encryption of Stronghold store: {err}");
*self.key_provider.lock().await = old_key_provider;
self.read_stronghold_snapshot().await?;
if let Some(timeout) = self.timeout {
let key_provider = self.key_provider.clone();
*self.timeout_task.lock().await = Some(tokio::spawn(task_key_clear(
Arc::downgrade(&self.timeout_task),
self.stronghold.clone(),
key_provider,
timeout,
)));
}
return Err(err);
}
}
self.write_stronghold_snapshot(None).await?;
if let Some(timeout) = self.timeout {
let key_provider = self.key_provider.clone();
*self.timeout_task.lock().await = Some(tokio::spawn(task_key_clear(
Arc::downgrade(&self.timeout_task),
self.stronghold.clone(),
key_provider,
timeout,
)));
}
Ok(())
}
pub async fn clear_key(&self) {
if let Some(timeout_task) = self.timeout_task.lock().await.take() {
timeout_task.abort();
}
if self.is_key_available().await {
if let Err(err) = self.unload_stronghold_snapshot().await {
warn!("failed to unload Stronghold while clearing the key: {err}");
}
}
self.key_provider.lock().await.take();
debug!("cleared stronghold key");
}
pub fn get_timeout(&self) -> Option<Duration> {
self.timeout
}
pub async fn set_timeout(&mut self, new_timeout: Option<Duration>) {
if let Some(timeout_task) = self.timeout_task.lock().await.take() {
timeout_task.abort();
}
self.timeout = new_timeout;
if let (Some(_), Some(timeout)) = (self.key_provider.lock().await.as_ref(), self.timeout) {
let key_provider = self.key_provider.clone();
*self.timeout_task.lock().await = Some(tokio::spawn(task_key_clear(
Arc::downgrade(&self.timeout_task),
self.stronghold.clone(),
key_provider,
timeout,
)));
}
}
pub async fn restart_key_clearing_task(&mut self) {
self.set_timeout(self.get_timeout()).await;
}
pub async fn read_stronghold_snapshot(&self) -> Result<(), Error> {
let locked_key_provider = self.key_provider.lock().await;
let key_provider = if let Some(key_provider) = &*locked_key_provider {
key_provider
} else {
return Err(Error::KeyCleared);
};
self.stronghold.lock().await.load_client_from_snapshot(
PRIVATE_DATA_CLIENT_PATH,
key_provider,
&SnapshotPath::from_path(&self.snapshot_path),
)?;
Ok(())
}
pub async fn write_stronghold_snapshot(&self, snapshot_path: Option<&Path>) -> Result<(), Error> {
if let Some(p) = snapshot_path {
if p.is_dir() {
return Err(
std::io::Error::new(std::io::ErrorKind::Other, format!("Path is not a file: {:?}", p)).into(),
);
}
}
let locked_key_provider = self.key_provider.lock().await;
let key_provider = if let Some(key_provider) = &*locked_key_provider {
key_provider
} else {
return Err(Error::KeyCleared);
};
self.stronghold.lock().await.commit_with_keyprovider(
&SnapshotPath::from_path(snapshot_path.unwrap_or(&self.snapshot_path)),
key_provider,
)?;
Ok(())
}
pub async fn unload_stronghold_snapshot(&self) -> Result<(), Error> {
self.write_stronghold_snapshot(None).await?;
self.stronghold.lock().await.clear()?;
Ok(())
}
pub async fn inner(&self) -> MutexGuard<'_, Stronghold> {
self.stronghold.lock().await
}
}
type TaskHandle = JoinHandle<()>;
async fn task_key_clear(
task: Weak<Mutex<Option<TaskHandle>>>,
stronghold: Arc<Mutex<Stronghold>>,
key_provider: Arc<Mutex<Option<KeyProvider>>>,
timeout: Duration,
) {
tokio::time::sleep(timeout).await;
if let Some(task) = task.upgrade() {
let mut lock = task.lock().await;
lock.take();
debug!("StrongholdAdapter is purging the key");
key_provider.lock().await.take();
if let Err(e) = stronghold.lock().await.clear() {
log::error!("Failed to clear stronghold keys: {e}");
}
drop(lock);
}
}
#[cfg(test)]
mod tests {
use std::fs;
use pretty_assertions::assert_eq;
use super::*;
#[tokio::test]
async fn test_clear_key() {
iota_stronghold::engine::snapshot::try_set_encrypt_work_factor(0).unwrap();
let timeout = Duration::from_millis(100);
let stronghold_path = "test_clear_key.stronghold";
let mut adapter = StrongholdAdapter::builder()
.password("drowssap".to_owned())
.timeout(timeout)
.build(stronghold_path)
.unwrap();
tokio::time::sleep(Duration::from_millis(10)).await;
assert!(adapter.key_provider.lock().await.is_some());
assert_eq!(adapter.get_timeout(), Some(timeout));
assert!(adapter.timeout_task.lock().await.is_some());
tokio::time::sleep(Duration::from_millis(200)).await;
assert!(adapter.key_provider.lock().await.is_none());
assert_eq!(adapter.get_timeout(), Some(timeout));
assert!(adapter.timeout_task.lock().await.is_none());
let timeout = None;
adapter.set_timeout(timeout).await;
assert!(adapter.set_password("password".to_owned()).await.is_err());
adapter.clear_key().await;
assert!(adapter.key_provider.lock().await.is_none());
assert_eq!(adapter.get_timeout(), timeout);
assert!(adapter.timeout_task.lock().await.is_none());
adapter.restart_key_clearing_task().await;
assert!(adapter.key_provider.lock().await.is_none());
assert_eq!(adapter.get_timeout(), timeout);
assert!(adapter.timeout_task.lock().await.is_none());
fs::remove_file(stronghold_path).unwrap();
}
#[tokio::test]
async fn stronghold_password_already_set() {
iota_stronghold::engine::snapshot::try_set_encrypt_work_factor(0).unwrap();
let stronghold_path = "stronghold_password_already_set.stronghold";
let adapter = StrongholdAdapter::builder()
.password("drowssap".to_owned())
.build(stronghold_path)
.unwrap();
adapter.clear_key().await;
assert!(adapter.set_password("drowssap".to_owned()).await.is_ok());
assert!(adapter.set_password("drowssap".to_owned()).await.is_ok());
assert!(adapter.set_password("other_password".to_owned()).await.is_err());
fs::remove_file(stronghold_path).unwrap();
}
}