ockam_identity/identities/storage/
credential_repository_sql.rs

1use sqlx::encode::IsNull;
2use sqlx::error::BoxDynError;
3use sqlx::*;
4use sqlx_core::any::AnyArgumentBuffer;
5use std::sync::Arc;
6use tracing::debug;
7
8use crate::models::{CredentialAndPurposeKey, Identifier};
9use crate::{CredentialRepository, TimestampInSeconds};
10use ockam_core::async_trait;
11use ockam_core::Result;
12use ockam_node::database::AutoRetry;
13use ockam_node::database::{FromSqlxError, SqlxDatabase, ToVoid};
14
15/// Implementation of `CredentialRepository` trait based on an underlying database
16/// using sqlx as its API, and Sqlite as its driver
17#[derive(Clone)]
18pub struct CredentialSqlxDatabase {
19    database: SqlxDatabase,
20    node_name: String,
21}
22
23impl CredentialSqlxDatabase {
24    /// Create a new database
25    pub fn new(database: SqlxDatabase, node_name: &str) -> Self {
26        debug!("create a repository for credentials");
27        Self {
28            database,
29            node_name: node_name.to_string(),
30        }
31    }
32
33    /// Create a repository
34    pub fn make_repository(
35        database: SqlxDatabase,
36        node_name: &str,
37    ) -> Arc<dyn CredentialRepository> {
38        if database.needs_retry() {
39            Arc::new(AutoRetry::new(Self::new(database, node_name)))
40        } else {
41            Arc::new(Self::new(database, node_name))
42        }
43    }
44
45    /// Create a new in-memory database
46    pub async fn create() -> Result<Self> {
47        Ok(Self::new(
48            SqlxDatabase::in_memory("credential").await?,
49            "default",
50        ))
51    }
52}
53
54impl CredentialSqlxDatabase {
55    /// Return all cached credentials for the given node
56    pub async fn get_all(&self) -> Result<Vec<(CredentialAndPurposeKey, String)>> {
57        let query = query_as("SELECT credential, scope FROM credential WHERE node_name = $1")
58            .bind(self.node_name.clone());
59
60        let cached_credential: Vec<CachedCredentialAndScopeRow> =
61            query.fetch_all(&*self.database.pool).await.into_core()?;
62
63        let res = cached_credential
64            .into_iter()
65            .map(|c| {
66                let cred = c.credential()?;
67                Ok((cred, c.scope().to_string()))
68            })
69            .collect::<Result<Vec<_>>>()?;
70
71        Ok(res)
72    }
73}
74
75#[async_trait]
76impl CredentialRepository for CredentialSqlxDatabase {
77    async fn get(
78        &self,
79        subject: &Identifier,
80        issuer: &Identifier,
81        scope: &str,
82    ) -> Result<Option<CredentialAndPurposeKey>> {
83        let query = query_as(
84            "SELECT credential FROM credential WHERE subject_identifier = $1 AND issuer_identifier = $2 AND scope = $3 AND node_name = $4"
85            )
86            .bind(subject)
87            .bind(issuer)
88            .bind(scope)
89            .bind(self.node_name.clone());
90        let cached_credential: Option<CachedCredentialRow> = query
91            .fetch_optional(&*self.database.pool)
92            .await
93            .into_core()?;
94        cached_credential.map(|c| c.credential()).transpose()
95    }
96
97    async fn put(
98        &self,
99        subject: &Identifier,
100        issuer: &Identifier,
101        scope: &str,
102        expires_at: TimestampInSeconds,
103        credential: CredentialAndPurposeKey,
104    ) -> Result<()> {
105        let query = query(
106            r#"INSERT INTO credential (subject_identifier, issuer_identifier, scope, credential, expires_at, node_name)
107            VALUES ($1, $2, $3, $4, $5, $6)
108            ON CONFLICT (subject_identifier, issuer_identifier, scope)
109            DO UPDATE SET credential = $4, expires_at = $5, node_name = $6"#)
110            .bind(subject)
111            .bind(issuer)
112            .bind(scope)
113            .bind(credential)
114            .bind(expires_at)
115            .bind(self.node_name.clone());
116        query.execute(&*self.database.pool).await.void()
117    }
118
119    async fn delete(&self, subject: &Identifier, issuer: &Identifier, scope: &str) -> Result<()> {
120        let query = query("DELETE FROM credential WHERE subject_identifier = $1 AND issuer_identifier = $2 AND scope = $3 AND node_name = $4")
121            .bind(subject)
122            .bind(issuer)
123            .bind(scope)
124            .bind(self.node_name.clone());
125        query.execute(&*self.database.pool).await.void()
126    }
127}
128
129// Database serialization / deserialization
130
131impl Type<Any> for CredentialAndPurposeKey {
132    fn type_info() -> <Any as Database>::TypeInfo {
133        <Vec<u8> as Type<Any>>::type_info()
134    }
135}
136
137impl Encode<'_, Any> for CredentialAndPurposeKey {
138    fn encode_by_ref(&self, buf: &mut AnyArgumentBuffer) -> Result<IsNull, BoxDynError> {
139        <Vec<u8> as Encode<'_, Any>>::encode_by_ref(&self.encode_as_cbor_bytes().unwrap(), buf)
140    }
141}
142
143impl Type<Any> for TimestampInSeconds {
144    fn type_info() -> <Any as Database>::TypeInfo {
145        <i64 as Type<Any>>::type_info()
146    }
147}
148
149impl Encode<'_, Any> for TimestampInSeconds {
150    fn encode_by_ref(&self, buf: &mut AnyArgumentBuffer) -> Result<IsNull, BoxDynError> {
151        <i64 as Encode<'_, Any>>::encode_by_ref(&(self.0 as i64), buf)
152    }
153}
154
155// Low-level representation of a table row
156#[derive(FromRow)]
157struct CachedCredentialRow {
158    credential: Vec<u8>,
159}
160
161impl CachedCredentialRow {
162    fn credential(&self) -> Result<CredentialAndPurposeKey> {
163        CredentialAndPurposeKey::decode_from_cbor_bytes(&self.credential)
164    }
165}
166
167#[derive(FromRow)]
168struct CachedCredentialAndScopeRow {
169    credential: Vec<u8>,
170    scope: String,
171}
172
173impl CachedCredentialAndScopeRow {
174    fn credential(&self) -> Result<CredentialAndPurposeKey> {
175        CredentialAndPurposeKey::decode_from_cbor_bytes(&self.credential)
176    }
177    pub fn scope(&self) -> &str {
178        &self.scope
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use ockam_core::compat::sync::Arc;
185    use ockam_node::database::with_dbs;
186    use std::time::Duration;
187
188    use super::*;
189    use crate::identities;
190    use crate::models::CredentialSchemaIdentifier;
191    use crate::utils::AttributesBuilder;
192
193    #[tokio::test]
194    async fn test_cached_credential_repository() -> Result<()> {
195        with_dbs(|db| async move {
196            let credentials_database = CredentialSqlxDatabase::new(db, "node");
197            let repository: Arc<dyn CredentialRepository> = Arc::new(credentials_database.clone());
198
199            let scope = "test".to_string();
200
201            let all = credentials_database.get_all().await?;
202            assert_eq!(all.len(), 0);
203
204            let identities = identities().await?;
205
206            let issuer = identities.identities_creation().create_identity().await?;
207            let subject = identities.identities_creation().create_identity().await?;
208
209            let attributes1 = AttributesBuilder::with_schema(CredentialSchemaIdentifier(1))
210                .with_attribute("key1", "value1")
211                .build();
212            let credential1 = identities
213                .credentials()
214                .credentials_creation()
215                .issue_credential(&issuer, &subject, attributes1, Duration::from_secs(60 * 60))
216                .await?;
217
218            repository
219                .put(
220                    &subject,
221                    &issuer,
222                    &scope,
223                    credential1.get_credential_data()?.expires_at,
224                    credential1.clone(),
225                )
226                .await?;
227
228            let all = credentials_database.get_all().await?;
229            assert_eq!(all.len(), 1);
230
231            let credential2 = repository.get(&subject, &issuer, &scope).await?;
232            assert_eq!(credential2, Some(credential1));
233
234            let attributes2 = AttributesBuilder::with_schema(CredentialSchemaIdentifier(1))
235                .with_attribute("key2", "value2")
236                .build();
237            let credential3 = identities
238                .credentials()
239                .credentials_creation()
240                .issue_credential(&issuer, &subject, attributes2, Duration::from_secs(60 * 60))
241                .await?;
242            repository
243                .put(
244                    &subject,
245                    &issuer,
246                    &scope,
247                    credential3.get_credential_data()?.expires_at,
248                    credential3.clone(),
249                )
250                .await?;
251            let all = credentials_database.get_all().await?;
252            assert_eq!(all.len(), 1);
253            let credential4 = repository.get(&subject, &issuer, &scope).await?;
254            assert_eq!(credential4, Some(credential3));
255
256            repository.delete(&subject, &issuer, &scope).await?;
257            let result = repository.get(&subject, &issuer, &scope).await?;
258            assert_eq!(result, None);
259
260            Ok(())
261        })
262        .await
263    }
264}