Skip to main content

pg_ephemeral/
config.rs

1use super::InstanceName;
2use crate::definition::{Definition, SslConfig};
3use crate::image::Image;
4use crate::seed::{Command, CommandCacheConfig, Seed, SeedName};
5
6#[derive(Clone, Debug, PartialEq)]
7pub struct Instance {
8    pub application_name: Option<pg_client::config::ApplicationName>,
9    pub backend: ociman::backend::Selection,
10    pub database: pg_client::Database,
11    pub seeds: indexmap::IndexMap<SeedName, Seed>,
12    pub ssl_config: Option<SslConfig>,
13    pub superuser: pg_client::User,
14    pub image: Image,
15    pub cross_container_access: bool,
16    pub wait_available_timeout: std::time::Duration,
17}
18
19impl Instance {
20    #[must_use]
21    pub fn new(backend: ociman::backend::Selection, image: Image) -> Self {
22        Self {
23            backend,
24            application_name: None,
25            seeds: indexmap::IndexMap::new(),
26            ssl_config: None,
27            superuser: pg_client::User::POSTGRES,
28            database: pg_client::Database::POSTGRES,
29            image,
30            cross_container_access: false,
31            wait_available_timeout: std::time::Duration::from_secs(10),
32        }
33    }
34
35    pub async fn definition(
36        &self,
37        instance_name: &crate::InstanceName,
38    ) -> Result<Definition, ociman::backend::resolve::Error> {
39        Ok(Definition {
40            instance_name: instance_name.clone(),
41            application_name: self.application_name.clone(),
42            backend: self.backend.resolve().await?,
43            database: self.database.clone(),
44            seeds: self.seeds.clone(),
45            ssl_config: self.ssl_config.clone(),
46            superuser: self.superuser.clone(),
47            image: self.image.clone(),
48            cross_container_access: self.cross_container_access,
49            wait_available_timeout: self.wait_available_timeout,
50            remove: true,
51        })
52    }
53}
54
55#[derive(Debug, thiserror::Error, PartialEq)]
56pub enum Error {
57    #[error("Could not load config file: {0}")]
58    IO(IoError),
59    #[error("Decoding as toml failed: {0}")]
60    TomlDecode(#[from] toml::de::Error),
61    #[error("Instance {instance_name} does not specify {field} and no default applies")]
62    MissingInstanceField {
63        instance_name: InstanceName,
64        field: &'static str,
65    },
66}
67
68#[derive(Debug, PartialEq)]
69pub struct IoError(pub std::io::ErrorKind);
70
71impl std::fmt::Display for IoError {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        write!(f, "{}", std::io::Error::from(self.0))
74    }
75}
76
77impl std::error::Error for IoError {}
78
79impl From<std::io::Error> for IoError {
80    fn from(error: std::io::Error) -> Self {
81        Self(error.kind())
82    }
83}
84
85#[derive(Clone, Debug, serde::Deserialize, PartialEq)]
86#[serde(tag = "type", rename_all = "kebab-case")]
87pub enum SeedConfig {
88    SqlFile {
89        path: std::path::PathBuf,
90        git_revision: Option<String>,
91    },
92    Command {
93        command: String,
94        #[serde(default)]
95        arguments: Vec<String>,
96        cache: CommandCacheConfig,
97    },
98    Script {
99        script: String,
100    },
101    ContainerScript {
102        script: String,
103    },
104    CsvFile {
105        path: std::path::PathBuf,
106        table: pg_client::QualifiedTable,
107        delimiter: Option<char>,
108    },
109}
110
111impl From<SeedConfig> for Seed {
112    fn from(value: SeedConfig) -> Self {
113        match value {
114            SeedConfig::SqlFile { path, git_revision } => match git_revision {
115                Some(git_revision) => Seed::SqlFileGitRevision { git_revision, path },
116                None => Seed::SqlFile { path },
117            },
118            SeedConfig::Command {
119                command,
120                arguments,
121                cache,
122            } => Seed::Command {
123                command: Command::new(command, arguments),
124                cache,
125            },
126            SeedConfig::Script { script } => Seed::Script { script },
127            SeedConfig::ContainerScript { script } => Seed::ContainerScript { script },
128            SeedConfig::CsvFile {
129                path,
130                table,
131                delimiter,
132            } => Seed::CsvFile {
133                path,
134                table,
135                delimiter: delimiter.unwrap_or(','),
136            },
137        }
138    }
139}
140
141#[derive(Clone, Debug, serde::Deserialize, PartialEq)]
142#[serde(deny_unknown_fields)]
143pub struct SslConfigDefinition {
144    pub hostname: pg_client::config::HostName,
145}
146
147#[derive(Debug, serde::Deserialize, PartialEq)]
148#[serde(deny_unknown_fields)]
149pub struct InstanceDefinition {
150    pub backend: Option<ociman::backend::Selection>,
151    pub image: Option<Image>,
152    #[serde(default)]
153    pub seeds: indexmap::IndexMap<SeedName, SeedConfig>,
154    pub ssl_config: Option<SslConfigDefinition>,
155    #[serde(default, with = "humantime_serde")]
156    pub wait_available_timeout: Option<std::time::Duration>,
157}
158
159impl InstanceDefinition {
160    #[must_use]
161    pub fn empty() -> Self {
162        Self {
163            backend: None,
164            image: None,
165            seeds: indexmap::IndexMap::new(),
166            ssl_config: None,
167            wait_available_timeout: None,
168        }
169    }
170
171    fn into_instance(
172        self,
173        instance_name: &InstanceName,
174        defaults: &InstanceDefinition,
175        overwrites: &InstanceDefinition,
176    ) -> Result<Instance, Error> {
177        let image = match overwrites
178            .image
179            .as_ref()
180            .or(self.image.as_ref())
181            .or(defaults.image.as_ref())
182        {
183            Some(image) => image.clone(),
184            None => {
185                return Err(Error::MissingInstanceField {
186                    instance_name: instance_name.clone(),
187                    field: "image",
188                });
189            }
190        };
191
192        let backend = overwrites
193            .backend
194            .or(self.backend)
195            .or(defaults.backend)
196            .unwrap_or(ociman::backend::Selection::Auto);
197
198        let seeds = self
199            .seeds
200            .into_iter()
201            .map(|(name, seed_config)| (name, seed_config.into()))
202            .collect();
203
204        let ssl_config = overwrites
205            .ssl_config
206            .as_ref()
207            .or(self.ssl_config.as_ref())
208            .or(defaults.ssl_config.as_ref())
209            .map(|ssl_config_def| SslConfig::Generated {
210                hostname: ssl_config_def.hostname.clone(),
211            });
212
213        let wait_available_timeout = overwrites
214            .wait_available_timeout
215            .or(self.wait_available_timeout)
216            .or(defaults.wait_available_timeout)
217            .unwrap_or(std::time::Duration::from_secs(10));
218
219        Ok(Instance {
220            application_name: None,
221            backend,
222            database: pg_client::Database::POSTGRES,
223            seeds,
224            ssl_config,
225            superuser: pg_client::User::POSTGRES,
226            image,
227            cross_container_access: false,
228            wait_available_timeout,
229        })
230    }
231}
232
233#[derive(Debug, serde::Deserialize, PartialEq)]
234#[serde(deny_unknown_fields)]
235pub struct Config {
236    image: Option<Image>,
237    backend: Option<ociman::backend::Selection>,
238    ssl_config: Option<SslConfigDefinition>,
239    #[serde(default, with = "humantime_serde")]
240    wait_available_timeout: Option<std::time::Duration>,
241    instances: Option<std::collections::BTreeMap<InstanceName, InstanceDefinition>>,
242}
243
244impl std::default::Default for Config {
245    fn default() -> Self {
246        Self {
247            image: Some(Image::default()),
248            backend: None,
249            ssl_config: None,
250            wait_available_timeout: None,
251            instances: None,
252        }
253    }
254}
255
256impl Config {
257    pub fn load_toml_file(
258        file: impl AsRef<std::path::Path>,
259        overwrites: &InstanceDefinition,
260    ) -> Result<super::InstanceMap, Error> {
261        let file = file.as_ref();
262        let base_dir = file
263            .parent()
264            .map(std::path::Path::to_path_buf)
265            .unwrap_or_default();
266
267        std::fs::read_to_string(file)
268            .map_err(|error| Error::IO(error.into()))
269            .and_then(Self::load_toml)
270            .map(|config| config.resolve_paths(&base_dir))
271            .and_then(|config| config.instance_map(overwrites))
272    }
273
274    fn resolve_paths(mut self, base_dir: &std::path::Path) -> Self {
275        let resolve_path = |path: std::path::PathBuf| -> std::path::PathBuf {
276            if path.is_relative() {
277                base_dir.join(path)
278            } else {
279                path
280            }
281        };
282
283        // Resolve a command string if it looks like a relative file path (contains a
284        // path separator). Plain command names such as "sh" or "psql" are left alone
285        // so they continue to be resolved via PATH.
286        let resolve_command = |command: &mut String| {
287            let path = std::path::Path::new(command.as_str());
288            if path.is_relative() && path.components().count() > 1 {
289                // Strip leading CurDir (`.`) components so `./bin/foo` and `bin/foo`
290                // both produce the same absolute result after joining.
291                let stripped: std::path::PathBuf = path
292                    .components()
293                    .filter(|c| !matches!(c, std::path::Component::CurDir))
294                    .collect();
295                *command = base_dir.join(stripped).to_string_lossy().into_owned();
296            }
297        };
298
299        if let Some(instances) = self.instances.as_mut() {
300            for instance in instances.values_mut() {
301                for seed in instance.seeds.values_mut() {
302                    match seed {
303                        SeedConfig::SqlFile { path, .. } => *path = resolve_path(path.clone()),
304                        SeedConfig::Command { command, cache, .. } => {
305                            resolve_command(command);
306                            if let CommandCacheConfig::KeyCommand {
307                                command: key_command,
308                                ..
309                            } = cache
310                            {
311                                resolve_command(key_command);
312                            }
313                        }
314                        SeedConfig::CsvFile { path, .. } => *path = resolve_path(path.clone()),
315                        SeedConfig::Script { .. } | SeedConfig::ContainerScript { .. } => {}
316                    }
317                }
318            }
319        }
320
321        self
322    }
323
324    pub fn load_toml(contents: impl AsRef<str>) -> Result<Config, Error> {
325        toml::from_str(contents.as_ref()).map_err(Error::TomlDecode)
326    }
327
328    pub fn instance_map(
329        self,
330        overwrites: &InstanceDefinition,
331    ) -> Result<super::InstanceMap, Error> {
332        let defaults = InstanceDefinition {
333            backend: self.backend,
334            image: self.image.clone(),
335            seeds: indexmap::IndexMap::new(),
336            ssl_config: self.ssl_config.clone(),
337            wait_available_timeout: self.wait_available_timeout,
338        };
339
340        match self.instances {
341            None => {
342                let instance_name = InstanceName::default();
343
344                InstanceDefinition::empty()
345                    .into_instance(&instance_name, &defaults, overwrites)
346                    .map(|instance| [(instance_name, instance)].into())
347            }
348            Some(map) => {
349                let mut instance_map = std::collections::BTreeMap::new();
350
351                for (instance_name, instance_definition) in map {
352                    let instance =
353                        instance_definition.into_instance(&instance_name, &defaults, overwrites)?;
354
355                    instance_map.insert(instance_name, instance);
356                }
357
358                Ok(instance_map)
359            }
360        }
361    }
362}
363
364#[cfg(test)]
365mod test {
366    use super::*;
367
368    #[test]
369    fn sql_file_path_resolved_relative_to_config() {
370        let dir = std::env::temp_dir().join("pg-ephemeral-config-test-sql-file");
371        std::fs::create_dir_all(&dir).unwrap();
372        let config_path = dir.join("database.toml");
373        std::fs::write(
374            &config_path,
375            indoc::indoc! {r#"
376                image = "15.6"
377
378                [instances.main.seeds.schema]
379                type = "sql-file"
380                path = "db/structure.sql"
381            "#},
382        )
383        .unwrap();
384
385        let instance_map =
386            Config::load_toml_file(&config_path, &InstanceDefinition::empty()).unwrap();
387
388        let instance_name: crate::InstanceName = "main".parse().unwrap();
389        let instance = instance_map.get(&instance_name).unwrap();
390        let seed_name: crate::seed::SeedName = "schema".parse().unwrap();
391
392        assert_eq!(
393            instance.seeds[&seed_name],
394            crate::seed::Seed::SqlFile {
395                path: dir.join("db/structure.sql"),
396            }
397        );
398    }
399
400    #[test]
401    fn command_path_resolved_relative_to_config() {
402        let dir = std::env::temp_dir().join("pg-ephemeral-config-test-command");
403        std::fs::create_dir_all(&dir).unwrap();
404        let config_path = dir.join("database.toml");
405        std::fs::write(
406            &config_path,
407            indoc::indoc! {r#"
408                image = "15.6"
409
410                [instances.main.seeds.migrate]
411                type = "command"
412                command = "./bin/migrate"
413                arguments = ["up"]
414                cache = { type = "none" }
415            "#},
416        )
417        .unwrap();
418
419        let instance_map =
420            Config::load_toml_file(&config_path, &InstanceDefinition::empty()).unwrap();
421
422        let instance_name: crate::InstanceName = "main".parse().unwrap();
423        let instance = instance_map.get(&instance_name).unwrap();
424        let seed_name: crate::seed::SeedName = "migrate".parse().unwrap();
425
426        assert_eq!(
427            instance.seeds[&seed_name],
428            crate::seed::Seed::Command {
429                command: crate::seed::Command::new(
430                    dir.join("bin/migrate").to_string_lossy(),
431                    ["up"],
432                ),
433                cache: crate::seed::CommandCacheConfig::None,
434            }
435        );
436    }
437
438    #[test]
439    fn bare_command_name_not_resolved() {
440        let dir = std::env::temp_dir().join("pg-ephemeral-config-test-bare-command");
441        std::fs::create_dir_all(&dir).unwrap();
442        let config_path = dir.join("database.toml");
443        std::fs::write(
444            &config_path,
445            indoc::indoc! {r#"
446                image = "15.6"
447
448                [instances.main.seeds.schema]
449                type = "command"
450                command = "psql"
451                arguments = ["-f", "schema.sql"]
452                cache = { type = "command-hash" }
453            "#},
454        )
455        .unwrap();
456
457        let instance_map =
458            Config::load_toml_file(&config_path, &InstanceDefinition::empty()).unwrap();
459
460        let instance_name: crate::InstanceName = "main".parse().unwrap();
461        let instance = instance_map.get(&instance_name).unwrap();
462        let seed_name: crate::seed::SeedName = "schema".parse().unwrap();
463
464        assert_eq!(
465            instance.seeds[&seed_name],
466            crate::seed::Seed::Command {
467                command: crate::seed::Command::new("psql", ["-f", "schema.sql"]),
468                cache: crate::seed::CommandCacheConfig::CommandHash,
469            }
470        );
471    }
472
473    #[test]
474    fn container_script_parsed() {
475        let dir = std::env::temp_dir().join("pg-ephemeral-config-test-container-script");
476        std::fs::create_dir_all(&dir).unwrap();
477        let config_path = dir.join("database.toml");
478        std::fs::write(
479            &config_path,
480            indoc::indoc! {r#"
481                image = "15.6"
482
483                [instances.main.seeds.install-ext]
484                type = "container-script"
485                script = "apt-get update && apt-get install -y postgresql-15-cron"
486            "#},
487        )
488        .unwrap();
489
490        let instance_map =
491            Config::load_toml_file(&config_path, &InstanceDefinition::empty()).unwrap();
492
493        let instance_name: crate::InstanceName = "main".parse().unwrap();
494        let instance = instance_map.get(&instance_name).unwrap();
495        let seed_name: crate::seed::SeedName = "install-ext".parse().unwrap();
496
497        assert_eq!(
498            instance.seeds[&seed_name],
499            crate::seed::Seed::ContainerScript {
500                script: "apt-get update && apt-get install -y postgresql-15-cron".to_string(),
501            }
502        );
503    }
504
505    #[test]
506    fn csv_file_parsed() {
507        let dir = std::env::temp_dir().join("pg-ephemeral-config-test-csv-file");
508        std::fs::create_dir_all(&dir).unwrap();
509        let config_path = dir.join("database.toml");
510        std::fs::write(
511            &config_path,
512            indoc::indoc! {r#"
513                image = "15.6"
514
515                [instances.main.seeds.users]
516                type = "csv-file"
517                path = "fixtures/users.csv"
518                table = { schema = "public", table = "users" }
519            "#},
520        )
521        .unwrap();
522
523        let instance_map =
524            Config::load_toml_file(&config_path, &InstanceDefinition::empty()).unwrap();
525
526        let instance_name: crate::InstanceName = "main".parse().unwrap();
527        let instance = instance_map.get(&instance_name).unwrap();
528        let seed_name: crate::seed::SeedName = "users".parse().unwrap();
529
530        assert_eq!(
531            instance.seeds[&seed_name],
532            crate::seed::Seed::CsvFile {
533                path: dir.join("fixtures/users.csv"),
534                table: pg_client::QualifiedTable {
535                    schema: pg_client::identifier::Schema::PUBLIC,
536                    table: "users".parse().unwrap(),
537                },
538                delimiter: ',',
539            }
540        );
541    }
542}