covert_storage/
migrator.rs

1use sha2::{Digest, Sha384};
2use sqlx::Executor;
3
4use crate::scoped_queries::ScopedQuery;
5use crate::{BackendStoragePool, EncryptedPool};
6
7const BACKEND_MIGRATIONS_TABLE: &str = "_BACKEND_STORAGE_MIGRATIONS";
8
9#[derive(Debug)]
10pub struct MigrationScript {
11    pub script: String,
12    pub description: String,
13}
14
15async fn create_migrate_table(pool: &EncryptedPool) -> Result<(), sqlx::Error> {
16    // TODO: is mount_id here the correct type?
17    let sql = format!(
18        "CREATE TABLE IF NOT EXISTS {BACKEND_MIGRATIONS_TABLE}(
19        mount_id INTEGER NOT NULL REFERENCES MOUNTS(id) ON DELETE CASCADE ON UPDATE CASCADE,
20        version INTEGER NOT NULL,
21        description TEXT NOT NULL,
22        checksum BLOB NOT NULL,
23        created_at TIMESTAMP NOT NULL,
24        PRIMARY KEY(mount_id, version)
25    )"
26    );
27    sqlx::query(&sql).execute(pool).await?;
28    Ok(())
29}
30
31#[derive(Debug, sqlx::FromRow)]
32struct LatestMigration {
33    latest_version: Option<i64>,
34}
35
36#[derive(Debug, thiserror::Error)]
37pub enum MigrationError {
38    #[error("sqlx error")]
39    DB(sqlx::Error),
40    #[error("bad query")]
41    BadQuery,
42    #[error("unable to parse migration script `{filename}`")]
43    Script { filename: String, error: String },
44    #[error("unable to execute migration script `{filename}`")]
45    Execution {
46        filename: String,
47        error: sqlx::Error,
48    },
49}
50
51impl From<sqlx::Error> for MigrationError {
52    fn from(err: sqlx::Error) -> Self {
53        Self::DB(err)
54    }
55}
56
57/// Apply [`MigrationScript`]'s for a backend by applying the mount id and prefix.
58///
59/// # Errors
60///
61/// Returns error if it fails to apply any of the migration scripts.
62pub async fn migrate(
63    pool: &EncryptedPool,
64    migrations: &[MigrationScript],
65    mount_id: &str,
66    prefix: &str,
67) -> Result<(), MigrationError> {
68    create_migrate_table(pool).await?;
69
70    let latest_migration: Option<LatestMigration> = sqlx::query_as(&format!(
71        "SELECT MAX(version) AS latest_version FROM {BACKEND_MIGRATIONS_TABLE} 
72        WHERE mount_id = ?"
73    ))
74    .bind(mount_id)
75    .fetch_optional(pool)
76    .await?;
77    let last_migration_version = latest_migration.and_then(|m| m.latest_version);
78
79    for (version, migration) in migrations.iter().enumerate() {
80        if let Some(last_migration_version) = last_migration_version {
81            if last_migration_version >= version as i64 {
82                continue;
83            }
84        }
85        // First check if scoped query is valid
86        let sql =
87            ScopedQuery::new(prefix, &migration.script).map_err(|_| MigrationError::BadQuery)?;
88        let checksum = Sha384::digest(sql.sql().as_bytes()).to_vec();
89
90        let mut tx = pool.begin().await?;
91
92        // Try to add new migration version for backend
93        sqlx::query(&format!(
94            "INSERT INTO {BACKEND_MIGRATIONS_TABLE} (
95        mount_id,
96        version,
97        description,
98        checksum,
99        created_at
100    ) VALUES (
101        ?,
102        ?,
103        ?,
104        ?,
105        ?
106    )"
107        ))
108        .bind(mount_id)
109        .bind(version as i64)
110        .bind(&migration.description)
111        .bind(checksum)
112        .bind(chrono::Utc::now())
113        .execute(&mut tx)
114        .await
115        .map_err(|_| MigrationError::BadQuery)?;
116
117        // Migration script
118        tx.execute(sql.sql()).await?;
119
120        tx.commit().await?;
121    }
122
123    Ok(())
124}
125
126/// Run migrations for a given backend.
127///
128/// # Errors
129///
130/// Returns error if the migration fails to read the migration file contents
131/// or fails to apply any of the migrations.
132pub async fn migrate_backend<M: rust_embed::RustEmbed>(
133    storage: &BackendStoragePool,
134) -> Result<(), MigrationError> {
135    let migrations = migration_scripts::<M>()?;
136
137    for migration in migrations {
138        storage
139            .query(&migration.script)
140            .map_err(|error| MigrationError::Execution {
141                filename: migration.description.clone(),
142                error,
143            })?
144            .execute()
145            .await
146            .map_err(|error| MigrationError::Execution {
147                filename: migration.description,
148                error,
149            })?;
150    }
151    Ok(())
152}
153
154/// Retrieve [`MigrationScript`]'s from type that implements [`rust_embed::RustEmbed`].
155///
156/// # Errors
157///
158/// Returns error if it is unable to parse the contents of any of the migration
159/// script files.
160pub fn migration_scripts<M: rust_embed::RustEmbed>() -> Result<Vec<MigrationScript>, MigrationError>
161{
162    let mut migrations = M::iter().collect::<Vec<_>>();
163    migrations.sort();
164
165    let mut migration_scripts = vec![];
166    for migration_file_name in migrations {
167        if let Some(migration) = M::get(&migration_file_name) {
168            let sql =
169                String::from_utf8(migration.data.to_vec()).map_err(|_| MigrationError::Script {
170                    error: "Unable to parse migration script to UTF-8".to_string(),
171                    filename: migration_file_name.to_string(),
172                })?;
173            migration_scripts.push(MigrationScript {
174                description: migration_file_name.to_string(),
175                script: sql,
176            });
177        } else {
178            return Err(MigrationError::Script {
179                filename: migration_file_name.to_string(),
180                error: "Unable to get migration script".to_string(),
181            });
182        }
183    }
184
185    Ok(migration_scripts)
186}
187
188#[derive(Debug, sqlx::FromRow)]
189pub struct Migration {
190    pub mount_id: i64,
191    pub version: i64,
192    pub description: String,
193    pub checksum: Vec<u8>,
194    pub created_at: chrono::DateTime<chrono::Utc>,
195}
196
197/// List applied migrations
198///
199/// # Errors
200///
201/// Returns error if query to retrieve migrations fails.
202pub async fn list_migrations(
203    pool: &EncryptedPool,
204    mount_id: &str,
205) -> Result<Vec<Migration>, sqlx::Error> {
206    sqlx::query_as(&format!(
207        "SELECT * FROM {BACKEND_MIGRATIONS_TABLE} WHERE mount_id = ?"
208    ))
209    .bind(mount_id)
210    .fetch_all(pool)
211    .await
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217
218    #[derive(Debug, sqlx::FromRow, PartialEq, Eq)]
219    struct Tables {
220        name: String,
221    }
222
223    #[tokio::test]
224    #[allow(clippy::too_many_lines)]
225    async fn migration_works() {
226        let pool = EncryptedPool::new(&":memory:".to_string());
227        let master_key = pool.initialize().unwrap().unwrap();
228        pool.unseal(master_key).unwrap();
229
230        let mount_id = "12421412";
231        let prefix = "foo_bar_";
232
233        // create dummy mounts table
234        sqlx::query("CREATE TABLE MOUNTS ( id INTEGER PRIMARY KEY )")
235            .execute(&pool)
236            .await
237            .unwrap();
238        sqlx::query("INSERT INTO MOUNTS (id) VALUES (?)")
239            .bind(mount_id)
240            .execute(&pool)
241            .await
242            .unwrap();
243
244        let mut migrations = vec![
245            MigrationScript {
246                description: "2022-12-12-init.sql".into(),
247                script: r#"
248CREATE TABLE IF NOT EXISTS SECRETS (
249    "key" TEXT NOT NULL,
250    "version" INTEGER NOT NULL, 
251    "value" TEXT,
252    created_time TIMESTAMP NOT NULL,
253    deleted BOOLEAN NOT NULL DEFAULT FALSE,
254    destroyed BOOLEAN NOT NULL DEFAULT FALSE,
255    PRIMARY KEY("key", "version"),
256    CONSTRAINT destroyed_secret CHECK (
257        -- If not destroyed then value is *not* null
258        (NOT(destroyed) AND "value" IS NOT NULL) OR
259        -- If destroyed then value is null
260        (destroyed AND "value" IS NULL) 
261    )
262); 
263
264
265CREATE TABLE IF NOT EXISTS CONFIG (
266    lock INTEGER PRIMARY KEY DEFAULT 1,
267    max_versions INTEGER NOT NULL DEFAULT 10,
268
269    -- Used to ensure that maximum one config is ever inserted
270    CONSTRAINT CONFIG_LOCK CHECK (lock=1)
271); 
272                
273                "#
274                .to_string(),
275            },
276            MigrationScript {
277                description: "2022-12-14-add-user.sql".into(),
278                script: r#"
279CREATE TABLE IF NOT EXISTS USERS (
280    uid INTEGER PRIMARY KEY,
281    "name" TEXT NOT NULL
282); 
283                "#
284                .to_string(),
285            },
286        ];
287        migrate(&pool, &migrations, mount_id, prefix).await.unwrap();
288
289        let res: Vec<Migration> =
290            sqlx::query_as(&format!("SELECT * FROM {BACKEND_MIGRATIONS_TABLE}"))
291                .fetch_all(&pool)
292                .await
293                .unwrap();
294        assert_eq!(res.len(), 2);
295
296        migrations.push(MigrationScript {
297            description: "2022-12-16-add-user-email.sql".into(),
298            script: r#"
299ALTER TABLE USERS 
300    ADD email TEXT; 
301            "#
302            .to_string(),
303        });
304        migrate(&pool, &migrations, mount_id, prefix).await.unwrap();
305
306        let res: Vec<Migration> =
307            sqlx::query_as(&format!("SELECT * FROM {BACKEND_MIGRATIONS_TABLE}"))
308                .fetch_all(&pool)
309                .await
310                .unwrap();
311        assert_eq!(res.len(), 3);
312
313        // Run again and nothing should change
314        migrate(&pool, &migrations, mount_id, prefix).await.unwrap();
315        let res: Vec<Migration> = list_migrations(&pool, mount_id).await.unwrap();
316        assert_eq!(res.len(), 3);
317
318        // List tables
319        let res: Vec<Tables> = sqlx::query_as("SELECT name FROM sqlite_master WHERE type='table'")
320            .fetch_all(&pool)
321            .await
322            .unwrap();
323        assert_eq!(
324            res,
325            vec![
326                Tables {
327                    name: "MOUNTS".to_string()
328                },
329                Tables {
330                    name: BACKEND_MIGRATIONS_TABLE.to_string()
331                },
332                // Tables from migrations
333                Tables {
334                    name: format!("{prefix}SECRETS")
335                },
336                Tables {
337                    name: format!("{prefix}CONFIG")
338                },
339                Tables {
340                    name: format!("{prefix}USERS")
341                }
342            ]
343        );
344
345        // Use new mount and it should work again
346        let mount_id = "6789";
347        let prefix = "foo_foo_";
348
349        let res: Vec<Migration> = list_migrations(&pool, mount_id).await.unwrap();
350        assert_eq!(res.len(), 0);
351
352        sqlx::query("INSERT INTO MOUNTS (id) VALUES (?)")
353            .bind(mount_id)
354            .execute(&pool)
355            .await
356            .unwrap();
357        migrate(&pool, &migrations, mount_id, prefix).await.unwrap();
358        let res: Vec<Migration> = list_migrations(&pool, mount_id).await.unwrap();
359        assert_eq!(res.len(), 3);
360    }
361}