1use std::{
2 collections::{HashMap, HashSet},
3 convert::Infallible,
4 marker::PhantomData,
5 ops::Deref,
6 path::Path,
7 sync::atomic::AtomicI64,
8};
9
10use rusqlite::{Connection, config::DbConfig};
11use sea_query::{Alias, ColumnDef, IntoTableRef, SqliteQueryBuilder, TableDropStatement};
12use self_cell::MutBorrow;
13
14use crate::{
15 IntoExpr, Lazy, Table, TableRow, Transaction,
16 alias::{Scope, TmpTable},
17 hash,
18 schema_pragma::read_schema,
19 transaction::{Database, OwnedTransaction, TXN, TransactionWithRows, try_insert_private},
20};
21
22pub struct TableTypBuilder<S> {
23 pub(crate) ast: hash::Schema,
24 _p: PhantomData<S>,
25}
26
27impl<S> Default for TableTypBuilder<S> {
28 fn default() -> Self {
29 Self {
30 ast: Default::default(),
31 _p: Default::default(),
32 }
33 }
34}
35
36impl<S> TableTypBuilder<S> {
37 pub fn table<T: Table<Schema = S>>(&mut self) {
38 let table = hash::Table::new::<T>();
39 let old = self.ast.tables.insert(T::NAME.to_owned(), table);
40 debug_assert!(old.is_none());
41 }
42}
43
44pub trait Schema: Sized + 'static {
45 const VERSION: i64;
46 fn typs(b: &mut TableTypBuilder<Self>);
47}
48
49pub trait Migration {
50 type FromSchema: 'static;
51 type From: Table<Schema = Self::FromSchema>;
52 type To: Table<MigrateFrom = Self::From>;
53 type Conflict;
54
55 #[doc(hidden)]
56 fn prepare(
57 val: Self,
58 prev: crate::Expr<'static, Self::FromSchema, Self::From>,
59 ) -> <Self::To as Table>::Insert;
60 #[doc(hidden)]
61 fn map_conflict(val: TableRow<Self::From>) -> Self::Conflict;
62}
63
64pub struct TransactionMigrate<FromSchema> {
66 inner: Transaction<FromSchema>,
67 scope: Scope,
68 rename_map: HashMap<&'static str, TmpTable>,
69 extra_index: Vec<String>,
71}
72
73impl<FromSchema> Deref for TransactionMigrate<FromSchema> {
74 type Target = Transaction<FromSchema>;
75
76 fn deref(&self) -> &Self::Target {
77 &self.inner
78 }
79}
80
81impl<FromSchema: 'static> TransactionMigrate<FromSchema> {
82 fn new_table_name<T: Table>(&mut self) -> TmpTable {
83 *self.rename_map.entry(T::NAME).or_insert_with(|| {
84 let new_table_name = self.scope.tmp_table();
85 TXN.with_borrow(|txn| {
86 let conn = txn.as_ref().unwrap().get();
87 let table = crate::hash::Table::new::<T>();
88 let extra_indices = new_table_inner(conn, &table, new_table_name, T::NAME);
89 self.extra_index.extend(extra_indices);
90 });
91 new_table_name
92 })
93 }
94
95 fn unmigrated<M: Migration<FromSchema = FromSchema>>(
96 &self,
97 new_name: TmpTable,
98 ) -> impl Iterator<Item = TableRow<M::From>> {
99 let data = self.inner.query(|rows| {
100 let old = rows.join_private::<M::From>();
101 rows.into_vec(old)
102 });
103
104 let migrated = Transaction::new().query(|rows| {
105 let new = rows.join_tmp::<M::From>(new_name);
106 rows.into_vec(new)
107 });
108 let migrated: HashSet<_> = migrated.into_iter().map(|x| x.inner.idx).collect();
109
110 data.into_iter()
111 .filter(move |row| !migrated.contains(&row.inner.idx))
112 }
113
114 pub fn migrate_optional<'t, M: Migration<FromSchema = FromSchema>>(
122 &'t mut self,
123 mut f: impl FnMut(Lazy<'t, M::From>) -> Option<M>,
124 ) -> Result<(), M::Conflict> {
125 let new_name = self.new_table_name::<M::To>();
126
127 for row in self.unmigrated::<M>(new_name) {
128 if let Some(new) = f(self.lazy(row)) {
129 try_insert_private::<M::To>(
130 new_name.into_table_ref(),
131 Some(row.inner.idx),
132 M::prepare(new, row.into_expr()),
133 )
134 .map_err(|_| M::map_conflict(row))?;
135 };
136 }
137 Ok(())
138 }
139
140 pub fn migrate<'t, M: Migration<FromSchema = FromSchema>>(
147 &'t mut self,
148 mut f: impl FnMut(Lazy<'t, M::From>) -> M,
149 ) -> Result<Migrated<'static, FromSchema, M::To>, M::Conflict> {
150 self.migrate_optional::<M>(|x| Some(f(x)))?;
151
152 Ok(Migrated {
153 _p: PhantomData,
154 f: Box::new(|_| {}),
155 _local: PhantomData,
156 })
157 }
158
159 pub fn migrate_ok<'t, M: Migration<FromSchema = FromSchema, Conflict = Infallible>>(
163 &'t mut self,
164 f: impl FnMut(Lazy<'t, M::From>) -> M,
165 ) -> Migrated<'static, FromSchema, M::To> {
166 let Ok(res) = self.migrate(f);
167 res
168 }
169}
170
171pub struct SchemaBuilder<'t, FromSchema> {
172 inner: TransactionMigrate<FromSchema>,
173 drop: Vec<TableDropStatement>,
174 foreign_key: HashMap<&'static str, Box<dyn 't + FnOnce() -> Infallible>>,
175}
176
177impl<'t, FromSchema: 'static> SchemaBuilder<'t, FromSchema> {
178 pub fn foreign_key<To: Table>(&mut self, err: impl 't + FnOnce() -> Infallible) {
179 self.inner.new_table_name::<To>();
180
181 self.foreign_key.insert(To::NAME, Box::new(err));
182 }
183
184 pub fn create_empty<To: Table>(&mut self) {
185 self.inner.new_table_name::<To>();
186 }
187
188 pub fn drop_table<T: Table>(&mut self) {
189 let name = Alias::new(T::NAME);
190 let step = sea_query::Table::drop().table(name).take();
191 self.drop.push(step);
192 }
193}
194
195fn new_table_inner(
196 conn: &Connection,
197 table: &crate::hash::Table,
198 alias: impl IntoTableRef,
199 index_table: &str,
200) -> Vec<String> {
201 let mut extra_indices = Vec::new();
202 let mut create = table.create(&mut extra_indices);
203 create
204 .table(alias)
205 .col(ColumnDef::new(Alias::new("id")).integer().primary_key());
206 let mut sql = create.to_string(SqliteQueryBuilder);
207 sql.push_str(" STRICT");
208 conn.execute(&sql, []).unwrap();
209
210 let index_table_ref = Alias::new(index_table);
211 extra_indices
212 .into_iter()
213 .enumerate()
214 .map(|(index_num, mut index)| {
215 index
216 .table(index_table_ref.clone())
217 .name(format!("{index_table}_index_{index_num}"))
218 .to_string(SqliteQueryBuilder)
219 })
220 .collect()
221}
222
223pub trait SchemaMigration<'a> {
224 type From: Schema;
225 type To: Schema;
226
227 fn tables(self, b: &mut SchemaBuilder<'a, Self::From>);
228}
229
230pub struct Config {
242 manager: r2d2_sqlite::SqliteConnectionManager,
243 init: Box<dyn FnOnce(&rusqlite::Transaction)>,
244 pub synchronous: Synchronous,
248 pub foreign_keys: ForeignKeys,
250}
251
252#[non_exhaustive]
256pub enum Synchronous {
257 Full,
261
262 Normal,
267}
268
269impl Synchronous {
270 fn as_str(self) -> &'static str {
271 match self {
272 Synchronous::Full => "FULL",
273 Synchronous::Normal => "NORMAL",
274 }
275 }
276}
277
278#[non_exhaustive]
282pub enum ForeignKeys {
283 Rust,
288
289 SQLite,
302}
303
304impl ForeignKeys {
305 fn as_str(self) -> &'static str {
306 match self {
307 ForeignKeys::Rust => "OFF",
308 ForeignKeys::SQLite => "ON",
309 }
310 }
311}
312
313impl Config {
314 pub fn open(p: impl AsRef<Path>) -> Self {
321 let manager = r2d2_sqlite::SqliteConnectionManager::file(p);
322 Self::open_internal(manager)
323 }
324
325 pub fn open_in_memory() -> Self {
327 let manager = r2d2_sqlite::SqliteConnectionManager::memory();
328 Self::open_internal(manager)
329 }
330
331 fn open_internal(manager: r2d2_sqlite::SqliteConnectionManager) -> Self {
332 Self {
333 manager,
334 init: Box::new(|_| {}),
335 synchronous: Synchronous::Full,
336 foreign_keys: ForeignKeys::SQLite,
337 }
338 }
339
340 pub fn init_stmt(mut self, sql: &'static str) -> Self {
344 self.init = Box::new(move |txn| {
345 (self.init)(txn);
346
347 txn.execute_batch(sql)
348 .expect("raw sql statement to populate db failed");
349 });
350 self
351 }
352}
353
354impl<S: Schema> Database<S> {
355 pub fn migrator(config: Config) -> Option<Migrator<S>> {
359 let synchronous = config.synchronous.as_str();
360 let foreign_keys = config.foreign_keys.as_str();
361 let manager = config.manager.with_init(move |inner| {
362 inner.pragma_update(None, "journal_mode", "WAL")?;
363 inner.pragma_update(None, "synchronous", synchronous)?;
364 inner.pragma_update(None, "foreign_keys", foreign_keys)?;
365 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DDL, false)?;
366 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DML, false)?;
367 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DEFENSIVE, true)?;
368 Ok(())
369 });
370
371 use r2d2::ManageConnection;
372 let conn = manager.connect().unwrap();
373 conn.pragma_update(None, "foreign_keys", "OFF").unwrap();
374 let txn = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
375 Some(
376 conn.borrow_mut()
377 .transaction_with_behavior(rusqlite::TransactionBehavior::Exclusive)
378 .unwrap(),
379 )
380 });
381
382 if schema_version(txn.get()) == 0 {
384 let schema = crate::hash::Schema::new::<S>();
385
386 for (table_name, table) in &schema.tables {
387 let table_name_ref = Alias::new(table_name);
388 let extra_indices = new_table_inner(txn.get(), table, table_name_ref, table_name);
389 for stmt in extra_indices {
390 txn.get().execute(&stmt, []).unwrap();
391 }
392 }
393 (config.init)(txn.get());
394 set_user_version(txn.get(), S::VERSION).unwrap();
395 }
396
397 let user_version = user_version(txn.get()).unwrap();
398 if user_version < S::VERSION {
400 return None;
401 }
402 debug_assert_eq!(
403 foreign_key_check(txn.get()),
404 None,
405 "foreign key constraint violated"
406 );
407
408 Some(Migrator {
409 manager,
410 transaction: txn,
411 _p: PhantomData,
412 })
413 }
414}
415
416pub struct Migrator<S> {
421 manager: r2d2_sqlite::SqliteConnectionManager,
422 transaction: OwnedTransaction,
423 _p: PhantomData<S>,
424}
425
426pub struct Migrated<'t, FromSchema, T> {
430 _p: PhantomData<T>,
431 f: Box<dyn 't + FnOnce(&mut SchemaBuilder<'t, FromSchema>)>,
432 _local: PhantomData<*const ()>,
433}
434
435impl<'t, FromSchema: 'static, T: Table> Migrated<'t, FromSchema, T> {
436 pub fn map_fk_err(err: impl 't + FnOnce() -> Infallible) -> Self {
440 Self {
441 _p: PhantomData,
442 f: Box::new(|x| x.foreign_key::<T>(err)),
443 _local: PhantomData,
444 }
445 }
446
447 #[doc(hidden)]
448 pub fn apply(self, b: &mut SchemaBuilder<'t, FromSchema>) {
449 (self.f)(b)
450 }
451}
452
453impl<S: Schema> Migrator<S> {
454 pub fn migrate<'x, M>(
458 mut self,
459 m: impl Send + FnOnce(&mut TransactionMigrate<S>) -> M,
460 ) -> Migrator<M::To>
461 where
462 M: SchemaMigration<'x, From = S>,
463 {
464 if user_version(self.transaction.get()).unwrap() == S::VERSION {
465 let res = std::thread::scope(|s| {
466 s.spawn(|| {
467 TXN.set(Some(TransactionWithRows::new_empty(self.transaction)));
468
469 check_schema::<S>();
470
471 let mut txn = TransactionMigrate {
472 inner: Transaction::new(),
473 scope: Default::default(),
474 rename_map: HashMap::new(),
475 extra_index: Vec::new(),
476 };
477 let m = m(&mut txn);
478
479 let mut builder = SchemaBuilder {
480 drop: vec![],
481 foreign_key: HashMap::new(),
482 inner: txn,
483 };
484 m.tables(&mut builder);
485
486 let transaction = TXN.take().unwrap();
487
488 for drop in builder.drop {
489 let sql = drop.to_string(SqliteQueryBuilder);
490 transaction.get().execute(&sql, []).unwrap();
491 }
492 for (to, tmp) in builder.inner.rename_map {
493 let rename = sea_query::Table::rename().table(tmp, Alias::new(to)).take();
494 let sql = rename.to_string(SqliteQueryBuilder);
495 transaction.get().execute(&sql, []).unwrap();
496 }
497 if let Some(fk) = foreign_key_check(transaction.get()) {
498 (builder.foreign_key.remove(&*fk).unwrap())();
499 }
500 #[allow(
501 unreachable_code,
502 reason = "rustc is stupid and thinks this is unreachable"
503 )]
504 for stmt in builder.inner.extra_index {
506 transaction.get().execute(&stmt, []).unwrap();
507 }
508 set_user_version(transaction.get(), M::To::VERSION).unwrap();
509
510 transaction.into_owner()
511 })
512 .join()
513 });
514 match res {
515 Ok(val) => self.transaction = val,
516 Err(payload) => std::panic::resume_unwind(payload),
517 }
518 }
519
520 Migrator {
521 manager: self.manager,
522 transaction: self.transaction,
523 _p: PhantomData,
524 }
525 }
526
527 pub fn finish(mut self) -> Option<Database<S>> {
533 let conn = &self.transaction;
534 if user_version(conn.get()).unwrap() != S::VERSION {
535 return None;
536 }
537
538 let res = std::thread::scope(|s| {
539 s.spawn(|| {
540 TXN.set(Some(TransactionWithRows::new_empty(self.transaction)));
541 check_schema::<S>();
542 TXN.take().unwrap().into_owner()
543 })
544 .join()
545 });
546 match res {
547 Ok(val) => self.transaction = val,
548 Err(payload) => std::panic::resume_unwind(payload),
549 }
550
551 self.transaction
553 .get()
554 .execute_batch("PRAGMA optimize;")
555 .unwrap();
556
557 let schema_version = schema_version(self.transaction.get());
558 self.transaction.with(|x| x.commit().unwrap());
559
560 Some(Database {
561 manager: self.manager,
562 schema_version: AtomicI64::new(schema_version),
563 schema: PhantomData,
564 mut_lock: parking_lot::FairMutex::new(()),
565 })
566 }
567}
568
569pub fn schema_version(conn: &rusqlite::Transaction) -> i64 {
570 conn.pragma_query_value(None, "schema_version", |r| r.get(0))
571 .unwrap()
572}
573
574pub fn user_version(conn: &rusqlite::Transaction) -> Result<i64, rusqlite::Error> {
576 conn.query_row("PRAGMA user_version", [], |row| row.get(0))
577}
578
579fn set_user_version(conn: &rusqlite::Transaction, v: i64) -> Result<(), rusqlite::Error> {
581 conn.pragma_update(None, "user_version", v)
582}
583
584pub(crate) fn check_schema<S: Schema>() {
585 pretty_assertions::assert_eq!(
587 crate::hash::Schema::new::<S>().normalize(),
588 read_schema(&crate::Transaction::new()).normalize(),
589 "schema is different (expected left, but got right)",
590 );
591}
592
593fn foreign_key_check(conn: &rusqlite::Transaction) -> Option<String> {
594 let error = conn
595 .prepare("PRAGMA foreign_key_check")
596 .unwrap()
597 .query_map([], |row| row.get(2))
598 .unwrap()
599 .next();
600 error.transpose().unwrap()
601}
602
603#[test]
604fn open_multiple() {
605 #[crate::migration::schema(Empty)]
606 pub mod vN {}
607
608 let _a = Database::<v0::Empty>::migrator(Config::open_in_memory());
609 let _b = Database::<v0::Empty>::migrator(Config::open_in_memory());
610}