use core::str::FromStr;
use sqlx::*;
use ockam::identity::Identifier;
use ockam_core::async_trait;
use ockam_core::compat::sync::Arc;
use ockam_core::Result;
use ockam_node::database::{FromSqlxError, SqlxDatabase, ToSqlxType, ToVoid};
use crate::cli_state::{IdentitiesRepository, NamedIdentity};
#[derive(Clone)]
pub struct IdentitiesSqlxDatabase {
database: Arc<SqlxDatabase>,
}
impl IdentitiesSqlxDatabase {
pub fn new(database: Arc<SqlxDatabase>) -> Self {
debug!("create a repository for identities");
Self { database }
}
pub async fn create() -> Result<Arc<Self>> {
Ok(Arc::new(Self::new(
SqlxDatabase::in_memory("identities").await?,
)))
}
}
#[async_trait]
impl IdentitiesRepository for IdentitiesSqlxDatabase {
async fn store_named_identity(
&self,
identifier: &Identifier,
name: &str,
vault_name: &str,
) -> Result<NamedIdentity> {
let mut transaction = self.database.begin().await.into_core()?;
let query1 = query_scalar(
"SELECT EXISTS(SELECT 1 FROM named_identity WHERE is_default=$1 AND name=$2)",
)
.bind(true.to_sql())
.bind(name.to_sql());
let is_already_default: bool = query1.fetch_one(&mut *transaction).await.into_core()?;
let query2 = query("INSERT OR REPLACE INTO named_identity VALUES (?, ?, ?, ?)")
.bind(identifier.to_sql())
.bind(name.to_sql())
.bind(vault_name.to_sql())
.bind(is_already_default.to_sql());
query2.execute(&mut *transaction).await.void()?;
transaction.commit().await.void()?;
Ok(NamedIdentity::new(
identifier.clone(),
name.to_string(),
vault_name.to_string(),
is_already_default,
))
}
async fn delete_identity(&self, name: &str) -> Result<Option<Identifier>> {
let mut transaction = self.database.begin().await.into_core()?;
let query1 = query_as("SELECT * FROM named_identity WHERE name=$1").bind(name.to_sql());
let row: Option<NamedIdentityRow> =
query1.fetch_optional(&mut *transaction).await.into_core()?;
let named_identity = row.map(|r| r.named_identity()).transpose()?;
let result = match named_identity {
None => None,
Some(named_identity) => {
let query2 = query("DELETE FROM named_identity WHERE name=?").bind(name.to_sql());
query2.execute(&mut *transaction).await.void()?;
if named_identity.is_default() {
if let Some(other_name) =
query_scalar::<_, String>("SELECT name FROM named_identity")
.fetch_optional(&mut *transaction)
.await
.into_core()?
{
let query3 =
query("UPDATE named_identity SET is_default = ? WHERE name = ?")
.bind(true.to_sql())
.bind(other_name.to_sql());
query3.execute(&mut *transaction).await.void()?
}
}
Some(named_identity.identifier())
}
};
transaction.commit().await.void()?;
Ok(result)
}
async fn delete_identity_by_identifier(
&self,
identifier: &Identifier,
) -> Result<Option<String>> {
if let Some(name) = self.get_identity_name_by_identifier(identifier).await? {
self.delete_identity(&name).await?;
Ok(Some(name))
} else {
Ok(None)
}
}
async fn get_identifier(&self, name: &str) -> Result<Option<Identifier>> {
let query = query_as("SELECT * FROM named_identity WHERE name=$1").bind(name.to_sql());
let row: Option<NamedIdentityRow> = query
.fetch_optional(&self.database.pool)
.await
.into_core()?;
row.map(|r| r.identifier()).transpose()
}
async fn get_identity_name_by_identifier(
&self,
identifier: &Identifier,
) -> Result<Option<String>> {
let query =
query_as("SELECT * FROM named_identity WHERE identifier=$1").bind(identifier.to_sql());
let row: Option<NamedIdentityRow> = query
.fetch_optional(&self.database.pool)
.await
.into_core()?;
Ok(row.map(|r| r.name()))
}
async fn get_named_identity(&self, name: &str) -> Result<Option<NamedIdentity>> {
let query = query_as("SELECT * FROM named_identity WHERE name=$1").bind(name.to_sql());
let row: Option<NamedIdentityRow> = query
.fetch_optional(&self.database.pool)
.await
.into_core()?;
row.map(|r| r.named_identity()).transpose()
}
async fn get_named_identity_by_identifier(
&self,
identifier: &Identifier,
) -> Result<Option<NamedIdentity>> {
let query =
query_as("SELECT * FROM named_identity WHERE identifier=$1").bind(identifier.to_sql());
let row: Option<NamedIdentityRow> = query
.fetch_optional(&self.database.pool)
.await
.into_core()?;
row.map(|r| r.named_identity()).transpose()
}
async fn get_named_identities(&self) -> Result<Vec<NamedIdentity>> {
let query = query_as("SELECT * FROM named_identity");
let row: Vec<NamedIdentityRow> = query.fetch_all(&self.database.pool).await.into_core()?;
row.iter().map(|r| r.named_identity()).collect()
}
async fn set_as_default(&self, name: &str) -> Result<()> {
let mut transaction = self.database.begin().await.into_core()?;
let query1 = query("UPDATE named_identity SET is_default = ? WHERE name = ?")
.bind(true.to_sql())
.bind(name.to_sql());
query1.execute(&mut *transaction).await.void()?;
let query2 = query("UPDATE named_identity SET is_default = ? WHERE name <> ?")
.bind(false.to_sql())
.bind(name.to_sql());
query2.execute(&mut *transaction).await.void()?;
transaction.commit().await.void()
}
async fn set_as_default_by_identifier(&self, identifier: &Identifier) -> Result<()> {
let mut transaction = self.database.begin().await.into_core()?;
let query1 = query("UPDATE named_identity SET is_default = ? WHERE identifier = ?")
.bind(true.to_sql())
.bind(identifier.to_sql());
query1.execute(&mut *transaction).await.void()?;
let query2 = query("UPDATE named_identity SET is_default = ? WHERE identifier <> ?")
.bind(false.to_sql())
.bind(identifier.to_sql());
query2.execute(&mut *transaction).await.void()?;
transaction.commit().await.void()
}
async fn get_default_named_identity(&self) -> Result<Option<NamedIdentity>> {
let query =
query_as("SELECT * FROM named_identity WHERE is_default=$1").bind(true.to_sql());
let row: Option<NamedIdentityRow> = query
.fetch_optional(&self.database.pool)
.await
.into_core()?;
row.map(|r| r.named_identity()).transpose()
}
}
#[derive(sqlx::FromRow)]
pub(crate) struct NamedIdentityRow {
identifier: String,
name: String,
vault_name: String,
is_default: bool,
}
impl NamedIdentityRow {
pub(crate) fn identifier(&self) -> Result<Identifier> {
Identifier::from_str(&self.identifier)
}
pub(crate) fn name(&self) -> String {
self.name.clone()
}
#[allow(unused)]
pub(crate) fn vault_name(&self) -> String {
self.vault_name.clone()
}
pub(crate) fn named_identity(&self) -> Result<NamedIdentity> {
Ok(NamedIdentity::new(
self.identifier()?,
self.name.clone(),
self.vault_name.clone(),
self.is_default,
))
}
}
#[cfg(test)]
mod tests {
use ockam::identity::identities;
use super::*;
#[tokio::test]
async fn test_identities_repository_named_identities() -> Result<()> {
let repository = create_repository().await?;
let identifier1 = create_identity().await?;
repository
.store_named_identity(&identifier1, "name1", "vault")
.await?;
let identifier2 = create_identity().await?;
repository
.store_named_identity(&identifier2, "name2", "vault")
.await?;
let result = repository.get_identifier("name1").await?;
assert_eq!(result, Some(identifier1.clone()));
let result = repository
.get_identity_name_by_identifier(&identifier1)
.await?;
assert_eq!(result, Some("name1".into()));
let result = repository.get_named_identity("name2").await?;
assert_eq!(result.map(|n| n.identifier()), Some(identifier2.clone()));
let result = repository.get_named_identities().await?;
assert_eq!(
result.iter().map(|n| n.identifier()).collect::<Vec<_>>(),
vec![identifier1.clone(), identifier2.clone()]
);
repository.delete_identity("name1").await?;
let result = repository.get_named_identities().await?;
assert_eq!(
result.iter().map(|n| n.identifier()).collect::<Vec<_>>(),
vec![identifier2.clone()]
);
Ok(())
}
#[tokio::test]
async fn test_identities_repository_default_identities() -> Result<()> {
let repository = create_repository().await?;
let identifier1 = create_identity().await?;
let named_identity1 = repository
.store_named_identity(&identifier1, "name1", "vault")
.await?;
let identifier2 = create_identity().await?;
let named_identity2 = repository
.store_named_identity(&identifier2, "name2", "vault")
.await?;
repository
.set_as_default_by_identifier(&identifier1)
.await?;
let result = repository.get_default_named_identity().await?;
assert_eq!(result, Some(named_identity1.set_as_default()));
repository.set_as_default("name2").await?;
let result = repository.get_default_named_identity().await?;
assert_eq!(result, Some(named_identity2.set_as_default()));
let result = repository.get_named_identity("name1").await?;
assert!(!result.unwrap().is_default());
let result = repository.get_default_named_identity().await?;
assert_eq!(result.map(|i| i.name()), Some("name2".to_string()));
Ok(())
}
async fn create_repository() -> Result<Arc<dyn IdentitiesRepository>> {
Ok(IdentitiesSqlxDatabase::create().await?)
}
async fn create_identity() -> Result<Identifier> {
let identities = identities().await?;
identities.identities_creation().create_identity().await
}
}