1use crate::{
4 db::utils::Health,
5 settings,
6 workflow::{self, StoredReceipt},
7 Receipt,
8};
9use anyhow::Result;
10use byte_unit::{AdjustedByte, Byte, ByteUnit};
11use diesel::{
12 dsl::now,
13 r2d2::{self, CustomizeConnection, ManageConnection},
14 BelongingToDsl, Connection as SingleConnection, ExpressionMethods, OptionalExtension, QueryDsl,
15 RunQueryDsl, SelectableHelper, SqliteConnection,
16};
17use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness};
18use dotenvy::dotenv;
19use homestar_invocation::Pointer;
20use libipld::Cid;
21use std::{env, sync::Arc, time::Duration};
22use tokio::fs;
23use tracing::info;
24
25#[allow(missing_docs, unused_imports)]
26#[rustfmt::skip]
27pub mod schema;
28pub(crate) mod utils;
29
30const MIGRATIONS: EmbeddedMigrations = embed_migrations!("migrations/");
31const PRAGMAS: &str = "
32PRAGMA journal_mode = WAL; -- better write-concurrency
33PRAGMA synchronous = NORMAL; -- fsync only in critical moments
34PRAGMA wal_autocheckpoint = 1000; -- write WAL changes back every 1000 pages, for an in average 1MB WAL file. May affect readers if number is increased
35PRAGMA busy_timeout = 1000; -- sleep if the database is busy
36PRAGMA foreign_keys = ON; -- enforce foreign keys
37";
38
39pub(crate) const ENV: &str = "DATABASE_URL";
41
42pub(crate) type Pool = r2d2::Pool<r2d2::ConnectionManager<SqliteConnection>>;
46pub(crate) type Connection =
51 r2d2::PooledConnection<r2d2::ConnectionManager<diesel::SqliteConnection>>;
52
53#[derive(Debug)]
55pub struct Db {
56 pub(crate) pool: Arc<Pool>,
58 pub(crate) url: String,
60}
61
62impl Clone for Db {
63 fn clone(&self) -> Self {
64 Self {
65 pool: Arc::clone(&self.pool),
66 url: self.url.clone(),
67 }
68 }
69}
70
71impl Db {
72 pub async fn size() -> Result<AdjustedByte> {
74 let url = env::var(ENV)?;
75 let len = fs::metadata(url).await?.len();
76 let byte = Byte::from_bytes(len);
77 let byte_unit = byte.get_adjusted_unit(ByteUnit::MB);
78 Ok(byte_unit)
79 }
80}
81
82pub trait Database: Send + Sync + Clone {
85 fn setup_connection_pool(
87 settings: &settings::Node,
88 database_url: Option<String>,
89 ) -> Result<Self>
90 where
91 Self: Sized;
92 fn conn(&self) -> Result<Connection>;
94
95 fn set_url(database_url: Option<String>) -> Option<String> {
99 database_url.map_or_else(
100 || dotenv().ok().and_then(|_| env::var(ENV).ok()),
101 |url| {
102 env::set_var(ENV, &url);
103 Some(url)
104 },
105 )
106 }
107
108 fn url() -> Result<String> {
110 Ok(env::var(ENV)?)
111 }
112
113 fn setup(url: &str) -> Result<SqliteConnection> {
115 info!(
116 subject = "database",
117 category = "homestar.init",
118 "setting up database at {}, running migrations if needed",
119 url
120 );
121 let mut connection = SqliteConnection::establish(url)?;
122 let _ = connection.run_pending_migrations(MIGRATIONS);
123
124 Ok(connection)
125 }
126
127 fn health_check(conn: &mut Connection) -> Result<Health, diesel::result::Error> {
129 diesel::sql_query("SELECT 1").execute(conn)?;
130 Ok(Health { healthy: true })
131 }
132
133 fn commit_receipt(
136 workflow_cid: Cid,
137 receipt: Receipt,
138 conn: &mut Connection,
139 ) -> Result<Receipt, diesel::result::Error> {
140 conn.transaction::<_, diesel::result::Error, _>(|conn| {
141 if let Some(returned) = Self::store_receipt(receipt.clone(), conn)? {
142 Self::store_workflow_receipt(workflow_cid, returned.cid(), conn)?;
143 Ok(returned)
144 } else {
145 Self::store_workflow_receipt(workflow_cid, receipt.cid(), conn)?;
146 Ok(receipt)
147 }
148 })
149 }
150
151 fn store_receipt(
155 receipt: Receipt,
156 conn: &mut Connection,
157 ) -> Result<Option<Receipt>, diesel::result::Error> {
158 diesel::insert_into(schema::receipts::table)
159 .values(&receipt)
160 .on_conflict(schema::receipts::cid)
161 .do_nothing()
162 .get_result(conn)
163 .optional()
164 }
165
166 fn store_receipts(
168 receipts: Vec<Receipt>,
169 conn: &mut Connection,
170 ) -> Result<usize, diesel::result::Error> {
171 receipts.iter().try_fold(0, |acc, receipt| {
172 if let Some(res) = diesel::insert_into(schema::receipts::table)
173 .values(receipt)
174 .on_conflict(schema::receipts::cid)
175 .do_nothing()
176 .execute(conn)
177 .optional()?
178 {
179 Ok::<_, diesel::result::Error>(acc + res)
180 } else {
181 Ok(acc)
182 }
183 })
184 }
185
186 fn find_instruction_pointers(
190 pointers: &Vec<Pointer>,
191 conn: &mut Connection,
192 ) -> Result<Vec<Receipt>, diesel::result::Error> {
193 schema::receipts::dsl::receipts
194 .filter(schema::receipts::instruction.eq_any(pointers))
195 .load(conn)
196 }
197
198 fn find_instruction_by_cid(
202 cid: Cid,
203 conn: &mut Connection,
204 ) -> Result<Receipt, diesel::result::Error> {
205 schema::receipts::dsl::receipts
206 .filter(schema::receipts::instruction.eq(Pointer::new(cid)))
207 .first(conn)
208 }
209
210 fn find_receipt_by_cid(
212 cid: Cid,
213 conn: &mut Connection,
214 ) -> Result<Receipt, diesel::result::Error> {
215 schema::receipts::dsl::receipts
216 .filter(schema::receipts::cid.eq(Pointer::new(cid)))
217 .select(Receipt::as_select())
218 .get_result(conn)
219 }
220
221 fn find_receipt_pointers(
223 pointers: &Vec<Pointer>,
224 conn: &mut Connection,
225 ) -> Result<Vec<Receipt>, diesel::result::Error> {
226 schema::receipts::dsl::receipts
227 .filter(schema::receipts::cid.eq_any(pointers))
228 .load(conn)
229 }
230
231 fn store_workflow(
236 workflow: workflow::Stored,
237 conn: &mut Connection,
238 ) -> Result<workflow::Stored, diesel::result::Error> {
239 if let Some(stored) = diesel::insert_into(schema::workflows::table)
240 .values(&workflow)
241 .on_conflict(schema::workflows::cid)
242 .do_nothing()
243 .get_result(conn)
244 .optional()?
245 {
246 Ok(stored)
247 } else {
248 Ok(workflow)
249 }
250 }
251
252 fn set_workflow_status(
254 workflow_cid: Cid,
255 status: workflow::Status,
256 conn: &mut Connection,
257 ) -> Result<(), diesel::result::Error> {
258 diesel::update(schema::workflows::dsl::workflows)
259 .filter(schema::workflows::cid.eq(Pointer::new(workflow_cid)))
260 .set(schema::workflows::status.eq(status))
261 .execute(conn)?;
262
263 Ok(())
264 }
265
266 fn store_workflow_receipt(
268 workflow_cid: Cid,
269 receipt_cid: Cid,
270 conn: &mut Connection,
271 ) -> Result<Option<usize>, diesel::result::Error> {
272 let value = StoredReceipt::new(Pointer::new(workflow_cid), Pointer::new(receipt_cid));
273 diesel::insert_into(schema::workflows_receipts::table)
274 .values(&value)
275 .on_conflict((
276 schema::workflows_receipts::workflow_cid,
277 schema::workflows_receipts::receipt_cid,
278 ))
279 .do_nothing()
280 .execute(conn)
281 .optional()
282 }
283
284 fn store_workflow_receipts(
291 workflow_cid: Cid,
292 receipts: &[Cid],
293 conn: &mut Connection,
294 ) -> Result<usize, diesel::result::Error> {
295 receipts.iter().try_fold(0, |acc, receipt| {
296 if let Some(res) = Self::store_workflow_receipt(workflow_cid, *receipt, conn)? {
297 Ok::<_, diesel::result::Error>(acc + res)
298 } else {
299 Ok(acc)
300 }
301 })
302 }
303
304 fn select_workflow(
306 cid: Cid,
307 conn: &mut Connection,
308 ) -> Result<workflow::Stored, diesel::result::Error> {
309 schema::workflows::dsl::workflows
310 .filter(schema::workflows::cid.eq(Pointer::new(cid)))
311 .select(workflow::Stored::as_select())
312 .get_result(conn)
313 }
314
315 fn get_workflow_info(
317 workflow_cid: Cid,
318 conn: &mut Connection,
319 ) -> Result<(Option<String>, workflow::Info), diesel::result::Error> {
320 let workflow = Self::select_workflow(workflow_cid, conn)?;
321 let associated_receipts = workflow::StoredReceipt::belonging_to(&workflow)
322 .select(schema::workflows_receipts::receipt_cid)
323 .load(conn)?;
324
325 let cids = associated_receipts
326 .into_iter()
327 .map(|pointer: Pointer| pointer.cid())
328 .collect();
329
330 let name = workflow.name.clone();
331 let info = workflow::Info::new(workflow, cids);
332
333 Ok((name, info))
334 }
335
336 fn update_local_name(name: &str, conn: &mut Connection) -> Result<(), diesel::result::Error> {
338 diesel::update(schema::workflows::dsl::workflows)
339 .filter(schema::workflows::created_at.lt(now))
340 .set(schema::workflows::name.eq(name))
341 .execute(conn)?;
342
343 Ok(())
344 }
345}
346
347impl Database for Db {
348 fn setup_connection_pool(
349 settings: &settings::Node,
350 database_url: Option<String>,
351 ) -> Result<Self> {
352 let database_url = Self::set_url(database_url).unwrap_or_else(|| {
353 settings
354 .db
355 .url
356 .as_ref()
357 .map_or_else(|| "homestar.db".to_string(), |url| url.to_string())
358 });
359
360 Self::setup(&database_url)?;
361 let manager = r2d2::ConnectionManager::<SqliteConnection>::new(database_url.clone());
362
363 manager
365 .connect()
366 .and_then(|mut conn| ConnectionCustomizer.on_acquire(&mut conn))?;
367
368 let pool = r2d2::Pool::builder()
369 .max_size(settings.db.max_pool_size)
371 .min_idle(Some(0))
373 .idle_timeout(Some(Duration::from_secs(30)))
375 .connection_customizer(Box::new(ConnectionCustomizer))
376 .build(manager)
377 .expect("DATABASE_URL must be set to an SQLite DB file");
378
379 Ok(Db {
380 pool: Arc::new(pool),
381 url: database_url,
382 })
383 }
384
385 fn conn(&self) -> Result<Connection> {
386 let conn = self.pool.get()?;
387 Ok(conn)
388 }
389}
390
391#[derive(Debug, Clone, PartialEq)]
393pub(crate) struct ConnectionCustomizer;
394
395impl<C> CustomizeConnection<C, r2d2::Error> for ConnectionCustomizer
396where
397 C: diesel::Connection,
398{
399 fn on_acquire(&self, conn: &mut C) -> Result<(), r2d2::Error> {
400 conn.batch_execute(PRAGMAS).map_err(r2d2::Error::QueryError)
401 }
402}
403
404#[cfg(test)]
405mod test {
406 use super::*;
407 use crate::test_utils::db::MemoryDb;
408
409 #[homestar_runtime_proc_macro::db_async_test]
410 fn check_pragmas_memory_db() {
411 let settings = TestSettings::load();
412
413 let db = MemoryDb::setup_connection_pool(settings.node(), None).unwrap();
414 let mut conn = db.conn().unwrap();
415
416 let journal_mode = diesel::dsl::sql::<diesel::sql_types::Text>("PRAGMA journal_mode")
417 .load::<String>(&mut conn)
418 .unwrap();
419
420 assert_eq!(journal_mode, vec!["memory".to_string()]);
421
422 let fk_mode = diesel::dsl::sql::<diesel::sql_types::Text>("PRAGMA foreign_keys")
423 .load::<String>(&mut conn)
424 .unwrap();
425
426 assert_eq!(fk_mode, vec!["1".to_string()]);
427
428 let busy_timeout = diesel::dsl::sql::<diesel::sql_types::Text>("PRAGMA busy_timeout")
429 .load::<String>(&mut conn)
430 .unwrap();
431
432 assert_eq!(busy_timeout, vec!["1000".to_string()]);
433 }
434}