use super::{SqlGraphEntity, SqlGraphIdentifier, ToSql};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use thiserror::Error;
#[derive(Debug, Clone, Hash, PartialOrd, Ord, PartialEq, Eq)]
pub struct ControlFile {
pub comment: String,
pub default_version: String,
pub module_pathname: Option<String>,
pub relocatable: bool,
pub superuser: bool,
pub schema: Option<String>,
pub trusted: bool,
}
impl ControlFile {
#[allow(clippy::should_implement_trait)]
pub fn from_str(input: &str) -> Result<Self, ControlFileError> {
Self::from_str_with_version(input, None)
}
pub fn from_str_with_cargo_version(
input: &str,
cargo_version: &str,
) -> Result<Self, ControlFileError> {
Self::from_str_with_version(input, Some(cargo_version))
}
pub fn from_path_with_cargo_version(
path: impl AsRef<Path>,
cargo_version: &str,
) -> Result<Self, ControlFileError> {
let contents = std::fs::read_to_string(path)?;
Self::from_str_with_cargo_version(contents.as_str(), cargo_version)
}
fn from_str_with_version(
input: &str,
cargo_version: Option<&str>,
) -> Result<Self, ControlFileError> {
fn do_var_replacements(
mut input: String,
cargo_version: Option<&str>,
) -> Result<String, ControlFileError> {
const CARGO_VERSION: &str = "@CARGO_VERSION@";
if input.contains(CARGO_VERSION) {
let cargo_version = match cargo_version {
Some(cargo_version) => cargo_version.to_owned(),
None => std::env::var("CARGO_PKG_VERSION").map_err(|_| {
ControlFileError::MissingEnvvar("CARGO_PKG_VERSION".to_string())
})?,
};
input = input.replace(CARGO_VERSION, &cargo_version);
}
Ok(input)
}
let mut temp = HashMap::new();
for line in input.lines() {
let parts: Vec<&str> = line.split('=').collect();
if parts.len() != 2 {
continue;
}
let (k, v) = (parts.first().unwrap().trim(), parts.get(1).unwrap().trim());
let v = v.trim_start_matches('\'');
let v = v.trim_end_matches('\'');
temp.insert(k, do_var_replacements(v.to_string(), cargo_version)?);
}
let control_file = ControlFile {
comment: temp
.get("comment")
.ok_or(ControlFileError::MissingField { field: "comment" })?
.to_string(),
default_version: temp
.get("default_version")
.ok_or(ControlFileError::MissingField { field: "default_version" })?
.to_string(),
module_pathname: temp.get("module_pathname").map(|v| v.to_string()),
relocatable: temp
.get("relocatable")
.ok_or(ControlFileError::MissingField { field: "relocatable" })?
== "true",
superuser: temp
.get("superuser")
.ok_or(ControlFileError::MissingField { field: "superuser" })?
== "true",
schema: temp.get("schema").map(|v| v.to_string()),
trusted: if let Some(v) = temp.get("trusted") { v == "true" } else { false },
};
if !control_file.superuser && control_file.trusted {
return Err(ControlFileError::RedundantField { field: "trusted" });
}
Ok(control_file)
}
}
#[cfg(test)]
mod tests {
use super::ControlFile;
const CONTROL_WITH_CARGO_VERSION: &str = "\
comment = 'test extension'
default_version = '@CARGO_VERSION@'
relocatable = false
superuser = false
";
#[test]
fn uses_the_supplied_cargo_version_for_substitution() {
let control = ControlFile::from_str_with_cargo_version(CONTROL_WITH_CARGO_VERSION, "0.0.0")
.expect("control file should parse");
assert_eq!(control.default_version, "0.0.0");
}
}
impl From<ControlFile> for SqlGraphEntity<'_> {
fn from(val: ControlFile) -> Self {
SqlGraphEntity::ExtensionRoot(val)
}
}
#[derive(Debug, Error)]
pub enum ControlFileError {
#[error("Filesystem error reading control file")]
IOError {
#[from]
error: std::io::Error,
},
#[error("Missing field in control file! Please add `{field}`.")]
MissingField { field: &'static str },
#[error("Redundant field in control file! Please remove `{field}`.")]
RedundantField { field: &'static str },
#[error("Missing environment variable: {0}")]
MissingEnvvar(String),
}
impl TryFrom<PathBuf> for ControlFile {
type Error = ControlFileError;
fn try_from(value: PathBuf) -> Result<Self, Self::Error> {
let contents = std::fs::read_to_string(value)?;
ControlFile::try_from(contents.as_str())
}
}
impl TryFrom<&str> for ControlFile {
type Error = ControlFileError;
fn try_from(input: &str) -> Result<Self, Self::Error> {
Self::from_str(input)
}
}
impl ToSql for ControlFile {
fn to_sql(&self, _context: &super::PgrxSql) -> eyre::Result<String> {
let comment = r#"
/*
This file is auto generated by pgrx.
The ordering of items is not stable, it is driven by a dependency graph.
*/
"#;
Ok(comment.into())
}
}
impl SqlGraphIdentifier for ControlFile {
fn dot_identifier(&self) -> String {
"extension root".into()
}
fn rust_identifier(&self) -> String {
"root".into()
}
fn file(&self) -> Option<&str> {
None
}
fn line(&self) -> Option<u32> {
None
}
}