1use super::InstanceName;
2use crate::definition::{Definition, SslConfig};
3use crate::image::Image;
4use crate::seed::{Command, Seed, SeedCacheConfig, SeedName};
5
6#[derive(Debug, PartialEq)]
13pub struct Resolved {
14 pub backend_selection: ociman::backend::Selection,
15 pub instances: super::InstanceMap,
16}
17
18#[derive(Clone, Debug, PartialEq)]
19pub struct Instance {
20 pub application_name: Option<pg_client::config::ApplicationName>,
21 pub database: pg_client::Database,
22 pub parameters: pg_client::parameter::Map,
23 pub seeds: indexmap::IndexMap<SeedName, Seed>,
24 pub ssl_config: Option<SslConfig>,
25 pub superuser: pg_client::User,
26 pub image: Image,
27 pub cross_container_access: bool,
28 pub wait_available_timeout: std::time::Duration,
29}
30
31impl Instance {
32 #[must_use]
33 pub fn new(image: Image) -> Self {
34 Self {
35 application_name: None,
36 parameters: pg_client::parameter::Map::new(),
37 seeds: indexmap::IndexMap::new(),
38 ssl_config: None,
39 superuser: pg_client::User::POSTGRES,
40 database: pg_client::Database::POSTGRES,
41 image,
42 cross_container_access: false,
43 wait_available_timeout: std::time::Duration::from_secs(10),
44 }
45 }
46
47 #[must_use]
48 pub fn definition(
49 &self,
50 backend: ociman::Backend,
51 instance_name: &crate::InstanceName,
52 ) -> Definition {
53 Definition {
54 instance_name: instance_name.clone(),
55 application_name: self.application_name.clone(),
56 backend,
57 database: self.database.clone(),
58 parameters: self.parameters.clone(),
59 seeds: self.seeds.clone(),
60 ssl_config: self.ssl_config.clone(),
61 superuser: self.superuser.clone(),
62 image: self.image.clone(),
63 cross_container_access: self.cross_container_access,
64 wait_available_timeout: self.wait_available_timeout,
65 remove: true,
66 session_name: None,
67 transparent_workdir: None,
68 }
69 }
70}
71
72#[derive(Debug, thiserror::Error, PartialEq)]
73pub enum Error {
74 #[error("Could not load config file: {0}")]
75 IO(IoError),
76 #[error("Decoding as toml failed: {0}")]
77 TomlDecode(#[from] toml::de::Error),
78 #[error("Instance {instance_name} does not specify {field} and no default applies")]
79 MissingInstanceField {
80 instance_name: InstanceName,
81 field: &'static str,
82 },
83}
84
85#[derive(Debug, PartialEq)]
86pub struct IoError(pub std::io::ErrorKind);
87
88impl std::fmt::Display for IoError {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 write!(f, "{}", std::io::Error::from(self.0))
91 }
92}
93
94impl std::error::Error for IoError {}
95
96impl From<std::io::Error> for IoError {
97 fn from(error: std::io::Error) -> Self {
98 Self(error.kind())
99 }
100}
101
102#[derive(Clone, Debug, serde::Deserialize, serde::Serialize, PartialEq)]
103#[serde(tag = "type", rename_all = "kebab-case", deny_unknown_fields)]
104pub enum SeedConfig {
105 SqlFile {
106 path: std::path::PathBuf,
107 git_revision: Option<String>,
108 },
109 SqlStatement {
110 statement: String,
111 },
112 Command {
113 command: String,
114 #[serde(default)]
115 arguments: Vec<String>,
116 cache: SeedCacheConfig,
117 },
118 Script {
119 script: String,
120 #[serde(default)]
121 cache: Option<SeedCacheConfig>,
122 },
123 ContainerScript {
124 script: String,
125 },
126 CsvFile {
127 path: std::path::PathBuf,
128 table: pg_client::QualifiedTable,
129 delimiter: Option<char>,
130 },
131}
132
133impl From<SeedConfig> for Seed {
134 fn from(value: SeedConfig) -> Self {
135 match value {
136 SeedConfig::SqlFile { path, git_revision } => match git_revision {
137 Some(git_revision) => Seed::SqlFileGitRevision { git_revision, path },
138 None => Seed::SqlFile { path },
139 },
140 SeedConfig::SqlStatement { statement } => Seed::SqlStatement { statement },
141 SeedConfig::Command {
142 command,
143 arguments,
144 cache,
145 } => Seed::Command {
146 command: Command::new(command, arguments),
147 cache,
148 },
149 SeedConfig::Script { script, cache } => Seed::Script {
150 script,
151 cache: cache.unwrap_or(SeedCacheConfig::CommandHash),
152 },
153 SeedConfig::ContainerScript { script } => Seed::ContainerScript { script },
154 SeedConfig::CsvFile {
155 path,
156 table,
157 delimiter,
158 } => Seed::CsvFile {
159 path,
160 table,
161 delimiter: delimiter.unwrap_or(','),
162 },
163 }
164 }
165}
166
167impl From<&Seed> for SeedConfig {
168 fn from(value: &Seed) -> Self {
169 match value {
170 Seed::SqlFile { path } => SeedConfig::SqlFile {
171 path: path.clone(),
172 git_revision: None,
173 },
174 Seed::SqlFileGitRevision { git_revision, path } => SeedConfig::SqlFile {
175 path: path.clone(),
176 git_revision: Some(git_revision.clone()),
177 },
178 Seed::SqlStatement { statement } => SeedConfig::SqlStatement {
179 statement: statement.clone(),
180 },
181 Seed::Command { command, cache } => SeedConfig::Command {
182 command: command.command.clone(),
183 arguments: command.arguments.clone(),
184 cache: cache.clone(),
185 },
186 Seed::Script { script, cache } => SeedConfig::Script {
187 script: script.clone(),
188 cache: Some(cache.clone()),
189 },
190 Seed::ContainerScript { script } => SeedConfig::ContainerScript {
191 script: script.clone(),
192 },
193 Seed::CsvFile {
194 path,
195 table,
196 delimiter,
197 } => SeedConfig::CsvFile {
198 path: path.clone(),
199 table: table.clone(),
200 delimiter: Some(*delimiter),
201 },
202 }
203 }
204}
205
206#[cfg(test)]
207mod from_seed_tests {
208 use super::*;
209
210 fn round_trip(config: SeedConfig) {
211 let seed: Seed = config.clone().into();
212 let restored: SeedConfig = (&seed).into();
213 assert_eq!(restored, config);
214 }
215
216 #[test]
217 fn round_trip_sql_file_no_git() {
218 round_trip(SeedConfig::SqlFile {
219 path: "schema.sql".into(),
220 git_revision: None,
221 });
222 }
223
224 #[test]
225 fn round_trip_sql_file_with_git() {
226 round_trip(SeedConfig::SqlFile {
227 path: "schema.sql".into(),
228 git_revision: Some("abc1234".to_string()),
229 });
230 }
231
232 #[test]
233 fn round_trip_sql_statement() {
234 round_trip(SeedConfig::SqlStatement {
235 statement: "CREATE TABLE t (id INT)".to_string(),
236 });
237 }
238
239 #[test]
240 fn round_trip_command() {
241 round_trip(SeedConfig::Command {
242 command: "psql".to_string(),
243 arguments: vec!["-c".to_string(), "SELECT 1".to_string()],
244 cache: SeedCacheConfig::CommandHash,
245 });
246 }
247
248 #[test]
249 fn round_trip_script_with_explicit_cache() {
250 round_trip(SeedConfig::Script {
251 script: "psql -c 'SELECT 1'".to_string(),
252 cache: Some(SeedCacheConfig::CommandHash),
253 });
254 }
255
256 #[test]
257 fn script_default_cache_is_recovered_explicitly() {
258 let starting = SeedConfig::Script {
259 script: "x".to_string(),
260 cache: None,
261 };
262 let seed: Seed = starting.into();
263 let restored: SeedConfig = (&seed).into();
264 assert_eq!(
265 restored,
266 SeedConfig::Script {
267 script: "x".to_string(),
268 cache: Some(SeedCacheConfig::CommandHash),
269 }
270 );
271 }
272
273 #[test]
274 fn round_trip_container_script() {
275 round_trip(SeedConfig::ContainerScript {
276 script: "apt-get install -y foo".to_string(),
277 });
278 }
279
280 #[test]
281 fn round_trip_csv_file_with_delimiter() {
282 round_trip(SeedConfig::CsvFile {
283 path: "data.csv".into(),
284 table: pg_client::QualifiedTable {
285 schema: pg_client::identifier::Schema::from_static_or_panic("public"),
286 table: pg_client::identifier::Table::from_static_or_panic("t"),
287 },
288 delimiter: Some(';'),
289 });
290 }
291
292 #[test]
293 fn csv_file_default_delimiter_is_recovered_explicitly() {
294 let starting = SeedConfig::CsvFile {
295 path: "data.csv".into(),
296 table: pg_client::QualifiedTable {
297 schema: pg_client::identifier::Schema::from_static_or_panic("public"),
298 table: pg_client::identifier::Table::from_static_or_panic("t"),
299 },
300 delimiter: None,
301 };
302 let seed: Seed = starting.into();
303 let restored: SeedConfig = (&seed).into();
304 assert_eq!(
305 restored,
306 SeedConfig::CsvFile {
307 path: "data.csv".into(),
308 table: pg_client::QualifiedTable {
309 schema: pg_client::identifier::Schema::from_static_or_panic("public"),
310 table: pg_client::identifier::Table::from_static_or_panic("t"),
311 },
312 delimiter: Some(','),
313 }
314 );
315 }
316}
317
318#[derive(Clone, Debug, serde::Deserialize, PartialEq)]
319#[serde(deny_unknown_fields)]
320pub struct SslConfigDefinition {
321 pub hostname: pg_client::config::HostName,
322}
323
324#[derive(Debug, serde::Deserialize, PartialEq)]
325#[serde(deny_unknown_fields)]
326pub struct InstanceDefinition {
327 pub image: Option<Image>,
328 #[serde(default)]
329 pub parameters: pg_client::parameter::Map,
330 #[serde(default)]
331 pub seeds: indexmap::IndexMap<SeedName, SeedConfig>,
332 pub ssl_config: Option<SslConfigDefinition>,
333 #[serde(default, with = "humantime_serde")]
334 pub wait_available_timeout: Option<std::time::Duration>,
335}
336
337impl InstanceDefinition {
338 #[must_use]
339 pub fn empty() -> Self {
340 Self {
341 image: None,
342 parameters: pg_client::parameter::Map::new(),
343 seeds: indexmap::IndexMap::new(),
344 ssl_config: None,
345 wait_available_timeout: None,
346 }
347 }
348
349 fn into_instance(
350 self,
351 instance_name: &InstanceName,
352 defaults: &InstanceDefinition,
353 overwrites: &InstanceDefinition,
354 ) -> Result<Instance, Error> {
355 let image = match overwrites
356 .image
357 .as_ref()
358 .or(self.image.as_ref())
359 .or(defaults.image.as_ref())
360 {
361 Some(image) => image.clone(),
362 None => {
363 return Err(Error::MissingInstanceField {
364 instance_name: instance_name.clone(),
365 field: "image",
366 });
367 }
368 };
369
370 let seeds = self
371 .seeds
372 .into_iter()
373 .map(|(name, seed_config)| (name, seed_config.into()))
374 .collect();
375
376 let ssl_config = overwrites
377 .ssl_config
378 .as_ref()
379 .or(self.ssl_config.as_ref())
380 .or(defaults.ssl_config.as_ref())
381 .map(|ssl_config_def| SslConfig::Generated {
382 hostname: ssl_config_def.hostname.clone(),
383 });
384
385 let wait_available_timeout = overwrites
386 .wait_available_timeout
387 .or(self.wait_available_timeout)
388 .or(defaults.wait_available_timeout)
389 .unwrap_or(std::time::Duration::from_secs(10));
390
391 Ok(Instance {
392 application_name: None,
393 database: pg_client::Database::POSTGRES,
394 parameters: self.parameters,
395 seeds,
396 ssl_config,
397 superuser: pg_client::User::POSTGRES,
398 image,
399 cross_container_access: false,
400 wait_available_timeout,
401 })
402 }
403}
404
405#[derive(Debug, serde::Deserialize, PartialEq)]
406#[serde(deny_unknown_fields)]
407pub struct Config {
408 image: Option<Image>,
409 backend: Option<ociman::backend::Selection>,
410 ssl_config: Option<SslConfigDefinition>,
411 #[serde(default, with = "humantime_serde")]
412 wait_available_timeout: Option<std::time::Duration>,
413 instances: Option<std::collections::BTreeMap<InstanceName, InstanceDefinition>>,
414}
415
416impl std::default::Default for Config {
417 fn default() -> Self {
418 Self {
419 image: Some(Image::default()),
420 backend: None,
421 ssl_config: None,
422 wait_available_timeout: None,
423 instances: None,
424 }
425 }
426}
427
428impl Config {
429 pub fn load_toml_file(
430 file: impl AsRef<std::path::Path>,
431 backend_overwrite: Option<ociman::backend::Selection>,
432 overwrites: &InstanceDefinition,
433 ) -> Result<Resolved, Error> {
434 let file = file.as_ref();
435 let base_dir = file
436 .parent()
437 .map(std::path::Path::to_path_buf)
438 .unwrap_or_default();
439
440 std::fs::read_to_string(file)
441 .map_err(|error| Error::IO(error.into()))
442 .and_then(Self::load_toml)
443 .map(|config| config.resolve_paths(&base_dir))
444 .and_then(|config| config.resolve(backend_overwrite, overwrites))
445 }
446
447 fn resolve_paths(mut self, base_dir: &std::path::Path) -> Self {
448 let resolve_path = |path: std::path::PathBuf| -> std::path::PathBuf {
449 if path.is_relative() {
450 base_dir.join(path)
451 } else {
452 path
453 }
454 };
455
456 let resolve_command = |command: &mut String| {
460 let path = std::path::Path::new(command.as_str());
461 if path.is_relative() && path.components().count() > 1 {
462 let stripped: std::path::PathBuf = path
465 .components()
466 .filter(|c| !matches!(c, std::path::Component::CurDir))
467 .collect();
468 *command = base_dir.join(stripped).to_string_lossy().into_owned();
469 }
470 };
471
472 if let Some(instances) = self.instances.as_mut() {
473 for instance in instances.values_mut() {
474 for seed in instance.seeds.values_mut() {
475 match seed {
476 SeedConfig::SqlFile { path, .. } => *path = resolve_path(path.clone()),
477 SeedConfig::Command { command, cache, .. } => {
478 resolve_command(command);
479 if let SeedCacheConfig::KeyCommand {
480 command: key_command,
481 ..
482 } = cache
483 {
484 resolve_command(key_command);
485 }
486 }
487 SeedConfig::Script { cache, .. } => {
488 if let Some(SeedCacheConfig::KeyCommand {
489 command: key_command,
490 ..
491 }) = cache
492 {
493 resolve_command(key_command);
494 }
495 }
496 SeedConfig::CsvFile { path, .. } => *path = resolve_path(path.clone()),
497 SeedConfig::ContainerScript { .. } | SeedConfig::SqlStatement { .. } => {}
498 }
499 }
500 }
501 }
502
503 self
504 }
505
506 pub fn load_toml(contents: impl AsRef<str>) -> Result<Config, Error> {
507 toml::from_str(contents.as_ref()).map_err(Error::TomlDecode)
508 }
509
510 pub fn resolve(
516 self,
517 backend_overwrite: Option<ociman::backend::Selection>,
518 overwrites: &InstanceDefinition,
519 ) -> Result<Resolved, Error> {
520 let backend_selection = backend_overwrite
521 .or(self.backend)
522 .unwrap_or(ociman::backend::Selection::Auto);
523
524 let defaults = InstanceDefinition {
525 image: self.image.clone(),
526 parameters: pg_client::parameter::Map::new(),
527 seeds: indexmap::IndexMap::new(),
528 ssl_config: self.ssl_config.clone(),
529 wait_available_timeout: self.wait_available_timeout,
530 };
531
532 let instances = match self.instances {
533 None => {
534 let instance_name = InstanceName::default();
535
536 InstanceDefinition::empty()
537 .into_instance(&instance_name, &defaults, overwrites)
538 .map(|instance| [(instance_name, instance)].into())?
539 }
540 Some(map) => {
541 let mut instance_map = std::collections::BTreeMap::new();
542
543 for (instance_name, instance_definition) in map {
544 let instance =
545 instance_definition.into_instance(&instance_name, &defaults, overwrites)?;
546
547 instance_map.insert(instance_name, instance);
548 }
549
550 instance_map
551 }
552 };
553
554 Ok(Resolved {
555 backend_selection,
556 instances,
557 })
558 }
559}
560
561#[cfg(test)]
562mod test {
563 use super::*;
564
565 #[test]
566 fn sql_file_path_resolved_relative_to_config() {
567 let dir = std::env::temp_dir().join("pg-ephemeral-config-test-sql-file");
568 std::fs::create_dir_all(&dir).unwrap();
569 let config_path = dir.join("database.toml");
570 std::fs::write(
571 &config_path,
572 indoc::indoc! {r#"
573 image = "15.6"
574
575 [instances.main.seeds.schema]
576 type = "sql-file"
577 path = "db/structure.sql"
578 "#},
579 )
580 .unwrap();
581
582 let resolved =
583 Config::load_toml_file(&config_path, None, &InstanceDefinition::empty()).unwrap();
584
585 let instance_name: crate::InstanceName = "main".parse().unwrap();
586 let instance = resolved.instances.get(&instance_name).unwrap();
587 let seed_name: crate::seed::SeedName = "schema".parse().unwrap();
588
589 assert_eq!(
590 instance.seeds[&seed_name],
591 crate::seed::Seed::SqlFile {
592 path: dir.join("db/structure.sql"),
593 }
594 );
595 }
596
597 #[test]
598 fn command_path_resolved_relative_to_config() {
599 let dir = std::env::temp_dir().join("pg-ephemeral-config-test-command");
600 std::fs::create_dir_all(&dir).unwrap();
601 let config_path = dir.join("database.toml");
602 std::fs::write(
603 &config_path,
604 indoc::indoc! {r#"
605 image = "15.6"
606
607 [instances.main.seeds.migrate]
608 type = "command"
609 command = "./bin/migrate"
610 arguments = ["up"]
611 cache = { type = "none" }
612 "#},
613 )
614 .unwrap();
615
616 let resolved =
617 Config::load_toml_file(&config_path, None, &InstanceDefinition::empty()).unwrap();
618
619 let instance_name: crate::InstanceName = "main".parse().unwrap();
620 let instance = resolved.instances.get(&instance_name).unwrap();
621 let seed_name: crate::seed::SeedName = "migrate".parse().unwrap();
622
623 assert_eq!(
624 instance.seeds[&seed_name],
625 crate::seed::Seed::Command {
626 command: crate::seed::Command::new(
627 dir.join("bin/migrate").to_string_lossy(),
628 ["up"],
629 ),
630 cache: crate::seed::SeedCacheConfig::None,
631 }
632 );
633 }
634
635 #[test]
636 fn bare_command_name_not_resolved() {
637 let dir = std::env::temp_dir().join("pg-ephemeral-config-test-bare-command");
638 std::fs::create_dir_all(&dir).unwrap();
639 let config_path = dir.join("database.toml");
640 std::fs::write(
641 &config_path,
642 indoc::indoc! {r#"
643 image = "15.6"
644
645 [instances.main.seeds.schema]
646 type = "command"
647 command = "psql"
648 arguments = ["-f", "schema.sql"]
649 cache = { type = "command-hash" }
650 "#},
651 )
652 .unwrap();
653
654 let resolved =
655 Config::load_toml_file(&config_path, None, &InstanceDefinition::empty()).unwrap();
656
657 let instance_name: crate::InstanceName = "main".parse().unwrap();
658 let instance = resolved.instances.get(&instance_name).unwrap();
659 let seed_name: crate::seed::SeedName = "schema".parse().unwrap();
660
661 assert_eq!(
662 instance.seeds[&seed_name],
663 crate::seed::Seed::Command {
664 command: crate::seed::Command::new("psql", ["-f", "schema.sql"]),
665 cache: crate::seed::SeedCacheConfig::CommandHash,
666 }
667 );
668 }
669
670 #[test]
671 fn container_script_parsed() {
672 let dir = std::env::temp_dir().join("pg-ephemeral-config-test-container-script");
673 std::fs::create_dir_all(&dir).unwrap();
674 let config_path = dir.join("database.toml");
675 std::fs::write(
676 &config_path,
677 indoc::indoc! {r#"
678 image = "15.6"
679
680 [instances.main.seeds.install-ext]
681 type = "container-script"
682 script = "apt-get update && apt-get install -y postgresql-15-cron"
683 "#},
684 )
685 .unwrap();
686
687 let resolved =
688 Config::load_toml_file(&config_path, None, &InstanceDefinition::empty()).unwrap();
689
690 let instance_name: crate::InstanceName = "main".parse().unwrap();
691 let instance = resolved.instances.get(&instance_name).unwrap();
692 let seed_name: crate::seed::SeedName = "install-ext".parse().unwrap();
693
694 assert_eq!(
695 instance.seeds[&seed_name],
696 crate::seed::Seed::ContainerScript {
697 script: "apt-get update && apt-get install -y postgresql-15-cron".to_string(),
698 }
699 );
700 }
701
702 #[test]
703 fn sql_statement_parsed() {
704 let dir = std::env::temp_dir().join("pg-ephemeral-config-test-sql-statement");
705 std::fs::create_dir_all(&dir).unwrap();
706 let config_path = dir.join("database.toml");
707 std::fs::write(
708 &config_path,
709 indoc::indoc! {r#"
710 image = "15.6"
711
712 [instances.main.seeds.create-users]
713 type = "sql-statement"
714 statement = "CREATE TABLE users (id INT)"
715 "#},
716 )
717 .unwrap();
718
719 let resolved =
720 Config::load_toml_file(&config_path, None, &InstanceDefinition::empty()).unwrap();
721
722 let instance_name: crate::InstanceName = "main".parse().unwrap();
723 let instance = resolved.instances.get(&instance_name).unwrap();
724 let seed_name: crate::seed::SeedName = "create-users".parse().unwrap();
725
726 assert_eq!(
727 instance.seeds[&seed_name],
728 crate::seed::Seed::SqlStatement {
729 statement: "CREATE TABLE users (id INT)".to_string(),
730 }
731 );
732 }
733
734 #[test]
735 fn csv_file_parsed() {
736 let dir = std::env::temp_dir().join("pg-ephemeral-config-test-csv-file");
737 std::fs::create_dir_all(&dir).unwrap();
738 let config_path = dir.join("database.toml");
739 std::fs::write(
740 &config_path,
741 indoc::indoc! {r#"
742 image = "15.6"
743
744 [instances.main.seeds.users]
745 type = "csv-file"
746 path = "fixtures/users.csv"
747 table = { schema = "public", table = "users" }
748 "#},
749 )
750 .unwrap();
751
752 let resolved =
753 Config::load_toml_file(&config_path, None, &InstanceDefinition::empty()).unwrap();
754
755 let instance_name: crate::InstanceName = "main".parse().unwrap();
756 let instance = resolved.instances.get(&instance_name).unwrap();
757 let seed_name: crate::seed::SeedName = "users".parse().unwrap();
758
759 assert_eq!(
760 instance.seeds[&seed_name],
761 crate::seed::Seed::CsvFile {
762 path: dir.join("fixtures/users.csv"),
763 table: pg_client::QualifiedTable {
764 schema: pg_client::identifier::Schema::PUBLIC,
765 table: "users".parse().unwrap(),
766 },
767 delimiter: ',',
768 }
769 );
770 }
771}