1use eyre::{eyre, WrapErr};
11use owo_colors::OwoColorize;
12use serde_derive::{Deserialize, Serialize};
13use std::collections::{BTreeMap, HashMap};
14use std::ffi::OsString;
15use std::fmt::{self, Display, Formatter};
16use std::io::ErrorKind;
17use std::path::{Path, PathBuf};
18use std::process::{Command, Stdio};
19use std::str::FromStr;
20use url::Url;
21
22pub static BASE_POSTGRES_PORT_NO: u16 = 28800;
23pub static BASE_POSTGRES_TESTING_PORT_NO: u16 = 32200;
24
25pub fn get_c_locale_flags() -> &'static [&'static str] {
28 #[cfg(target_os = "macos")]
29 {
30 &["--locale=C", "--lc-ctype=UTF-8"]
31 }
32 #[cfg(not(target_os = "macos"))]
33 {
34 match Command::new("locale").arg("-a").output() {
35 Ok(cmd)
36 if String::from_utf8_lossy(&cmd.stdout)
37 .lines()
38 .any(|l| l == "C.UTF-8" || l == "C.utf8") =>
39 {
40 &["--locale=C.UTF-8"]
41 }
42 _ => &["--locale=C"],
44 }
45 }
46}
47
48mod path_methods;
53pub use path_methods::{get_target_dir, prefix_path};
54
55#[derive(Clone, Debug)]
56pub struct PgVersion {
57 major: u16,
58 minor: u16,
59 url: Url,
60}
61
62impl PgVersion {
63 pub fn new(major: u16, minor: u16, url: Url) -> PgVersion {
64 PgVersion { major, minor, url }
65 }
66}
67
68impl Display for PgVersion {
69 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
70 write!(f, "{}.{}", self.major, self.minor)
71 }
72}
73
74#[derive(Clone, Debug)]
75pub struct PgConfig {
76 version: Option<PgVersion>,
77 pg_config: Option<PathBuf>,
78 known_props: Option<BTreeMap<String, String>>,
79 base_port: u16,
80 base_testing_port: u16,
81}
82
83impl Display for PgConfig {
84 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
85 let major = self.major_version().expect("could not determine major version");
86 let minor = self.minor_version().expect("could not determine minor version");
87 let path = match self.pg_config.as_ref() {
88 Some(path) => path.display().to_string(),
89 None => self.version.as_ref().unwrap().url.to_string(),
90 };
91 write!(f, "{}.{}={}", major, minor, path)
92 }
93}
94
95impl Default for PgConfig {
96 fn default() -> Self {
97 PgConfig {
98 version: None,
99 pg_config: None,
100 known_props: None,
101 base_port: BASE_POSTGRES_PORT_NO,
102 base_testing_port: BASE_POSTGRES_TESTING_PORT_NO,
103 }
104 }
105}
106
107impl From<PgVersion> for PgConfig {
108 fn from(version: PgVersion) -> Self {
109 PgConfig { version: Some(version), pg_config: None, ..Default::default() }
110 }
111}
112
113impl PgConfig {
114 pub fn new(pg_config: PathBuf, base_port: u16, base_testing_port: u16) -> Self {
115 PgConfig {
116 version: None,
117 pg_config: Some(pg_config),
118 known_props: None,
119 base_port,
120 base_testing_port,
121 }
122 }
123
124 pub fn new_with_defaults(pg_config: PathBuf) -> Self {
125 PgConfig {
126 version: None,
127 pg_config: Some(pg_config),
128 known_props: None,
129 base_port: BASE_POSTGRES_PORT_NO,
130 base_testing_port: BASE_POSTGRES_TESTING_PORT_NO,
131 }
132 }
133
134 pub fn from_path() -> Self {
135 let path =
136 pathsearch::find_executable_in_path("pg_config").unwrap_or_else(|| "pg_config".into());
137 Self::new_with_defaults(path)
138 }
139
140 pub fn from_env() -> eyre::Result<Self> {
146 if !Self::is_in_environment() {
147 Err(eyre::eyre!("`PgConfig` not described in the environment"))
148 } else {
149 const PREFIX: &str = "PGX_PG_CONFIG_";
150
151 let mut known_props = BTreeMap::new();
152 for (k, v) in std::env::vars() {
153 if k.starts_with(PREFIX) {
154 let prop = format!("--{}", k.trim_start_matches(PREFIX).to_lowercase());
156 known_props.insert(prop, v);
157 }
158 }
159
160 Ok(Self {
161 version: None,
162 pg_config: None,
163 known_props: Some(known_props),
164 base_port: 0,
165 base_testing_port: 0,
166 })
167 }
168 }
169
170 pub fn is_in_environment() -> bool {
171 match std::env::var("PGX_PG_CONFIG_AS_ENV") {
172 Ok(value) => value == "true",
173 _ => false,
174 }
175 }
176
177 pub fn is_real(&self) -> bool {
178 self.pg_config.is_some()
179 }
180
181 pub fn label(&self) -> eyre::Result<String> {
182 Ok(format!("pg{}", self.major_version()?))
183 }
184
185 pub fn path(&self) -> Option<PathBuf> {
186 self.pg_config.clone()
187 }
188
189 pub fn parent_path(&self) -> PathBuf {
190 self.path().unwrap().parent().unwrap().to_path_buf()
191 }
192
193 fn parse_version_str(version_str: &str) -> eyre::Result<(u16, u16)> {
194 let version_parts = version_str.split_whitespace().collect::<Vec<&str>>();
195 let version = version_parts
196 .get(1)
197 .ok_or_else(|| eyre!("invalid version string: {}", version_str))?
198 .split('.')
199 .collect::<Vec<&str>>();
200 if version.len() < 2 {
201 return Err(eyre!("invalid version string: {}", version_str));
202 }
203 let major = u16::from_str(version[0])
204 .map_err(|e| eyre!("invalid major version number `{}`: {:?}", version[0], e))?;
205 let mut minor = version[1];
206 let mut end_index = minor.len();
207 for (i, c) in minor.chars().enumerate() {
208 if !c.is_ascii_digit() {
209 end_index = i;
210 break;
211 }
212 }
213 minor = &minor[0..end_index];
214 let minor = u16::from_str(minor)
215 .map_err(|e| eyre!("invalid minor version number `{}`: {:?}", minor, e))?;
216 return Ok((major, minor));
217 }
218
219 fn get_version(&self) -> eyre::Result<(u16, u16)> {
220 let version_string = self.run("--version")?;
221 Self::parse_version_str(&version_string)
222 }
223
224 pub fn major_version(&self) -> eyre::Result<u16> {
225 match &self.version {
226 Some(version) => Ok(version.major),
227 None => Ok(self.get_version()?.0),
228 }
229 }
230
231 pub fn minor_version(&self) -> eyre::Result<u16> {
232 match &self.version {
233 Some(version) => Ok(version.minor),
234 None => Ok(self.get_version()?.1),
235 }
236 }
237
238 pub fn version(&self) -> eyre::Result<String> {
239 let major = self.major_version()?;
240 let minor = self.minor_version()?;
241 let version = format!("{}.{}", major, minor);
242 Ok(version)
243 }
244
245 pub fn url(&self) -> Option<&Url> {
246 match &self.version {
247 Some(version) => Some(&version.url),
248 None => None,
249 }
250 }
251
252 pub fn port(&self) -> eyre::Result<u16> {
253 Ok(self.base_port + self.major_version()?)
254 }
255
256 pub fn test_port(&self) -> eyre::Result<u16> {
257 Ok(self.base_testing_port + self.major_version()?)
258 }
259
260 pub fn host(&self) -> &'static str {
261 "localhost"
262 }
263
264 pub fn bin_dir(&self) -> eyre::Result<PathBuf> {
265 Ok(Path::new(&self.run("--bindir")?).to_path_buf())
266 }
267
268 pub fn postmaster_path(&self) -> eyre::Result<PathBuf> {
269 let mut path = self.bin_dir()?;
270 path.push("postmaster");
271 Ok(path)
272 }
273
274 pub fn initdb_path(&self) -> eyre::Result<PathBuf> {
275 let mut path = self.bin_dir()?;
276 path.push("initdb");
277 Ok(path)
278 }
279
280 pub fn createdb_path(&self) -> eyre::Result<PathBuf> {
281 let mut path = self.bin_dir()?;
282 path.push("createdb");
283 Ok(path)
284 }
285
286 pub fn dropdb_path(&self) -> eyre::Result<PathBuf> {
287 let mut path = self.bin_dir()?;
288 path.push("dropdb");
289 Ok(path)
290 }
291
292 pub fn psql_path(&self) -> eyre::Result<PathBuf> {
293 let mut path = self.bin_dir()?;
294 path.push("psql");
295 Ok(path)
296 }
297
298 pub fn data_dir(&self) -> eyre::Result<PathBuf> {
299 let mut path = Pgx::home()?;
300 path.push(format!("data-{}", self.major_version()?));
301 Ok(path)
302 }
303
304 pub fn log_file(&self) -> eyre::Result<PathBuf> {
305 let mut path = Pgx::home()?;
306 path.push(format!("{}.log", self.major_version()?));
307 Ok(path)
308 }
309
310 pub fn includedir_server(&self) -> eyre::Result<PathBuf> {
311 Ok(self.run("--includedir-server")?.into())
312 }
313
314 pub fn pkglibdir(&self) -> eyre::Result<PathBuf> {
315 Ok(self.run("--pkglibdir")?.into())
316 }
317
318 pub fn sharedir(&self) -> eyre::Result<PathBuf> {
319 Ok(self.run("--sharedir")?.into())
320 }
321
322 pub fn cppflags(&self) -> eyre::Result<OsString> {
323 Ok(self.run("--cppflags")?.into())
324 }
325
326 pub fn extension_dir(&self) -> eyre::Result<PathBuf> {
327 let mut path = self.sharedir()?;
328 path.push("extension");
329 Ok(path)
330 }
331
332 fn run(&self, arg: &str) -> eyre::Result<String> {
333 if self.known_props.is_some() {
334 Ok(self
337 .known_props
338 .as_ref()
339 .unwrap()
340 .get(arg)
341 .ok_or_else(|| {
342 std::io::Error::new(
343 ErrorKind::InvalidData,
344 format!("`PgConfig` has no known property named {arg}"),
345 )
346 })
347 .cloned()?)
348 } else {
349 let pg_config = self.pg_config.clone().unwrap_or_else(|| {
352 std::env::var("PG_CONFIG").unwrap_or_else(|_| "pg_config".to_string()).into()
353 });
354
355 match Command::new(&pg_config).arg(arg).output() {
356 Ok(output) => Ok(String::from_utf8(output.stdout).unwrap().trim().to_string()),
357 Err(e) => match e.kind() {
358 ErrorKind::NotFound => Err(e).wrap_err_with(|| {
359 format!("Unable to find `{}` on the system $PATH", "pg_config".yellow())
360 }),
361 _ => Err(e.into()),
362 },
363 }
364 }
365 }
366}
367
368#[derive(Debug)]
369pub struct Pgx {
370 pg_configs: Vec<PgConfig>,
371 base_port: u16,
372 base_testing_port: u16,
373}
374
375impl Default for Pgx {
376 fn default() -> Self {
377 Self {
378 pg_configs: vec![],
379 base_port: BASE_POSTGRES_PORT_NO,
380 base_testing_port: BASE_POSTGRES_TESTING_PORT_NO,
381 }
382 }
383}
384
385#[derive(Debug, Serialize, Deserialize)]
386struct ConfigToml {
387 configs: HashMap<String, PathBuf>,
388 #[serde(skip_serializing_if = "Option::is_none")]
389 base_port: Option<u16>,
390 #[serde(skip_serializing_if = "Option::is_none")]
391 base_testing_port: Option<u16>,
392}
393
394pub enum PgConfigSelector<'a> {
395 All,
396 Specific(&'a str),
397 Environment,
398}
399
400impl<'a> PgConfigSelector<'a> {
401 pub fn new(label: &'a str) -> Self {
402 if label == "all" {
403 PgConfigSelector::All
404 } else {
405 PgConfigSelector::Specific(label)
406 }
407 }
408}
409
410impl Pgx {
411 pub fn new(base_port: u16, base_testing_port: u16) -> Self {
412 Pgx { pg_configs: vec![], base_port, base_testing_port }
413 }
414
415 pub fn from_config() -> eyre::Result<Self> {
416 match std::env::var("PGX_PG_CONFIG_PATH") {
417 Ok(pg_config) => {
418 let mut pgx = Pgx::default();
420 pgx.push(PgConfig::new(pg_config.into(), pgx.base_port, pgx.base_testing_port));
421 Ok(pgx)
422 }
423 Err(_) => {
424 let path = Pgx::config_toml()?;
426 if !path.exists() {
427 return Err(eyre!(
428 "{} not found. Have you run `{}` yet?",
429 path.display(),
430 "cargo pgx init".bold().yellow()
431 ));
432 }
433
434 match toml::from_str::<ConfigToml>(&std::fs::read_to_string(&path)?) {
435 Ok(configs) => {
436 let mut pgx = Pgx::new(
437 configs.base_port.unwrap_or(BASE_POSTGRES_PORT_NO),
438 configs.base_testing_port.unwrap_or(BASE_POSTGRES_TESTING_PORT_NO),
439 );
440
441 for (_, v) in configs.configs {
442 pgx.push(PgConfig::new(v, pgx.base_port, pgx.base_testing_port));
443 }
444 Ok(pgx)
445 }
446 Err(e) => {
447 Err(e).wrap_err_with(|| format!("Could not read `{}`", path.display()))
448 }
449 }
450 }
451 }
452 }
453
454 pub fn push(&mut self, pg_config: PgConfig) {
455 self.pg_configs.push(pg_config);
456 }
457
458 pub fn iter(
468 &self,
469 which: PgConfigSelector,
470 ) -> impl std::iter::Iterator<Item = eyre::Result<PgConfig>> {
471 match (which, PgConfig::is_in_environment()) {
472 (PgConfigSelector::All, true) | (PgConfigSelector::Environment, _) => {
473 vec![PgConfig::from_env()].into_iter()
474 }
475
476 (PgConfigSelector::All, _) => {
477 let mut configs = self.pg_configs.iter().collect::<Vec<_>>();
478 configs.sort_by(|a, b| {
479 a.major_version()
480 .expect("no major version")
481 .cmp(&b.major_version().expect("no major version"))
482 });
483
484 configs.into_iter().map(|c| Ok(c.clone())).collect::<Vec<_>>().into_iter()
485 }
486 (PgConfigSelector::Specific(label), _) => vec![self.get(label)].into_iter(),
487 }
488 }
489
490 pub fn get(&self, label: &str) -> eyre::Result<PgConfig> {
491 for pg_config in self.pg_configs.iter() {
492 if pg_config.label()? == label {
493 return Ok(pg_config.clone());
494 }
495 }
496 Err(eyre!("Postgres `{}` is not managed by pgx", label))
497 }
498
499 pub fn is_feature_flag(&self, label: &str) -> bool {
502 for v in SUPPORTED_MAJOR_VERSIONS {
503 if label == &format!("pg{}", v) {
504 return true;
505 }
506 }
507 false
508 }
509
510 pub fn home() -> Result<PathBuf, std::io::Error> {
511 std::env::var("PGX_HOME").map_or_else(
512 |_| {
513 let mut dir = match dirs::home_dir() {
514 Some(dir) => dir,
515 None => {
516 return Err(std::io::Error::new(
517 ErrorKind::NotFound,
518 "You don't seem to have a home directory",
519 ));
520 }
521 };
522 dir.push(".pgx");
523 if !dir.exists() {
524 if let Err(e) = std::fs::create_dir_all(&dir) {
525 return Err(std::io::Error::new(
526 ErrorKind::InvalidInput,
527 format!("could not create PGX_HOME at `{}`: {:?}", dir.display(), e),
528 ));
529 }
530 }
531
532 Ok(dir)
533 },
534 |v| Ok(v.into()),
535 )
536 }
537
538 pub fn postmaster_stub_dir() -> Result<PathBuf, std::io::Error> {
544 let mut stub_dir = Self::home()?;
545 stub_dir.push("postmaster_stubs");
546 Ok(stub_dir)
547 }
548
549 pub fn config_toml() -> Result<PathBuf, std::io::Error> {
550 let mut path = Pgx::home()?;
551 path.push("config.toml");
552 Ok(path)
553 }
554}
555
556pub const SUPPORTED_MAJOR_VERSIONS: &[u16] = &[11, 12, 13, 14, 15];
557
558pub fn createdb(
559 pg_config: &PgConfig,
560 dbname: &str,
561 is_test: bool,
562 if_not_exists: bool,
563) -> eyre::Result<bool> {
564 if if_not_exists && does_db_exist(pg_config, dbname)? {
565 return Ok(false);
566 }
567
568 println!("{} database {}", " Creating".bold().green(), dbname);
569 let mut command = Command::new(pg_config.createdb_path()?);
570 command
571 .env_remove("PGDATABASE")
572 .env_remove("PGHOST")
573 .env_remove("PGPORT")
574 .env_remove("PGUSER")
575 .arg("-h")
576 .arg(pg_config.host())
577 .arg("-p")
578 .arg(if is_test {
579 pg_config.test_port()?.to_string()
580 } else {
581 pg_config.port()?.to_string()
582 })
583 .arg(dbname)
584 .stdout(Stdio::piped())
585 .stderr(Stdio::piped());
586
587 let command_str = format!("{:?}", command);
588
589 let child = command.spawn().wrap_err_with(|| {
590 format!("Failed to spawn process for creating database using command: '{command_str}': ")
591 })?;
592
593 let output = child.wait_with_output().wrap_err_with(|| {
594 format!(
595 "failed waiting for spawned process to create database using command: '{command_str}': "
596 )
597 })?;
598
599 if !output.status.success() {
600 return Err(eyre!(
601 "problem running createdb: {}\n\n{}{}",
602 command_str,
603 String::from_utf8(output.stdout).unwrap(),
604 String::from_utf8(output.stderr).unwrap()
605 ));
606 }
607
608 Ok(true)
609}
610
611fn does_db_exist(pg_config: &PgConfig, dbname: &str) -> eyre::Result<bool> {
612 let mut command = Command::new(pg_config.psql_path()?);
613 command
614 .arg("-XqAt")
615 .env_remove("PGUSER")
616 .arg("-h")
617 .arg(pg_config.host())
618 .arg("-p")
619 .arg(pg_config.port()?.to_string())
620 .arg("template1")
621 .arg("-c")
622 .arg(&format!(
623 "select count(*) from pg_database where datname = '{}';",
624 dbname.replace("'", "''")
625 ))
626 .stdout(Stdio::piped())
627 .stderr(Stdio::piped());
628
629 let command_str = format!("{:?}", command);
630 let output = command.output()?;
631
632 if !output.status.success() {
633 return Err(eyre!(
634 "problem checking if database '{}' exists: {}\n\n{}{}",
635 dbname,
636 command_str,
637 String::from_utf8(output.stdout).unwrap(),
638 String::from_utf8(output.stderr).unwrap()
639 ));
640 } else {
641 let count = i32::from_str(&String::from_utf8(output.stdout).unwrap().trim())
642 .wrap_err("result is not a number")?;
643 Ok(count > 0)
644 }
645}
646
647#[test]
648fn parse_version() {
649 let versions = [
651 ("PostgreSQL 10.22", 10, 22),
652 ("PostgreSQL 11.2", 11, 2),
653 ("PostgreSQL 11.17", 11, 17),
654 ("PostgreSQL 12.12", 12, 12),
655 ("PostgreSQL 13.8", 13, 8),
656 ("PostgreSQL 14.5", 14, 5),
657 ("PostgreSQL 11.2-FOO-BAR+", 11, 2),
658 ("PostgreSQL 10.22-", 10, 22),
659 ];
660 for (s, major_expected, minor_expected) in versions {
661 let (major, minor) =
662 PgConfig::parse_version_str(s).expect("Unable to parse version string");
663 assert_eq!(major, major_expected, "Major version should match");
664 assert_eq!(minor, minor_expected, "Minor version should match");
665 }
666
667 let _ = PgConfig::parse_version_str("10.22").expect_err("Parsed invalid version string");
669 let _ =
670 PgConfig::parse_version_str("PostgresSQL 10").expect_err("Parsed invalid version string");
671 let _ =
672 PgConfig::parse_version_str("PostgresSQL 10.").expect_err("Parsed invalid version string");
673 let _ =
674 PgConfig::parse_version_str("PostgresSQL 12.f").expect_err("Parsed invalid version string");
675 let _ =
676 PgConfig::parse_version_str("PostgresSQL .53").expect_err("Parsed invalid version string");
677}
678
679#[test]
680fn from_empty_env() -> eyre::Result<()> {
681 let pg_config = PgConfig::from_env();
683 assert!(pg_config.is_err());
684
685 std::env::set_var("PGX_PG_CONFIG_AS_ENV", "true");
687 std::env::set_var("PGX_PG_CONFIG_VERSION", "PostgresSQL 15.1");
688 std::env::set_var("PGX_PG_CONFIG_INCLUDEDIR-SERVER", "/path/to/server/headers");
689 std::env::set_var("PGX_PG_CONFIG_CPPFLAGS", "some cpp flags");
690
691 let pg_config = PgConfig::from_env().unwrap();
692 assert_eq!(pg_config.major_version()?, 15, "Major version should match");
693 assert_eq!(pg_config.minor_version()?, 1, "Minor version should match");
694 assert_eq!(
695 pg_config.includedir_server()?,
696 PathBuf::from("/path/to/server/headers"),
697 "includdir_server should match"
698 );
699 assert_eq!(pg_config.cppflags()?, OsString::from("some cpp flags"), "cppflags should match");
700
701 assert!(pg_config.sharedir().is_err());
703 Ok(())
704}