1pub mod config;
2pub mod migration;
3
4use std::{
5 collections::{BTreeSet, HashMap},
6 marker::PhantomData,
7 sync::atomic::AtomicI64,
8};
9
10use annotate_snippets::{Renderer, renderer::DecorStyle};
11use rusqlite::{Connection, config::DbConfig};
12use sea_query::{Alias, ColumnDef, IntoTableRef, SqliteQueryBuilder};
13use self_cell::MutBorrow;
14
15use crate::{
16 Table, Transaction,
17 migrate::{
18 config::Config,
19 migration::{SchemaBuilder, TransactionMigrate},
20 },
21 pool::Pool,
22 schema::{
23 from_db, from_macro,
24 read::{read_index_names_for_table, read_schema},
25 },
26 transaction::{Database, OwnedTransaction, TXN, TransactionWithRows},
27};
28
29pub struct TableTypBuilder<S> {
30 pub(crate) ast: from_macro::Schema,
31 _p: PhantomData<S>,
32}
33
34impl<S> Default for TableTypBuilder<S> {
35 fn default() -> Self {
36 Self {
37 ast: Default::default(),
38 _p: Default::default(),
39 }
40 }
41}
42
43impl<S> TableTypBuilder<S> {
44 pub fn table<T: Table<Schema = S>>(&mut self) {
45 let table = from_macro::Table::new::<T>();
46 let old = self.ast.tables.insert(T::NAME.to_owned(), table);
47 debug_assert!(old.is_none());
48 }
49}
50
51pub trait Schema: Sized + 'static {
52 const VERSION: i64;
53 const SOURCE: &str;
54 const PATH: &str;
55 const SPAN: (usize, usize);
56 fn typs(b: &mut TableTypBuilder<Self>);
57}
58
59fn new_table_inner(
60 conn: &Connection,
61 table: &crate::schema::from_macro::Table,
62 alias: impl IntoTableRef,
63) {
64 let mut create = table.create();
65 create
66 .table(alias)
67 .col(ColumnDef::new(Alias::new("id")).integer().primary_key());
68 let mut sql = create.to_string(SqliteQueryBuilder);
69 sql.push_str(" STRICT");
70 conn.execute(&sql, []).unwrap();
71}
72
73pub trait SchemaMigration<'a> {
74 type From: Schema;
75 type To: Schema;
76
77 fn tables(self, b: &mut SchemaBuilder<'a, Self::From>);
78}
79
80impl<S: Schema> Database<S> {
81 pub fn migrator(config: Config) -> Option<Migrator<S>> {
85 let synchronous = config.synchronous.as_str();
86 let foreign_keys = config.foreign_keys.as_str();
87 let manager = config.manager.with_init(move |inner| {
88 inner.pragma_update(None, "journal_mode", "WAL")?;
89 inner.pragma_update(None, "synchronous", synchronous)?;
90 inner.pragma_update(None, "foreign_keys", foreign_keys)?;
91 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DDL, false)?;
92 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DML, false)?;
93 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DEFENSIVE, true)?;
94 Ok(())
95 });
96
97 use r2d2::ManageConnection;
98 let conn = manager.connect().unwrap();
99 conn.pragma_update(None, "foreign_keys", "OFF").unwrap();
100 let txn = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
101 Some(
102 conn.borrow_mut()
103 .transaction_with_behavior(rusqlite::TransactionBehavior::Exclusive)
104 .unwrap(),
105 )
106 });
107
108 if schema_version(txn.get()) == 0 {
110 let schema = crate::schema::from_macro::Schema::new::<S>();
111
112 for (table_name, table) in &schema.tables {
113 let table_name_ref = Alias::new(table_name);
114 new_table_inner(txn.get(), table, table_name_ref);
115 for stmt in table.create_indices(table_name) {
116 txn.get().execute(&stmt, []).unwrap();
117 }
118 }
119 (config.init)(txn.get());
120 set_user_version(txn.get(), S::VERSION).unwrap();
121 }
122
123 let user_version = user_version(txn.get()).unwrap();
124 if user_version < S::VERSION {
126 return None;
127 }
128 debug_assert_eq!(
129 foreign_key_check(txn.get()),
130 None,
131 "foreign key constraint violated"
132 );
133
134 Some(Migrator {
135 indices_fixed: false,
136 manager,
137 transaction: txn,
138 _p: PhantomData,
139 })
140 }
141}
142
143pub struct Migrator<S> {
148 manager: r2d2_sqlite::SqliteConnectionManager,
149 transaction: OwnedTransaction,
150 indices_fixed: bool,
151 _p: PhantomData<S>,
152}
153
154impl<S: Schema> Migrator<S> {
155 pub fn migrate<'x, M>(
159 mut self,
160 m: impl Send + FnOnce(&mut TransactionMigrate<S>) -> M,
161 ) -> Migrator<M::To>
162 where
163 M: SchemaMigration<'x, From = S>,
164 {
165 if user_version(self.transaction.get()).unwrap() == S::VERSION {
166 let res = std::thread::scope(|s| {
167 s.spawn(|| {
168 TXN.set(Some(TransactionWithRows::new_empty(self.transaction)));
169 let txn = Transaction::new_ref();
170
171 check_schema::<S>(txn);
172 if !self.indices_fixed {
173 fix_indices::<S>(txn);
174 self.indices_fixed = true;
175 }
176
177 let mut txn = TransactionMigrate {
178 inner: Transaction::new(),
179 scope: Default::default(),
180 rename_map: HashMap::new(),
181 extra_index: Vec::new(),
182 };
183 let m = m(&mut txn);
184
185 let mut builder = SchemaBuilder {
186 drop: vec![],
187 foreign_key: HashMap::new(),
188 inner: txn,
189 };
190 m.tables(&mut builder);
191
192 let transaction = TXN.take().unwrap();
193
194 for drop in builder.drop {
195 let sql = drop.to_string(SqliteQueryBuilder);
196 transaction.get().execute(&sql, []).unwrap();
197 }
198 for (to, tmp) in builder.inner.rename_map {
199 let rename = sea_query::Table::rename().table(tmp, Alias::new(to)).take();
200 let sql = rename.to_string(SqliteQueryBuilder);
201 transaction.get().execute(&sql, []).unwrap();
202 }
203 if let Some(fk) = foreign_key_check(transaction.get()) {
204 (builder.foreign_key.remove(&*fk).unwrap())();
205 }
206 #[allow(
207 unreachable_code,
208 reason = "rustc is stupid and thinks this is unreachable"
209 )]
210 for stmt in builder.inner.extra_index {
212 transaction.get().execute(&stmt, []).unwrap();
213 }
214 set_user_version(transaction.get(), M::To::VERSION).unwrap();
215
216 transaction.into_owner()
217 })
218 .join()
219 });
220 match res {
221 Ok(val) => self.transaction = val,
222 Err(payload) => std::panic::resume_unwind(payload),
223 }
224 }
225
226 Migrator {
227 indices_fixed: self.indices_fixed,
228 manager: self.manager,
229 transaction: self.transaction,
230 _p: PhantomData,
231 }
232 }
233
234 pub fn finish(mut self) -> Option<Database<S>> {
240 if user_version(self.transaction.get()).unwrap() != S::VERSION {
241 return None;
242 }
243
244 let res = std::thread::scope(|s| {
245 s.spawn(|| {
246 TXN.set(Some(TransactionWithRows::new_empty(self.transaction)));
247 let txn = Transaction::new_ref();
248
249 check_schema::<S>(txn);
250 if !self.indices_fixed {
251 fix_indices::<S>(txn);
252 self.indices_fixed = true;
253 }
254
255 TXN.take().unwrap().into_owner()
256 })
257 .join()
258 });
259 match res {
260 Ok(val) => self.transaction = val,
261 Err(payload) => std::panic::resume_unwind(payload),
262 }
263
264 self.transaction
266 .get()
267 .execute_batch("PRAGMA optimize;")
268 .unwrap();
269
270 let schema_version = schema_version(self.transaction.get());
271 self.transaction.with(|x| x.commit().unwrap());
272
273 Some(Database {
274 manager: Pool::new(self.manager),
275 schema_version: AtomicI64::new(schema_version),
276 schema: PhantomData,
277 mut_lock: parking_lot::FairMutex::new(()),
278 })
279 }
280}
281
282fn fix_indices<S: Schema>(txn: &Transaction<S>) {
283 let schema = read_schema(txn);
284 let expected_schema = crate::schema::from_macro::Schema::new::<S>();
285
286 fn check_eq(expected: &from_macro::Table, actual: &from_db::Table) -> bool {
287 let expected: BTreeSet<_> = expected.indices.iter().map(|idx| &idx.def).collect();
288 let actual: BTreeSet<_> = actual.indices.values().collect();
289 expected == actual
290 }
291
292 for (name, table) in schema.tables {
293 let expected_table = &expected_schema.tables[&name];
294
295 if !check_eq(expected_table, &table) {
296 for index_name in read_index_names_for_table(&crate::Transaction::new(), &name) {
298 let sql = sea_query::Index::drop()
299 .name(index_name)
300 .build(SqliteQueryBuilder);
301 txn.execute(&sql);
302 }
303
304 for sql in expected_table.create_indices(&name) {
306 txn.execute(&sql);
307 }
308 }
309 }
310
311 let schema = read_schema(txn);
313 for (name, table) in schema.tables {
314 let expected_table = &expected_schema.tables[&name];
315 assert!(check_eq(expected_table, &table));
316 }
317}
318
319impl<S> Transaction<S> {
320 pub(crate) fn execute(&self, sql: &str) {
321 TXN.with_borrow(|txn| txn.as_ref().unwrap().get().execute(sql, []))
322 .unwrap();
323 }
324}
325
326pub fn schema_version(conn: &rusqlite::Transaction) -> i64 {
327 conn.pragma_query_value(None, "schema_version", |r| r.get(0))
328 .unwrap()
329}
330
331pub fn user_version(conn: &rusqlite::Transaction) -> Result<i64, rusqlite::Error> {
333 conn.query_row("PRAGMA user_version", [], |row| row.get(0))
334}
335
336fn set_user_version(conn: &rusqlite::Transaction, v: i64) -> Result<(), rusqlite::Error> {
338 conn.pragma_update(None, "user_version", v)
339}
340
341pub(crate) fn check_schema<S: Schema>(txn: &Transaction<S>) {
342 let from_macro = crate::schema::from_macro::Schema::new::<S>();
343 let from_db = read_schema(txn);
344 let report = from_db.diff(from_macro, S::SOURCE, S::PATH, S::VERSION);
345 if !report.is_empty() {
346 let renderer = if cfg!(test) {
347 Renderer::plain().anonymized_line_numbers(true)
348 } else {
349 Renderer::styled()
350 }
351 .decor_style(DecorStyle::Unicode);
352 panic!("{}", renderer.render(&report));
353 }
354}
355
356fn foreign_key_check(conn: &rusqlite::Transaction) -> Option<String> {
357 let error = conn
358 .prepare("PRAGMA foreign_key_check")
359 .unwrap()
360 .query_map([], |row| row.get(2))
361 .unwrap()
362 .next();
363 error.transpose().unwrap()
364}
365
366impl<S> Transaction<S> {
367 #[cfg(test)]
368 pub(crate) fn schema(&self) -> Vec<String> {
369 TXN.with_borrow(|x| {
370 x.as_ref()
371 .unwrap()
372 .get()
373 .prepare("SELECT sql FROM 'main'.'sqlite_schema'")
374 .unwrap()
375 .query_map([], |row| row.get("sql"))
376 .unwrap()
377 .map(|x| x.unwrap())
378 .collect()
379 })
380 }
381}