pg_client_config/
lib.rs

1//!
2//! Connection string parsing with support for service file
3//! and a subset of psql environment variables.
4//!
5//! *Note*: tokio-postgres 0.7.9 [introduced a change](https://github.com/3liz/pg-event-server/issues/1)
6//! preventing `PGUSER` and service configuration to set connection user. 
7//! The [release of tokio-postgres 0.7.10](https://github.com/sfackler/rust-postgres/blob/master/tokio-postgres/CHANGELOG.md#v0710---2023-08-25)
8//! fix this issue.
9//!
10//! ## Environment variables
11//!
12//! * `PGSERVICE` - Name of the postgres service used for connection params.
13//! * `PGSYSCONFDIR` - Location of the service files.
14//! * `PGSERVICEFILE` - Name of the service file.
15//! * `PGHOST` - behaves the same as the `host` connection parameter.
16//! * `PGPORT` - behaves the same as the `port` connection parameter.
17//! * `PGDATABASE` - behaves the same as the `dbname` connection parameter.
18//! * `PGUSER` - behaves the same as the user connection parameter.
19//! * `PGOPTIONS` - behaves the same as the `options` parameter.
20//! * `PGAPPNAME` - behaves the same as the `application_name` connection parameter.
21//! * `PGCONNECT_TIMEOUT` - behaves the same as the `connect_timeout` connection parameter.
22//! * `PGPASSFILE` - Specifies the name of the file used to store password.
23//!
24//! ## Passfile support 
25//!
26//! Passfile is actually supported only on linux platform
27//!
28//! ## Example
29//!
30//! ```no_run
31//! use pg_client_config::load_config;
32//!
33//! let config = load_config(Some("service=myservice")).unwrap();
34//! println!("{config:#?}");
35//! ```
36//!
37//! ## See also
38//!
39//! * [Pg service file](https://www.postgresql.org/docs/current/libpq-pgservice.html)
40//! * [Pg pass file](https://www.postgresql.org/docs/current/libpq-pgpass.html)
41//!
42
43use ini::Ini;
44use std::path::{Path, PathBuf};
45use std::str::FromStr;
46use std::time::Duration;
47use tokio_postgres::config::{ChannelBinding, Config, SslMode};
48
49#[cfg(all(target_family = "unix", feature = "with-passfile"))]
50mod passfile;
51
52#[cfg(not(all(target_family = "unix", feature = "with-passfile")))]
53mod passfile {
54    use super::*;
55    pub(crate) fn get_password_from_passfile(_: &mut Config) -> Result<()> {
56        Ok(())
57    }
58}
59
60/// Error while parsing service file or
61/// retrieving parameter from environment
62#[derive(thiserror::Error, Debug)]
63pub enum Error {
64    #[error("IO Error")]
65    IOError(#[from] std::io::Error),
66    #[error("Service File Error")]
67    PgServiceFileError(#[from] ini::Error),
68    #[error("Service file not found: {0}")]
69    PgServiceFileNotFound(String),
70    #[error("Definition of service {0} not found")]
71    PgServiceNotFound(String),
72    #[error("Invalid ssl mode, expecting 'prefer', 'require' or 'disable': found '{0}'")]
73    InvalidSslMode(String),
74    #[error("Invalid port, expecting integer, found '{0}'")]
75    InvalidPort(String),
76    #[error("Invalid connect_timeout, expecting number of secs, found '{0}'")]
77    InvalidConnectTimeout(String),
78    #[error("Invalid keepalives, '1' or '0', found '{0}'")]
79    InvalidKeepalives(String),
80    #[error("Invalid keepalives, expecting number of secs, found '{0}'")]
81    InvalidKeepalivesIdle(String),
82    #[error("Invalid Channel Binding, expecting 'prefer', 'require' or 'disable': found '{0}'")]
83    InvalidChannelBinding(String),
84    #[error("Missing service name in connection string")]
85    MissingServiceName,
86    #[error("Postgres config error")]
87    PostgresConfig(#[from] tokio_postgres::Error),
88    #[error("Invalid passfile mode")]
89    InvalidPassFileMode,
90    #[error("Error parsing passfile")]
91    PassfileParseError,
92    #[error("Pass file not found: {0}")]
93    PgPassFileNotFound(String),
94}
95
96pub type Result<T, E = Error> = std::result::Result<T, E>;
97
98/// Load postgres connection configuration
99///
100/// The configuration will handle PG environment variable.
101///
102/// If the connection string start with `service=<service>`
103/// the service will be searched in the file given by `PGSERVICEFILE`,
104/// or in `~/.pg_service.conf` and `PGSYSCONFDIR/pg_service.conf`.
105/// The remaining of the connection string is used directly for
106/// initializing [`Config`]
107///
108/// If the connection string do no start with `service=` the connection
109/// string is directly passed to [`Config`] along with parameters
110/// defined for any service defined in `PGSERVICE`.
111///
112/// If the connection string is None the [`Config`] is initialized
113/// from environment variables and/or service defined in `PGSERVICE`.
114///
115/// In all cases, parameters from the connection string take precedence.
116///
117pub fn load_config(config: Option<&str>) -> Result<Config> {
118    fn load_service_config(service: &str, cnxstr: &str) -> Result<Config> {
119        let mut config = if cnxstr.is_empty() {
120            Config::new()
121        } else {
122            Config::from_str(cnxstr)?
123        };
124        load_config_from_service(&mut config, service)?;
125        load_config_from_env(&mut config)?;
126        Ok(config)
127    }
128
129    if let Some(cnxstr) = config {
130        let cnxstr = cnxstr.trim_start();
131        if cnxstr.starts_with("service=") {
132            // Get the service name
133            // Assume the the tail is valid connection string
134            if let Some((service, tail)) = cnxstr.split_once('=').map(|(_, tail)| {
135                tail.split_once(|c: char| c.is_whitespace())
136                    .unwrap_or((tail, ""))
137            }) {
138                load_service_config(service, tail.trim())
139            } else {
140                Err(Error::MissingServiceName)
141            }
142        } else if let Ok(service) = std::env::var("PGSERVICE") {
143            // Service file defined
144            // But overridable from connection string
145            load_service_config(&service, cnxstr)
146        } else {
147            // No service defined
148            let mut config = Config::from_str(cnxstr)?;
149            load_config_from_env(&mut config)?;
150            Ok(config)
151        }
152    } else if let Ok(service) = std::env::var("PGSERVICE") {
153        load_service_config(&service, "")
154    } else {
155        // No service defined
156        // Initialize from env vars.
157        let mut config = Config::new();
158        load_config_from_env(&mut config)?;
159        Ok(config)
160    }
161    .and_then(|mut config| {
162        if config.get_password().is_none() {
163            passfile::get_password_from_passfile(&mut config)?;
164        }
165        Ok(config)
166    })
167}
168
169/// Load connection parameters from service config_file
170fn load_config_from_service(config: &mut Config, service_name: &str) -> Result<()> {
171    fn user_service_file() -> Option<PathBuf> {
172        std::env::var("PGSERVICEFILE")
173            .map(|path| Path::new(&path).into())
174            .or_else(|_| {
175                std::env::var("HOME").map(|path| Path::new(&path).join(".pg_service.conf"))
176            })
177            .ok()
178    }
179
180    fn sysconf_service_file() -> Option<PathBuf> {
181        std::env::var("PGSYSCONFDIR")
182            .map(|path| Path::new(&path).join("pg_service.conf"))
183            .ok()
184    }
185
186    fn get_service_params(config: &mut Config, path: &Path, service_name: &str) -> Result<bool> {
187        if path.exists() {
188            Ini::load_from_file(path)
189                .map_err(Error::from)
190                .and_then(|ini| {
191                    if let Some(params) = ini.section(Some(service_name)) {
192                        params
193                            .iter()
194                            .try_for_each(|(k, v)| set_parameter(config, k, v))
195                            .map(|_| true)
196                    } else {
197                        Ok(false)
198                    }
199                })
200        } else {
201            Err(Error::PgServiceFileNotFound(
202                path.to_string_lossy().into_owned(),
203            ))
204        }
205    }
206
207    let found = match user_service_file().and_then(|p| p.as_path().exists().then_some(p)) {
208        Some(path) => get_service_params(config, &path, service_name)?,
209        None => false,
210    } || match sysconf_service_file() {
211        Some(path) => get_service_params(config, &path, service_name)?,
212        None => false,
213    };
214
215    if !found {
216        Err(Error::PgServiceNotFound(service_name.into()))
217    } else {
218        Ok(())
219    }
220}
221
222/// Load configuration from environment variables
223fn load_config_from_env(config: &mut Config) -> Result<()> {
224    static ENV: [(&str, &str); 7] = [
225        ("PGHOST", "host"),
226        ("PGPORT", "port"),
227        ("PGDATABASE", "dbname"),
228        ("PGUSER", "user"),
229        ("PGOPTIONS", "options"),
230        ("PGAPPNAME", "application_name"),
231        ("PGCONNECT_TIMEOUT", "connect_timeout"),
232    ];
233
234    ENV.iter().try_for_each(|(varname, k)| {
235        if let Ok(v) = std::env::var(varname) {
236            set_parameter(config, k, &v)
237        } else {
238            Ok(())
239        }
240    })
241}
242
243fn set_parameter(config: &mut Config, k: &str, v: &str) -> Result<()> {
244    fn parse_ssl_mode(mode: &str) -> Result<SslMode> {
245        match mode {
246            "disable" => Ok(SslMode::Disable),
247            "prefer" => Ok(SslMode::Prefer),
248            "require" => Ok(SslMode::Require),
249            _ => Err(Error::InvalidSslMode(mode.into())),
250        }
251    }
252
253    fn parse_channel_binding(mode: &str) -> Result<ChannelBinding> {
254        match mode {
255            "disable" => Ok(ChannelBinding::Disable),
256            "prefer" => Ok(ChannelBinding::Prefer),
257            "require" => Ok(ChannelBinding::Require),
258            _ => Err(Error::InvalidChannelBinding(mode.into())),
259        }
260    }
261
262    match k {
263        // The following values may be set from
264        // environment variables
265        "user" => {
266            if config.get_user().is_none() {
267                config.user(v);
268            }
269        }
270        "password" => {
271            if config.get_password().is_none() {
272                config.password(v);
273            }
274        }
275        "dbname" => {
276            if config.get_dbname().is_none() {
277                config.dbname(v);
278            }
279        }
280        "options" => {
281            if config.get_options().is_none() {
282                config.options(v);
283            }
284        }
285        "host" | "hostaddr" => {
286            if config.get_hosts().is_empty() {
287                config.host(v);
288            }
289        }
290        "port" => {
291            if config.get_ports().is_empty() {
292                config.port(v.parse().map_err(|_| Error::InvalidPort(v.into()))?);
293            }
294        }
295        "application_name" => {
296            if config.get_application_name().is_none() {
297                config.application_name(v);
298            }
299        }
300        "connect_timeout" => {
301            if config.get_connect_timeout().is_none() {
302                config.connect_timeout(Duration::from_secs(
303                    v.parse()
304                        .map_err(|_| Error::InvalidConnectTimeout(v.into()))?,
305                ));
306            }
307        }
308        // The following are not set from environment variables
309        // values are always overriden (i.e service configuration takes
310        // precedence)
311        "sslmode" => {
312            config.ssl_mode(parse_ssl_mode(v)?);
313        }
314        "keepalives" => {
315            config.keepalives(match v {
316                "1" => Ok(true),
317                "0" => Ok(false),
318                _ => Err(Error::InvalidKeepalives(v.into())),
319            }?);
320        }
321        "keepalives_idle" => {
322            config.keepalives_idle(Duration::from_secs(
323                v.parse()
324                    .map_err(|_| Error::InvalidKeepalivesIdle(v.into()))?,
325            ));
326        }
327        "channel_binding" => {
328            config.channel_binding(parse_channel_binding(v)?);
329        }
330        _ => (),
331    }
332
333    Ok(())
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339    use tokio_postgres::config::Host;
340
341    #[test]
342    fn from_environment() {
343        std::env::set_var("PGUSER", "foo");
344        std::env::set_var("PGHOST", "foo.com");
345        std::env::set_var("PGDATABASE", "foodb");
346        std::env::set_var("PGPORT", "1234");
347
348        let config = load_config(None).unwrap();
349
350        assert_eq!(config.get_user(), Some("foo"));
351        assert_eq!(config.get_ports(), [1234]);
352        assert_eq!(config.get_hosts(), [Host::Tcp("foo.com".into())]);
353        assert_eq!(config.get_dbname(), Some("foodb"));
354    }
355
356    #[test]
357    fn from_service_file() {
358        std::env::set_var(
359            "PGSYSCONFDIR",
360            Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap())
361                .join("fixtures")
362                .to_str()
363                .unwrap(),
364        );
365
366        let config = load_config(Some("service=bar")).unwrap();
367
368        assert_eq!(config.get_user(), Some("bar"));
369        assert_eq!(config.get_ports(), [1234]);
370        assert_eq!(config.get_hosts(), [Host::Tcp("bar.com".into())]);
371        assert_eq!(config.get_dbname(), Some("bardb"));
372    }
373
374    #[test]
375    fn service_override() {
376        std::env::set_var(
377            "PGSYSCONFDIR",
378            Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap())
379                .join("fixtures")
380                .to_str()
381                .unwrap(),
382        );
383
384        let config = load_config(Some("service=bar user=baz")).unwrap();
385
386        assert_eq!(config.get_user(), Some("baz"));
387    }
388}