1pub mod config;
2pub mod migration;
3#[cfg(test)]
4mod test;
5
6use std::{
7 collections::{BTreeSet, HashMap},
8 marker::PhantomData,
9 sync::atomic::AtomicI64,
10};
11
12use annotate_snippets::{Renderer, renderer::DecorStyle};
13use rusqlite::config::DbConfig;
14use sea_query::{Alias, ColumnDef, IntoIden, SqliteQueryBuilder};
15use self_cell::MutBorrow;
16
17use crate::{
18 Table, Transaction,
19 alias::Scope,
20 migrate::{
21 config::Config,
22 migration::{SchemaBuilder, TransactionMigrate},
23 },
24 pool::Pool,
25 schema::{from_db, from_macro, read::read_schema},
26 transaction::{Database, OwnedTransaction, TXN, TransactionWithRows},
27};
28
29pub struct TableTypBuilder<S> {
30 pub(crate) ast: from_macro::Schema,
31 _p: PhantomData<S>,
32}
33
34impl<S> Default for TableTypBuilder<S> {
35 fn default() -> Self {
36 Self {
37 ast: Default::default(),
38 _p: Default::default(),
39 }
40 }
41}
42
43impl<S> TableTypBuilder<S> {
44 pub fn table<T: Table<Schema = S>>(&mut self) {
45 let table = from_macro::Table::new::<T>();
46 let old = self.ast.tables.insert(T::NAME, table);
47 debug_assert!(old.is_none());
48 }
49}
50
51pub trait Schema: Sized + 'static {
52 const VERSION: i64;
53 const SOURCE: &str;
54 const PATH: &str;
55 const SPAN: (usize, usize);
56 fn typs(b: &mut TableTypBuilder<Self>);
57}
58
59fn new_table_inner(table: &crate::schema::from_macro::Table, alias: impl IntoIden) -> String {
60 let alias = alias.into_iden();
61 let mut create = table.create();
62 create
63 .table(alias.clone())
64 .col(ColumnDef::new(Alias::new("id")).integer().primary_key());
65 let mut sql = create.to_string(SqliteQueryBuilder);
66 sql.push_str(" STRICT");
67 sql
68}
69
70pub trait SchemaMigration<'a> {
71 type From: Schema;
72 type To: Schema;
73
74 fn tables(self, b: &mut SchemaBuilder<'a, Self::From>);
75}
76
77impl<S: Schema> Database<S> {
78 pub fn migrator(config: Config) -> Option<Migrator<S>> {
82 let synchronous = config.synchronous.as_str();
83 let foreign_keys = config.foreign_keys.as_str();
84 let manager = config.manager.with_init(move |inner| {
85 inner.pragma_update(None, "journal_mode", "WAL")?;
86 inner.pragma_update(None, "synchronous", synchronous)?;
87 inner.pragma_update(None, "foreign_keys", foreign_keys)?;
88 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DDL, false)?;
89 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DQS_DML, false)?;
90 inner.set_db_config(DbConfig::SQLITE_DBCONFIG_DEFENSIVE, true)?;
91
92 #[cfg(feature = "bundled")]
93 inner.create_scalar_function(
94 "floor",
95 1,
96 rusqlite::functions::FunctionFlags::SQLITE_DETERMINISTIC,
97 |ctx| {
98 assert_eq!(ctx.len(), 1, "called with unexpected number of arguments");
99 let res = ctx.get::<Option<f64>>(0)?.map(|x| x.floor());
100 Ok(res)
101 },
102 )?;
103
104 #[cfg(feature = "bundled")]
105 inner.create_scalar_function(
106 "ceil",
107 1,
108 rusqlite::functions::FunctionFlags::SQLITE_DETERMINISTIC,
109 |ctx| {
110 assert_eq!(ctx.len(), 1, "called with unexpected number of arguments");
111 let res = ctx.get::<Option<f64>>(0)?.map(|x| x.ceil());
112 Ok(res)
113 },
114 )?;
115
116 #[cfg(feature = "jiff-02")]
117 inner.create_scalar_function(
118 "timestamp_add_nanosecond",
119 2,
120 rusqlite::functions::FunctionFlags::SQLITE_DETERMINISTIC,
121 |ctx| {
122 use crate::value::DbTyp;
123 assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
124 if matches!(ctx.get_raw(0), rusqlite::types::ValueRef::Null)
125 || matches!(ctx.get_raw(1), rusqlite::types::ValueRef::Null)
126 {
127 return Ok(None);
128 }
129
130 let timestamp = jiff::Timestamp::from_sql(ctx.get_raw(0))?;
131 let seconds = ctx.get::<i64>(1)?;
132 let new = timestamp + jiff::SignedDuration::from_nanos(seconds);
133 let sea_query::Value::String(Some(res)) = jiff::Timestamp::out_to_value(new)
134 else {
135 unreachable!("func always returns some string")
136 };
137 Ok(Some(res))
138 },
139 )?;
140
141 #[cfg(feature = "jiff-02")]
142 inner.create_scalar_function(
143 "timestamp_subsec_nanosecond",
144 1,
145 rusqlite::functions::FunctionFlags::SQLITE_DETERMINISTIC,
146 |ctx| {
147 use crate::value::DbTyp;
148 assert_eq!(ctx.len(), 1, "called with unexpected number of arguments");
149 if matches!(ctx.get_raw(0), rusqlite::types::ValueRef::Null) {
150 return Ok(None);
151 }
152
153 let timestamp = jiff::Timestamp::from_sql(ctx.get_raw(0))?;
154 Ok(Some(timestamp.subsec_nanosecond()))
155 },
156 )?;
157
158 #[cfg(feature = "jiff-02")]
159 inner.create_scalar_function(
160 "timestamp_to_second",
161 1,
162 rusqlite::functions::FunctionFlags::SQLITE_DETERMINISTIC,
163 |ctx| {
164 use crate::value::DbTyp;
165 assert_eq!(ctx.len(), 1, "called with unexpected number of arguments");
166 if matches!(ctx.get_raw(0), rusqlite::types::ValueRef::Null) {
167 return Ok(None);
168 }
169
170 let timestamp = jiff::Timestamp::from_sql(ctx.get_raw(0))?;
171 Ok(Some(timestamp.as_second()))
172 },
173 )?;
174
175 #[cfg(feature = "jiff-02")]
176 inner.create_scalar_function(
177 "timestamp_to_date",
178 2,
179 rusqlite::functions::FunctionFlags::SQLITE_DETERMINISTIC,
180 |ctx| {
181 use jiff::fmt::temporal;
182
183 use crate::value::DbTyp;
184 assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
185 if matches!(ctx.get_raw(0), rusqlite::types::ValueRef::Null)
186 || matches!(ctx.get_raw(1), rusqlite::types::ValueRef::Null)
187 {
188 return Ok(None);
189 }
190
191 static PARSER: temporal::DateTimeParser = temporal::DateTimeParser::new();
192
193 let timestamp = jiff::Timestamp::from_sql(ctx.get_raw(0))?;
194 let timezone = PARSER
195 .parse_time_zone(ctx.get_raw(1).as_str()?)
196 .expect("time zone was serialized with jiff");
197 let date = timezone.to_datetime(timestamp).date();
198 let sea_query::Value::String(Some(res)) = jiff::civil::Date::out_to_value(date)
199 else {
200 unreachable!("func always returns some string")
201 };
202 Ok(Some(res))
203 },
204 )?;
205
206 Ok(())
207 });
208
209 use r2d2::ManageConnection;
210 let conn = manager.connect().unwrap();
211 conn.pragma_update(None, "foreign_keys", "OFF").unwrap();
212 let txn = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
213 Some(
214 conn.borrow_mut()
215 .transaction_with_behavior(rusqlite::TransactionBehavior::Exclusive)
216 .unwrap(),
217 )
218 });
219
220 let mut user_version = Some(user_version(txn.get()).unwrap());
221
222 if schema_version(txn.get()) == 0 {
224 user_version = None;
225
226 let schema = crate::schema::from_macro::Schema::new::<S>();
227
228 for (&table_name, table) in &schema.tables {
229 txn.get()
230 .execute(&new_table_inner(table, table_name), [])
231 .unwrap();
232 for stmt in table.delayed_indices(table_name) {
233 txn.get().execute(&stmt, []).unwrap();
234 }
235 }
236 (config.init)(txn.get());
237 } else if user_version.unwrap() < S::VERSION {
238 return None;
240 }
241
242 debug_assert_eq!(
243 foreign_key_check(txn.get()),
244 None,
245 "foreign key constraint violated"
246 );
247
248 Some(Migrator {
249 user_version,
250 manager,
251 transaction: txn,
252 _p: PhantomData,
253 })
254 }
255}
256
257pub struct Migrator<S> {
262 manager: r2d2_sqlite::SqliteConnectionManager,
263 transaction: OwnedTransaction,
264 user_version: Option<i64>,
269 _p: PhantomData<S>,
270}
271
272impl<S: Schema> Migrator<S> {
273 fn with_transaction(mut self, f: impl Send + FnOnce(&'static mut Transaction<S>)) -> Self {
274 assert!(self.user_version.is_none_or(|x| x == S::VERSION));
275 let res = std::thread::scope(|s| {
276 s.spawn(|| {
277 TXN.set(Some(TransactionWithRows::new_empty(self.transaction)));
278 let txn = Transaction::new_ref();
279
280 if self.user_version.take().is_some() {
282 check_schema::<S>(txn);
284 fix_indices::<S>(txn);
286 }
287
288 f(txn);
289
290 let transaction = TXN.take().unwrap();
291
292 transaction.into_owner()
293 })
294 .join()
295 });
296 match res {
297 Ok(val) => self.transaction = val,
298 Err(payload) => std::panic::resume_unwind(payload),
299 }
300 self
301 }
302
303 pub fn migrate<'x, M>(
307 mut self,
308 m: impl Send + FnOnce(&mut TransactionMigrate<S>) -> M,
309 ) -> Migrator<M::To>
310 where
311 M: SchemaMigration<'x, From = S>,
312 {
313 if self.user_version.is_none_or(|x| x == S::VERSION) {
314 self = self.with_transaction(|txn| {
315 let mut txn = TransactionMigrate {
316 inner: txn.copy(),
317 scope: Default::default(),
318 rename_map: HashMap::new(),
319 extra_index: Vec::new(),
320 };
321 let m = m(&mut txn);
322
323 let mut builder = SchemaBuilder {
324 drop: vec![],
325 foreign_key: HashMap::new(),
326 inner: txn,
327 };
328 m.tables(&mut builder);
329
330 let transaction = TXN.take().unwrap();
331
332 for drop in builder.drop {
333 let sql = drop.to_string(SqliteQueryBuilder);
334 transaction.get().execute(&sql, []).unwrap();
335 }
336 for (to, tmp) in builder.inner.rename_map {
337 let rename = sea_query::Table::rename().table(tmp, Alias::new(to)).take();
338 let sql = rename.to_string(SqliteQueryBuilder);
339 transaction.get().execute(&sql, []).unwrap();
340 }
341 #[allow(
342 unreachable_code,
343 reason = "rustc is stupid and thinks this is unreachable"
344 )]
345 if let Some(fk) = foreign_key_check(transaction.get()) {
346 (builder.foreign_key.remove(&*fk).unwrap())();
347 }
348 for stmt in builder.inner.extra_index {
350 transaction.get().execute(&stmt, []).unwrap();
351 }
352
353 TXN.set(Some(transaction));
354 });
355 }
356
357 Migrator {
358 user_version: self.user_version,
359 manager: self.manager,
360 transaction: self.transaction,
361 _p: PhantomData,
362 }
363 }
364
365 pub fn fixup(mut self, f: impl Send + FnOnce(&'static mut Transaction<S>)) -> Self {
372 if self.user_version.is_none() {
373 self = self.with_transaction(f);
374 }
375 self
376 }
377
378 pub fn finish(mut self) -> Option<Database<S>> {
384 if self.user_version.is_some_and(|x| x != S::VERSION) {
385 return None;
386 }
387
388 self = self.with_transaction(|txn| {
390 check_schema::<S>(txn);
392 });
393
394 self.transaction
396 .get()
397 .execute_batch("PRAGMA optimize;")
398 .unwrap();
399
400 set_user_version(self.transaction.get(), S::VERSION).unwrap();
401 let schema_version = schema_version(self.transaction.get());
402 self.transaction.with(|x| x.commit().unwrap());
403
404 Some(Database {
405 manager: Pool::new(self.manager),
406 schema_version: AtomicI64::new(schema_version),
407 schema: PhantomData,
408 mut_lock: parking_lot::FairMutex::new(()),
409 })
410 }
411}
412
413fn fix_indices<S: Schema>(txn: &Transaction<S>) {
414 let schema = read_schema(txn);
415 let expected_schema = crate::schema::from_macro::Schema::new::<S>();
416
417 fn check_eq(expected: &from_macro::Table, actual: &from_db::Table) -> bool {
418 let expected: BTreeSet<_> = expected.indices.iter().map(|idx| &idx.def).collect();
419 let actual: BTreeSet<_> = actual.indices.values().collect();
420 expected == actual
421 }
422
423 for (&table_name, expected_table) in &expected_schema.tables {
424 let table = &schema.tables[table_name];
425
426 if !check_eq(expected_table, table) {
427 let scope = Scope::default();
432 let tmp_name = scope.tmp_table();
433
434 txn.execute(&new_table_inner(expected_table, tmp_name));
435
436 let mut columns: Vec<_> = expected_table.columns.keys().map(Alias::new).collect();
437 columns.push(Alias::new("id"));
438
439 txn.execute(
440 &sea_query::InsertStatement::new()
441 .into_table(tmp_name)
442 .columns(columns.clone())
443 .select_from(
444 sea_query::SelectStatement::new()
445 .from(table_name)
446 .columns(columns)
447 .take(),
448 )
449 .unwrap()
450 .build(SqliteQueryBuilder)
451 .0,
452 );
453
454 txn.execute(
455 &sea_query::TableDropStatement::new()
456 .table(table_name)
457 .build(SqliteQueryBuilder),
458 );
459
460 txn.execute(
461 &sea_query::TableRenameStatement::new()
462 .table(tmp_name, table_name)
463 .build(SqliteQueryBuilder),
464 );
465 for sql in expected_table.delayed_indices(table_name) {
467 txn.execute(&sql);
468 }
469 }
470 }
471
472 let schema = read_schema(txn);
474 for (name, table) in schema.tables {
475 let expected_table = &expected_schema.tables[&*name];
476 assert!(check_eq(expected_table, &table));
477 }
478}
479
480impl<S> Transaction<S> {
481 #[track_caller]
482 pub(crate) fn execute(&self, sql: &str) {
483 TXN.with_borrow(|txn| txn.as_ref().unwrap().get().execute(sql, []))
484 .unwrap();
485 }
486}
487
488pub fn schema_version(conn: &rusqlite::Transaction) -> i64 {
489 conn.pragma_query_value(None, "schema_version", |r| r.get(0))
490 .unwrap()
491}
492
493pub fn user_version(conn: &rusqlite::Transaction) -> Result<i64, rusqlite::Error> {
495 conn.query_row("PRAGMA user_version", [], |row| row.get(0))
496}
497
498fn set_user_version(conn: &rusqlite::Transaction, v: i64) -> Result<(), rusqlite::Error> {
500 conn.pragma_update(None, "user_version", v)
501}
502
503pub(crate) fn check_schema<S: Schema>(txn: &Transaction<S>) {
504 let from_macro = crate::schema::from_macro::Schema::new::<S>();
505 let from_db = read_schema(txn);
506 let report = from_db.diff(from_macro, S::SOURCE, S::PATH, S::VERSION);
507 if !report.is_empty() {
508 let renderer = if cfg!(test) {
509 Renderer::plain().anonymized_line_numbers(true)
510 } else {
511 Renderer::styled()
512 }
513 .decor_style(DecorStyle::Unicode);
514 panic!("{}", renderer.render(&report));
515 }
516}
517
518fn foreign_key_check(conn: &rusqlite::Transaction) -> Option<String> {
519 let error = conn
520 .prepare("PRAGMA foreign_key_check")
521 .unwrap()
522 .query_map([], |row| row.get(2))
523 .unwrap()
524 .next();
525 error.transpose().unwrap()
526}
527
528impl<S> Transaction<S> {
529 #[cfg(test)]
530 pub(crate) fn schema(&self) -> Vec<String> {
531 TXN.with_borrow(|x| {
532 x.as_ref()
533 .unwrap()
534 .get()
535 .prepare("SELECT sql FROM 'main'.'sqlite_schema'")
536 .unwrap()
537 .query_map([], |row| row.get::<_, Option<String>>("sql"))
538 .unwrap()
539 .flat_map(|x| x.unwrap())
540 .collect()
541 })
542 }
543}