1use std::{
2 cell::RefCell, convert::Infallible, iter::zip, marker::PhantomData, sync::atomic::AtomicI64,
3};
4
5use rusqlite::ErrorCode;
6use sea_query::{
7 Alias, CommonTableExpression, DeleteStatement, Expr, ExprTrait, InsertStatement, IntoTableRef,
8 SelectStatement, SqliteQueryBuilder, UpdateStatement, WithClause,
9};
10use sea_query_rusqlite::RusqliteBinder;
11use self_cell::{MutBorrow, self_cell};
12
13use crate::{
14 IntoExpr, IntoSelect, Table, TableRow,
15 joinable::DynJoinable,
16 migrate::{Schema, check_schema, schema_version, user_version},
17 migration::Config,
18 mutable::Mutable,
19 pool::Pool,
20 private::{Joinable, Reader},
21 query::{OwnedRows, Query, track_stmt},
22 rows::Rows,
23 value::{DynTypedExpr, MyTyp, OptTable, SecretFromSql, ValueBuilder},
24 writable::TableInsert,
25};
26
27pub struct Database<S> {
40 pub(crate) manager: Pool,
41 pub(crate) schema_version: AtomicI64,
42 pub(crate) schema: PhantomData<S>,
43 pub(crate) mut_lock: parking_lot::FairMutex<()>,
44}
45
46impl<S: Schema> Database<S> {
47 pub fn new(config: Config) -> Self {
52 let Some(m) = Self::migrator(config) else {
53 panic!("schema version {}, but got an older version", S::VERSION)
54 };
55 let Some(m) = m.finish() else {
56 panic!("schema version {}, but got a new version", S::VERSION)
57 };
58 m
59 }
60}
61
62use rusqlite::Connection;
63type RTransaction<'x> = Option<rusqlite::Transaction<'x>>;
64
65self_cell!(
66 pub struct OwnedTransaction {
67 owner: MutBorrow<Connection>,
68
69 #[covariant]
70 dependent: RTransaction,
71 }
72);
73
74unsafe impl Send for OwnedTransaction {}
80assert_not_impl_any! {OwnedTransaction: Sync}
81
82thread_local! {
83 pub(crate) static TXN: RefCell<Option<TransactionWithRows>> = const { RefCell::new(None) };
84}
85
86impl OwnedTransaction {
87 pub(crate) fn get(&self) -> &rusqlite::Transaction<'_> {
88 self.borrow_dependent().as_ref().unwrap()
89 }
90
91 pub(crate) fn with(
92 mut self,
93 f: impl FnOnce(rusqlite::Transaction<'_>),
94 ) -> rusqlite::Connection {
95 self.with_dependent_mut(|_, b| f(b.take().unwrap()));
96 self.into_owner().into_inner()
97 }
98}
99
100type OwnedRowsVec<'x> = slab::Slab<OwnedRows<'x>>;
101self_cell!(
102 pub struct TransactionWithRows {
103 owner: OwnedTransaction,
104
105 #[not_covariant]
106 dependent: OwnedRowsVec,
107 }
108);
109
110impl TransactionWithRows {
111 pub(crate) fn new_empty(txn: OwnedTransaction) -> Self {
112 Self::new(txn, |_| slab::Slab::new())
113 }
114
115 pub(crate) fn get(&self) -> &rusqlite::Transaction<'_> {
116 self.borrow_owner().get()
117 }
118}
119
120impl<S: Send + Sync + Schema> Database<S> {
121 #[doc = include_str!("database/transaction.md")]
122 pub fn transaction<R: Send>(&self, f: impl Send + FnOnce(&'static Transaction<S>) -> R) -> R {
123 let res = std::thread::scope(|scope| scope.spawn(|| self.transaction_local(f)).join());
124 match res {
125 Ok(val) => val,
126 Err(payload) => std::panic::resume_unwind(payload),
127 }
128 }
129
130 pub(crate) fn transaction_local<R>(&self, f: impl FnOnce(&'static Transaction<S>) -> R) -> R {
132 let conn = self.manager.pop();
133
134 let owned = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
135 Some(conn.borrow_mut().transaction().unwrap())
136 });
137
138 let res = f(Transaction::new_checked(owned, &self.schema_version));
139
140 let owned = TXN.take().unwrap().into_owner();
141 self.manager.push(owned.into_owner().into_inner());
142
143 res
144 }
145
146 #[doc = include_str!("database/transaction_mut.md")]
147 pub fn transaction_mut<O: Send, E: Send>(
148 &self,
149 f: impl Send + FnOnce(&'static mut Transaction<S>) -> Result<O, E>,
150 ) -> Result<O, E> {
151 let join_res =
152 std::thread::scope(|scope| scope.spawn(|| self.transaction_mut_local(f)).join());
153
154 match join_res {
155 Ok(val) => val,
156 Err(payload) => std::panic::resume_unwind(payload),
157 }
158 }
159
160 pub(crate) fn transaction_mut_local<O, E>(
161 &self,
162 f: impl FnOnce(&'static mut Transaction<S>) -> Result<O, E>,
163 ) -> Result<O, E> {
164 let guard = self.mut_lock.lock();
168
169 let conn = self.manager.pop();
170
171 let owned = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
172 let txn = conn
173 .borrow_mut()
174 .transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)
175 .unwrap();
176 Some(txn)
177 });
178 let res = f(Transaction::new_checked(owned, &self.schema_version));
180
181 drop(guard);
184
185 let owned = TXN.take().unwrap().into_owner();
186
187 let conn = if res.is_ok() {
188 owned.with(|x| x.commit().unwrap())
189 } else {
190 owned.with(|x| x.rollback().unwrap())
191 };
192 self.manager.push(conn);
193
194 res
195 }
196
197 #[doc = include_str!("database/transaction_mut_ok.md")]
198 pub fn transaction_mut_ok<R: Send>(
199 &self,
200 f: impl Send + FnOnce(&'static mut Transaction<S>) -> R,
201 ) -> R {
202 self.transaction_mut(|txn| Ok::<R, Infallible>(f(txn)))
203 .unwrap()
204 }
205
206 pub fn rusqlite_connection(&self) -> rusqlite::Connection {
219 let conn = self.manager.pop();
220 conn.pragma_update(None, "foreign_keys", "ON").unwrap();
221 conn
222 }
223}
224
225pub struct Transaction<S> {
231 pub(crate) _p2: PhantomData<S>,
232 pub(crate) _local: PhantomData<*const ()>,
233}
234
235impl<S> Transaction<S> {
236 pub(crate) fn new() -> Self {
237 Self {
238 _p2: PhantomData,
239 _local: PhantomData,
240 }
241 }
242
243 pub(crate) fn new_ref() -> &'static mut Self {
244 Box::leak(Box::new(Self::new()))
246 }
247}
248
249impl<S: Schema> Transaction<S> {
250 pub(crate) fn new_checked(txn: OwnedTransaction, expected: &AtomicI64) -> &'static mut Self {
252 let schema_version = schema_version(txn.get());
253 if schema_version != expected.load(std::sync::atomic::Ordering::Relaxed) {
256 if user_version(txn.get()).unwrap() != S::VERSION {
257 panic!("The database user_version changed unexpectedly")
258 }
259
260 TXN.set(Some(TransactionWithRows::new_empty(txn)));
261 check_schema::<S>(Self::new_ref());
262 expected.store(schema_version, std::sync::atomic::Ordering::Relaxed);
263 } else {
264 TXN.set(Some(TransactionWithRows::new_empty(txn)));
265 }
266
267 const {
268 assert!(size_of::<Self>() == 0);
269 }
270 Self::new_ref()
271 }
272}
273
274impl<S> Transaction<S> {
275 pub fn query<'t, R>(&'t self, f: impl FnOnce(&mut Query<'t, '_, S>) -> R) -> R {
288 let q = Rows {
293 phantom: PhantomData,
294 ast: Default::default(),
295 _p: PhantomData,
296 };
297 f(&mut Query {
298 q,
299 phantom: PhantomData,
300 })
301 }
302
303 pub fn query_one<O: 'static>(&self, val: impl IntoSelect<'static, S, Out = O>) -> O {
316 self.query(|e| e.into_iter(val.into_select()).next().unwrap())
317 }
318
319 pub fn lazy<'t, T: OptTable>(&'t self, val: impl IntoExpr<'static, S, Typ = T>) -> T::Lazy<'t> {
327 T::out_to_lazy(self.query_one(val.into_expr()))
328 }
329
330 pub fn lazy_iter<'t, T: Table<Schema = S>>(
334 &'t self,
335 val: impl Joinable<'static, Typ = T>,
336 ) -> LazyIter<'t, T> {
337 let val = DynJoinable::new(val);
338 self.query(|rows| {
339 let table = rows.join(val);
340 LazyIter {
341 txn: self,
342 iter: rows.into_iter(table),
343 }
344 })
345 }
346
347 pub fn mutable<'t, T: OptTable<Schema = S>>(
351 &'t mut self,
352 val: impl IntoExpr<'static, S, Typ = T>,
353 ) -> T::Mutable<'t> {
354 let x = self.query_one(T::select_opt_mutable(val.into_expr()));
355 T::into_mutable(x)
356 }
357
358 pub fn mutable_vec<'t, T: Table<Schema = S>>(
362 &'t mut self,
363 val: impl Joinable<'static, Typ = T>,
364 ) -> Vec<Mutable<'t, T>> {
365 let val = DynJoinable::new(val);
366 self.query(|rows| {
367 let val = rows.join(val);
368 rows.into_vec(T::select_mutable(val))
369 .into_iter()
370 .map(T::into_mutable)
371 .collect()
372 })
373 }
374}
375
376pub struct LazyIter<'t, T: Table> {
377 txn: &'t Transaction<T::Schema>,
378 iter: crate::query::Iter<'t, TableRow<T>>,
379}
380
381impl<'t, T: Table> Iterator for LazyIter<'t, T> {
382 type Item = <T as MyTyp>::Lazy<'t>;
383
384 fn next(&mut self) -> Option<Self::Item> {
385 self.iter.next().map(|x| self.txn.lazy(x))
386 }
387}
388
389impl<S: 'static> Transaction<S> {
390 pub fn insert<T: Table<Schema = S>>(
412 &mut self,
413 val: impl TableInsert<T = T>,
414 ) -> Result<TableRow<T>, T::Conflict> {
415 try_insert_private(T::NAME.into_table_ref(), None, val.into_insert())
416 }
417
418 pub fn insert_ok<T: Table<Schema = S, Conflict = Infallible>>(
423 &mut self,
424 val: impl TableInsert<T = T>,
425 ) -> TableRow<T> {
426 let Ok(row) = self.insert(val);
427 row
428 }
429
430 pub fn find_or_insert<T: Table<Schema = S, Conflict = TableRow<T>>>(
450 &mut self,
451 val: impl TableInsert<T = T>,
452 ) -> TableRow<T> {
453 match self.insert(val) {
454 Ok(row) => row,
455 Err(row) => row,
456 }
457 }
458
459 pub fn update<T: Table<Schema = S>>(
482 &mut self,
483 row: impl IntoExpr<'static, S, Typ = T>,
484 val: T::Update,
485 ) -> Result<(), T::Conflict> {
486 let mut id = ValueBuilder::default();
487 let row = row.into_expr();
488 let (id, _) = id.simple_one(DynTypedExpr::erase(&row));
489
490 let val = T::apply_try_update(val, row);
491 let mut reader = Reader::default();
492 T::read(&val, &mut reader);
493 let (col_names, col_exprs): (Vec<_>, Vec<_>) = reader.builder.into_iter().collect();
494
495 let (select, col_fields) = ValueBuilder::default().simple(col_exprs);
496 let cte = CommonTableExpression::new()
497 .query(select)
498 .columns(col_fields.clone())
499 .table_name(Alias::new("cte"))
500 .to_owned();
501 let with_clause = WithClause::new().cte(cte).to_owned();
502
503 let mut update = UpdateStatement::new()
504 .table(("main", T::NAME))
505 .cond_where(Expr::col(("main", T::NAME, T::ID)).in_subquery(id))
506 .to_owned();
507
508 for (name, field) in zip(col_names, col_fields) {
509 let select = SelectStatement::new()
510 .from(Alias::new("cte"))
511 .column(field)
512 .to_owned();
513 let value = sea_query::Expr::SubQuery(
514 None,
515 Box::new(sea_query::SubQueryStatement::SelectStatement(select)),
516 );
517 update.value(Alias::new(name), value);
518 }
519
520 let (query, args) = update.with(with_clause).build_rusqlite(SqliteQueryBuilder);
521
522 let res = TXN.with_borrow(|txn| {
523 let txn = txn.as_ref().unwrap().get();
524
525 let mut stmt = txn.prepare_cached(&query).unwrap();
526 stmt.execute(&*args.as_params())
527 });
528
529 match res {
530 Ok(1) => Ok(()),
531 Ok(n) => panic!("unexpected number of updates: {n}"),
532 Err(rusqlite::Error::SqliteFailure(kind, Some(_val)))
533 if kind.code == ErrorCode::ConstraintViolation =>
534 {
535 Err(T::get_conflict_unchecked(self, &val))
537 }
538 Err(err) => panic!("{err:?}"),
539 }
540 }
541
542 #[deprecated = "Use Transaction::mutable instead"]
548 pub fn update_ok<T: Table<Schema = S>>(
549 &mut self,
550 row: impl IntoExpr<'static, S, Typ = T>,
551 val: T::UpdateOk,
552 ) {
553 match self.update(row, T::update_into_try_update(val)) {
554 Ok(val) => val,
555 Err(_) => {
556 unreachable!("update can not fail")
557 }
558 }
559 }
560
561 pub fn downgrade(&'static mut self) -> &'static mut TransactionWeak<S> {
563 Box::leak(Box::new(TransactionWeak { inner: PhantomData }))
565 }
566}
567
568pub struct TransactionWeak<S> {
575 inner: PhantomData<Transaction<S>>,
576}
577
578impl<S: Schema> TransactionWeak<S> {
579 pub fn delete<T: Table<Schema = S>>(&mut self, val: TableRow<T>) -> Result<bool, T::Referer> {
586 let schema = crate::schema::from_macro::Schema::new::<S>();
587
588 let mut checks = vec![];
592 for (&table_name, table) in &schema.tables {
593 for col in table.columns.iter().filter_map(|(col_name, col)| {
594 let col = &col.def;
595 col.fk
596 .as_ref()
597 .is_some_and(|(t, c)| t == T::NAME && c == T::ID)
598 .then_some(col_name)
599 }) {
600 let stmt = SelectStatement::new()
601 .expr(
602 val.in_subquery(
603 SelectStatement::new()
604 .from(table_name)
605 .column(Alias::new(col))
606 .take(),
607 ),
608 )
609 .take();
610 checks.push(stmt.build_rusqlite(SqliteQueryBuilder));
611 }
612 }
613
614 let stmt = DeleteStatement::new()
615 .from_table(("main", T::NAME))
616 .cond_where(Expr::col(("main", T::NAME, T::ID)).eq(val.inner.idx))
617 .take();
618
619 let (query, args) = stmt.build_rusqlite(SqliteQueryBuilder);
620
621 TXN.with_borrow(|txn| {
622 let txn = txn.as_ref().unwrap().get();
623
624 for (query, args) in checks {
625 let mut stmt = txn.prepare_cached(&query).unwrap();
626 match stmt.query_one(&*args.as_params(), |r| r.get(0)) {
627 Ok(true) => return Err(T::get_referer_unchecked()),
628 Ok(false) => {}
629 Err(err) => panic!("{err:?}"),
630 }
631 }
632
633 let mut stmt = txn.prepare_cached(&query).unwrap();
634 match stmt.execute(&*args.as_params()) {
635 Ok(0) => Ok(false),
636 Ok(1) => Ok(true),
637 Ok(n) => {
638 panic!("unexpected number of deletes {n}")
639 }
640 Err(err) => panic!("{err:?}"),
641 }
642 })
643 }
644
645 pub fn delete_ok<T: Table<Referer = Infallible, Schema = S>>(
651 &mut self,
652 val: TableRow<T>,
653 ) -> bool {
654 let Ok(res) = self.delete(val);
655 res
656 }
657
658 pub fn rusqlite_transaction<R>(&mut self, f: impl FnOnce(&rusqlite::Transaction) -> R) -> R {
668 TXN.with_borrow(|txn| f(txn.as_ref().unwrap().get()))
669 }
670}
671
672pub fn try_insert_private<T: Table>(
673 table: sea_query::TableRef,
674 idx: Option<i64>,
675 val: T::Insert,
676) -> Result<TableRow<T>, T::Conflict> {
677 let mut reader = Reader::default();
678 T::read(&val, &mut reader);
679 if let Some(idx) = idx {
680 reader.col(T::ID, idx);
681 }
682 let (col_names, col_exprs): (Vec<_>, Vec<_>) = reader.builder.into_iter().collect();
683 let is_empty = col_names.is_empty();
684
685 let (select, _) = ValueBuilder::default().simple(col_exprs);
686
687 let mut insert = InsertStatement::new();
688 insert.into_table(table);
689 insert.columns(col_names.into_iter().map(Alias::new));
690 if is_empty {
691 insert.or_default_values();
693 } else {
694 insert.select_from(select).unwrap();
695 }
696 insert.returning_col(T::ID);
697
698 let (sql, values) = insert.build_rusqlite(SqliteQueryBuilder);
699
700 let res = TXN.with_borrow(|txn| {
701 let txn = txn.as_ref().unwrap().get();
702 track_stmt(txn, &sql, &values);
703
704 let mut statement = txn.prepare_cached(&sql).unwrap();
705 let mut res = statement
706 .query_map(&*values.as_params(), |row| {
707 Ok(TableRow::<T>::from_sql(row.get_ref(T::ID)?)?)
708 })
709 .unwrap();
710
711 res.next().unwrap()
712 });
713
714 match res {
715 Ok(id) => {
716 if let Some(idx) = idx {
717 assert_eq!(idx, id.inner.idx);
718 }
719 Ok(id)
720 }
721 Err(rusqlite::Error::SqliteFailure(kind, Some(_val)))
722 if kind.code == ErrorCode::ConstraintViolation =>
723 {
724 Err(T::get_conflict_unchecked(&Transaction::new(), &val))
726 }
727 Err(err) => panic!("{err:?}"),
728 }
729}