1use std::{
2 cell::RefCell, convert::Infallible, iter::zip, marker::PhantomData, sync::atomic::AtomicI64,
3};
4
5use rusqlite::ErrorCode;
6use sea_query::{
7 Alias, CommonTableExpression, DeleteStatement, Expr, ExprTrait, InsertStatement, IntoTableRef,
8 SelectStatement, SqliteQueryBuilder, UpdateStatement, WithClause,
9};
10use sea_query_rusqlite::RusqliteBinder;
11use self_cell::{MutBorrow, self_cell};
12
13use crate::{
14 IntoExpr, IntoSelect, Table, TableRow,
15 migrate::{Schema, check_schema, schema_version, user_version},
16 migration::Config,
17 mutable::Mutable,
18 pool::Pool,
19 private::{IntoJoinable, Reader},
20 query::{OwnedRows, Query, track_stmt},
21 rows::Rows,
22 value::{DynTypedExpr, MyTyp, OptTable, SecretFromSql, ValueBuilder},
23 writable::TableInsert,
24};
25
26pub struct Database<S> {
39 pub(crate) manager: Pool,
40 pub(crate) schema_version: AtomicI64,
41 pub(crate) schema: PhantomData<S>,
42 pub(crate) mut_lock: parking_lot::FairMutex<()>,
43}
44
45impl<S: Schema> Database<S> {
46 pub fn new(config: Config) -> Self {
51 let Some(m) = Self::migrator(config) else {
52 panic!("schema version {}, but got an older version", S::VERSION)
53 };
54 let Some(m) = m.finish() else {
55 panic!("schema version {}, but got a new version", S::VERSION)
56 };
57 m
58 }
59}
60
61use rusqlite::Connection;
62type RTransaction<'x> = Option<rusqlite::Transaction<'x>>;
63
64self_cell!(
65 pub struct OwnedTransaction {
66 owner: MutBorrow<Connection>,
67
68 #[covariant]
69 dependent: RTransaction,
70 }
71);
72
73unsafe impl Send for OwnedTransaction {}
79assert_not_impl_any! {OwnedTransaction: Sync}
80
81thread_local! {
82 pub(crate) static TXN: RefCell<Option<TransactionWithRows>> = const { RefCell::new(None) };
83}
84
85impl OwnedTransaction {
86 pub(crate) fn get(&self) -> &rusqlite::Transaction<'_> {
87 self.borrow_dependent().as_ref().unwrap()
88 }
89
90 pub(crate) fn with(
91 mut self,
92 f: impl FnOnce(rusqlite::Transaction<'_>),
93 ) -> rusqlite::Connection {
94 self.with_dependent_mut(|_, b| f(b.take().unwrap()));
95 self.into_owner().into_inner()
96 }
97}
98
99type OwnedRowsVec<'x> = slab::Slab<OwnedRows<'x>>;
100self_cell!(
101 pub struct TransactionWithRows {
102 owner: OwnedTransaction,
103
104 #[not_covariant]
105 dependent: OwnedRowsVec,
106 }
107);
108
109impl TransactionWithRows {
110 pub(crate) fn new_empty(txn: OwnedTransaction) -> Self {
111 Self::new(txn, |_| slab::Slab::new())
112 }
113
114 pub(crate) fn get(&self) -> &rusqlite::Transaction<'_> {
115 self.borrow_owner().get()
116 }
117}
118
119impl<S: Send + Sync + Schema> Database<S> {
120 #[doc = include_str!("database/transaction.md")]
121 pub fn transaction<R: Send>(&self, f: impl Send + FnOnce(&'static Transaction<S>) -> R) -> R {
122 let res = std::thread::scope(|scope| scope.spawn(|| self.transaction_local(f)).join());
123 match res {
124 Ok(val) => val,
125 Err(payload) => std::panic::resume_unwind(payload),
126 }
127 }
128
129 pub(crate) fn transaction_local<R>(&self, f: impl FnOnce(&'static Transaction<S>) -> R) -> R {
131 let conn = self.manager.pop();
132
133 let owned = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
134 Some(conn.borrow_mut().transaction().unwrap())
135 });
136
137 let res = f(Transaction::new_checked(owned, &self.schema_version));
138
139 let owned = TXN.take().unwrap().into_owner();
140 self.manager.push(owned.into_owner().into_inner());
141
142 res
143 }
144
145 #[doc = include_str!("database/transaction_mut.md")]
146 pub fn transaction_mut<O: Send, E: Send>(
147 &self,
148 f: impl Send + FnOnce(&'static mut Transaction<S>) -> Result<O, E>,
149 ) -> Result<O, E> {
150 let join_res =
151 std::thread::scope(|scope| scope.spawn(|| self.transaction_mut_local(f)).join());
152
153 match join_res {
154 Ok(val) => val,
155 Err(payload) => std::panic::resume_unwind(payload),
156 }
157 }
158
159 pub(crate) fn transaction_mut_local<O, E>(
160 &self,
161 f: impl FnOnce(&'static mut Transaction<S>) -> Result<O, E>,
162 ) -> Result<O, E> {
163 let guard = self.mut_lock.lock();
167
168 let conn = self.manager.pop();
169
170 let owned = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
171 let txn = conn
172 .borrow_mut()
173 .transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)
174 .unwrap();
175 Some(txn)
176 });
177 let res = f(Transaction::new_checked(owned, &self.schema_version));
179
180 drop(guard);
183
184 let owned = TXN.take().unwrap().into_owner();
185
186 let conn = if res.is_ok() {
187 owned.with(|x| x.commit().unwrap())
188 } else {
189 owned.with(|x| x.rollback().unwrap())
190 };
191 self.manager.push(conn);
192
193 res
194 }
195
196 #[doc = include_str!("database/transaction_mut_ok.md")]
197 pub fn transaction_mut_ok<R: Send>(
198 &self,
199 f: impl Send + FnOnce(&'static mut Transaction<S>) -> R,
200 ) -> R {
201 self.transaction_mut(|txn| Ok::<R, Infallible>(f(txn)))
202 .unwrap()
203 }
204
205 pub fn rusqlite_connection(&self) -> rusqlite::Connection {
218 let conn = self.manager.pop();
219 conn.pragma_update(None, "foreign_keys", "ON").unwrap();
220 conn
221 }
222}
223
224pub struct Transaction<S> {
230 pub(crate) _p2: PhantomData<S>,
231 pub(crate) _local: PhantomData<*const ()>,
232}
233
234impl<S> Transaction<S> {
235 pub(crate) fn new() -> Self {
236 Self {
237 _p2: PhantomData,
238 _local: PhantomData,
239 }
240 }
241
242 pub(crate) fn copy(&self) -> Self {
243 Self::new()
244 }
245
246 pub(crate) fn new_ref() -> &'static mut Self {
247 Box::leak(Box::new(Self::new()))
249 }
250}
251
252impl<S: Schema> Transaction<S> {
253 pub(crate) fn new_checked(txn: OwnedTransaction, expected: &AtomicI64) -> &'static mut Self {
255 let schema_version = schema_version(txn.get());
256 if schema_version != expected.load(std::sync::atomic::Ordering::Relaxed) {
259 if user_version(txn.get()).unwrap() != S::VERSION {
260 panic!("The database user_version changed unexpectedly")
261 }
262
263 TXN.set(Some(TransactionWithRows::new_empty(txn)));
264 check_schema::<S>(Self::new_ref());
265 expected.store(schema_version, std::sync::atomic::Ordering::Relaxed);
266 } else {
267 TXN.set(Some(TransactionWithRows::new_empty(txn)));
268 }
269
270 const {
271 assert!(size_of::<Self>() == 0);
272 }
273 Self::new_ref()
274 }
275}
276
277impl<S> Transaction<S> {
278 pub fn query<'t, R>(&'t self, f: impl FnOnce(&mut Query<'t, '_, S>) -> R) -> R {
291 let q = Rows {
296 phantom: PhantomData,
297 ast: Default::default(),
298 _p: PhantomData,
299 };
300 f(&mut Query {
301 q,
302 phantom: PhantomData,
303 })
304 }
305
306 pub fn query_one<O: 'static>(&self, val: impl IntoSelect<'static, S, Out = O>) -> O {
319 self.query(|e| e.into_iter(val.into_select()).next().unwrap())
320 }
321
322 pub fn lazy<'t, T: OptTable>(&'t self, val: impl IntoExpr<'static, S, Typ = T>) -> T::Lazy<'t> {
330 T::out_to_lazy(self.query_one(val.into_expr()))
331 }
332
333 pub fn lazy_iter<'t, T: Table<Schema = S>>(
337 &'t self,
338 val: impl IntoJoinable<'static, S, Typ = T>,
339 ) -> LazyIter<'t, T> {
340 let val = val.into_joinable();
341 self.query(|rows| {
342 let table = rows.join(val);
343 LazyIter {
344 txn: self,
345 iter: rows.into_iter(table),
346 }
347 })
348 }
349
350 pub fn mutable<'t, T: OptTable<Schema = S>>(
354 &'t mut self,
355 val: impl IntoExpr<'static, S, Typ = T>,
356 ) -> T::Mutable<'t> {
357 let x = self.query_one(T::select_opt_mutable(val.into_expr()));
358 T::into_mutable(x)
359 }
360
361 pub fn mutable_vec<'t, T: Table<Schema = S>>(
365 &'t mut self,
366 val: impl IntoJoinable<'static, S, Typ = T>,
367 ) -> Vec<Mutable<'t, T>> {
368 let val = val.into_joinable();
369 self.query(|rows| {
370 let val = rows.join(val);
371 rows.into_vec((T::select_mutable(val.clone()), val))
372 .into_iter()
373 .map(T::into_mutable)
374 .collect()
375 })
376 }
377}
378
379pub struct LazyIter<'t, T: Table> {
380 txn: &'t Transaction<T::Schema>,
381 iter: crate::query::Iter<'t, TableRow<T>>,
382}
383
384impl<'t, T: Table> Iterator for LazyIter<'t, T> {
385 type Item = <T as MyTyp>::Lazy<'t>;
386
387 fn next(&mut self) -> Option<Self::Item> {
388 self.iter.next().map(|x| self.txn.lazy(x))
389 }
390}
391
392impl<S: 'static> Transaction<S> {
393 pub fn insert<T: Table<Schema = S>>(
415 &mut self,
416 val: impl TableInsert<T = T>,
417 ) -> Result<TableRow<T>, T::Conflict> {
418 try_insert_private(T::NAME.into_table_ref(), None, val.into_insert())
419 }
420
421 pub fn insert_ok<T: Table<Schema = S, Conflict = Infallible>>(
426 &mut self,
427 val: impl TableInsert<T = T>,
428 ) -> TableRow<T> {
429 let Ok(row) = self.insert(val);
430 row
431 }
432
433 pub fn find_or_insert<T: Table<Schema = S, Conflict = TableRow<T>>>(
453 &mut self,
454 val: impl TableInsert<T = T>,
455 ) -> TableRow<T> {
456 match self.insert(val) {
457 Ok(row) => row,
458 Err(row) => row,
459 }
460 }
461
462 #[deprecated = "Use `Mutable::unique` instead"]
485 pub fn update<T: Table<Schema = S>>(
486 &mut self,
487 row: impl IntoExpr<'static, S, Typ = T>,
488 val: T::Update,
489 ) -> Result<(), T::Conflict> {
490 let mut id = ValueBuilder::default();
491 let row = row.into_expr();
492 let (id, _) = id.simple_one(DynTypedExpr::erase(&row));
493
494 let val = T::apply_try_update(val, row);
495 let mut reader = Reader::default();
496 T::read(&val, &mut reader);
497 let (col_names, col_exprs): (Vec<_>, Vec<_>) = reader.builder.into_iter().collect();
498
499 let (select, col_fields) = ValueBuilder::default().simple(col_exprs);
500 let cte = CommonTableExpression::new()
501 .query(select)
502 .columns(col_fields.clone())
503 .table_name(Alias::new("cte"))
504 .to_owned();
505 let with_clause = WithClause::new().cte(cte).to_owned();
506
507 let mut update = UpdateStatement::new()
508 .table(("main", T::NAME))
509 .cond_where(Expr::col(("main", T::NAME, T::ID)).in_subquery(id))
510 .to_owned();
511
512 for (name, field) in zip(col_names, col_fields) {
513 let select = SelectStatement::new()
514 .from(Alias::new("cte"))
515 .column(field)
516 .to_owned();
517 let value = sea_query::Expr::SubQuery(
518 None,
519 Box::new(sea_query::SubQueryStatement::SelectStatement(select)),
520 );
521 update.value(Alias::new(name), value);
522 }
523
524 let (query, args) = update.with(with_clause).build_rusqlite(SqliteQueryBuilder);
525
526 let res = TXN.with_borrow(|txn| {
527 let txn = txn.as_ref().unwrap().get();
528
529 let mut stmt = txn.prepare_cached(&query).unwrap();
530 stmt.execute(&*args.as_params())
531 });
532
533 match res {
534 Ok(1) => Ok(()),
535 Ok(n) => panic!("unexpected number of updates: {n}"),
536 Err(rusqlite::Error::SqliteFailure(kind, Some(_val)))
537 if kind.code == ErrorCode::ConstraintViolation =>
538 {
539 Err(T::get_conflict_unchecked(self, &val))
541 }
542 Err(err) => panic!("{err:?}"),
543 }
544 }
545
546 #[deprecated = "Use Transaction::mutable instead"]
552 pub fn update_ok<T: Table<Schema = S>>(
553 &mut self,
554 row: impl IntoExpr<'static, S, Typ = T>,
555 val: T::UpdateOk,
556 ) {
557 #[expect(deprecated)]
558 match self.update(row, T::update_into_try_update(val)) {
559 Ok(val) => val,
560 Err(_) => {
561 unreachable!("update can not fail")
562 }
563 }
564 }
565
566 pub fn downgrade(&'static mut self) -> &'static mut TransactionWeak<S> {
568 Box::leak(Box::new(TransactionWeak { inner: PhantomData }))
570 }
571}
572
573pub struct TransactionWeak<S> {
580 inner: PhantomData<Transaction<S>>,
581}
582
583impl<S: Schema> TransactionWeak<S> {
584 pub fn delete<T: Table<Schema = S>>(&mut self, val: TableRow<T>) -> Result<bool, T::Referer> {
591 let schema = crate::schema::from_macro::Schema::new::<S>();
592
593 let mut checks = vec![];
597 for (&table_name, table) in &schema.tables {
598 for col in table.columns.iter().filter_map(|(col_name, col)| {
599 let col = &col.def;
600 col.fk
601 .as_ref()
602 .is_some_and(|(t, c)| t == T::NAME && c == T::ID)
603 .then_some(col_name)
604 }) {
605 let stmt = SelectStatement::new()
606 .expr(
607 val.in_subquery(
608 SelectStatement::new()
609 .from(table_name)
610 .column(Alias::new(col))
611 .take(),
612 ),
613 )
614 .take();
615 checks.push(stmt.build_rusqlite(SqliteQueryBuilder));
616 }
617 }
618
619 let stmt = DeleteStatement::new()
620 .from_table(("main", T::NAME))
621 .cond_where(Expr::col(("main", T::NAME, T::ID)).eq(val.inner.idx))
622 .take();
623
624 let (query, args) = stmt.build_rusqlite(SqliteQueryBuilder);
625
626 TXN.with_borrow(|txn| {
627 let txn = txn.as_ref().unwrap().get();
628
629 for (query, args) in checks {
630 let mut stmt = txn.prepare_cached(&query).unwrap();
631 match stmt.query_one(&*args.as_params(), |r| r.get(0)) {
632 Ok(true) => return Err(T::get_referer_unchecked()),
633 Ok(false) => {}
634 Err(err) => panic!("{err:?}"),
635 }
636 }
637
638 let mut stmt = txn.prepare_cached(&query).unwrap();
639 match stmt.execute(&*args.as_params()) {
640 Ok(0) => Ok(false),
641 Ok(1) => Ok(true),
642 Ok(n) => {
643 panic!("unexpected number of deletes {n}")
644 }
645 Err(err) => panic!("{err:?}"),
646 }
647 })
648 }
649
650 pub fn delete_ok<T: Table<Referer = Infallible, Schema = S>>(
656 &mut self,
657 val: TableRow<T>,
658 ) -> bool {
659 let Ok(res) = self.delete(val);
660 res
661 }
662
663 pub fn rusqlite_transaction<R>(&mut self, f: impl FnOnce(&rusqlite::Transaction) -> R) -> R {
673 TXN.with_borrow(|txn| f(txn.as_ref().unwrap().get()))
674 }
675}
676
677pub fn try_insert_private<T: Table>(
678 table: sea_query::TableRef,
679 idx: Option<i64>,
680 val: T::Insert,
681) -> Result<TableRow<T>, T::Conflict> {
682 let mut reader = Reader::default();
683 T::read(&val, &mut reader);
684 if let Some(idx) = idx {
685 reader.col(T::ID, idx);
686 }
687 let (col_names, col_exprs): (Vec<_>, Vec<_>) = reader.builder.into_iter().collect();
688 let is_empty = col_names.is_empty();
689
690 let (select, _) = ValueBuilder::default().simple(col_exprs);
691
692 let mut insert = InsertStatement::new();
693 insert.into_table(table);
694 insert.columns(col_names.into_iter().map(Alias::new));
695 if is_empty {
696 insert.or_default_values();
698 } else {
699 insert.select_from(select).unwrap();
700 }
701 insert.returning_col(T::ID);
702
703 let (sql, values) = insert.build_rusqlite(SqliteQueryBuilder);
704
705 let res = TXN.with_borrow(|txn| {
706 let txn = txn.as_ref().unwrap().get();
707 track_stmt(txn, &sql, &values);
708
709 let mut statement = txn.prepare_cached(&sql).unwrap();
710 let mut res = statement
711 .query_map(&*values.as_params(), |row| {
712 Ok(TableRow::<T>::from_sql(row.get_ref(T::ID)?)?)
713 })
714 .unwrap();
715
716 res.next().unwrap()
717 });
718
719 match res {
720 Ok(id) => {
721 if let Some(idx) = idx {
722 assert_eq!(idx, id.inner.idx);
723 }
724 Ok(id)
725 }
726 Err(rusqlite::Error::SqliteFailure(kind, Some(_val)))
727 if kind.code == ErrorCode::ConstraintViolation =>
728 {
729 Err(T::get_conflict_unchecked(&Transaction::new(), &val))
731 }
732 Err(err) => panic!("{err:?}"),
733 }
734}