1use std::path::Path;
2
3use anyhow::{Context, Result, bail, ensure};
4use rusqlite::{Connection, OpenFlags};
5
6use super::entry::{Migration, MigrationEntry, SqlMigration, apply_migration_and_verify_schema};
7use super::{MigratorBuilder, SchemaHash, SchemaHashes, schema};
8
9#[derive(Debug)]
32pub struct Migrator {
33 retired_migrations: Vec<SqlMigration>,
34 active_migrations: Vec<Migration>,
35 expected_schema_hashes: Vec<SchemaHash>,
36}
37
38impl Migrator {
39 pub(super) fn empty() -> Self {
44 Self {
45 retired_migrations: Vec::new(),
46 active_migrations: Vec::new(),
47 expected_schema_hashes: Vec::new(),
48 }
49 }
50
51 pub fn builder() -> Result<MigratorBuilder> {
53 MigratorBuilder::new()
54 }
55
56 pub(super) fn next_version(&self) -> usize {
61 self.expected_schema_hashes.len() + 1
62 }
63
64 pub(super) fn push_retired_unchecked(
71 &mut self,
72 migration: SqlMigration,
73 schema_hash: SchemaHash,
74 ) {
75 assert!(
76 self.active_migrations.is_empty(),
77 "cannot add retired migration after active migrations have started"
78 );
79 self.retired_migrations.push(migration);
80 self.expected_schema_hashes.push(schema_hash);
81 }
82
83 pub(super) fn push_active_unchecked(&mut self, migration: Migration, schema_hash: SchemaHash) {
89 self.active_migrations.push(migration);
90 self.expected_schema_hashes.push(schema_hash);
91 }
92
93 pub(super) fn validate(&self) -> Result<()> {
98 let migration_count = self.retired_migrations.len() + self.active_migrations.len();
99 ensure!(
100 !self.expected_schema_hashes.is_empty(),
101 "cannot build migrator without migrations"
102 );
103 ensure!(
104 self.expected_schema_hashes.len() == migration_count,
105 "migrator schema hash count {} must match migration count {migration_count}",
106 self.expected_schema_hashes.len()
107 );
108 Ok(())
109 }
110
111 pub fn schema_hashes(&self) -> SchemaHashes<'_> {
116 SchemaHashes(&self.expected_schema_hashes)
117 }
118
119 pub fn bootstrap(&self, database_filepath: impl AsRef<Path>) -> Result<()> {
125 let database_filepath = database_filepath.as_ref();
126 ensure!(
127 !fs_err::exists(database_filepath).with_context(|| {
128 format!("failed to check database path {}", database_filepath.display())
129 })?,
130 "database already exists: {}",
131 database_filepath.display()
132 );
133
134 let mut conn = Connection::open_with_flags(
135 database_filepath,
136 OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE,
137 )
138 .with_context(|| format!("failed to create database {}", database_filepath.display()))?;
139
140 self.apply_missing_migrations(&mut conn, 0)
141 }
142
143 pub fn migrate(&self, database_filepath: impl AsRef<Path>) -> Result<()> {
150 let database_filepath = existing_database_path(database_filepath.as_ref())?;
151 let mut conn =
152 Connection::open_with_flags(database_filepath, OpenFlags::SQLITE_OPEN_READ_WRITE)
153 .with_context(|| {
154 format!("failed to open database {}", database_filepath.display())
155 })?;
156
157 self.migrate_connection(&mut conn)
158 }
159
160 pub fn verify_latest_schema(&self, database_filepath: impl AsRef<Path>) -> Result<()> {
169 let database_filepath = existing_database_path(database_filepath.as_ref())?;
170 let conn =
171 Connection::open_with_flags(database_filepath, OpenFlags::SQLITE_OPEN_READ_WRITE)
172 .with_context(|| {
173 format!("failed to open existing database {}", database_filepath.display())
174 })?;
175
176 self.verify_latest_connection_schema(&conn)
177 }
178
179 fn migrate_connection(&self, conn: &mut Connection) -> Result<()> {
180 let current_version = self.version_check(conn)?;
181 ensure!(current_version > 0, "database has not been bootstrapped; run bootstrap first");
182
183 self.apply_missing_migrations(conn, current_version)
184 }
185
186 fn apply_missing_migrations(
187 &self,
188 conn: &mut Connection,
189 current_version: usize,
190 ) -> Result<()> {
191 let retired_versions = self.retired_migrations.len();
192
193 let mut applied_version = current_version;
194 if applied_version == 0 {
195 for (idx, migration) in self.retired_migrations.iter().enumerate() {
196 let version = idx + 1;
197 self.apply_migration(conn, version, migration)?;
198 applied_version = version;
199 }
200 }
201
202 let active_start = applied_version.saturating_sub(retired_versions);
203 for (idx, migration) in self.active_migrations.iter().enumerate().skip(active_start) {
204 let version = retired_versions + idx + 1;
205 self.apply_migration(conn, version, migration)?;
206 }
207
208 Ok(())
209 }
210
211 fn verify_latest_connection_schema(&self, conn: &Connection) -> Result<()> {
212 let current_version = self.version_check(conn)?;
213 let total_versions = self.expected_schema_hashes.len();
214
215 ensure!(
216 current_version == total_versions,
217 "database version {current_version} is older than migrator version {total_versions}; \
218 run the migrate command first"
219 );
220
221 Ok(())
222 }
223
224 fn version_check(&self, conn: &Connection) -> Result<usize> {
230 let current_version =
231 schema::get_version(conn).context("failed to read database version")?;
232 let total_versions = self.expected_schema_hashes.len();
233
234 ensure!(
235 current_version <= total_versions,
236 "database version {current_version} is newer than migrator version {total_versions}"
237 );
238
239 let retired_versions = self.retired_migrations.len();
240 if current_version > 0 && current_version < retired_versions {
241 let name = self.migration_name(current_version).unwrap_or("<unknown>");
242 bail!(
243 "database version {current_version} \"{name}\" is inside the retired migration \
244 range; retired migrations can only initialize new databases"
245 );
246 }
247
248 if current_version > 0 {
249 self.verify_current_schema(conn, current_version)?;
250 }
251
252 Ok(current_version)
253 }
254
255 fn apply_migration(
257 &self,
258 conn: &mut Connection,
259 version: usize,
260 migration: &impl MigrationEntry,
261 ) -> Result<()> {
262 let name = migration.name();
263 let expected = self.expected_schema_hashes[version - 1];
264 apply_migration_and_verify_schema(conn, version, migration, expected)
265 .with_context(|| format!("failed to apply migration {version} \"{name}\""))
266 }
267
268 fn verify_current_schema(&self, conn: &Connection, version: usize) -> Result<()> {
270 let name = self.migration_name(version).unwrap_or("<unknown>");
271 let expected = self.expected_schema_hashes[version - 1];
272 let actual = SchemaHash::new(conn).with_context(|| {
273 format!("failed to compute schema hash at database version {version} \"{name}\"")
274 })?;
275
276 ensure!(
277 actual == expected,
278 "schema hash mismatch at database version {version} \"{name}\": expected {expected}, \
279 got {actual}"
280 );
281 Ok(())
282 }
283
284 fn migration_name(&self, version: usize) -> Option<&'static str> {
286 if version == 0 {
287 return None;
288 }
289
290 if version <= self.retired_migrations.len() {
291 return Some(self.retired_migrations[version - 1].name());
292 }
293
294 self.active_migrations
295 .get(version - self.retired_migrations.len() - 1)
296 .map(MigrationEntry::name)
297 }
298}
299
300fn existing_database_path(database_filepath: &Path) -> Result<&Path> {
301 let metadata = fs_err::metadata(database_filepath)
302 .with_context(|| format!("failed to read database {}", database_filepath.display()))?;
303 ensure!(
304 metadata.is_file(),
305 "database path is not a file: {}",
306 database_filepath.display()
307 );
308 Ok(database_filepath)
309}
310
311#[cfg(test)]
312mod tests {
313 use std::path::{Path, PathBuf};
314
315 use anyhow::Result;
316 use rusqlite::{Connection, Transaction};
317
318 use super::super::{Migrator, schema};
319
320 fn add_items_index(tx: &Transaction<'_>) -> Result<()> {
321 tx.execute_batch("CREATE INDEX idx_items_value ON items(value);")?;
322 Ok(())
323 }
324
325 fn add_item_height(tx: &Transaction<'_>) -> Result<()> {
326 tx.execute_batch("ALTER TABLE items ADD COLUMN height INTEGER;")?;
327 Ok(())
328 }
329
330 fn create_extra_table_when_items_exist(tx: &Transaction<'_>) -> Result<()> {
331 let item_count: i64 = tx.query_row("SELECT COUNT(*) FROM items", [], |row| row.get(0))?;
332 if item_count > 0 {
333 tx.execute_batch("CREATE TABLE unexpected (id INTEGER PRIMARY KEY);")?;
334 }
335 Ok(())
336 }
337
338 fn create_items_table(tx: &Transaction<'_>) -> Result<()> {
339 tx.execute_batch("CREATE TABLE items (id INTEGER PRIMARY KEY, value TEXT);")?;
340 Ok(())
341 }
342
343 fn object_exists(conn: &Connection, name: &str) -> Result<bool> {
344 let exists = conn.query_row(
345 "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE name = ?1)",
346 [name],
347 |row| row.get::<_, bool>(0),
348 )?;
349 Ok(exists)
350 }
351
352 struct TestDatabase {
353 path: PathBuf,
354 }
355
356 impl TestDatabase {
357 fn new(name: &str) -> Self {
358 let path = std::env::temp_dir()
359 .join(format!("miden-node-db-migrator-{name}-{}.sqlite3", std::process::id()));
360 let db = Self { path };
361 db.remove_files();
362 db
363 }
364
365 fn path(&self) -> &Path {
366 &self.path
367 }
368
369 fn open(&self) -> Result<Connection> {
370 Connection::open(&self.path).map_err(Into::into)
371 }
372
373 fn remove_files(&self) {
374 let _ = fs_err::remove_file(&self.path);
375 let _ = fs_err::remove_file(self.path.with_extension("sqlite3-wal"));
376 let _ = fs_err::remove_file(self.path.with_extension("sqlite3-shm"));
377 }
378 }
379
380 impl Drop for TestDatabase {
381 fn drop(&mut self) {
382 self.remove_files();
383 }
384 }
385
386 #[test]
387 fn bootstraps_new_database_through_retired_and_code() -> Result<()> {
388 let migrator = Migrator::builder()?
389 .push_retired(
390 "create items",
391 "CREATE TABLE items (id INTEGER PRIMARY KEY, value TEXT);",
392 )?
393 .push_code("add item height", add_item_height)?
394 .build()?;
395
396 let db = TestDatabase::new("bootstraps_new_database_through_retired_and_code");
397 migrator.bootstrap(db.path())?;
398
399 let conn = db.open()?;
400 assert_eq!(schema::get_version(&conn)?, 2);
401 conn.execute("INSERT INTO items (id, value, height) VALUES (1, 'a', 10)", [])?;
402 Ok(())
403 }
404
405 #[test]
406 fn bootstraps_new_database_with_code_only_migration() -> Result<()> {
407 let migrator =
408 Migrator::builder()?.push_code("create items", create_items_table)?.build()?;
409
410 let db = TestDatabase::new("bootstraps_new_database_with_code_only_migration");
411 migrator.bootstrap(db.path())?;
412
413 let conn = db.open()?;
414 assert_eq!(schema::get_version(&conn)?, 1);
415 conn.execute("INSERT INTO items (id, value) VALUES (1, 'a')", [])?;
416 Ok(())
417 }
418
419 #[test]
420 fn bootstraps_new_database_with_sql_only_migration() -> Result<()> {
421 let migrator = Migrator::builder()?
422 .push_sql("create items", "CREATE TABLE items (id INTEGER PRIMARY KEY, value TEXT);")?
423 .build()?;
424
425 let db = TestDatabase::new("bootstraps_new_database_with_sql_only_migration");
426 migrator.bootstrap(db.path())?;
427
428 let conn = db.open()?;
429 assert_eq!(schema::get_version(&conn)?, 1);
430 conn.execute("INSERT INTO items (id, value) VALUES (1, 'a')", [])?;
431 Ok(())
432 }
433
434 #[test]
435 fn applies_missing_code_migrations_to_existing_database() -> Result<()> {
436 let migrator = Migrator::builder()?
437 .push_retired(
438 "create items",
439 "CREATE TABLE items (id INTEGER PRIMARY KEY, value TEXT);",
440 )?
441 .push_code("index item values", add_items_index)?
442 .build()?;
443
444 let db = TestDatabase::new("applies_missing_code_migrations_to_existing_database");
445 {
446 let conn = db.open()?;
447 conn.execute_batch(
448 "CREATE TABLE items (id INTEGER PRIMARY KEY, value TEXT);
449 PRAGMA user_version = 1;",
450 )?;
451 }
452
453 migrator.migrate(db.path())?;
454
455 let conn = db.open()?;
456 assert_eq!(schema::get_version(&conn)?, 2);
457 assert!(object_exists(&conn, "idx_items_value")?);
458 Ok(())
459 }
460
461 #[test]
462 fn bootstrap_rejects_existing_database_file() -> Result<()> {
463 let migrator = Migrator::builder()?
464 .push_sql("create items", "CREATE TABLE items (id INTEGER PRIMARY KEY);")?
465 .build()?;
466
467 let db = TestDatabase::new("bootstrap_rejects_existing_database_file");
468 {
469 let _conn = db.open()?;
470 }
471
472 let err = migrator.bootstrap(db.path()).expect_err("existing database should fail");
473 assert!(err.to_string().contains("database already exists"));
474 Ok(())
475 }
476
477 #[test]
478 fn migrate_rejects_missing_database() -> Result<()> {
479 let migrator = Migrator::builder()?
480 .push_sql("create items", "CREATE TABLE items (id INTEGER PRIMARY KEY);")?
481 .build()?;
482
483 let db = TestDatabase::new("migrate_rejects_missing_database");
484
485 let err = migrator.migrate(db.path()).expect_err("missing database should fail");
486 assert!(err.to_string().contains("failed to read database"));
487 assert!(!db.path().exists());
488 Ok(())
489 }
490
491 #[test]
492 fn migrate_rejects_unbootstrapped_database() -> Result<()> {
493 let migrator = Migrator::builder()?
494 .push_sql("create items", "CREATE TABLE items (id INTEGER PRIMARY KEY);")?
495 .build()?;
496
497 let db = TestDatabase::new("migrate_rejects_unbootstrapped_database");
498 {
499 let _conn = db.open()?;
500 }
501
502 let err = migrator.migrate(db.path()).expect_err("unbootstrapped database should fail");
503 assert!(err.to_string().contains("database has not been bootstrapped"));
504 Ok(())
505 }
506
507 #[test]
508 fn rejects_existing_database_inside_retired_migration_range() -> Result<()> {
509 let migrator = Migrator::builder()?
510 .push_retired("create items", "CREATE TABLE items (id INTEGER PRIMARY KEY);")?
511 .push_retired("create notes", "CREATE TABLE notes (id INTEGER PRIMARY KEY);")?
512 .build()?;
513
514 let db = TestDatabase::new("rejects_existing_database_inside_retired_migration_range");
515 {
516 let conn = db.open()?;
517 conn.execute_batch(
518 "CREATE TABLE items (id INTEGER PRIMARY KEY);
519 PRAGMA user_version = 1;",
520 )?;
521 }
522
523 let err = migrator.migrate(db.path()).expect_err("migration should fail");
524 assert!(err.to_string().contains("inside the retired migration range"));
525 Ok(())
526 }
527
528 #[test]
529 fn verifies_current_schema_before_applying_missing_migrations() -> Result<()> {
530 let migrator = Migrator::builder()?
531 .push_retired("create items", "CREATE TABLE items (id INTEGER PRIMARY KEY);")?
532 .build()?;
533
534 let db = TestDatabase::new("verifies_current_schema_before_applying_missing_migrations");
535 migrator.bootstrap(db.path())?;
536 {
537 let conn = db.open()?;
538 conn.execute_batch("CREATE TABLE tampered (id INTEGER PRIMARY KEY);")?;
539 }
540
541 let err = migrator.migrate(db.path()).expect_err("migration should fail");
542 assert!(err.to_string().contains("schema hash mismatch at database version 1"));
543 Ok(())
544 }
545
546 #[test]
547 fn rolls_back_code_migration_when_schema_hash_mismatches() -> Result<()> {
548 let migrator = Migrator::builder()?
549 .push_retired("create items", "CREATE TABLE items (id INTEGER PRIMARY KEY);")?
550 .push_code("conditionally create extra", create_extra_table_when_items_exist)?
551 .build()?;
552
553 let db = TestDatabase::new("rolls_back_code_migration_when_schema_hash_mismatches");
554 {
555 let conn = db.open()?;
556 conn.execute_batch(
557 "CREATE TABLE items (id INTEGER PRIMARY KEY);
558 INSERT INTO items (id) VALUES (1);
559 PRAGMA user_version = 1;",
560 )?;
561 }
562
563 let err = migrator.migrate(db.path()).expect_err("migration should fail");
564 assert!(err.to_string().contains("failed to apply migration 2"));
565 assert!(err.chain().any(|cause| cause.to_string().contains("schema hash mismatch")));
566
567 let conn = db.open()?;
568 assert_eq!(schema::get_version(&conn)?, 1);
569 assert!(!object_exists(&conn, "unexpected")?);
570 Ok(())
571 }
572
573 #[test]
574 fn verify_latest_schema_accepts_current_database() -> Result<()> {
575 let migrator = Migrator::builder()?
576 .push_sql("create items", "CREATE TABLE items (id INTEGER PRIMARY KEY);")?
577 .build()?;
578
579 let db = TestDatabase::new("verify_latest_schema_accepts_current_database");
580 migrator.bootstrap(db.path())?;
581
582 migrator.verify_latest_schema(db.path())?;
583 Ok(())
584 }
585
586 #[test]
587 fn verify_latest_schema_rejects_schema_hash_mismatch() -> Result<()> {
588 let migrator = Migrator::builder()?
589 .push_sql("create items", "CREATE TABLE items (id INTEGER PRIMARY KEY);")?
590 .build()?;
591
592 let db = TestDatabase::new("verify_latest_schema_rejects_schema_hash_mismatch");
593 {
594 let conn = db.open()?;
595 conn.execute_batch(
596 "CREATE TABLE different (id INTEGER PRIMARY KEY);
597 PRAGMA user_version = 1;",
598 )?;
599 }
600
601 let err = migrator.verify_latest_schema(db.path()).expect_err("schema drift should fail");
602 assert!(err.to_string().contains("schema hash mismatch"));
603 Ok(())
604 }
605
606 #[test]
607 fn verify_latest_schema_rejects_missing_migrations_without_applying_them() -> Result<()> {
608 let migrator = Migrator::builder()?
609 .push_sql("create items", "CREATE TABLE items (id INTEGER PRIMARY KEY, value TEXT);")?
610 .push_code("index item values", add_items_index)?
611 .build()?;
612
613 let db = TestDatabase::new("verify_latest_schema_rejects_missing_migrations");
614 {
615 let conn = db.open()?;
616 conn.execute_batch(
617 "CREATE TABLE items (id INTEGER PRIMARY KEY, value TEXT);
618 PRAGMA user_version = 1;",
619 )?;
620 }
621
622 let err = migrator.verify_latest_schema(db.path()).expect_err("old database should fail");
623 assert!(err.to_string().contains("run the migrate command first"));
624
625 let conn = db.open()?;
626 assert_eq!(schema::get_version(&conn)?, 1);
627 assert!(!object_exists(&conn, "idx_items_value")?);
628 Ok(())
629 }
630
631 #[test]
632 fn verify_latest_schema_rejects_missing_database_without_creating_it() -> Result<()> {
633 let migrator = Migrator::builder()?
634 .push_sql("create items", "CREATE TABLE items (id INTEGER PRIMARY KEY);")?
635 .build()?;
636
637 let db = TestDatabase::new("verify_latest_schema_rejects_missing_database");
638
639 let err = migrator
640 .verify_latest_schema(db.path())
641 .expect_err("missing database should fail");
642 assert!(err.to_string().contains("failed to read database"));
643 assert!(!db.path().exists());
644 Ok(())
645 }
646}