Skip to main content

rorm_db/
database.rs

1//! [`Database`] struct and several common operations
2
3use std::sync::Arc;
4
5use rorm_declaration::config::DatabaseDriver;
6use rorm_sql::conditional;
7use rorm_sql::delete::Delete;
8use rorm_sql::insert::Insert;
9use rorm_sql::join_table::JoinTableData;
10use rorm_sql::ordering::OrderByEntry;
11#[cfg(feature = "postgres-only")]
12use rorm_sql::select::LockingClause;
13use rorm_sql::select_column::SelectColumnData;
14use rorm_sql::update::Update;
15use rorm_sql::value::Value;
16use tracing::warn;
17
18use crate::error::Error;
19use crate::executor::AffectedRows;
20use crate::executor::All;
21use crate::executor::Executor;
22use crate::executor::Nothing;
23use crate::executor::One;
24use crate::executor::QueryStrategy;
25use crate::internal::any::AnyPool;
26use crate::query_type::GetLimitClause;
27use crate::row::Row;
28use crate::transaction::Transaction;
29use crate::transaction::TransactionError;
30
31/**
32Type alias for [`SelectColumnData`]..
33
34As all databases use currently the same fields, a type alias is sufficient.
35*/
36pub type ColumnSelector<'a> = SelectColumnData<'a>;
37
38/**
39Type alias for [`JoinTableData`].
40
41As all databases use currently the same fields, a type alias is sufficient.
42*/
43pub type JoinTable<'until_build, 'post_build> = JoinTableData<'until_build, 'post_build>;
44
45/// Configuration use in [`Database::connect`].
46#[derive(Debug)]
47#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
48pub struct DatabaseConfiguration {
49    /// The driver and its corresponding settings
50    pub driver: DatabaseDriver,
51
52    /// Minimal connections to initialize upfront.
53    ///
54    /// Must be greater than `0` and can't be larger than `max_connections`.
55    pub min_connections: u32,
56
57    /// Maximum connections that allowed to be created.
58    ///
59    /// Must be greater than `0`.
60    pub max_connections: u32,
61}
62
63impl DatabaseConfiguration {
64    /**
65    Create a new database configuration with some defaults set.
66
67    **Defaults**:
68    - `min_connections`: 1
69    - `max_connections`: 10
70
71    **Parameter**:
72    - `driver`: [`DatabaseDriver`]: Configuration of the database driver.
73    */
74    pub fn new(driver: DatabaseDriver) -> Self {
75        DatabaseConfiguration {
76            driver,
77            min_connections: 1,
78            max_connections: 10,
79        }
80    }
81}
82
83/// Handle to a pool of database connections
84///
85/// Executing sql statements is done through the [`Executor`] trait
86/// which is implemented on `&Database`.
87///
88/// Common operations are implemented as functions in the [`database`](self) module.
89///
90/// Cloning is cheap i.e. two `Arc`s.
91#[derive(Clone)]
92pub struct Database(pub(crate) AnyPool, Arc<()>);
93
94impl Database {
95    /// Connects to the database using `configuration`
96    pub async fn connect(configuration: DatabaseConfiguration) -> Result<Self, Error> {
97        Ok(Self(AnyPool::connect(configuration).await?, Arc::new(())))
98    }
99
100    /// Starts a new transaction
101    ///
102    /// `&mut Transaction` implements [`Executor`] like `&Database` does
103    /// but its database operations can be reverted using [`Transaction::rollback`]
104    /// or simply dropping the transaction without calling [`Transaction::commit`].
105    pub async fn start_transaction(&self) -> Result<Transaction, Error> {
106        Ok(Transaction::new(self.0.begin().await?))
107    }
108
109    /// Closes the database connection
110    ///
111    /// While calling this method is not strictly necessary,
112    /// terminating your program without it
113    /// might result in some final queries not being flushed properly.
114    ///
115    /// This method consumes the database handle,
116    /// but actually all handles created using `clone` will become invalid after this call.
117    /// This means any further operation would result in an `Err`
118    pub async fn close(self) {
119        self.0.close().await;
120    }
121}
122
123impl Drop for Database {
124    /// Checks whether [`Database::close`] has been called before the last instance is dropped
125    fn drop(&mut self) {
126        // The use of strong_count should be correct:
127        // - the arc is private and we don't create WeakRefs
128        // => when observing a strong_count of 1, there can't be any remaining refs
129        if Arc::strong_count(&self.1) == 1 && !self.0.is_closed() {
130            warn!("Database has been dropped without calling close. This might cause the last queries to not be flushed properly");
131        }
132    }
133}
134
135/// Executes a simple `SELECT` query.
136///
137/// It is generic over a [`QueryStrategy`] which specifies how and how many rows to query.
138///
139/// **Parameter**:
140/// - `model`: Model to query.
141/// - `columns`: Columns to retrieve values from.
142/// - `joins`: Join tables expressions.
143/// - `conditions`: Optional conditions to apply.
144/// - `order_by_clause`: Columns to order the rows by.
145/// - `limit`: Optional limit / offset to apply to the query.
146///   Depending on the query strategy, this is either [`LimitClause`](rorm_sql::limit_clause::LimitClause)
147///   (for [`All`] and [`Stream`](crate::executor::Stream))
148///   or a simple [`u64`] (for [`One`] and [`Optional`](crate::executor::Optional)).
149#[allow(clippy::too_many_arguments)] // TODO: refactor this API, clippy is right
150pub fn query<'result, 'db: 'result, 'post_query: 'result, Q: QueryStrategy + GetLimitClause>(
151    executor: impl Executor<'db>,
152    model: &str,
153    columns: &[ColumnSelector<'_>],
154    joins: &[JoinTable<'_, 'post_query>],
155    conditions: Option<&conditional::Condition<'post_query>>,
156    order_by_clause: &[OrderByEntry<'_>],
157    limit: Option<Q::LimitOrOffset>,
158    distinct: bool,
159    #[cfg(feature = "postgres-only")] locking_clause: Option<LockingClause>,
160) -> Q::Result<'result> {
161    let columns: Vec<_> = columns
162        .iter()
163        .map(|c| {
164            executor.dialect().select_column(
165                c.table_name,
166                c.column_name,
167                c.select_alias,
168                c.aggregation,
169            )
170        })
171        .collect();
172    let joins: Vec<_> = joins
173        .iter()
174        .map(|j| {
175            executor.dialect().join_table(
176                j.join_type,
177                j.table_name,
178                j.join_alias,
179                j.join_condition.clone(),
180            )
181        })
182        .collect();
183    let mut q = executor
184        .dialect()
185        .select(&columns, model, &joins, order_by_clause);
186
187    if let Some(condition) = conditions {
188        q = q.where_clause(condition);
189    }
190
191    if let Some(limit) = Q::get_limit_clause(limit) {
192        q = q.limit_clause(limit);
193    }
194
195    if distinct {
196        q = q.distinct();
197    }
198
199    #[cfg(feature = "postgres-only")]
200    if let Some(x) = locking_clause {
201        q = q.locking_clause(x);
202    }
203
204    let (query_string, bind_params) = q.build();
205
206    executor.execute::<Q>(query_string, bind_params)
207}
208
209/// Inserts a single row and returns columns from it.
210///
211/// **Parameter**:
212/// - `model`: Table to insert to
213/// - `columns`: Columns to set `values` for.
214/// - `values`: Values to bind to the corresponding columns.
215/// - `returning`: Columns to query from the inserted row.
216pub async fn insert_returning(
217    executor: impl Executor<'_>,
218    model: &str,
219    columns: &[&str],
220    values: &[Value<'_>],
221    returning: &[&str],
222) -> Result<Row, Error> {
223    generic_insert::<One>(executor, model, columns, values, Some(returning)).await
224}
225
226/// Inserts a single row.
227///
228/// **Parameter**:
229/// - `model`: Table to insert to
230/// - `columns`: Columns to set `values` for.
231/// - `values`: Values to bind to the corresponding columns.
232pub async fn insert(
233    executor: impl Executor<'_>,
234    model: &str,
235    columns: &[&str],
236    values: &[Value<'_>],
237) -> Result<(), Error> {
238    generic_insert::<Nothing>(executor, model, columns, values, None).await
239}
240
241/// Generic implementation of:
242/// - [`Database::insert`]
243/// - [`Database::insert_returning`]
244pub(crate) fn generic_insert<'result, 'db: 'result, 'post_query: 'result, Q: QueryStrategy>(
245    executor: impl Executor<'db>,
246    model: &str,
247    columns: &[&str],
248    values: &[Value<'post_query>],
249    returning: Option<&[&str]>,
250) -> Q::Result<'result> {
251    let values = &[values];
252    let q = executor.dialect().insert(model, columns, values, returning);
253
254    let (query_string, bind_params): (_, Vec<Value<'post_query>>) = q.build();
255
256    executor.execute::<Q>(query_string, bind_params)
257}
258
259/// This method is used to bulk insert rows.
260///
261/// If one insert statement fails, the complete operation will be rolled back.
262///
263/// **Parameter**:
264/// - `model`: Table to insert to
265/// - `columns`: Columns to set `rows` for.
266/// - `rows`: List of values to bind to the corresponding columns.
267/// - `transaction`: Optional transaction to execute the query on.
268pub async fn insert_bulk(
269    executor: impl Executor<'_>,
270    model: &str,
271    columns: &[&str],
272    rows: &[&[Value<'_>]],
273) -> Result<(), Error> {
274    let mut guard = executor.ensure_transaction().await?;
275    let tr: &mut Transaction = guard.get_transaction();
276
277    for chunk in rows.chunks(25) {
278        let mut insert = tr.dialect().insert(model, columns, chunk, None);
279        insert = insert.rollback_transaction();
280        let (insert_query, insert_params) = insert.build();
281
282        tr.execute::<Nothing>(insert_query, insert_params).await?;
283    }
284
285    guard.commit().await.map_err(|x| match x {
286        TransactionError::Database(x) => x,
287        TransactionError::Hook(_) => {
288            unreachable!("Potentially create transaction does not use hooks")
289        }
290    })?;
291
292    Ok(())
293}
294
295/// This method is used to bulk insert rows.
296///
297/// If one insert statement fails, the complete operation will be rolled back.
298///
299/// **Parameter**:
300/// - `model`: Table to insert to
301/// - `columns`: Columns to set `rows` for.
302/// - `rows`: List of values to bind to the corresponding columns.
303/// - `transaction`: Optional transaction to execute the query on.
304pub async fn insert_bulk_returning(
305    executor: impl Executor<'_>,
306    model: &str,
307    columns: &[&str],
308    rows: &[&[Value<'_>]],
309    returning: &[&str],
310) -> Result<Vec<Row>, Error> {
311    let mut guard = executor.ensure_transaction().await?;
312    let tr: &mut Transaction = guard.get_transaction();
313
314    let mut inserted = Vec::with_capacity(rows.len());
315    for chunk in rows.chunks(25) {
316        let mut insert = tr.dialect().insert(model, columns, chunk, Some(returning));
317        insert = insert.rollback_transaction();
318        let (insert_query, insert_params) = insert.build();
319
320        inserted.extend(tr.execute::<All>(insert_query, insert_params).await?);
321    }
322
323    guard.commit().await.map_err(|x| match x {
324        TransactionError::Database(x) => x,
325        TransactionError::Hook(_) => {
326            unreachable!("Potentially create transaction does not use hooks")
327        }
328    })?;
329
330    Ok(inserted)
331}
332
333/// This method is used to delete rows from a table.
334///
335/// **Parameter**:
336/// - `model`: Name of the model to delete rows from
337/// - `condition`: Optional condition to apply.
338/// - `transaction`: Optional transaction to execute the query on.
339///
340/// **Returns** the rows affected of the delete statement. Note that this also includes
341/// relations, etc.
342pub async fn delete<'post_build>(
343    executor: impl Executor<'_>,
344    model: &str,
345    condition: Option<&conditional::Condition<'post_build>>,
346) -> Result<u64, Error> {
347    let mut q = executor.dialect().delete(model);
348    if let Some(condition) = condition {
349        q = q.where_clause(condition);
350    }
351
352    let (query_string, bind_params) = q.build();
353
354    executor
355        .execute::<AffectedRows>(query_string, bind_params)
356        .await
357}
358
359/// This method is used to update rows in a table.
360///
361/// **Parameter**:
362/// - `model`: Name of the model to update rows from
363/// - `updates`: A list of updates. An update is a tuple that consists of a list of columns to
364///   update as well as the value to set to the columns.
365/// - `condition`: Optional condition to apply.
366/// - `transaction`: Optional transaction to execute the query on.
367///
368/// **Returns** the rows affected from the update statement. Note that this also includes
369/// relations, etc.
370pub async fn update<'post_build>(
371    executor: impl Executor<'_>,
372    model: &str,
373    updates: &[(&str, Value<'post_build>)],
374    condition: Option<&conditional::Condition<'post_build>>,
375) -> Result<u64, Error> {
376    let mut stmt = executor.dialect().update(model);
377
378    for (column, value) in updates {
379        stmt = stmt.add_update(column, *value);
380    }
381
382    if let Some(cond) = condition {
383        stmt = stmt.where_clause(cond);
384    }
385
386    let (query_string, bind_params) = stmt.build()?;
387
388    executor
389        .execute::<AffectedRows>(query_string, bind_params)
390        .await
391}