1use std::borrow::Cow;
2use std::str::FromStr;
3use std::time::Duration;
4
5use sqlx::{
6 postgres::{PgConnectOptions, PgConnection, PgPool, PgPoolOptions, Postgres},
7 ConnectOptions, Connection, Error as SqlxError, Executor, Row, Transaction,
8};
9
10use crate::{
11 backend::{
12 db_utils::{init_keys, random_profile_name},
13 ManageBackend,
14 },
15 error::Error,
16 future::{unblock, BoxFuture},
17 options::IntoOptions,
18 protect::{KeyCache, PassKey, ProfileId, StoreKeyMethod, StoreKeyReference},
19};
20
21use super::PostgresBackend;
22
23const DEFAULT_CONNECT_TIMEOUT: u64 = 30;
24const DEFAULT_IDLE_TIMEOUT: u64 = 300;
25const DEFAULT_MIN_CONNECTIONS: u32 = 0;
26const DEFAULT_MAX_CONNECTIONS: u32 = 10;
27
28#[derive(Debug)]
30pub struct PostgresStoreOptions {
31 pub(crate) connect_timeout: Duration,
32 pub(crate) idle_timeout: Duration,
33 pub(crate) max_connections: u32,
34 pub(crate) min_connections: u32,
35 pub(crate) uri: String,
36 pub(crate) admin_uri: String,
37 pub(crate) host: String,
38 pub(crate) name: String,
39 pub(crate) username: String,
40 pub(crate) schema: Option<String>,
41}
42
43impl PostgresStoreOptions {
44 pub fn new<'a, O>(options: O) -> Result<Self, Error>
46 where
47 O: IntoOptions<'a>,
48 {
49 let mut opts = options.into_options()?;
50 let connect_timeout = if let Some(timeout) = opts.query.remove("connect_timeout") {
51 timeout
52 .parse()
53 .map_err(err_map!(Input, "Error parsing 'connect_timeout' parameter"))?
54 } else {
55 DEFAULT_CONNECT_TIMEOUT
56 };
57 let idle_timeout = if let Some(timeout) = opts.query.remove("idle_timeout") {
58 timeout
59 .parse()
60 .map_err(err_map!(Input, "Error parsing 'idle_timeout' parameter"))?
61 } else {
62 DEFAULT_IDLE_TIMEOUT
63 };
64 let max_connections = if let Some(max_conn) = opts.query.remove("max_connections") {
65 max_conn
66 .parse()
67 .map_err(err_map!(Input, "Error parsing 'max_connections' parameter"))?
68 } else {
69 DEFAULT_MAX_CONNECTIONS
70 };
71 let min_connections = if let Some(min_conn) = opts.query.remove("min_connections") {
72 min_conn
73 .parse()
74 .map_err(err_map!(Input, "Error parsing 'min_connections' parameter"))?
75 } else {
76 DEFAULT_MIN_CONNECTIONS
77 };
78 let schema = opts.query.remove("schema");
79 let admin_acct = opts.query.remove("admin_account");
80 let admin_pass = opts.query.remove("admin_password");
81 let username = match opts.user.as_ref() {
82 "" => "postgres".to_owned(),
83 a => a.to_owned(),
84 };
85 let uri = opts.clone().into_uri();
86 if admin_acct.is_some() || admin_pass.is_some() {
87 if let Some(admin_acct) = admin_acct {
88 opts.user = Cow::Owned(admin_acct);
89 }
90 if let Some(admin_pass) = admin_pass {
91 opts.password = Cow::Owned(admin_pass);
92 }
93 }
94 let host = opts.host.to_string();
95 let path = opts.path.as_ref();
96 if path.len() < 2 {
97 return Err(err_msg!(Input, "Missing database name"));
98 }
99 let name = path[1..].to_string();
100 if let Some(schema) = schema.as_ref() {
101 _validate_ident(schema, "schema")?;
102 }
103 _validate_ident(&name, "database")?;
104 _validate_ident(&username, "username")?;
105 opts.path = Cow::Borrowed("/postgres");
107 Ok(Self {
108 connect_timeout: Duration::from_secs(connect_timeout),
109 idle_timeout: Duration::from_secs(idle_timeout),
110 max_connections,
111 min_connections,
112 uri,
113 admin_uri: opts.into_uri(),
114 host,
115 name,
116 username,
117 schema,
118 })
119 }
120
121 async fn pool(&self) -> Result<PgPool, SqlxError> {
122 #[allow(unused_mut)]
123 let mut conn_opts = PgConnectOptions::from_str(self.uri.as_str())?;
124 #[cfg(feature = "log")]
125 {
126 conn_opts = conn_opts
127 .log_statements(log::LevelFilter::Debug)
128 .log_slow_statements(log::LevelFilter::Debug, Default::default());
129 }
130 if let Some(s) = self.schema.as_ref() {
131 conn_opts = conn_opts.options([("search_path", s)]);
133 }
134 PgPoolOptions::default()
135 .acquire_timeout(self.connect_timeout)
136 .idle_timeout(self.idle_timeout)
137 .max_connections(self.max_connections)
138 .min_connections(self.min_connections)
139 .test_before_acquire(false)
140 .connect_with(conn_opts)
141 .await
142 }
143
144 pub(crate) async fn create_db_pool(&self) -> Result<PgPool, Error> {
145 match self.pool().await {
147 Ok(pool) => Ok(pool),
148 Err(SqlxError::Database(db_err)) if db_err.code() == Some(Cow::Borrowed("3D000")) => {
149 let mut admin_conn = PgConnection::connect(self.admin_uri.as_ref())
152 .await
153 .map_err(err_map!(
154 Backend,
155 "Error creating admin connection to database"
156 ))?;
157 let create_q = format!(
159 "CREATE DATABASE \"{}\" OWNER \"{}\"",
160 self.name, self.username
161 );
162 match admin_conn.execute(create_q.as_str()).await {
163 Ok(_) => (),
164 Err(SqlxError::Database(db_err))
165 if db_err.code() == Some(Cow::Borrowed("23505"))
166 || db_err.code() == Some(Cow::Borrowed("42P04")) =>
167 {
168 }
173 Err(err) => {
174 admin_conn.close().await?;
175 return Err(err_msg!(Backend, "Error creating database").with_cause(err));
176 }
177 }
178 admin_conn.close().await?;
179 Ok(self.pool().await?)
180 }
181 Err(err) => Err(err_msg!(Backend, "Error opening database").with_cause(err)),
182 }
183 }
184
185 pub async fn provision(
187 self,
188 method: StoreKeyMethod,
189 pass_key: PassKey<'_>,
190 profile: Option<String>,
191 recreate: bool,
192 ) -> Result<PostgresBackend, Error> {
193 let conn_pool = self.create_db_pool().await?;
194 let mut conn = conn_pool.acquire().await?;
195 let mut txn = conn.begin().await?;
196
197 if recreate {
198 reset_db(&mut txn).await?;
200 } else {
201 let count = if let Some(schema) = self.schema.as_ref() {
203 sqlx::query_scalar::<_, i64>(
204 "SELECT COUNT(*) FROM information_schema.tables
205 WHERE table_schema=?1 AND table_name='config'",
206 )
207 .persistent(false)
208 .bind(schema)
209 .fetch_one(txn.as_mut())
210 .await
211 .map_err(err_map!(Backend, "Error checking for existing store"))?
212 } else {
213 sqlx::query_scalar::<_, i64>(
214 "SELECT COUNT(*) FROM information_schema.tables
215 WHERE table_schema=ANY (CURRENT_SCHEMAS(false)) AND table_name='config'",
216 )
217 .persistent(false)
218 .fetch_one(txn.as_mut())
219 .await
220 .map_err(err_map!(Backend, "Error checking for existing store"))?
221 };
222 if count > 0 {
223 return open_db(
225 conn_pool,
226 Some(method),
227 pass_key,
228 profile,
229 self.host,
230 self.name,
231 )
232 .await;
233 }
234 }
235
236 let (profile_key, enc_profile_key, store_key, store_key_ref) = unblock({
239 let pass_key = pass_key.into_owned();
240 move || init_keys(method, pass_key)
241 })
242 .await?;
243 let default_profile = profile.unwrap_or_else(random_profile_name);
244 let profile_id = init_db(
245 txn,
246 &default_profile,
247 store_key_ref,
248 enc_profile_key,
249 self.schema.as_ref().unwrap_or(&self.username),
250 )
251 .await?;
252 conn.return_to_pool().await;
253
254 let mut key_cache = KeyCache::new(store_key);
255 key_cache.add_profile_mut(default_profile.clone(), profile_id, profile_key);
256
257 Ok(PostgresBackend::new(
258 conn_pool,
259 default_profile,
260 key_cache,
261 self.host,
262 self.name,
263 ))
264 }
265
266 pub async fn open(
268 self,
269 method: Option<StoreKeyMethod>,
270 pass_key: PassKey<'_>,
271 profile: Option<String>,
272 ) -> Result<PostgresBackend, Error> {
273 let pool = match self.pool().await {
274 Ok(p) => Ok(p),
275 Err(SqlxError::Database(db_err)) if db_err.code() == Some(Cow::Borrowed("3D000")) => {
276 Err(err_msg!(NotFound, "The requested database was not found"))
279 }
280 Err(err) => Err(err_msg!(Backend, "Error connecting to database pool").with_cause(err)),
281 }?;
282 open_db(pool, method, pass_key, profile, self.host, self.name).await
283 }
284
285 pub async fn remove(self) -> Result<bool, Error> {
287 let mut admin_conn = PgConnection::connect(self.admin_uri.as_ref())
288 .await
289 .map_err(err_map!(
290 Backend,
291 "Error creating admin connection to database"
292 ))?;
293 let drop_q = format!("DROP DATABASE \"{}\"", self.name);
296 let res = match admin_conn.execute(drop_q.as_str()).await {
297 Ok(_) => Ok(true),
298 Err(SqlxError::Database(db_err)) if db_err.code() == Some(Cow::Borrowed("3D000")) => {
299 Ok(false)
301 }
302 Err(err) => Err(err_msg!(Backend, "Error removing database").with_cause(err)),
303 }?;
304 admin_conn.close().await?;
305 Ok(res)
306 }
307}
308
309impl<'a> ManageBackend<'a> for PostgresStoreOptions {
310 type Backend = PostgresBackend;
311
312 fn open_backend(
313 self,
314 method: Option<StoreKeyMethod>,
315 pass_key: PassKey<'_>,
316 profile: Option<String>,
317 ) -> BoxFuture<'a, Result<PostgresBackend, Error>> {
318 let pass_key = pass_key.into_owned();
319 Box::pin(self.open(method, pass_key, profile))
320 }
321
322 fn provision_backend(
323 self,
324 method: StoreKeyMethod,
325 pass_key: PassKey<'_>,
326 profile: Option<String>,
327 recreate: bool,
328 ) -> BoxFuture<'a, Result<PostgresBackend, Error>> {
329 let pass_key = pass_key.into_owned();
330 Box::pin(self.provision(method, pass_key, profile, recreate))
331 }
332
333 fn remove_backend(self) -> BoxFuture<'a, Result<bool, Error>> {
334 Box::pin(self.remove())
335 }
336}
337
338pub(crate) async fn init_db(
339 mut txn: Transaction<'_, Postgres>,
340 profile_name: &str,
341 store_key_ref: String,
342 enc_profile_key: Vec<u8>,
343 schema: &str,
344) -> Result<ProfileId, Error> {
345 txn.execute(
346 format!(r#"
347 CREATE SCHEMA IF NOT EXISTS "{schema}";
348
349 CREATE TABLE "{schema}".config (
350 name TEXT NOT NULL,
351 value TEXT,
352 PRIMARY KEY(name)
353 );
354
355 CREATE TABLE "{schema}".profiles (
356 id BIGSERIAL,
357 name TEXT NOT NULL,
358 reference TEXT NULL,
359 profile_key BYTEA NULL,
360 PRIMARY KEY(id)
361 );
362 CREATE UNIQUE INDEX ix_profile_name ON "{schema}".profiles(name);
363
364 CREATE TABLE "{schema}".items (
365 id BIGSERIAL,
366 profile_id BIGINT NOT NULL,
367 kind SMALLINT NOT NULL,
368 category BYTEA NOT NULL,
369 name BYTEA NOT NULL,
370 value BYTEA NOT NULL,
371 expiry TIMESTAMP NULL,
372 PRIMARY KEY(id),
373 FOREIGN KEY(profile_id) REFERENCES "{schema}".profiles(id)
374 ON DELETE CASCADE ON UPDATE CASCADE
375 );
376 CREATE UNIQUE INDEX ix_items_uniq ON "{schema}".items(profile_id, kind, category, name);
377
378 CREATE TABLE "{schema}".items_tags (
379 id BIGSERIAL,
380 item_id BIGINT NOT NULL,
381 name BYTEA NOT NULL,
382 value BYTEA NOT NULL,
383 plaintext SMALLINT NOT NULL,
384 PRIMARY KEY(id),
385 FOREIGN KEY(item_id) REFERENCES "{schema}".items(id)
386 ON DELETE CASCADE ON UPDATE CASCADE
387 );
388 CREATE INDEX ix_items_tags_item_id ON "{schema}".items_tags(item_id);
389 CREATE INDEX ix_items_tags_name_enc ON "{schema}".items_tags(name, SUBSTR(value, 1, 12)) INCLUDE (item_id) WHERE plaintext=0;
390 CREATE INDEX ix_items_tags_name_plain ON "{schema}".items_tags(name, value) INCLUDE (item_id) WHERE plaintext=1;
391 "#).as_str(),
392 )
393 .await
394 .map_err(err_map!(Backend, "Error creating database tables"))?;
395
396 sqlx::query(
397 "INSERT INTO config (name, value) VALUES
398 ('default_profile', $1),
399 ('key', $2),
400 ('version', '1')",
401 )
402 .persistent(false)
403 .bind(profile_name)
404 .bind(store_key_ref)
405 .execute(txn.as_mut())
406 .await
407 .map_err(err_map!(Backend, "Error inserting configuration"))?;
408
409 let profile_id =
410 sqlx::query_scalar("INSERT INTO profiles (name, profile_key) VALUES ($1, $2) RETURNING id")
411 .bind(profile_name)
412 .bind(enc_profile_key)
413 .fetch_one(txn.as_mut())
414 .await
415 .map_err(err_map!(Backend, "Error inserting default profile"))?;
416
417 txn.commit().await?;
418
419 Ok(profile_id)
420}
421
422pub(crate) async fn reset_db(conn: &mut PgConnection) -> Result<(), Error> {
423 conn.execute(
424 "
425 DROP TABLE IF EXISTS
426 config, profiles,
427 profile_keys, keys,
428 items, items_tags;
429 ",
430 )
431 .await?;
432 Ok(())
433}
434
435pub(crate) async fn open_db(
436 conn_pool: PgPool,
437 method: Option<StoreKeyMethod>,
438 pass_key: PassKey<'_>,
439 profile: Option<String>,
440 host: String,
441 name: String,
442) -> Result<PostgresBackend, Error> {
443 let mut conn = conn_pool.acquire().await?;
444 let mut ver_ok = false;
445 let mut default_profile: Option<String> = None;
446 let mut store_key_ref: Option<String> = None;
447
448 let config = sqlx::query(
449 r#"SELECT name, value FROM config
450 WHERE name IN ('default_profile', 'key', 'version')"#,
451 )
452 .fetch_all(conn.as_mut())
453 .await
454 .map_err(err_map!(Backend, "Error fetching store configuration"))?;
455 for row in config {
456 match row.try_get(0)? {
457 "default_profile" => {
458 default_profile.replace(row.try_get(1)?);
459 }
460 "key" => {
461 store_key_ref.replace(row.try_get(1)?);
462 }
463 "version" => {
464 if row.try_get::<&str, _>(1)? != "1" {
465 return Err(err_msg!(Unsupported, "Unsupported store version"));
466 }
467 ver_ok = true;
468 }
469 _ => (),
470 }
471 }
472 if !ver_ok {
473 return Err(err_msg!(Unsupported, "Store version not found"));
474 }
475 let profile = profile
476 .or(default_profile)
477 .ok_or_else(|| err_msg!(Unsupported, "Default store profile not found"))?;
478 let store_key = if let Some(store_key_ref) = store_key_ref {
479 let wrap_ref = StoreKeyReference::parse_uri(&store_key_ref)?;
480 if let Some(method) = method {
481 if !wrap_ref.compare_method(&method) {
482 return Err(err_msg!(Input, "Store key method mismatch"));
483 }
484 }
485 unblock({
486 let pass_key = pass_key.into_owned();
487 move || wrap_ref.resolve(pass_key)
488 })
489 .await?
490 } else {
491 return Err(err_msg!(Unsupported, "Store key not found"));
492 };
493
494 let mut key_cache = KeyCache::new(store_key);
495 let row = sqlx::query("SELECT id, profile_key FROM profiles WHERE name = $1")
496 .bind(&profile)
497 .fetch_one(conn.as_mut())
498 .await?;
499 let profile_id = row.try_get(0)?;
500 let profile_key = key_cache.load_key(row.try_get(1)?).await?;
501 conn.return_to_pool().await;
502
503 key_cache.add_profile_mut(profile.clone(), profile_id, profile_key);
504
505 Ok(PostgresBackend::new(
506 conn_pool, profile, key_cache, host, name,
507 ))
508}
509
510fn _validate_ident(ident: &str, name: &str) -> Result<(), Error> {
514 if ident.is_empty() {
515 Err(err_msg!(Input, "{name} identifier is empty"))
516 } else if ident.find(['"', '\0']).is_some() {
517 Err(err_msg!(
518 Input,
519 "Invalid character in {name} identifier: '\"' and '\\0' are disallowed"
520 ))
521 } else {
522 Ok(())
523 }
524}
525
526#[cfg(test)]
527mod tests {
528 use super::*;
529
530 #[test]
531 fn postgres_parse_uri() {
532 let uri = "postgres://user:pass@host/db_name\
533 ?admin_account=user2&admin_password=pass2\
534 &connect_timeout=9&max_connections=23&min_connections=32\
535 &idle_timeout=99\
536 &test=1";
537 let opts = PostgresStoreOptions::new(uri).unwrap();
538 assert_eq!(opts.max_connections, 23);
539 assert_eq!(opts.min_connections, 32);
540 assert_eq!(opts.connect_timeout, Duration::from_secs(9));
541 assert_eq!(opts.idle_timeout, Duration::from_secs(99));
542 assert_eq!(opts.uri, "postgres://user:pass@host/db_name?test=1");
543 assert_eq!(
544 opts.admin_uri,
545 "postgres://user2:pass2@host/postgres?test=1"
546 );
547 }
548}