1use super::InstanceName;
2use crate::definition::{Definition, SslConfig};
3use crate::image::Image;
4use crate::seed::{Command, CommandCacheConfig, Seed, SeedName};
5
6#[derive(Clone, Debug, PartialEq)]
7pub struct Instance {
8 pub application_name: Option<pg_client::config::ApplicationName>,
9 pub backend: ociman::backend::Selection,
10 pub database: pg_client::Database,
11 pub seeds: indexmap::IndexMap<SeedName, Seed>,
12 pub ssl_config: Option<SslConfig>,
13 pub superuser: pg_client::User,
14 pub image: Image,
15 pub cross_container_access: bool,
16 pub wait_available_timeout: std::time::Duration,
17}
18
19impl Instance {
20 #[must_use]
21 pub fn new(backend: ociman::backend::Selection, image: Image) -> Self {
22 Self {
23 backend,
24 application_name: None,
25 seeds: indexmap::IndexMap::new(),
26 ssl_config: None,
27 superuser: pg_client::User::POSTGRES,
28 database: pg_client::Database::POSTGRES,
29 image,
30 cross_container_access: false,
31 wait_available_timeout: std::time::Duration::from_secs(10),
32 }
33 }
34
35 pub async fn definition(
36 &self,
37 instance_name: &crate::InstanceName,
38 ) -> Result<Definition, ociman::backend::resolve::Error> {
39 Ok(Definition {
40 instance_name: instance_name.clone(),
41 application_name: self.application_name.clone(),
42 backend: self.backend.resolve().await?,
43 database: self.database.clone(),
44 seeds: self.seeds.clone(),
45 ssl_config: self.ssl_config.clone(),
46 superuser: self.superuser.clone(),
47 image: self.image.clone(),
48 cross_container_access: self.cross_container_access,
49 wait_available_timeout: self.wait_available_timeout,
50 remove: true,
51 })
52 }
53}
54
55#[derive(Debug, thiserror::Error, PartialEq)]
56pub enum Error {
57 #[error("Could not load config file: {0}")]
58 IO(IoError),
59 #[error("Decoding as toml failed: {0}")]
60 TomlDecode(#[from] toml::de::Error),
61 #[error("Instance {instance_name} does not specify {field} and no default applies")]
62 MissingInstanceField {
63 instance_name: InstanceName,
64 field: &'static str,
65 },
66}
67
68#[derive(Debug, PartialEq)]
69pub struct IoError(pub std::io::ErrorKind);
70
71impl std::fmt::Display for IoError {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 write!(f, "{}", std::io::Error::from(self.0))
74 }
75}
76
77impl std::error::Error for IoError {}
78
79impl From<std::io::Error> for IoError {
80 fn from(error: std::io::Error) -> Self {
81 Self(error.kind())
82 }
83}
84
85#[derive(Clone, Debug, serde::Deserialize, PartialEq)]
86#[serde(tag = "type", rename_all = "kebab-case")]
87pub enum SeedConfig {
88 SqlFile {
89 path: std::path::PathBuf,
90 git_revision: Option<String>,
91 },
92 Command {
93 command: String,
94 #[serde(default)]
95 arguments: Vec<String>,
96 cache: CommandCacheConfig,
97 },
98 Script {
99 script: String,
100 },
101 ContainerScript {
102 script: String,
103 },
104 CsvFile {
105 path: std::path::PathBuf,
106 table: pg_client::QualifiedTable,
107 delimiter: Option<char>,
108 },
109}
110
111impl From<SeedConfig> for Seed {
112 fn from(value: SeedConfig) -> Self {
113 match value {
114 SeedConfig::SqlFile { path, git_revision } => match git_revision {
115 Some(git_revision) => Seed::SqlFileGitRevision { git_revision, path },
116 None => Seed::SqlFile { path },
117 },
118 SeedConfig::Command {
119 command,
120 arguments,
121 cache,
122 } => Seed::Command {
123 command: Command::new(command, arguments),
124 cache,
125 },
126 SeedConfig::Script { script } => Seed::Script { script },
127 SeedConfig::ContainerScript { script } => Seed::ContainerScript { script },
128 SeedConfig::CsvFile {
129 path,
130 table,
131 delimiter,
132 } => Seed::CsvFile {
133 path,
134 table,
135 delimiter: delimiter.unwrap_or(','),
136 },
137 }
138 }
139}
140
141#[derive(Clone, Debug, serde::Deserialize, PartialEq)]
142#[serde(deny_unknown_fields)]
143pub struct SslConfigDefinition {
144 pub hostname: pg_client::config::HostName,
145}
146
147#[derive(Debug, serde::Deserialize, PartialEq)]
148#[serde(deny_unknown_fields)]
149pub struct InstanceDefinition {
150 pub backend: Option<ociman::backend::Selection>,
151 pub image: Option<Image>,
152 #[serde(default)]
153 pub seeds: indexmap::IndexMap<SeedName, SeedConfig>,
154 pub ssl_config: Option<SslConfigDefinition>,
155 #[serde(default, with = "humantime_serde")]
156 pub wait_available_timeout: Option<std::time::Duration>,
157}
158
159impl InstanceDefinition {
160 #[must_use]
161 pub fn empty() -> Self {
162 Self {
163 backend: None,
164 image: None,
165 seeds: indexmap::IndexMap::new(),
166 ssl_config: None,
167 wait_available_timeout: None,
168 }
169 }
170
171 fn into_instance(
172 self,
173 instance_name: &InstanceName,
174 defaults: &InstanceDefinition,
175 overwrites: &InstanceDefinition,
176 ) -> Result<Instance, Error> {
177 let image = match overwrites
178 .image
179 .as_ref()
180 .or(self.image.as_ref())
181 .or(defaults.image.as_ref())
182 {
183 Some(image) => image.clone(),
184 None => {
185 return Err(Error::MissingInstanceField {
186 instance_name: instance_name.clone(),
187 field: "image",
188 });
189 }
190 };
191
192 let backend = overwrites
193 .backend
194 .or(self.backend)
195 .or(defaults.backend)
196 .unwrap_or(ociman::backend::Selection::Auto);
197
198 let seeds = self
199 .seeds
200 .into_iter()
201 .map(|(name, seed_config)| (name, seed_config.into()))
202 .collect();
203
204 let ssl_config = overwrites
205 .ssl_config
206 .as_ref()
207 .or(self.ssl_config.as_ref())
208 .or(defaults.ssl_config.as_ref())
209 .map(|ssl_config_def| SslConfig::Generated {
210 hostname: ssl_config_def.hostname.clone(),
211 });
212
213 let wait_available_timeout = overwrites
214 .wait_available_timeout
215 .or(self.wait_available_timeout)
216 .or(defaults.wait_available_timeout)
217 .unwrap_or(std::time::Duration::from_secs(10));
218
219 Ok(Instance {
220 application_name: None,
221 backend,
222 database: pg_client::Database::POSTGRES,
223 seeds,
224 ssl_config,
225 superuser: pg_client::User::POSTGRES,
226 image,
227 cross_container_access: false,
228 wait_available_timeout,
229 })
230 }
231}
232
233#[derive(Debug, serde::Deserialize, PartialEq)]
234#[serde(deny_unknown_fields)]
235pub struct Config {
236 image: Option<Image>,
237 backend: Option<ociman::backend::Selection>,
238 ssl_config: Option<SslConfigDefinition>,
239 #[serde(default, with = "humantime_serde")]
240 wait_available_timeout: Option<std::time::Duration>,
241 instances: Option<std::collections::BTreeMap<InstanceName, InstanceDefinition>>,
242}
243
244impl std::default::Default for Config {
245 fn default() -> Self {
246 Self {
247 image: Some(Image::default()),
248 backend: None,
249 ssl_config: None,
250 wait_available_timeout: None,
251 instances: None,
252 }
253 }
254}
255
256impl Config {
257 pub fn load_toml_file(
258 file: impl AsRef<std::path::Path>,
259 overwrites: &InstanceDefinition,
260 ) -> Result<super::InstanceMap, Error> {
261 let file = file.as_ref();
262 let base_dir = file
263 .parent()
264 .map(std::path::Path::to_path_buf)
265 .unwrap_or_default();
266
267 std::fs::read_to_string(file)
268 .map_err(|error| Error::IO(error.into()))
269 .and_then(Self::load_toml)
270 .map(|config| config.resolve_paths(&base_dir))
271 .and_then(|config| config.instance_map(overwrites))
272 }
273
274 fn resolve_paths(mut self, base_dir: &std::path::Path) -> Self {
275 let resolve_path = |path: std::path::PathBuf| -> std::path::PathBuf {
276 if path.is_relative() {
277 base_dir.join(path)
278 } else {
279 path
280 }
281 };
282
283 let resolve_command = |command: &mut String| {
287 let path = std::path::Path::new(command.as_str());
288 if path.is_relative() && path.components().count() > 1 {
289 let stripped: std::path::PathBuf = path
292 .components()
293 .filter(|c| !matches!(c, std::path::Component::CurDir))
294 .collect();
295 *command = base_dir.join(stripped).to_string_lossy().into_owned();
296 }
297 };
298
299 if let Some(instances) = self.instances.as_mut() {
300 for instance in instances.values_mut() {
301 for seed in instance.seeds.values_mut() {
302 match seed {
303 SeedConfig::SqlFile { path, .. } => *path = resolve_path(path.clone()),
304 SeedConfig::Command { command, cache, .. } => {
305 resolve_command(command);
306 if let CommandCacheConfig::KeyCommand {
307 command: key_command,
308 ..
309 } = cache
310 {
311 resolve_command(key_command);
312 }
313 }
314 SeedConfig::CsvFile { path, .. } => *path = resolve_path(path.clone()),
315 SeedConfig::Script { .. } | SeedConfig::ContainerScript { .. } => {}
316 }
317 }
318 }
319 }
320
321 self
322 }
323
324 pub fn load_toml(contents: impl AsRef<str>) -> Result<Config, Error> {
325 toml::from_str(contents.as_ref()).map_err(Error::TomlDecode)
326 }
327
328 pub fn instance_map(
329 self,
330 overwrites: &InstanceDefinition,
331 ) -> Result<super::InstanceMap, Error> {
332 let defaults = InstanceDefinition {
333 backend: self.backend,
334 image: self.image.clone(),
335 seeds: indexmap::IndexMap::new(),
336 ssl_config: self.ssl_config.clone(),
337 wait_available_timeout: self.wait_available_timeout,
338 };
339
340 match self.instances {
341 None => {
342 let instance_name = InstanceName::default();
343
344 InstanceDefinition::empty()
345 .into_instance(&instance_name, &defaults, overwrites)
346 .map(|instance| [(instance_name, instance)].into())
347 }
348 Some(map) => {
349 let mut instance_map = std::collections::BTreeMap::new();
350
351 for (instance_name, instance_definition) in map {
352 let instance =
353 instance_definition.into_instance(&instance_name, &defaults, overwrites)?;
354
355 instance_map.insert(instance_name, instance);
356 }
357
358 Ok(instance_map)
359 }
360 }
361 }
362}
363
364#[cfg(test)]
365mod test {
366 use super::*;
367
368 #[test]
369 fn sql_file_path_resolved_relative_to_config() {
370 let dir = std::env::temp_dir().join("pg-ephemeral-config-test-sql-file");
371 std::fs::create_dir_all(&dir).unwrap();
372 let config_path = dir.join("database.toml");
373 std::fs::write(
374 &config_path,
375 indoc::indoc! {r#"
376 image = "15.6"
377
378 [instances.main.seeds.schema]
379 type = "sql-file"
380 path = "db/structure.sql"
381 "#},
382 )
383 .unwrap();
384
385 let instance_map =
386 Config::load_toml_file(&config_path, &InstanceDefinition::empty()).unwrap();
387
388 let instance_name: crate::InstanceName = "main".parse().unwrap();
389 let instance = instance_map.get(&instance_name).unwrap();
390 let seed_name: crate::seed::SeedName = "schema".parse().unwrap();
391
392 assert_eq!(
393 instance.seeds[&seed_name],
394 crate::seed::Seed::SqlFile {
395 path: dir.join("db/structure.sql"),
396 }
397 );
398 }
399
400 #[test]
401 fn command_path_resolved_relative_to_config() {
402 let dir = std::env::temp_dir().join("pg-ephemeral-config-test-command");
403 std::fs::create_dir_all(&dir).unwrap();
404 let config_path = dir.join("database.toml");
405 std::fs::write(
406 &config_path,
407 indoc::indoc! {r#"
408 image = "15.6"
409
410 [instances.main.seeds.migrate]
411 type = "command"
412 command = "./bin/migrate"
413 arguments = ["up"]
414 cache = { type = "none" }
415 "#},
416 )
417 .unwrap();
418
419 let instance_map =
420 Config::load_toml_file(&config_path, &InstanceDefinition::empty()).unwrap();
421
422 let instance_name: crate::InstanceName = "main".parse().unwrap();
423 let instance = instance_map.get(&instance_name).unwrap();
424 let seed_name: crate::seed::SeedName = "migrate".parse().unwrap();
425
426 assert_eq!(
427 instance.seeds[&seed_name],
428 crate::seed::Seed::Command {
429 command: crate::seed::Command::new(
430 dir.join("bin/migrate").to_string_lossy(),
431 ["up"],
432 ),
433 cache: crate::seed::CommandCacheConfig::None,
434 }
435 );
436 }
437
438 #[test]
439 fn bare_command_name_not_resolved() {
440 let dir = std::env::temp_dir().join("pg-ephemeral-config-test-bare-command");
441 std::fs::create_dir_all(&dir).unwrap();
442 let config_path = dir.join("database.toml");
443 std::fs::write(
444 &config_path,
445 indoc::indoc! {r#"
446 image = "15.6"
447
448 [instances.main.seeds.schema]
449 type = "command"
450 command = "psql"
451 arguments = ["-f", "schema.sql"]
452 cache = { type = "command-hash" }
453 "#},
454 )
455 .unwrap();
456
457 let instance_map =
458 Config::load_toml_file(&config_path, &InstanceDefinition::empty()).unwrap();
459
460 let instance_name: crate::InstanceName = "main".parse().unwrap();
461 let instance = instance_map.get(&instance_name).unwrap();
462 let seed_name: crate::seed::SeedName = "schema".parse().unwrap();
463
464 assert_eq!(
465 instance.seeds[&seed_name],
466 crate::seed::Seed::Command {
467 command: crate::seed::Command::new("psql", ["-f", "schema.sql"]),
468 cache: crate::seed::CommandCacheConfig::CommandHash,
469 }
470 );
471 }
472
473 #[test]
474 fn container_script_parsed() {
475 let dir = std::env::temp_dir().join("pg-ephemeral-config-test-container-script");
476 std::fs::create_dir_all(&dir).unwrap();
477 let config_path = dir.join("database.toml");
478 std::fs::write(
479 &config_path,
480 indoc::indoc! {r#"
481 image = "15.6"
482
483 [instances.main.seeds.install-ext]
484 type = "container-script"
485 script = "apt-get update && apt-get install -y postgresql-15-cron"
486 "#},
487 )
488 .unwrap();
489
490 let instance_map =
491 Config::load_toml_file(&config_path, &InstanceDefinition::empty()).unwrap();
492
493 let instance_name: crate::InstanceName = "main".parse().unwrap();
494 let instance = instance_map.get(&instance_name).unwrap();
495 let seed_name: crate::seed::SeedName = "install-ext".parse().unwrap();
496
497 assert_eq!(
498 instance.seeds[&seed_name],
499 crate::seed::Seed::ContainerScript {
500 script: "apt-get update && apt-get install -y postgresql-15-cron".to_string(),
501 }
502 );
503 }
504
505 #[test]
506 fn csv_file_parsed() {
507 let dir = std::env::temp_dir().join("pg-ephemeral-config-test-csv-file");
508 std::fs::create_dir_all(&dir).unwrap();
509 let config_path = dir.join("database.toml");
510 std::fs::write(
511 &config_path,
512 indoc::indoc! {r#"
513 image = "15.6"
514
515 [instances.main.seeds.users]
516 type = "csv-file"
517 path = "fixtures/users.csv"
518 table = { schema = "public", table = "users" }
519 "#},
520 )
521 .unwrap();
522
523 let instance_map =
524 Config::load_toml_file(&config_path, &InstanceDefinition::empty()).unwrap();
525
526 let instance_name: crate::InstanceName = "main".parse().unwrap();
527 let instance = instance_map.get(&instance_name).unwrap();
528 let seed_name: crate::seed::SeedName = "users".parse().unwrap();
529
530 assert_eq!(
531 instance.seeds[&seed_name],
532 crate::seed::Seed::CsvFile {
533 path: dir.join("fixtures/users.csv"),
534 table: pg_client::QualifiedTable {
535 schema: pg_client::identifier::Schema::PUBLIC,
536 table: "users".parse().unwrap(),
537 },
538 delimiter: ',',
539 }
540 );
541 }
542}