askar_storage/backend/postgres/
provision.rs

1use std::borrow::Cow;
2use std::str::FromStr;
3use std::time::Duration;
4
5use sqlx::{
6    postgres::{PgConnectOptions, PgConnection, PgPool, PgPoolOptions, Postgres},
7    ConnectOptions, Connection, Error as SqlxError, Executor, Row, Transaction,
8};
9
10use crate::{
11    backend::{
12        db_utils::{init_keys, random_profile_name},
13        ManageBackend,
14    },
15    error::Error,
16    future::{unblock, BoxFuture},
17    options::IntoOptions,
18    protect::{KeyCache, PassKey, ProfileId, StoreKeyMethod, StoreKeyReference},
19};
20
21use super::PostgresBackend;
22
23const DEFAULT_CONNECT_TIMEOUT: u64 = 30;
24const DEFAULT_IDLE_TIMEOUT: u64 = 300;
25const DEFAULT_MIN_CONNECTIONS: u32 = 0;
26const DEFAULT_MAX_CONNECTIONS: u32 = 10;
27
28/// Configuration options for PostgreSQL stores
29#[derive(Debug)]
30pub struct PostgresStoreOptions {
31    pub(crate) connect_timeout: Duration,
32    pub(crate) idle_timeout: Duration,
33    pub(crate) max_connections: u32,
34    pub(crate) min_connections: u32,
35    pub(crate) uri: String,
36    pub(crate) admin_uri: String,
37    pub(crate) host: String,
38    pub(crate) name: String,
39    pub(crate) username: String,
40    pub(crate) schema: Option<String>,
41}
42
43impl PostgresStoreOptions {
44    /// Initialize `PostgresStoreOptions` from a generic set of options
45    pub fn new<'a, O>(options: O) -> Result<Self, Error>
46    where
47        O: IntoOptions<'a>,
48    {
49        let mut opts = options.into_options()?;
50        let connect_timeout = if let Some(timeout) = opts.query.remove("connect_timeout") {
51            timeout
52                .parse()
53                .map_err(err_map!(Input, "Error parsing 'connect_timeout' parameter"))?
54        } else {
55            DEFAULT_CONNECT_TIMEOUT
56        };
57        let idle_timeout = if let Some(timeout) = opts.query.remove("idle_timeout") {
58            timeout
59                .parse()
60                .map_err(err_map!(Input, "Error parsing 'idle_timeout' parameter"))?
61        } else {
62            DEFAULT_IDLE_TIMEOUT
63        };
64        let max_connections = if let Some(max_conn) = opts.query.remove("max_connections") {
65            max_conn
66                .parse()
67                .map_err(err_map!(Input, "Error parsing 'max_connections' parameter"))?
68        } else {
69            DEFAULT_MAX_CONNECTIONS
70        };
71        let min_connections = if let Some(min_conn) = opts.query.remove("min_connections") {
72            min_conn
73                .parse()
74                .map_err(err_map!(Input, "Error parsing 'min_connections' parameter"))?
75        } else {
76            DEFAULT_MIN_CONNECTIONS
77        };
78        let schema = opts.query.remove("schema");
79        let admin_acct = opts.query.remove("admin_account");
80        let admin_pass = opts.query.remove("admin_password");
81        let username = match opts.user.as_ref() {
82            "" => "postgres".to_owned(),
83            a => a.to_owned(),
84        };
85        let uri = opts.clone().into_uri();
86        if admin_acct.is_some() || admin_pass.is_some() {
87            if let Some(admin_acct) = admin_acct {
88                opts.user = Cow::Owned(admin_acct);
89            }
90            if let Some(admin_pass) = admin_pass {
91                opts.password = Cow::Owned(admin_pass);
92            }
93        }
94        let host = opts.host.to_string();
95        let path = opts.path.as_ref();
96        if path.len() < 2 {
97            return Err(err_msg!(Input, "Missing database name"));
98        }
99        let name = path[1..].to_string();
100        if let Some(schema) = schema.as_ref() {
101            _validate_ident(schema, "schema")?;
102        }
103        _validate_ident(&name, "database")?;
104        _validate_ident(&username, "username")?;
105        // admin user selects the default database
106        opts.path = Cow::Borrowed("/postgres");
107        Ok(Self {
108            connect_timeout: Duration::from_secs(connect_timeout),
109            idle_timeout: Duration::from_secs(idle_timeout),
110            max_connections,
111            min_connections,
112            uri,
113            admin_uri: opts.into_uri(),
114            host,
115            name,
116            username,
117            schema,
118        })
119    }
120
121    async fn pool(&self) -> Result<PgPool, SqlxError> {
122        #[allow(unused_mut)]
123        let mut conn_opts = PgConnectOptions::from_str(self.uri.as_str())?;
124        #[cfg(feature = "log")]
125        {
126            conn_opts = conn_opts
127                .log_statements(log::LevelFilter::Debug)
128                .log_slow_statements(log::LevelFilter::Debug, Default::default());
129        }
130        if let Some(s) = self.schema.as_ref() {
131            // NB: schema is a validated identifier
132            conn_opts = conn_opts.options([("search_path", s)]);
133        }
134        PgPoolOptions::default()
135            .acquire_timeout(self.connect_timeout)
136            .idle_timeout(self.idle_timeout)
137            .max_connections(self.max_connections)
138            .min_connections(self.min_connections)
139            .test_before_acquire(false)
140            .connect_with(conn_opts)
141            .await
142    }
143
144    pub(crate) async fn create_db_pool(&self) -> Result<PgPool, Error> {
145        // try connecting normally in case the database exists
146        match self.pool().await {
147            Ok(pool) => Ok(pool),
148            Err(SqlxError::Database(db_err)) if db_err.code() == Some(Cow::Borrowed("3D000")) => {
149                // error 3D000 is INVALID CATALOG NAME in postgres,
150                // this indicates that the database does not exist
151                let mut admin_conn = PgConnection::connect(self.admin_uri.as_ref())
152                    .await
153                    .map_err(err_map!(
154                        Backend,
155                        "Error creating admin connection to database"
156                    ))?;
157                // self.name and self.username are validated identifiers
158                let create_q = format!(
159                    "CREATE DATABASE \"{}\" OWNER \"{}\"",
160                    self.name, self.username
161                );
162                match admin_conn.execute(create_q.as_str()).await {
163                    Ok(_) => (),
164                    Err(SqlxError::Database(db_err))
165                        if db_err.code() == Some(Cow::Borrowed("23505"))
166                            || db_err.code() == Some(Cow::Borrowed("42P04")) =>
167                    {
168                        // 23505 is 'duplicate key value violates unique constraint'
169                        // 42P04 is 'duplicate database error'
170                        // in either case, assume another connection created the database
171                        // before we could and continue
172                    }
173                    Err(err) => {
174                        admin_conn.close().await?;
175                        return Err(err_msg!(Backend, "Error creating database").with_cause(err));
176                    }
177                }
178                admin_conn.close().await?;
179                Ok(self.pool().await?)
180            }
181            Err(err) => Err(err_msg!(Backend, "Error opening database").with_cause(err)),
182        }
183    }
184
185    /// Provision a Postgres store from this set of configuration options
186    pub async fn provision(
187        self,
188        method: StoreKeyMethod,
189        pass_key: PassKey<'_>,
190        profile: Option<String>,
191        recreate: bool,
192    ) -> Result<PostgresBackend, Error> {
193        let conn_pool = self.create_db_pool().await?;
194        let mut conn = conn_pool.acquire().await?;
195        let mut txn = conn.begin().await?;
196
197        if recreate {
198            // remove expected tables
199            reset_db(&mut txn).await?;
200        } else {
201            // check for presence of config table
202            let count = if let Some(schema) = self.schema.as_ref() {
203                sqlx::query_scalar::<_, i64>(
204                    "SELECT COUNT(*) FROM information_schema.tables
205                        WHERE table_schema=?1 AND table_name='config'",
206                )
207                .persistent(false)
208                .bind(schema)
209                .fetch_one(txn.as_mut())
210                .await
211                .map_err(err_map!(Backend, "Error checking for existing store"))?
212            } else {
213                sqlx::query_scalar::<_, i64>(
214                    "SELECT COUNT(*) FROM information_schema.tables
215                    WHERE table_schema=ANY (CURRENT_SCHEMAS(false)) AND table_name='config'",
216                )
217                .persistent(false)
218                .fetch_one(txn.as_mut())
219                .await
220                .map_err(err_map!(Backend, "Error checking for existing store"))?
221            };
222            if count > 0 {
223                // proceed to open, will fail if the version doesn't match
224                return open_db(
225                    conn_pool,
226                    Some(method),
227                    pass_key,
228                    profile,
229                    self.host,
230                    self.name,
231                )
232                .await;
233            }
234        }
235
236        // no 'config' table, assume empty database
237
238        let (profile_key, enc_profile_key, store_key, store_key_ref) = unblock({
239            let pass_key = pass_key.into_owned();
240            move || init_keys(method, pass_key)
241        })
242        .await?;
243        let default_profile = profile.unwrap_or_else(random_profile_name);
244        let profile_id = init_db(
245            txn,
246            &default_profile,
247            store_key_ref,
248            enc_profile_key,
249            self.schema.as_ref().unwrap_or(&self.username),
250        )
251        .await?;
252        conn.return_to_pool().await;
253
254        let mut key_cache = KeyCache::new(store_key);
255        key_cache.add_profile_mut(default_profile.clone(), profile_id, profile_key);
256
257        Ok(PostgresBackend::new(
258            conn_pool,
259            default_profile,
260            key_cache,
261            self.host,
262            self.name,
263        ))
264    }
265
266    /// Open an existing Postgres store from this set of configuration options
267    pub async fn open(
268        self,
269        method: Option<StoreKeyMethod>,
270        pass_key: PassKey<'_>,
271        profile: Option<String>,
272    ) -> Result<PostgresBackend, Error> {
273        let pool = match self.pool().await {
274            Ok(p) => Ok(p),
275            Err(SqlxError::Database(db_err)) if db_err.code() == Some(Cow::Borrowed("3D000")) => {
276                // error 3D000 is INVALID CATALOG NAME in postgres,
277                // this indicates that the database does not exist
278                Err(err_msg!(NotFound, "The requested database was not found"))
279            }
280            Err(err) => Err(err_msg!(Backend, "Error connecting to database pool").with_cause(err)),
281        }?;
282        open_db(pool, method, pass_key, profile, self.host, self.name).await
283    }
284
285    /// Remove an existing Postgres store defined by these configuration options
286    pub async fn remove(self) -> Result<bool, Error> {
287        let mut admin_conn = PgConnection::connect(self.admin_uri.as_ref())
288            .await
289            .map_err(err_map!(
290                Backend,
291                "Error creating admin connection to database"
292            ))?;
293        // any character except NUL is allowed in an identifier.
294        // double quotes must be escaped, but we just disallow those
295        let drop_q = format!("DROP DATABASE \"{}\"", self.name);
296        let res = match admin_conn.execute(drop_q.as_str()).await {
297            Ok(_) => Ok(true),
298            Err(SqlxError::Database(db_err)) if db_err.code() == Some(Cow::Borrowed("3D000")) => {
299                // invalid catalog name is raised if the database does not exist
300                Ok(false)
301            }
302            Err(err) => Err(err_msg!(Backend, "Error removing database").with_cause(err)),
303        }?;
304        admin_conn.close().await?;
305        Ok(res)
306    }
307}
308
309impl<'a> ManageBackend<'a> for PostgresStoreOptions {
310    type Backend = PostgresBackend;
311
312    fn open_backend(
313        self,
314        method: Option<StoreKeyMethod>,
315        pass_key: PassKey<'_>,
316        profile: Option<String>,
317    ) -> BoxFuture<'a, Result<PostgresBackend, Error>> {
318        let pass_key = pass_key.into_owned();
319        Box::pin(self.open(method, pass_key, profile))
320    }
321
322    fn provision_backend(
323        self,
324        method: StoreKeyMethod,
325        pass_key: PassKey<'_>,
326        profile: Option<String>,
327        recreate: bool,
328    ) -> BoxFuture<'a, Result<PostgresBackend, Error>> {
329        let pass_key = pass_key.into_owned();
330        Box::pin(self.provision(method, pass_key, profile, recreate))
331    }
332
333    fn remove_backend(self) -> BoxFuture<'a, Result<bool, Error>> {
334        Box::pin(self.remove())
335    }
336}
337
338pub(crate) async fn init_db(
339    mut txn: Transaction<'_, Postgres>,
340    profile_name: &str,
341    store_key_ref: String,
342    enc_profile_key: Vec<u8>,
343    schema: &str,
344) -> Result<ProfileId, Error> {
345    txn.execute(
346        format!(r#"
347        CREATE SCHEMA IF NOT EXISTS "{schema}";
348
349        CREATE TABLE "{schema}".config (
350            name TEXT NOT NULL,
351            value TEXT,
352            PRIMARY KEY(name)
353        );
354
355        CREATE TABLE "{schema}".profiles (
356            id BIGSERIAL,
357            name TEXT NOT NULL,
358            reference TEXT NULL,
359            profile_key BYTEA NULL,
360            PRIMARY KEY(id)
361        );
362        CREATE UNIQUE INDEX ix_profile_name ON "{schema}".profiles(name);
363
364        CREATE TABLE "{schema}".items (
365            id BIGSERIAL,
366            profile_id BIGINT NOT NULL,
367            kind SMALLINT NOT NULL,
368            category BYTEA NOT NULL,
369            name BYTEA NOT NULL,
370            value BYTEA NOT NULL,
371            expiry TIMESTAMP NULL,
372            PRIMARY KEY(id),
373            FOREIGN KEY(profile_id) REFERENCES "{schema}".profiles(id)
374                ON DELETE CASCADE ON UPDATE CASCADE
375        );
376        CREATE UNIQUE INDEX ix_items_uniq ON "{schema}".items(profile_id, kind, category, name);
377
378        CREATE TABLE "{schema}".items_tags (
379            id BIGSERIAL,
380            item_id BIGINT NOT NULL,
381            name BYTEA NOT NULL,
382            value BYTEA NOT NULL,
383            plaintext SMALLINT NOT NULL,
384            PRIMARY KEY(id),
385            FOREIGN KEY(item_id) REFERENCES "{schema}".items(id)
386                ON DELETE CASCADE ON UPDATE CASCADE
387        );
388        CREATE INDEX ix_items_tags_item_id ON "{schema}".items_tags(item_id);
389        CREATE INDEX ix_items_tags_name_enc ON "{schema}".items_tags(name, SUBSTR(value, 1, 12)) INCLUDE (item_id) WHERE plaintext=0;
390        CREATE INDEX ix_items_tags_name_plain ON "{schema}".items_tags(name, value) INCLUDE (item_id) WHERE plaintext=1;
391    "#).as_str(),
392    )
393    .await
394    .map_err(err_map!(Backend, "Error creating database tables"))?;
395
396    sqlx::query(
397        "INSERT INTO config (name, value) VALUES
398            ('default_profile', $1),
399            ('key', $2),
400            ('version', '1')",
401    )
402    .persistent(false)
403    .bind(profile_name)
404    .bind(store_key_ref)
405    .execute(txn.as_mut())
406    .await
407    .map_err(err_map!(Backend, "Error inserting configuration"))?;
408
409    let profile_id =
410        sqlx::query_scalar("INSERT INTO profiles (name, profile_key) VALUES ($1, $2) RETURNING id")
411            .bind(profile_name)
412            .bind(enc_profile_key)
413            .fetch_one(txn.as_mut())
414            .await
415            .map_err(err_map!(Backend, "Error inserting default profile"))?;
416
417    txn.commit().await?;
418
419    Ok(profile_id)
420}
421
422pub(crate) async fn reset_db(conn: &mut PgConnection) -> Result<(), Error> {
423    conn.execute(
424        "
425        DROP TABLE IF EXISTS
426          config, profiles,
427          profile_keys, keys,
428          items, items_tags;
429        ",
430    )
431    .await?;
432    Ok(())
433}
434
435pub(crate) async fn open_db(
436    conn_pool: PgPool,
437    method: Option<StoreKeyMethod>,
438    pass_key: PassKey<'_>,
439    profile: Option<String>,
440    host: String,
441    name: String,
442) -> Result<PostgresBackend, Error> {
443    let mut conn = conn_pool.acquire().await?;
444    let mut ver_ok = false;
445    let mut default_profile: Option<String> = None;
446    let mut store_key_ref: Option<String> = None;
447
448    let config = sqlx::query(
449        r#"SELECT name, value FROM config
450        WHERE name IN ('default_profile', 'key', 'version')"#,
451    )
452    .fetch_all(conn.as_mut())
453    .await
454    .map_err(err_map!(Backend, "Error fetching store configuration"))?;
455    for row in config {
456        match row.try_get(0)? {
457            "default_profile" => {
458                default_profile.replace(row.try_get(1)?);
459            }
460            "key" => {
461                store_key_ref.replace(row.try_get(1)?);
462            }
463            "version" => {
464                if row.try_get::<&str, _>(1)? != "1" {
465                    return Err(err_msg!(Unsupported, "Unsupported store version"));
466                }
467                ver_ok = true;
468            }
469            _ => (),
470        }
471    }
472    if !ver_ok {
473        return Err(err_msg!(Unsupported, "Store version not found"));
474    }
475    let profile = profile
476        .or(default_profile)
477        .ok_or_else(|| err_msg!(Unsupported, "Default store profile not found"))?;
478    let store_key = if let Some(store_key_ref) = store_key_ref {
479        let wrap_ref = StoreKeyReference::parse_uri(&store_key_ref)?;
480        if let Some(method) = method {
481            if !wrap_ref.compare_method(&method) {
482                return Err(err_msg!(Input, "Store key method mismatch"));
483            }
484        }
485        unblock({
486            let pass_key = pass_key.into_owned();
487            move || wrap_ref.resolve(pass_key)
488        })
489        .await?
490    } else {
491        return Err(err_msg!(Unsupported, "Store key not found"));
492    };
493
494    let mut key_cache = KeyCache::new(store_key);
495    let row = sqlx::query("SELECT id, profile_key FROM profiles WHERE name = $1")
496        .bind(&profile)
497        .fetch_one(conn.as_mut())
498        .await?;
499    let profile_id = row.try_get(0)?;
500    let profile_key = key_cache.load_key(row.try_get(1)?).await?;
501    conn.return_to_pool().await;
502
503    key_cache.add_profile_mut(profile.clone(), profile_id, profile_key);
504
505    Ok(PostgresBackend::new(
506        conn_pool, profile, key_cache, host, name,
507    ))
508}
509
510/// Validate a postgres identifier.
511/// Any character except NUL is allowed in an identifier. Double quotes must be escaped,
512/// but we just disallow those instead.
513fn _validate_ident(ident: &str, name: &str) -> Result<(), Error> {
514    if ident.is_empty() {
515        Err(err_msg!(Input, "{name} identifier is empty"))
516    } else if ident.find(['"', '\0']).is_some() {
517        Err(err_msg!(
518            Input,
519            "Invalid character in {name} identifier: '\"' and '\\0' are disallowed"
520        ))
521    } else {
522        Ok(())
523    }
524}
525
526#[cfg(test)]
527mod tests {
528    use super::*;
529
530    #[test]
531    fn postgres_parse_uri() {
532        let uri = "postgres://user:pass@host/db_name\
533            ?admin_account=user2&admin_password=pass2\
534            &connect_timeout=9&max_connections=23&min_connections=32\
535            &idle_timeout=99\
536            &test=1";
537        let opts = PostgresStoreOptions::new(uri).unwrap();
538        assert_eq!(opts.max_connections, 23);
539        assert_eq!(opts.min_connections, 32);
540        assert_eq!(opts.connect_timeout, Duration::from_secs(9));
541        assert_eq!(opts.idle_timeout, Duration::from_secs(99));
542        assert_eq!(opts.uri, "postgres://user:pass@host/db_name?test=1");
543        assert_eq!(
544            opts.admin_uri,
545            "postgres://user2:pass2@host/postgres?test=1"
546        );
547    }
548}