pgrx_sql_entity_graph/
control_file.rs1use super::{SqlGraphEntity, SqlGraphIdentifier, ToSql};
19use std::collections::HashMap;
20use std::path::{Path, PathBuf};
21use thiserror::Error;
22
23#[derive(Debug, Clone, Hash, PartialOrd, Ord, PartialEq, Eq)]
36pub struct ControlFile {
37 pub comment: String,
38 pub default_version: String,
39 pub module_pathname: Option<String>,
40 pub relocatable: bool,
41 pub superuser: bool,
42 pub schema: Option<String>,
43 pub trusted: bool,
44}
45
46impl ControlFile {
47 #[allow(clippy::should_implement_trait)]
70 pub fn from_str(input: &str) -> Result<Self, ControlFileError> {
71 Self::from_str_with_version(input, None)
72 }
73
74 pub fn from_str_with_cargo_version(
75 input: &str,
76 cargo_version: &str,
77 ) -> Result<Self, ControlFileError> {
78 Self::from_str_with_version(input, Some(cargo_version))
79 }
80
81 pub fn from_path_with_cargo_version(
82 path: impl AsRef<Path>,
83 cargo_version: &str,
84 ) -> Result<Self, ControlFileError> {
85 let contents = std::fs::read_to_string(path)?;
86 Self::from_str_with_cargo_version(contents.as_str(), cargo_version)
87 }
88
89 fn from_str_with_version(
90 input: &str,
91 cargo_version: Option<&str>,
92 ) -> Result<Self, ControlFileError> {
93 fn do_var_replacements(
94 mut input: String,
95 cargo_version: Option<&str>,
96 ) -> Result<String, ControlFileError> {
97 const CARGO_VERSION: &str = "@CARGO_VERSION@";
98
99 if input.contains(CARGO_VERSION) {
101 let cargo_version = match cargo_version {
102 Some(cargo_version) => cargo_version.to_owned(),
103 None => std::env::var("CARGO_PKG_VERSION").map_err(|_| {
104 ControlFileError::MissingEnvvar("CARGO_PKG_VERSION".to_string())
105 })?,
106 };
107 input = input.replace(CARGO_VERSION, &cargo_version);
108 }
109
110 Ok(input)
111 }
112
113 let mut temp = HashMap::new();
114 for line in input.lines() {
115 let parts: Vec<&str> = line.split('=').collect();
116
117 if parts.len() != 2 {
118 continue;
119 }
120
121 let (k, v) = (parts.first().unwrap().trim(), parts.get(1).unwrap().trim());
122
123 let v = v.trim_start_matches('\'');
124 let v = v.trim_end_matches('\'');
125
126 temp.insert(k, do_var_replacements(v.to_string(), cargo_version)?);
127 }
128 let control_file = ControlFile {
129 comment: temp
130 .get("comment")
131 .ok_or(ControlFileError::MissingField { field: "comment" })?
132 .to_string(),
133 default_version: temp
134 .get("default_version")
135 .ok_or(ControlFileError::MissingField { field: "default_version" })?
136 .to_string(),
137 module_pathname: temp.get("module_pathname").map(|v| v.to_string()),
138 relocatable: temp
139 .get("relocatable")
140 .ok_or(ControlFileError::MissingField { field: "relocatable" })?
141 == "true",
142 superuser: temp
143 .get("superuser")
144 .ok_or(ControlFileError::MissingField { field: "superuser" })?
145 == "true",
146 schema: temp.get("schema").map(|v| v.to_string()),
147 trusted: if let Some(v) = temp.get("trusted") { v == "true" } else { false },
148 };
149
150 if !control_file.superuser && control_file.trusted {
151 return Err(ControlFileError::RedundantField { field: "trusted" });
153 }
154
155 Ok(control_file)
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::ControlFile;
162
163 const CONTROL_WITH_CARGO_VERSION: &str = "\
164comment = 'test extension'
165default_version = '@CARGO_VERSION@'
166relocatable = false
167superuser = false
168";
169
170 #[test]
171 fn uses_the_supplied_cargo_version_for_substitution() {
172 let control = ControlFile::from_str_with_cargo_version(CONTROL_WITH_CARGO_VERSION, "0.0.0")
173 .expect("control file should parse");
174
175 assert_eq!(control.default_version, "0.0.0");
176 }
177}
178
179impl From<ControlFile> for SqlGraphEntity<'_> {
180 fn from(val: ControlFile) -> Self {
181 SqlGraphEntity::ExtensionRoot(val)
182 }
183}
184
185#[derive(Debug, Error)]
187pub enum ControlFileError {
188 #[error("Filesystem error reading control file")]
189 IOError {
190 #[from]
191 error: std::io::Error,
192 },
193 #[error("Missing field in control file! Please add `{field}`.")]
194 MissingField { field: &'static str },
195 #[error("Redundant field in control file! Please remove `{field}`.")]
196 RedundantField { field: &'static str },
197 #[error("Missing environment variable: {0}")]
198 MissingEnvvar(String),
199}
200
201impl TryFrom<PathBuf> for ControlFile {
202 type Error = ControlFileError;
203
204 fn try_from(value: PathBuf) -> Result<Self, Self::Error> {
205 let contents = std::fs::read_to_string(value)?;
206 ControlFile::try_from(contents.as_str())
207 }
208}
209
210impl TryFrom<&str> for ControlFile {
211 type Error = ControlFileError;
212
213 fn try_from(input: &str) -> Result<Self, Self::Error> {
214 Self::from_str(input)
215 }
216}
217
218impl ToSql for ControlFile {
219 fn to_sql(&self, _context: &super::PgrxSql) -> eyre::Result<String> {
220 let comment = r#"
221/*
222This file is auto generated by pgrx.
223
224The ordering of items is not stable, it is driven by a dependency graph.
225*/
226"#;
227 Ok(comment.into())
228 }
229}
230
231impl SqlGraphIdentifier for ControlFile {
232 fn dot_identifier(&self) -> String {
233 "extension root".into()
234 }
235 fn rust_identifier(&self) -> String {
236 "root".into()
237 }
238
239 fn file(&self) -> Option<&str> {
240 None
241 }
242
243 fn line(&self) -> Option<u32> {
244 None
245 }
246}