Skip to main content

pg_ephemeral/
config.rs

1use super::InstanceName;
2use crate::definition::{Definition, SslConfig};
3use crate::image::Image;
4use crate::seed::{Command, CommandCacheConfig, Seed, SeedName};
5
6#[derive(Clone, Debug, PartialEq)]
7pub struct Instance {
8    pub application_name: Option<pg_client::config::ApplicationName>,
9    pub backend: ociman::backend::Selection,
10    pub database: pg_client::Database,
11    pub seeds: indexmap::IndexMap<SeedName, Seed>,
12    pub ssl_config: Option<SslConfig>,
13    pub superuser: pg_client::User,
14    pub image: Image,
15    pub cross_container_access: bool,
16    pub wait_available_timeout: std::time::Duration,
17}
18
19impl Instance {
20    #[must_use]
21    pub fn new(backend: ociman::backend::Selection, image: Image) -> Self {
22        Self {
23            backend,
24            application_name: None,
25            seeds: indexmap::IndexMap::new(),
26            ssl_config: None,
27            superuser: pg_client::User::POSTGRES,
28            database: pg_client::Database::POSTGRES,
29            image,
30            cross_container_access: false,
31            wait_available_timeout: std::time::Duration::from_secs(10),
32        }
33    }
34
35    pub async fn definition(
36        &self,
37        instance_name: &crate::InstanceName,
38    ) -> Result<Definition, ociman::backend::resolve::Error> {
39        Ok(Definition {
40            instance_name: instance_name.clone(),
41            application_name: self.application_name.clone(),
42            backend: self.backend.resolve().await?,
43            database: self.database.clone(),
44            seeds: self.seeds.clone(),
45            ssl_config: self.ssl_config.clone(),
46            superuser: self.superuser.clone(),
47            image: self.image.clone(),
48            cross_container_access: self.cross_container_access,
49            wait_available_timeout: self.wait_available_timeout,
50            remove: true,
51        })
52    }
53}
54
55#[derive(Debug, thiserror::Error, PartialEq)]
56pub enum Error {
57    #[error("Could not load config file: {0}")]
58    IO(IoError),
59    #[error("Decoding as toml failed: {0}")]
60    TomlDecode(#[from] toml::de::Error),
61    #[error("Instance {instance_name} does not specify {field} and no default applies")]
62    MissingInstanceField {
63        instance_name: InstanceName,
64        field: &'static str,
65    },
66}
67
68#[derive(Debug, PartialEq)]
69pub struct IoError(pub std::io::ErrorKind);
70
71impl std::fmt::Display for IoError {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        write!(f, "{}", std::io::Error::from(self.0))
74    }
75}
76
77impl std::error::Error for IoError {}
78
79impl From<std::io::Error> for IoError {
80    fn from(error: std::io::Error) -> Self {
81        Self(error.kind())
82    }
83}
84
85#[derive(Clone, Debug, serde::Deserialize, PartialEq)]
86#[serde(tag = "type", rename_all = "kebab-case")]
87pub enum SeedConfig {
88    SqlFile {
89        path: std::path::PathBuf,
90        git_revision: Option<String>,
91    },
92    Command {
93        command: String,
94        #[serde(default)]
95        arguments: Vec<String>,
96        cache: CommandCacheConfig,
97    },
98    Script {
99        script: String,
100    },
101    ContainerScript {
102        script: String,
103    },
104    CsvFile {
105        path: std::path::PathBuf,
106        table: pg_client::QualifiedTable,
107    },
108}
109
110impl From<SeedConfig> for Seed {
111    fn from(value: SeedConfig) -> Self {
112        match value {
113            SeedConfig::SqlFile { path, git_revision } => match git_revision {
114                Some(git_revision) => Seed::SqlFileGitRevision { git_revision, path },
115                None => Seed::SqlFile { path },
116            },
117            SeedConfig::Command {
118                command,
119                arguments,
120                cache,
121            } => Seed::Command {
122                command: Command::new(command, arguments),
123                cache,
124            },
125            SeedConfig::Script { script } => Seed::Script { script },
126            SeedConfig::ContainerScript { script } => Seed::ContainerScript { script },
127            SeedConfig::CsvFile { path, table } => Seed::CsvFile { path, table },
128        }
129    }
130}
131
132#[derive(Clone, Debug, serde::Deserialize, PartialEq)]
133#[serde(deny_unknown_fields)]
134pub struct SslConfigDefinition {
135    pub hostname: pg_client::config::HostName,
136}
137
138#[derive(Debug, serde::Deserialize, PartialEq)]
139#[serde(deny_unknown_fields)]
140pub struct InstanceDefinition {
141    pub backend: Option<ociman::backend::Selection>,
142    pub image: Option<Image>,
143    #[serde(default)]
144    pub seeds: indexmap::IndexMap<SeedName, SeedConfig>,
145    pub ssl_config: Option<SslConfigDefinition>,
146    #[serde(default, with = "humantime_serde")]
147    pub wait_available_timeout: Option<std::time::Duration>,
148}
149
150impl InstanceDefinition {
151    #[must_use]
152    pub fn empty() -> Self {
153        Self {
154            backend: None,
155            image: None,
156            seeds: indexmap::IndexMap::new(),
157            ssl_config: None,
158            wait_available_timeout: None,
159        }
160    }
161
162    fn into_instance(
163        self,
164        instance_name: &InstanceName,
165        defaults: &InstanceDefinition,
166        overwrites: &InstanceDefinition,
167    ) -> Result<Instance, Error> {
168        let image = match overwrites
169            .image
170            .as_ref()
171            .or(self.image.as_ref())
172            .or(defaults.image.as_ref())
173        {
174            Some(image) => image.clone(),
175            None => {
176                return Err(Error::MissingInstanceField {
177                    instance_name: instance_name.clone(),
178                    field: "image",
179                });
180            }
181        };
182
183        let backend = overwrites
184            .backend
185            .or(self.backend)
186            .or(defaults.backend)
187            .unwrap_or(ociman::backend::Selection::Auto);
188
189        let seeds = self
190            .seeds
191            .into_iter()
192            .map(|(name, seed_config)| (name, seed_config.into()))
193            .collect();
194
195        let ssl_config = overwrites
196            .ssl_config
197            .as_ref()
198            .or(self.ssl_config.as_ref())
199            .or(defaults.ssl_config.as_ref())
200            .map(|ssl_config_def| SslConfig::Generated {
201                hostname: ssl_config_def.hostname.clone(),
202            });
203
204        let wait_available_timeout = overwrites
205            .wait_available_timeout
206            .or(self.wait_available_timeout)
207            .or(defaults.wait_available_timeout)
208            .unwrap_or(std::time::Duration::from_secs(10));
209
210        Ok(Instance {
211            application_name: None,
212            backend,
213            database: pg_client::Database::POSTGRES,
214            seeds,
215            ssl_config,
216            superuser: pg_client::User::POSTGRES,
217            image,
218            cross_container_access: false,
219            wait_available_timeout,
220        })
221    }
222}
223
224#[derive(Debug, serde::Deserialize, PartialEq)]
225#[serde(deny_unknown_fields)]
226pub struct Config {
227    image: Option<Image>,
228    backend: Option<ociman::backend::Selection>,
229    ssl_config: Option<SslConfigDefinition>,
230    #[serde(default, with = "humantime_serde")]
231    wait_available_timeout: Option<std::time::Duration>,
232    instances: Option<std::collections::BTreeMap<InstanceName, InstanceDefinition>>,
233}
234
235impl std::default::Default for Config {
236    fn default() -> Self {
237        Self {
238            image: Some(Image::default()),
239            backend: None,
240            ssl_config: None,
241            wait_available_timeout: None,
242            instances: None,
243        }
244    }
245}
246
247impl Config {
248    pub fn load_toml_file(
249        file: impl AsRef<std::path::Path>,
250        overwrites: &InstanceDefinition,
251    ) -> Result<super::InstanceMap, Error> {
252        let file = file.as_ref();
253        let base_dir = file
254            .parent()
255            .map(std::path::Path::to_path_buf)
256            .unwrap_or_default();
257
258        std::fs::read_to_string(file)
259            .map_err(|error| Error::IO(error.into()))
260            .and_then(Self::load_toml)
261            .map(|config| config.resolve_paths(&base_dir))
262            .and_then(|config| config.instance_map(overwrites))
263    }
264
265    fn resolve_paths(mut self, base_dir: &std::path::Path) -> Self {
266        let resolve_path = |path: std::path::PathBuf| -> std::path::PathBuf {
267            if path.is_relative() {
268                base_dir.join(path)
269            } else {
270                path
271            }
272        };
273
274        // Resolve a command string if it looks like a relative file path (contains a
275        // path separator). Plain command names such as "sh" or "psql" are left alone
276        // so they continue to be resolved via PATH.
277        let resolve_command = |command: &mut String| {
278            let path = std::path::Path::new(command.as_str());
279            if path.is_relative() && path.components().count() > 1 {
280                // Strip leading CurDir (`.`) components so `./bin/foo` and `bin/foo`
281                // both produce the same absolute result after joining.
282                let stripped: std::path::PathBuf = path
283                    .components()
284                    .filter(|c| !matches!(c, std::path::Component::CurDir))
285                    .collect();
286                *command = base_dir.join(stripped).to_string_lossy().into_owned();
287            }
288        };
289
290        if let Some(instances) = self.instances.as_mut() {
291            for instance in instances.values_mut() {
292                for seed in instance.seeds.values_mut() {
293                    match seed {
294                        SeedConfig::SqlFile { path, .. } => *path = resolve_path(path.clone()),
295                        SeedConfig::Command { command, cache, .. } => {
296                            resolve_command(command);
297                            if let CommandCacheConfig::KeyCommand {
298                                command: key_command,
299                                ..
300                            } = cache
301                            {
302                                resolve_command(key_command);
303                            }
304                        }
305                        SeedConfig::CsvFile { path, .. } => *path = resolve_path(path.clone()),
306                        SeedConfig::Script { .. } | SeedConfig::ContainerScript { .. } => {}
307                    }
308                }
309            }
310        }
311
312        self
313    }
314
315    pub fn load_toml(contents: impl AsRef<str>) -> Result<Config, Error> {
316        toml::from_str(contents.as_ref()).map_err(Error::TomlDecode)
317    }
318
319    pub fn instance_map(
320        self,
321        overwrites: &InstanceDefinition,
322    ) -> Result<super::InstanceMap, Error> {
323        let defaults = InstanceDefinition {
324            backend: self.backend,
325            image: self.image.clone(),
326            seeds: indexmap::IndexMap::new(),
327            ssl_config: self.ssl_config.clone(),
328            wait_available_timeout: self.wait_available_timeout,
329        };
330
331        match self.instances {
332            None => {
333                let instance_name = InstanceName::default();
334
335                InstanceDefinition::empty()
336                    .into_instance(&instance_name, &defaults, overwrites)
337                    .map(|instance| [(instance_name, instance)].into())
338            }
339            Some(map) => {
340                let mut instance_map = std::collections::BTreeMap::new();
341
342                for (instance_name, instance_definition) in map {
343                    let instance =
344                        instance_definition.into_instance(&instance_name, &defaults, overwrites)?;
345
346                    instance_map.insert(instance_name, instance);
347                }
348
349                Ok(instance_map)
350            }
351        }
352    }
353}
354
355#[cfg(test)]
356mod test {
357    use super::*;
358
359    #[test]
360    fn sql_file_path_resolved_relative_to_config() {
361        let dir = std::env::temp_dir().join("pg-ephemeral-config-test-sql-file");
362        std::fs::create_dir_all(&dir).unwrap();
363        let config_path = dir.join("database.toml");
364        std::fs::write(
365            &config_path,
366            indoc::indoc! {r#"
367                image = "15.6"
368
369                [instances.main.seeds.schema]
370                type = "sql-file"
371                path = "db/structure.sql"
372            "#},
373        )
374        .unwrap();
375
376        let instance_map =
377            Config::load_toml_file(&config_path, &InstanceDefinition::empty()).unwrap();
378
379        let instance_name: crate::InstanceName = "main".parse().unwrap();
380        let instance = instance_map.get(&instance_name).unwrap();
381        let seed_name: crate::seed::SeedName = "schema".parse().unwrap();
382
383        assert_eq!(
384            instance.seeds[&seed_name],
385            crate::seed::Seed::SqlFile {
386                path: dir.join("db/structure.sql"),
387            }
388        );
389    }
390
391    #[test]
392    fn command_path_resolved_relative_to_config() {
393        let dir = std::env::temp_dir().join("pg-ephemeral-config-test-command");
394        std::fs::create_dir_all(&dir).unwrap();
395        let config_path = dir.join("database.toml");
396        std::fs::write(
397            &config_path,
398            indoc::indoc! {r#"
399                image = "15.6"
400
401                [instances.main.seeds.migrate]
402                type = "command"
403                command = "./bin/migrate"
404                arguments = ["up"]
405                cache = { type = "none" }
406            "#},
407        )
408        .unwrap();
409
410        let instance_map =
411            Config::load_toml_file(&config_path, &InstanceDefinition::empty()).unwrap();
412
413        let instance_name: crate::InstanceName = "main".parse().unwrap();
414        let instance = instance_map.get(&instance_name).unwrap();
415        let seed_name: crate::seed::SeedName = "migrate".parse().unwrap();
416
417        assert_eq!(
418            instance.seeds[&seed_name],
419            crate::seed::Seed::Command {
420                command: crate::seed::Command::new(
421                    dir.join("bin/migrate").to_string_lossy(),
422                    ["up"],
423                ),
424                cache: crate::seed::CommandCacheConfig::None,
425            }
426        );
427    }
428
429    #[test]
430    fn bare_command_name_not_resolved() {
431        let dir = std::env::temp_dir().join("pg-ephemeral-config-test-bare-command");
432        std::fs::create_dir_all(&dir).unwrap();
433        let config_path = dir.join("database.toml");
434        std::fs::write(
435            &config_path,
436            indoc::indoc! {r#"
437                image = "15.6"
438
439                [instances.main.seeds.schema]
440                type = "command"
441                command = "psql"
442                arguments = ["-f", "schema.sql"]
443                cache = { type = "command-hash" }
444            "#},
445        )
446        .unwrap();
447
448        let instance_map =
449            Config::load_toml_file(&config_path, &InstanceDefinition::empty()).unwrap();
450
451        let instance_name: crate::InstanceName = "main".parse().unwrap();
452        let instance = instance_map.get(&instance_name).unwrap();
453        let seed_name: crate::seed::SeedName = "schema".parse().unwrap();
454
455        assert_eq!(
456            instance.seeds[&seed_name],
457            crate::seed::Seed::Command {
458                command: crate::seed::Command::new("psql", ["-f", "schema.sql"]),
459                cache: crate::seed::CommandCacheConfig::CommandHash,
460            }
461        );
462    }
463
464    #[test]
465    fn container_script_parsed() {
466        let dir = std::env::temp_dir().join("pg-ephemeral-config-test-container-script");
467        std::fs::create_dir_all(&dir).unwrap();
468        let config_path = dir.join("database.toml");
469        std::fs::write(
470            &config_path,
471            indoc::indoc! {r#"
472                image = "15.6"
473
474                [instances.main.seeds.install-ext]
475                type = "container-script"
476                script = "apt-get update && apt-get install -y postgresql-15-cron"
477            "#},
478        )
479        .unwrap();
480
481        let instance_map =
482            Config::load_toml_file(&config_path, &InstanceDefinition::empty()).unwrap();
483
484        let instance_name: crate::InstanceName = "main".parse().unwrap();
485        let instance = instance_map.get(&instance_name).unwrap();
486        let seed_name: crate::seed::SeedName = "install-ext".parse().unwrap();
487
488        assert_eq!(
489            instance.seeds[&seed_name],
490            crate::seed::Seed::ContainerScript {
491                script: "apt-get update && apt-get install -y postgresql-15-cron".to_string(),
492            }
493        );
494    }
495
496    #[test]
497    fn csv_file_parsed() {
498        let dir = std::env::temp_dir().join("pg-ephemeral-config-test-csv-file");
499        std::fs::create_dir_all(&dir).unwrap();
500        let config_path = dir.join("database.toml");
501        std::fs::write(
502            &config_path,
503            indoc::indoc! {r#"
504                image = "15.6"
505
506                [instances.main.seeds.users]
507                type = "csv-file"
508                path = "fixtures/users.csv"
509                table = { schema = "public", table = "users" }
510            "#},
511        )
512        .unwrap();
513
514        let instance_map =
515            Config::load_toml_file(&config_path, &InstanceDefinition::empty()).unwrap();
516
517        let instance_name: crate::InstanceName = "main".parse().unwrap();
518        let instance = instance_map.get(&instance_name).unwrap();
519        let seed_name: crate::seed::SeedName = "users".parse().unwrap();
520
521        assert_eq!(
522            instance.seeds[&seed_name],
523            crate::seed::Seed::CsvFile {
524                path: dir.join("fixtures/users.csv"),
525                table: pg_client::QualifiedTable {
526                    schema: pg_client::identifier::Schema::PUBLIC,
527                    table: "users".parse().unwrap(),
528                },
529            }
530        );
531    }
532}