Skip to main content

pg_ephemeral/
config.rs

1use super::InstanceName;
2use crate::definition::{Definition, SslConfig};
3use crate::image::Image;
4use crate::seed::{Command, Seed, SeedCacheConfig, SeedName};
5
6/// Outcome of loading or constructing a [`Config`]: a resolved backend
7/// selection paired with the per-instance map.
8///
9/// Backend selection is a single global property — resolved once at startup
10/// and shared across every instance — so it lives here alongside the map
11/// rather than as a per-[`Instance`] field.
12#[derive(Debug, PartialEq)]
13pub struct Resolved {
14    pub backend_selection: ociman::backend::Selection,
15    pub instances: super::InstanceMap,
16}
17
18#[derive(Clone, Debug, PartialEq)]
19pub struct Instance {
20    pub application_name: Option<pg_client::config::ApplicationName>,
21    pub database: pg_client::Database,
22    pub parameters: pg_client::parameter::Map,
23    pub seeds: indexmap::IndexMap<SeedName, Seed>,
24    pub ssl_config: Option<SslConfig>,
25    pub superuser: pg_client::User,
26    pub image: Image,
27    pub cross_container_access: bool,
28    pub wait_available_timeout: std::time::Duration,
29}
30
31impl Instance {
32    #[must_use]
33    pub fn new(image: Image) -> Self {
34        Self {
35            application_name: None,
36            parameters: pg_client::parameter::Map::new(),
37            seeds: indexmap::IndexMap::new(),
38            ssl_config: None,
39            superuser: pg_client::User::POSTGRES,
40            database: pg_client::Database::POSTGRES,
41            image,
42            cross_container_access: false,
43            wait_available_timeout: std::time::Duration::from_secs(10),
44        }
45    }
46
47    #[must_use]
48    pub fn definition(
49        &self,
50        backend: ociman::Backend,
51        instance_name: &crate::InstanceName,
52    ) -> Definition {
53        Definition {
54            instance_name: instance_name.clone(),
55            application_name: self.application_name.clone(),
56            backend,
57            database: self.database.clone(),
58            parameters: self.parameters.clone(),
59            seeds: self.seeds.clone(),
60            ssl_config: self.ssl_config.clone(),
61            superuser: self.superuser.clone(),
62            image: self.image.clone(),
63            cross_container_access: self.cross_container_access,
64            wait_available_timeout: self.wait_available_timeout,
65            remove: true,
66            session_name: None,
67            transparent_workdir: None,
68        }
69    }
70}
71
72#[derive(Debug, thiserror::Error, PartialEq)]
73pub enum Error {
74    #[error("Could not load config file: {0}")]
75    IO(IoError),
76    #[error("Decoding as toml failed: {0}")]
77    TomlDecode(#[from] toml::de::Error),
78    #[error("Instance {instance_name} does not specify {field} and no default applies")]
79    MissingInstanceField {
80        instance_name: InstanceName,
81        field: &'static str,
82    },
83}
84
85#[derive(Debug, PartialEq)]
86pub struct IoError(pub std::io::ErrorKind);
87
88impl std::fmt::Display for IoError {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        write!(f, "{}", std::io::Error::from(self.0))
91    }
92}
93
94impl std::error::Error for IoError {}
95
96impl From<std::io::Error> for IoError {
97    fn from(error: std::io::Error) -> Self {
98        Self(error.kind())
99    }
100}
101
102#[derive(Clone, Debug, serde::Deserialize, serde::Serialize, PartialEq)]
103#[serde(tag = "type", rename_all = "kebab-case", deny_unknown_fields)]
104pub enum SeedConfig {
105    SqlFile {
106        path: std::path::PathBuf,
107        git_revision: Option<String>,
108    },
109    SqlStatement {
110        statement: String,
111    },
112    Command {
113        command: String,
114        #[serde(default)]
115        arguments: Vec<String>,
116        cache: SeedCacheConfig,
117    },
118    Script {
119        script: String,
120        #[serde(default)]
121        cache: Option<SeedCacheConfig>,
122    },
123    ContainerScript {
124        script: String,
125    },
126    CsvFile {
127        path: std::path::PathBuf,
128        table: pg_client::QualifiedTable,
129        delimiter: Option<char>,
130    },
131}
132
133impl From<SeedConfig> for Seed {
134    fn from(value: SeedConfig) -> Self {
135        match value {
136            SeedConfig::SqlFile { path, git_revision } => match git_revision {
137                Some(git_revision) => Seed::SqlFileGitRevision { git_revision, path },
138                None => Seed::SqlFile { path },
139            },
140            SeedConfig::SqlStatement { statement } => Seed::SqlStatement { statement },
141            SeedConfig::Command {
142                command,
143                arguments,
144                cache,
145            } => Seed::Command {
146                command: Command::new(command, arguments),
147                cache,
148            },
149            SeedConfig::Script { script, cache } => Seed::Script {
150                script,
151                cache: cache.unwrap_or(SeedCacheConfig::CommandHash),
152            },
153            SeedConfig::ContainerScript { script } => Seed::ContainerScript { script },
154            SeedConfig::CsvFile {
155                path,
156                table,
157                delimiter,
158            } => Seed::CsvFile {
159                path,
160                table,
161                delimiter: delimiter.unwrap_or(','),
162            },
163        }
164    }
165}
166
167impl From<&Seed> for SeedConfig {
168    fn from(value: &Seed) -> Self {
169        match value {
170            Seed::SqlFile { path } => SeedConfig::SqlFile {
171                path: path.clone(),
172                git_revision: None,
173            },
174            Seed::SqlFileGitRevision { git_revision, path } => SeedConfig::SqlFile {
175                path: path.clone(),
176                git_revision: Some(git_revision.clone()),
177            },
178            Seed::SqlStatement { statement } => SeedConfig::SqlStatement {
179                statement: statement.clone(),
180            },
181            Seed::Command { command, cache } => SeedConfig::Command {
182                command: command.command.clone(),
183                arguments: command.arguments.clone(),
184                cache: cache.clone(),
185            },
186            Seed::Script { script, cache } => SeedConfig::Script {
187                script: script.clone(),
188                cache: Some(cache.clone()),
189            },
190            Seed::ContainerScript { script } => SeedConfig::ContainerScript {
191                script: script.clone(),
192            },
193            Seed::CsvFile {
194                path,
195                table,
196                delimiter,
197            } => SeedConfig::CsvFile {
198                path: path.clone(),
199                table: table.clone(),
200                delimiter: Some(*delimiter),
201            },
202        }
203    }
204}
205
206#[cfg(test)]
207mod from_seed_tests {
208    use super::*;
209
210    fn round_trip(config: SeedConfig) {
211        let seed: Seed = config.clone().into();
212        let restored: SeedConfig = (&seed).into();
213        assert_eq!(restored, config);
214    }
215
216    #[test]
217    fn round_trip_sql_file_no_git() {
218        round_trip(SeedConfig::SqlFile {
219            path: "schema.sql".into(),
220            git_revision: None,
221        });
222    }
223
224    #[test]
225    fn round_trip_sql_file_with_git() {
226        round_trip(SeedConfig::SqlFile {
227            path: "schema.sql".into(),
228            git_revision: Some("abc1234".to_string()),
229        });
230    }
231
232    #[test]
233    fn round_trip_sql_statement() {
234        round_trip(SeedConfig::SqlStatement {
235            statement: "CREATE TABLE t (id INT)".to_string(),
236        });
237    }
238
239    #[test]
240    fn round_trip_command() {
241        round_trip(SeedConfig::Command {
242            command: "psql".to_string(),
243            arguments: vec!["-c".to_string(), "SELECT 1".to_string()],
244            cache: SeedCacheConfig::CommandHash,
245        });
246    }
247
248    #[test]
249    fn round_trip_script_with_explicit_cache() {
250        round_trip(SeedConfig::Script {
251            script: "psql -c 'SELECT 1'".to_string(),
252            cache: Some(SeedCacheConfig::CommandHash),
253        });
254    }
255
256    #[test]
257    fn script_default_cache_is_recovered_explicitly() {
258        let starting = SeedConfig::Script {
259            script: "x".to_string(),
260            cache: None,
261        };
262        let seed: Seed = starting.into();
263        let restored: SeedConfig = (&seed).into();
264        assert_eq!(
265            restored,
266            SeedConfig::Script {
267                script: "x".to_string(),
268                cache: Some(SeedCacheConfig::CommandHash),
269            }
270        );
271    }
272
273    #[test]
274    fn round_trip_container_script() {
275        round_trip(SeedConfig::ContainerScript {
276            script: "apt-get install -y foo".to_string(),
277        });
278    }
279
280    #[test]
281    fn round_trip_csv_file_with_delimiter() {
282        round_trip(SeedConfig::CsvFile {
283            path: "data.csv".into(),
284            table: pg_client::QualifiedTable {
285                schema: pg_client::identifier::Schema::from_static_or_panic("public"),
286                table: pg_client::identifier::Table::from_static_or_panic("t"),
287            },
288            delimiter: Some(';'),
289        });
290    }
291
292    #[test]
293    fn csv_file_default_delimiter_is_recovered_explicitly() {
294        let starting = SeedConfig::CsvFile {
295            path: "data.csv".into(),
296            table: pg_client::QualifiedTable {
297                schema: pg_client::identifier::Schema::from_static_or_panic("public"),
298                table: pg_client::identifier::Table::from_static_or_panic("t"),
299            },
300            delimiter: None,
301        };
302        let seed: Seed = starting.into();
303        let restored: SeedConfig = (&seed).into();
304        assert_eq!(
305            restored,
306            SeedConfig::CsvFile {
307                path: "data.csv".into(),
308                table: pg_client::QualifiedTable {
309                    schema: pg_client::identifier::Schema::from_static_or_panic("public"),
310                    table: pg_client::identifier::Table::from_static_or_panic("t"),
311                },
312                delimiter: Some(','),
313            }
314        );
315    }
316}
317
318#[derive(Clone, Debug, serde::Deserialize, PartialEq)]
319#[serde(deny_unknown_fields)]
320pub struct SslConfigDefinition {
321    pub hostname: pg_client::config::HostName,
322}
323
324#[derive(Debug, serde::Deserialize, PartialEq)]
325#[serde(deny_unknown_fields)]
326pub struct InstanceDefinition {
327    pub image: Option<Image>,
328    #[serde(default)]
329    pub parameters: pg_client::parameter::Map,
330    #[serde(default)]
331    pub seeds: indexmap::IndexMap<SeedName, SeedConfig>,
332    pub ssl_config: Option<SslConfigDefinition>,
333    #[serde(default, with = "humantime_serde")]
334    pub wait_available_timeout: Option<std::time::Duration>,
335}
336
337impl InstanceDefinition {
338    #[must_use]
339    pub fn empty() -> Self {
340        Self {
341            image: None,
342            parameters: pg_client::parameter::Map::new(),
343            seeds: indexmap::IndexMap::new(),
344            ssl_config: None,
345            wait_available_timeout: None,
346        }
347    }
348
349    fn into_instance(
350        self,
351        instance_name: &InstanceName,
352        defaults: &InstanceDefinition,
353        overwrites: &InstanceDefinition,
354    ) -> Result<Instance, Error> {
355        let image = match overwrites
356            .image
357            .as_ref()
358            .or(self.image.as_ref())
359            .or(defaults.image.as_ref())
360        {
361            Some(image) => image.clone(),
362            None => {
363                return Err(Error::MissingInstanceField {
364                    instance_name: instance_name.clone(),
365                    field: "image",
366                });
367            }
368        };
369
370        let seeds = self
371            .seeds
372            .into_iter()
373            .map(|(name, seed_config)| (name, seed_config.into()))
374            .collect();
375
376        let ssl_config = overwrites
377            .ssl_config
378            .as_ref()
379            .or(self.ssl_config.as_ref())
380            .or(defaults.ssl_config.as_ref())
381            .map(|ssl_config_def| SslConfig::Generated {
382                hostname: ssl_config_def.hostname.clone(),
383            });
384
385        let wait_available_timeout = overwrites
386            .wait_available_timeout
387            .or(self.wait_available_timeout)
388            .or(defaults.wait_available_timeout)
389            .unwrap_or(std::time::Duration::from_secs(10));
390
391        Ok(Instance {
392            application_name: None,
393            database: pg_client::Database::POSTGRES,
394            parameters: self.parameters,
395            seeds,
396            ssl_config,
397            superuser: pg_client::User::POSTGRES,
398            image,
399            cross_container_access: false,
400            wait_available_timeout,
401        })
402    }
403}
404
405#[derive(Debug, serde::Deserialize, PartialEq)]
406#[serde(deny_unknown_fields)]
407pub struct Config {
408    image: Option<Image>,
409    backend: Option<ociman::backend::Selection>,
410    ssl_config: Option<SslConfigDefinition>,
411    #[serde(default, with = "humantime_serde")]
412    wait_available_timeout: Option<std::time::Duration>,
413    instances: Option<std::collections::BTreeMap<InstanceName, InstanceDefinition>>,
414}
415
416impl std::default::Default for Config {
417    fn default() -> Self {
418        Self {
419            image: Some(Image::default()),
420            backend: None,
421            ssl_config: None,
422            wait_available_timeout: None,
423            instances: None,
424        }
425    }
426}
427
428impl Config {
429    pub fn load_toml_file(
430        file: impl AsRef<std::path::Path>,
431        backend_overwrite: Option<ociman::backend::Selection>,
432        overwrites: &InstanceDefinition,
433    ) -> Result<Resolved, Error> {
434        let file = file.as_ref();
435        let base_dir = file
436            .parent()
437            .map(std::path::Path::to_path_buf)
438            .unwrap_or_default();
439
440        std::fs::read_to_string(file)
441            .map_err(|error| Error::IO(error.into()))
442            .and_then(Self::load_toml)
443            .map(|config| config.resolve_paths(&base_dir))
444            .and_then(|config| config.resolve(backend_overwrite, overwrites))
445    }
446
447    fn resolve_paths(mut self, base_dir: &std::path::Path) -> Self {
448        let resolve_path = |path: std::path::PathBuf| -> std::path::PathBuf {
449            if path.is_relative() {
450                base_dir.join(path)
451            } else {
452                path
453            }
454        };
455
456        // Resolve a command string if it looks like a relative file path (contains a
457        // path separator). Plain command names such as "sh" or "psql" are left alone
458        // so they continue to be resolved via PATH.
459        let resolve_command = |command: &mut String| {
460            let path = std::path::Path::new(command.as_str());
461            if path.is_relative() && path.components().count() > 1 {
462                // Strip leading CurDir (`.`) components so `./bin/foo` and `bin/foo`
463                // both produce the same absolute result after joining.
464                let stripped: std::path::PathBuf = path
465                    .components()
466                    .filter(|c| !matches!(c, std::path::Component::CurDir))
467                    .collect();
468                *command = base_dir.join(stripped).to_string_lossy().into_owned();
469            }
470        };
471
472        if let Some(instances) = self.instances.as_mut() {
473            for instance in instances.values_mut() {
474                for seed in instance.seeds.values_mut() {
475                    match seed {
476                        SeedConfig::SqlFile { path, .. } => *path = resolve_path(path.clone()),
477                        SeedConfig::Command { command, cache, .. } => {
478                            resolve_command(command);
479                            if let SeedCacheConfig::KeyCommand {
480                                command: key_command,
481                                ..
482                            } = cache
483                            {
484                                resolve_command(key_command);
485                            }
486                        }
487                        SeedConfig::Script { cache, .. } => {
488                            if let Some(SeedCacheConfig::KeyCommand {
489                                command: key_command,
490                                ..
491                            }) = cache
492                            {
493                                resolve_command(key_command);
494                            }
495                        }
496                        SeedConfig::CsvFile { path, .. } => *path = resolve_path(path.clone()),
497                        SeedConfig::ContainerScript { .. } | SeedConfig::SqlStatement { .. } => {}
498                    }
499                }
500            }
501        }
502
503        self
504    }
505
506    pub fn load_toml(contents: impl AsRef<str>) -> Result<Config, Error> {
507        toml::from_str(contents.as_ref()).map_err(Error::TomlDecode)
508    }
509
510    /// Resolve this config into a [`Resolved`] outcome, applying the
511    /// CLI-level `backend_overwrite` and per-instance `overwrites`.
512    ///
513    /// Backend selection precedence (highest first): CLI `--backend`,
514    /// `Config.backend` from TOML, [`ociman::backend::Selection::Auto`].
515    pub fn resolve(
516        self,
517        backend_overwrite: Option<ociman::backend::Selection>,
518        overwrites: &InstanceDefinition,
519    ) -> Result<Resolved, Error> {
520        let backend_selection = backend_overwrite
521            .or(self.backend)
522            .unwrap_or(ociman::backend::Selection::Auto);
523
524        let defaults = InstanceDefinition {
525            image: self.image.clone(),
526            parameters: pg_client::parameter::Map::new(),
527            seeds: indexmap::IndexMap::new(),
528            ssl_config: self.ssl_config.clone(),
529            wait_available_timeout: self.wait_available_timeout,
530        };
531
532        let instances = match self.instances {
533            None => {
534                let instance_name = InstanceName::default();
535
536                InstanceDefinition::empty()
537                    .into_instance(&instance_name, &defaults, overwrites)
538                    .map(|instance| [(instance_name, instance)].into())?
539            }
540            Some(map) => {
541                let mut instance_map = std::collections::BTreeMap::new();
542
543                for (instance_name, instance_definition) in map {
544                    let instance =
545                        instance_definition.into_instance(&instance_name, &defaults, overwrites)?;
546
547                    instance_map.insert(instance_name, instance);
548                }
549
550                instance_map
551            }
552        };
553
554        Ok(Resolved {
555            backend_selection,
556            instances,
557        })
558    }
559}
560
561#[cfg(test)]
562mod test {
563    use super::*;
564
565    #[test]
566    fn sql_file_path_resolved_relative_to_config() {
567        let dir = std::env::temp_dir().join("pg-ephemeral-config-test-sql-file");
568        std::fs::create_dir_all(&dir).unwrap();
569        let config_path = dir.join("database.toml");
570        std::fs::write(
571            &config_path,
572            indoc::indoc! {r#"
573                image = "15.6"
574
575                [instances.main.seeds.schema]
576                type = "sql-file"
577                path = "db/structure.sql"
578            "#},
579        )
580        .unwrap();
581
582        let resolved =
583            Config::load_toml_file(&config_path, None, &InstanceDefinition::empty()).unwrap();
584
585        let instance_name: crate::InstanceName = "main".parse().unwrap();
586        let instance = resolved.instances.get(&instance_name).unwrap();
587        let seed_name: crate::seed::SeedName = "schema".parse().unwrap();
588
589        assert_eq!(
590            instance.seeds[&seed_name],
591            crate::seed::Seed::SqlFile {
592                path: dir.join("db/structure.sql"),
593            }
594        );
595    }
596
597    #[test]
598    fn command_path_resolved_relative_to_config() {
599        let dir = std::env::temp_dir().join("pg-ephemeral-config-test-command");
600        std::fs::create_dir_all(&dir).unwrap();
601        let config_path = dir.join("database.toml");
602        std::fs::write(
603            &config_path,
604            indoc::indoc! {r#"
605                image = "15.6"
606
607                [instances.main.seeds.migrate]
608                type = "command"
609                command = "./bin/migrate"
610                arguments = ["up"]
611                cache = { type = "none" }
612            "#},
613        )
614        .unwrap();
615
616        let resolved =
617            Config::load_toml_file(&config_path, None, &InstanceDefinition::empty()).unwrap();
618
619        let instance_name: crate::InstanceName = "main".parse().unwrap();
620        let instance = resolved.instances.get(&instance_name).unwrap();
621        let seed_name: crate::seed::SeedName = "migrate".parse().unwrap();
622
623        assert_eq!(
624            instance.seeds[&seed_name],
625            crate::seed::Seed::Command {
626                command: crate::seed::Command::new(
627                    dir.join("bin/migrate").to_string_lossy(),
628                    ["up"],
629                ),
630                cache: crate::seed::SeedCacheConfig::None,
631            }
632        );
633    }
634
635    #[test]
636    fn bare_command_name_not_resolved() {
637        let dir = std::env::temp_dir().join("pg-ephemeral-config-test-bare-command");
638        std::fs::create_dir_all(&dir).unwrap();
639        let config_path = dir.join("database.toml");
640        std::fs::write(
641            &config_path,
642            indoc::indoc! {r#"
643                image = "15.6"
644
645                [instances.main.seeds.schema]
646                type = "command"
647                command = "psql"
648                arguments = ["-f", "schema.sql"]
649                cache = { type = "command-hash" }
650            "#},
651        )
652        .unwrap();
653
654        let resolved =
655            Config::load_toml_file(&config_path, None, &InstanceDefinition::empty()).unwrap();
656
657        let instance_name: crate::InstanceName = "main".parse().unwrap();
658        let instance = resolved.instances.get(&instance_name).unwrap();
659        let seed_name: crate::seed::SeedName = "schema".parse().unwrap();
660
661        assert_eq!(
662            instance.seeds[&seed_name],
663            crate::seed::Seed::Command {
664                command: crate::seed::Command::new("psql", ["-f", "schema.sql"]),
665                cache: crate::seed::SeedCacheConfig::CommandHash,
666            }
667        );
668    }
669
670    #[test]
671    fn container_script_parsed() {
672        let dir = std::env::temp_dir().join("pg-ephemeral-config-test-container-script");
673        std::fs::create_dir_all(&dir).unwrap();
674        let config_path = dir.join("database.toml");
675        std::fs::write(
676            &config_path,
677            indoc::indoc! {r#"
678                image = "15.6"
679
680                [instances.main.seeds.install-ext]
681                type = "container-script"
682                script = "apt-get update && apt-get install -y postgresql-15-cron"
683            "#},
684        )
685        .unwrap();
686
687        let resolved =
688            Config::load_toml_file(&config_path, None, &InstanceDefinition::empty()).unwrap();
689
690        let instance_name: crate::InstanceName = "main".parse().unwrap();
691        let instance = resolved.instances.get(&instance_name).unwrap();
692        let seed_name: crate::seed::SeedName = "install-ext".parse().unwrap();
693
694        assert_eq!(
695            instance.seeds[&seed_name],
696            crate::seed::Seed::ContainerScript {
697                script: "apt-get update && apt-get install -y postgresql-15-cron".to_string(),
698            }
699        );
700    }
701
702    #[test]
703    fn sql_statement_parsed() {
704        let dir = std::env::temp_dir().join("pg-ephemeral-config-test-sql-statement");
705        std::fs::create_dir_all(&dir).unwrap();
706        let config_path = dir.join("database.toml");
707        std::fs::write(
708            &config_path,
709            indoc::indoc! {r#"
710                image = "15.6"
711
712                [instances.main.seeds.create-users]
713                type = "sql-statement"
714                statement = "CREATE TABLE users (id INT)"
715            "#},
716        )
717        .unwrap();
718
719        let resolved =
720            Config::load_toml_file(&config_path, None, &InstanceDefinition::empty()).unwrap();
721
722        let instance_name: crate::InstanceName = "main".parse().unwrap();
723        let instance = resolved.instances.get(&instance_name).unwrap();
724        let seed_name: crate::seed::SeedName = "create-users".parse().unwrap();
725
726        assert_eq!(
727            instance.seeds[&seed_name],
728            crate::seed::Seed::SqlStatement {
729                statement: "CREATE TABLE users (id INT)".to_string(),
730            }
731        );
732    }
733
734    #[test]
735    fn csv_file_parsed() {
736        let dir = std::env::temp_dir().join("pg-ephemeral-config-test-csv-file");
737        std::fs::create_dir_all(&dir).unwrap();
738        let config_path = dir.join("database.toml");
739        std::fs::write(
740            &config_path,
741            indoc::indoc! {r#"
742                image = "15.6"
743
744                [instances.main.seeds.users]
745                type = "csv-file"
746                path = "fixtures/users.csv"
747                table = { schema = "public", table = "users" }
748            "#},
749        )
750        .unwrap();
751
752        let resolved =
753            Config::load_toml_file(&config_path, None, &InstanceDefinition::empty()).unwrap();
754
755        let instance_name: crate::InstanceName = "main".parse().unwrap();
756        let instance = resolved.instances.get(&instance_name).unwrap();
757        let seed_name: crate::seed::SeedName = "users".parse().unwrap();
758
759        assert_eq!(
760            instance.seeds[&seed_name],
761            crate::seed::Seed::CsvFile {
762                path: dir.join("fixtures/users.csv"),
763                table: pg_client::QualifiedTable {
764                    schema: pg_client::identifier::Schema::PUBLIC,
765                    table: "users".parse().unwrap(),
766                },
767                delimiter: ',',
768            }
769        );
770    }
771}