Skip to main content

argos_arpa/
archivist.rs

1//! A module to handle the postgres database connection.
2//! All DB interactions should pass through Archivist.
3//!
4//! Every function modifying the DB (i.e. not ones that only _get_ data) will
5//! automatically start a transaction if it there is not already one active.
6//! No function should commit a transaction, except for `commit_transcation`.
7
8use crate::{ARPAError, config::Config};
9use log::{info, warn};
10use std::{fmt::Debug, fs::read_to_string};
11
12pub mod data_types;
13mod error;
14pub mod table;
15
16pub use error::ArchivistError;
17use sqlx::{
18    FromRow, PgConnection, Pool, Postgres, Transaction,
19    postgres::{PgPoolOptions, PgRow},
20};
21use table::TableItem;
22
23type Result<T> = std::result::Result<T, ArchivistError>;
24
25/// This keeps a live connection to the database and acts as your friend in
26/// getting and posting data.
27///
28/// For any queries that modify the DB, a transaction
29/// _will_ be used, and if the user has not explicitly started one, a warning
30/// will be issued (though not an error). NB: If a transaction goes out of scope,
31/// it is rolled back.
32///
33/// All tables are accessible _only_ through the `Table` enum.
34pub struct Archivist {
35    pool: Pool<Postgres>,
36    config: Config,
37
38    /// This is here so that potentially destructive app commands always go
39    /// through transactions.
40    current_transaction: Option<Transaction<'static, Postgres>>,
41}
42
43impl Archivist {
44    /// Initializes a new connection to the database.
45    ///
46    /// # Errors
47    /// Fails if setup data is missing. Forwards errors from `sqlx`.
48    pub async fn new(
49        config_path: impl AsRef<std::path::Path>,
50        sql_setup_dir: impl AsRef<std::path::Path>,
51    ) -> std::result::Result<Self, ARPAError> {
52        info!("Reading config \"{}\"...", config_path.as_ref().display());
53        let config = Config::load(config_path)?;
54
55        let pool = PgPoolOptions::new()
56            .max_connections(config.database.pool_connections)
57            .acquire_timeout(std::time::Duration::from_millis(
58                config.database.connection_timeout,
59            ))
60            .connect(&config.database.url)
61            .await
62            .map_err(ArchivistError::from)?;
63
64        info!("Connected to database!");
65
66        // Setup from sql directory
67        info!(
68            "Reading setup dir \"{}\"...",
69            sql_setup_dir.as_ref().display()
70        );
71        let files = std::fs::read_dir(sql_setup_dir)?
72            .flat_map(|entry| entry.map(|e| read_to_string(e.path())))
73            .flatten()
74            .collect::<Vec<_>>();
75
76        for file in files {
77            for sql in file.split(';') {
78                sqlx::query(sql)
79                    .execute(&pool)
80                    .await
81                    .map_err(ArchivistError::from)?;
82            }
83        }
84        info!("Finished setup!");
85
86        Ok(Self {
87            pool,
88            config,
89            current_transaction: None,
90        })
91    }
92
93    /// Starts a new transaction. Returns an error if there is a previous
94    /// transaction still live.
95    /// # Errors
96    /// Fails if there is already a live transaction
97    pub async fn start_transaction(&mut self) -> Result<()> {
98        if self.current_transaction.is_some() {
99            return Err(ArchivistError::TransactionAlreadyLive);
100        }
101
102        self.current_transaction = Some(self.pool.begin().await?);
103        Ok(())
104    }
105
106    /// Commits a currently live transaction. Returns an error if there is none
107    /// present.
108    /// # Errors
109    /// Fails if there is no live transaction. Forwards errors from `sqlx`.
110    pub async fn commit_transaction(&mut self) -> Result<()> {
111        self.current_transaction
112            .take()
113            .ok_or(ArchivistError::NoTransactionToCommit)?
114            .commit()
115            .await?;
116
117        Ok(())
118    }
119
120    /// Undos a currently live transaction. Returns an error if there is none
121    /// present.
122    /// # Errors
123    /// Fails if there is no live transaction. Forwards errors from `sqlx`.
124    pub async fn rollback_transaction(&mut self) -> Result<()> {
125        self.current_transaction
126            .take()
127            .ok_or(ArchivistError::NoTransactionToRollback)?
128            .rollback()
129            .await?;
130
131        Ok(())
132    }
133
134    /// Checks whether a row with `id` exists in `table`.
135    /// # Errors
136    /// Forwards errors from `sqlx`.
137    pub async fn exists<T: TableItem>(&self, id: i32) -> Result<bool> {
138        let exists = T::exists(&self.pool, id).await?;
139        Ok(exists)
140    }
141
142    /// Same as `exists`, but returns a result instead of an option.
143    /// # Errors
144    /// Fails if the id does not exist. Forwards errors from `sqlx`.
145    pub async fn assert_exists<T: TableItem>(&self, id: i32) -> Result<()> {
146        if self.exists::<T>(id).await? {
147            Ok(())
148        } else {
149            Err(ArchivistError::MissingID(T::TABLE, id))
150        }
151    }
152
153    /// Returns an error if the provided item collides with anything.
154    /// # Errors
155    /// Fails if there is a collision. Forwards errors from `sqlx`.
156    pub async fn assert_unique<T: TableItem>(&self, item: &T) -> Result<()> {
157        item.check_unique(&self.pool).await?.map_or(Ok(()), |id| {
158            Err(ArchivistError::EntryAlreadyExists(T::TABLE.to_string(), id))
159        })
160    }
161
162    /// Gets an item whose id you know.
163    ///
164    /// # Errors
165    /// Forwards errors from `sqlx`.
166    pub async fn get<T: TableItem>(&self, id: i32) -> Result<T> {
167        self.assert_exists::<T>(id).await?;
168        T::select_by_id(&self.pool, id)
169            .await
170            .map_err(ArchivistError::Sqlx)
171    }
172
173    /// Gets all items from `T::TABLE`.
174    /// # Errors
175    /// Forwards errors from `sqlx`.
176    pub async fn get_all<T: TableItem>(&self) -> Result<Vec<T>> {
177        T::select_all(&self.pool)
178            .await
179            .map_err(ArchivistError::Sqlx)
180    }
181
182    /// Finds an item from `T::TABLE`, fulfilling a `where`-condition.
183    ///
184    /// This is essentially just wrapping a query like `select T from TABLE
185    /// where CONDITION;`.
186    ///
187    /// Due to the flexibility in the `condition` parameter, this is not run
188    /// via any macro, and as such cannot be compile-time tested. Use
189    /// responsibly.
190    ///
191    /// # Errors
192    /// Forwards errors from `sqlx`.
193    pub async fn find<T: TableItem>(&self, condition: &str, ) -> Result<Option<T>> {
194        let query = format!("select id from {} where {};", T::TABLE, condition);
195
196        let opt_id: Option<(i32,)> =
197            sqlx::query_as(&query).fetch_optional(&self.pool).await?;
198
199        let id = match opt_id {
200            None => return Ok(None),
201            Some((i,)) => i,
202        };
203
204        T::select_by_id(&self.pool, id)
205            .await
206            .map_err(ArchivistError::Sqlx)
207            .map(|i| Some(i))
208    }
209
210    /// Adds a new entry to `T::TABLE`, making sure no unique fields are
211    /// duplicated.
212    ///
213    /// Returns the id of the newly inserted item.
214    /// # Errors
215    /// Fails if there are collisions in the table. Forwards errors from `sqlx`.
216    pub async fn insert<T: TableItem>(&mut self, item: T) -> Result<i32> {
217        self.assert_unique(&item).await?;
218        let tx = self.get_transaction().await?;
219
220        item.insert(tx).await.map_err(ArchivistError::Sqlx)
221    }
222
223    /// Update an entry with the given `id` in the given `table`. `value` in
224    /// this case is a string like `number = 2`, i.e. both the column and the
225    /// actual value.
226    ///
227    /// Remember that string values need to be incased in single quotes.
228    ///
229    /// Due to the flexibility in the `value` parameter, this is not run via
230    /// any macro, and as such cannot be compile-time tested. Use responsibly.
231    ///
232    /// # Errors
233    /// Forwards errors from `sqlx`.
234    pub async fn update<T: TableItem>(&mut self, id: i32, value: &str, ) -> Result<()> {
235        self.assert_exists::<T>(id).await?;
236
237        let query = format!("update {} set {value} where id={id};", T::TABLE,);
238
239        let tx = self.get_transaction().await?;
240        sqlx::query(&query).execute(tx).await?;
241
242        Ok(())
243    }
244
245    /// Updates all columns for a the row with the supplied `id`.
246    ///
247    /// # Errors
248    /// Forwards errors from `sqlx`.
249    pub async fn update_from_cache<T: TableItem>(&mut self, cache: &T, id: i32,) -> Result<()> {
250        self.assert_exists::<T>(id).await?;
251        let tx = self.get_transaction().await?;
252
253        cache.update(tx, id).await.map_err(ArchivistError::Sqlx)
254    }
255
256    /// Deletes an item from a table. Make sure you are providing the correct
257    /// type, as there is no way of checking your intentions!
258    ///
259    /// # Errors
260    /// Fails if `id` does not exist. Forwards errors from `sqlx`.
261    pub async fn delete<T: TableItem>(&mut self, id: i32) -> Result<()> {
262        if !self.exists::<T>(id).await? {
263            warn!(
264                "Entry with id {id} does not exists and thus cannot be removed"
265            );
266            return Ok(());
267        }
268
269        let tx = self.get_transaction().await?;
270        T::delete(tx, id).await.map_err(ArchivistError::Sqlx)
271    }
272
273    /// Gets the indicated values from `table`, for one row if it meets
274    /// `condition`.
275    ///
276    /// This may be preferred if you want only a specific value instead of the
277    /// whole item, or a value that is not present in the rust-end struct, but
278    /// is stored in the table (e.g. a password hash).
279    ///
280    /// Due to the flexibility in the parameters, this is not run via any
281    /// macro, and as such cannot be compile-time tested. Use responsibly.
282    ///
283    /// # Errors
284    /// Forwards errors from `sqlx`.
285    pub async fn get_special<T: TableItem, U>(&self, columns: &str, condition: &str, ) -> Result<Option<U>> where for<'r> U: FromRow<'r, PgRow> + Send + Unpin {
286        let query = format!(
287            "select {columns} from {} where {condition} limit 1;",
288            T::TABLE,
289        );
290
291        let item = sqlx::query_as(&query).fetch_optional(&self.pool).await?;
292
293        Ok(item)
294    }
295
296    /// Returns the currently live transaction. If there is none present, it
297    /// first creates one.
298    async fn get_transaction(&mut self) -> Result<&mut PgConnection> {
299        if self.current_transaction.is_none() {
300            warn!("Started implicit transaction.");
301            self.current_transaction = Some(self.pool.begin().await?);
302        }
303
304        Ok(self.current_transaction.as_mut().unwrap())
305    }
306
307    /// The current configuration.
308    pub const fn config(&self) -> &Config {
309        &self.config
310    }
311}
312
313impl Debug for Archivist {
314    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315        f.debug_struct("Archivist")
316            .field("live:", &self.current_transaction.is_some())
317            .finish_non_exhaustive()
318    }
319}