rust_query/
migrate.rs

1use std::{marker::PhantomData, path::Path, sync::atomic::AtomicBool};
2
3use rusqlite::{config::DbConfig, Connection};
4use sea_query::{
5    Alias, ColumnDef, InsertStatement, IntoTableRef, SqliteQueryBuilder, TableDropStatement,
6    TableRenameStatement,
7};
8use sea_query_rusqlite::RusqliteBinder;
9
10use crate::{
11    alias::{Scope, TmpTable},
12    ast::MySelect,
13    dummy::{Cached, Cacher},
14    hash,
15    insert::Reader,
16    pragma::read_schema,
17    token::LocalClient,
18    transaction::Database,
19    value, Column, IntoColumn, Rows, Table,
20};
21
22pub type M<'a, From, To> = Box<
23    dyn 'a
24        + for<'t> FnOnce(
25            ::rust_query::Column<'t, <From as Table>::Schema, From>,
26        ) -> Alter<'t, 'a, From, To>,
27>;
28
29/// This is the type used to return table alterations in migrations.
30///
31/// Note that migrations allow you to use anything that implements [crate::Dummy] to specify the new values.
32/// In particular this allows mapping values using native rust with [crate::Dummy::map_dummy].
33///
34/// Take a look at the documentation of [crate::migration::schema] for more general information.
35///
36/// The purpose of wrapping migration results in [Alter] (and [Create]) is to dyn box the type so that type inference works.
37/// (Type inference is problematic with higher ranked generic returns from closures).
38/// Futhermore [Alter] (and [Create]) also have an implied bound of `'a: 't` which makes it easier to implement migrations.
39pub struct Alter<'t, 'a, From, To> {
40    _p: PhantomData<&'t &'a ()>,
41    inner: Box<dyn TableMigration<'t, 'a, From = From, To = To> + 't>,
42}
43
44impl<'t, 'a, From, To> Alter<'t, 'a, From, To> {
45    pub fn new(val: impl TableMigration<'t, 'a, From = From, To = To> + 't) -> Self {
46        Self {
47            _p: PhantomData,
48            inner: Box::new(val),
49        }
50    }
51}
52
53pub type C<'a, FromSchema, To> =
54    Box<dyn 'a + for<'t> FnOnce(&mut Rows<'t, FromSchema>) -> Create<'t, 'a, FromSchema, To>>;
55
56/// This is the type used to return table creations in migrations.
57///
58/// For more information take a look at [Alter].
59pub struct Create<'t, 'a, FromSchema, To> {
60    _p: PhantomData<&'t &'a ()>,
61    inner: Box<dyn TableCreation<'t, 'a, FromSchema = FromSchema, To = To> + 't>,
62}
63
64impl<'t, 'a, FromSchema, To: 'a> Create<'t, 'a, FromSchema, To> {
65    pub fn new(val: impl TableCreation<'t, 'a, FromSchema = FromSchema, To = To> + 't) -> Self {
66        Self {
67            _p: PhantomData,
68            inner: Box::new(val),
69        }
70    }
71
72    /// Use this if you want the new table to be empty.
73    pub fn empty(rows: &mut Rows<'t, FromSchema>) -> Self {
74        rows.filter(false);
75        Create::new(NeverCreate(PhantomData, PhantomData))
76    }
77}
78
79struct NeverCreate<FromSchema, To>(PhantomData<FromSchema>, PhantomData<To>);
80
81impl<'t, 'a, FromSchema, To> TableCreation<'t, 'a> for NeverCreate<FromSchema, To> {
82    type FromSchema = FromSchema;
83    type To = To;
84
85    fn prepare(
86        self: Box<Self>,
87        _: Cacher<'_, 't, Self::FromSchema>,
88    ) -> Box<dyn FnMut(crate::private::Row<'_, 't, 'a>, Reader<'_, 't, Self::FromSchema>) + 't>
89    where
90        'a: 't,
91    {
92        Box::new(|_, _| unreachable!())
93    }
94}
95
96pub struct TableTypBuilder<S> {
97    pub(crate) ast: hash::Schema,
98    _p: PhantomData<S>,
99}
100
101impl<S> Default for TableTypBuilder<S> {
102    fn default() -> Self {
103        Self {
104            ast: Default::default(),
105            _p: Default::default(),
106        }
107    }
108}
109
110impl<S> TableTypBuilder<S> {
111    pub fn table<T: Table<Schema = S>>(&mut self) {
112        let mut b = hash::TypBuilder::default();
113        T::typs(&mut b);
114        self.ast.tables.insert((T::NAME.to_owned(), b.ast));
115    }
116}
117
118pub trait Schema: Sized + 'static {
119    const VERSION: i64;
120    fn typs(b: &mut TableTypBuilder<Self>);
121}
122
123pub trait TableMigration<'t, 'a> {
124    type From: Table;
125    type To;
126
127    fn prepare(
128        self: Box<Self>,
129        prev: Cached<'t, Self::From>,
130        cacher: Cacher<'_, 't, <Self::From as Table>::Schema>,
131    ) -> Box<
132        dyn FnMut(crate::private::Row<'_, 't, 'a>, Reader<'_, 't, <Self::From as Table>::Schema>)
133            + 't,
134    >
135    where
136        'a: 't;
137}
138
139pub trait TableCreation<'t, 'a> {
140    type FromSchema;
141    type To;
142
143    fn prepare(
144        self: Box<Self>,
145        cacher: Cacher<'_, 't, Self::FromSchema>,
146    ) -> Box<dyn FnMut(crate::private::Row<'_, 't, 'a>, Reader<'_, 't, Self::FromSchema>) + 't>
147    where
148        'a: 't;
149}
150
151struct Wrapper<'t, 'a, From: Table, To>(
152    Box<dyn TableMigration<'t, 'a, From = From, To = To> + 't>,
153    Column<'t, From::Schema, From>,
154);
155
156impl<'t, 'a, From: Table, To> TableCreation<'t, 'a> for Wrapper<'t, 'a, From, To> {
157    type FromSchema = From::Schema;
158    type To = To;
159
160    fn prepare(
161        self: Box<Self>,
162        mut cacher: Cacher<'_, 't, Self::FromSchema>,
163    ) -> Box<dyn FnMut(crate::private::Row<'_, 't, 'a>, Reader<'_, 't, Self::FromSchema>) + 't>
164    where
165        'a: 't,
166    {
167        let db_id = cacher.cache(self.1);
168        let mut prepared = Box::new(self.0).prepare(db_id, cacher);
169        Box::new(move |row, reader| {
170            // keep the ID the same
171            reader.col(From::ID, row.get(db_id));
172            prepared(row, reader);
173        })
174    }
175}
176
177impl<'inner, S> Rows<'inner, S> {
178    fn cacher<'t>(&'_ self) -> Cacher<'_, 't, S> {
179        Cacher {
180            ast: &self.ast,
181            _p: PhantomData,
182        }
183    }
184}
185
186pub struct SchemaBuilder<'x, 'a> {
187    // this is used to create temporary table names
188    scope: Scope,
189    conn: &'x rusqlite::Transaction<'x>,
190    drop: Vec<TableDropStatement>,
191    rename: Vec<TableRenameStatement>,
192    _p: PhantomData<fn(&'a ()) -> &'a ()>,
193}
194
195impl<'a> SchemaBuilder<'_, 'a> {
196    pub fn migrate_table<From: Table, To: Table>(&mut self, m: M<'a, From, To>) {
197        self.create_inner::<From::Schema, To>(|rows| {
198            let db_id = From::join(rows);
199            let migration = m(db_id.clone());
200            Create::new(Wrapper(migration.inner, db_id))
201        });
202
203        self.drop.push(
204            sea_query::Table::drop()
205                .table(Alias::new(From::NAME))
206                .take(),
207        );
208    }
209
210    pub fn create_from<FromSchema, To: Table>(&mut self, f: C<'a, FromSchema, To>) {
211        self.create_inner::<FromSchema, To>(f);
212    }
213
214    fn create_inner<FromSchema, To: Table>(
215        &mut self,
216        f: impl for<'t> FnOnce(&mut Rows<'t, FromSchema>) -> Create<'t, 'a, FromSchema, To>,
217    ) {
218        let new_table_name = self.scope.tmp_table();
219        new_table::<To>(self.conn, new_table_name);
220
221        self.rename.push(
222            sea_query::Table::rename()
223                .table(new_table_name, Alias::new(To::NAME))
224                .take(),
225        );
226
227        let mut q = Rows::<FromSchema> {
228            phantom: PhantomData,
229            ast: MySelect::default(),
230        };
231        let create = f(&mut q);
232        let mut prepared = create.inner.prepare(q.cacher());
233
234        let select = q.ast.simple();
235        let (sql, values) = select.build_rusqlite(SqliteQueryBuilder);
236
237        // no caching here, migration is only executed once
238        let mut statement = self.conn.prepare(&sql).unwrap();
239        let mut rows = statement.query(&*values.as_params()).unwrap();
240
241        while let Some(row) = rows.next().unwrap() {
242            let row = crate::private::Row {
243                _p: PhantomData,
244                _p2: PhantomData,
245                row,
246            };
247
248            let new_ast = MySelect::default();
249            let reader = Reader {
250                ast: &new_ast,
251                _p: PhantomData,
252                _p2: PhantomData,
253            };
254            prepared(row, reader);
255
256            let mut insert = InsertStatement::new();
257            let names = new_ast.select.iter().map(|(_field, name)| *name);
258            insert.into_table(new_table_name);
259            insert.columns(names);
260            insert.select_from(new_ast.simple()).unwrap();
261
262            let (sql, values) = insert.build_rusqlite(SqliteQueryBuilder);
263            let mut statement = self.conn.prepare_cached(&sql).unwrap();
264            statement.execute(&*values.as_params()).unwrap();
265        }
266    }
267
268    pub fn drop_table<T: Table>(&mut self) {
269        let name = Alias::new(T::NAME);
270        let step = sea_query::Table::drop().table(name).take();
271        self.drop.push(step);
272    }
273}
274
275fn new_table<T: Table>(conn: &Connection, alias: TmpTable) {
276    let mut f = crate::hash::TypBuilder::default();
277    T::typs(&mut f);
278    new_table_inner(conn, &f.ast, alias);
279}
280
281fn new_table_inner(conn: &Connection, table: &crate::hash::Table, alias: impl IntoTableRef) {
282    let mut create = table.create();
283    create
284        .table(alias)
285        .col(ColumnDef::new(Alias::new("id")).integer().primary_key());
286    let mut sql = create.to_string(SqliteQueryBuilder);
287    sql.push_str(" STRICT");
288    conn.execute(&sql, []).unwrap();
289}
290
291pub trait Migration<'a> {
292    type From: Schema;
293    type To: Schema;
294
295    fn tables(self, b: &mut SchemaBuilder<'_, 'a>);
296}
297
298/// [Config] is used to open a database from a file or in memory.
299///
300/// This is the first step in the [Config] -> [Migrator] -> [Database] chain to
301/// get a [Database] instance.
302pub struct Config {
303    manager: r2d2_sqlite::SqliteConnectionManager,
304    init: Box<dyn FnOnce(&rusqlite::Transaction)>,
305}
306
307static ALLOWED: AtomicBool = AtomicBool::new(true);
308
309impl Config {
310    /// Open a database that is stored in a file.
311    /// Creates the database if it does not exist.
312    ///
313    /// Opening the same database multiple times at the same time is fine,
314    /// as long as they migrate to or use the same schema.
315    /// All locking is done by sqlite, so connections can even be made using different client implementations.
316    pub fn open(p: impl AsRef<Path>) -> Self {
317        let manager = r2d2_sqlite::SqliteConnectionManager::file(p);
318        Self::open_internal(manager)
319    }
320
321    /// Creates a new empty database in memory.
322    pub fn open_in_memory() -> Self {
323        let manager = r2d2_sqlite::SqliteConnectionManager::memory();
324        Self::open_internal(manager)
325    }
326
327    fn open_internal(manager: r2d2_sqlite::SqliteConnectionManager) -> Self {
328        assert!(ALLOWED.swap(false, std::sync::atomic::Ordering::Relaxed));
329        let manager = manager.with_init(|inner| {
330            inner.pragma_update(None, "journal_mode", "WAL")?;
331            inner.pragma_update(None, "synchronous", "NORMAL")?;
332            inner.pragma_update(None, "foreign_keys", "ON")?;
333            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DDL, false)?;
334            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DML, false)?;
335            inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DEFENSIVE, true)?;
336            Ok(())
337        });
338
339        Self {
340            manager,
341            init: Box::new(|_| {}),
342        }
343    }
344
345    /// Execute a raw sql statement if the database was just created.
346    ///
347    /// The statement is executed after creating the empty database and executingall previous statements.
348    pub fn init_stmt(mut self, sql: &'static str) -> Self {
349        self.init = Box::new(move |txn| {
350            (self.init)(txn);
351
352            txn.execute_batch(sql)
353                .expect("raw sql statement to populate db failed");
354        });
355        self
356    }
357}
358
359impl LocalClient {
360    /// Create a [Migrator] to migrate a database.
361    ///
362    /// Returns [None] if the database `user_version` on disk is older than `S`.
363    ///
364    /// This function will panic if the schema on disk does not match what is expected for its `user_version`.
365    pub fn migrator<'t, S: Schema>(&'t mut self, config: Config) -> Option<Migrator<'t, S>> {
366        use r2d2::ManageConnection;
367        let conn = self.conn.insert(config.manager.connect().unwrap());
368        conn.pragma_update(None, "foreign_keys", "OFF").unwrap();
369
370        let conn = conn
371            .transaction_with_behavior(rusqlite::TransactionBehavior::Exclusive)
372            .unwrap();
373
374        // check if this database is newly created
375        if schema_version(&conn) == 0 {
376            let mut b = TableTypBuilder::default();
377            S::typs(&mut b);
378
379            for (table_name, table) in &*b.ast.tables {
380                new_table_inner(&conn, table, Alias::new(table_name));
381            }
382            (config.init)(&conn);
383            set_user_version(&conn, S::VERSION).unwrap();
384        }
385
386        let user_version = user_version(&conn).unwrap();
387        // We can not migrate databases older than `S`
388        if user_version < S::VERSION {
389            return None;
390        } else if user_version == S::VERSION {
391            foreign_key_check::<S>(&conn);
392        }
393
394        Some(Migrator {
395            manager: config.manager,
396            transaction: conn,
397            _p: PhantomData,
398            _local: PhantomData,
399        })
400    }
401}
402
403/// [Migrator] is used to apply database migrations.
404///
405/// When all migrations are done, it can be turned into a [Database] instance with
406/// [Migrator::finish].
407pub struct Migrator<'t, S> {
408    manager: r2d2_sqlite::SqliteConnectionManager,
409    transaction: rusqlite::Transaction<'t>,
410    _p: PhantomData<S>,
411    // We want to make sure that Migrator is always used with the same LocalClient
412    // so we make it local to the current thread.
413    // This is mostly important because the LocalClient can have a reference to our transaction.
414    _local: PhantomData<LocalClient>,
415}
416
417impl<'t, S: Schema> Migrator<'t, S> {
418    /// Apply a database migration if the current schema is `S` and return a [Migrator] for the next schema `N`.
419    ///
420    /// This function will panic if the schema on disk does not match what is expected for its `user_version`.
421    pub fn migrate<M, N: Schema>(self, m: M) -> Migrator<'t, N>
422    where
423        M: Migration<'t, From = S, To = N>,
424    {
425        let conn = &self.transaction;
426
427        if user_version(conn).unwrap() == S::VERSION {
428            let mut builder = SchemaBuilder {
429                scope: Default::default(),
430                conn,
431                drop: vec![],
432                rename: vec![],
433                _p: PhantomData,
434            };
435            m.tables(&mut builder);
436            for drop in builder.drop {
437                let sql = drop.to_string(SqliteQueryBuilder);
438                conn.execute(&sql, []).unwrap();
439            }
440            for rename in builder.rename {
441                let sql = rename.to_string(SqliteQueryBuilder);
442                conn.execute(&sql, []).unwrap();
443            }
444            foreign_key_check::<N>(conn);
445            set_user_version(conn, N::VERSION).unwrap();
446        }
447
448        Migrator {
449            manager: self.manager,
450            transaction: self.transaction,
451            _p: PhantomData,
452            _local: PhantomData,
453        }
454    }
455
456    /// Commit the migration transaction and return a [Database].
457    ///
458    /// Returns [None] if the database schema version is newer than `S`.
459    pub fn finish(self) -> Option<Database<S>> {
460        let conn = &self.transaction;
461        if user_version(conn).unwrap() != S::VERSION {
462            return None;
463        }
464
465        let schema_version = schema_version(conn);
466        self.transaction.commit().unwrap();
467
468        Some(Database {
469            manager: self.manager,
470            schema_version,
471            schema: PhantomData,
472        })
473    }
474}
475
476pub fn schema_version(conn: &rusqlite::Transaction) -> i64 {
477    conn.pragma_query_value(None, "schema_version", |r| r.get(0))
478        .unwrap()
479}
480
481// Read user version field from the SQLite db
482fn user_version(conn: &rusqlite::Transaction) -> Result<i64, rusqlite::Error> {
483    conn.query_row("PRAGMA user_version", [], |row| row.get(0))
484}
485
486// Set user version field from the SQLite db
487fn set_user_version(conn: &rusqlite::Transaction, v: i64) -> Result<(), rusqlite::Error> {
488    conn.pragma_update(None, "user_version", v)
489}
490
491fn foreign_key_check<S: Schema>(conn: &rusqlite::Transaction) {
492    let errors = conn
493        .prepare("PRAGMA foreign_key_check")
494        .unwrap()
495        .query_map([], |_| Ok(()))
496        .unwrap()
497        .count();
498    if errors != 0 {
499        panic!("migration violated foreign key constraint")
500    }
501
502    let mut b = TableTypBuilder::default();
503    S::typs(&mut b);
504    pretty_assertions::assert_eq!(
505        b.ast,
506        read_schema(conn),
507        "schema is different (expected left, but got right)",
508    );
509}
510
511/// Special table name that is used as souce of newly created tables.
512#[derive(Clone, Copy)]
513pub struct NoTable(());
514
515impl value::Typed for NoTable {
516    type Typ = NoTable;
517    fn build_expr(&self, _b: value::ValueBuilder) -> sea_query::SimpleExpr {
518        unreachable!("NoTable can not be constructed")
519    }
520}
521impl<S> IntoColumn<'_, S> for NoTable {
522    type Owned = Self;
523
524    fn into_owned(self) -> Self::Owned {
525        self
526    }
527}