1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
use core::fmt::{Debug, Formatter};
use sqlx::pool::PoolOptions;
use sqlx::sqlite::SqliteConnectOptions;
use std::ops::Deref;
use std::path::Path;

use ockam_core::errcode::{Kind, Origin};
use sqlx::{ConnectOptions, SqlitePool};
use tokio_retry::strategy::{jitter, FixedInterval};
use tokio_retry::Retry;
use tracing::debug;
use tracing::log::LevelFilter;

use crate::database::migrations::application_migration_set::ApplicationMigrationSet;
use crate::database::migrations::node_migration_set::NodeMigrationSet;
use crate::database::migrations::MigrationSet;
use ockam_core::compat::sync::Arc;
use ockam_core::{Error, Result};

/// The SqlxDatabase struct is used to create a database:
///   - at a given path
///   - with a given schema / or migrations applied to an existing schema
///
/// We use sqlx as our primary interface for interacting with the database
/// The database driver is currently Sqlite
#[derive(Clone)]
pub struct SqlxDatabase {
    /// Pool of connections to the database
    pub pool: Arc<SqlitePool>,
}

impl Debug for SqlxDatabase {
    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
        f.write_str(format!("database options {:?}", self.pool.connect_options()).as_str())
    }
}

impl Deref for SqlxDatabase {
    type Target = SqlitePool;

    fn deref(&self) -> &Self::Target {
        &self.pool
    }
}

impl SqlxDatabase {
    /// Constructor for a database persisted on disk
    pub async fn create(path: impl AsRef<Path>) -> Result<Self> {
        Self::create_impl(path, Some(NodeMigrationSet)).await
    }

    /// Constructor for a database persisted on disk, with a specific schema / migration
    pub async fn create_with_migration(
        path: impl AsRef<Path>,
        migration_set: impl MigrationSet,
    ) -> Result<Self> {
        Self::create_impl(path, Some(migration_set)).await
    }

    /// Constructor for a database persisted on disk without migration
    pub async fn create_no_migration(path: impl AsRef<Path>) -> Result<Self> {
        Self::create_impl(path, None::<NodeMigrationSet>).await
    }

    async fn create_impl(
        path: impl AsRef<Path>,
        migration_set: Option<impl MigrationSet>,
    ) -> Result<Self> {
        path.as_ref()
            .parent()
            .map(std::fs::create_dir_all)
            .transpose()
            .map_err(|e| Error::new(Origin::Api, Kind::Io, e.to_string()))?;

        // creating a new database might be failing a few times
        // if the files are currently being held by another pod which is shutting down.
        // In that case we retry a few times, between 1 and 10 seconds.
        let retry_strategy = FixedInterval::from_millis(1000)
            .map(jitter) // add jitter to delays
            .take(10); // limit to 10 retries

        let db = Retry::spawn(retry_strategy, || async {
            Self::create_at(path.as_ref()).await
        })
        .await?;

        if let Some(migration_set) = migration_set {
            let migrator = migration_set.create_migrator()?;
            migrator.migrate(&db.pool).await?;
        }

        Ok(db)
    }

    /// Create a nodes database in memory
    ///   => this database is deleted on an `ockam reset` command! (contrary to the application database below)
    pub async fn in_memory(usage: &str) -> Result<Self> {
        Self::in_memory_with_migration(usage, NodeMigrationSet).await
    }

    /// Create an application database in memory
    /// The application database which contains the application configurations
    ///   => this database is NOT deleted on an `ockam reset` command!
    pub async fn application_in_memory(usage: &str) -> Result<Self> {
        Self::in_memory_with_migration(usage, ApplicationMigrationSet).await
    }

    /// Create an in-memory database with a specific migration
    pub async fn in_memory_with_migration(
        usage: &str,
        migration_set: impl MigrationSet,
    ) -> Result<Self> {
        debug!("create an in memory database for {usage}");
        let pool = Self::create_in_memory_connection_pool().await?;
        let migrator = migration_set.create_migrator()?;
        migrator.migrate(&pool).await?;
        // FIXME: We should be careful if we run multiple nodes in one process
        let db = SqlxDatabase {
            pool: Arc::new(pool),
        };
        Ok(db)
    }

    async fn create_at(path: &Path) -> Result<Self> {
        // Creates database file if it doesn't exist
        let pool = Self::create_connection_pool(path).await?;
        Ok(SqlxDatabase {
            pool: Arc::new(pool),
        })
    }

    pub(crate) async fn create_connection_pool(path: &Path) -> Result<SqlitePool> {
        let options = SqliteConnectOptions::new()
            .filename(path)
            .create_if_missing(true)
            .log_statements(LevelFilter::Debug);
        let pool = SqlitePool::connect_with(options)
            .await
            .map_err(Self::map_sql_err)?;
        Ok(pool)
    }

    pub(crate) async fn create_in_memory_connection_pool() -> Result<SqlitePool> {
        // SQLite in-memory DB get wiped if there is no connection to it.
        // The below setting tries to ensure there is always an open connection
        let pool_options = PoolOptions::new().idle_timeout(None).max_lifetime(None);

        let pool = pool_options
            .connect("sqlite::memory:")
            .await
            .map_err(Self::map_sql_err)?;
        Ok(pool)
    }

    /// Map a sqlx error into an ockam error
    #[track_caller]
    pub fn map_sql_err(err: sqlx::Error) -> Error {
        Error::new(Origin::Application, Kind::Io, err)
    }

    /// Map a minicbor decode error into an ockam error
    #[track_caller]
    pub fn map_decode_err(err: minicbor::decode::Error) -> Error {
        Error::new(Origin::Application, Kind::Io, err)
    }
}

/// This trait provides some syntax for transforming sqlx errors into ockam errors
pub trait FromSqlxError<T> {
    /// Make an ockam core Error
    fn into_core(self) -> Result<T>;
}

impl<T> FromSqlxError<T> for core::result::Result<T, sqlx::error::Error> {
    #[track_caller]
    fn into_core(self) -> Result<T> {
        match self {
            Ok(r) => Ok(r),
            Err(err) => {
                let err = Error::new(Origin::Api, Kind::Internal, err.to_string());
                Err(err)
            }
        }
    }
}

impl<T> FromSqlxError<T> for core::result::Result<T, sqlx::migrate::MigrateError> {
    #[track_caller]
    fn into_core(self) -> Result<T> {
        match self {
            Ok(r) => Ok(r),
            Err(err) => Err(Error::new(
                Origin::Application,
                Kind::Io,
                format!("migration error {err}"),
            )),
        }
    }
}

/// This trait provides some syntax to shorten queries execution returning ()
pub trait ToVoid<T> {
    /// Return a () value
    fn void(self) -> Result<()>;
}

impl<T> ToVoid<T> for core::result::Result<T, sqlx::error::Error> {
    #[track_caller]
    fn void(self) -> Result<()> {
        self.map(|_| ()).into_core()
    }
}

#[cfg(test)]
mod tests {
    use sqlx::sqlite::SqliteQueryResult;
    use sqlx::FromRow;
    use tempfile::NamedTempFile;

    use crate::database::ToSqlxType;

    use super::*;

    /// This is a sanity check to test that the database can be created with a file path
    /// and that migrations are running ok, at least for one table
    #[tokio::test]
    async fn test_create_identity_table() -> Result<()> {
        let db_file = NamedTempFile::new().unwrap();
        let db = SqlxDatabase::create(db_file.path()).await?;

        let inserted = insert_identity(&db).await.unwrap();

        assert_eq!(inserted.rows_affected(), 1);
        Ok(())
    }

    /// This test checks that we can run a query and return an entity
    #[tokio::test]
    async fn test_query() -> Result<()> {
        let db_file = NamedTempFile::new().unwrap();
        let db = SqlxDatabase::create(db_file.path()).await?;

        insert_identity(&db).await.unwrap();

        // successful query
        let result: Option<IdentifierRow> =
            sqlx::query_as("SELECT identifier FROM identity WHERE identifier=?1")
                .bind("Ifa804b7fca12a19eed206ae180b5b576860ae651")
                .fetch_optional(&*db.pool)
                .await
                .unwrap();
        assert_eq!(
            result,
            Some(IdentifierRow(
                "Ifa804b7fca12a19eed206ae180b5b576860ae651".into()
            ))
        );

        // failed query
        let result: Option<IdentifierRow> =
            sqlx::query_as("SELECT identifier FROM identity WHERE identifier=?1")
                .bind("x")
                .fetch_optional(&*db.pool)
                .await
                .unwrap();
        assert_eq!(result, None);
        Ok(())
    }

    /// HELPERS
    async fn insert_identity(db: &SqlxDatabase) -> Result<SqliteQueryResult> {
        sqlx::query("INSERT INTO identity VALUES (?1, ?2)")
            .bind("Ifa804b7fca12a19eed206ae180b5b576860ae651")
            .bind("123".to_sql())
            .execute(&*db.pool)
            .await
            .into_core()
    }

    #[derive(FromRow, PartialEq, Eq, Debug)]
    struct IdentifierRow(String);
}