rust_query/
transaction.rs

1use std::{convert::Infallible, iter::zip, marker::PhantomData, ops::Deref, rc::Rc};
2
3use rusqlite::ErrorCode;
4use sea_query::{
5    Alias, CommonTableExpression, DeleteStatement, Expr, InsertStatement, IntoTableRef,
6    SelectStatement, SimpleExpr, SqliteQueryBuilder, UpdateStatement, WithClause,
7};
8use sea_query_rusqlite::RusqliteBinder;
9
10use crate::{
11    IntoExpr, IntoSelect, Table, TableRow,
12    client::LocalClient,
13    migrate::schema_version,
14    private::Reader,
15    query::Query,
16    rows::Rows,
17    value::{SecretFromSql, ValueBuilder},
18    writable::TableInsert,
19};
20
21/// [Database] is a proof that the database has been configured.
22///
23/// Creating a [Database] requires going through the steps to migrate an existing database to
24/// the required schema, or creating a new database from scratch (See also [crate::migration::Config]).
25/// Having done the setup to create a compatible database is sadly not a guarantee that the
26/// database will stay compatible for the lifetime of the [Database] struct.
27///
28/// That is why [Database] also stores the `schema_version`. This allows detecting non-malicious
29/// modifications to the schema and gives us the ability to panic when this is detected.
30/// Such non-malicious modification of the schema can happen for example if another [Database]
31/// instance is created with additional migrations (e.g. by another newer instance of your program).
32///
33/// For information on how to create transactions, please refer to [LocalClient].
34pub struct Database<S> {
35    pub(crate) manager: r2d2_sqlite::SqliteConnectionManager,
36    pub(crate) schema_version: i64,
37    pub(crate) schema: PhantomData<S>,
38}
39
40impl<S> Database<S> {
41    /// Create a new [rusqlite::Connection] to the database.
42    ///
43    /// You can do (almost) anything you want with this connection as it is almost completely isolated from all other
44    /// [rust_query] connections. The only thing you should not do here is changing the schema.
45    /// Schema changes are detected with the `schema_version` pragma and will result in a panic when creating a new
46    /// transaction.
47    pub fn rusqlite_connection(&self) -> rusqlite::Connection {
48        use r2d2::ManageConnection;
49        self.manager.connect().unwrap()
50    }
51}
52
53/// [Transaction] can be used to query the database.
54///
55/// From the perspective of a [Transaction] each [TransactionMut] is fully applied or not at all.
56/// Futhermore, the effects of [TransactionMut]s have a global order.
57/// So if we have mutations `A` and then `B`, it is impossible for a [Transaction] to see the effect of `B` without seeing the effect of `A`.
58///
59/// All [TableRow] references retrieved from the database live for at most `'a`.
60/// This makes these references effectively local to this [Transaction].
61pub struct Transaction<'t, S> {
62    pub(crate) transaction: Rc<rusqlite::Transaction<'t>>,
63    pub(crate) _p: PhantomData<fn(&'t ()) -> &'t ()>,
64    pub(crate) _p2: PhantomData<S>,
65    pub(crate) _local: PhantomData<LocalClient>,
66}
67
68impl<'t, S> Transaction<'t, S> {
69    pub(crate) fn new(raw: Rc<rusqlite::Transaction<'t>>) -> Self {
70        Self {
71            transaction: raw,
72            _p: PhantomData,
73            _p2: PhantomData,
74            _local: PhantomData,
75        }
76    }
77}
78
79/// Same as [Transaction], but allows inserting new rows.
80///
81/// [TransactionMut] always uses the latest version of the database, with the effects of all previous [TransactionMut]s applied.
82///
83/// To make mutations to the database permanent you need to use [TransactionMut::commit].
84/// This is to make sure that if a function panics while holding a mutable transaction, it will roll back those changes.
85pub struct TransactionMut<'t, S> {
86    pub(crate) inner: Transaction<'t, S>,
87}
88
89impl<'t, S> Deref for TransactionMut<'t, S> {
90    type Target = Transaction<'t, S>;
91
92    fn deref(&self) -> &Self::Target {
93        &self.inner
94    }
95}
96
97impl<'t, S> Transaction<'t, S> {
98    /// This will check the schema version and panic if it is not as expected
99    pub(crate) fn new_checked(txn: rusqlite::Transaction<'t>, expected: i64) -> Self {
100        if schema_version(&txn) != expected {
101            panic!("The database schema was updated unexpectedly")
102        }
103
104        Self::new(Rc::new(txn))
105    }
106
107    /// Execute a query with multiple results.
108    ///
109    /// ```
110    /// # use rust_query::{private::doctest::*};
111    /// # let mut client = get_client();
112    /// # let txn = get_txn(&mut client);
113    /// let user_names = txn.query(|rows| {
114    ///     let user = rows.join(User);
115    ///     rows.into_vec(user.name())
116    /// });
117    /// assert_eq!(user_names, vec!["Alice".to_owned()]);
118    /// ```
119    pub fn query<F, R>(&self, f: F) -> R
120    where
121        F: for<'inner> FnOnce(&mut Query<'t, 'inner, S>) -> R,
122    {
123        // Execution already happens in a [Transaction].
124        // and thus any [TransactionMut] that it might be borrowed
125        // from is borrowed immutably, which means the rows can not change.
126        let conn: &rusqlite::Connection = &self.transaction;
127        let q = Rows {
128            phantom: PhantomData,
129            ast: Default::default(),
130            _p: PhantomData,
131        };
132        f(&mut Query {
133            q,
134            phantom: PhantomData,
135            conn,
136        })
137    }
138
139    /// Retrieve a single result from the database.
140    ///
141    /// ```
142    /// # use rust_query::{private::doctest::*, IntoExpr};
143    /// # let mut client = rust_query::private::doctest::get_client();
144    /// # let txn = rust_query::private::doctest::get_txn(&mut client);
145    /// let res = txn.query_one("test".into_expr());
146    /// assert_eq!(res, "test");
147    /// ```
148    ///
149    /// Instead of using [Self::query_one] in a loop, it is better to
150    /// call [Self::query] and return all results at once.
151    pub fn query_one<'e, O>(&self, val: impl IntoSelect<'t, 't, S, Out = O>) -> O {
152        // Theoretically this doesn't even need to be in a transaction.
153        // We already have one though, so we must use it.
154        let mut res = self.query(|e| {
155            // Cast the static lifetime to any lifetime necessary, this is fine because we know the static lifetime
156            // can not be guaranteed by a query scope.
157            e.into_vec_private(val)
158        });
159        res.pop().unwrap()
160    }
161}
162
163impl<'t, S: 'static> TransactionMut<'t, S> {
164    /// Try inserting a value into the database.
165    ///
166    /// Returns [Ok] with a reference to the new inserted value or an [Err] with conflict information.
167    /// The type of conflict information depends on the number of unique constraints on the table:
168    /// - 0 unique constraints => [Infallible]
169    /// - 1 unique constraint => [Expr] reference to the conflicting table row.
170    /// - 2+ unique constraints => `()` no further information is provided.
171    ///
172    /// ```
173    /// # use rust_query::{private::doctest::*, IntoExpr};
174    /// # let mut client = rust_query::private::doctest::get_client();
175    /// # let mut txn = rust_query::private::doctest::get_txn(&mut client);
176    /// let res = txn.insert(User {
177    ///     name: "Bob",
178    /// });
179    /// assert!(res.is_ok());
180    /// let res = txn.insert(User {
181    ///     name: "Bob",
182    /// });
183    /// assert!(res.is_err(), "there is a unique constraint on the name");
184    /// ```
185    pub fn insert<T: Table<Schema = S>>(
186        &mut self,
187        val: impl TableInsert<'t, T = T>,
188    ) -> Result<TableRow<'t, T>, T::Conflict<'t>> {
189        try_insert_private(
190            &self.transaction,
191            Alias::new(T::NAME).into_table_ref(),
192            None,
193            val.into_insert(),
194        )
195    }
196
197    /// This is a convenience function to make using [TransactionMut::insert]
198    /// easier for tables without unique constraints.
199    ///
200    /// The new row is added to the table and the row reference is returned.
201    pub fn insert_ok<T: Table<Schema = S, Conflict<'t> = Infallible>>(
202        &mut self,
203        val: impl TableInsert<'t, T = T>,
204    ) -> TableRow<'t, T> {
205        let Ok(row) = self.insert(val);
206        row
207    }
208
209    /// This is a convenience function to make using [TransactionMut::insert]
210    /// easier for tables with exactly one unique constraints.
211    ///
212    /// The new row is inserted and the reference to the row is returned OR
213    /// an existing row is found which conflicts with the new row and a reference
214    /// to the conflicting row is returned.
215    ///
216    /// ```
217    /// # use rust_query::{private::doctest::*, IntoExpr};
218    /// # let mut client = rust_query::private::doctest::get_client();
219    /// # let mut txn = rust_query::private::doctest::get_txn(&mut client);
220    /// let bob = txn.insert(User {
221    ///     name: "Bob",
222    /// }).unwrap();
223    /// let bob2 = txn.find_or_insert(User {
224    ///     name: "Bob", // this will conflict with the existing row.
225    /// });
226    /// assert_eq!(bob, bob2);
227    /// ```
228    pub fn find_or_insert<T: Table<Schema = S, Conflict<'t> = TableRow<'t, T>>>(
229        &mut self,
230        val: impl TableInsert<'t, T = T>,
231    ) -> TableRow<'t, T> {
232        match self.insert(val) {
233            Ok(row) => row,
234            Err(row) => row,
235        }
236    }
237
238    /// Try updating a row in the database to have new column values.
239    ///
240    /// Updating can fail just like [TransactionMut::insert] because of unique constraint conflicts.
241    /// This happens when the new values are in conflict with an existing different row.
242    ///
243    /// When the update succeeds, this function returns [Ok<()>], when it fails it returns [Err] with one of
244    /// three conflict types:
245    /// - 0 unique constraints => [Infallible]
246    /// - 1 unique constraint => [Expr] reference to the conflicting table row.
247    /// - 2+ unique constraints => `()` no further information is provided.
248    ///
249    /// ```
250    /// # use rust_query::{private::doctest::*, IntoExpr, Update};
251    /// # let mut client = rust_query::private::doctest::get_client();
252    /// # let mut txn = rust_query::private::doctest::get_txn(&mut client);
253    /// let bob = txn.insert(User {
254    ///     name: "Bob",
255    /// }).unwrap();
256    /// txn.update(bob, User {
257    ///     name: Update::set("New Bob"),
258    /// }).unwrap();
259    /// ```
260    pub fn update<T: Table<Schema = S>>(
261        &mut self,
262        row: impl IntoExpr<'t, S, Typ = T>,
263        val: T::Update<'t>,
264    ) -> Result<(), T::Conflict<'t>> {
265        let mut id = ValueBuilder::default();
266        let row = row.into_expr();
267        let (id, _) = id.simple_one(row.inner.clone().erase());
268
269        let val = T::apply_try_update(val, row);
270        let mut reader = Reader::default();
271        T::read(&val, &mut reader);
272        let (col_names, col_exprs): (Vec<_>, Vec<_>) = reader.builder.into_iter().collect();
273
274        let (select, col_fields) = ValueBuilder::default().simple(col_exprs);
275        let cte = CommonTableExpression::new()
276            .query(select)
277            .columns(col_fields.clone())
278            .table_name(Alias::new("cte"))
279            .to_owned();
280        let with_clause = WithClause::new().cte(cte).to_owned();
281
282        let mut update = UpdateStatement::new()
283            .table((Alias::new("main"), Alias::new(T::NAME)))
284            .cond_where(
285                Expr::col((Alias::new("main"), Alias::new(T::NAME), Alias::new(T::ID)))
286                    .in_subquery(id),
287            )
288            .to_owned();
289
290        for (name, field) in zip(col_names, col_fields) {
291            let select = SelectStatement::new()
292                .from(Alias::new("cte"))
293                .column(field)
294                .to_owned();
295            let value = SimpleExpr::SubQuery(
296                None,
297                Box::new(sea_query::SubQueryStatement::SelectStatement(select)),
298            );
299            update.value(Alias::new(name), value);
300        }
301
302        let (query, args) = update.with(with_clause).build_rusqlite(SqliteQueryBuilder);
303
304        let mut stmt = self.transaction.prepare_cached(&query).unwrap();
305        match stmt.execute(&*args.as_params()) {
306            Ok(1) => Ok(()),
307            Ok(n) => panic!("unexpected number of updates: {n}"),
308            Err(rusqlite::Error::SqliteFailure(kind, Some(_val)))
309                if kind.code == ErrorCode::ConstraintViolation =>
310            {
311                // val looks like "UNIQUE constraint failed: playlist_track.playlist, playlist_track.track"
312                Err(T::get_conflict_unchecked(self, &val))
313            }
314            Err(err) => panic!("{:?}", err),
315        }
316    }
317
318    /// This is a convenience function to use [TransactionMut::update] for updates
319    /// that can not cause unique constraint violations.
320    ///
321    /// This method can be used for all tables, it just does not allow modifying
322    /// columns that are part of unique constraints.
323    pub fn update_ok<T: Table<Schema = S>>(
324        &mut self,
325        row: impl IntoExpr<'t, S, Typ = T>,
326        val: T::UpdateOk<'t>,
327    ) {
328        match self.update(row, T::update_into_try_update(val)) {
329            Ok(val) => val,
330            Err(_) => {
331                unreachable!("update can not fail")
332            }
333        }
334    }
335
336    /// Make the changes made in this [TransactionMut] permanent.
337    ///
338    /// If the [TransactionMut] is dropped without calling this function, then the changes are rolled back.
339    pub fn commit(self) {
340        Rc::into_inner(self.inner.transaction)
341            .unwrap()
342            .commit()
343            .unwrap();
344    }
345
346    /// Convert the [TransactionMut] into a [TransactionWeak] to allow deletions.
347    pub fn downgrade(self) -> TransactionWeak<'t, S> {
348        TransactionWeak { inner: self }
349    }
350}
351
352/// This is the weak version of [TransactionMut].
353///
354/// The reason that it is called `weak` is because [TransactionWeak] can not guarantee
355/// that [TableRow]s prove the existence of their particular row.
356///
357/// [TransactionWeak] is useful because it allowes deleting rows.
358pub struct TransactionWeak<'t, S> {
359    inner: TransactionMut<'t, S>,
360}
361
362impl<'t, S: 'static> TransactionWeak<'t, S> {
363    /// Try to delete a row from the database.
364    ///
365    /// This will return an [Err] if there is a row that references the row that is being deleted.
366    /// When this method returns [Ok] it will contain a [bool] that is either
367    /// - `true` if the row was just deleted.
368    /// - `false` if the row was deleted previously in this transaction.
369    pub fn delete<T: Table<Schema = S>>(
370        &mut self,
371        val: TableRow<'t, T>,
372    ) -> Result<bool, T::Referer> {
373        let stmt = DeleteStatement::new()
374            .from_table((Alias::new("main"), Alias::new(T::NAME)))
375            .cond_where(
376                Expr::col((Alias::new("main"), Alias::new(T::NAME), Alias::new(T::ID)))
377                    .eq(val.inner.idx),
378            )
379            .to_owned();
380
381        let (query, args) = stmt.build_rusqlite(SqliteQueryBuilder);
382        let mut stmt = self.inner.transaction.prepare_cached(&query).unwrap();
383
384        match stmt.execute(&*args.as_params()) {
385            Ok(0) => Ok(false),
386            Ok(1) => Ok(true),
387            Ok(n) => {
388                panic!("unexpected number of deletes {n}")
389            }
390            Err(rusqlite::Error::SqliteFailure(kind, Some(_val)))
391                if kind.code == ErrorCode::ConstraintViolation =>
392            {
393                // Some foreign key constraint got violated
394                Err(T::get_referer_unchecked())
395            }
396            Err(err) => panic!("{:?}", err),
397        }
398    }
399
400    /// Delete a row from the database.
401    ///
402    /// This is the infallible version of [TransactionWeak::delete].
403    ///
404    /// To be able to use this method you have to mark the table as `#[no_reference]` in the schema.
405    pub fn delete_ok<T: Table<Referer = Infallible, Schema = S>>(
406        &mut self,
407        val: TableRow<'t, T>,
408    ) -> bool {
409        let Ok(res) = self.delete(val);
410        res
411    }
412
413    /// This allows you to do (almost) anything you want with the internal [rusqlite::Transaction].
414    ///
415    /// Note that there are some things that you should not do with the transaction, such as:
416    /// - Changes to the schema, these will result in a panic as described in [Database].
417    /// - Changes to the connection configuration such as disabling foreign key checks.
418    ///
419    /// **When this method is used to break [rust_query] invariants, all other [rust_query] function calls
420    /// may result in a panic.**
421    pub fn rusqlite_transaction(&mut self) -> &rusqlite::Transaction {
422        &self.inner.transaction
423    }
424
425    /// Make the changes made in this [TransactionWeak] permanent.
426    ///
427    /// If the [TransactionWeak] is dropped without calling this function, then the changes are rolled back.
428    pub fn commit(self) {
429        self.inner.commit();
430    }
431}
432
433pub fn try_insert_private<'t, T: Table>(
434    transaction: &Rc<rusqlite::Transaction<'t>>,
435    table: sea_query::TableRef,
436    idx: Option<i64>,
437    val: T::Insert<'t>,
438) -> Result<TableRow<'t, T>, T::Conflict<'t>> {
439    let mut reader = Reader::default();
440    T::read(&val, &mut reader);
441    if let Some(idx) = idx {
442        reader.col(T::ID, idx);
443    }
444    let (col_names, col_exprs): (Vec<_>, Vec<_>) = reader.builder.into_iter().collect();
445    let is_empty = col_names.is_empty();
446
447    let (select, _) = ValueBuilder::default().simple(col_exprs);
448
449    let mut insert = InsertStatement::new();
450    insert.into_table(table);
451    insert.columns(col_names.into_iter().map(|name| Alias::new(name)));
452    if is_empty {
453        // select always has at least one column, so we leave it out when there are no columns
454        insert.or_default_values();
455    } else {
456        insert.select_from(select).unwrap();
457    }
458    insert.returning_col(Alias::new(T::ID));
459
460    let (sql, values) = insert.build_rusqlite(SqliteQueryBuilder);
461
462    let mut statement = transaction.prepare_cached(&sql).unwrap();
463    let mut res = statement
464        .query_map(&*values.as_params(), |row| {
465            Ok(TableRow::<'_, T>::from_sql(row.get_ref(T::ID)?)?)
466        })
467        .unwrap();
468
469    match res.next().unwrap() {
470        Ok(id) => Ok(id),
471        Err(rusqlite::Error::SqliteFailure(kind, Some(_val)))
472            if kind.code == ErrorCode::ConstraintViolation =>
473        {
474            // val looks like "UNIQUE constraint failed: playlist_track.playlist, playlist_track.track"
475            Err(T::get_conflict_unchecked(
476                &Transaction::new(transaction.clone()),
477                &val,
478            ))
479        }
480        Err(err) => panic!("{:?}", err),
481    }
482}