Skip to main content

pg_ephemeral/
definition.rs

1use std::os::fd::FromRawFd;
2
3use crate::Container;
4use crate::seed::{
5    Command, CommandCacheConfig, DuplicateSeedName, LoadError, LoadedSeed, LoadedSeeds, Seed,
6    SeedName,
7};
8
9#[derive(Debug, thiserror::Error)]
10pub enum SeedApplyError {
11    #[error("Failed to apply command seed")]
12    Command(#[from] cmd_proc::CommandError),
13    #[error("Failed to apply SQL seed")]
14    Sql(#[from] sqlx::Error),
15}
16
17#[derive(Clone, Debug, PartialEq)]
18pub enum SslConfig {
19    Generated { hostname: pg_client::HostName },
20    // UserProvided { ca_cert: PathBuf, server_cert: PathBuf, server_key: PathBuf },
21}
22
23#[derive(Clone, Debug, PartialEq)]
24pub struct Definition {
25    pub instance_name: crate::InstanceName,
26    pub application_name: Option<pg_client::ApplicationName>,
27    pub backend: ociman::Backend,
28    pub database: pg_client::Database,
29    pub seeds: indexmap::IndexMap<SeedName, Seed>,
30    pub ssl_config: Option<SslConfig>,
31    pub superuser: pg_client::User,
32    pub image: crate::image::Image,
33    pub cross_container_access: bool,
34    pub wait_available_timeout: std::time::Duration,
35    pub remove: bool,
36}
37
38impl Definition {
39    #[must_use]
40    pub fn new(
41        backend: ociman::backend::Backend,
42        image: crate::image::Image,
43        instance_name: crate::InstanceName,
44    ) -> Self {
45        Self {
46            instance_name,
47            backend,
48            application_name: None,
49            seeds: indexmap::IndexMap::new(),
50            ssl_config: None,
51            superuser: pg_client::User::POSTGRES,
52            database: pg_client::Database::POSTGRES,
53            image,
54            cross_container_access: false,
55            wait_available_timeout: std::time::Duration::from_secs(10),
56            remove: true,
57        }
58    }
59
60    #[must_use]
61    pub fn remove(self, remove: bool) -> Self {
62        Self { remove, ..self }
63    }
64
65    #[must_use]
66    pub fn image(self, image: crate::image::Image) -> Self {
67        Self { image, ..self }
68    }
69
70    pub fn add_seed(self, name: SeedName, seed: Seed) -> Result<Self, DuplicateSeedName> {
71        let mut seeds = self.seeds.clone();
72
73        if seeds.contains_key(&name) {
74            return Err(DuplicateSeedName(name));
75        }
76
77        seeds.insert(name, seed);
78        Ok(Self { seeds, ..self })
79    }
80
81    pub fn apply_file(
82        self,
83        name: SeedName,
84        path: std::path::PathBuf,
85    ) -> Result<Self, DuplicateSeedName> {
86        self.add_seed(name, Seed::SqlFile { path })
87    }
88
89    pub async fn load_seeds(
90        &self,
91        instance_name: &crate::InstanceName,
92    ) -> Result<LoadedSeeds<'_>, LoadError> {
93        LoadedSeeds::load(
94            &self.image,
95            self.ssl_config.as_ref(),
96            &self.seeds,
97            &self.backend,
98            instance_name,
99        )
100        .await
101    }
102
103    pub async fn print_cache_status(
104        &self,
105        instance_name: &crate::InstanceName,
106        json: bool,
107    ) -> Result<(), crate::container::Error> {
108        let loaded_seeds = self.load_seeds(instance_name).await?;
109        if json {
110            loaded_seeds.print_json(instance_name);
111        } else {
112            loaded_seeds.print(instance_name);
113        }
114        Ok(())
115    }
116
117    #[must_use]
118    pub fn superuser(self, user: pg_client::User) -> Self {
119        Self {
120            superuser: user,
121            ..self
122        }
123    }
124
125    pub fn apply_file_from_git_revision(
126        self,
127        name: SeedName,
128        path: std::path::PathBuf,
129        git_revision: impl Into<String>,
130    ) -> Result<Self, DuplicateSeedName> {
131        self.add_seed(
132            name,
133            Seed::SqlFileGitRevision {
134                git_revision: git_revision.into(),
135                path,
136            },
137        )
138    }
139
140    pub fn apply_command(
141        self,
142        name: SeedName,
143        command: Command,
144        cache: CommandCacheConfig,
145    ) -> Result<Self, DuplicateSeedName> {
146        self.add_seed(name, Seed::Command { command, cache })
147    }
148
149    pub fn apply_script(
150        self,
151        name: SeedName,
152        script: impl Into<String>,
153    ) -> Result<Self, DuplicateSeedName> {
154        self.add_seed(
155            name,
156            Seed::Script {
157                script: script.into(),
158            },
159        )
160    }
161
162    pub fn apply_container_script(
163        self,
164        name: SeedName,
165        script: impl Into<String>,
166    ) -> Result<Self, DuplicateSeedName> {
167        self.add_seed(
168            name,
169            Seed::ContainerScript {
170                script: script.into(),
171            },
172        )
173    }
174
175    #[must_use]
176    pub fn ssl_config(self, ssl_config: SslConfig) -> Self {
177        Self {
178            ssl_config: Some(ssl_config),
179            ..self
180        }
181    }
182
183    #[must_use]
184    pub fn cross_container_access(self, enabled: bool) -> Self {
185        Self {
186            cross_container_access: enabled,
187            ..self
188        }
189    }
190
191    #[must_use]
192    pub fn wait_available_timeout(self, timeout: std::time::Duration) -> Self {
193        Self {
194            wait_available_timeout: timeout,
195            ..self
196        }
197    }
198
199    #[must_use]
200    pub fn to_ociman_definition(&self) -> ociman::Definition {
201        ociman::Definition::new(self.backend.clone(), (&self.image).into())
202    }
203
204    pub async fn with_container<T>(
205        &self,
206        mut action: impl AsyncFnMut(&Container) -> T,
207    ) -> Result<T, crate::container::Error> {
208        let (last_cache_hit, uncached_seeds) = self.populate_cache(&self.instance_name).await?;
209
210        let boot_definition = match &last_cache_hit {
211            Some(reference) => self
212                .clone()
213                .image(crate::image::Image::Explicit(reference.clone())),
214            None => self.clone(),
215        };
216
217        let mut db_container = Container::run_definition(&boot_definition).await;
218
219        if last_cache_hit.is_some() {
220            db_container
221                .set_superuser_password(db_container.client_config.password.as_ref().unwrap())
222                .await?;
223        }
224
225        db_container.wait_available().await?;
226
227        for seed in &uncached_seeds {
228            self.apply_loaded_seed(&db_container, seed).await?;
229        }
230
231        let result = action(&db_container).await;
232
233        db_container.stop().await;
234
235        Ok(result)
236    }
237
238    /// Populate cache images for seeds.
239    ///
240    /// Returns a tuple of:
241    /// - The last cache hit reference (if any), which can be used to boot from
242    /// - The loaded seeds that could not be cached because the cache chain was broken
243    pub async fn populate_cache(
244        &self,
245        instance_name: &crate::InstanceName,
246    ) -> Result<(Option<ociman::Reference>, Vec<LoadedSeed>), crate::container::Error> {
247        let loaded_seeds = self.load_seeds(instance_name).await?;
248
249        let mut previous_cache_reference: Option<&ociman::Reference> = None;
250        let mut seeds_iter = loaded_seeds.iter_seeds().peekable();
251
252        while let Some(seed) = seeds_iter.next() {
253            let Some(cache_reference) = seed.cache_status().reference() else {
254                // Uncacheable seed - cache chain is broken, return remaining seeds
255                let mut remaining = vec![seed.clone()];
256                remaining.extend(seeds_iter.cloned());
257                return Ok((previous_cache_reference.cloned(), remaining));
258            };
259
260            if seed.cache_status().is_hit() {
261                previous_cache_reference = Some(cache_reference);
262                continue;
263            }
264
265            let caching_image = previous_cache_reference
266                .map(|reference| crate::image::Image::Explicit(reference.clone()))
267                .unwrap_or_else(|| self.image.clone());
268
269            if let LoadedSeed::ContainerScript { script, .. } = seed {
270                log::info!("Applying container-script seed: {}", seed.name());
271
272                let base_image: ociman::image::Reference = (&caching_image).into();
273                let build_dir = create_container_script_build_dir(&base_image, script);
274
275                ociman::image::BuildDefinition::from_directory(
276                    &self.backend,
277                    cache_reference.clone(),
278                    &build_dir,
279                )
280                .build()
281                .await;
282
283                std::fs::remove_dir_all(&build_dir)
284                    .expect("failed to clean up container-script build directory");
285            } else {
286                let caching_definition = self.clone().remove(false).image(caching_image);
287
288                let mut container = Container::run_definition(&caching_definition).await;
289
290                if previous_cache_reference.is_some() {
291                    container
292                        .set_superuser_password(container.client_config.password.as_ref().unwrap())
293                        .await?;
294                }
295
296                container.wait_available().await?;
297
298                self.apply_loaded_seed(&container, seed).await?;
299                container.stop_commit_remove(cache_reference).await;
300            }
301
302            log::info!("Committed cache image: {cache_reference}");
303
304            previous_cache_reference = Some(cache_reference);
305        }
306
307        Ok((previous_cache_reference.cloned(), Vec::new()))
308    }
309
310    pub async fn run_integration_server(
311        &self,
312        result_fd: std::os::fd::RawFd,
313        control_fd: std::os::fd::RawFd,
314    ) -> Result<(), crate::container::Error> {
315        self.with_container(async |container| {
316            // SAFETY: The parent process guarantees these are valid, exclusively-owned FDs
317            // inherited via the process spawn protocol.
318            let result_owned = unsafe { std::os::fd::OwnedFd::from_raw_fd(result_fd) };
319            let control_owned = unsafe { std::os::fd::OwnedFd::from_raw_fd(control_fd) };
320
321            let mut result_file = std::fs::File::from(result_owned);
322            let json = serde_json::to_string(&container.client_config).unwrap();
323
324            use std::io::Write;
325            writeln!(result_file, "{json}").expect("Failed to write config to result pipe");
326            drop(result_file);
327
328            log::info!("Integration server is running, waiting for EOF on control pipe");
329
330            let control_fd = tokio::io::unix::AsyncFd::new(control_owned)
331                .expect("Failed to register control pipe with tokio");
332
333            let _ = control_fd.readable().await.unwrap();
334
335            log::info!("Integration server received EOF on control pipe, exiting");
336        })
337        .await
338    }
339
340    async fn apply_loaded_seed(
341        &self,
342        db_container: &Container,
343        loaded_seed: &LoadedSeed,
344    ) -> Result<(), SeedApplyError> {
345        log::info!("Applying seed: {}", loaded_seed.name());
346        match loaded_seed {
347            LoadedSeed::SqlFile { content, .. } => db_container.apply_sql(content).await?,
348            LoadedSeed::SqlFileGitRevision { content, .. } => {
349                db_container.apply_sql(content).await?
350            }
351            LoadedSeed::Command { command, .. } => {
352                self.execute_command(db_container, command).await?
353            }
354            LoadedSeed::Script { script, .. } => self.execute_script(db_container, script).await?,
355            LoadedSeed::ContainerScript { script, .. } => {
356                db_container.exec_container_script(script).await?
357            }
358        }
359
360        Ok(())
361    }
362
363    async fn execute_command(
364        &self,
365        db_container: &Container,
366        command: &Command,
367    ) -> Result<(), cmd_proc::CommandError> {
368        cmd_proc::Command::new(&command.command)
369            .arguments(&command.arguments)
370            .envs(db_container.pg_env())
371            .env(&crate::ENV_DATABASE_URL, db_container.database_url())
372            .status()
373            .await
374    }
375
376    async fn execute_script(
377        &self,
378        db_container: &Container,
379        script: &str,
380    ) -> Result<(), cmd_proc::CommandError> {
381        cmd_proc::Command::new("sh")
382            .arguments(["-e", "-c"])
383            .argument(script)
384            .envs(db_container.pg_env())
385            .env(&crate::ENV_DATABASE_URL, db_container.database_url())
386            .status()
387            .await
388    }
389
390    pub async fn schema_dump(
391        &self,
392        client_config: &pg_client::Config,
393        pg_schema_dump: &pg_client::PgSchemaDump,
394    ) -> String {
395        let (effective_config, mounts) = apply_ociman_mounts(client_config);
396
397        let bytes = self
398            .to_ociman_definition()
399            .entrypoint("pg_dump")
400            .arguments(pg_schema_dump.arguments())
401            .environment_variables(effective_config.to_pg_env())
402            .mounts(mounts)
403            .run_capture_only_stdout()
404            .await;
405
406        crate::convert_schema(&bytes)
407    }
408}
409
410#[must_use]
411pub fn apply_ociman_mounts(
412    client_config: &pg_client::Config,
413) -> (pg_client::Config, Vec<ociman::Mount>) {
414    let owned_client_config = client_config.clone();
415
416    match client_config.ssl_root_cert {
417        Some(ref ssl_root_cert) => match ssl_root_cert {
418            pg_client::SslRootCert::File(file) => {
419                let host =
420                    std::fs::canonicalize(file).expect("could not canonicalize ssl root path");
421
422                let mut container_path = std::path::PathBuf::new();
423
424                container_path.push("/pg_ephemeral");
425                container_path.push(file.file_name().unwrap());
426
427                let mounts = vec![ociman::Mount::from(format!(
428                    "type=bind,ro,source={},target={}",
429                    host.to_str().unwrap(),
430                    container_path.to_str().unwrap()
431                ))];
432
433                (
434                    pg_client::Config {
435                        ssl_root_cert: Some(container_path.into()),
436                        ..owned_client_config
437                    },
438                    mounts,
439                )
440            }
441            pg_client::SslRootCert::System => (owned_client_config, vec![]),
442        },
443        None => (owned_client_config, vec![]),
444    }
445}
446
447fn create_container_script_build_dir(
448    base_image: &ociman::image::Reference,
449    script: &str,
450) -> std::path::PathBuf {
451    use rand::RngExt;
452
453    let suffix: String = rand::rng()
454        .sample_iter(rand::distr::Alphanumeric)
455        .take(16)
456        .map(char::from)
457        .collect();
458
459    let dir = std::env::temp_dir().join(format!("pg-ephemeral-build-{suffix}"));
460    std::fs::create_dir(&dir).expect("failed to create container-script build directory");
461
462    std::fs::write(dir.join("script.sh"), script).expect("failed to write container-script");
463
464    std::fs::write(
465        dir.join("Dockerfile"),
466        format!("FROM {base_image}\nCOPY script.sh /tmp/pg-ephemeral-script.sh\nRUN sh -e /tmp/pg-ephemeral-script.sh && rm /tmp/pg-ephemeral-script.sh\n"),
467    )
468    .expect("failed to write Dockerfile");
469
470    dir
471}
472
473#[cfg(test)]
474mod test {
475    use super::*;
476
477    fn test_backend() -> ociman::Backend {
478        ociman::Backend::Podman {
479            version: semver::Version::new(4, 0, 0),
480        }
481    }
482
483    fn test_instance_name() -> crate::InstanceName {
484        "test".parse().unwrap()
485    }
486
487    #[test]
488    fn test_add_seed_rejects_duplicate() {
489        let definition = Definition::new(
490            test_backend(),
491            crate::Image::default(),
492            test_instance_name(),
493        );
494        let seed_name: SeedName = "test-seed".parse().unwrap();
495
496        let definition = definition
497            .add_seed(
498                seed_name.clone(),
499                Seed::SqlFile {
500                    path: "file1.sql".into(),
501                },
502            )
503            .unwrap();
504
505        let result = definition.add_seed(
506            seed_name.clone(),
507            Seed::SqlFile {
508                path: "file2.sql".into(),
509            },
510        );
511
512        assert_eq!(result, Err(DuplicateSeedName(seed_name)));
513    }
514
515    #[test]
516    fn test_add_seed_allows_different_names() {
517        let definition = Definition::new(
518            test_backend(),
519            crate::Image::default(),
520            test_instance_name(),
521        );
522
523        let definition = definition
524            .add_seed(
525                "seed1".parse().unwrap(),
526                Seed::SqlFile {
527                    path: "file1.sql".into(),
528                },
529            )
530            .unwrap();
531
532        let result = definition.add_seed(
533            "seed2".parse().unwrap(),
534            Seed::SqlFile {
535                path: "file2.sql".into(),
536            },
537        );
538
539        assert!(result.is_ok());
540    }
541
542    #[test]
543    fn test_apply_file_rejects_duplicate() {
544        let definition = Definition::new(
545            test_backend(),
546            crate::Image::default(),
547            test_instance_name(),
548        );
549        let seed_name: SeedName = "test-seed".parse().unwrap();
550
551        let definition = definition
552            .apply_file(seed_name.clone(), "file1.sql".into())
553            .unwrap();
554
555        let result = definition.apply_file(seed_name.clone(), "file2.sql".into());
556
557        assert_eq!(result, Err(DuplicateSeedName(seed_name)));
558    }
559
560    #[test]
561    fn test_apply_command_adds_seed() {
562        let definition = Definition::new(
563            test_backend(),
564            crate::Image::default(),
565            test_instance_name(),
566        );
567
568        let result = definition.apply_command(
569            "test-command".parse().unwrap(),
570            Command::new("echo", vec!["test"]),
571            CommandCacheConfig::CommandHash,
572        );
573
574        assert!(result.is_ok());
575        let definition = result.unwrap();
576        assert_eq!(definition.seeds.len(), 1);
577    }
578
579    #[test]
580    fn test_apply_command_rejects_duplicate() {
581        let definition = Definition::new(
582            test_backend(),
583            crate::Image::default(),
584            test_instance_name(),
585        );
586        let seed_name: SeedName = "test-command".parse().unwrap();
587
588        let definition = definition
589            .apply_command(
590                seed_name.clone(),
591                Command::new("echo", vec!["test1"]),
592                CommandCacheConfig::CommandHash,
593            )
594            .unwrap();
595
596        let result = definition.apply_command(
597            seed_name.clone(),
598            Command::new("echo", vec!["test2"]),
599            CommandCacheConfig::CommandHash,
600        );
601
602        assert_eq!(result, Err(DuplicateSeedName(seed_name)));
603    }
604
605    #[test]
606    fn test_apply_script_adds_seed() {
607        let definition = Definition::new(
608            test_backend(),
609            crate::Image::default(),
610            test_instance_name(),
611        );
612
613        let result = definition.apply_script("test-script".parse().unwrap(), "echo test");
614
615        assert!(result.is_ok());
616        let definition = result.unwrap();
617        assert_eq!(definition.seeds.len(), 1);
618    }
619
620    #[test]
621    fn test_apply_script_rejects_duplicate() {
622        let definition = Definition::new(
623            test_backend(),
624            crate::Image::default(),
625            test_instance_name(),
626        );
627        let seed_name: SeedName = "test-script".parse().unwrap();
628
629        let definition = definition
630            .apply_script(seed_name.clone(), "echo test1")
631            .unwrap();
632
633        let result = definition.apply_script(seed_name.clone(), "echo test2");
634
635        assert_eq!(result, Err(DuplicateSeedName(seed_name)));
636    }
637
638    #[test]
639    fn test_apply_container_script_adds_seed() {
640        let definition = Definition::new(
641            test_backend(),
642            crate::Image::default(),
643            test_instance_name(),
644        );
645
646        let result = definition.apply_container_script(
647            "install-ext".parse().unwrap(),
648            "apt-get update && apt-get install -y postgresql-17-cron",
649        );
650
651        assert!(result.is_ok());
652        let definition = result.unwrap();
653        assert_eq!(definition.seeds.len(), 1);
654    }
655
656    #[test]
657    fn test_apply_container_script_rejects_duplicate() {
658        let definition = Definition::new(
659            test_backend(),
660            crate::Image::default(),
661            test_instance_name(),
662        );
663        let seed_name: SeedName = "install-ext".parse().unwrap();
664
665        let definition = definition
666            .apply_container_script(seed_name.clone(), "apt-get update")
667            .unwrap();
668
669        let result = definition.apply_container_script(seed_name.clone(), "apt-get update");
670
671        assert_eq!(result, Err(DuplicateSeedName(seed_name)));
672    }
673}