pglite_oxide/pglite/
config.rs1use std::collections::BTreeMap;
2
3use anyhow::{Result, bail, ensure};
4
5use crate::pglite::interface::DebugLevel;
6
7#[derive(Debug, Clone, Default, PartialEq, Eq)]
13pub struct PostgresConfig {
14 settings: BTreeMap<String, String>,
15}
16
17impl PostgresConfig {
18 pub fn new() -> Self {
20 Self::default()
21 }
22
23 pub fn set(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
25 self.settings.insert(name.into(), value.into());
26 self
27 }
28
29 pub(crate) fn insert(&mut self, name: impl Into<String>, value: impl Into<String>) {
30 self.settings.insert(name.into(), value.into());
31 }
32
33 pub(crate) fn validate(&self) -> Result<()> {
34 for (name, value) in &self.settings {
35 validate_guc_name(name)?;
36 ensure!(
37 !value.contains('\0'),
38 "Postgres config value for '{name}' must not contain NUL bytes"
39 );
40 }
41 Ok(())
42 }
43
44 pub(crate) fn iter(&self) -> impl Iterator<Item = (&str, &str)> {
45 self.settings
46 .iter()
47 .map(|(name, value)| (name.as_str(), value.as_str()))
48 }
49
50 #[cfg(feature = "extensions")]
51 pub(crate) fn stable_entries(&self) -> Vec<(String, String)> {
52 self.settings
53 .iter()
54 .map(|(name, value)| (name.clone(), value.clone()))
55 .collect()
56 }
57}
58
59#[derive(Debug, Clone, PartialEq, Eq)]
60pub(crate) struct StartupConfig {
61 pub(crate) username: String,
62 pub(crate) database: String,
63 pub(crate) debug_level: Option<DebugLevel>,
64 pub(crate) relaxed_durability: bool,
65 pub(crate) extra_args: Vec<String>,
66}
67
68impl Default for StartupConfig {
69 fn default() -> Self {
70 Self {
71 username: "postgres".to_owned(),
72 database: "template1".to_owned(),
73 debug_level: None,
74 relaxed_durability: false,
75 extra_args: Vec::new(),
76 }
77 }
78}
79
80impl StartupConfig {
81 pub(crate) fn validate(&self) -> Result<()> {
82 validate_startup_value("username", &self.username)?;
83 validate_startup_value("database", &self.database)?;
84 if let Some(level) = self.debug_level {
85 ensure!(
86 level <= 5,
87 "Postgres debug level must be between 0 and 5, got {level}"
88 );
89 }
90 for arg in &self.extra_args {
91 ensure!(
92 !arg.contains('\0'),
93 "Postgres startup argument must not contain NUL bytes"
94 );
95 }
96 Ok(())
97 }
98}
99
100fn validate_guc_name(name: &str) -> Result<()> {
101 ensure!(!name.is_empty(), "Postgres config name must not be empty");
102 ensure!(
103 !name.contains('\0') && !name.contains('='),
104 "Postgres config name '{name}' must not contain NUL bytes or '='"
105 );
106
107 for part in name.split('.') {
108 if part.is_empty() {
109 bail!("Postgres config name '{name}' contains an empty identifier part");
110 }
111 let mut chars = part.chars();
112 let first = chars.next().expect("part is non-empty");
113 if !(first == '_' || first.is_ascii_alphabetic()) {
114 bail!("Postgres config name '{name}' must start each identifier with a letter or '_'");
115 }
116 if chars.any(|ch| !(ch == '_' || ch.is_ascii_alphanumeric())) {
117 bail!("Postgres config name '{name}' may only contain letters, digits, '_', and '.'");
118 }
119 }
120
121 Ok(())
122}
123
124fn validate_startup_value(name: &str, value: &str) -> Result<()> {
125 ensure!(
126 !value.is_empty(),
127 "Postgres startup {name} must not be empty"
128 );
129 ensure!(
130 !value.contains('\0'),
131 "Postgres startup {name} must not contain NUL bytes"
132 );
133 Ok(())
134}
135
136#[cfg(test)]
137mod tests {
138 use super::PostgresConfig;
139
140 #[test]
141 fn validates_builtin_and_extension_guc_names() {
142 PostgresConfig::new()
143 .set("synchronous_commit", "off")
144 .set("pg_stat_statements.track", "all")
145 .validate()
146 .unwrap();
147 }
148
149 #[test]
150 fn rejects_invalid_guc_names_before_startup() {
151 let err = PostgresConfig::new()
152 .set("bad=name", "off")
153 .validate()
154 .expect_err("invalid GUC name should be rejected");
155 assert!(err.to_string().contains("must not contain"));
156 }
157}