1pub mod config;
2mod fix_by_copy;
3pub mod migration;
4#[cfg(test)]
5mod test;
6
7use std::{collections::HashMap, marker::PhantomData, sync::atomic::AtomicI64};
8
9use annotate_snippets::{Renderer, renderer::DecorStyle};
10use self_cell::MutBorrow;
11
12use crate::{
13 Table, Transaction,
14 lower::{self, list_writer::Alias},
15 migrate::{
16 config::Config,
17 fix_by_copy::fix_by_copy,
18 migration::{SchemaBuilder, TransactionMigrate},
19 },
20 pool::Pool,
21 schema::{from_macro, read::read_schema},
22 transaction::{Database, OwnedTransaction, TXN, TransactionWithRows},
23};
24
25pub struct TableTypBuilder<S> {
26 pub(crate) ast: from_macro::Schema,
27 _p: PhantomData<S>,
28}
29
30impl<S> Default for TableTypBuilder<S> {
31 fn default() -> Self {
32 Self {
33 ast: Default::default(),
34 _p: Default::default(),
35 }
36 }
37}
38
39impl<S> TableTypBuilder<S> {
40 pub fn table<T: Table<Schema = S>>(&mut self) {
41 let table = from_macro::Table::new::<T>();
42 let old = self.ast.tables.insert(T::NAME, table);
43 debug_assert!(old.is_none());
44 }
45}
46
47pub trait Schema: Sized + 'static {
48 const VERSION: i64;
49 const SOURCE: &str;
50 const PATH: &str;
51 const SPAN: (usize, usize);
52 fn typs(b: &mut TableTypBuilder<Self>);
53}
54
55pub trait SchemaMigration<'a> {
56 type From: Schema;
57 type To: Schema;
58
59 fn tables(self, b: &mut SchemaBuilder<'a, Self::From>);
60}
61
62impl<S: Schema> Database<S> {
63 pub fn migrator(config: Config) -> Option<Migrator<S>> {
67 let pool = Pool::new(config);
68
69 let conn = pool.pop();
70 conn.pragma_update(None, "foreign_keys", "OFF").unwrap();
71 let txn = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
72 Some(
73 conn.borrow_mut()
74 .transaction_with_behavior(rusqlite::TransactionBehavior::Exclusive)
75 .unwrap(),
76 )
77 });
78
79 let mut user_version = Some(user_version(txn.get()).unwrap());
80
81 if schema_version(txn.get()) == 0 {
83 user_version = None;
84
85 let schema = crate::schema::from_macro::Schema::new::<S>();
86
87 for (table_name, table) in schema.tables {
88 let table = table.to_db();
89 let create = table.create(lower::JoinableTable::Table(table_name), "id");
90 txn.get().execute(&create, []).unwrap();
91 for stmt in table.delayed_indices(table_name) {
92 txn.get().execute(&stmt, []).unwrap();
93 }
94 }
95 } else if user_version.unwrap() < S::VERSION {
96 return None;
98 }
99
100 debug_assert_eq!(
101 foreign_key_check(txn.get()),
102 None,
103 "foreign key constraint violated"
104 );
105
106 Some(Migrator {
107 user_version,
108 pool,
109 transaction: txn,
110 _p: PhantomData,
111 })
112 }
113}
114
115pub struct Migrator<S> {
120 pool: Pool,
121 transaction: OwnedTransaction,
122 user_version: Option<i64>,
127 _p: PhantomData<S>,
128}
129
130impl<S: Schema> Migrator<S> {
131 fn with_transaction(mut self, f: impl Send + FnOnce(&'static mut Transaction<S>)) -> Self {
132 assert!(self.user_version.is_none_or(|x| x == S::VERSION));
133 let res = std::thread::scope(|s| {
134 s.spawn(|| {
135 TXN.set(Some(TransactionWithRows::new_empty(self.transaction)));
136 let txn = Transaction::new_ref();
137
138 if self.user_version.take().is_some() {
140 check_schema::<S>(txn, false);
142 fix_by_copy::<S>(txn, fix_by_copy::Detail::Indexes);
144 }
145
146 f(txn);
147
148 let transaction = TXN.take().unwrap();
149
150 transaction.into_owner()
151 })
152 .join()
153 });
154 match res {
155 Ok(val) => self.transaction = val,
156 Err(payload) => std::panic::resume_unwind(payload),
157 }
158 self
159 }
160
161 pub fn migrate<'x, M>(
165 mut self,
166 m: impl Send + FnOnce(&mut TransactionMigrate<S>) -> M,
167 ) -> Migrator<M::To>
168 where
169 M: SchemaMigration<'x, From = S>,
170 {
171 if self.user_version.is_none_or(|x| x == S::VERSION) {
172 self = self.with_transaction(|txn| {
173 let mut txn = TransactionMigrate {
174 inner: txn.copy(),
175 scope: Default::default(),
176 rename_map: HashMap::new(),
177 extra_index: Vec::new(),
178 };
179 let m = m(&mut txn);
180
181 let mut builder = SchemaBuilder {
182 drop: vec![],
183 foreign_key: HashMap::new(),
184 inner: txn,
185 };
186 m.tables(&mut builder);
187 let txn = builder.inner.inner;
188
189 for drop in builder.drop {
190 txn.execute(&drop);
191 }
192 for (to, tmp) in builder.inner.rename_map {
193 txn.execute(&format!("ALTER TABLE main.{tmp} RENAME TO {}", Alias(to)));
194 }
195 for stmt in builder.inner.extra_index {
196 txn.execute(&stmt);
197 }
198
199 fix_by_copy::<M::To>(&Transaction::new(), fix_by_copy::Detail::ForeignKeys);
201
202 let transaction = TXN.take().unwrap();
203 if let Some(fk) = foreign_key_check(transaction.get()) {
204 (builder.foreign_key.remove(&*fk).unwrap())();
205 }
206
207 TXN.set(Some(transaction));
208 });
209 }
210
211 Migrator {
212 user_version: self.user_version,
213 pool: self.pool,
214 transaction: self.transaction,
215 _p: PhantomData,
216 }
217 }
218
219 pub fn fixup(mut self, f: impl Send + FnOnce(&'static mut Transaction<S>)) -> Self {
226 if self.user_version.is_none() {
227 self = self.with_transaction(f);
228 }
229 self
230 }
231
232 pub fn finish(mut self) -> Option<Database<S>> {
238 if self.user_version.is_some_and(|x| x != S::VERSION) {
239 return None;
240 }
241
242 self = self.with_transaction(|txn| {
244 check_schema::<S>(txn, true);
246 });
247
248 self.transaction
250 .get()
251 .execute_batch("PRAGMA optimize;")
252 .unwrap();
253
254 set_user_version(self.transaction.get(), S::VERSION).unwrap();
255 let schema_version = schema_version(self.transaction.get());
256 self.transaction.with(|x| x.commit().unwrap());
257
258 Some(Database {
259 pool: self.pool,
260 schema_version: AtomicI64::new(schema_version),
261 schema: PhantomData,
262 mut_lock: parking_lot::FairMutex::new(()),
263 })
264 }
265}
266
267impl<S> Transaction<S> {
268 #[track_caller]
269 pub(crate) fn execute(&self, sql: &str) {
270 TXN.with_borrow(|txn| txn.as_ref().unwrap().get().execute(sql, []))
271 .unwrap();
272 }
273}
274
275pub fn schema_version(conn: &rusqlite::Transaction) -> i64 {
276 conn.pragma_query_value(None, "schema_version", |r| r.get(0))
277 .unwrap()
278}
279
280pub fn user_version(conn: &rusqlite::Transaction) -> Result<i64, rusqlite::Error> {
282 conn.query_row("PRAGMA user_version", [], |row| row.get(0))
283}
284
285fn set_user_version(conn: &rusqlite::Transaction, v: i64) -> Result<(), rusqlite::Error> {
287 conn.pragma_update(None, "user_version", v)
288}
289
290pub(crate) fn check_schema<S: Schema>(txn: &Transaction<S>, sanity: bool) {
291 let from_macro = crate::schema::from_macro::Schema::new::<S>();
292 let from_db = read_schema(txn);
293 let report = from_db.diff(from_macro, S::SOURCE, S::PATH, S::VERSION);
294 if !report.is_empty() {
295 let renderer = if cfg!(test) {
296 Renderer::plain().anonymized_line_numbers(true)
297 } else {
298 Renderer::styled()
299 }
300 .decor_style(DecorStyle::Unicode);
301 if sanity {
302 unreachable!("THIS IS A RUST-QUERY BUG {}", renderer.render(&report));
303 } else {
304 panic!("{}", renderer.render(&report));
305 }
306 }
307}
308
309fn foreign_key_check(conn: &rusqlite::Transaction) -> Option<String> {
310 let error = conn
311 .prepare("PRAGMA foreign_key_check")
312 .unwrap()
313 .query_map([], |row| row.get(2))
314 .unwrap()
315 .next();
316 error.transpose().unwrap()
317}
318
319impl<S> Transaction<S> {
320 #[cfg(test)]
321 pub(crate) fn schema(&self) -> Vec<String> {
322 TXN.with_borrow(|x| {
323 x.as_ref()
324 .unwrap()
325 .get()
326 .prepare("SELECT sql FROM 'main'.'sqlite_schema'")
327 .unwrap()
328 .query_map([], |row| row.get::<_, Option<String>>("sql"))
329 .unwrap()
330 .flat_map(|x| x.unwrap())
331 .collect()
332 })
333 }
334}