ockam_identity/identities/storage/
identity_attributes_repository_sql.rs

1use core::str::FromStr;
2use sqlx::encode::IsNull;
3use sqlx::error::BoxDynError;
4use sqlx::*;
5use sqlx_core::any::AnyArgumentBuffer;
6use std::sync::Arc;
7use tracing::debug;
8
9use crate::models::Identifier;
10use crate::{AttributesEntry, IdentityAttributesRepository, TimestampInSeconds};
11use ockam_core::async_trait;
12use ockam_core::Result;
13use ockam_node::database::AutoRetry;
14use ockam_node::database::{FromSqlxError, Nullable, SqlxDatabase, ToVoid};
15
16/// Implementation of [`IdentityAttributesRepository`] trait based on an underlying database
17/// using sqlx as its API, and Sqlite as its driver
18#[derive(Clone)]
19pub struct IdentityAttributesSqlxDatabase {
20    database: SqlxDatabase,
21    node_name: String,
22}
23
24impl IdentityAttributesSqlxDatabase {
25    /// Create a new database
26    pub fn new(database: SqlxDatabase, node_name: &str) -> Self {
27        debug!("create a repository for identity attributes");
28        Self {
29            database,
30            node_name: node_name.to_string(),
31        }
32    }
33
34    /// Create a repository
35    pub fn make_repository(
36        database: SqlxDatabase,
37        node_name: &str,
38    ) -> Arc<dyn IdentityAttributesRepository> {
39        if database.needs_retry() {
40            Arc::new(AutoRetry::new(Self::new(database, node_name)))
41        } else {
42            Arc::new(Self::new(database, node_name))
43        }
44    }
45
46    /// Create a new in-memory database
47    pub async fn create() -> Result<Self> {
48        Ok(Self::new(
49            SqlxDatabase::in_memory("identity attributes").await?,
50            "default",
51        ))
52    }
53}
54
55#[async_trait]
56impl IdentityAttributesRepository for IdentityAttributesSqlxDatabase {
57    async fn get_attributes(
58        &self,
59        identity: &Identifier,
60        attested_by: &Identifier,
61    ) -> Result<Option<AttributesEntry>> {
62        let query = query_as(
63            "SELECT identifier, attributes, added, expires, attested_by FROM identity_attributes WHERE identifier = $1 AND attested_by = $2 AND node_name = $3"
64            )
65            .bind(identity)
66            .bind(attested_by)
67            .bind(&self.node_name);
68        let identity_attributes: Option<IdentityAttributesRow> = query
69            .fetch_optional(&*self.database.pool)
70            .await
71            .into_core()?;
72        Ok(identity_attributes.map(|r| r.attributes()).transpose()?)
73    }
74
75    async fn put_attributes(&self, subject: &Identifier, entry: AttributesEntry) -> Result<()> {
76        let query = query(
77            r#"
78            INSERT INTO identity_attributes (identifier, attributes, added, expires, attested_by, node_name)
79            VALUES ($1, $2, $3, $4, $5, $6)
80            ON CONFLICT (identifier, node_name)
81            DO UPDATE SET attributes = $2, added = $3, expires = $4, attested_by = $5, node_name = $6"#)
82            .bind(subject)
83            .bind(&entry)
84            .bind(entry.added_at())
85            .bind(entry.expires_at())
86            .bind(entry.attested_by())
87            .bind(&self.node_name);
88        query.execute(&*self.database.pool).await.void()
89    }
90
91    // This query is regularly invoked by IdentitiesAttributes to make sure that we expire attributes regularly
92    async fn delete_expired_attributes(&self, now: TimestampInSeconds) -> Result<()> {
93        let query = query("DELETE FROM identity_attributes WHERE expires <= $1 AND node_name = $2")
94            .bind(now)
95            .bind(&self.node_name);
96        query.execute(&*self.database.pool).await.void()
97    }
98}
99
100// Database serialization / deserialization
101
102impl Type<Any> for AttributesEntry {
103    fn type_info() -> <Any as Database>::TypeInfo {
104        <Vec<u8> as Type<Any>>::type_info()
105    }
106}
107
108impl Encode<'_, Any> for AttributesEntry {
109    fn encode_by_ref(&self, buf: &mut AnyArgumentBuffer) -> Result<IsNull, BoxDynError> {
110        <Vec<u8> as Encode<'_, Any>>::encode_by_ref(
111            &ockam_core::cbor_encode_preallocate(self.attrs()).unwrap(),
112            buf,
113        )
114    }
115}
116
117// Low-level representation of a table row
118#[derive(FromRow)]
119struct IdentityAttributesRow {
120    identifier: String,
121    attributes: Vec<u8>,
122    added: i64,
123    expires: Nullable<i64>,
124    attested_by: Nullable<String>,
125}
126
127impl IdentityAttributesRow {
128    #[allow(dead_code)]
129    fn identifier(&self) -> Result<Identifier> {
130        Identifier::from_str(&self.identifier)
131    }
132
133    fn attributes(&self) -> Result<AttributesEntry> {
134        let attributes =
135            minicbor::decode(self.attributes.as_slice()).map_err(SqlxDatabase::map_decode_err)?;
136        let added = TimestampInSeconds(self.added as u64);
137        let expires = self
138            .expires
139            .to_option()
140            .map(|v| TimestampInSeconds(v as u64));
141        let attested_by = self
142            .attested_by
143            .to_option()
144            .map(|v| Identifier::from_str(&v))
145            .transpose()?;
146
147        Ok(AttributesEntry::new(
148            attributes,
149            added,
150            expires,
151            attested_by,
152        ))
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use ockam_core::compat::collections::BTreeMap;
159    use ockam_core::compat::sync::Arc;
160    use ockam_node::database::with_dbs;
161    use std::ops::Add;
162
163    use super::*;
164    use crate::identities;
165    use crate::utils::now;
166
167    #[tokio::test]
168    async fn test_identities_attributes_repository() -> Result<()> {
169        with_dbs(|db| async move {
170            let repository: Arc<dyn IdentityAttributesRepository> =
171                Arc::new(IdentityAttributesSqlxDatabase::new(db, "node"));
172
173            let now = now()?;
174
175            // store and retrieve attributes by identity
176            let identifier1 = create_identity().await?;
177            let attributes1 = create_attributes_entry(&identifier1, now, Some(2.into())).await?;
178            let identifier2 = create_identity().await?;
179            let attributes2 = create_attributes_entry(&identifier2, now, Some(2.into())).await?;
180
181            repository
182                .put_attributes(&identifier1, attributes1.clone())
183                .await?;
184            repository
185                .put_attributes(&identifier2, attributes2.clone())
186                .await?;
187
188            let result = repository
189                .get_attributes(&identifier1, &identifier1)
190                .await?;
191            assert_eq!(result, Some(attributes1.clone()));
192
193            let result = repository
194                .get_attributes(&identifier2, &identifier2)
195                .await?;
196            assert_eq!(result, Some(attributes2.clone()));
197
198            Ok(())
199        })
200        .await
201    }
202
203    #[tokio::test]
204    async fn test_delete_expired_attributes() -> Result<()> {
205        with_dbs(|db| async move {
206            let repository: Arc<dyn IdentityAttributesRepository> =
207                Arc::new(IdentityAttributesSqlxDatabase::new(db, "node"));
208
209            let now = now()?;
210
211            // store some attributes with and without an expiry date
212            let identifier1 = create_identity().await?;
213            let identifier2 = create_identity().await?;
214            let identifier3 = create_identity().await?;
215            let identifier4 = create_identity().await?;
216            let attributes1 = create_attributes_entry(&identifier1, now, Some(1.into())).await?;
217            let attributes2 = create_attributes_entry(&identifier2, now, Some(10.into())).await?;
218            let attributes3 = create_attributes_entry(&identifier3, now, Some(100.into())).await?;
219            let attributes4 = create_attributes_entry(&identifier4, now, None).await?;
220
221            repository
222                .put_attributes(&identifier1, attributes1.clone())
223                .await?;
224            repository
225                .put_attributes(&identifier2, attributes2.clone())
226                .await?;
227            repository
228                .put_attributes(&identifier3, attributes3.clone())
229                .await?;
230            repository
231                .put_attributes(&identifier4, attributes4.clone())
232                .await?;
233
234            // delete all the attributes with an expiry date <= now + 10
235            // only attributes1 and attributes2 must be deleted
236            repository.delete_expired_attributes(now.add(10)).await?;
237
238            let result = repository
239                .get_attributes(&identifier1, &identifier1)
240                .await?;
241            assert_eq!(result, None);
242
243            let result = repository
244                .get_attributes(&identifier2, &identifier2)
245                .await?;
246            assert_eq!(result, None);
247
248            let result = repository
249                .get_attributes(&identifier3, &identifier3)
250                .await?;
251            assert_eq!(
252                result,
253                Some(attributes3),
254                "attributes 3 are not expired yet"
255            );
256
257            let result = repository
258                .get_attributes(&identifier4, &identifier4)
259                .await?;
260            assert_eq!(
261                result,
262                Some(attributes4),
263                "attributes 4 have no expiry date"
264            );
265
266            Ok(())
267        })
268        .await
269    }
270
271    // HELPERS
272    async fn create_attributes_entry(
273        identifier: &Identifier,
274        now: TimestampInSeconds,
275        ttl: Option<TimestampInSeconds>,
276    ) -> Result<AttributesEntry> {
277        Ok(AttributesEntry::new(
278            BTreeMap::from([
279                ("name".as_bytes().to_vec(), "alice".as_bytes().to_vec()),
280                ("age".as_bytes().to_vec(), "20".as_bytes().to_vec()),
281            ]),
282            now,
283            ttl.map(|ttl| now + ttl),
284            Some(identifier.clone()),
285        ))
286    }
287
288    async fn create_identity() -> Result<Identifier> {
289        let identities = identities().await?;
290        identities.identities_creation().create_identity().await
291    }
292}