rust_query/
migrate.rs

1pub mod config;
2pub mod migration;
3
4use std::{
5    collections::{BTreeSet, HashMap},
6    marker::PhantomData,
7    sync::atomic::AtomicI64,
8};
9
10use annotate_snippets::{Renderer, renderer::DecorStyle};
11use rusqlite::{Connection, config::DbConfig};
12use sea_query::{Alias, ColumnDef, IntoTableRef, SqliteQueryBuilder};
13use self_cell::MutBorrow;
14
15use crate::{
16    Table, Transaction,
17    migrate::{
18        config::Config,
19        migration::{SchemaBuilder, TransactionMigrate},
20    },
21    pool::Pool,
22    schema::{
23        from_db, from_macro,
24        read::{read_index_names_for_table, read_schema},
25    },
26    transaction::{Database, OwnedTransaction, TXN, TransactionWithRows},
27};
28
29pub struct TableTypBuilder<S> {
30    pub(crate) ast: from_macro::Schema,
31    _p: PhantomData<S>,
32}
33
34impl<S> Default for TableTypBuilder<S> {
35    fn default() -> Self {
36        Self {
37            ast: Default::default(),
38            _p: Default::default(),
39        }
40    }
41}
42
43impl<S> TableTypBuilder<S> {
44    pub fn table<T: Table<Schema = S>>(&mut self) {
45        let table = from_macro::Table::new::<T>();
46        let old = self.ast.tables.insert(T::NAME.to_owned(), table);
47        debug_assert!(old.is_none());
48    }
49}
50
51pub trait Schema: Sized + 'static {
52    const VERSION: i64;
53    const SOURCE: &str;
54    const PATH: &str;
55    const SPAN: (usize, usize);
56    fn typs(b: &mut TableTypBuilder<Self>);
57}
58
59fn new_table_inner(
60    conn: &Connection,
61    table: &crate::schema::from_macro::Table,
62    alias: impl IntoTableRef,
63) {
64    let mut create = table.create();
65    create
66        .table(alias)
67        .col(ColumnDef::new(Alias::new("id")).integer().primary_key());
68    let mut sql = create.to_string(SqliteQueryBuilder);
69    sql.push_str(" STRICT");
70    conn.execute(&sql, []).unwrap();
71}
72
73pub trait SchemaMigration<'a> {
74    type From: Schema;
75    type To: Schema;
76
77    fn tables(self, b: &mut SchemaBuilder<'a, Self::From>);
78}
79
80impl<S: Schema> Database<S> {
81    /// Create a [Migrator] to migrate a database.
82    ///
83    /// Returns [None] if the database `user_version` on disk is older than `S`.
84    pub fn migrator(config: Config) -> Option<Migrator<S>> {
85        let synchronous = config.synchronous.as_str();
86        let foreign_keys = config.foreign_keys.as_str();
87        let manager = config.manager.with_init(move |inner| {
88            inner.pragma_update(None, "journal_mode", "WAL")?;
89            inner.pragma_update(None, "synchronous", synchronous)?;
90            inner.pragma_update(None, "foreign_keys", foreign_keys)?;
91            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DDL, false)?;
92            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DML, false)?;
93            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DEFENSIVE, true)?;
94            Ok(())
95        });
96
97        use r2d2::ManageConnection;
98        let conn = manager.connect().unwrap();
99        conn.pragma_update(None, "foreign_keys", "OFF").unwrap();
100        let txn = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
101            Some(
102                conn.borrow_mut()
103                    .transaction_with_behavior(rusqlite::TransactionBehavior::Exclusive)
104                    .unwrap(),
105            )
106        });
107
108        // check if this database is newly created
109        if schema_version(txn.get()) == 0 {
110            let schema = crate::schema::from_macro::Schema::new::<S>();
111
112            for (table_name, table) in &schema.tables {
113                let table_name_ref = Alias::new(table_name);
114                new_table_inner(txn.get(), table, table_name_ref);
115                for stmt in table.create_indices(table_name) {
116                    txn.get().execute(&stmt, []).unwrap();
117                }
118            }
119            (config.init)(txn.get());
120            set_user_version(txn.get(), S::VERSION).unwrap();
121        }
122
123        let user_version = user_version(txn.get()).unwrap();
124        // We can not migrate databases older than `S`
125        if user_version < S::VERSION {
126            return None;
127        }
128        debug_assert_eq!(
129            foreign_key_check(txn.get()),
130            None,
131            "foreign key constraint violated"
132        );
133
134        Some(Migrator {
135            indices_fixed: false,
136            manager,
137            transaction: txn,
138            _p: PhantomData,
139        })
140    }
141}
142
143/// [Migrator] is used to apply database migrations.
144///
145/// When all migrations are done, it can be turned into a [Database] instance with
146/// [Migrator::finish].
147pub struct Migrator<S> {
148    manager: r2d2_sqlite::SqliteConnectionManager,
149    transaction: OwnedTransaction,
150    indices_fixed: bool,
151    _p: PhantomData<S>,
152}
153
154impl<S: Schema> Migrator<S> {
155    /// Apply a database migration if the current schema is `S` and return a [Migrator] for the next schema `N`.
156    ///
157    /// This function will panic if the schema on disk does not match what is expected for its `user_version`.
158    pub fn migrate<'x, M>(
159        mut self,
160        m: impl Send + FnOnce(&mut TransactionMigrate<S>) -> M,
161    ) -> Migrator<M::To>
162    where
163        M: SchemaMigration<'x, From = S>,
164    {
165        if user_version(self.transaction.get()).unwrap() == S::VERSION {
166            let res = std::thread::scope(|s| {
167                s.spawn(|| {
168                    TXN.set(Some(TransactionWithRows::new_empty(self.transaction)));
169                    let txn = Transaction::new_ref();
170
171                    check_schema::<S>(txn);
172                    if !self.indices_fixed {
173                        fix_indices::<S>(txn);
174                        self.indices_fixed = true;
175                    }
176
177                    let mut txn = TransactionMigrate {
178                        inner: Transaction::new(),
179                        scope: Default::default(),
180                        rename_map: HashMap::new(),
181                        extra_index: Vec::new(),
182                    };
183                    let m = m(&mut txn);
184
185                    let mut builder = SchemaBuilder {
186                        drop: vec![],
187                        foreign_key: HashMap::new(),
188                        inner: txn,
189                    };
190                    m.tables(&mut builder);
191
192                    let transaction = TXN.take().unwrap();
193
194                    for drop in builder.drop {
195                        let sql = drop.to_string(SqliteQueryBuilder);
196                        transaction.get().execute(&sql, []).unwrap();
197                    }
198                    for (to, tmp) in builder.inner.rename_map {
199                        let rename = sea_query::Table::rename().table(tmp, Alias::new(to)).take();
200                        let sql = rename.to_string(SqliteQueryBuilder);
201                        transaction.get().execute(&sql, []).unwrap();
202                    }
203                    if let Some(fk) = foreign_key_check(transaction.get()) {
204                        (builder.foreign_key.remove(&*fk).unwrap())();
205                    }
206                    #[allow(
207                        unreachable_code,
208                        reason = "rustc is stupid and thinks this is unreachable"
209                    )]
210                    // adding indexes is fine to do after checking foreign keys
211                    for stmt in builder.inner.extra_index {
212                        transaction.get().execute(&stmt, []).unwrap();
213                    }
214                    set_user_version(transaction.get(), M::To::VERSION).unwrap();
215
216                    transaction.into_owner()
217                })
218                .join()
219            });
220            match res {
221                Ok(val) => self.transaction = val,
222                Err(payload) => std::panic::resume_unwind(payload),
223            }
224        }
225
226        Migrator {
227            indices_fixed: self.indices_fixed,
228            manager: self.manager,
229            transaction: self.transaction,
230            _p: PhantomData,
231        }
232    }
233
234    /// Commit the migration transaction and return a [Database].
235    ///
236    /// Returns [None] if the database schema version is newer than `S`.
237    ///
238    /// This function will panic if the schema on disk does not match what is expected for its `user_version`.
239    pub fn finish(mut self) -> Option<Database<S>> {
240        if user_version(self.transaction.get()).unwrap() != S::VERSION {
241            return None;
242        }
243
244        let res = std::thread::scope(|s| {
245            s.spawn(|| {
246                TXN.set(Some(TransactionWithRows::new_empty(self.transaction)));
247                let txn = Transaction::new_ref();
248
249                check_schema::<S>(txn);
250                if !self.indices_fixed {
251                    fix_indices::<S>(txn);
252                    self.indices_fixed = true;
253                }
254
255                TXN.take().unwrap().into_owner()
256            })
257            .join()
258        });
259        match res {
260            Ok(val) => self.transaction = val,
261            Err(payload) => std::panic::resume_unwind(payload),
262        }
263
264        // adds an sqlite_stat1 table
265        self.transaction
266            .get()
267            .execute_batch("PRAGMA optimize;")
268            .unwrap();
269
270        let schema_version = schema_version(self.transaction.get());
271        self.transaction.with(|x| x.commit().unwrap());
272
273        Some(Database {
274            manager: Pool::new(self.manager),
275            schema_version: AtomicI64::new(schema_version),
276            schema: PhantomData,
277            mut_lock: parking_lot::FairMutex::new(()),
278        })
279    }
280}
281
282fn fix_indices<S: Schema>(txn: &Transaction<S>) {
283    let schema = read_schema(txn);
284    let expected_schema = crate::schema::from_macro::Schema::new::<S>();
285
286    fn check_eq(expected: &from_macro::Table, actual: &from_db::Table) -> bool {
287        let expected: BTreeSet<_> = expected.indices.iter().map(|idx| &idx.def).collect();
288        let actual: BTreeSet<_> = actual.indices.values().collect();
289        expected == actual
290    }
291
292    for (name, table) in schema.tables {
293        let expected_table = &expected_schema.tables[&name];
294
295        if !check_eq(expected_table, &table) {
296            // Delete all indices associated with the table
297            for index_name in read_index_names_for_table(&crate::Transaction::new(), &name) {
298                let sql = sea_query::Index::drop()
299                    .name(index_name)
300                    .build(SqliteQueryBuilder);
301                txn.execute(&sql);
302            }
303
304            // Add the new indices
305            for sql in expected_table.create_indices(&name) {
306                txn.execute(&sql);
307            }
308        }
309    }
310
311    // check that we solved the mismatch
312    let schema = read_schema(txn);
313    for (name, table) in schema.tables {
314        let expected_table = &expected_schema.tables[&name];
315        assert!(check_eq(expected_table, &table));
316    }
317}
318
319impl<S> Transaction<S> {
320    pub(crate) fn execute(&self, sql: &str) {
321        TXN.with_borrow(|txn| txn.as_ref().unwrap().get().execute(sql, []))
322            .unwrap();
323    }
324}
325
326pub fn schema_version(conn: &rusqlite::Transaction) -> i64 {
327    conn.pragma_query_value(None, "schema_version", |r| r.get(0))
328        .unwrap()
329}
330
331// Read user version field from the SQLite db
332pub fn user_version(conn: &rusqlite::Transaction) -> Result<i64, rusqlite::Error> {
333    conn.query_row("PRAGMA user_version", [], |row| row.get(0))
334}
335
336// Set user version field from the SQLite db
337fn set_user_version(conn: &rusqlite::Transaction, v: i64) -> Result<(), rusqlite::Error> {
338    conn.pragma_update(None, "user_version", v)
339}
340
341pub(crate) fn check_schema<S: Schema>(txn: &Transaction<S>) {
342    let from_macro = crate::schema::from_macro::Schema::new::<S>();
343    let from_db = read_schema(txn);
344    let report = from_db.diff(from_macro, S::SOURCE, S::PATH, S::VERSION);
345    if !report.is_empty() {
346        let renderer = if cfg!(test) {
347            Renderer::plain().anonymized_line_numbers(true)
348        } else {
349            Renderer::styled()
350        }
351        .decor_style(DecorStyle::Unicode);
352        panic!("{}", renderer.render(&report));
353    }
354}
355
356fn foreign_key_check(conn: &rusqlite::Transaction) -> Option<String> {
357    let error = conn
358        .prepare("PRAGMA foreign_key_check")
359        .unwrap()
360        .query_map([], |row| row.get(2))
361        .unwrap()
362        .next();
363    error.transpose().unwrap()
364}
365
366impl<S> Transaction<S> {
367    #[cfg(test)]
368    pub(crate) fn schema(&self) -> Vec<String> {
369        TXN.with_borrow(|x| {
370            x.as_ref()
371                .unwrap()
372                .get()
373                .prepare("SELECT sql FROM 'main'.'sqlite_schema'")
374                .unwrap()
375                .query_map([], |row| row.get("sql"))
376                .unwrap()
377                .map(|x| x.unwrap())
378                .collect()
379        })
380    }
381}