ockam_identity/identities/storage/
credential_repository_sql.rs1use sqlx::encode::IsNull;
2use sqlx::error::BoxDynError;
3use sqlx::*;
4use sqlx_core::any::AnyArgumentBuffer;
5use std::sync::Arc;
6use tracing::debug;
7
8use crate::models::{CredentialAndPurposeKey, Identifier};
9use crate::{CredentialRepository, TimestampInSeconds};
10use ockam_core::async_trait;
11use ockam_core::Result;
12use ockam_node::database::AutoRetry;
13use ockam_node::database::{FromSqlxError, SqlxDatabase, ToVoid};
14
15#[derive(Clone)]
18pub struct CredentialSqlxDatabase {
19 database: SqlxDatabase,
20 node_name: String,
21}
22
23impl CredentialSqlxDatabase {
24 pub fn new(database: SqlxDatabase, node_name: &str) -> Self {
26 debug!("create a repository for credentials");
27 Self {
28 database,
29 node_name: node_name.to_string(),
30 }
31 }
32
33 pub fn make_repository(
35 database: SqlxDatabase,
36 node_name: &str,
37 ) -> Arc<dyn CredentialRepository> {
38 if database.needs_retry() {
39 Arc::new(AutoRetry::new(Self::new(database, node_name)))
40 } else {
41 Arc::new(Self::new(database, node_name))
42 }
43 }
44
45 pub async fn create() -> Result<Self> {
47 Ok(Self::new(
48 SqlxDatabase::in_memory("credential").await?,
49 "default",
50 ))
51 }
52}
53
54impl CredentialSqlxDatabase {
55 pub async fn get_all(&self) -> Result<Vec<(CredentialAndPurposeKey, String)>> {
57 let query = query_as("SELECT credential, scope FROM credential WHERE node_name = $1")
58 .bind(self.node_name.clone());
59
60 let cached_credential: Vec<CachedCredentialAndScopeRow> =
61 query.fetch_all(&*self.database.pool).await.into_core()?;
62
63 let res = cached_credential
64 .into_iter()
65 .map(|c| {
66 let cred = c.credential()?;
67 Ok((cred, c.scope().to_string()))
68 })
69 .collect::<Result<Vec<_>>>()?;
70
71 Ok(res)
72 }
73}
74
75#[async_trait]
76impl CredentialRepository for CredentialSqlxDatabase {
77 async fn get(
78 &self,
79 subject: &Identifier,
80 issuer: &Identifier,
81 scope: &str,
82 ) -> Result<Option<CredentialAndPurposeKey>> {
83 let query = query_as(
84 "SELECT credential FROM credential WHERE subject_identifier = $1 AND issuer_identifier = $2 AND scope = $3 AND node_name = $4"
85 )
86 .bind(subject)
87 .bind(issuer)
88 .bind(scope)
89 .bind(self.node_name.clone());
90 let cached_credential: Option<CachedCredentialRow> = query
91 .fetch_optional(&*self.database.pool)
92 .await
93 .into_core()?;
94 cached_credential.map(|c| c.credential()).transpose()
95 }
96
97 async fn put(
98 &self,
99 subject: &Identifier,
100 issuer: &Identifier,
101 scope: &str,
102 expires_at: TimestampInSeconds,
103 credential: CredentialAndPurposeKey,
104 ) -> Result<()> {
105 let query = query(
106 r#"INSERT INTO credential (subject_identifier, issuer_identifier, scope, credential, expires_at, node_name)
107 VALUES ($1, $2, $3, $4, $5, $6)
108 ON CONFLICT (subject_identifier, issuer_identifier, scope)
109 DO UPDATE SET credential = $4, expires_at = $5, node_name = $6"#)
110 .bind(subject)
111 .bind(issuer)
112 .bind(scope)
113 .bind(credential)
114 .bind(expires_at)
115 .bind(self.node_name.clone());
116 query.execute(&*self.database.pool).await.void()
117 }
118
119 async fn delete(&self, subject: &Identifier, issuer: &Identifier, scope: &str) -> Result<()> {
120 let query = query("DELETE FROM credential WHERE subject_identifier = $1 AND issuer_identifier = $2 AND scope = $3 AND node_name = $4")
121 .bind(subject)
122 .bind(issuer)
123 .bind(scope)
124 .bind(self.node_name.clone());
125 query.execute(&*self.database.pool).await.void()
126 }
127}
128
129impl Type<Any> for CredentialAndPurposeKey {
132 fn type_info() -> <Any as Database>::TypeInfo {
133 <Vec<u8> as Type<Any>>::type_info()
134 }
135}
136
137impl Encode<'_, Any> for CredentialAndPurposeKey {
138 fn encode_by_ref(&self, buf: &mut AnyArgumentBuffer) -> Result<IsNull, BoxDynError> {
139 <Vec<u8> as Encode<'_, Any>>::encode_by_ref(&self.encode_as_cbor_bytes().unwrap(), buf)
140 }
141}
142
143impl Type<Any> for TimestampInSeconds {
144 fn type_info() -> <Any as Database>::TypeInfo {
145 <i64 as Type<Any>>::type_info()
146 }
147}
148
149impl Encode<'_, Any> for TimestampInSeconds {
150 fn encode_by_ref(&self, buf: &mut AnyArgumentBuffer) -> Result<IsNull, BoxDynError> {
151 <i64 as Encode<'_, Any>>::encode_by_ref(&(self.0 as i64), buf)
152 }
153}
154
155#[derive(FromRow)]
157struct CachedCredentialRow {
158 credential: Vec<u8>,
159}
160
161impl CachedCredentialRow {
162 fn credential(&self) -> Result<CredentialAndPurposeKey> {
163 CredentialAndPurposeKey::decode_from_cbor_bytes(&self.credential)
164 }
165}
166
167#[derive(FromRow)]
168struct CachedCredentialAndScopeRow {
169 credential: Vec<u8>,
170 scope: String,
171}
172
173impl CachedCredentialAndScopeRow {
174 fn credential(&self) -> Result<CredentialAndPurposeKey> {
175 CredentialAndPurposeKey::decode_from_cbor_bytes(&self.credential)
176 }
177 pub fn scope(&self) -> &str {
178 &self.scope
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use ockam_core::compat::sync::Arc;
185 use ockam_node::database::with_dbs;
186 use std::time::Duration;
187
188 use super::*;
189 use crate::identities;
190 use crate::models::CredentialSchemaIdentifier;
191 use crate::utils::AttributesBuilder;
192
193 #[tokio::test]
194 async fn test_cached_credential_repository() -> Result<()> {
195 with_dbs(|db| async move {
196 let credentials_database = CredentialSqlxDatabase::new(db, "node");
197 let repository: Arc<dyn CredentialRepository> = Arc::new(credentials_database.clone());
198
199 let scope = "test".to_string();
200
201 let all = credentials_database.get_all().await?;
202 assert_eq!(all.len(), 0);
203
204 let identities = identities().await?;
205
206 let issuer = identities.identities_creation().create_identity().await?;
207 let subject = identities.identities_creation().create_identity().await?;
208
209 let attributes1 = AttributesBuilder::with_schema(CredentialSchemaIdentifier(1))
210 .with_attribute("key1", "value1")
211 .build();
212 let credential1 = identities
213 .credentials()
214 .credentials_creation()
215 .issue_credential(&issuer, &subject, attributes1, Duration::from_secs(60 * 60))
216 .await?;
217
218 repository
219 .put(
220 &subject,
221 &issuer,
222 &scope,
223 credential1.get_credential_data()?.expires_at,
224 credential1.clone(),
225 )
226 .await?;
227
228 let all = credentials_database.get_all().await?;
229 assert_eq!(all.len(), 1);
230
231 let credential2 = repository.get(&subject, &issuer, &scope).await?;
232 assert_eq!(credential2, Some(credential1));
233
234 let attributes2 = AttributesBuilder::with_schema(CredentialSchemaIdentifier(1))
235 .with_attribute("key2", "value2")
236 .build();
237 let credential3 = identities
238 .credentials()
239 .credentials_creation()
240 .issue_credential(&issuer, &subject, attributes2, Duration::from_secs(60 * 60))
241 .await?;
242 repository
243 .put(
244 &subject,
245 &issuer,
246 &scope,
247 credential3.get_credential_data()?.expires_at,
248 credential3.clone(),
249 )
250 .await?;
251 let all = credentials_database.get_all().await?;
252 assert_eq!(all.len(), 1);
253 let credential4 = repository.get(&subject, &issuer, &scope).await?;
254 assert_eq!(credential4, Some(credential3));
255
256 repository.delete(&subject, &issuer, &scope).await?;
257 let result = repository.get(&subject, &issuer, &scope).await?;
258 assert_eq!(result, None);
259
260 Ok(())
261 })
262 .await
263 }
264}