1pub mod config;
2pub mod migration;
3#[cfg(test)]
4mod test;
5
6use std::{
7 collections::{BTreeSet, HashMap},
8 marker::PhantomData,
9 sync::atomic::AtomicI64,
10};
11
12use annotate_snippets::{Renderer, renderer::DecorStyle};
13use rusqlite::config::DbConfig;
14use sea_query::{Alias, ColumnDef, IntoIden, SqliteQueryBuilder};
15use self_cell::MutBorrow;
16
17use crate::{
18 Table, Transaction,
19 alias::Scope,
20 migrate::{
21 config::Config,
22 migration::{SchemaBuilder, TransactionMigrate},
23 },
24 pool::Pool,
25 schema::{from_db, from_macro, read::read_schema},
26 transaction::{Database, OwnedTransaction, TXN, TransactionWithRows},
27};
28
29pub struct TableTypBuilder<S> {
30 pub(crate) ast: from_macro::Schema,
31 _p: PhantomData<S>,
32}
33
34impl<S> Default for TableTypBuilder<S> {
35 fn default() -> Self {
36 Self {
37 ast: Default::default(),
38 _p: Default::default(),
39 }
40 }
41}
42
43impl<S> TableTypBuilder<S> {
44 pub fn table<T: Table<Schema = S>>(&mut self) {
45 let table = from_macro::Table::new::<T>();
46 let old = self.ast.tables.insert(T::NAME, table);
47 debug_assert!(old.is_none());
48 }
49}
50
51pub trait Schema: Sized + 'static {
52 const VERSION: i64;
53 const SOURCE: &str;
54 const PATH: &str;
55 const SPAN: (usize, usize);
56 fn typs(b: &mut TableTypBuilder<Self>);
57}
58
59fn new_table_inner(table: &crate::schema::from_macro::Table, alias: impl IntoIden) -> String {
60 let alias = alias.into_iden();
61 let mut create = table.create();
62 create
63 .table(alias.clone())
64 .col(ColumnDef::new(Alias::new("id")).integer().primary_key());
65 let mut sql = create.to_string(SqliteQueryBuilder);
66 sql.push_str(" STRICT");
67 sql
68}
69
70pub trait SchemaMigration<'a> {
71 type From: Schema;
72 type To: Schema;
73
74 fn tables(self, b: &mut SchemaBuilder<'a, Self::From>);
75}
76
77impl<S: Schema> Database<S> {
78 pub fn migrator(config: Config) -> Option<Migrator<S>> {
82 let synchronous = config.synchronous.as_str();
83 let foreign_keys = config.foreign_keys.as_str();
84 let manager = config.manager.with_init(move |inner| {
85 inner.pragma_update(None, "journal_mode", "WAL")?;
86 inner.pragma_update(None, "synchronous", synchronous)?;
87 inner.pragma_update(None, "foreign_keys", foreign_keys)?;
88 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DDL, false)?;
89 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DML, false)?;
90 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DEFENSIVE, true)?;
91 rusqlite::vtab::array::load_module(inner).unwrap();
92 Ok(())
93 });
94
95 use r2d2::ManageConnection;
96 let conn = manager.connect().unwrap();
97 conn.pragma_update(None, "foreign_keys", "OFF").unwrap();
98 let txn = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
99 Some(
100 conn.borrow_mut()
101 .transaction_with_behavior(rusqlite::TransactionBehavior::Exclusive)
102 .unwrap(),
103 )
104 });
105
106 let mut user_version = Some(user_version(txn.get()).unwrap());
107
108 if schema_version(txn.get()) == 0 {
110 user_version = None;
111
112 let schema = crate::schema::from_macro::Schema::new::<S>();
113
114 for (&table_name, table) in &schema.tables {
115 txn.get()
116 .execute(&new_table_inner(table, table_name), [])
117 .unwrap();
118 for stmt in table.delayed_indices(table_name) {
119 txn.get().execute(&stmt, []).unwrap();
120 }
121 }
122 (config.init)(txn.get());
123 } else if user_version.unwrap() < S::VERSION {
124 return None;
126 }
127
128 debug_assert_eq!(
129 foreign_key_check(txn.get()),
130 None,
131 "foreign key constraint violated"
132 );
133
134 Some(Migrator {
135 user_version,
136 manager,
137 transaction: txn,
138 _p: PhantomData,
139 })
140 }
141}
142
143pub struct Migrator<S> {
148 manager: r2d2_sqlite::SqliteConnectionManager,
149 transaction: OwnedTransaction,
150 user_version: Option<i64>,
155 _p: PhantomData<S>,
156}
157
158impl<S: Schema> Migrator<S> {
159 fn with_transaction(mut self, f: impl Send + FnOnce(&'static mut Transaction<S>)) -> Self {
160 assert!(self.user_version.is_none_or(|x| x == S::VERSION));
161 let res = std::thread::scope(|s| {
162 s.spawn(|| {
163 TXN.set(Some(TransactionWithRows::new_empty(self.transaction)));
164 let txn = Transaction::new_ref();
165
166 if self.user_version.take().is_some() {
168 check_schema::<S>(txn);
170 fix_indices::<S>(txn);
172 }
173
174 f(txn);
175
176 let transaction = TXN.take().unwrap();
177
178 transaction.into_owner()
179 })
180 .join()
181 });
182 match res {
183 Ok(val) => self.transaction = val,
184 Err(payload) => std::panic::resume_unwind(payload),
185 }
186 self
187 }
188
189 pub fn migrate<'x, M>(
193 mut self,
194 m: impl Send + FnOnce(&mut TransactionMigrate<S>) -> M,
195 ) -> Migrator<M::To>
196 where
197 M: SchemaMigration<'x, From = S>,
198 {
199 if self.user_version.is_none_or(|x| x == S::VERSION) {
200 self = self.with_transaction(|txn| {
201 let mut txn = TransactionMigrate {
202 inner: txn.copy(),
203 scope: Default::default(),
204 rename_map: HashMap::new(),
205 extra_index: Vec::new(),
206 };
207 let m = m(&mut txn);
208
209 let mut builder = SchemaBuilder {
210 drop: vec![],
211 foreign_key: HashMap::new(),
212 inner: txn,
213 };
214 m.tables(&mut builder);
215
216 let transaction = TXN.take().unwrap();
217
218 for drop in builder.drop {
219 let sql = drop.to_string(SqliteQueryBuilder);
220 transaction.get().execute(&sql, []).unwrap();
221 }
222 for (to, tmp) in builder.inner.rename_map {
223 let rename = sea_query::Table::rename().table(tmp, Alias::new(to)).take();
224 let sql = rename.to_string(SqliteQueryBuilder);
225 transaction.get().execute(&sql, []).unwrap();
226 }
227 #[allow(
228 unreachable_code,
229 reason = "rustc is stupid and thinks this is unreachable"
230 )]
231 if let Some(fk) = foreign_key_check(transaction.get()) {
232 (builder.foreign_key.remove(&*fk).unwrap())();
233 }
234 for stmt in builder.inner.extra_index {
236 transaction.get().execute(&stmt, []).unwrap();
237 }
238
239 TXN.set(Some(transaction));
240 });
241 }
242
243 Migrator {
244 user_version: self.user_version,
245 manager: self.manager,
246 transaction: self.transaction,
247 _p: PhantomData,
248 }
249 }
250
251 pub fn fixup(mut self, f: impl Send + FnOnce(&'static mut Transaction<S>)) -> Self {
258 if self.user_version.is_none() {
259 self = self.with_transaction(f);
260 }
261 self
262 }
263
264 pub fn finish(mut self) -> Option<Database<S>> {
270 if self.user_version.is_some_and(|x| x != S::VERSION) {
271 return None;
272 }
273
274 self = self.with_transaction(|txn| {
276 check_schema::<S>(txn);
278 });
279
280 self.transaction
282 .get()
283 .execute_batch("PRAGMA optimize;")
284 .unwrap();
285
286 set_user_version(self.transaction.get(), S::VERSION).unwrap();
287 let schema_version = schema_version(self.transaction.get());
288 self.transaction.with(|x| x.commit().unwrap());
289
290 Some(Database {
291 manager: Pool::new(self.manager),
292 schema_version: AtomicI64::new(schema_version),
293 schema: PhantomData,
294 mut_lock: parking_lot::FairMutex::new(()),
295 })
296 }
297}
298
299fn fix_indices<S: Schema>(txn: &Transaction<S>) {
300 let schema = read_schema(txn);
301 let expected_schema = crate::schema::from_macro::Schema::new::<S>();
302
303 fn check_eq(expected: &from_macro::Table, actual: &from_db::Table) -> bool {
304 let expected: BTreeSet<_> = expected.indices.iter().map(|idx| &idx.def).collect();
305 let actual: BTreeSet<_> = actual.indices.values().collect();
306 expected == actual
307 }
308
309 for (&table_name, expected_table) in &expected_schema.tables {
310 let table = &schema.tables[table_name];
311
312 if !check_eq(expected_table, &table) {
313 let scope = Scope::default();
318 let tmp_name = scope.tmp_table();
319
320 txn.execute(&new_table_inner(expected_table, tmp_name));
321
322 let mut columns: Vec<_> = expected_table
323 .columns
324 .keys()
325 .map(|x| Alias::new(x))
326 .collect();
327 columns.push(Alias::new("id"));
328
329 txn.execute(
330 &sea_query::InsertStatement::new()
331 .into_table(tmp_name)
332 .columns(columns.clone())
333 .select_from(
334 sea_query::SelectStatement::new()
335 .from(table_name)
336 .columns(columns)
337 .take(),
338 )
339 .unwrap()
340 .build(SqliteQueryBuilder)
341 .0,
342 );
343
344 txn.execute(
345 &sea_query::TableDropStatement::new()
346 .table(table_name)
347 .build(SqliteQueryBuilder),
348 );
349
350 txn.execute(
351 &sea_query::TableRenameStatement::new()
352 .table(tmp_name, table_name)
353 .build(SqliteQueryBuilder),
354 );
355 for sql in expected_table.delayed_indices(table_name) {
357 txn.execute(&sql);
358 }
359 }
360 }
361
362 let schema = read_schema(txn);
364 for (name, table) in schema.tables {
365 let expected_table = &expected_schema.tables[&*name];
366 assert!(check_eq(expected_table, &table));
367 }
368}
369
370impl<S> Transaction<S> {
371 #[track_caller]
372 pub(crate) fn execute(&self, sql: &str) {
373 TXN.with_borrow(|txn| txn.as_ref().unwrap().get().execute(sql, []))
374 .unwrap();
375 }
376}
377
378pub fn schema_version(conn: &rusqlite::Transaction) -> i64 {
379 conn.pragma_query_value(None, "schema_version", |r| r.get(0))
380 .unwrap()
381}
382
383pub fn user_version(conn: &rusqlite::Transaction) -> Result<i64, rusqlite::Error> {
385 conn.query_row("PRAGMA user_version", [], |row| row.get(0))
386}
387
388fn set_user_version(conn: &rusqlite::Transaction, v: i64) -> Result<(), rusqlite::Error> {
390 conn.pragma_update(None, "user_version", v)
391}
392
393pub(crate) fn check_schema<S: Schema>(txn: &Transaction<S>) {
394 let from_macro = crate::schema::from_macro::Schema::new::<S>();
395 let from_db = read_schema(txn);
396 let report = from_db.diff(from_macro, S::SOURCE, S::PATH, S::VERSION);
397 if !report.is_empty() {
398 let renderer = if cfg!(test) {
399 Renderer::plain().anonymized_line_numbers(true)
400 } else {
401 Renderer::styled()
402 }
403 .decor_style(DecorStyle::Unicode);
404 panic!("{}", renderer.render(&report));
405 }
406}
407
408fn foreign_key_check(conn: &rusqlite::Transaction) -> Option<String> {
409 let error = conn
410 .prepare("PRAGMA foreign_key_check")
411 .unwrap()
412 .query_map([], |row| row.get(2))
413 .unwrap()
414 .next();
415 error.transpose().unwrap()
416}
417
418impl<S> Transaction<S> {
419 #[cfg(test)]
420 pub(crate) fn schema(&self) -> Vec<String> {
421 TXN.with_borrow(|x| {
422 x.as_ref()
423 .unwrap()
424 .get()
425 .prepare("SELECT sql FROM 'main'.'sqlite_schema'")
426 .unwrap()
427 .query_map([], |row| row.get::<_, Option<String>>("sql"))
428 .unwrap()
429 .flat_map(|x| x.unwrap())
430 .collect()
431 })
432 }
433}