1use ini::Ini;
44use std::path::{Path, PathBuf};
45use std::str::FromStr;
46use std::time::Duration;
47use tokio_postgres::config::{ChannelBinding, Config, SslMode};
48
49#[cfg(all(target_family = "unix", feature = "with-passfile"))]
50mod passfile;
51
52#[cfg(not(all(target_family = "unix", feature = "with-passfile")))]
53mod passfile {
54 use super::*;
55 pub(crate) fn get_password_from_passfile(_: &mut Config) -> Result<()> {
56 Ok(())
57 }
58}
59
60#[derive(thiserror::Error, Debug)]
63pub enum Error {
64 #[error("IO Error")]
65 IOError(#[from] std::io::Error),
66 #[error("Service File Error")]
67 PgServiceFileError(#[from] ini::Error),
68 #[error("Service file not found: {0}")]
69 PgServiceFileNotFound(String),
70 #[error("Definition of service {0} not found")]
71 PgServiceNotFound(String),
72 #[error("Invalid ssl mode, expecting 'prefer', 'require' or 'disable': found '{0}'")]
73 InvalidSslMode(String),
74 #[error("Invalid port, expecting integer, found '{0}'")]
75 InvalidPort(String),
76 #[error("Invalid connect_timeout, expecting number of secs, found '{0}'")]
77 InvalidConnectTimeout(String),
78 #[error("Invalid keepalives, '1' or '0', found '{0}'")]
79 InvalidKeepalives(String),
80 #[error("Invalid keepalives, expecting number of secs, found '{0}'")]
81 InvalidKeepalivesIdle(String),
82 #[error("Invalid Channel Binding, expecting 'prefer', 'require' or 'disable': found '{0}'")]
83 InvalidChannelBinding(String),
84 #[error("Missing service name in connection string")]
85 MissingServiceName,
86 #[error("Postgres config error")]
87 PostgresConfig(#[from] tokio_postgres::Error),
88 #[error("Invalid passfile mode")]
89 InvalidPassFileMode,
90 #[error("Error parsing passfile")]
91 PassfileParseError,
92 #[error("Pass file not found: {0}")]
93 PgPassFileNotFound(String),
94}
95
96pub type Result<T, E = Error> = std::result::Result<T, E>;
97
98pub fn load_config(config: Option<&str>) -> Result<Config> {
118 fn load_service_config(service: &str, cnxstr: &str) -> Result<Config> {
119 let mut config = if cnxstr.is_empty() {
120 Config::new()
121 } else {
122 Config::from_str(cnxstr)?
123 };
124 load_config_from_service(&mut config, service)?;
125 load_config_from_env(&mut config)?;
126 Ok(config)
127 }
128
129 if let Some(cnxstr) = config {
130 let cnxstr = cnxstr.trim_start();
131 if cnxstr.starts_with("service=") {
132 if let Some((service, tail)) = cnxstr.split_once('=').map(|(_, tail)| {
135 tail.split_once(|c: char| c.is_whitespace())
136 .unwrap_or((tail, ""))
137 }) {
138 load_service_config(service, tail.trim())
139 } else {
140 Err(Error::MissingServiceName)
141 }
142 } else if let Ok(service) = std::env::var("PGSERVICE") {
143 load_service_config(&service, cnxstr)
146 } else {
147 let mut config = Config::from_str(cnxstr)?;
149 load_config_from_env(&mut config)?;
150 Ok(config)
151 }
152 } else if let Ok(service) = std::env::var("PGSERVICE") {
153 load_service_config(&service, "")
154 } else {
155 let mut config = Config::new();
158 load_config_from_env(&mut config)?;
159 Ok(config)
160 }
161 .and_then(|mut config| {
162 if config.get_password().is_none() {
163 passfile::get_password_from_passfile(&mut config)?;
164 }
165 Ok(config)
166 })
167}
168
169fn load_config_from_service(config: &mut Config, service_name: &str) -> Result<()> {
171 fn user_service_file() -> Option<PathBuf> {
172 std::env::var("PGSERVICEFILE")
173 .map(|path| Path::new(&path).into())
174 .or_else(|_| {
175 std::env::var("HOME").map(|path| Path::new(&path).join(".pg_service.conf"))
176 })
177 .ok()
178 }
179
180 fn sysconf_service_file() -> Option<PathBuf> {
181 std::env::var("PGSYSCONFDIR")
182 .map(|path| Path::new(&path).join("pg_service.conf"))
183 .ok()
184 }
185
186 fn get_service_params(config: &mut Config, path: &Path, service_name: &str) -> Result<bool> {
187 if path.exists() {
188 Ini::load_from_file(path)
189 .map_err(Error::from)
190 .and_then(|ini| {
191 if let Some(params) = ini.section(Some(service_name)) {
192 params
193 .iter()
194 .try_for_each(|(k, v)| set_parameter(config, k, v))
195 .map(|_| true)
196 } else {
197 Ok(false)
198 }
199 })
200 } else {
201 Err(Error::PgServiceFileNotFound(
202 path.to_string_lossy().into_owned(),
203 ))
204 }
205 }
206
207 let found = match user_service_file().and_then(|p| p.as_path().exists().then_some(p)) {
208 Some(path) => get_service_params(config, &path, service_name)?,
209 None => false,
210 } || match sysconf_service_file() {
211 Some(path) => get_service_params(config, &path, service_name)?,
212 None => false,
213 };
214
215 if !found {
216 Err(Error::PgServiceNotFound(service_name.into()))
217 } else {
218 Ok(())
219 }
220}
221
222fn load_config_from_env(config: &mut Config) -> Result<()> {
224 static ENV: [(&str, &str); 7] = [
225 ("PGHOST", "host"),
226 ("PGPORT", "port"),
227 ("PGDATABASE", "dbname"),
228 ("PGUSER", "user"),
229 ("PGOPTIONS", "options"),
230 ("PGAPPNAME", "application_name"),
231 ("PGCONNECT_TIMEOUT", "connect_timeout"),
232 ];
233
234 ENV.iter().try_for_each(|(varname, k)| {
235 if let Ok(v) = std::env::var(varname) {
236 set_parameter(config, k, &v)
237 } else {
238 Ok(())
239 }
240 })
241}
242
243fn set_parameter(config: &mut Config, k: &str, v: &str) -> Result<()> {
244 fn parse_ssl_mode(mode: &str) -> Result<SslMode> {
245 match mode {
246 "disable" => Ok(SslMode::Disable),
247 "prefer" => Ok(SslMode::Prefer),
248 "require" => Ok(SslMode::Require),
249 _ => Err(Error::InvalidSslMode(mode.into())),
250 }
251 }
252
253 fn parse_channel_binding(mode: &str) -> Result<ChannelBinding> {
254 match mode {
255 "disable" => Ok(ChannelBinding::Disable),
256 "prefer" => Ok(ChannelBinding::Prefer),
257 "require" => Ok(ChannelBinding::Require),
258 _ => Err(Error::InvalidChannelBinding(mode.into())),
259 }
260 }
261
262 match k {
263 "user" => {
266 if config.get_user().is_none() {
267 config.user(v);
268 }
269 }
270 "password" => {
271 if config.get_password().is_none() {
272 config.password(v);
273 }
274 }
275 "dbname" => {
276 if config.get_dbname().is_none() {
277 config.dbname(v);
278 }
279 }
280 "options" => {
281 if config.get_options().is_none() {
282 config.options(v);
283 }
284 }
285 "host" | "hostaddr" => {
286 if config.get_hosts().is_empty() {
287 config.host(v);
288 }
289 }
290 "port" => {
291 if config.get_ports().is_empty() {
292 config.port(v.parse().map_err(|_| Error::InvalidPort(v.into()))?);
293 }
294 }
295 "application_name" => {
296 if config.get_application_name().is_none() {
297 config.application_name(v);
298 }
299 }
300 "connect_timeout" => {
301 if config.get_connect_timeout().is_none() {
302 config.connect_timeout(Duration::from_secs(
303 v.parse()
304 .map_err(|_| Error::InvalidConnectTimeout(v.into()))?,
305 ));
306 }
307 }
308 "sslmode" => {
312 config.ssl_mode(parse_ssl_mode(v)?);
313 }
314 "keepalives" => {
315 config.keepalives(match v {
316 "1" => Ok(true),
317 "0" => Ok(false),
318 _ => Err(Error::InvalidKeepalives(v.into())),
319 }?);
320 }
321 "keepalives_idle" => {
322 config.keepalives_idle(Duration::from_secs(
323 v.parse()
324 .map_err(|_| Error::InvalidKeepalivesIdle(v.into()))?,
325 ));
326 }
327 "channel_binding" => {
328 config.channel_binding(parse_channel_binding(v)?);
329 }
330 _ => (),
331 }
332
333 Ok(())
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339 use tokio_postgres::config::Host;
340
341 #[test]
342 fn from_environment() {
343 std::env::set_var("PGUSER", "foo");
344 std::env::set_var("PGHOST", "foo.com");
345 std::env::set_var("PGDATABASE", "foodb");
346 std::env::set_var("PGPORT", "1234");
347
348 let config = load_config(None).unwrap();
349
350 assert_eq!(config.get_user(), Some("foo"));
351 assert_eq!(config.get_ports(), [1234]);
352 assert_eq!(config.get_hosts(), [Host::Tcp("foo.com".into())]);
353 assert_eq!(config.get_dbname(), Some("foodb"));
354 }
355
356 #[test]
357 fn from_service_file() {
358 std::env::set_var(
359 "PGSYSCONFDIR",
360 Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap())
361 .join("fixtures")
362 .to_str()
363 .unwrap(),
364 );
365
366 let config = load_config(Some("service=bar")).unwrap();
367
368 assert_eq!(config.get_user(), Some("bar"));
369 assert_eq!(config.get_ports(), [1234]);
370 assert_eq!(config.get_hosts(), [Host::Tcp("bar.com".into())]);
371 assert_eq!(config.get_dbname(), Some("bardb"));
372 }
373
374 #[test]
375 fn service_override() {
376 std::env::set_var(
377 "PGSYSCONFDIR",
378 Path::new(&std::env::var("CARGO_MANIFEST_DIR").unwrap())
379 .join("fixtures")
380 .to_str()
381 .unwrap(),
382 );
383
384 let config = load_config(Some("service=bar user=baz")).unwrap();
385
386 assert_eq!(config.get_user(), Some("baz"));
387 }
388}