rust_query/
migrate.rs

1use std::{
2    collections::{HashMap, HashSet},
3    convert::Infallible,
4    marker::PhantomData,
5    ops::{Deref, Not},
6    path::Path,
7    sync::atomic::AtomicI64,
8};
9
10use rusqlite::{Connection, config::DbConfig};
11use sea_query::{Alias, ColumnDef, IntoTableRef, SqliteQueryBuilder, TableDropStatement};
12use self_cell::MutBorrow;
13
14use crate::{
15    FromExpr, IntoExpr, Table, TableRow, Transaction,
16    alias::{Scope, TmpTable},
17    hash,
18    schema_pragma::read_schema,
19    transaction::{Database, OwnedTransaction, TXN, try_insert_private},
20};
21
22pub struct TableTypBuilder<S> {
23    pub(crate) ast: hash::Schema,
24    _p: PhantomData<S>,
25}
26
27impl<S> Default for TableTypBuilder<S> {
28    fn default() -> Self {
29        Self {
30            ast: Default::default(),
31            _p: Default::default(),
32        }
33    }
34}
35
36impl<S> TableTypBuilder<S> {
37    pub fn table<T: Table<Schema = S>>(&mut self) {
38        let mut b = hash::TypBuilder::default();
39        T::typs(&mut b);
40        self.ast.tables.insert((T::NAME.to_owned(), b.ast));
41    }
42}
43
44pub trait Schema: Sized + 'static {
45    const VERSION: i64;
46    fn typs(b: &mut TableTypBuilder<Self>);
47}
48
49pub trait Migration {
50    type FromSchema: 'static;
51    type From: Table<Schema = Self::FromSchema>;
52    type To: Table<MigrateFrom = Self::From>;
53    type Conflict;
54
55    #[doc(hidden)]
56    fn prepare(
57        val: Self,
58        prev: crate::Expr<'static, Self::FromSchema, Self::From>,
59    ) -> <Self::To as Table>::Insert;
60    #[doc(hidden)]
61    fn map_conflict(val: TableRow<Self::From>) -> Self::Conflict;
62}
63
64/// Transaction type for use in migrations.
65pub struct TransactionMigrate<FromSchema> {
66    inner: Transaction<FromSchema>,
67    scope: Scope,
68    rename_map: HashMap<&'static str, TmpTable>,
69}
70
71impl<FromSchema> Deref for TransactionMigrate<FromSchema> {
72    type Target = Transaction<FromSchema>;
73
74    fn deref(&self) -> &Self::Target {
75        &self.inner
76    }
77}
78
79impl<FromSchema> TransactionMigrate<FromSchema> {
80    fn new_table_name<T: Table>(&mut self) -> TmpTable {
81        *self.rename_map.entry(T::NAME).or_insert_with(|| {
82            let new_table_name = self.scope.tmp_table();
83            TXN.with_borrow(|txn| new_table::<T>(txn.as_ref().unwrap().get(), new_table_name));
84            new_table_name
85        })
86    }
87
88    fn unmigrated<M: Migration<FromSchema = FromSchema>, Out>(
89        &self,
90        new_name: TmpTable,
91    ) -> impl Iterator<Item = (i64, Out)>
92    where
93        Out: FromExpr<FromSchema, M::From>,
94    {
95        let data = self.inner.query(|rows| {
96            let old = rows.join(<M::From as Table>::TOKEN);
97            rows.into_vec((&old, Out::from_expr(&old)))
98        });
99
100        let migrated = Transaction::new().query(|rows| {
101            let new = rows.join_tmp::<M::From>(new_name);
102            rows.into_vec(new)
103        });
104        let migrated: HashSet<_> = migrated.into_iter().map(|x| x.inner.idx).collect();
105
106        data.into_iter().filter_map(move |(row, data)| {
107            migrated
108                .contains(&row.inner.idx)
109                .not()
110                .then_some((row.inner.idx, data))
111        })
112    }
113
114    /// Migrate some rows to the new schema.
115    ///
116    /// This will return an error when there is a conflict.
117    /// The error type depends on the number of unique constraints that the
118    /// migration can violate:
119    /// - 0 => [Infallible]
120    /// - 1.. => `TableRow<T::From>` (row in the old table that could not be migrated)
121    pub fn migrate_optional<
122        M: Migration<FromSchema = FromSchema>,
123        X: FromExpr<FromSchema, M::From>,
124    >(
125        &mut self,
126        mut f: impl FnMut(X) -> Option<M>,
127    ) -> Result<(), M::Conflict> {
128        let new_name = self.new_table_name::<M::To>();
129
130        for (idx, x) in self.unmigrated::<M, X>(new_name) {
131            if let Some(new) = f(x) {
132                try_insert_private::<M::To>(
133                    new_name.into_table_ref(),
134                    Some(idx),
135                    M::prepare(new, TableRow::new(idx).into_expr()),
136                )
137                .map_err(|_| M::map_conflict(TableRow::new(idx)))?;
138            };
139        }
140        Ok(())
141    }
142
143    /// Migrate all rows to the new schema.
144    ///
145    /// Conflict errors work the same as in [Self::migrate_optional].
146    ///
147    /// However, this method will return [Migrated] when all rows are migrated.
148    /// This can then be used as proof that there will be no foreign key violations.
149    pub fn migrate<M: Migration<FromSchema = FromSchema>, X: FromExpr<FromSchema, M::From>>(
150        &mut self,
151        mut f: impl FnMut(X) -> M,
152    ) -> Result<Migrated<'static, FromSchema, M::To>, M::Conflict> {
153        self.migrate_optional::<M, X>(|x| Some(f(x)))?;
154
155        Ok(Migrated {
156            _p: PhantomData,
157            f: Box::new(|_| {}),
158            _local: PhantomData,
159        })
160    }
161
162    /// Helper method for [Self::migrate].
163    ///
164    /// It can only be used when the migration is known to never cause unique constraint conflicts.
165    pub fn migrate_ok<
166        M: Migration<FromSchema = FromSchema, Conflict = Infallible>,
167        X: FromExpr<FromSchema, M::From>,
168    >(
169        &mut self,
170        f: impl FnMut(X) -> M,
171    ) -> Migrated<'static, FromSchema, M::To> {
172        let Ok(res) = self.migrate(f);
173        res
174    }
175}
176
177pub struct SchemaBuilder<'t, FromSchema> {
178    inner: TransactionMigrate<FromSchema>,
179    drop: Vec<TableDropStatement>,
180    foreign_key: HashMap<&'static str, Box<dyn 't + FnOnce() -> Infallible>>,
181}
182
183impl<'t, FromSchema: 'static> SchemaBuilder<'t, FromSchema> {
184    pub fn foreign_key<To: Table>(&mut self, err: impl 't + FnOnce() -> Infallible) {
185        self.inner.new_table_name::<To>();
186
187        self.foreign_key.insert(To::NAME, Box::new(err));
188    }
189
190    pub fn create_empty<To: Table>(&mut self) {
191        self.inner.new_table_name::<To>();
192    }
193
194    pub fn drop_table<T: Table>(&mut self) {
195        let name = Alias::new(T::NAME);
196        let step = sea_query::Table::drop().table(name).take();
197        self.drop.push(step);
198    }
199}
200
201fn new_table<T: Table>(conn: &Connection, alias: TmpTable) {
202    let mut f = crate::hash::TypBuilder::default();
203    T::typs(&mut f);
204    new_table_inner(conn, &f.ast, alias);
205}
206
207fn new_table_inner(conn: &Connection, table: &crate::hash::Table, alias: impl IntoTableRef) {
208    let mut create = table.create();
209    create
210        .table(alias)
211        .col(ColumnDef::new(Alias::new("id")).integer().primary_key());
212    let mut sql = create.to_string(SqliteQueryBuilder);
213    sql.push_str(" STRICT");
214    conn.execute(&sql, []).unwrap();
215}
216
217pub trait SchemaMigration<'a> {
218    type From: Schema;
219    type To: Schema;
220
221    fn tables(self, b: &mut SchemaBuilder<'a, Self::From>);
222}
223
224/// [Config] is used to open a database from a file or in memory.
225///
226/// This is the first step in the [Config] -> [Migrator] -> [Database] chain to
227/// get a [Database] instance.
228///
229/// # Sqlite config
230///
231/// Sqlite is configured to be in [WAL mode](https://www.sqlite.org/wal.html).
232/// The effect of this mode is that there can be any number of readers with one concurrent writer.
233/// What is nice about this is that a `&`[crate::Transaction] can always be made immediately.
234/// Making a `&mut`[crate::Transaction] has to wait until all other `&mut`[crate::Transaction]s are finished.
235pub struct Config {
236    manager: r2d2_sqlite::SqliteConnectionManager,
237    init: Box<dyn FnOnce(&rusqlite::Transaction)>,
238    /// Configure how often SQLite will synchronize the database to disk.
239    ///
240    /// The default is [Synchronous::Full].
241    pub synchronous: Synchronous,
242    /// Configure how foreign keys should be checked.
243    pub foreign_keys: ForeignKeys,
244}
245
246/// <https://www.sqlite.org/pragma.html#pragma_synchronous>
247///
248/// Note that the database uses WAL mode, so make sure to read the WAL specific section.
249#[non_exhaustive]
250pub enum Synchronous {
251    /// SQLite will fsync after every transaction.
252    ///
253    /// Transactions are durable, even following a power failure or hard reboot.
254    Full,
255
256    /// SQLite will only do essential fsync to prevent corruption.
257    ///
258    /// The database will not rollback transactions due to application crashes, but it might rollback due to a hardware reset or power loss.
259    /// Use this when performance is more important than durability.
260    Normal,
261}
262
263impl Synchronous {
264    fn as_str(self) -> &'static str {
265        match self {
266            Synchronous::Full => "FULL",
267            Synchronous::Normal => "NORMAL",
268        }
269    }
270}
271
272/// Which method should be used to check foreign-key constraints.
273///
274/// The default is [ForeignKeys::SQLite], but this is likely to change to [ForeignKeys::Rust].
275#[non_exhaustive]
276pub enum ForeignKeys {
277    /// Foreign-key constraints are checked by rust-query only.
278    ///
279    /// Most foreign-key checks are done at compile time and are thus completely free.
280    /// However, some runtime checks are required for deletes.
281    Rust,
282
283    /// Foreign-key constraints are checked by SQLite in addition to the checks done by rust-query.
284    ///
285    /// This is useful when using rust-query with [crate::TransactionWeak::rusqlite_transaction]
286    /// or when other software can write to the database.
287    /// Both can result in "dangling" foreign keys (which point at a non-existent row) if written incorrectly.
288    /// Dangling foreign keys can result in wrong results, but these dangling foreign keys can also turn
289    /// into "false" foreign keys if a new record is inserted that makes the foreign key valid.
290    /// This is a lot worse than a dangling foreign key, because it is generally not possible to detect.
291    ///
292    /// With the [ForeignKeys::SQLite] option, rust-query will prevent creating such false foreign keys
293    /// and panic instead.
294    /// The downside is that indexes are required on all foreign keys to make the checks efficient.
295    SQLite,
296}
297
298impl ForeignKeys {
299    fn as_str(self) -> &'static str {
300        match self {
301            ForeignKeys::Rust => "OFF",
302            ForeignKeys::SQLite => "ON",
303        }
304    }
305}
306
307impl Config {
308    /// Open a database that is stored in a file.
309    /// Creates the database if it does not exist.
310    ///
311    /// Opening the same database multiple times at the same time is fine,
312    /// as long as they migrate to or use the same schema.
313    /// All locking is done by sqlite, so connections can even be made using different client implementations.
314    pub fn open(p: impl AsRef<Path>) -> Self {
315        let manager = r2d2_sqlite::SqliteConnectionManager::file(p);
316        Self::open_internal(manager)
317    }
318
319    /// Creates a new empty database in memory.
320    pub fn open_in_memory() -> Self {
321        let manager = r2d2_sqlite::SqliteConnectionManager::memory();
322        Self::open_internal(manager)
323    }
324
325    fn open_internal(manager: r2d2_sqlite::SqliteConnectionManager) -> Self {
326        Self {
327            manager,
328            init: Box::new(|_| {}),
329            synchronous: Synchronous::Full,
330            foreign_keys: ForeignKeys::SQLite,
331        }
332    }
333
334    /// Append a raw sql statement to be executed if the database was just created.
335    ///
336    /// The statement is executed after creating the empty database and executing all previous statements.
337    pub fn init_stmt(mut self, sql: &'static str) -> Self {
338        self.init = Box::new(move |txn| {
339            (self.init)(txn);
340
341            txn.execute_batch(sql)
342                .expect("raw sql statement to populate db failed");
343        });
344        self
345    }
346}
347
348impl<S: Schema> Database<S> {
349    /// Create a [Migrator] to migrate a database.
350    ///
351    /// Returns [None] if the database `user_version` on disk is older than `S`.
352    pub fn migrator(config: Config) -> Option<Migrator<S>> {
353        let synchronous = config.synchronous.as_str();
354        let foreign_keys = config.foreign_keys.as_str();
355        let manager = config.manager.with_init(move |inner| {
356            inner.pragma_update(None, "journal_mode", "WAL")?;
357            inner.pragma_update(None, "synchronous", synchronous)?;
358            inner.pragma_update(None, "foreign_keys", foreign_keys)?;
359            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DDL, false)?;
360            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DML, false)?;
361            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DEFENSIVE, true)?;
362            Ok(())
363        });
364
365        use r2d2::ManageConnection;
366        let conn = manager.connect().unwrap();
367        conn.pragma_update(None, "foreign_keys", "OFF").unwrap();
368        let txn = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
369            Some(
370                conn.borrow_mut()
371                    .transaction_with_behavior(rusqlite::TransactionBehavior::Exclusive)
372                    .unwrap(),
373            )
374        });
375
376        // check if this database is newly created
377        if schema_version(txn.get()) == 0 {
378            let mut b = TableTypBuilder::default();
379            S::typs(&mut b);
380
381            for (table_name, table) in &*b.ast.tables {
382                new_table_inner(txn.get(), table, Alias::new(table_name));
383            }
384            (config.init)(txn.get());
385            set_user_version(txn.get(), S::VERSION).unwrap();
386        }
387
388        let user_version = user_version(txn.get()).unwrap();
389        // We can not migrate databases older than `S`
390        if user_version < S::VERSION {
391            return None;
392        }
393        debug_assert_eq!(
394            foreign_key_check(txn.get()),
395            None,
396            "foreign key constraint violated"
397        );
398
399        Some(Migrator {
400            manager,
401            transaction: txn,
402            _p: PhantomData,
403        })
404    }
405}
406
407/// [Migrator] is used to apply database migrations.
408///
409/// When all migrations are done, it can be turned into a [Database] instance with
410/// [Migrator::finish].
411pub struct Migrator<S> {
412    manager: r2d2_sqlite::SqliteConnectionManager,
413    transaction: OwnedTransaction,
414    _p: PhantomData<S>,
415}
416
417/// [Migrated] provides a proof of migration.
418///
419/// This only needs to be provided for tables that are migrated from a previous table.
420pub struct Migrated<'t, FromSchema, T> {
421    _p: PhantomData<T>,
422    f: Box<dyn 't + FnOnce(&mut SchemaBuilder<'t, FromSchema>)>,
423    _local: PhantomData<*const ()>,
424}
425
426impl<'t, FromSchema: 'static, T: Table> Migrated<'t, FromSchema, T> {
427    /// Don't migrate the remaining rows.
428    ///
429    /// This can cause foreign key constraint violations, which is why an error callback needs to be provided.
430    pub fn map_fk_err(err: impl 't + FnOnce() -> Infallible) -> Self {
431        Self {
432            _p: PhantomData,
433            f: Box::new(|x| x.foreign_key::<T>(err)),
434            _local: PhantomData,
435        }
436    }
437
438    #[doc(hidden)]
439    pub fn apply(self, b: &mut SchemaBuilder<'t, FromSchema>) {
440        (self.f)(b)
441    }
442}
443
444impl<S: Schema> Migrator<S> {
445    /// Apply a database migration if the current schema is `S` and return a [Migrator] for the next schema `N`.
446    ///
447    /// This function will panic if the schema on disk does not match what is expected for its `user_version`.
448    pub fn migrate<'x, M>(
449        mut self,
450        m: impl Send + FnOnce(&mut TransactionMigrate<S>) -> M,
451    ) -> Migrator<M::To>
452    where
453        M: SchemaMigration<'x, From = S>,
454    {
455        if user_version(self.transaction.get()).unwrap() == S::VERSION {
456            let res = std::thread::scope(|s| {
457                s.spawn(|| {
458                    TXN.set(Some(self.transaction));
459
460                    check_schema::<S>();
461
462                    let mut txn = TransactionMigrate {
463                        inner: Transaction::new(),
464                        scope: Default::default(),
465                        rename_map: HashMap::new(),
466                    };
467                    let m = m(&mut txn);
468
469                    let mut builder = SchemaBuilder {
470                        drop: vec![],
471                        foreign_key: HashMap::new(),
472                        inner: txn,
473                    };
474                    m.tables(&mut builder);
475
476                    let transaction = TXN.take().unwrap();
477
478                    for drop in builder.drop {
479                        let sql = drop.to_string(SqliteQueryBuilder);
480                        transaction.get().execute(&sql, []).unwrap();
481                    }
482                    for (to, tmp) in builder.inner.rename_map {
483                        let rename = sea_query::Table::rename().table(tmp, Alias::new(to)).take();
484                        let sql = rename.to_string(SqliteQueryBuilder);
485                        transaction.get().execute(&sql, []).unwrap();
486                    }
487                    if let Some(fk) = foreign_key_check(transaction.get()) {
488                        (builder.foreign_key.remove(&*fk).unwrap())();
489                    }
490                    #[allow(
491                        unreachable_code,
492                        reason = "rustc is stupid and thinks this is unreachable"
493                    )]
494                    set_user_version(transaction.get(), M::To::VERSION).unwrap();
495
496                    transaction
497                })
498                .join()
499            });
500            match res {
501                Ok(val) => self.transaction = val,
502                Err(payload) => std::panic::resume_unwind(payload),
503            }
504        }
505
506        Migrator {
507            manager: self.manager,
508            transaction: self.transaction,
509            _p: PhantomData,
510        }
511    }
512
513    /// Commit the migration transaction and return a [Database].
514    ///
515    /// Returns [None] if the database schema version is newer than `S`.
516    ///
517    /// This function will panic if the schema on disk does not match what is expected for its `user_version`.
518    pub fn finish(mut self) -> Option<Database<S>> {
519        let conn = &self.transaction;
520        if user_version(conn.get()).unwrap() != S::VERSION {
521            return None;
522        }
523
524        let res = std::thread::scope(|s| {
525            s.spawn(|| {
526                TXN.set(Some(self.transaction));
527                check_schema::<S>();
528                TXN.take().unwrap()
529            })
530            .join()
531        });
532        match res {
533            Ok(val) => self.transaction = val,
534            Err(payload) => std::panic::resume_unwind(payload),
535        }
536
537        // adds an sqlite_stat1 table
538        self.transaction
539            .get()
540            .execute_batch("PRAGMA optimize;")
541            .unwrap();
542
543        let schema_version = schema_version(self.transaction.get());
544        self.transaction.with(|x| x.commit().unwrap());
545
546        Some(Database {
547            manager: self.manager,
548            schema_version: AtomicI64::new(schema_version),
549            schema: PhantomData,
550            mut_lock: parking_lot::FairMutex::new(()),
551        })
552    }
553}
554
555pub fn schema_version(conn: &rusqlite::Transaction) -> i64 {
556    conn.pragma_query_value(None, "schema_version", |r| r.get(0))
557        .unwrap()
558}
559
560// Read user version field from the SQLite db
561pub fn user_version(conn: &rusqlite::Transaction) -> Result<i64, rusqlite::Error> {
562    conn.query_row("PRAGMA user_version", [], |row| row.get(0))
563}
564
565// Set user version field from the SQLite db
566fn set_user_version(conn: &rusqlite::Transaction, v: i64) -> Result<(), rusqlite::Error> {
567    conn.pragma_update(None, "user_version", v)
568}
569
570pub(crate) fn check_schema<S: Schema>() {
571    let mut b = TableTypBuilder::default();
572    S::typs(&mut b);
573    pretty_assertions::assert_eq!(
574        b.ast,
575        read_schema(&crate::Transaction::new()),
576        "schema is different (expected left, but got right)",
577    );
578}
579
580fn foreign_key_check(conn: &rusqlite::Transaction) -> Option<String> {
581    let error = conn
582        .prepare("PRAGMA foreign_key_check")
583        .unwrap()
584        .query_map([], |row| row.get(2))
585        .unwrap()
586        .next();
587    error.transpose().unwrap()
588}
589
590#[test]
591fn open_multiple() {
592    #[crate::migration::schema(Empty)]
593    pub mod vN {}
594
595    let _a = Database::<v0::Empty>::migrator(Config::open_in_memory());
596    let _b = Database::<v0::Empty>::migrator(Config::open_in_memory());
597}