1use sha2::{Digest, Sha384};
2use sqlx::Executor;
3
4use crate::scoped_queries::ScopedQuery;
5use crate::{BackendStoragePool, EncryptedPool};
6
7const BACKEND_MIGRATIONS_TABLE: &str = "_BACKEND_STORAGE_MIGRATIONS";
8
9#[derive(Debug)]
10pub struct MigrationScript {
11 pub script: String,
12 pub description: String,
13}
14
15async fn create_migrate_table(pool: &EncryptedPool) -> Result<(), sqlx::Error> {
16 let sql = format!(
18 "CREATE TABLE IF NOT EXISTS {BACKEND_MIGRATIONS_TABLE}(
19 mount_id INTEGER NOT NULL REFERENCES MOUNTS(id) ON DELETE CASCADE ON UPDATE CASCADE,
20 version INTEGER NOT NULL,
21 description TEXT NOT NULL,
22 checksum BLOB NOT NULL,
23 created_at TIMESTAMP NOT NULL,
24 PRIMARY KEY(mount_id, version)
25 )"
26 );
27 sqlx::query(&sql).execute(pool).await?;
28 Ok(())
29}
30
31#[derive(Debug, sqlx::FromRow)]
32struct LatestMigration {
33 latest_version: Option<i64>,
34}
35
36#[derive(Debug, thiserror::Error)]
37pub enum MigrationError {
38 #[error("sqlx error")]
39 DB(sqlx::Error),
40 #[error("bad query")]
41 BadQuery,
42 #[error("unable to parse migration script `{filename}`")]
43 Script { filename: String, error: String },
44 #[error("unable to execute migration script `{filename}`")]
45 Execution {
46 filename: String,
47 error: sqlx::Error,
48 },
49}
50
51impl From<sqlx::Error> for MigrationError {
52 fn from(err: sqlx::Error) -> Self {
53 Self::DB(err)
54 }
55}
56
57pub async fn migrate(
63 pool: &EncryptedPool,
64 migrations: &[MigrationScript],
65 mount_id: &str,
66 prefix: &str,
67) -> Result<(), MigrationError> {
68 create_migrate_table(pool).await?;
69
70 let latest_migration: Option<LatestMigration> = sqlx::query_as(&format!(
71 "SELECT MAX(version) AS latest_version FROM {BACKEND_MIGRATIONS_TABLE}
72 WHERE mount_id = ?"
73 ))
74 .bind(mount_id)
75 .fetch_optional(pool)
76 .await?;
77 let last_migration_version = latest_migration.and_then(|m| m.latest_version);
78
79 for (version, migration) in migrations.iter().enumerate() {
80 if let Some(last_migration_version) = last_migration_version {
81 if last_migration_version >= version as i64 {
82 continue;
83 }
84 }
85 let sql =
87 ScopedQuery::new(prefix, &migration.script).map_err(|_| MigrationError::BadQuery)?;
88 let checksum = Sha384::digest(sql.sql().as_bytes()).to_vec();
89
90 let mut tx = pool.begin().await?;
91
92 sqlx::query(&format!(
94 "INSERT INTO {BACKEND_MIGRATIONS_TABLE} (
95 mount_id,
96 version,
97 description,
98 checksum,
99 created_at
100 ) VALUES (
101 ?,
102 ?,
103 ?,
104 ?,
105 ?
106 )"
107 ))
108 .bind(mount_id)
109 .bind(version as i64)
110 .bind(&migration.description)
111 .bind(checksum)
112 .bind(chrono::Utc::now())
113 .execute(&mut tx)
114 .await
115 .map_err(|_| MigrationError::BadQuery)?;
116
117 tx.execute(sql.sql()).await?;
119
120 tx.commit().await?;
121 }
122
123 Ok(())
124}
125
126pub async fn migrate_backend<M: rust_embed::RustEmbed>(
133 storage: &BackendStoragePool,
134) -> Result<(), MigrationError> {
135 let migrations = migration_scripts::<M>()?;
136
137 for migration in migrations {
138 storage
139 .query(&migration.script)
140 .map_err(|error| MigrationError::Execution {
141 filename: migration.description.clone(),
142 error,
143 })?
144 .execute()
145 .await
146 .map_err(|error| MigrationError::Execution {
147 filename: migration.description,
148 error,
149 })?;
150 }
151 Ok(())
152}
153
154pub fn migration_scripts<M: rust_embed::RustEmbed>() -> Result<Vec<MigrationScript>, MigrationError>
161{
162 let mut migrations = M::iter().collect::<Vec<_>>();
163 migrations.sort();
164
165 let mut migration_scripts = vec![];
166 for migration_file_name in migrations {
167 if let Some(migration) = M::get(&migration_file_name) {
168 let sql =
169 String::from_utf8(migration.data.to_vec()).map_err(|_| MigrationError::Script {
170 error: "Unable to parse migration script to UTF-8".to_string(),
171 filename: migration_file_name.to_string(),
172 })?;
173 migration_scripts.push(MigrationScript {
174 description: migration_file_name.to_string(),
175 script: sql,
176 });
177 } else {
178 return Err(MigrationError::Script {
179 filename: migration_file_name.to_string(),
180 error: "Unable to get migration script".to_string(),
181 });
182 }
183 }
184
185 Ok(migration_scripts)
186}
187
188#[derive(Debug, sqlx::FromRow)]
189pub struct Migration {
190 pub mount_id: i64,
191 pub version: i64,
192 pub description: String,
193 pub checksum: Vec<u8>,
194 pub created_at: chrono::DateTime<chrono::Utc>,
195}
196
197pub async fn list_migrations(
203 pool: &EncryptedPool,
204 mount_id: &str,
205) -> Result<Vec<Migration>, sqlx::Error> {
206 sqlx::query_as(&format!(
207 "SELECT * FROM {BACKEND_MIGRATIONS_TABLE} WHERE mount_id = ?"
208 ))
209 .bind(mount_id)
210 .fetch_all(pool)
211 .await
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217
218 #[derive(Debug, sqlx::FromRow, PartialEq, Eq)]
219 struct Tables {
220 name: String,
221 }
222
223 #[tokio::test]
224 #[allow(clippy::too_many_lines)]
225 async fn migration_works() {
226 let pool = EncryptedPool::new(&":memory:".to_string());
227 let master_key = pool.initialize().unwrap().unwrap();
228 pool.unseal(master_key).unwrap();
229
230 let mount_id = "12421412";
231 let prefix = "foo_bar_";
232
233 sqlx::query("CREATE TABLE MOUNTS ( id INTEGER PRIMARY KEY )")
235 .execute(&pool)
236 .await
237 .unwrap();
238 sqlx::query("INSERT INTO MOUNTS (id) VALUES (?)")
239 .bind(mount_id)
240 .execute(&pool)
241 .await
242 .unwrap();
243
244 let mut migrations = vec![
245 MigrationScript {
246 description: "2022-12-12-init.sql".into(),
247 script: r#"
248CREATE TABLE IF NOT EXISTS SECRETS (
249 "key" TEXT NOT NULL,
250 "version" INTEGER NOT NULL,
251 "value" TEXT,
252 created_time TIMESTAMP NOT NULL,
253 deleted BOOLEAN NOT NULL DEFAULT FALSE,
254 destroyed BOOLEAN NOT NULL DEFAULT FALSE,
255 PRIMARY KEY("key", "version"),
256 CONSTRAINT destroyed_secret CHECK (
257 -- If not destroyed then value is *not* null
258 (NOT(destroyed) AND "value" IS NOT NULL) OR
259 -- If destroyed then value is null
260 (destroyed AND "value" IS NULL)
261 )
262);
263
264
265CREATE TABLE IF NOT EXISTS CONFIG (
266 lock INTEGER PRIMARY KEY DEFAULT 1,
267 max_versions INTEGER NOT NULL DEFAULT 10,
268
269 -- Used to ensure that maximum one config is ever inserted
270 CONSTRAINT CONFIG_LOCK CHECK (lock=1)
271);
272
273 "#
274 .to_string(),
275 },
276 MigrationScript {
277 description: "2022-12-14-add-user.sql".into(),
278 script: r#"
279CREATE TABLE IF NOT EXISTS USERS (
280 uid INTEGER PRIMARY KEY,
281 "name" TEXT NOT NULL
282);
283 "#
284 .to_string(),
285 },
286 ];
287 migrate(&pool, &migrations, mount_id, prefix).await.unwrap();
288
289 let res: Vec<Migration> =
290 sqlx::query_as(&format!("SELECT * FROM {BACKEND_MIGRATIONS_TABLE}"))
291 .fetch_all(&pool)
292 .await
293 .unwrap();
294 assert_eq!(res.len(), 2);
295
296 migrations.push(MigrationScript {
297 description: "2022-12-16-add-user-email.sql".into(),
298 script: r#"
299ALTER TABLE USERS
300 ADD email TEXT;
301 "#
302 .to_string(),
303 });
304 migrate(&pool, &migrations, mount_id, prefix).await.unwrap();
305
306 let res: Vec<Migration> =
307 sqlx::query_as(&format!("SELECT * FROM {BACKEND_MIGRATIONS_TABLE}"))
308 .fetch_all(&pool)
309 .await
310 .unwrap();
311 assert_eq!(res.len(), 3);
312
313 migrate(&pool, &migrations, mount_id, prefix).await.unwrap();
315 let res: Vec<Migration> = list_migrations(&pool, mount_id).await.unwrap();
316 assert_eq!(res.len(), 3);
317
318 let res: Vec<Tables> = sqlx::query_as("SELECT name FROM sqlite_master WHERE type='table'")
320 .fetch_all(&pool)
321 .await
322 .unwrap();
323 assert_eq!(
324 res,
325 vec![
326 Tables {
327 name: "MOUNTS".to_string()
328 },
329 Tables {
330 name: BACKEND_MIGRATIONS_TABLE.to_string()
331 },
332 Tables {
334 name: format!("{prefix}SECRETS")
335 },
336 Tables {
337 name: format!("{prefix}CONFIG")
338 },
339 Tables {
340 name: format!("{prefix}USERS")
341 }
342 ]
343 );
344
345 let mount_id = "6789";
347 let prefix = "foo_foo_";
348
349 let res: Vec<Migration> = list_migrations(&pool, mount_id).await.unwrap();
350 assert_eq!(res.len(), 0);
351
352 sqlx::query("INSERT INTO MOUNTS (id) VALUES (?)")
353 .bind(mount_id)
354 .execute(&pool)
355 .await
356 .unwrap();
357 migrate(&pool, &migrations, mount_id, prefix).await.unwrap();
358 let res: Vec<Migration> = list_migrations(&pool, mount_id).await.unwrap();
359 assert_eq!(res.len(), 3);
360 }
361}