forest/utils/sqlite/
mod.rs1#![allow(dead_code)]
9
10#[cfg(test)]
11mod tests;
12
13use anyhow::Context as _;
14use sqlx::{
15 SqlitePool,
16 query::Query,
17 sqlite::{
18 SqliteArguments, SqliteAutoVacuum, SqliteConnectOptions, SqliteJournalMode,
19 SqliteSynchronous,
20 },
21};
22use std::{cmp::Ordering, path::Path, time::Instant};
23
24pub type SqliteQuery<'q> = Query<'q, sqlx::Sqlite, SqliteArguments<'q>>;
25
26pub async fn open_file(file: &Path) -> anyhow::Result<SqlitePool> {
28 if let Some(dir) = file.parent()
29 && !dir.is_dir()
30 {
31 std::fs::create_dir_all(dir)?;
32 }
33 let options = SqliteConnectOptions::new().filename(file);
34 Ok(open(options).await?)
35}
36
37pub async fn open_memory() -> sqlx::Result<SqlitePool> {
39 open(
40 SqliteConnectOptions::new()
41 .in_memory(true)
42 .shared_cache(true),
43 )
44 .await
45}
46
47pub async fn open(options: SqliteConnectOptions) -> sqlx::Result<SqlitePool> {
49 let options = options
50 .synchronous(SqliteSynchronous::Normal)
51 .pragma("temp_store", "memory")
52 .pragma("mmap_size", "30000000000")
53 .auto_vacuum(SqliteAutoVacuum::None)
54 .journal_mode(SqliteJournalMode::Wal)
55 .pragma("journal_size_limit", "0") .foreign_keys(true)
57 .read_only(false);
58 SqlitePool::connect_with(options).await
59}
60
61pub async fn init_db<'q>(
68 db: &SqlitePool,
69 name: &str,
70 ddls: impl IntoIterator<Item = SqliteQuery<'q>>,
71 version_migrations: Vec<SqliteQuery<'q>>,
72) -> anyhow::Result<()> {
73 let schema_version = version_migrations.len() + 1;
74
75 let init = async |db: &SqlitePool, schema_version| {
76 let mut tx = db.begin().await?;
77 sqlx::query("CREATE TABLE IF NOT EXISTS _meta (version UINT64 NOT NULL UNIQUE)")
78 .execute(tx.as_mut())
79 .await?;
80 for i in 1..=schema_version {
81 sqlx::query("INSERT OR IGNORE INTO _meta (version) VALUES (?)")
82 .bind(i as i64)
83 .execute(tx.as_mut())
84 .await?;
85 }
86 for ddl in ddls.into_iter() {
87 ddl.execute(tx.as_mut()).await?;
88 }
89 tx.commit().await
90 };
91
92 if sqlx::query("SELECT name FROM sqlite_master WHERE type='table' AND name='_meta';")
93 .fetch_optional(db)
94 .await
95 .map_err(|e| anyhow::anyhow!("error looking for {name} database _meta table: {e}"))?
96 .is_none()
97 {
98 init(db, schema_version).await?;
99 }
100
101 let found_version: u64 = sqlx::query_scalar("SELECT max(version) FROM _meta")
102 .fetch_optional(db)
103 .await?
104 .with_context(|| format!("invalid {name} database version: no version found"))?;
105 anyhow::ensure!(found_version > 0, "schema version should be 1 based");
106
107 let run_vacuum = match found_version.cmp(&(schema_version as _)) {
108 Ordering::Greater => {
109 anyhow::bail!(
110 "invalid {name} database version: version {found_version} is greater than the number of migrations {schema_version}"
111 );
112 }
113 Ordering::Equal => false,
114 Ordering::Less => true,
115 };
116
117 for (from_version, to_version, migration) in version_migrations
122 .into_iter()
123 .enumerate()
124 .map(|(i, m)| (i + 1, i + 2, m))
125 .skip(found_version as usize - 1)
127 {
128 tracing::info!("Migrating {name} database to version {to_version}");
129 let now = Instant::now();
130 let mut tx = db.begin().await?;
131 migration.execute(tx.as_mut()).await?;
132 sqlx::query("INSERT OR IGNORE INTO _meta (version) VALUES (?)")
133 .bind(to_version as i64)
134 .execute(tx.as_mut())
135 .await?;
136 tx.commit().await?;
137 tracing::info!(
138 "Successfully migrated {name} database from version {from_version} to {to_version} in {}",
139 humantime::format_duration(now.elapsed())
140 );
141 }
142
143 if run_vacuum {
144 tracing::info!(
148 "Performing {name} database vacuum and wal checkpointing to free up space after the migration"
149 );
150 if let Err(e) = sqlx::query("VACUUM").execute(db).await {
151 tracing::warn!("error vacuuming {name} database: {e}")
152 }
153 if let Err(e) = sqlx::query("PRAGMA wal_checkpoint(TRUNCATE)")
154 .execute(db)
155 .await
156 {
157 tracing::warn!("error checkpointing {name} database wal: {e}")
158 }
159 }
160
161 Ok(())
162}