use core::str::FromStr;
use sqlx::*;
use std::sync::Arc;
use crate::cli_state::{IdentitiesRepository, NamedIdentity};
use ockam::identity::Identifier;
use ockam_core::async_trait;
use ockam_core::Result;
use ockam_node::database::AutoRetry;
use ockam_node::database::{Boolean, FromSqlxError, SqlxDatabase, ToVoid};
#[derive(Clone)]
pub struct IdentitiesSqlxDatabase {
database: SqlxDatabase,
}
impl IdentitiesSqlxDatabase {
pub fn new(database: SqlxDatabase) -> Self {
debug!("create a repository for identities");
Self { database }
}
pub fn make_repository(database: SqlxDatabase) -> Arc<dyn IdentitiesRepository> {
if database.needs_retry() {
Arc::new(AutoRetry::new(Self::new(database)))
} else {
Arc::new(Self::new(database))
}
}
pub async fn create() -> Result<Self> {
Ok(Self::new(SqlxDatabase::in_memory("identities").await?))
}
}
#[async_trait]
impl IdentitiesRepository for IdentitiesSqlxDatabase {
async fn store_named_identity(
&self,
identifier: &Identifier,
name: &str,
vault_name: &str,
) -> Result<NamedIdentity> {
let mut transaction = self.database.begin().await.into_core()?;
let query1 = query_scalar(
"SELECT EXISTS(SELECT 1 FROM named_identity WHERE is_default = $1 AND name = $2)",
)
.bind(true)
.bind(name);
let is_already_default: Boolean = query1.fetch_one(&mut *transaction).await.into_core()?;
let is_already_default = is_already_default.to_bool();
let query2 = query(
r#"
INSERT INTO named_identity (identifier, name, vault_name, is_default)
VALUES ($1, $2, $3, $4)
ON CONFLICT (identifier)
DO UPDATE SET name = $2, vault_name = $3, is_default = $4"#,
)
.bind(identifier)
.bind(name)
.bind(vault_name)
.bind(is_already_default);
query2.execute(&mut *transaction).await.void()?;
transaction.commit().await.void()?;
Ok(NamedIdentity::new(
identifier.clone(),
name.to_string(),
vault_name.to_string(),
is_already_default,
))
}
async fn delete_identity(&self, name: &str) -> Result<Option<Identifier>> {
let mut transaction = self.database.begin().await.into_core()?;
let query1 = query_as(
"SELECT identifier, name, vault_name, is_default FROM named_identity WHERE name = $1",
)
.bind(name);
let row: Option<NamedIdentityRow> =
query1.fetch_optional(&mut *transaction).await.into_core()?;
let named_identity = row.map(|r| r.named_identity()).transpose()?;
let result = match named_identity {
None => None,
Some(named_identity) => {
let query2 = query("DELETE FROM named_identity WHERE name = $1").bind(name);
query2.execute(&mut *transaction).await.void()?;
if named_identity.is_default() {
if let Some(other_name) =
query_scalar::<_, String>("SELECT name FROM named_identity")
.fetch_optional(&mut *transaction)
.await
.into_core()?
{
let query3 =
query("UPDATE named_identity SET is_default = $1 WHERE name = $2")
.bind(true)
.bind(other_name);
query3.execute(&mut *transaction).await.void()?
}
}
Some(named_identity.identifier())
}
};
transaction.commit().await.void()?;
Ok(result)
}
async fn delete_identity_by_identifier(
&self,
identifier: &Identifier,
) -> Result<Option<String>> {
if let Some(name) = self.get_identity_name_by_identifier(identifier).await? {
self.delete_identity(&name).await?;
Ok(Some(name))
} else {
Ok(None)
}
}
async fn update_name(&self, identifier: &Identifier, name: &str) -> Result<()> {
query("UPDATE named_identity SET name = $1 WHERE identifier = $2")
.bind(name)
.bind(identifier)
.execute(&*self.database.pool)
.await
.void()
}
async fn get_identifier(&self, name: &str) -> Result<Option<Identifier>> {
let query = query_as(
"SELECT identifier, name, vault_name, is_default FROM named_identity WHERE name = $1",
)
.bind(name);
let row: Option<NamedIdentityRow> = query
.fetch_optional(&*self.database.pool)
.await
.into_core()?;
row.map(|r| r.identifier()).transpose()
}
async fn get_identity_name_by_identifier(
&self,
identifier: &Identifier,
) -> Result<Option<String>> {
let query =
query_as("SELECT identifier, name, vault_name, is_default FROM named_identity WHERE identifier = $1").bind(identifier);
let row: Option<NamedIdentityRow> = query
.fetch_optional(&*self.database.pool)
.await
.into_core()?;
Ok(row.map(|r| r.name()))
}
async fn get_named_identity(&self, name: &str) -> Result<Option<NamedIdentity>> {
let query = query_as(
"SELECT identifier, name, vault_name, is_default FROM named_identity WHERE name = $1",
)
.bind(name);
let row: Option<NamedIdentityRow> = query
.fetch_optional(&*self.database.pool)
.await
.into_core()?;
row.map(|r| r.named_identity()).transpose()
}
async fn get_named_identity_by_identifier(
&self,
identifier: &Identifier,
) -> Result<Option<NamedIdentity>> {
let query =
query_as("SELECT identifier, name, vault_name, is_default FROM named_identity WHERE identifier = $1").bind(identifier);
let row: Option<NamedIdentityRow> = query
.fetch_optional(&*self.database.pool)
.await
.into_core()?;
row.map(|r| r.named_identity()).transpose()
}
async fn get_named_identities(&self) -> Result<Vec<NamedIdentity>> {
let query = query_as("SELECT identifier, name, vault_name, is_default FROM named_identity");
let row: Vec<NamedIdentityRow> = query.fetch_all(&*self.database.pool).await.into_core()?;
row.iter().map(|r| r.named_identity()).collect()
}
async fn get_named_identities_by_vault_name(
&self,
vault_name: &str,
) -> Result<Vec<NamedIdentity>> {
let query = query_as("SELECT identifier, name, vault_name, is_default FROM named_identity WHERE vault_name = $1").bind(vault_name);
let row: Vec<NamedIdentityRow> = query.fetch_all(&*self.database.pool).await.into_core()?;
row.iter().map(|r| r.named_identity()).collect()
}
async fn set_as_default(&self, name: &str) -> Result<()> {
let mut transaction = self.database.begin().await.into_core()?;
let query1 = query("UPDATE named_identity SET is_default = $1 WHERE name = $2")
.bind(true)
.bind(name);
query1.execute(&mut *transaction).await.void()?;
let query2 = query("UPDATE named_identity SET is_default = $1 WHERE name <> $2")
.bind(false)
.bind(name);
query2.execute(&mut *transaction).await.void()?;
transaction.commit().await.void()
}
async fn set_as_default_by_identifier(&self, identifier: &Identifier) -> Result<()> {
let mut transaction = self.database.begin().await.into_core()?;
let query1 = query("UPDATE named_identity SET is_default = $1 WHERE identifier = $2")
.bind(true)
.bind(identifier);
query1.execute(&mut *transaction).await.void()?;
let query2 = query("UPDATE named_identity SET is_default = $1 WHERE identifier <> $2")
.bind(false)
.bind(identifier);
query2.execute(&mut *transaction).await.void()?;
transaction.commit().await.void()
}
async fn get_default_named_identity(&self) -> Result<Option<NamedIdentity>> {
let query =
query_as("SELECT identifier, name, vault_name, is_default FROM named_identity WHERE is_default = $1").bind(true);
let row: Option<NamedIdentityRow> = query
.fetch_optional(&*self.database.pool)
.await
.into_core()?;
row.map(|r| r.named_identity()).transpose()
}
}
#[derive(sqlx::FromRow)]
pub(crate) struct NamedIdentityRow {
identifier: String,
name: String,
vault_name: String,
is_default: Boolean,
}
impl NamedIdentityRow {
pub(crate) fn identifier(&self) -> Result<Identifier> {
Identifier::from_str(&self.identifier)
}
pub(crate) fn name(&self) -> String {
self.name.clone()
}
#[allow(unused)]
pub(crate) fn vault_name(&self) -> String {
self.vault_name.clone()
}
pub(crate) fn named_identity(&self) -> Result<NamedIdentity> {
Ok(NamedIdentity::new(
self.identifier()?,
self.name.clone(),
self.vault_name.clone(),
self.is_default.to_bool(),
))
}
}
#[cfg(test)]
mod tests {
use ockam::identity::identities;
use ockam_core::compat::sync::Arc;
use ockam_node::database::with_dbs;
use super::*;
#[tokio::test]
async fn test_identities_repository_named_identities() -> Result<()> {
with_dbs(|db| async move {
let repository: Arc<dyn IdentitiesRepository> =
Arc::new(IdentitiesSqlxDatabase::new(db));
let identifier1 = create_identity().await?;
repository
.store_named_identity(&identifier1, "name1", "vault")
.await?;
let identifier2 = create_identity().await?;
repository
.store_named_identity(&identifier2, "name2", "vault")
.await?;
let result = repository.get_identifier("name1").await?;
assert_eq!(result, Some(identifier1.clone()));
let result = repository
.get_identity_name_by_identifier(&identifier1)
.await?;
assert_eq!(result, Some("name1".into()));
let result = repository.get_named_identity("name2").await?;
assert_eq!(result.map(|n| n.identifier()), Some(identifier2.clone()));
let result = repository.get_named_identities().await?;
assert_eq!(
result.iter().map(|n| n.identifier()).collect::<Vec<_>>(),
vec![identifier1.clone(), identifier2.clone()]
);
repository.delete_identity("name1").await?;
let result = repository.get_named_identities().await?;
assert_eq!(
result.iter().map(|n| n.identifier()).collect::<Vec<_>>(),
vec![identifier2.clone()]
);
Ok(())
})
.await
}
#[tokio::test]
async fn test_identities_repository_default_identities() -> Result<()> {
with_dbs(|db| async move {
let repository: Arc<dyn IdentitiesRepository> =
Arc::new(IdentitiesSqlxDatabase::new(db));
let identifier1 = create_identity().await?;
let named_identity1 = repository
.store_named_identity(&identifier1, "name1", "vault")
.await?;
let identifier2 = create_identity().await?;
let named_identity2 = repository
.store_named_identity(&identifier2, "name2", "vault")
.await?;
repository
.set_as_default_by_identifier(&identifier1)
.await?;
let result = repository.get_default_named_identity().await?;
assert_eq!(result, Some(named_identity1.set_as_default()));
repository.set_as_default("name2").await?;
let result = repository.get_default_named_identity().await?;
assert_eq!(result, Some(named_identity2.set_as_default()));
let result = repository.get_named_identity("name1").await?;
assert!(!result.unwrap().is_default());
let result = repository.get_default_named_identity().await?;
assert_eq!(result.map(|i| i.name()), Some("name2".to_string()));
Ok(())
})
.await
}
#[tokio::test]
async fn test_get_identities_by_vault_name() -> Result<()> {
with_dbs(|db| async move {
let repository: Arc<dyn IdentitiesRepository> =
Arc::new(IdentitiesSqlxDatabase::new(db));
let identifier1 = create_identity().await?;
repository
.store_named_identity(&identifier1, "name1", "vault1")
.await?;
let identifier2 = create_identity().await?;
repository
.store_named_identity(&identifier2, "name2", "vault2")
.await?;
let identifier3 = create_identity().await?;
repository
.store_named_identity(&identifier3, "name3", "vault1")
.await?;
let result = repository
.get_named_identities_by_vault_name("vault1")
.await?;
let names: Vec<String> = result.iter().map(|i| i.name()).collect();
assert_eq!(names, vec!["name1", "name3"]);
Ok(())
})
.await
}
#[tokio::test]
async fn test_update_name() -> Result<()> {
with_dbs(|db| async move {
let repository: Arc<dyn IdentitiesRepository> =
Arc::new(IdentitiesSqlxDatabase::new(db));
let identifier1 = create_identity().await?;
repository
.store_named_identity(&identifier1, "name1", "vault1")
.await?;
repository.update_name(&identifier1, "new-name1").await?;
let result = repository.get_named_identity("new-name1").await?;
assert_eq!(result.map(|i| i.name()), Some("new-name1".to_string()));
Ok(())
})
.await
}
async fn create_identity() -> Result<Identifier> {
let identities = identities().await?;
identities.identities_creation().create_identity().await
}
}