argos_arpa/archivist.rs
1//! A module to handle the postgres database connection.
2//! All DB interactions should pass through Archivist.
3//!
4//! Every function modifying the DB (i.e. not ones that only _get_ data) will
5//! automatically start a transaction if it there is not already one active.
6//! No function should commit a transaction, except for `commit_transcation`.
7
8use crate::{ARPAError, config::Config};
9use log::{info, warn};
10use std::{fmt::Debug, fs::read_to_string};
11
12pub mod data_types;
13mod error;
14pub mod table;
15
16pub use error::ArchivistError;
17use sqlx::{
18 FromRow, PgConnection, Pool, Postgres, Transaction,
19 postgres::{PgPoolOptions, PgRow},
20};
21use table::TableItem;
22
23type Result<T> = std::result::Result<T, ArchivistError>;
24
25/// This keeps a live connection to the database and acts as your friend in
26/// getting and posting data.
27///
28/// For any queries that modify the DB, a transaction
29/// _will_ be used, and if the user has not explicitly started one, a warning
30/// will be issued (though not an error). NB: If a transaction goes out of scope,
31/// it is rolled back.
32///
33/// All tables are accessible _only_ through the `Table` enum.
34pub struct Archivist {
35 pool: Pool<Postgres>,
36 config: Config,
37
38 /// This is here so that potentially destructive app commands always go
39 /// through transactions.
40 current_transaction: Option<Transaction<'static, Postgres>>,
41}
42
43impl Archivist {
44 /// Initializes a new connection to the database.
45 ///
46 /// # Errors
47 /// Fails if setup data is missing. Forwards errors from `sqlx`.
48 pub async fn new(
49 config_path: impl AsRef<std::path::Path>,
50 sql_setup_dir: impl AsRef<std::path::Path>,
51 ) -> std::result::Result<Self, ARPAError> {
52 info!("Reading config \"{}\"...", config_path.as_ref().display());
53 let config = Config::load(config_path)?;
54
55 let pool = PgPoolOptions::new()
56 .max_connections(config.database.pool_connections)
57 .acquire_timeout(std::time::Duration::from_millis(
58 config.database.connection_timeout,
59 ))
60 .connect(&config.database.url)
61 .await
62 .map_err(ArchivistError::from)?;
63
64 info!("Connected to database!");
65
66 // Setup from sql directory
67 info!(
68 "Reading setup dir \"{}\"...",
69 sql_setup_dir.as_ref().display()
70 );
71 let files = std::fs::read_dir(sql_setup_dir)?
72 .flat_map(|entry| entry.map(|e| read_to_string(e.path())))
73 .flatten()
74 .collect::<Vec<_>>();
75
76 for file in files {
77 for sql in file.split(';') {
78 sqlx::query(sql)
79 .execute(&pool)
80 .await
81 .map_err(ArchivistError::from)?;
82 }
83 }
84 info!("Finished setup!");
85
86 Ok(Self {
87 pool,
88 config,
89 current_transaction: None,
90 })
91 }
92
93 /// Starts a new transaction. Returns an error if there is a previous
94 /// transaction still live.
95 /// # Errors
96 /// Fails if there is already a live transaction
97 pub async fn start_transaction(&mut self) -> Result<()> {
98 if self.current_transaction.is_some() {
99 return Err(ArchivistError::TransactionAlreadyLive);
100 }
101
102 self.current_transaction = Some(self.pool.begin().await?);
103 Ok(())
104 }
105
106 /// Commits a currently live transaction. Returns an error if there is none
107 /// present.
108 /// # Errors
109 /// Fails if there is no live transaction. Forwards errors from `sqlx`.
110 pub async fn commit_transaction(&mut self) -> Result<()> {
111 self.current_transaction
112 .take()
113 .ok_or(ArchivistError::NoTransactionToCommit)?
114 .commit()
115 .await?;
116
117 Ok(())
118 }
119
120 /// Undos a currently live transaction. Returns an error if there is none
121 /// present.
122 /// # Errors
123 /// Fails if there is no live transaction. Forwards errors from `sqlx`.
124 pub async fn rollback_transaction(&mut self) -> Result<()> {
125 self.current_transaction
126 .take()
127 .ok_or(ArchivistError::NoTransactionToRollback)?
128 .rollback()
129 .await?;
130
131 Ok(())
132 }
133
134 /// Checks whether a row with `id` exists in `table`.
135 /// # Errors
136 /// Forwards errors from `sqlx`.
137 pub async fn exists<T: TableItem>(&self, id: i32) -> Result<bool> {
138 let exists = T::exists(&self.pool, id).await?;
139 Ok(exists)
140 }
141
142 /// Same as `exists`, but returns a result instead of an option.
143 /// # Errors
144 /// Fails if the id does not exist. Forwards errors from `sqlx`.
145 pub async fn assert_exists<T: TableItem>(&self, id: i32) -> Result<()> {
146 if self.exists::<T>(id).await? {
147 Ok(())
148 } else {
149 Err(ArchivistError::MissingID(T::TABLE, id))
150 }
151 }
152
153 /// Returns an error if the provided item collides with anything.
154 /// # Errors
155 /// Fails if there is a collision. Forwards errors from `sqlx`.
156 pub async fn assert_unique<T: TableItem>(&self, item: &T) -> Result<()> {
157 item.check_unique(&self.pool).await?.map_or(Ok(()), |id| {
158 Err(ArchivistError::EntryAlreadyExists(T::TABLE.to_string(), id))
159 })
160 }
161
162 /// Gets an item whose id you know.
163 ///
164 /// # Errors
165 /// Forwards errors from `sqlx`.
166 pub async fn get<T: TableItem>(&self, id: i32) -> Result<T> {
167 self.assert_exists::<T>(id).await?;
168 T::select_by_id(&self.pool, id)
169 .await
170 .map_err(ArchivistError::Sqlx)
171 }
172
173 /// Gets all items from `T::TABLE`.
174 /// # Errors
175 /// Forwards errors from `sqlx`.
176 pub async fn get_all<T: TableItem>(&self) -> Result<Vec<T>> {
177 T::select_all(&self.pool)
178 .await
179 .map_err(ArchivistError::Sqlx)
180 }
181
182 /// Finds an item from `T::TABLE`, fulfilling a `where`-condition.
183 ///
184 /// This is essentially just wrapping a query like `select T from TABLE
185 /// where CONDITION;`.
186 ///
187 /// Due to the flexibility in the `condition` parameter, this is not run
188 /// via any macro, and as such cannot be compile-time tested. Use
189 /// responsibly.
190 ///
191 /// # Errors
192 /// Forwards errors from `sqlx`.
193 pub async fn find<T: TableItem>(&self, condition: &str, ) -> Result<Option<T>> {
194 let query = format!("select id from {} where {};", T::TABLE, condition);
195
196 let opt_id: Option<(i32,)> =
197 sqlx::query_as(&query).fetch_optional(&self.pool).await?;
198
199 let id = match opt_id {
200 None => return Ok(None),
201 Some((i,)) => i,
202 };
203
204 T::select_by_id(&self.pool, id)
205 .await
206 .map_err(ArchivistError::Sqlx)
207 .map(|i| Some(i))
208 }
209
210 /// Adds a new entry to `T::TABLE`, making sure no unique fields are
211 /// duplicated.
212 ///
213 /// Returns the id of the newly inserted item.
214 /// # Errors
215 /// Fails if there are collisions in the table. Forwards errors from `sqlx`.
216 pub async fn insert<T: TableItem>(&mut self, item: T) -> Result<i32> {
217 self.assert_unique(&item).await?;
218 let tx = self.get_transaction().await?;
219
220 item.insert(tx).await.map_err(ArchivistError::Sqlx)
221 }
222
223 /// Update an entry with the given `id` in the given `table`. `value` in
224 /// this case is a string like `number = 2`, i.e. both the column and the
225 /// actual value.
226 ///
227 /// Remember that string values need to be incased in single quotes.
228 ///
229 /// Due to the flexibility in the `value` parameter, this is not run via
230 /// any macro, and as such cannot be compile-time tested. Use responsibly.
231 ///
232 /// # Errors
233 /// Forwards errors from `sqlx`.
234 pub async fn update<T: TableItem>(&mut self, id: i32, value: &str, ) -> Result<()> {
235 self.assert_exists::<T>(id).await?;
236
237 let query = format!("update {} set {value} where id={id};", T::TABLE,);
238
239 let tx = self.get_transaction().await?;
240 sqlx::query(&query).execute(tx).await?;
241
242 Ok(())
243 }
244
245 /// Updates all columns for a the row with the supplied `id`.
246 ///
247 /// # Errors
248 /// Forwards errors from `sqlx`.
249 pub async fn update_from_cache<T: TableItem>(&mut self, cache: &T, id: i32,) -> Result<()> {
250 self.assert_exists::<T>(id).await?;
251 let tx = self.get_transaction().await?;
252
253 cache.update(tx, id).await.map_err(ArchivistError::Sqlx)
254 }
255
256 /// Deletes an item from a table. Make sure you are providing the correct
257 /// type, as there is no way of checking your intentions!
258 ///
259 /// # Errors
260 /// Fails if `id` does not exist. Forwards errors from `sqlx`.
261 pub async fn delete<T: TableItem>(&mut self, id: i32) -> Result<()> {
262 if !self.exists::<T>(id).await? {
263 warn!(
264 "Entry with id {id} does not exists and thus cannot be removed"
265 );
266 return Ok(());
267 }
268
269 let tx = self.get_transaction().await?;
270 T::delete(tx, id).await.map_err(ArchivistError::Sqlx)
271 }
272
273 /// Gets the indicated values from `table`, for one row if it meets
274 /// `condition`.
275 ///
276 /// This may be preferred if you want only a specific value instead of the
277 /// whole item, or a value that is not present in the rust-end struct, but
278 /// is stored in the table (e.g. a password hash).
279 ///
280 /// Due to the flexibility in the parameters, this is not run via any
281 /// macro, and as such cannot be compile-time tested. Use responsibly.
282 ///
283 /// # Errors
284 /// Forwards errors from `sqlx`.
285 pub async fn get_special<T: TableItem, U>(&self, columns: &str, condition: &str, ) -> Result<Option<U>> where for<'r> U: FromRow<'r, PgRow> + Send + Unpin {
286 let query = format!(
287 "select {columns} from {} where {condition} limit 1;",
288 T::TABLE,
289 );
290
291 let item = sqlx::query_as(&query).fetch_optional(&self.pool).await?;
292
293 Ok(item)
294 }
295
296 /// Returns the currently live transaction. If there is none present, it
297 /// first creates one.
298 async fn get_transaction(&mut self) -> Result<&mut PgConnection> {
299 if self.current_transaction.is_none() {
300 warn!("Started implicit transaction.");
301 self.current_transaction = Some(self.pool.begin().await?);
302 }
303
304 Ok(self.current_transaction.as_mut().unwrap())
305 }
306
307 /// The current configuration.
308 pub const fn config(&self) -> &Config {
309 &self.config
310 }
311}
312
313impl Debug for Archivist {
314 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
315 f.debug_struct("Archivist")
316 .field("live:", &self.current_transaction.is_some())
317 .finish_non_exhaustive()
318 }
319}