use std::str::FromStr;
use std::sync::Arc;
use sqlx::sqlite::SqliteRow;
use sqlx::FromRow;
use sqlx::*;
use time::OffsetDateTime;
use ockam::identity::Identifier;
use ockam::{FromSqlxError, SqlxDatabase, ToSqlxType, ToVoid};
use ockam_core::async_trait;
use ockam_core::Result;
use crate::cli_state::enrollments::IdentityEnrollment;
use crate::cli_state::EnrollmentsRepository;
pub struct EnrollmentsSqlxDatabase {
database: Arc<SqlxDatabase>,
}
impl EnrollmentsSqlxDatabase {
pub fn new(database: Arc<SqlxDatabase>) -> Self {
debug!("create a repository for enrollments");
Self { database }
}
#[allow(unused)]
pub async fn create() -> Result<Arc<EnrollmentsSqlxDatabase>> {
Ok(Arc::new(Self::new(
SqlxDatabase::in_memory("enrollments").await?,
)))
}
}
#[async_trait]
impl EnrollmentsRepository for EnrollmentsSqlxDatabase {
async fn set_as_enrolled(&self, identifier: &Identifier) -> Result<()> {
let query = query("INSERT OR REPLACE INTO identity_enrollment VALUES (?, ?)")
.bind(identifier.to_sql())
.bind(OffsetDateTime::now_utc().to_sql());
Ok(query.execute(&self.database.pool).await.void()?)
}
async fn get_enrolled_identities(&self) -> Result<Vec<IdentityEnrollment>> {
let query = query_as(
r#"
SELECT
identity.identifier, named_identity.name, named_identity.is_default,
identity_enrollment.enrolled_at
FROM identity
INNER JOIN identity_enrollment ON
identity.identifier = identity_enrollment.identifier
INNER JOIN named_identity ON
identity.identifier = named_identity.identifier
"#,
)
.bind(None as Option<i64>);
let result: Vec<EnrollmentRow> = query.fetch_all(&self.database.pool).await.into_core()?;
result
.into_iter()
.map(|r| r.identity_enrollment())
.collect::<Result<Vec<_>>>()
}
async fn get_all_identities_enrollments(&self) -> Result<Vec<IdentityEnrollment>> {
let query = query_as(
r#"
SELECT
identity.identifier, named_identity.name, named_identity.is_default,
identity_enrollment.enrolled_at
FROM identity
LEFT JOIN identity_enrollment ON
identity.identifier = identity_enrollment.identifier
INNER JOIN named_identity ON
identity.identifier = named_identity.identifier
"#,
);
let result: Vec<EnrollmentRow> = query.fetch_all(&self.database.pool).await.into_core()?;
result
.into_iter()
.map(|r| r.identity_enrollment())
.collect::<Result<Vec<_>>>()
}
async fn is_default_identity_enrolled(&self) -> Result<bool> {
let query = query(
r#"
SELECT
identity_enrollment.enrolled_at
FROM identity
INNER JOIN identity_enrollment ON
identity.identifier = identity_enrollment.identifier
INNER JOIN named_identity ON
identity.identifier = named_identity.identifier
WHERE
named_identity.is_default = ?
"#,
)
.bind(true.to_sql());
let result: Option<SqliteRow> = query
.fetch_optional(&self.database.pool)
.await
.into_core()?;
Ok(result.map(|_| true).unwrap_or(false))
}
async fn is_identity_enrolled(&self, name: &str) -> Result<bool> {
let query = query(
r#"
SELECT
identity_enrollment.enrolled_at
FROM identity
INNER JOIN identity_enrollment ON
identity.identifier = identity_enrollment.identifier
INNER JOIN named_identity ON
identity.identifier = named_identity.identifier
INNER JOIN named_identity ON
identity.identifier = named_identity.identifier
WHERE
named_identity.name = ?
"#,
)
.bind(name.to_sql());
let result: Option<SqliteRow> = query
.fetch_optional(&self.database.pool)
.await
.into_core()?;
Ok(result.map(|_| true).unwrap_or(false))
}
}
#[derive(FromRow)]
pub struct EnrollmentRow {
identifier: String,
name: Option<String>,
is_default: bool,
enrolled_at: Option<i64>,
}
impl EnrollmentRow {
fn identity_enrollment(&self) -> Result<IdentityEnrollment> {
let identifier = Identifier::from_str(self.identifier.as_str())?;
Ok(IdentityEnrollment::new(
identifier,
self.name.clone(),
self.is_default,
self.enrolled_at(),
))
}
fn enrolled_at(&self) -> Option<OffsetDateTime> {
self.enrolled_at
.map(|at| OffsetDateTime::from_unix_timestamp(at).unwrap_or(OffsetDateTime::now_utc()))
}
}
#[cfg(test)]
mod tests {
use crate::cli_state::{EnrollmentsRepository, IdentitiesRepository, IdentitiesSqlxDatabase};
use ockam::identity::{
identities, ChangeHistoryRepository, ChangeHistorySqlxDatabase, Identity,
};
use super::*;
#[tokio::test]
async fn test_identities_enrollment_repository() -> Result<()> {
let db = create_database().await?;
let repository = create_repository(db.clone());
let identity1 = create_identity(db.clone(), "identity1").await?;
create_identity(db.clone(), "identity2").await?;
repository.set_as_enrolled(identity1.identifier()).await?;
let result = repository.get_all_identities_enrollments().await?;
assert_eq!(result.len(), 2);
let result = repository.get_enrolled_identities().await?;
assert_eq!(result.len(), 1);
let result = repository.is_default_identity_enrolled().await?;
assert!(result);
Ok(())
}
async fn create_identity(db: Arc<SqlxDatabase>, name: &str) -> Result<Identity> {
let identities = identities().await?;
let identifier = identities.identities_creation().create_identity().await?;
let identity = identities.get_identity(&identifier).await?;
store_identity(db, name, identity).await
}
async fn store_identity(
db: Arc<SqlxDatabase>,
name: &str,
identity: Identity,
) -> Result<Identity> {
let change_history_repository = create_change_history_repository(db.clone()).await?;
let identities_repository = create_identities_repository(db).await?;
change_history_repository
.store_change_history(identity.identifier(), identity.change_history().clone())
.await?;
identities_repository
.store_named_identity(identity.identifier(), name, "vault")
.await?;
if name == "identity1" {
identities_repository
.set_as_default_by_identifier(identity.identifier())
.await?;
}
Ok(identity)
}
fn create_repository(db: Arc<SqlxDatabase>) -> Arc<dyn EnrollmentsRepository> {
Arc::new(EnrollmentsSqlxDatabase::new(db))
}
async fn create_database() -> Result<Arc<SqlxDatabase>> {
SqlxDatabase::in_memory("enrollments-test").await
}
async fn create_change_history_repository(
db: Arc<SqlxDatabase>,
) -> Result<Arc<dyn ChangeHistoryRepository>> {
Ok(Arc::new(ChangeHistorySqlxDatabase::new(db)))
}
async fn create_identities_repository(
db: Arc<SqlxDatabase>,
) -> Result<Arc<dyn IdentitiesRepository>> {
Ok(Arc::new(IdentitiesSqlxDatabase::new(db)))
}
}