1use std::{
2 cell::RefCell, convert::Infallible, iter::zip, marker::PhantomData, sync::atomic::AtomicI64,
3};
4
5use rusqlite::ErrorCode;
6use sea_query::{
7 Alias, CommonTableExpression, DeleteStatement, Expr, ExprTrait, InsertStatement, IntoTableRef,
8 SelectStatement, SqliteQueryBuilder, UpdateStatement, WithClause,
9};
10use sea_query_rusqlite::RusqliteBinder;
11use self_cell::{MutBorrow, self_cell};
12
13use crate::{
14 IntoExpr, IntoSelect, Table, TableRow,
15 joinable::DynJoinable,
16 migrate::{Schema, check_schema, schema_version, user_version},
17 migration::Config,
18 mutable::Mutable,
19 pool::Pool,
20 private::{Joinable, Reader},
21 query::{OwnedRows, Query, track_stmt},
22 rows::Rows,
23 value::{DynTypedExpr, MyTyp, OptTable, SecretFromSql, ValueBuilder},
24 writable::TableInsert,
25};
26
27pub struct Database<S> {
40 pub(crate) manager: Pool,
41 pub(crate) schema_version: AtomicI64,
42 pub(crate) schema: PhantomData<S>,
43 pub(crate) mut_lock: parking_lot::FairMutex<()>,
44}
45
46impl<S: Schema> Database<S> {
47 pub fn new(config: Config) -> Self {
52 let Some(m) = Self::migrator(config) else {
53 panic!("schema version {}, but got an older version", S::VERSION)
54 };
55 let Some(m) = m.finish() else {
56 panic!("schema version {}, but got a new version", S::VERSION)
57 };
58 m
59 }
60}
61
62use rusqlite::Connection;
63type RTransaction<'x> = Option<rusqlite::Transaction<'x>>;
64
65self_cell!(
66 pub struct OwnedTransaction {
67 owner: MutBorrow<Connection>,
68
69 #[covariant]
70 dependent: RTransaction,
71 }
72);
73
74unsafe impl Send for OwnedTransaction {}
80assert_not_impl_any! {OwnedTransaction: Sync}
81
82thread_local! {
83 pub(crate) static TXN: RefCell<Option<TransactionWithRows>> = const { RefCell::new(None) };
84}
85
86impl OwnedTransaction {
87 pub(crate) fn get(&self) -> &rusqlite::Transaction<'_> {
88 self.borrow_dependent().as_ref().unwrap()
89 }
90
91 pub(crate) fn with(
92 mut self,
93 f: impl FnOnce(rusqlite::Transaction<'_>),
94 ) -> rusqlite::Connection {
95 self.with_dependent_mut(|_, b| f(b.take().unwrap()));
96 self.into_owner().into_inner()
97 }
98}
99
100type OwnedRowsVec<'x> = slab::Slab<OwnedRows<'x>>;
101self_cell!(
102 pub struct TransactionWithRows {
103 owner: OwnedTransaction,
104
105 #[not_covariant]
106 dependent: OwnedRowsVec,
107 }
108);
109
110impl TransactionWithRows {
111 pub(crate) fn new_empty(txn: OwnedTransaction) -> Self {
112 Self::new(txn, |_| slab::Slab::new())
113 }
114
115 pub(crate) fn get(&self) -> &rusqlite::Transaction<'_> {
116 self.borrow_owner().get()
117 }
118}
119
120impl<S: Send + Sync + Schema> Database<S> {
121 #[doc = include_str!("database/transaction.md")]
122 pub fn transaction<R: Send>(&self, f: impl Send + FnOnce(&'static Transaction<S>) -> R) -> R {
123 let res = std::thread::scope(|scope| scope.spawn(|| self.transaction_local(f)).join());
124 match res {
125 Ok(val) => val,
126 Err(payload) => std::panic::resume_unwind(payload),
127 }
128 }
129
130 pub(crate) fn transaction_local<R>(&self, f: impl FnOnce(&'static Transaction<S>) -> R) -> R {
132 let conn = self.manager.pop();
133
134 let owned = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
135 Some(conn.borrow_mut().transaction().unwrap())
136 });
137
138 let res = f(Transaction::new_checked(owned, &self.schema_version));
139
140 let owned = TXN.take().unwrap().into_owner();
141 self.manager.push(owned.into_owner().into_inner());
142
143 res
144 }
145
146 #[doc = include_str!("database/transaction_mut.md")]
147 pub fn transaction_mut<O: Send, E: Send>(
148 &self,
149 f: impl Send + FnOnce(&'static mut Transaction<S>) -> Result<O, E>,
150 ) -> Result<O, E> {
151 let join_res =
152 std::thread::scope(|scope| scope.spawn(|| self.transaction_mut_local(f)).join());
153
154 match join_res {
155 Ok(val) => val,
156 Err(payload) => std::panic::resume_unwind(payload),
157 }
158 }
159
160 pub(crate) fn transaction_mut_local<O, E>(
161 &self,
162 f: impl FnOnce(&'static mut Transaction<S>) -> Result<O, E>,
163 ) -> Result<O, E> {
164 let guard = self.mut_lock.lock();
168
169 let conn = self.manager.pop();
170
171 let owned = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
172 let txn = conn
173 .borrow_mut()
174 .transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)
175 .unwrap();
176 Some(txn)
177 });
178 let res = f(Transaction::new_checked(owned, &self.schema_version));
180
181 drop(guard);
184
185 let owned = TXN.take().unwrap().into_owner();
186
187 let conn = if res.is_ok() {
188 owned.with(|x| x.commit().unwrap())
189 } else {
190 owned.with(|x| x.rollback().unwrap())
191 };
192 self.manager.push(conn);
193
194 res
195 }
196
197 #[doc = include_str!("database/transaction_mut_ok.md")]
198 pub fn transaction_mut_ok<R: Send>(
199 &self,
200 f: impl Send + FnOnce(&'static mut Transaction<S>) -> R,
201 ) -> R {
202 self.transaction_mut(|txn| Ok::<R, Infallible>(f(txn)))
203 .unwrap()
204 }
205
206 pub fn rusqlite_connection(&self) -> rusqlite::Connection {
219 let conn = self.manager.pop();
220 conn.pragma_update(None, "foreign_keys", "ON").unwrap();
221 conn
222 }
223}
224
225pub struct Transaction<S> {
231 pub(crate) _p2: PhantomData<S>,
232 pub(crate) _local: PhantomData<*const ()>,
233}
234
235impl<S> Transaction<S> {
236 pub(crate) fn new() -> Self {
237 Self {
238 _p2: PhantomData,
239 _local: PhantomData,
240 }
241 }
242
243 pub(crate) fn copy(&self) -> Self {
244 Self::new()
245 }
246
247 pub(crate) fn new_ref() -> &'static mut Self {
248 Box::leak(Box::new(Self::new()))
250 }
251}
252
253impl<S: Schema> Transaction<S> {
254 pub(crate) fn new_checked(txn: OwnedTransaction, expected: &AtomicI64) -> &'static mut Self {
256 let schema_version = schema_version(txn.get());
257 if schema_version != expected.load(std::sync::atomic::Ordering::Relaxed) {
260 if user_version(txn.get()).unwrap() != S::VERSION {
261 panic!("The database user_version changed unexpectedly")
262 }
263
264 TXN.set(Some(TransactionWithRows::new_empty(txn)));
265 check_schema::<S>(Self::new_ref());
266 expected.store(schema_version, std::sync::atomic::Ordering::Relaxed);
267 } else {
268 TXN.set(Some(TransactionWithRows::new_empty(txn)));
269 }
270
271 const {
272 assert!(size_of::<Self>() == 0);
273 }
274 Self::new_ref()
275 }
276}
277
278impl<S> Transaction<S> {
279 pub fn query<'t, R>(&'t self, f: impl FnOnce(&mut Query<'t, '_, S>) -> R) -> R {
292 let q = Rows {
297 phantom: PhantomData,
298 ast: Default::default(),
299 _p: PhantomData,
300 };
301 f(&mut Query {
302 q,
303 phantom: PhantomData,
304 })
305 }
306
307 pub fn query_one<O: 'static>(&self, val: impl IntoSelect<'static, S, Out = O>) -> O {
320 self.query(|e| e.into_iter(val.into_select()).next().unwrap())
321 }
322
323 pub fn lazy<'t, T: OptTable>(&'t self, val: impl IntoExpr<'static, S, Typ = T>) -> T::Lazy<'t> {
331 T::out_to_lazy(self.query_one(val.into_expr()))
332 }
333
334 pub fn lazy_iter<'t, T: Table<Schema = S>>(
338 &'t self,
339 val: impl Joinable<'static, Typ = T>,
340 ) -> LazyIter<'t, T> {
341 let val = DynJoinable::new(val);
342 self.query(|rows| {
343 let table = rows.join(val);
344 LazyIter {
345 txn: self,
346 iter: rows.into_iter(table),
347 }
348 })
349 }
350
351 pub fn mutable<'t, T: OptTable<Schema = S>>(
355 &'t mut self,
356 val: impl IntoExpr<'static, S, Typ = T>,
357 ) -> T::Mutable<'t> {
358 let x = self.query_one(T::select_opt_mutable(val.into_expr()));
359 T::into_mutable(x)
360 }
361
362 pub fn mutable_vec<'t, T: Table<Schema = S>>(
366 &'t mut self,
367 val: impl Joinable<'static, Typ = T>,
368 ) -> Vec<Mutable<'t, T>> {
369 let val = DynJoinable::new(val);
370 self.query(|rows| {
371 let val = rows.join(val);
372 rows.into_vec((T::select_mutable(val.clone()), val))
373 .into_iter()
374 .map(T::into_mutable)
375 .collect()
376 })
377 }
378}
379
380pub struct LazyIter<'t, T: Table> {
381 txn: &'t Transaction<T::Schema>,
382 iter: crate::query::Iter<'t, TableRow<T>>,
383}
384
385impl<'t, T: Table> Iterator for LazyIter<'t, T> {
386 type Item = <T as MyTyp>::Lazy<'t>;
387
388 fn next(&mut self) -> Option<Self::Item> {
389 self.iter.next().map(|x| self.txn.lazy(x))
390 }
391}
392
393impl<S: 'static> Transaction<S> {
394 pub fn insert<T: Table<Schema = S>>(
416 &mut self,
417 val: impl TableInsert<T = T>,
418 ) -> Result<TableRow<T>, T::Conflict> {
419 try_insert_private(T::NAME.into_table_ref(), None, val.into_insert())
420 }
421
422 pub fn insert_ok<T: Table<Schema = S, Conflict = Infallible>>(
427 &mut self,
428 val: impl TableInsert<T = T>,
429 ) -> TableRow<T> {
430 let Ok(row) = self.insert(val);
431 row
432 }
433
434 pub fn find_or_insert<T: Table<Schema = S, Conflict = TableRow<T>>>(
454 &mut self,
455 val: impl TableInsert<T = T>,
456 ) -> TableRow<T> {
457 match self.insert(val) {
458 Ok(row) => row,
459 Err(row) => row,
460 }
461 }
462
463 #[deprecated = "Use `Mutable::unique` instead"]
486 pub fn update<T: Table<Schema = S>>(
487 &mut self,
488 row: impl IntoExpr<'static, S, Typ = T>,
489 val: T::Update,
490 ) -> Result<(), T::Conflict> {
491 let mut id = ValueBuilder::default();
492 let row = row.into_expr();
493 let (id, _) = id.simple_one(DynTypedExpr::erase(&row));
494
495 let val = T::apply_try_update(val, row);
496 let mut reader = Reader::default();
497 T::read(&val, &mut reader);
498 let (col_names, col_exprs): (Vec<_>, Vec<_>) = reader.builder.into_iter().collect();
499
500 let (select, col_fields) = ValueBuilder::default().simple(col_exprs);
501 let cte = CommonTableExpression::new()
502 .query(select)
503 .columns(col_fields.clone())
504 .table_name(Alias::new("cte"))
505 .to_owned();
506 let with_clause = WithClause::new().cte(cte).to_owned();
507
508 let mut update = UpdateStatement::new()
509 .table(("main", T::NAME))
510 .cond_where(Expr::col(("main", T::NAME, T::ID)).in_subquery(id))
511 .to_owned();
512
513 for (name, field) in zip(col_names, col_fields) {
514 let select = SelectStatement::new()
515 .from(Alias::new("cte"))
516 .column(field)
517 .to_owned();
518 let value = sea_query::Expr::SubQuery(
519 None,
520 Box::new(sea_query::SubQueryStatement::SelectStatement(select)),
521 );
522 update.value(Alias::new(name), value);
523 }
524
525 let (query, args) = update.with(with_clause).build_rusqlite(SqliteQueryBuilder);
526
527 let res = TXN.with_borrow(|txn| {
528 let txn = txn.as_ref().unwrap().get();
529
530 let mut stmt = txn.prepare_cached(&query).unwrap();
531 stmt.execute(&*args.as_params())
532 });
533
534 match res {
535 Ok(1) => Ok(()),
536 Ok(n) => panic!("unexpected number of updates: {n}"),
537 Err(rusqlite::Error::SqliteFailure(kind, Some(_val)))
538 if kind.code == ErrorCode::ConstraintViolation =>
539 {
540 Err(T::get_conflict_unchecked(self, &val))
542 }
543 Err(err) => panic!("{err:?}"),
544 }
545 }
546
547 #[deprecated = "Use Transaction::mutable instead"]
553 pub fn update_ok<T: Table<Schema = S>>(
554 &mut self,
555 row: impl IntoExpr<'static, S, Typ = T>,
556 val: T::UpdateOk,
557 ) {
558 #[expect(deprecated)]
559 match self.update(row, T::update_into_try_update(val)) {
560 Ok(val) => val,
561 Err(_) => {
562 unreachable!("update can not fail")
563 }
564 }
565 }
566
567 pub fn downgrade(&'static mut self) -> &'static mut TransactionWeak<S> {
569 Box::leak(Box::new(TransactionWeak { inner: PhantomData }))
571 }
572}
573
574pub struct TransactionWeak<S> {
581 inner: PhantomData<Transaction<S>>,
582}
583
584impl<S: Schema> TransactionWeak<S> {
585 pub fn delete<T: Table<Schema = S>>(&mut self, val: TableRow<T>) -> Result<bool, T::Referer> {
592 let schema = crate::schema::from_macro::Schema::new::<S>();
593
594 let mut checks = vec![];
598 for (&table_name, table) in &schema.tables {
599 for col in table.columns.iter().filter_map(|(col_name, col)| {
600 let col = &col.def;
601 col.fk
602 .as_ref()
603 .is_some_and(|(t, c)| t == T::NAME && c == T::ID)
604 .then_some(col_name)
605 }) {
606 let stmt = SelectStatement::new()
607 .expr(
608 val.in_subquery(
609 SelectStatement::new()
610 .from(table_name)
611 .column(Alias::new(col))
612 .take(),
613 ),
614 )
615 .take();
616 checks.push(stmt.build_rusqlite(SqliteQueryBuilder));
617 }
618 }
619
620 let stmt = DeleteStatement::new()
621 .from_table(("main", T::NAME))
622 .cond_where(Expr::col(("main", T::NAME, T::ID)).eq(val.inner.idx))
623 .take();
624
625 let (query, args) = stmt.build_rusqlite(SqliteQueryBuilder);
626
627 TXN.with_borrow(|txn| {
628 let txn = txn.as_ref().unwrap().get();
629
630 for (query, args) in checks {
631 let mut stmt = txn.prepare_cached(&query).unwrap();
632 match stmt.query_one(&*args.as_params(), |r| r.get(0)) {
633 Ok(true) => return Err(T::get_referer_unchecked()),
634 Ok(false) => {}
635 Err(err) => panic!("{err:?}"),
636 }
637 }
638
639 let mut stmt = txn.prepare_cached(&query).unwrap();
640 match stmt.execute(&*args.as_params()) {
641 Ok(0) => Ok(false),
642 Ok(1) => Ok(true),
643 Ok(n) => {
644 panic!("unexpected number of deletes {n}")
645 }
646 Err(err) => panic!("{err:?}"),
647 }
648 })
649 }
650
651 pub fn delete_ok<T: Table<Referer = Infallible, Schema = S>>(
657 &mut self,
658 val: TableRow<T>,
659 ) -> bool {
660 let Ok(res) = self.delete(val);
661 res
662 }
663
664 pub fn rusqlite_transaction<R>(&mut self, f: impl FnOnce(&rusqlite::Transaction) -> R) -> R {
674 TXN.with_borrow(|txn| f(txn.as_ref().unwrap().get()))
675 }
676}
677
678pub fn try_insert_private<T: Table>(
679 table: sea_query::TableRef,
680 idx: Option<i64>,
681 val: T::Insert,
682) -> Result<TableRow<T>, T::Conflict> {
683 let mut reader = Reader::default();
684 T::read(&val, &mut reader);
685 if let Some(idx) = idx {
686 reader.col(T::ID, idx);
687 }
688 let (col_names, col_exprs): (Vec<_>, Vec<_>) = reader.builder.into_iter().collect();
689 let is_empty = col_names.is_empty();
690
691 let (select, _) = ValueBuilder::default().simple(col_exprs);
692
693 let mut insert = InsertStatement::new();
694 insert.into_table(table);
695 insert.columns(col_names.into_iter().map(Alias::new));
696 if is_empty {
697 insert.or_default_values();
699 } else {
700 insert.select_from(select).unwrap();
701 }
702 insert.returning_col(T::ID);
703
704 let (sql, values) = insert.build_rusqlite(SqliteQueryBuilder);
705
706 let res = TXN.with_borrow(|txn| {
707 let txn = txn.as_ref().unwrap().get();
708 track_stmt(txn, &sql, &values);
709
710 let mut statement = txn.prepare_cached(&sql).unwrap();
711 let mut res = statement
712 .query_map(&*values.as_params(), |row| {
713 Ok(TableRow::<T>::from_sql(row.get_ref(T::ID)?)?)
714 })
715 .unwrap();
716
717 res.next().unwrap()
718 });
719
720 match res {
721 Ok(id) => {
722 if let Some(idx) = idx {
723 assert_eq!(idx, id.inner.idx);
724 }
725 Ok(id)
726 }
727 Err(rusqlite::Error::SqliteFailure(kind, Some(_val)))
728 if kind.code == ErrorCode::ConstraintViolation =>
729 {
730 Err(T::get_conflict_unchecked(&Transaction::new(), &val))
732 }
733 Err(err) => panic!("{err:?}"),
734 }
735}