rust_query/
migrate.rs

1pub mod config;
2pub mod migration;
3#[cfg(test)]
4mod test;
5
6use std::{
7    collections::{BTreeSet, HashMap},
8    marker::PhantomData,
9    sync::atomic::AtomicI64,
10};
11
12use annotate_snippets::{Renderer, renderer::DecorStyle};
13use rusqlite::config::DbConfig;
14use sea_query::{Alias, ColumnDef, IntoIden, SqliteQueryBuilder};
15use self_cell::MutBorrow;
16
17use crate::{
18    Table, Transaction,
19    alias::Scope,
20    migrate::{
21        config::Config,
22        migration::{SchemaBuilder, TransactionMigrate},
23    },
24    pool::Pool,
25    schema::{from_db, from_macro, read::read_schema},
26    transaction::{Database, OwnedTransaction, TXN, TransactionWithRows},
27};
28
29pub struct TableTypBuilder<S> {
30    pub(crate) ast: from_macro::Schema,
31    _p: PhantomData<S>,
32}
33
34impl<S> Default for TableTypBuilder<S> {
35    fn default() -> Self {
36        Self {
37            ast: Default::default(),
38            _p: Default::default(),
39        }
40    }
41}
42
43impl<S> TableTypBuilder<S> {
44    pub fn table<T: Table<Schema = S>>(&mut self) {
45        let table = from_macro::Table::new::<T>();
46        let old = self.ast.tables.insert(T::NAME, table);
47        debug_assert!(old.is_none());
48    }
49}
50
51pub trait Schema: Sized + 'static {
52    const VERSION: i64;
53    const SOURCE: &str;
54    const PATH: &str;
55    const SPAN: (usize, usize);
56    fn typs(b: &mut TableTypBuilder<Self>);
57}
58
59fn new_table_inner(table: &crate::schema::from_macro::Table, alias: impl IntoIden) -> String {
60    let alias = alias.into_iden();
61    let mut create = table.create();
62    create
63        .table(alias.clone())
64        .col(ColumnDef::new(Alias::new("id")).integer().primary_key());
65    let mut sql = create.to_string(SqliteQueryBuilder);
66    sql.push_str(" STRICT");
67    sql
68}
69
70pub trait SchemaMigration<'a> {
71    type From: Schema;
72    type To: Schema;
73
74    fn tables(self, b: &mut SchemaBuilder<'a, Self::From>);
75}
76
77impl<S: Schema> Database<S> {
78    /// Create a [Migrator] to migrate a database.
79    ///
80    /// Returns [None] if the database `user_version` on disk is older than `S`.
81    pub fn migrator(config: Config) -> Option<Migrator<S>> {
82        let synchronous = config.synchronous.as_str();
83        let foreign_keys = config.foreign_keys.as_str();
84        let manager = config.manager.with_init(move |inner| {
85            inner.pragma_update(None, "journal_mode", "WAL")?;
86            inner.pragma_update(None, "synchronous", synchronous)?;
87            inner.pragma_update(None, "foreign_keys", foreign_keys)?;
88            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DDL, false)?;
89            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DML, false)?;
90            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DEFENSIVE, true)?;
91            rusqlite::vtab::array::load_module(inner).unwrap();
92            Ok(())
93        });
94
95        use r2d2::ManageConnection;
96        let conn = manager.connect().unwrap();
97        conn.pragma_update(None, "foreign_keys", "OFF").unwrap();
98        let txn = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
99            Some(
100                conn.borrow_mut()
101                    .transaction_with_behavior(rusqlite::TransactionBehavior::Exclusive)
102                    .unwrap(),
103            )
104        });
105
106        let mut user_version = Some(user_version(txn.get()).unwrap());
107
108        // check if this database is newly created
109        if schema_version(txn.get()) == 0 {
110            user_version = None;
111
112            let schema = crate::schema::from_macro::Schema::new::<S>();
113
114            for (&table_name, table) in &schema.tables {
115                txn.get()
116                    .execute(&new_table_inner(table, table_name), [])
117                    .unwrap();
118                for stmt in table.delayed_indices(table_name) {
119                    txn.get().execute(&stmt, []).unwrap();
120                }
121            }
122            (config.init)(txn.get());
123        } else if user_version.unwrap() < S::VERSION {
124            // We can not migrate databases older than `S`
125            return None;
126        }
127
128        debug_assert_eq!(
129            foreign_key_check(txn.get()),
130            None,
131            "foreign key constraint violated"
132        );
133
134        Some(Migrator {
135            user_version,
136            manager,
137            transaction: txn,
138            _p: PhantomData,
139        })
140    }
141}
142
143/// [Migrator] is used to apply database migrations.
144///
145/// When all migrations are done, it can be turned into a [Database] instance with
146/// [Migrator::finish].
147pub struct Migrator<S> {
148    manager: r2d2_sqlite::SqliteConnectionManager,
149    transaction: OwnedTransaction,
150    // Initialized to the user version when the transaction starts.
151    // This is set to None if the schema user_version is updated.
152    // Fixups are only applied if the user_version is None.
153    // Indices are fixed before this is set to None.
154    user_version: Option<i64>,
155    _p: PhantomData<S>,
156}
157
158impl<S: Schema> Migrator<S> {
159    fn with_transaction(mut self, f: impl Send + FnOnce(&'static mut Transaction<S>)) -> Self {
160        assert!(self.user_version.is_none_or(|x| x == S::VERSION));
161        let res = std::thread::scope(|s| {
162            s.spawn(|| {
163                TXN.set(Some(TransactionWithRows::new_empty(self.transaction)));
164                let txn = Transaction::new_ref();
165
166                // check if this is the first migration that is applied
167                if self.user_version.take().is_some() {
168                    // we check the schema before doing any migrations
169                    check_schema::<S>(txn);
170                    // fixing indices before migrations can help with migration performance
171                    fix_indices::<S>(txn);
172                }
173
174                f(txn);
175
176                let transaction = TXN.take().unwrap();
177
178                transaction.into_owner()
179            })
180            .join()
181        });
182        match res {
183            Ok(val) => self.transaction = val,
184            Err(payload) => std::panic::resume_unwind(payload),
185        }
186        self
187    }
188
189    /// Apply a database migration if the current schema is `S` and return a [Migrator] for the next schema `N`.
190    ///
191    /// This function will panic if the schema on disk does not match what is expected for its `user_version`.
192    pub fn migrate<'x, M>(
193        mut self,
194        m: impl Send + FnOnce(&mut TransactionMigrate<S>) -> M,
195    ) -> Migrator<M::To>
196    where
197        M: SchemaMigration<'x, From = S>,
198    {
199        if self.user_version.is_none_or(|x| x == S::VERSION) {
200            self = self.with_transaction(|txn| {
201                let mut txn = TransactionMigrate {
202                    inner: txn.copy(),
203                    scope: Default::default(),
204                    rename_map: HashMap::new(),
205                    extra_index: Vec::new(),
206                };
207                let m = m(&mut txn);
208
209                let mut builder = SchemaBuilder {
210                    drop: vec![],
211                    foreign_key: HashMap::new(),
212                    inner: txn,
213                };
214                m.tables(&mut builder);
215
216                let transaction = TXN.take().unwrap();
217
218                for drop in builder.drop {
219                    let sql = drop.to_string(SqliteQueryBuilder);
220                    transaction.get().execute(&sql, []).unwrap();
221                }
222                for (to, tmp) in builder.inner.rename_map {
223                    let rename = sea_query::Table::rename().table(tmp, Alias::new(to)).take();
224                    let sql = rename.to_string(SqliteQueryBuilder);
225                    transaction.get().execute(&sql, []).unwrap();
226                }
227                #[allow(
228                    unreachable_code,
229                    reason = "rustc is stupid and thinks this is unreachable"
230                )]
231                if let Some(fk) = foreign_key_check(transaction.get()) {
232                    (builder.foreign_key.remove(&*fk).unwrap())();
233                }
234                // adding non unique indexes is fine to do after checking foreign keys
235                for stmt in builder.inner.extra_index {
236                    transaction.get().execute(&stmt, []).unwrap();
237                }
238
239                TXN.set(Some(transaction));
240            });
241        }
242
243        Migrator {
244            user_version: self.user_version,
245            manager: self.manager,
246            transaction: self.transaction,
247            _p: PhantomData,
248        }
249    }
250
251    /// Mutate the database as part of migrations.
252    ///
253    /// The closure will only be executed if the database got migrated to schema version `S`
254    /// by this [Migrator] instance.
255    /// If [Migrator::fixup] is used before all [Migrator::migrate], then the closures is only executed
256    /// when the database is created.
257    pub fn fixup(mut self, f: impl Send + FnOnce(&'static mut Transaction<S>)) -> Self {
258        if self.user_version.is_none() {
259            self = self.with_transaction(f);
260        }
261        self
262    }
263
264    /// Commit the migration transaction and return a [Database].
265    ///
266    /// Returns [None] if the database schema version is newer than `S`.
267    ///
268    /// This function will panic if the schema on disk does not match what is expected for its `user_version`.
269    pub fn finish(mut self) -> Option<Database<S>> {
270        if self.user_version.is_some_and(|x| x != S::VERSION) {
271            return None;
272        }
273
274        // This checks that the schema is correct and fixes indices etc
275        self = self.with_transaction(|txn| {
276            // sanity check, this should never fail
277            check_schema::<S>(txn);
278        });
279
280        // adds an sqlite_stat1 table
281        self.transaction
282            .get()
283            .execute_batch("PRAGMA optimize;")
284            .unwrap();
285
286        set_user_version(self.transaction.get(), S::VERSION).unwrap();
287        let schema_version = schema_version(self.transaction.get());
288        self.transaction.with(|x| x.commit().unwrap());
289
290        Some(Database {
291            manager: Pool::new(self.manager),
292            schema_version: AtomicI64::new(schema_version),
293            schema: PhantomData,
294            mut_lock: parking_lot::FairMutex::new(()),
295        })
296    }
297}
298
299fn fix_indices<S: Schema>(txn: &Transaction<S>) {
300    let schema = read_schema(txn);
301    let expected_schema = crate::schema::from_macro::Schema::new::<S>();
302
303    fn check_eq(expected: &from_macro::Table, actual: &from_db::Table) -> bool {
304        let expected: BTreeSet<_> = expected.indices.iter().map(|idx| &idx.def).collect();
305        let actual: BTreeSet<_> = actual.indices.values().collect();
306        expected == actual
307    }
308
309    for (&table_name, expected_table) in &expected_schema.tables {
310        let table = &schema.tables[table_name];
311
312        if !check_eq(expected_table, &table) {
313            // Unique constraints that are part of a table definition
314            // can not be dropped, so we assume the worst and just recreate
315            // the whole table.
316
317            let scope = Scope::default();
318            let tmp_name = scope.tmp_table();
319
320            txn.execute(&new_table_inner(expected_table, tmp_name));
321
322            let mut columns: Vec<_> = expected_table
323                .columns
324                .keys()
325                .map(|x| Alias::new(x))
326                .collect();
327            columns.push(Alias::new("id"));
328
329            txn.execute(
330                &sea_query::InsertStatement::new()
331                    .into_table(tmp_name)
332                    .columns(columns.clone())
333                    .select_from(
334                        sea_query::SelectStatement::new()
335                            .from(table_name)
336                            .columns(columns)
337                            .take(),
338                    )
339                    .unwrap()
340                    .build(SqliteQueryBuilder)
341                    .0,
342            );
343
344            txn.execute(
345                &sea_query::TableDropStatement::new()
346                    .table(table_name)
347                    .build(SqliteQueryBuilder),
348            );
349
350            txn.execute(
351                &sea_query::TableRenameStatement::new()
352                    .table(tmp_name, table_name)
353                    .build(SqliteQueryBuilder),
354            );
355            // Add the new non-unique indices
356            for sql in expected_table.delayed_indices(table_name) {
357                txn.execute(&sql);
358            }
359        }
360    }
361
362    // check that we solved the mismatch
363    let schema = read_schema(txn);
364    for (name, table) in schema.tables {
365        let expected_table = &expected_schema.tables[&*name];
366        assert!(check_eq(expected_table, &table));
367    }
368}
369
370impl<S> Transaction<S> {
371    #[track_caller]
372    pub(crate) fn execute(&self, sql: &str) {
373        TXN.with_borrow(|txn| txn.as_ref().unwrap().get().execute(sql, []))
374            .unwrap();
375    }
376}
377
378pub fn schema_version(conn: &rusqlite::Transaction) -> i64 {
379    conn.pragma_query_value(None, "schema_version", |r| r.get(0))
380        .unwrap()
381}
382
383// Read user version field from the SQLite db
384pub fn user_version(conn: &rusqlite::Transaction) -> Result<i64, rusqlite::Error> {
385    conn.query_row("PRAGMA user_version", [], |row| row.get(0))
386}
387
388// Set user version field from the SQLite db
389fn set_user_version(conn: &rusqlite::Transaction, v: i64) -> Result<(), rusqlite::Error> {
390    conn.pragma_update(None, "user_version", v)
391}
392
393pub(crate) fn check_schema<S: Schema>(txn: &Transaction<S>) {
394    let from_macro = crate::schema::from_macro::Schema::new::<S>();
395    let from_db = read_schema(txn);
396    let report = from_db.diff(from_macro, S::SOURCE, S::PATH, S::VERSION);
397    if !report.is_empty() {
398        let renderer = if cfg!(test) {
399            Renderer::plain().anonymized_line_numbers(true)
400        } else {
401            Renderer::styled()
402        }
403        .decor_style(DecorStyle::Unicode);
404        panic!("{}", renderer.render(&report));
405    }
406}
407
408fn foreign_key_check(conn: &rusqlite::Transaction) -> Option<String> {
409    let error = conn
410        .prepare("PRAGMA foreign_key_check")
411        .unwrap()
412        .query_map([], |row| row.get(2))
413        .unwrap()
414        .next();
415    error.transpose().unwrap()
416}
417
418impl<S> Transaction<S> {
419    #[cfg(test)]
420    pub(crate) fn schema(&self) -> Vec<String> {
421        TXN.with_borrow(|x| {
422            x.as_ref()
423                .unwrap()
424                .get()
425                .prepare("SELECT sql FROM 'main'.'sqlite_schema'")
426                .unwrap()
427                .query_map([], |row| row.get::<_, Option<String>>("sql"))
428                .unwrap()
429                .flat_map(|x| x.unwrap())
430                .collect()
431        })
432    }
433}