homestar_runtime/
db.rs

1//! (Default) sqlite database integration and setup.
2
3use crate::{
4    db::utils::Health,
5    settings,
6    workflow::{self, StoredReceipt},
7    Receipt,
8};
9use anyhow::Result;
10use byte_unit::{AdjustedByte, Byte, ByteUnit};
11use diesel::{
12    dsl::now,
13    r2d2::{self, CustomizeConnection, ManageConnection},
14    BelongingToDsl, Connection as SingleConnection, ExpressionMethods, OptionalExtension, QueryDsl,
15    RunQueryDsl, SelectableHelper, SqliteConnection,
16};
17use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness};
18use dotenvy::dotenv;
19use homestar_invocation::Pointer;
20use libipld::Cid;
21use std::{env, sync::Arc, time::Duration};
22use tokio::fs;
23use tracing::info;
24
25#[allow(missing_docs, unused_imports)]
26#[rustfmt::skip]
27pub mod schema;
28pub(crate) mod utils;
29
30const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/");
31const PRAGMAS: &str = "
32PRAGMA journal_mode = WAL;          -- better write-concurrency
33PRAGMA synchronous = NORMAL;        -- fsync only in critical moments
34PRAGMA wal_autocheckpoint = 1000;   -- write WAL changes back every 1000 pages, for an in average 1MB WAL file. May affect readers if number is increased
35PRAGMA busy_timeout = 1000;         -- sleep if the database is busy
36PRAGMA foreign_keys = ON;           -- enforce foreign keys
37";
38
39/// Database environment variable.
40pub(crate) const ENV: &str = "DATABASE_URL";
41
42/// A Sqlite connection [pool].
43///
44/// [pool]: r2d2::Pool
45pub(crate) type Pool = r2d2::Pool<r2d2::ConnectionManager<SqliteConnection>>;
46/// A [connection] from the Sqlite connection [pool].
47///
48/// [connection]: r2d2::PooledConnection
49/// [pool]: r2d2::Pool
50pub(crate) type Connection =
51    r2d2::PooledConnection<r2d2::ConnectionManager<diesel::SqliteConnection>>;
52
53/// The database object, which wraps an inner [Arc] to the connection pool.
54#[derive(Debug)]
55pub struct Db {
56    /// The [Arc]'ed connection pool.
57    pub(crate) pool: Arc<Pool>,
58    /// The database URL.
59    pub(crate) url: String,
60}
61
62impl Clone for Db {
63    fn clone(&self) -> Self {
64        Self {
65            pool: Arc::clone(&self.pool),
66            url: self.url.clone(),
67        }
68    }
69}
70
71impl Db {
72    /// Get size of SQlite file in megabytes (via async call).
73    pub async fn size() -> Result<AdjustedByte> {
74        let url = env::var(ENV)?;
75        let len = fs::metadata(url).await?.len();
76        let byte = Byte::from_bytes(len);
77        let byte_unit = byte.get_adjusted_unit(ByteUnit::MB);
78        Ok(byte_unit)
79    }
80}
81
82/// Database trait for working with different Sqlite connection pool and
83/// connection configurations.
84pub trait Database: Send + Sync + Clone {
85    /// Establish a pooled connection to Sqlite database.
86    fn setup_connection_pool(
87        settings: &settings::Node,
88        database_url: Option<String>,
89    ) -> Result<Self>
90    where
91        Self: Sized;
92    /// Get a pooled connection for the database.
93    fn conn(&self) -> Result<Connection>;
94
95    /// Set database url.
96    ///
97    /// Contains a minimal side-effect to set the env if not already set.
98    fn set_url(database_url: Option<String>) -> Option<String> {
99        database_url.map_or_else(
100            || dotenv().ok().and_then(|_| env::var(ENV).ok()),
101            |url| {
102                env::set_var(ENV, &url);
103                Some(url)
104            },
105        )
106    }
107
108    /// Get database url.
109    fn url() -> Result<String> {
110        Ok(env::var(ENV)?)
111    }
112
113    /// Test a Sqlite connection to the database and run pending migrations.
114    fn setup(url: &str) -> Result<SqliteConnection> {
115        info!(
116            subject = "database",
117            category = "homestar.init",
118            "setting up database at {}, running migrations if needed",
119            url
120        );
121        let mut connection = SqliteConnection::establish(url)?;
122        let _ = connection.run_pending_migrations(MIGRATIONS);
123
124        Ok(connection)
125    }
126
127    /// Check if the database is up.
128    fn health_check(conn: &mut Connection) -> Result<Health, diesel::result::Error> {
129        diesel::sql_query("SELECT 1").execute(conn)?;
130        Ok(Health { healthy: true })
131    }
132
133    /// Commit a receipt to the database, updating two tables
134    /// within a transaction.
135    fn commit_receipt(
136        workflow_cid: Cid,
137        receipt: Receipt,
138        conn: &mut Connection,
139    ) -> Result<Receipt, diesel::result::Error> {
140        conn.transaction::<_, diesel::result::Error, _>(|conn| {
141            if let Some(returned) = Self::store_receipt(receipt.clone(), conn)? {
142                Self::store_workflow_receipt(workflow_cid, returned.cid(), conn)?;
143                Ok(returned)
144            } else {
145                Self::store_workflow_receipt(workflow_cid, receipt.cid(), conn)?;
146                Ok(receipt)
147            }
148        })
149    }
150
151    /// Store receipt given a connection to the database pool.
152    ///
153    /// On conflicts, do nothing.
154    fn store_receipt(
155        receipt: Receipt,
156        conn: &mut Connection,
157    ) -> Result<Option<Receipt>, diesel::result::Error> {
158        diesel::insert_into(schema::receipts::table)
159            .values(&receipt)
160            .on_conflict(schema::receipts::cid)
161            .do_nothing()
162            .get_result(conn)
163            .optional()
164    }
165
166    /// Store receipts given a connection to the Database pool.
167    fn store_receipts(
168        receipts: Vec<Receipt>,
169        conn: &mut Connection,
170    ) -> Result<usize, diesel::result::Error> {
171        receipts.iter().try_fold(0, |acc, receipt| {
172            if let Some(res) = diesel::insert_into(schema::receipts::table)
173                .values(receipt)
174                .on_conflict(schema::receipts::cid)
175                .do_nothing()
176                .execute(conn)
177                .optional()?
178            {
179                Ok::<_, diesel::result::Error>(acc + res)
180            } else {
181                Ok(acc)
182            }
183        })
184    }
185
186    /// Find receipts given a set of [Instruction] [Pointer]s, which is indexed.
187    ///
188    /// [Instruction]: homestar_invocation::task::Instruction
189    fn find_instruction_pointers(
190        pointers: &Vec<Pointer>,
191        conn: &mut Connection,
192    ) -> Result<Vec<Receipt>, diesel::result::Error> {
193        schema::receipts::dsl::receipts
194            .filter(schema::receipts::instruction.eq_any(pointers))
195            .load(conn)
196    }
197
198    /// Find receipt for a given [Instruction] Cid, which is indexed.
199    ///
200    /// [Instruction]: homestar_invocation::task::Instruction
201    fn find_instruction_by_cid(
202        cid: Cid,
203        conn: &mut Connection,
204    ) -> Result<Receipt, diesel::result::Error> {
205        schema::receipts::dsl::receipts
206            .filter(schema::receipts::instruction.eq(Pointer::new(cid)))
207            .first(conn)
208    }
209
210    /// Find a receipt for a given Cid.
211    fn find_receipt_by_cid(
212        cid: Cid,
213        conn: &mut Connection,
214    ) -> Result<Receipt, diesel::result::Error> {
215        schema::receipts::dsl::receipts
216            .filter(schema::receipts::cid.eq(Pointer::new(cid)))
217            .select(Receipt::as_select())
218            .get_result(conn)
219    }
220
221    /// Find receipts given a batch of [Receipt] [Pointer]s.
222    fn find_receipt_pointers(
223        pointers: &Vec<Pointer>,
224        conn: &mut Connection,
225    ) -> Result<Vec<Receipt>, diesel::result::Error> {
226        schema::receipts::dsl::receipts
227            .filter(schema::receipts::cid.eq_any(pointers))
228            .load(conn)
229    }
230
231    /// Store localized workflow cid and information, e.g. number of tasks.
232    ///
233    /// On conflicts, do nothing.
234    /// Otherwise, return the stored workflow.
235    fn store_workflow(
236        workflow: workflow::Stored,
237        conn: &mut Connection,
238    ) -> Result<workflow::Stored, diesel::result::Error> {
239        if let Some(stored) = diesel::insert_into(schema::workflows::table)
240            .values(&workflow)
241            .on_conflict(schema::workflows::cid)
242            .do_nothing()
243            .get_result(conn)
244            .optional()?
245        {
246            Ok(stored)
247        } else {
248            Ok(workflow)
249        }
250    }
251
252    /// Update workflow status given a Cid to the workflow.
253    fn set_workflow_status(
254        workflow_cid: Cid,
255        status: workflow::Status,
256        conn: &mut Connection,
257    ) -> Result<(), diesel::result::Error> {
258        diesel::update(schema::workflows::dsl::workflows)
259            .filter(schema::workflows::cid.eq(Pointer::new(workflow_cid)))
260            .set(schema::workflows::status.eq(status))
261            .execute(conn)?;
262
263        Ok(())
264    }
265
266    /// Store workflow Cid and [Receipt] Cid in the database for inner join.
267    fn store_workflow_receipt(
268        workflow_cid: Cid,
269        receipt_cid: Cid,
270        conn: &mut Connection,
271    ) -> Result<Option<usize>, diesel::result::Error> {
272        let value = StoredReceipt::new(Pointer::new(workflow_cid), Pointer::new(receipt_cid));
273        diesel::insert_into(schema::workflows_receipts::table)
274            .values(&value)
275            .on_conflict((
276                schema::workflows_receipts::workflow_cid,
277                schema::workflows_receipts::receipt_cid,
278            ))
279            .do_nothing()
280            .execute(conn)
281            .optional()
282    }
283
284    /// Store series of receipts for a workflow Cid in the
285    /// [schema::workflows_receipts] table.
286    ///
287    /// NOTE: We cannot do batch inserts with `on_conflict`, so we add
288    /// each one 1-by-1:
289    /// <https://github.com/diesel-rs/diesel/issues/3114>
290    fn store_workflow_receipts(
291        workflow_cid: Cid,
292        receipts: &[Cid],
293        conn: &mut Connection,
294    ) -> Result<usize, diesel::result::Error> {
295        receipts.iter().try_fold(0, |acc, receipt| {
296            if let Some(res) = Self::store_workflow_receipt(workflow_cid, *receipt, conn)? {
297                Ok::<_, diesel::result::Error>(acc + res)
298            } else {
299                Ok(acc)
300            }
301        })
302    }
303
304    /// Select workflow given a Cid to the workflow.
305    fn select_workflow(
306        cid: Cid,
307        conn: &mut Connection,
308    ) -> Result<workflow::Stored, diesel::result::Error> {
309        schema::workflows::dsl::workflows
310            .filter(schema::workflows::cid.eq(Pointer::new(cid)))
311            .select(workflow::Stored::as_select())
312            .get_result(conn)
313    }
314
315    /// Return workflow information with number of receipts emitted.
316    fn get_workflow_info(
317        workflow_cid: Cid,
318        conn: &mut Connection,
319    ) -> Result<(Option<String>, workflow::Info), diesel::result::Error> {
320        let workflow = Self::select_workflow(workflow_cid, conn)?;
321        let associated_receipts = workflow::StoredReceipt::belonging_to(&workflow)
322            .select(schema::workflows_receipts::receipt_cid)
323            .load(conn)?;
324
325        let cids = associated_receipts
326            .into_iter()
327            .map(|pointer: Pointer| pointer.cid())
328            .collect();
329
330        let name = workflow.name.clone();
331        let info = workflow::Info::new(workflow, cids);
332
333        Ok((name, info))
334    }
335
336    /// Update the local (view) name of a workflow.
337    fn update_local_name(name: &str, conn: &mut Connection) -> Result<(), diesel::result::Error> {
338        diesel::update(schema::workflows::dsl::workflows)
339            .filter(schema::workflows::created_at.lt(now))
340            .set(schema::workflows::name.eq(name))
341            .execute(conn)?;
342
343        Ok(())
344    }
345}
346
347impl Database for Db {
348    fn setup_connection_pool(
349        settings: &settings::Node,
350        database_url: Option<String>,
351    ) -> Result<Self> {
352        let database_url = Self::set_url(database_url).unwrap_or_else(|| {
353            settings
354                .db
355                .url
356                .as_ref()
357                .map_or_else(|| "homestar.db".to_string(), |url| url.to_string())
358        });
359
360        Self::setup(&database_url)?;
361        let manager = r2d2::ConnectionManager::<SqliteConnection>::new(database_url.clone());
362
363        // setup PRAGMAs
364        manager
365            .connect()
366            .and_then(|mut conn| ConnectionCustomizer.on_acquire(&mut conn))?;
367
368        let pool = r2d2::Pool::builder()
369            // Max number of conns.
370            .max_size(settings.db.max_pool_size)
371            // Never maintain idle connections
372            .min_idle(Some(0))
373            // Close connections after 30 seconds of idle time
374            .idle_timeout(Some(Duration::from_secs(30)))
375            .connection_customizer(Box::new(ConnectionCustomizer))
376            .build(manager)
377            .expect("DATABASE_URL must be set to an SQLite DB file");
378
379        Ok(Db {
380            pool: Arc::new(pool),
381            url: database_url,
382        })
383    }
384
385    fn conn(&self) -> Result<Connection> {
386        let conn = self.pool.get()?;
387        Ok(conn)
388    }
389}
390
391/// Database connection options.
392#[derive(Debug, Clone, PartialEq)]
393pub(crate) struct ConnectionCustomizer;
394
395impl<C> CustomizeConnection<C, r2d2::Error> for ConnectionCustomizer
396where
397    C: diesel::Connection,
398{
399    fn on_acquire(&self, conn: &mut C) -> Result<(), r2d2::Error> {
400        conn.batch_execute(PRAGMAS).map_err(r2d2::Error::QueryError)
401    }
402}
403
404#[cfg(test)]
405mod test {
406    use super::*;
407    use crate::test_utils::db::MemoryDb;
408
409    #[homestar_runtime_proc_macro::db_async_test]
410    fn check_pragmas_memory_db() {
411        let settings = TestSettings::load();
412
413        let db = MemoryDb::setup_connection_pool(settings.node(), None).unwrap();
414        let mut conn = db.conn().unwrap();
415
416        let journal_mode = diesel::dsl::sql::<diesel::sql_types::Text>("PRAGMA journal_mode")
417            .load::<String>(&mut conn)
418            .unwrap();
419
420        assert_eq!(journal_mode, vec!["memory".to_string()]);
421
422        let fk_mode = diesel::dsl::sql::<diesel::sql_types::Text>("PRAGMA foreign_keys")
423            .load::<String>(&mut conn)
424            .unwrap();
425
426        assert_eq!(fk_mode, vec!["1".to_string()]);
427
428        let busy_timeout = diesel::dsl::sql::<diesel::sql_types::Text>("PRAGMA busy_timeout")
429            .load::<String>(&mut conn)
430            .unwrap();
431
432        assert_eq!(busy_timeout, vec!["1000".to_string()]);
433    }
434}