Skip to main content

pg_ephemeral/
definition.rs

1use std::os::fd::FromRawFd;
2
3use crate::Container;
4use crate::seed::{
5    Command, CommandCacheConfig, DuplicateSeedName, LoadError, LoadedSeed, LoadedSeeds, Seed,
6    SeedName,
7};
8
9#[derive(Debug, thiserror::Error)]
10pub enum SeedApplyError {
11    #[error("Failed to apply command seed")]
12    Command(#[from] cmd_proc::CommandError),
13    #[error("Failed to apply SQL seed")]
14    Sql(#[from] sqlx::Error),
15}
16
17#[derive(Clone, Debug, PartialEq)]
18pub enum SslConfig {
19    Generated {
20        hostname: pg_client::config::HostName,
21    },
22    // UserProvided { ca_cert: PathBuf, server_cert: PathBuf, server_key: PathBuf },
23}
24
25#[derive(Clone, Debug, PartialEq)]
26pub struct Definition {
27    pub instance_name: crate::InstanceName,
28    pub application_name: Option<pg_client::config::ApplicationName>,
29    pub backend: ociman::Backend,
30    pub database: pg_client::Database,
31    pub seeds: indexmap::IndexMap<SeedName, Seed>,
32    pub ssl_config: Option<SslConfig>,
33    pub superuser: pg_client::User,
34    pub image: crate::image::Image,
35    pub cross_container_access: bool,
36    pub wait_available_timeout: std::time::Duration,
37    pub remove: bool,
38}
39
40impl Definition {
41    #[must_use]
42    pub fn new(
43        backend: ociman::backend::Backend,
44        image: crate::image::Image,
45        instance_name: crate::InstanceName,
46    ) -> Self {
47        Self {
48            instance_name,
49            backend,
50            application_name: None,
51            seeds: indexmap::IndexMap::new(),
52            ssl_config: None,
53            superuser: pg_client::User::POSTGRES,
54            database: pg_client::Database::POSTGRES,
55            image,
56            cross_container_access: false,
57            wait_available_timeout: std::time::Duration::from_secs(10),
58            remove: true,
59        }
60    }
61
62    #[must_use]
63    pub fn remove(self, remove: bool) -> Self {
64        Self { remove, ..self }
65    }
66
67    #[must_use]
68    pub fn image(self, image: crate::image::Image) -> Self {
69        Self { image, ..self }
70    }
71
72    pub fn add_seed(self, name: SeedName, seed: Seed) -> Result<Self, DuplicateSeedName> {
73        let mut seeds = self.seeds.clone();
74
75        if seeds.contains_key(&name) {
76            return Err(DuplicateSeedName(name));
77        }
78
79        seeds.insert(name, seed);
80        Ok(Self { seeds, ..self })
81    }
82
83    pub fn apply_file(
84        self,
85        name: SeedName,
86        path: std::path::PathBuf,
87    ) -> Result<Self, DuplicateSeedName> {
88        self.add_seed(name, Seed::SqlFile { path })
89    }
90
91    pub async fn load_seeds(
92        &self,
93        instance_name: &crate::InstanceName,
94    ) -> Result<LoadedSeeds<'_>, LoadError> {
95        LoadedSeeds::load(
96            &self.image,
97            self.ssl_config.as_ref(),
98            &self.seeds,
99            &self.backend,
100            instance_name,
101        )
102        .await
103    }
104
105    pub async fn print_cache_status(
106        &self,
107        instance_name: &crate::InstanceName,
108        json: bool,
109    ) -> Result<(), crate::container::Error> {
110        let loaded_seeds = self.load_seeds(instance_name).await?;
111        if json {
112            loaded_seeds.print_json(instance_name);
113        } else {
114            loaded_seeds.print(instance_name);
115        }
116        Ok(())
117    }
118
119    #[must_use]
120    pub fn superuser(self, user: pg_client::User) -> Self {
121        Self {
122            superuser: user,
123            ..self
124        }
125    }
126
127    pub fn apply_file_from_git_revision(
128        self,
129        name: SeedName,
130        path: std::path::PathBuf,
131        git_revision: impl Into<String>,
132    ) -> Result<Self, DuplicateSeedName> {
133        self.add_seed(
134            name,
135            Seed::SqlFileGitRevision {
136                git_revision: git_revision.into(),
137                path,
138            },
139        )
140    }
141
142    pub fn apply_command(
143        self,
144        name: SeedName,
145        command: Command,
146        cache: CommandCacheConfig,
147    ) -> Result<Self, DuplicateSeedName> {
148        self.add_seed(name, Seed::Command { command, cache })
149    }
150
151    pub fn apply_script(
152        self,
153        name: SeedName,
154        script: impl Into<String>,
155    ) -> Result<Self, DuplicateSeedName> {
156        self.add_seed(
157            name,
158            Seed::Script {
159                script: script.into(),
160            },
161        )
162    }
163
164    pub fn apply_container_script(
165        self,
166        name: SeedName,
167        script: impl Into<String>,
168    ) -> Result<Self, DuplicateSeedName> {
169        self.add_seed(
170            name,
171            Seed::ContainerScript {
172                script: script.into(),
173            },
174        )
175    }
176
177    #[must_use]
178    pub fn ssl_config(self, ssl_config: SslConfig) -> Self {
179        Self {
180            ssl_config: Some(ssl_config),
181            ..self
182        }
183    }
184
185    #[must_use]
186    pub fn cross_container_access(self, enabled: bool) -> Self {
187        Self {
188            cross_container_access: enabled,
189            ..self
190        }
191    }
192
193    #[must_use]
194    pub fn wait_available_timeout(self, timeout: std::time::Duration) -> Self {
195        Self {
196            wait_available_timeout: timeout,
197            ..self
198        }
199    }
200
201    #[must_use]
202    pub fn to_ociman_definition(&self) -> ociman::Definition {
203        ociman::Definition::new(self.backend.clone(), (&self.image).into())
204    }
205
206    pub async fn with_container<T>(
207        &self,
208        mut action: impl AsyncFnMut(&Container) -> T,
209    ) -> Result<T, crate::container::Error> {
210        let (last_cache_hit, uncached_seeds) = self.populate_cache(&self.instance_name).await?;
211
212        let boot_definition = match &last_cache_hit {
213            Some(reference) => self
214                .clone()
215                .image(crate::image::Image::Explicit(reference.clone())),
216            None => self.clone(),
217        };
218
219        let mut db_container = Container::run_definition(&boot_definition).await;
220
221        if last_cache_hit.is_some() {
222            db_container
223                .set_superuser_password(
224                    db_container
225                        .client_config
226                        .session
227                        .password
228                        .as_ref()
229                        .unwrap(),
230                )
231                .await?;
232        }
233
234        db_container.wait_available().await?;
235
236        for seed in &uncached_seeds {
237            self.apply_loaded_seed(&db_container, seed).await?;
238        }
239
240        let result = action(&db_container).await;
241
242        db_container.stop().await;
243
244        Ok(result)
245    }
246
247    /// Populate cache images for seeds.
248    ///
249    /// Returns a tuple of:
250    /// - The last cache hit reference (if any), which can be used to boot from
251    /// - The loaded seeds that could not be cached because the cache chain was broken
252    pub async fn populate_cache(
253        &self,
254        instance_name: &crate::InstanceName,
255    ) -> Result<(Option<ociman::Reference>, Vec<LoadedSeed>), crate::container::Error> {
256        let loaded_seeds = self.load_seeds(instance_name).await?;
257
258        let mut previous_cache_reference: Option<&ociman::Reference> = None;
259        let mut seeds_iter = loaded_seeds.iter_seeds().peekable();
260
261        while let Some(seed) = seeds_iter.next() {
262            let Some(cache_reference) = seed.cache_status().reference() else {
263                // Uncacheable seed - cache chain is broken, return remaining seeds
264                let mut remaining = vec![seed.clone()];
265                remaining.extend(seeds_iter.cloned());
266                return Ok((previous_cache_reference.cloned(), remaining));
267            };
268
269            if seed.cache_status().is_hit() {
270                previous_cache_reference = Some(cache_reference);
271                continue;
272            }
273
274            let caching_image = previous_cache_reference
275                .map(|reference| crate::image::Image::Explicit(reference.clone()))
276                .unwrap_or_else(|| self.image.clone());
277
278            if let LoadedSeed::ContainerScript { script, .. } = seed {
279                log::info!("Applying container-script seed: {}", seed.name());
280
281                let base_image: ociman::image::Reference = (&caching_image).into();
282                let build_dir = create_container_script_build_dir(&base_image, script);
283
284                ociman::image::BuildDefinition::from_directory(
285                    &self.backend,
286                    cache_reference.clone(),
287                    &build_dir,
288                )
289                .build()
290                .await;
291
292                std::fs::remove_dir_all(&build_dir)
293                    .expect("failed to clean up container-script build directory");
294            } else {
295                let caching_definition = self.clone().remove(false).image(caching_image);
296
297                let mut container = Container::run_definition(&caching_definition).await;
298
299                if previous_cache_reference.is_some() {
300                    container
301                        .set_superuser_password(
302                            container.client_config.session.password.as_ref().unwrap(),
303                        )
304                        .await?;
305                }
306
307                container.wait_available().await?;
308
309                self.apply_loaded_seed(&container, seed).await?;
310                container.stop_commit_remove(cache_reference).await?;
311            }
312
313            log::info!("Committed cache image: {cache_reference}");
314
315            previous_cache_reference = Some(cache_reference);
316        }
317
318        Ok((previous_cache_reference.cloned(), Vec::new()))
319    }
320
321    pub async fn run_integration_server(
322        &self,
323        result_fd: std::os::fd::RawFd,
324        control_fd: std::os::fd::RawFd,
325    ) -> Result<(), crate::container::Error> {
326        self.with_container(async |container| {
327            // SAFETY: The parent process guarantees these are valid, exclusively-owned FDs
328            // inherited via the process spawn protocol.
329            let result_owned = unsafe { std::os::fd::OwnedFd::from_raw_fd(result_fd) };
330            let control_owned = unsafe { std::os::fd::OwnedFd::from_raw_fd(control_fd) };
331
332            let mut result_file = std::fs::File::from(result_owned);
333            let json = serde_json::to_string(&container.client_config).unwrap();
334
335            use std::io::Write;
336            writeln!(result_file, "{json}").expect("Failed to write config to result pipe");
337            drop(result_file);
338
339            log::info!("Integration server is running, waiting for EOF on control pipe");
340
341            let control_fd = tokio::io::unix::AsyncFd::new(control_owned)
342                .expect("Failed to register control pipe with tokio");
343
344            let _ = control_fd.readable().await.unwrap();
345
346            log::info!("Integration server received EOF on control pipe, exiting");
347        })
348        .await
349    }
350
351    async fn apply_loaded_seed(
352        &self,
353        db_container: &Container,
354        loaded_seed: &LoadedSeed,
355    ) -> Result<(), SeedApplyError> {
356        log::info!("Applying seed: {}", loaded_seed.name());
357        match loaded_seed {
358            LoadedSeed::SqlFile { content, .. } => db_container.apply_sql(content).await?,
359            LoadedSeed::SqlFileGitRevision { content, .. } => {
360                db_container.apply_sql(content).await?
361            }
362            LoadedSeed::Command { command, .. } => {
363                self.execute_command(db_container, command).await?
364            }
365            LoadedSeed::Script { script, .. } => self.execute_script(db_container, script).await?,
366            LoadedSeed::ContainerScript { script, .. } => {
367                db_container.exec_container_script(script).await?
368            }
369        }
370
371        Ok(())
372    }
373
374    async fn execute_command(
375        &self,
376        db_container: &Container,
377        command: &Command,
378    ) -> Result<(), cmd_proc::CommandError> {
379        cmd_proc::Command::new(&command.command)
380            .arguments(&command.arguments)
381            .envs(db_container.pg_env())
382            .env(&crate::ENV_DATABASE_URL, db_container.database_url())
383            .status()
384            .await
385    }
386
387    async fn execute_script(
388        &self,
389        db_container: &Container,
390        script: &str,
391    ) -> Result<(), cmd_proc::CommandError> {
392        cmd_proc::Command::new("sh")
393            .arguments(["-e", "-c"])
394            .argument(script)
395            .envs(db_container.pg_env())
396            .env(&crate::ENV_DATABASE_URL, db_container.database_url())
397            .status()
398            .await
399    }
400
401    pub async fn schema_dump(
402        &self,
403        client_config: &pg_client::Config,
404        pg_schema_dump: &pg_client::PgSchemaDump,
405    ) -> String {
406        let (effective_config, mounts) = apply_ociman_mounts(client_config);
407
408        let bytes = self
409            .to_ociman_definition()
410            .entrypoint("pg_dump")
411            .arguments(pg_schema_dump.arguments())
412            .environment_variables(effective_config.to_pg_env())
413            .mounts(mounts)
414            .run_capture_only_stdout()
415            .await;
416
417        crate::convert_schema(&bytes)
418    }
419}
420
421#[must_use]
422pub fn apply_ociman_mounts(
423    client_config: &pg_client::Config,
424) -> (pg_client::Config, Vec<ociman::Mount>) {
425    let owned_client_config = client_config.clone();
426
427    match client_config.ssl_root_cert {
428        Some(ref ssl_root_cert) => match ssl_root_cert {
429            pg_client::config::SslRootCert::File(file) => {
430                let host =
431                    std::fs::canonicalize(file).expect("could not canonicalize ssl root path");
432
433                let mut container_path = std::path::PathBuf::new();
434
435                container_path.push("/pg_ephemeral");
436                container_path.push(file.file_name().unwrap());
437
438                let mounts = vec![ociman::Mount::from(format!(
439                    "type=bind,ro,source={},target={}",
440                    host.to_str().unwrap(),
441                    container_path.to_str().unwrap()
442                ))];
443
444                (
445                    pg_client::Config {
446                        ssl_root_cert: Some(container_path.into()),
447                        ..owned_client_config
448                    },
449                    mounts,
450                )
451            }
452            pg_client::config::SslRootCert::System => (owned_client_config, vec![]),
453        },
454        None => (owned_client_config, vec![]),
455    }
456}
457
458fn create_container_script_build_dir(
459    base_image: &ociman::image::Reference,
460    script: &str,
461) -> std::path::PathBuf {
462    use rand::RngExt;
463
464    let suffix: String = rand::rng()
465        .sample_iter(rand::distr::Alphanumeric)
466        .take(16)
467        .map(char::from)
468        .collect();
469
470    let dir = std::env::temp_dir().join(format!("pg-ephemeral-build-{suffix}"));
471    std::fs::create_dir(&dir).expect("failed to create container-script build directory");
472
473    std::fs::write(dir.join("script.sh"), script).expect("failed to write container-script");
474
475    std::fs::write(
476        dir.join("Dockerfile"),
477        format!("FROM {base_image}\nCOPY script.sh /tmp/pg-ephemeral-script.sh\nRUN sh -e /tmp/pg-ephemeral-script.sh && rm /tmp/pg-ephemeral-script.sh\n"),
478    )
479    .expect("failed to write Dockerfile");
480
481    dir
482}
483
484#[cfg(test)]
485mod test {
486    use super::*;
487
488    fn test_backend() -> ociman::Backend {
489        ociman::Backend::Podman {
490            version: semver::Version::new(4, 0, 0),
491        }
492    }
493
494    fn test_instance_name() -> crate::InstanceName {
495        "test".parse().unwrap()
496    }
497
498    #[test]
499    fn test_add_seed_rejects_duplicate() {
500        let definition = Definition::new(
501            test_backend(),
502            crate::Image::default(),
503            test_instance_name(),
504        );
505        let seed_name: SeedName = "test-seed".parse().unwrap();
506
507        let definition = definition
508            .add_seed(
509                seed_name.clone(),
510                Seed::SqlFile {
511                    path: "file1.sql".into(),
512                },
513            )
514            .unwrap();
515
516        let result = definition.add_seed(
517            seed_name.clone(),
518            Seed::SqlFile {
519                path: "file2.sql".into(),
520            },
521        );
522
523        assert_eq!(result, Err(DuplicateSeedName(seed_name)));
524    }
525
526    #[test]
527    fn test_add_seed_allows_different_names() {
528        let definition = Definition::new(
529            test_backend(),
530            crate::Image::default(),
531            test_instance_name(),
532        );
533
534        let definition = definition
535            .add_seed(
536                "seed1".parse().unwrap(),
537                Seed::SqlFile {
538                    path: "file1.sql".into(),
539                },
540            )
541            .unwrap();
542
543        let result = definition.add_seed(
544            "seed2".parse().unwrap(),
545            Seed::SqlFile {
546                path: "file2.sql".into(),
547            },
548        );
549
550        assert!(result.is_ok());
551    }
552
553    #[test]
554    fn test_apply_file_rejects_duplicate() {
555        let definition = Definition::new(
556            test_backend(),
557            crate::Image::default(),
558            test_instance_name(),
559        );
560        let seed_name: SeedName = "test-seed".parse().unwrap();
561
562        let definition = definition
563            .apply_file(seed_name.clone(), "file1.sql".into())
564            .unwrap();
565
566        let result = definition.apply_file(seed_name.clone(), "file2.sql".into());
567
568        assert_eq!(result, Err(DuplicateSeedName(seed_name)));
569    }
570
571    #[test]
572    fn test_apply_command_adds_seed() {
573        let definition = Definition::new(
574            test_backend(),
575            crate::Image::default(),
576            test_instance_name(),
577        );
578
579        let result = definition.apply_command(
580            "test-command".parse().unwrap(),
581            Command::new("echo", vec!["test"]),
582            CommandCacheConfig::CommandHash,
583        );
584
585        assert!(result.is_ok());
586        let definition = result.unwrap();
587        assert_eq!(definition.seeds.len(), 1);
588    }
589
590    #[test]
591    fn test_apply_command_rejects_duplicate() {
592        let definition = Definition::new(
593            test_backend(),
594            crate::Image::default(),
595            test_instance_name(),
596        );
597        let seed_name: SeedName = "test-command".parse().unwrap();
598
599        let definition = definition
600            .apply_command(
601                seed_name.clone(),
602                Command::new("echo", vec!["test1"]),
603                CommandCacheConfig::CommandHash,
604            )
605            .unwrap();
606
607        let result = definition.apply_command(
608            seed_name.clone(),
609            Command::new("echo", vec!["test2"]),
610            CommandCacheConfig::CommandHash,
611        );
612
613        assert_eq!(result, Err(DuplicateSeedName(seed_name)));
614    }
615
616    #[test]
617    fn test_apply_script_adds_seed() {
618        let definition = Definition::new(
619            test_backend(),
620            crate::Image::default(),
621            test_instance_name(),
622        );
623
624        let result = definition.apply_script("test-script".parse().unwrap(), "echo test");
625
626        assert!(result.is_ok());
627        let definition = result.unwrap();
628        assert_eq!(definition.seeds.len(), 1);
629    }
630
631    #[test]
632    fn test_apply_script_rejects_duplicate() {
633        let definition = Definition::new(
634            test_backend(),
635            crate::Image::default(),
636            test_instance_name(),
637        );
638        let seed_name: SeedName = "test-script".parse().unwrap();
639
640        let definition = definition
641            .apply_script(seed_name.clone(), "echo test1")
642            .unwrap();
643
644        let result = definition.apply_script(seed_name.clone(), "echo test2");
645
646        assert_eq!(result, Err(DuplicateSeedName(seed_name)));
647    }
648
649    #[test]
650    fn test_apply_container_script_adds_seed() {
651        let definition = Definition::new(
652            test_backend(),
653            crate::Image::default(),
654            test_instance_name(),
655        );
656
657        let result = definition.apply_container_script(
658            "install-ext".parse().unwrap(),
659            "apt-get update && apt-get install -y postgresql-17-cron",
660        );
661
662        assert!(result.is_ok());
663        let definition = result.unwrap();
664        assert_eq!(definition.seeds.len(), 1);
665    }
666
667    #[test]
668    fn test_apply_container_script_rejects_duplicate() {
669        let definition = Definition::new(
670            test_backend(),
671            crate::Image::default(),
672            test_instance_name(),
673        );
674        let seed_name: SeedName = "install-ext".parse().unwrap();
675
676        let definition = definition
677            .apply_container_script(seed_name.clone(), "apt-get update")
678            .unwrap();
679
680        let result = definition.apply_container_script(seed_name.clone(), "apt-get update");
681
682        assert_eq!(result, Err(DuplicateSeedName(seed_name)));
683    }
684}