Skip to main content

pg_ephemeral/
config.rs

1use super::InstanceName;
2use crate::definition::{Definition, SslConfig};
3use crate::image::Image;
4use crate::seed::{Command, CommandCacheConfig, Seed, SeedName};
5
6#[derive(Clone, Debug, PartialEq)]
7pub struct Instance {
8    pub application_name: Option<pg_client::ApplicationName>,
9    pub backend: ociman::backend::Selection,
10    pub database: pg_client::Database,
11    pub seeds: indexmap::IndexMap<SeedName, Seed>,
12    pub ssl_config: Option<SslConfig>,
13    pub superuser: pg_client::User,
14    pub image: Image,
15    pub cross_container_access: bool,
16    pub wait_available_timeout: std::time::Duration,
17}
18
19impl Instance {
20    #[must_use]
21    pub fn new(backend: ociman::backend::Selection, image: Image) -> Self {
22        Self {
23            backend,
24            application_name: None,
25            seeds: indexmap::IndexMap::new(),
26            ssl_config: None,
27            superuser: pg_client::User::POSTGRES,
28            database: pg_client::Database::POSTGRES,
29            image,
30            cross_container_access: false,
31            wait_available_timeout: std::time::Duration::from_secs(10),
32        }
33    }
34
35    pub async fn definition(
36        &self,
37        instance_name: &crate::InstanceName,
38    ) -> Result<Definition, ociman::backend::resolve::Error> {
39        Ok(Definition {
40            instance_name: instance_name.clone(),
41            application_name: self.application_name.clone(),
42            backend: self.backend.resolve().await?,
43            database: self.database.clone(),
44            seeds: self.seeds.clone(),
45            ssl_config: self.ssl_config.clone(),
46            superuser: self.superuser.clone(),
47            image: self.image.clone(),
48            cross_container_access: self.cross_container_access,
49            wait_available_timeout: self.wait_available_timeout,
50            remove: true,
51        })
52    }
53}
54
55#[derive(Debug, thiserror::Error, PartialEq)]
56pub enum Error {
57    #[error("Could not load config file: {0}")]
58    IO(IoError),
59    #[error("Decoding as toml failed: {0}")]
60    TomlDecode(#[from] toml::de::Error),
61    #[error("Instance {instance_name} does not specify {field} and no default applies")]
62    MissingInstanceField {
63        instance_name: InstanceName,
64        field: &'static str,
65    },
66}
67
68#[derive(Debug, PartialEq)]
69pub struct IoError(pub std::io::ErrorKind);
70
71impl std::fmt::Display for IoError {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        write!(f, "{}", std::io::Error::from(self.0))
74    }
75}
76
77impl std::error::Error for IoError {}
78
79impl From<std::io::Error> for IoError {
80    fn from(error: std::io::Error) -> Self {
81        Self(error.kind())
82    }
83}
84
85#[derive(Clone, Debug, serde::Deserialize, PartialEq)]
86#[serde(tag = "type", rename_all = "kebab-case")]
87pub enum SeedConfig {
88    SqlFile {
89        path: std::path::PathBuf,
90        git_revision: Option<String>,
91    },
92    Command {
93        command: String,
94        #[serde(default)]
95        arguments: Vec<String>,
96        cache: CommandCacheConfig,
97    },
98    Script {
99        script: String,
100    },
101    ContainerScript {
102        script: String,
103    },
104}
105
106impl From<SeedConfig> for Seed {
107    fn from(value: SeedConfig) -> Self {
108        match value {
109            SeedConfig::SqlFile { path, git_revision } => match git_revision {
110                Some(git_revision) => Seed::SqlFileGitRevision { git_revision, path },
111                None => Seed::SqlFile { path },
112            },
113            SeedConfig::Command {
114                command,
115                arguments,
116                cache,
117            } => Seed::Command {
118                command: Command::new(command, arguments),
119                cache,
120            },
121            SeedConfig::Script { script } => Seed::Script { script },
122            SeedConfig::ContainerScript { script } => Seed::ContainerScript { script },
123        }
124    }
125}
126
127#[derive(Clone, Debug, serde::Deserialize, PartialEq)]
128#[serde(deny_unknown_fields)]
129pub struct SslConfigDefinition {
130    pub hostname: pg_client::HostName,
131}
132
133#[derive(Debug, serde::Deserialize, PartialEq)]
134#[serde(deny_unknown_fields)]
135pub struct InstanceDefinition {
136    pub backend: Option<ociman::backend::Selection>,
137    pub image: Option<Image>,
138    #[serde(default)]
139    pub seeds: indexmap::IndexMap<SeedName, SeedConfig>,
140    pub ssl_config: Option<SslConfigDefinition>,
141    #[serde(default, with = "humantime_serde")]
142    pub wait_available_timeout: Option<std::time::Duration>,
143}
144
145impl InstanceDefinition {
146    #[must_use]
147    pub fn empty() -> Self {
148        Self {
149            backend: None,
150            image: None,
151            seeds: indexmap::IndexMap::new(),
152            ssl_config: None,
153            wait_available_timeout: None,
154        }
155    }
156
157    fn into_instance(
158        self,
159        instance_name: &InstanceName,
160        defaults: &InstanceDefinition,
161        overwrites: &InstanceDefinition,
162    ) -> Result<Instance, Error> {
163        let image = match overwrites
164            .image
165            .as_ref()
166            .or(self.image.as_ref())
167            .or(defaults.image.as_ref())
168        {
169            Some(image) => image.clone(),
170            None => {
171                return Err(Error::MissingInstanceField {
172                    instance_name: instance_name.clone(),
173                    field: "image",
174                });
175            }
176        };
177
178        let backend = overwrites
179            .backend
180            .or(self.backend)
181            .or(defaults.backend)
182            .unwrap_or(ociman::backend::Selection::Auto);
183
184        let seeds = self
185            .seeds
186            .into_iter()
187            .map(|(name, seed_config)| (name, seed_config.into()))
188            .collect();
189
190        let ssl_config = overwrites
191            .ssl_config
192            .as_ref()
193            .or(self.ssl_config.as_ref())
194            .or(defaults.ssl_config.as_ref())
195            .map(|ssl_config_def| SslConfig::Generated {
196                hostname: ssl_config_def.hostname.clone(),
197            });
198
199        let wait_available_timeout = overwrites
200            .wait_available_timeout
201            .or(self.wait_available_timeout)
202            .or(defaults.wait_available_timeout)
203            .unwrap_or(std::time::Duration::from_secs(10));
204
205        Ok(Instance {
206            application_name: None,
207            backend,
208            database: pg_client::Database::POSTGRES,
209            seeds,
210            ssl_config,
211            superuser: pg_client::User::POSTGRES,
212            image,
213            cross_container_access: false,
214            wait_available_timeout,
215        })
216    }
217}
218
219#[derive(Debug, serde::Deserialize, PartialEq)]
220#[serde(deny_unknown_fields)]
221pub struct Config {
222    image: Option<Image>,
223    backend: Option<ociman::backend::Selection>,
224    ssl_config: Option<SslConfigDefinition>,
225    #[serde(default, with = "humantime_serde")]
226    wait_available_timeout: Option<std::time::Duration>,
227    instances: Option<std::collections::BTreeMap<InstanceName, InstanceDefinition>>,
228}
229
230impl std::default::Default for Config {
231    fn default() -> Self {
232        Self {
233            image: Some(Image::default()),
234            backend: None,
235            ssl_config: None,
236            wait_available_timeout: None,
237            instances: None,
238        }
239    }
240}
241
242impl Config {
243    pub fn load_toml_file(
244        file: impl AsRef<std::path::Path>,
245        overwrites: &InstanceDefinition,
246    ) -> Result<super::InstanceMap, Error> {
247        let file = file.as_ref();
248        let base_dir = file
249            .parent()
250            .map(std::path::Path::to_path_buf)
251            .unwrap_or_default();
252
253        std::fs::read_to_string(file)
254            .map_err(|error| Error::IO(error.into()))
255            .and_then(Self::load_toml)
256            .map(|config| config.resolve_paths(&base_dir))
257            .and_then(|config| config.instance_map(overwrites))
258    }
259
260    fn resolve_paths(mut self, base_dir: &std::path::Path) -> Self {
261        let resolve_path = |path: std::path::PathBuf| -> std::path::PathBuf {
262            if path.is_relative() {
263                base_dir.join(path)
264            } else {
265                path
266            }
267        };
268
269        // Resolve a command string if it looks like a relative file path (contains a
270        // path separator). Plain command names such as "sh" or "psql" are left alone
271        // so they continue to be resolved via PATH.
272        let resolve_command = |command: &mut String| {
273            let path = std::path::Path::new(command.as_str());
274            if path.is_relative() && path.components().count() > 1 {
275                // Strip leading CurDir (`.`) components so `./bin/foo` and `bin/foo`
276                // both produce the same absolute result after joining.
277                let stripped: std::path::PathBuf = path
278                    .components()
279                    .filter(|c| !matches!(c, std::path::Component::CurDir))
280                    .collect();
281                *command = base_dir.join(stripped).to_string_lossy().into_owned();
282            }
283        };
284
285        if let Some(instances) = self.instances.as_mut() {
286            for instance in instances.values_mut() {
287                for seed in instance.seeds.values_mut() {
288                    match seed {
289                        SeedConfig::SqlFile { path, .. } => *path = resolve_path(path.clone()),
290                        SeedConfig::Command { command, cache, .. } => {
291                            resolve_command(command);
292                            if let CommandCacheConfig::KeyCommand {
293                                command: key_command,
294                                ..
295                            } = cache
296                            {
297                                resolve_command(key_command);
298                            }
299                        }
300                        SeedConfig::Script { .. } | SeedConfig::ContainerScript { .. } => {}
301                    }
302                }
303            }
304        }
305
306        self
307    }
308
309    pub fn load_toml(contents: impl AsRef<str>) -> Result<Config, Error> {
310        toml::from_str(contents.as_ref()).map_err(Error::TomlDecode)
311    }
312
313    pub fn instance_map(
314        self,
315        overwrites: &InstanceDefinition,
316    ) -> Result<super::InstanceMap, Error> {
317        let defaults = InstanceDefinition {
318            backend: self.backend,
319            image: self.image.clone(),
320            seeds: indexmap::IndexMap::new(),
321            ssl_config: self.ssl_config.clone(),
322            wait_available_timeout: self.wait_available_timeout,
323        };
324
325        match self.instances {
326            None => {
327                let instance_name = InstanceName::default();
328
329                InstanceDefinition::empty()
330                    .into_instance(&instance_name, &defaults, overwrites)
331                    .map(|instance| [(instance_name, instance)].into())
332            }
333            Some(map) => {
334                let mut instance_map = std::collections::BTreeMap::new();
335
336                for (instance_name, instance_definition) in map {
337                    let instance =
338                        instance_definition.into_instance(&instance_name, &defaults, overwrites)?;
339
340                    instance_map.insert(instance_name, instance);
341                }
342
343                Ok(instance_map)
344            }
345        }
346    }
347}
348
349#[cfg(test)]
350mod test {
351    use super::*;
352
353    #[test]
354    fn sql_file_path_resolved_relative_to_config() {
355        let dir = std::env::temp_dir().join("pg-ephemeral-config-test-sql-file");
356        std::fs::create_dir_all(&dir).unwrap();
357        let config_path = dir.join("database.toml");
358        std::fs::write(
359            &config_path,
360            indoc::indoc! {r#"
361                image = "15.6"
362
363                [instances.main.seeds.schema]
364                type = "sql-file"
365                path = "db/structure.sql"
366            "#},
367        )
368        .unwrap();
369
370        let instance_map =
371            Config::load_toml_file(&config_path, &InstanceDefinition::empty()).unwrap();
372
373        let instance_name: crate::InstanceName = "main".parse().unwrap();
374        let instance = instance_map.get(&instance_name).unwrap();
375        let seed_name: crate::seed::SeedName = "schema".parse().unwrap();
376
377        assert_eq!(
378            instance.seeds[&seed_name],
379            crate::seed::Seed::SqlFile {
380                path: dir.join("db/structure.sql"),
381            }
382        );
383    }
384
385    #[test]
386    fn command_path_resolved_relative_to_config() {
387        let dir = std::env::temp_dir().join("pg-ephemeral-config-test-command");
388        std::fs::create_dir_all(&dir).unwrap();
389        let config_path = dir.join("database.toml");
390        std::fs::write(
391            &config_path,
392            indoc::indoc! {r#"
393                image = "15.6"
394
395                [instances.main.seeds.migrate]
396                type = "command"
397                command = "./bin/migrate"
398                arguments = ["up"]
399                cache = { type = "none" }
400            "#},
401        )
402        .unwrap();
403
404        let instance_map =
405            Config::load_toml_file(&config_path, &InstanceDefinition::empty()).unwrap();
406
407        let instance_name: crate::InstanceName = "main".parse().unwrap();
408        let instance = instance_map.get(&instance_name).unwrap();
409        let seed_name: crate::seed::SeedName = "migrate".parse().unwrap();
410
411        assert_eq!(
412            instance.seeds[&seed_name],
413            crate::seed::Seed::Command {
414                command: crate::seed::Command::new(
415                    dir.join("bin/migrate").to_string_lossy(),
416                    ["up"],
417                ),
418                cache: crate::seed::CommandCacheConfig::None,
419            }
420        );
421    }
422
423    #[test]
424    fn bare_command_name_not_resolved() {
425        let dir = std::env::temp_dir().join("pg-ephemeral-config-test-bare-command");
426        std::fs::create_dir_all(&dir).unwrap();
427        let config_path = dir.join("database.toml");
428        std::fs::write(
429            &config_path,
430            indoc::indoc! {r#"
431                image = "15.6"
432
433                [instances.main.seeds.schema]
434                type = "command"
435                command = "psql"
436                arguments = ["-f", "schema.sql"]
437                cache = { type = "command-hash" }
438            "#},
439        )
440        .unwrap();
441
442        let instance_map =
443            Config::load_toml_file(&config_path, &InstanceDefinition::empty()).unwrap();
444
445        let instance_name: crate::InstanceName = "main".parse().unwrap();
446        let instance = instance_map.get(&instance_name).unwrap();
447        let seed_name: crate::seed::SeedName = "schema".parse().unwrap();
448
449        assert_eq!(
450            instance.seeds[&seed_name],
451            crate::seed::Seed::Command {
452                command: crate::seed::Command::new("psql", ["-f", "schema.sql"]),
453                cache: crate::seed::CommandCacheConfig::CommandHash,
454            }
455        );
456    }
457
458    #[test]
459    fn container_script_parsed() {
460        let dir = std::env::temp_dir().join("pg-ephemeral-config-test-container-script");
461        std::fs::create_dir_all(&dir).unwrap();
462        let config_path = dir.join("database.toml");
463        std::fs::write(
464            &config_path,
465            indoc::indoc! {r#"
466                image = "15.6"
467
468                [instances.main.seeds.install-ext]
469                type = "container-script"
470                script = "apt-get update && apt-get install -y postgresql-15-cron"
471            "#},
472        )
473        .unwrap();
474
475        let instance_map =
476            Config::load_toml_file(&config_path, &InstanceDefinition::empty()).unwrap();
477
478        let instance_name: crate::InstanceName = "main".parse().unwrap();
479        let instance = instance_map.get(&instance_name).unwrap();
480        let seed_name: crate::seed::SeedName = "install-ext".parse().unwrap();
481
482        assert_eq!(
483            instance.seeds[&seed_name],
484            crate::seed::Seed::ContainerScript {
485                script: "apt-get update && apt-get install -y postgresql-15-cron".to_string(),
486            }
487        );
488    }
489}