1use std::{
2 borrow::Cow, fs::remove_file, io::ErrorKind as IoErrorKind, str::FromStr,
3 thread::available_parallelism, time::Duration,
4};
5
6use sqlx::{
7 sqlite::{
8 SqliteAutoVacuum, SqliteConnectOptions, SqliteJournalMode, SqliteLockingMode, SqlitePool,
9 SqlitePoolOptions, SqliteSynchronous,
10 },
11 ConnectOptions, Error as SqlxError, Row,
12};
13
14use super::SqliteBackend;
15use crate::{
16 backend::{
17 db_utils::{init_keys, random_profile_name},
18 ManageBackend,
19 },
20 error::Error,
21 future::{sleep, unblock, BoxFuture},
22 options::{IntoOptions, Options},
23 protect::{KeyCache, PassKey, StoreKeyMethod, StoreKeyReference},
24};
25
26const DEFAULT_MIN_CONNECTIONS: usize = 1;
27const DEFAULT_LOWER_MAX_CONNECTIONS: usize = 4;
28const DEFAULT_UPPER_MAX_CONNECTIONS: usize = 8;
29const DEFAULT_BUSY_TIMEOUT: Duration = Duration::from_secs(5);
30const DEFAULT_JOURNAL_MODE: SqliteJournalMode = SqliteJournalMode::Wal;
31const DEFAULT_LOCKING_MODE: SqliteLockingMode = SqliteLockingMode::Normal;
32const DEFAULT_SYNCHRONOUS: SqliteSynchronous = SqliteSynchronous::Full;
33
34#[derive(Debug)]
36pub struct SqliteStoreOptions {
37 pub(crate) in_memory: bool,
38 pub(crate) path: String,
39 pub(crate) busy_timeout: Duration,
40 pub(crate) max_connections: u32,
41 pub(crate) min_connections: u32,
42 pub(crate) journal_mode: SqliteJournalMode,
43 pub(crate) locking_mode: SqliteLockingMode,
44 pub(crate) shared_cache: bool,
45 pub(crate) synchronous: SqliteSynchronous,
46}
47
48impl Default for SqliteStoreOptions {
49 fn default() -> Self {
50 Self::new(":memory:").expect("Error initializing with default options")
51 }
52}
53
54impl SqliteStoreOptions {
55 pub fn new<'a>(options: impl IntoOptions<'a>) -> Result<Self, Error> {
57 let mut opts = options.into_options()?;
58 let mut path = opts.host.to_string();
59 path.push_str(&opts.path);
60 let in_memory = path == ":memory:";
61
62 let busy_timeout = if let Some(timeout) = opts.query.remove("busy_timeout") {
63 Duration::from_millis(
64 timeout
65 .parse()
66 .map_err(err_map!(Input, "Error parsing 'busy_timeout' parameter"))?,
67 )
68 } else {
69 DEFAULT_BUSY_TIMEOUT
70 };
71 let max_connections = if let Some(max_conn) = opts.query.remove("max_connections") {
72 max_conn
73 .parse()
74 .map_err(err_map!(Input, "Error parsing 'max_connections' parameter"))?
75 } else {
76 #[allow(clippy::manual_clamp)]
77 {
78 available_parallelism()
79 .map_err(err_map!(
80 Unexpected,
81 "Error determining available parallelism"
82 ))?
83 .get()
84 .max(DEFAULT_LOWER_MAX_CONNECTIONS)
85 .min(DEFAULT_UPPER_MAX_CONNECTIONS) as u32
86 }
87 };
88 let min_connections = if let Some(min_conn) = opts.query.remove("min_connections") {
89 min_conn
90 .parse()
91 .map_err(err_map!(Input, "Error parsing 'min_connections' parameter"))?
92 } else {
93 DEFAULT_MIN_CONNECTIONS as u32
94 };
95 let journal_mode = if let Some(mode) = opts.query.remove("journal_mode") {
96 SqliteJournalMode::from_str(&mode)
97 .map_err(err_map!(Input, "Error parsing 'journal_mode' parameter"))?
98 } else {
99 DEFAULT_JOURNAL_MODE
100 };
101 let locking_mode = if let Some(mode) = opts.query.remove("locking_mode") {
102 SqliteLockingMode::from_str(&mode)
103 .map_err(err_map!(Input, "Error parsing 'locking_mode' parameter"))?
104 } else {
105 DEFAULT_LOCKING_MODE
106 };
107 let shared_cache = if let Some(cache) = opts.query.remove("cache") {
108 cache.eq_ignore_ascii_case("shared")
109 } else {
110 in_memory
111 };
112 let synchronous = if let Some(sync) = opts.query.remove("synchronous") {
113 SqliteSynchronous::from_str(&sync)
114 .map_err(err_map!(Input, "Error parsing 'synchronous' parameter"))?
115 } else {
116 DEFAULT_SYNCHRONOUS
117 };
118
119 Ok(Self {
120 in_memory,
121 path,
122 busy_timeout,
123 max_connections,
124 min_connections,
125 journal_mode,
126 locking_mode,
127 shared_cache,
128 synchronous,
129 })
130 }
131
132 async fn pool(&self, auto_create: bool) -> std::result::Result<SqlitePool, SqlxError> {
133 #[allow(unused_mut)]
134 let mut conn_opts = SqliteConnectOptions::from_str(self.path.as_ref())?
135 .create_if_missing(auto_create)
136 .auto_vacuum(SqliteAutoVacuum::Incremental)
137 .busy_timeout(self.busy_timeout)
138 .journal_mode(self.journal_mode)
139 .locking_mode(self.locking_mode)
140 .shared_cache(self.shared_cache)
141 .synchronous(self.synchronous);
142 #[cfg(feature = "log")]
143 {
144 conn_opts = conn_opts
145 .log_statements(log::LevelFilter::Debug)
146 .log_slow_statements(log::LevelFilter::Debug, Default::default());
147 }
148 SqlitePoolOptions::default()
149 .min_connections(self.min_connections)
153 .max_connections(self.max_connections)
154 .test_before_acquire(false)
155 .connect_with(conn_opts)
156 .await
157 }
158
159 pub async fn provision(
161 self,
162 method: StoreKeyMethod,
163 pass_key: PassKey<'_>,
164 profile: Option<String>,
165 recreate: bool,
166 ) -> Result<SqliteBackend, Error> {
167 if recreate && !self.in_memory {
168 try_remove_file(self.path.to_string()).await?;
169 }
170 let conn_pool = self
171 .pool(true)
172 .await
173 .map_err(err_map!(Backend, "Error creating database pool"))?;
174
175 if !recreate {
176 let mut conn = conn_pool.acquire().await?;
177 let found = sqlx::query_scalar::<_, i64>(
178 "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='config'",
179 )
180 .fetch_one(conn.as_mut())
181 .await
182 .map_err(err_map!(Backend, "Error checking for existing store"))?
183 == 1;
184 conn.return_to_pool().await;
185 if found {
186 return open_db(
187 conn_pool,
188 Some(method),
189 pass_key,
190 profile,
191 self.path.to_string(),
192 )
193 .await;
194 }
195 }
196 let default_profile = profile.unwrap_or_else(random_profile_name);
199 let key_cache = init_db(&conn_pool, &default_profile, method, pass_key).await?;
200
201 Ok(SqliteBackend::new(
202 conn_pool,
203 default_profile,
204 key_cache,
205 self.path.to_string(),
206 ))
207 }
208
209 pub async fn open(
211 self,
212 method: Option<StoreKeyMethod>,
213 pass_key: PassKey<'_>,
214 profile: Option<String>,
215 ) -> Result<SqliteBackend, Error> {
216 let conn_pool = match self.pool(false).await {
217 Ok(pool) => Ok(pool),
218 Err(SqlxError::Database(db_err)) => {
219 if db_err.code().expect("Expected SQLite error code") == "14" {
220 Err(err_msg!(
222 NotFound,
223 "The requested database path was not found"
224 ))
225 } else {
226 Err(err_msg!(Backend, "Error connecting to database pool")
227 .with_cause(SqlxError::Database(db_err)))
228 }
229 }
230 Err(err) => Err(err.into()),
231 }?;
232 open_db(conn_pool, method, pass_key, profile, self.path.to_string()).await
233 }
234
235 pub async fn remove(self) -> Result<bool, Error> {
237 if self.in_memory {
238 Ok(true)
239 } else {
240 try_remove_file(self.path.to_string()).await
241 }
242 }
243
244 pub fn in_memory() -> Self {
246 Self::from_path(":memory:")
247 }
248
249 pub fn from_path(path: &str) -> Self {
251 let opts = Options {
252 host: Cow::Borrowed(path),
253 ..Default::default()
254 };
255 Self::new(opts).unwrap()
256 }
257}
258
259impl<'a> ManageBackend<'a> for SqliteStoreOptions {
260 type Backend = SqliteBackend;
261
262 fn open_backend(
263 self,
264 method: Option<StoreKeyMethod>,
265 pass_key: PassKey<'a>,
266 profile: Option<String>,
267 ) -> BoxFuture<'a, Result<SqliteBackend, Error>> {
268 Box::pin(self.open(method, pass_key, profile))
269 }
270
271 fn provision_backend(
272 self,
273 method: StoreKeyMethod,
274 pass_key: PassKey<'a>,
275 profile: Option<String>,
276 recreate: bool,
277 ) -> BoxFuture<'a, Result<SqliteBackend, Error>> {
278 Box::pin(self.provision(method, pass_key, profile, recreate))
279 }
280
281 fn remove_backend(self) -> BoxFuture<'a, Result<bool, Error>> {
282 Box::pin(self.remove())
283 }
284}
285
286async fn init_db(
287 conn_pool: &SqlitePool,
288 profile_name: &str,
289 method: StoreKeyMethod,
290 pass_key: PassKey<'_>,
291) -> Result<KeyCache, Error> {
292 let (profile_key, enc_profile_key, store_key, store_key_ref) = unblock({
293 let pass_key = pass_key.into_owned();
294 move || init_keys(method, pass_key)
295 })
296 .await?;
297
298 let mut conn = conn_pool.acquire().await?;
299
300 sqlx::query(
301 r#"
302 BEGIN EXCLUSIVE TRANSACTION;
303
304 CREATE TABLE config (
305 name TEXT NOT NULL,
306 value TEXT,
307 PRIMARY KEY (name)
308 );
309 INSERT INTO config (name, value) VALUES
310 ("default_profile", ?1),
311 ("key", ?2),
312 ("version", "1");
313
314 CREATE TABLE profiles (
315 id INTEGER NOT NULL,
316 name TEXT NOT NULL,
317 reference TEXT NULL,
318 profile_key BLOB NULL,
319 PRIMARY KEY(id)
320 );
321 CREATE UNIQUE INDEX ix_profile_name ON profiles (name);
322
323 CREATE TABLE items (
324 id INTEGER NOT NULL,
325 profile_id INTEGER NOT NULL,
326 kind INTEGER NOT NULL,
327 category BLOB NOT NULL,
328 name BLOB NOT NULL,
329 value BLOB NOT NULL,
330 expiry DATETIME NULL,
331 PRIMARY KEY (id),
332 FOREIGN KEY (profile_id) REFERENCES profiles (id)
333 ON DELETE CASCADE ON UPDATE CASCADE
334 );
335 CREATE UNIQUE INDEX ix_items_uniq ON items (profile_id, kind, category, name);
336
337 CREATE TABLE items_tags (
338 id INTEGER NOT NULL,
339 item_id INTEGER NOT NULL,
340 name BLOB NOT NULL,
341 value BLOB NOT NULL,
342 plaintext BOOLEAN NOT NULL,
343 PRIMARY KEY (id),
344 FOREIGN KEY (item_id) REFERENCES items (id)
345 ON DELETE CASCADE ON UPDATE CASCADE
346 );
347 CREATE INDEX ix_items_tags_item_id ON items_tags (item_id);
348 CREATE INDEX ix_items_tags_name_enc ON items_tags (name, SUBSTR(value, 1, 12)) WHERE plaintext=0;
349 CREATE INDEX ix_items_tags_name_plain ON items_tags (name, value) WHERE plaintext=1;
350
351 INSERT INTO profiles (name, profile_key) VALUES (?1, ?3);
352
353 COMMIT;
354 "#,
355 )
356 .persistent(false)
357 .bind(profile_name)
358 .bind(store_key_ref)
359 .bind(enc_profile_key)
360 .execute(conn.as_mut())
361 .await.map_err(err_map!(Backend, "Error creating database tables"))?;
362
363 let row = sqlx::query("SELECT id FROM profiles WHERE name = ?1")
364 .persistent(false)
365 .bind(profile_name)
366 .fetch_one(conn.as_mut())
367 .await
368 .map_err(err_map!(Backend, "Error checking for existing profile"))?;
369 conn.return_to_pool().await;
370
371 let mut key_cache = KeyCache::new(store_key);
372 key_cache.add_profile_mut(profile_name.to_string(), row.try_get(0)?, profile_key);
373
374 Ok(key_cache)
375}
376
377async fn open_db(
378 conn_pool: SqlitePool,
379 method: Option<StoreKeyMethod>,
380 pass_key: PassKey<'_>,
381 profile: Option<String>,
382 path: String,
383) -> Result<SqliteBackend, Error> {
384 let mut conn = conn_pool.acquire().await?;
385 let mut ver_ok = false;
386 let mut default_profile: Option<String> = None;
387 let mut store_key_ref: Option<String> = None;
388
389 let config = sqlx::query(
390 r#"SELECT name, value FROM config
391 WHERE name IN ("default_profile", "key", "version")"#,
392 )
393 .fetch_all(conn.as_mut())
394 .await
395 .map_err(err_map!(Backend, "Error fetching store configuration"))?;
396 for row in config {
397 match row.try_get(0)? {
398 "default_profile" => {
399 default_profile.replace(row.try_get(1)?);
400 }
401 "key" => {
402 store_key_ref.replace(row.try_get(1)?);
403 }
404 "version" => {
405 if row.try_get::<&str, _>(1)? != "1" {
406 return Err(err_msg!(Unsupported, "Unsupported store version"));
407 }
408 ver_ok = true;
409 }
410 _ => (),
411 }
412 }
413 if !ver_ok {
414 return Err(err_msg!(Unsupported, "Store version not found"));
415 }
416 let profile = profile
417 .or(default_profile)
418 .ok_or_else(|| err_msg!(Unsupported, "Default store profile not found"))?;
419 let store_key = if let Some(store_key_ref) = store_key_ref {
420 let wrap_ref = StoreKeyReference::parse_uri(&store_key_ref)?;
421 if let Some(method) = method {
422 if !wrap_ref.compare_method(&method) {
423 return Err(err_msg!(Input, "Store key method mismatch"));
424 }
425 }
426 unblock({
427 let pass_key = pass_key.into_owned();
428 move || wrap_ref.resolve(pass_key)
429 })
430 .await?
431 } else {
432 return Err(err_msg!(Unsupported, "Store key not found"));
433 };
434
435 let mut key_cache = KeyCache::new(store_key);
436 let row = sqlx::query("SELECT id, profile_key FROM profiles WHERE name = ?1")
437 .bind(&profile)
438 .fetch_one(conn.as_mut())
439 .await?;
440 let profile_id = row.try_get(0)?;
441 let profile_key = key_cache.load_key(row.try_get(1)?).await?;
442 conn.return_to_pool().await;
443 key_cache.add_profile_mut(profile.clone(), profile_id, profile_key);
444
445 Ok(SqliteBackend::new(conn_pool, profile, key_cache, path))
446}
447
448async fn try_remove_file(path: String) -> Result<bool, Error> {
449 let mut retries = 0;
450 loop {
451 let path = path.clone();
452 if let Some(res) = unblock(move || match remove_file(path) {
453 Ok(()) => Ok(Some(true)),
454 Err(err) if err.kind() == IoErrorKind::NotFound => Ok(Some(false)),
455 #[cfg(target_os = "windows")]
456 Err(err) if err.raw_os_error() == Some(32) => Ok(None),
459 Err(err) => Err(err_msg!(Backend, "Error removing file").with_cause(err)),
460 })
461 .await?
462 {
463 break Ok(res);
464 } else {
465 sleep(Duration::from_millis(50)).await;
466 retries += 1;
467 if retries >= 10 {
468 return Err(err_msg!(Backend, "Error removing file: still in use"));
469 }
470 }
471 }
472}