use std::fmt::{Display, Formatter};
use std::path::PathBuf;
use std::sync::Arc;
use ockam::identity::{Identities, Vault};
use ockam_core::errcode::{Kind, Origin};
use ockam_node::database::SqlxDatabase;
use ockam_vault_aws::AwsSigningVault;
use crate::cli_state::{random_name, CliState, Result};
impl CliState {
pub async fn create_named_vault(&self, vault_name: &str) -> Result<NamedVault> {
self.create_a_vault(vault_name, false).await
}
pub async fn create_kms_vault(&self, vault_name: &str) -> Result<NamedVault> {
self.create_a_vault(vault_name, true).await
}
pub async fn set_default_vault(&self, vault_name: &str) -> Result<()> {
Ok(self
.vaults_repository()
.await?
.set_as_default(vault_name)
.await?)
}
pub async fn delete_named_vault(&self, vault_name: &str) -> Result<()> {
let repository = self.vaults_repository().await?;
let vault = repository.get_named_vault(vault_name).await?;
if let Some(vault) = vault {
repository.delete_vault(vault_name).await?;
if vault.path != self.database_path() {
let _ = std::fs::remove_file(vault.path);
}
}
Ok(())
}
pub async fn delete_all_named_vaults(&self) -> Result<()> {
let vaults = self.vaults_repository().await?.get_named_vaults().await?;
for vault in vaults {
self.delete_named_vault(&vault.name()).await?;
}
Ok(())
}
}
impl CliState {
pub async fn get_named_vaults(&self) -> Result<Vec<NamedVault>> {
Ok(self.vaults_repository().await?.get_named_vaults().await?)
}
pub async fn get_named_vault(&self, vault_name: &str) -> Result<NamedVault> {
let result = self
.vaults_repository()
.await?
.get_named_vault(vault_name)
.await?;
result.ok_or_else(|| {
ockam_core::Error::new(
Origin::Api,
Kind::NotFound,
format!("no vault found with name {vault_name}"),
)
.into()
})
}
pub async fn get_default_named_vault(&self) -> Result<NamedVault> {
let result = self.vaults_repository().await?.get_default_vault().await?;
match result {
Some(vault) => Ok(vault),
None => self.create_named_vault(&random_name()).await,
}
}
pub async fn get_named_vault_or_default(
&self,
vault_name: &Option<String>,
) -> Result<NamedVault> {
match vault_name {
Some(name) => self.get_named_vault(name).await,
None => self.get_default_named_vault().await,
}
}
}
impl CliState {
pub async fn make_identities(&self, vault: Vault) -> Result<Arc<Identities>> {
Ok(Identities::builder()
.await?
.with_vault(vault)
.with_change_history_repository(self.change_history_repository().await?)
.with_identity_attributes_repository(self.identity_attributes_repository().await?)
.with_purpose_keys_repository(self.purpose_keys_repository().await?)
.build())
}
}
impl CliState {
async fn create_a_vault(&self, vault_name: &str, is_kms: bool) -> Result<NamedVault> {
let vaults_repository = self.vaults_repository().await?;
let is_default_vault = vaults_repository.get_default_vault().await?.is_none();
let path = if is_default_vault {
self.database_path()
} else {
self.dir().join(vault_name)
};
let mut vault = vaults_repository
.store_vault(vault_name, path, is_kms)
.await?;
if is_default_vault {
vaults_repository.set_as_default(vault_name).await?;
vault = vault.set_as_default();
}
Ok(vault)
}
}
#[derive(Debug, PartialEq, Eq, Clone, serde::Serialize, serde::Deserialize)]
pub struct NamedVault {
name: String,
path: PathBuf,
is_default: bool,
is_kms: bool,
}
impl NamedVault {
pub fn new(name: &str, path: PathBuf, is_default: bool, is_kms: bool) -> Self {
Self {
name: name.to_string(),
path,
is_default,
is_kms,
}
}
pub fn name(&self) -> String {
self.name.clone()
}
pub fn path(&self) -> PathBuf {
self.path.clone()
}
pub fn is_default(&self) -> bool {
self.is_default
}
pub fn set_as_default(&self) -> NamedVault {
let mut result = self.clone();
result.is_default = true;
result
}
pub fn is_kms(&self) -> bool {
self.is_kms
}
pub async fn vault(&self) -> Result<Vault> {
if self.is_kms {
let mut vault = Vault::create().await?;
let aws_vault = Arc::new(AwsSigningVault::create().await?);
vault.identity_vault = aws_vault.clone();
vault.credential_vault = aws_vault;
Ok(vault)
} else {
Ok(Vault::create_with_database(self.database().await?))
}
}
async fn database(&self) -> Result<Arc<SqlxDatabase>> {
Ok(Arc::new(SqlxDatabase::create(self.path.as_path()).await?))
}
}
impl Display for NamedVault {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
writeln!(f, "Name: {}", self.name)?;
writeln!(
f,
"Type: {}",
match self.is_kms {
true => "AWS KMS",
false => "OCKAM",
}
)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_create_named_vault() -> Result<()> {
let cli = CliState::test().await?;
let named_vault1 = cli.create_named_vault("vault1").await?;
let result = cli.get_named_vault("vault1").await?;
assert_eq!(result, named_vault1.clone());
let named_vault2 = cli.create_named_vault("vault2").await?;
let result = cli.get_named_vaults().await?;
assert_eq!(result, vec![named_vault1.clone(), named_vault2.clone()]);
let result = cli.get_default_named_vault().await?;
assert_eq!(result, named_vault1.clone());
cli.set_default_vault("vault2").await?;
let result = cli.get_default_named_vault().await?;
assert_eq!(result, named_vault2.set_as_default());
cli.delete_named_vault("vault2").await?;
let result = cli.get_default_named_vault().await?;
assert_eq!(result, named_vault1.set_as_default());
cli.delete_all_named_vaults().await?;
let result = cli.get_named_vaults().await?;
assert!(result.is_empty());
Ok(())
}
#[tokio::test]
async fn test_get_default_named_vault() -> Result<()> {
let cli = CliState::test().await?;
let vault = cli.get_default_named_vault().await?;
assert!(vault.is_default());
assert!(vault.path().starts_with(cli.dir()));
Ok(())
}
}