Skip to main content

pglite_oxide/pglite/
config.rs

1use std::collections::BTreeMap;
2
3use anyhow::{Result, bail, ensure};
4
5use crate::pglite::interface::DebugLevel;
6
7/// PostgreSQL startup configuration applied through normal `postgres -c` GUC
8/// handling before the embedded backend starts.
9///
10/// Settings added here override `pglite-oxide`'s default startup profile because
11/// they are appended after the defaults in the generated PostgreSQL argv.
12#[derive(Debug, Clone, Default, PartialEq, Eq)]
13pub struct PostgresConfig {
14    settings: BTreeMap<String, String>,
15}
16
17impl PostgresConfig {
18    /// Create an empty startup configuration.
19    pub fn new() -> Self {
20        Self::default()
21    }
22
23    /// Set or replace one PostgreSQL GUC.
24    pub fn set(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
25        self.settings.insert(name.into(), value.into());
26        self
27    }
28
29    pub(crate) fn insert(&mut self, name: impl Into<String>, value: impl Into<String>) {
30        self.settings.insert(name.into(), value.into());
31    }
32
33    pub(crate) fn validate(&self) -> Result<()> {
34        for (name, value) in &self.settings {
35            validate_guc_name(name)?;
36            ensure!(
37                !value.contains('\0'),
38                "Postgres config value for '{name}' must not contain NUL bytes"
39            );
40        }
41        Ok(())
42    }
43
44    pub(crate) fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
45        self.settings
46            .iter()
47            .map(|(name, value)| (name.as_str(), value.as_str()))
48    }
49
50    #[cfg(feature = "extensions")]
51    pub(crate) fn stable_entries(&self) -> Vec<(String, String)> {
52        self.settings
53            .iter()
54            .map(|(name, value)| (name.clone(), value.clone()))
55            .collect()
56    }
57}
58
59#[derive(Debug, Clone, PartialEq, Eq)]
60pub(crate) struct StartupConfig {
61    pub(crate) username: String,
62    pub(crate) database: String,
63    pub(crate) debug_level: Option<DebugLevel>,
64    pub(crate) relaxed_durability: bool,
65    pub(crate) extra_args: Vec<String>,
66}
67
68impl Default for StartupConfig {
69    fn default() -> Self {
70        Self {
71            username: "postgres".to_owned(),
72            database: "template1".to_owned(),
73            debug_level: None,
74            relaxed_durability: false,
75            extra_args: Vec::new(),
76        }
77    }
78}
79
80impl StartupConfig {
81    pub(crate) fn validate(&self) -> Result<()> {
82        validate_startup_value("username", &self.username)?;
83        validate_startup_value("database", &self.database)?;
84        if let Some(level) = self.debug_level {
85            ensure!(
86                level <= 5,
87                "Postgres debug level must be between 0 and 5, got {level}"
88            );
89        }
90        for arg in &self.extra_args {
91            ensure!(
92                !arg.contains('\0'),
93                "Postgres startup argument must not contain NUL bytes"
94            );
95        }
96        Ok(())
97    }
98}
99
100fn validate_guc_name(name: &str) -> Result<()> {
101    ensure!(!name.is_empty(), "Postgres config name must not be empty");
102    ensure!(
103        !name.contains('\0') && !name.contains('='),
104        "Postgres config name '{name}' must not contain NUL bytes or '='"
105    );
106
107    for part in name.split('.') {
108        if part.is_empty() {
109            bail!("Postgres config name '{name}' contains an empty identifier part");
110        }
111        let mut chars = part.chars();
112        let first = chars.next().expect("part is non-empty");
113        if !(first == '_' || first.is_ascii_alphabetic()) {
114            bail!("Postgres config name '{name}' must start each identifier with a letter or '_'");
115        }
116        if chars.any(|ch| !(ch == '_' || ch.is_ascii_alphanumeric())) {
117            bail!("Postgres config name '{name}' may only contain letters, digits, '_', and '.'");
118        }
119    }
120
121    Ok(())
122}
123
124fn validate_startup_value(name: &str, value: &str) -> Result<()> {
125    ensure!(
126        !value.is_empty(),
127        "Postgres startup {name} must not be empty"
128    );
129    ensure!(
130        !value.contains('\0'),
131        "Postgres startup {name} must not contain NUL bytes"
132    );
133    Ok(())
134}
135
136#[cfg(test)]
137mod tests {
138    use super::PostgresConfig;
139
140    #[test]
141    fn validates_builtin_and_extension_guc_names() {
142        PostgresConfig::new()
143            .set("synchronous_commit", "off")
144            .set("pg_stat_statements.track", "all")
145            .validate()
146            .unwrap();
147    }
148
149    #[test]
150    fn rejects_invalid_guc_names_before_startup() {
151        let err = PostgresConfig::new()
152            .set("bad=name", "off")
153            .validate()
154            .expect_err("invalid GUC name should be rejected");
155        assert!(err.to_string().contains("must not contain"));
156    }
157}