Skip to main content

pg_ephemeral/
definition.rs

1use std::os::fd::FromRawFd;
2
3use crate::Container;
4use crate::seed::{
5    Command, CommandCacheConfig, DuplicateSeedName, LoadError, LoadedSeed, LoadedSeeds, Seed,
6    SeedName,
7};
8
9#[derive(Debug, thiserror::Error)]
10pub enum SeedApplyError {
11    #[error("Failed to apply command seed")]
12    Command(#[from] cmd_proc::CommandError),
13    #[error("Failed to apply SQL seed")]
14    Sql(#[from] sqlx::Error),
15}
16
17#[derive(Clone, Debug, PartialEq)]
18pub enum SslConfig {
19    Generated {
20        hostname: pg_client::config::HostName,
21    },
22    // UserProvided { ca_cert: PathBuf, server_cert: PathBuf, server_key: PathBuf },
23}
24
25#[derive(Clone, Debug, PartialEq)]
26pub struct Definition {
27    pub instance_name: crate::InstanceName,
28    pub application_name: Option<pg_client::config::ApplicationName>,
29    pub backend: ociman::Backend,
30    pub database: pg_client::Database,
31    pub seeds: indexmap::IndexMap<SeedName, Seed>,
32    pub ssl_config: Option<SslConfig>,
33    pub superuser: pg_client::User,
34    pub image: crate::image::Image,
35    pub cross_container_access: bool,
36    pub wait_available_timeout: std::time::Duration,
37    pub remove: bool,
38}
39
40impl Definition {
41    #[must_use]
42    pub fn new(
43        backend: ociman::backend::Backend,
44        image: crate::image::Image,
45        instance_name: crate::InstanceName,
46    ) -> Self {
47        Self {
48            instance_name,
49            backend,
50            application_name: None,
51            seeds: indexmap::IndexMap::new(),
52            ssl_config: None,
53            superuser: pg_client::User::POSTGRES,
54            database: pg_client::Database::POSTGRES,
55            image,
56            cross_container_access: false,
57            wait_available_timeout: std::time::Duration::from_secs(10),
58            remove: true,
59        }
60    }
61
62    #[must_use]
63    pub fn remove(self, remove: bool) -> Self {
64        Self { remove, ..self }
65    }
66
67    #[must_use]
68    pub fn image(self, image: crate::image::Image) -> Self {
69        Self { image, ..self }
70    }
71
72    pub fn add_seed(self, name: SeedName, seed: Seed) -> Result<Self, DuplicateSeedName> {
73        let mut seeds = self.seeds.clone();
74
75        if seeds.contains_key(&name) {
76            return Err(DuplicateSeedName(name));
77        }
78
79        seeds.insert(name, seed);
80        Ok(Self { seeds, ..self })
81    }
82
83    pub fn apply_file(
84        self,
85        name: SeedName,
86        path: std::path::PathBuf,
87    ) -> Result<Self, DuplicateSeedName> {
88        self.add_seed(name, Seed::SqlFile { path })
89    }
90
91    pub async fn load_seeds(
92        &self,
93        instance_name: &crate::InstanceName,
94    ) -> Result<LoadedSeeds<'_>, LoadError> {
95        LoadedSeeds::load(
96            &self.image,
97            self.ssl_config.as_ref(),
98            &self.seeds,
99            &self.backend,
100            instance_name,
101        )
102        .await
103    }
104
105    pub async fn print_cache_status(
106        &self,
107        instance_name: &crate::InstanceName,
108        json: bool,
109    ) -> Result<(), crate::container::Error> {
110        let loaded_seeds = self.load_seeds(instance_name).await?;
111        if json {
112            loaded_seeds.print_json(instance_name);
113        } else {
114            loaded_seeds.print(instance_name);
115        }
116        Ok(())
117    }
118
119    #[must_use]
120    pub fn superuser(self, user: pg_client::User) -> Self {
121        Self {
122            superuser: user,
123            ..self
124        }
125    }
126
127    pub fn apply_file_from_git_revision(
128        self,
129        name: SeedName,
130        path: std::path::PathBuf,
131        git_revision: impl Into<String>,
132    ) -> Result<Self, DuplicateSeedName> {
133        self.add_seed(
134            name,
135            Seed::SqlFileGitRevision {
136                git_revision: git_revision.into(),
137                path,
138            },
139        )
140    }
141
142    pub fn apply_command(
143        self,
144        name: SeedName,
145        command: Command,
146        cache: CommandCacheConfig,
147    ) -> Result<Self, DuplicateSeedName> {
148        self.add_seed(name, Seed::Command { command, cache })
149    }
150
151    pub fn apply_script(
152        self,
153        name: SeedName,
154        script: impl Into<String>,
155    ) -> Result<Self, DuplicateSeedName> {
156        self.add_seed(
157            name,
158            Seed::Script {
159                script: script.into(),
160            },
161        )
162    }
163
164    pub fn apply_container_script(
165        self,
166        name: SeedName,
167        script: impl Into<String>,
168    ) -> Result<Self, DuplicateSeedName> {
169        self.add_seed(
170            name,
171            Seed::ContainerScript {
172                script: script.into(),
173            },
174        )
175    }
176
177    pub fn apply_csv_file(
178        self,
179        name: SeedName,
180        path: std::path::PathBuf,
181        table: pg_client::QualifiedTable,
182    ) -> Result<Self, DuplicateSeedName> {
183        self.add_seed(name, Seed::CsvFile { path, table })
184    }
185
186    #[must_use]
187    pub fn ssl_config(self, ssl_config: SslConfig) -> Self {
188        Self {
189            ssl_config: Some(ssl_config),
190            ..self
191        }
192    }
193
194    #[must_use]
195    pub fn cross_container_access(self, enabled: bool) -> Self {
196        Self {
197            cross_container_access: enabled,
198            ..self
199        }
200    }
201
202    #[must_use]
203    pub fn wait_available_timeout(self, timeout: std::time::Duration) -> Self {
204        Self {
205            wait_available_timeout: timeout,
206            ..self
207        }
208    }
209
210    #[must_use]
211    pub fn to_ociman_definition(&self) -> ociman::Definition {
212        ociman::Definition::new(self.backend.clone(), (&self.image).into())
213    }
214
215    pub async fn with_container<T>(
216        &self,
217        mut action: impl AsyncFnMut(&Container) -> T,
218    ) -> Result<T, crate::container::Error> {
219        let (last_cache_hit, uncached_seeds) = self.populate_cache(&self.instance_name).await?;
220
221        let boot_definition = match &last_cache_hit {
222            Some(reference) => self
223                .clone()
224                .image(crate::image::Image::Explicit(reference.clone())),
225            None => self.clone(),
226        };
227
228        let mut db_container = Container::run_definition(&boot_definition).await;
229
230        if last_cache_hit.is_some() {
231            db_container
232                .set_superuser_password(
233                    db_container
234                        .client_config
235                        .session
236                        .password
237                        .as_ref()
238                        .unwrap(),
239                )
240                .await?;
241        }
242
243        db_container.wait_available().await?;
244
245        for seed in &uncached_seeds {
246            self.apply_loaded_seed(&db_container, seed).await?;
247        }
248
249        let result = action(&db_container).await;
250
251        db_container.stop().await;
252
253        Ok(result)
254    }
255
256    /// Populate cache images for seeds.
257    ///
258    /// Returns a tuple of:
259    /// - The last cache hit reference (if any), which can be used to boot from
260    /// - The loaded seeds that could not be cached because the cache chain was broken
261    pub async fn populate_cache(
262        &self,
263        instance_name: &crate::InstanceName,
264    ) -> Result<(Option<ociman::Reference>, Vec<LoadedSeed>), crate::container::Error> {
265        let loaded_seeds = self.load_seeds(instance_name).await?;
266
267        let mut previous_cache_reference: Option<&ociman::Reference> = None;
268        let mut seeds_iter = loaded_seeds.iter_seeds().peekable();
269
270        while let Some(seed) = seeds_iter.next() {
271            let Some(cache_reference) = seed.cache_status().reference() else {
272                // Uncacheable seed - cache chain is broken, return remaining seeds
273                let mut remaining = vec![seed.clone()];
274                remaining.extend(seeds_iter.cloned());
275                return Ok((previous_cache_reference.cloned(), remaining));
276            };
277
278            if seed.cache_status().is_hit() {
279                previous_cache_reference = Some(cache_reference);
280                continue;
281            }
282
283            let caching_image = previous_cache_reference
284                .map(|reference| crate::image::Image::Explicit(reference.clone()))
285                .unwrap_or_else(|| self.image.clone());
286
287            if let LoadedSeed::ContainerScript { script, .. } = seed {
288                log::info!("Applying container-script seed: {}", seed.name());
289
290                let base_image: ociman::image::Reference = (&caching_image).into();
291                let build_dir = create_container_script_build_dir(&base_image, script);
292
293                ociman::image::BuildDefinition::from_directory(
294                    &self.backend,
295                    cache_reference.clone(),
296                    &build_dir,
297                )
298                .build()
299                .await;
300
301                std::fs::remove_dir_all(&build_dir)
302                    .expect("failed to clean up container-script build directory");
303            } else {
304                let caching_definition = self.clone().remove(false).image(caching_image);
305
306                let mut container = Container::run_definition(&caching_definition).await;
307
308                if previous_cache_reference.is_some() {
309                    container
310                        .set_superuser_password(
311                            container.client_config.session.password.as_ref().unwrap(),
312                        )
313                        .await?;
314                }
315
316                container.wait_available().await?;
317
318                self.apply_loaded_seed(&container, seed).await?;
319                container.stop_commit_remove(cache_reference).await;
320            }
321
322            log::info!("Committed cache image: {cache_reference}");
323
324            previous_cache_reference = Some(cache_reference);
325        }
326
327        Ok((previous_cache_reference.cloned(), Vec::new()))
328    }
329
330    pub async fn run_integration_server(
331        &self,
332        result_fd: std::os::fd::RawFd,
333        control_fd: std::os::fd::RawFd,
334    ) -> Result<(), crate::container::Error> {
335        self.with_container(async |container| {
336            // SAFETY: The parent process guarantees these are valid, exclusively-owned FDs
337            // inherited via the process spawn protocol.
338            let result_owned = unsafe { std::os::fd::OwnedFd::from_raw_fd(result_fd) };
339            let control_owned = unsafe { std::os::fd::OwnedFd::from_raw_fd(control_fd) };
340
341            let mut result_file = std::fs::File::from(result_owned);
342            let json = serde_json::to_string(&container.client_config).unwrap();
343
344            use std::io::Write;
345            writeln!(result_file, "{json}").expect("Failed to write config to result pipe");
346            drop(result_file);
347
348            log::info!("Integration server is running, waiting for EOF on control pipe");
349
350            let control_fd = tokio::io::unix::AsyncFd::new(control_owned)
351                .expect("Failed to register control pipe with tokio");
352
353            let _ = control_fd.readable().await.unwrap();
354
355            log::info!("Integration server received EOF on control pipe, exiting");
356        })
357        .await
358    }
359
360    async fn apply_loaded_seed(
361        &self,
362        db_container: &Container,
363        loaded_seed: &LoadedSeed,
364    ) -> Result<(), SeedApplyError> {
365        log::info!("Applying seed: {}", loaded_seed.name());
366        match loaded_seed {
367            LoadedSeed::SqlFile { content, .. } => db_container.apply_sql(content).await?,
368            LoadedSeed::SqlFileGitRevision { content, .. } => {
369                db_container.apply_sql(content).await?
370            }
371            LoadedSeed::Command { command, .. } => {
372                self.execute_command(db_container, command).await?
373            }
374            LoadedSeed::Script { script, .. } => self.execute_script(db_container, script).await?,
375            LoadedSeed::ContainerScript { script, .. } => {
376                db_container.exec_container_script(script).await?
377            }
378            LoadedSeed::CsvFile { table, content, .. } => {
379                db_container.apply_csv(table, content).await?
380            }
381        }
382
383        Ok(())
384    }
385
386    async fn execute_command(
387        &self,
388        db_container: &Container,
389        command: &Command,
390    ) -> Result<(), cmd_proc::CommandError> {
391        cmd_proc::Command::new(&command.command)
392            .arguments(&command.arguments)
393            .envs(db_container.pg_env())
394            .env(&crate::ENV_DATABASE_URL, db_container.database_url())
395            .status()
396            .await
397    }
398
399    async fn execute_script(
400        &self,
401        db_container: &Container,
402        script: &str,
403    ) -> Result<(), cmd_proc::CommandError> {
404        cmd_proc::Command::new("sh")
405            .arguments(["-e", "-c"])
406            .argument(script)
407            .envs(db_container.pg_env())
408            .env(&crate::ENV_DATABASE_URL, db_container.database_url())
409            .status()
410            .await
411    }
412
413    pub async fn schema_dump(
414        &self,
415        client_config: &pg_client::Config,
416        pg_schema_dump: &pg_client::PgSchemaDump,
417    ) -> String {
418        let (effective_config, mounts) = apply_ociman_mounts(client_config);
419
420        let bytes = self
421            .to_ociman_definition()
422            .entrypoint("pg_dump")
423            .arguments(pg_schema_dump.arguments())
424            .environment_variables(effective_config.to_pg_env())
425            .mounts(mounts)
426            .run_capture_only_stdout()
427            .await;
428
429        crate::convert_schema(&bytes)
430    }
431}
432
433#[must_use]
434pub fn apply_ociman_mounts(
435    client_config: &pg_client::Config,
436) -> (pg_client::Config, Vec<ociman::Mount>) {
437    let owned_client_config = client_config.clone();
438
439    match client_config.ssl_root_cert {
440        Some(ref ssl_root_cert) => match ssl_root_cert {
441            pg_client::config::SslRootCert::File(file) => {
442                let host =
443                    std::fs::canonicalize(file).expect("could not canonicalize ssl root path");
444
445                let mut container_path = std::path::PathBuf::new();
446
447                container_path.push("/pg_ephemeral");
448                container_path.push(file.file_name().unwrap());
449
450                let mounts = vec![ociman::Mount::from(format!(
451                    "type=bind,ro,source={},target={}",
452                    host.to_str().unwrap(),
453                    container_path.to_str().unwrap()
454                ))];
455
456                (
457                    pg_client::Config {
458                        ssl_root_cert: Some(container_path.into()),
459                        ..owned_client_config
460                    },
461                    mounts,
462                )
463            }
464            pg_client::config::SslRootCert::System => (owned_client_config, vec![]),
465        },
466        None => (owned_client_config, vec![]),
467    }
468}
469
470fn create_container_script_build_dir(
471    base_image: &ociman::image::Reference,
472    script: &str,
473) -> std::path::PathBuf {
474    use rand::RngExt;
475
476    let suffix: String = rand::rng()
477        .sample_iter(rand::distr::Alphanumeric)
478        .take(16)
479        .map(char::from)
480        .collect();
481
482    let dir = std::env::temp_dir().join(format!("pg-ephemeral-build-{suffix}"));
483    std::fs::create_dir(&dir).expect("failed to create container-script build directory");
484
485    std::fs::write(dir.join("script.sh"), script).expect("failed to write container-script");
486
487    std::fs::write(
488        dir.join("Dockerfile"),
489        format!("FROM {base_image}\nCOPY script.sh /tmp/pg-ephemeral-script.sh\nRUN sh -e /tmp/pg-ephemeral-script.sh && rm /tmp/pg-ephemeral-script.sh\n"),
490    )
491    .expect("failed to write Dockerfile");
492
493    dir
494}
495
496#[cfg(test)]
497mod test {
498    use super::*;
499
500    fn test_backend() -> ociman::Backend {
501        ociman::Backend::Podman {
502            version: semver::Version::new(4, 0, 0),
503        }
504    }
505
506    fn test_instance_name() -> crate::InstanceName {
507        "test".parse().unwrap()
508    }
509
510    #[test]
511    fn test_add_seed_rejects_duplicate() {
512        let definition = Definition::new(
513            test_backend(),
514            crate::Image::default(),
515            test_instance_name(),
516        );
517        let seed_name: SeedName = "test-seed".parse().unwrap();
518
519        let definition = definition
520            .add_seed(
521                seed_name.clone(),
522                Seed::SqlFile {
523                    path: "file1.sql".into(),
524                },
525            )
526            .unwrap();
527
528        let result = definition.add_seed(
529            seed_name.clone(),
530            Seed::SqlFile {
531                path: "file2.sql".into(),
532            },
533        );
534
535        assert_eq!(result, Err(DuplicateSeedName(seed_name)));
536    }
537
538    #[test]
539    fn test_add_seed_allows_different_names() {
540        let definition = Definition::new(
541            test_backend(),
542            crate::Image::default(),
543            test_instance_name(),
544        );
545
546        let definition = definition
547            .add_seed(
548                "seed1".parse().unwrap(),
549                Seed::SqlFile {
550                    path: "file1.sql".into(),
551                },
552            )
553            .unwrap();
554
555        let result = definition.add_seed(
556            "seed2".parse().unwrap(),
557            Seed::SqlFile {
558                path: "file2.sql".into(),
559            },
560        );
561
562        assert!(result.is_ok());
563    }
564
565    #[test]
566    fn test_apply_file_rejects_duplicate() {
567        let definition = Definition::new(
568            test_backend(),
569            crate::Image::default(),
570            test_instance_name(),
571        );
572        let seed_name: SeedName = "test-seed".parse().unwrap();
573
574        let definition = definition
575            .apply_file(seed_name.clone(), "file1.sql".into())
576            .unwrap();
577
578        let result = definition.apply_file(seed_name.clone(), "file2.sql".into());
579
580        assert_eq!(result, Err(DuplicateSeedName(seed_name)));
581    }
582
583    #[test]
584    fn test_apply_command_adds_seed() {
585        let definition = Definition::new(
586            test_backend(),
587            crate::Image::default(),
588            test_instance_name(),
589        );
590
591        let result = definition.apply_command(
592            "test-command".parse().unwrap(),
593            Command::new("echo", vec!["test"]),
594            CommandCacheConfig::CommandHash,
595        );
596
597        assert!(result.is_ok());
598        let definition = result.unwrap();
599        assert_eq!(definition.seeds.len(), 1);
600    }
601
602    #[test]
603    fn test_apply_command_rejects_duplicate() {
604        let definition = Definition::new(
605            test_backend(),
606            crate::Image::default(),
607            test_instance_name(),
608        );
609        let seed_name: SeedName = "test-command".parse().unwrap();
610
611        let definition = definition
612            .apply_command(
613                seed_name.clone(),
614                Command::new("echo", vec!["test1"]),
615                CommandCacheConfig::CommandHash,
616            )
617            .unwrap();
618
619        let result = definition.apply_command(
620            seed_name.clone(),
621            Command::new("echo", vec!["test2"]),
622            CommandCacheConfig::CommandHash,
623        );
624
625        assert_eq!(result, Err(DuplicateSeedName(seed_name)));
626    }
627
628    #[test]
629    fn test_apply_script_adds_seed() {
630        let definition = Definition::new(
631            test_backend(),
632            crate::Image::default(),
633            test_instance_name(),
634        );
635
636        let result = definition.apply_script("test-script".parse().unwrap(), "echo test");
637
638        assert!(result.is_ok());
639        let definition = result.unwrap();
640        assert_eq!(definition.seeds.len(), 1);
641    }
642
643    #[test]
644    fn test_apply_script_rejects_duplicate() {
645        let definition = Definition::new(
646            test_backend(),
647            crate::Image::default(),
648            test_instance_name(),
649        );
650        let seed_name: SeedName = "test-script".parse().unwrap();
651
652        let definition = definition
653            .apply_script(seed_name.clone(), "echo test1")
654            .unwrap();
655
656        let result = definition.apply_script(seed_name.clone(), "echo test2");
657
658        assert_eq!(result, Err(DuplicateSeedName(seed_name)));
659    }
660
661    #[test]
662    fn test_apply_container_script_adds_seed() {
663        let definition = Definition::new(
664            test_backend(),
665            crate::Image::default(),
666            test_instance_name(),
667        );
668
669        let result = definition.apply_container_script(
670            "install-ext".parse().unwrap(),
671            "apt-get update && apt-get install -y postgresql-17-cron",
672        );
673
674        assert!(result.is_ok());
675        let definition = result.unwrap();
676        assert_eq!(definition.seeds.len(), 1);
677    }
678
679    #[test]
680    fn test_apply_container_script_rejects_duplicate() {
681        let definition = Definition::new(
682            test_backend(),
683            crate::Image::default(),
684            test_instance_name(),
685        );
686        let seed_name: SeedName = "install-ext".parse().unwrap();
687
688        let definition = definition
689            .apply_container_script(seed_name.clone(), "apt-get update")
690            .unwrap();
691
692        let result = definition.apply_container_script(seed_name.clone(), "apt-get update");
693
694        assert_eq!(result, Err(DuplicateSeedName(seed_name)));
695    }
696
697    fn test_qualified_table() -> pg_client::QualifiedTable {
698        pg_client::QualifiedTable {
699            schema: pg_client::identifier::Schema::PUBLIC,
700            table: "users".parse().unwrap(),
701        }
702    }
703
704    #[test]
705    fn test_apply_csv_file_adds_seed() {
706        let definition = Definition::new(
707            test_backend(),
708            crate::Image::default(),
709            test_instance_name(),
710        );
711
712        let result = definition.apply_csv_file(
713            "import-users".parse().unwrap(),
714            "fixtures/users.csv".into(),
715            test_qualified_table(),
716        );
717
718        assert!(result.is_ok());
719        let definition = result.unwrap();
720        assert_eq!(definition.seeds.len(), 1);
721    }
722
723    #[test]
724    fn test_apply_csv_file_rejects_duplicate() {
725        let definition = Definition::new(
726            test_backend(),
727            crate::Image::default(),
728            test_instance_name(),
729        );
730        let seed_name: SeedName = "import-users".parse().unwrap();
731
732        let definition = definition
733            .apply_csv_file(
734                seed_name.clone(),
735                "fixtures/users.csv".into(),
736                test_qualified_table(),
737            )
738            .unwrap();
739
740        let result = definition.apply_csv_file(
741            seed_name.clone(),
742            "fixtures/other.csv".into(),
743            test_qualified_table(),
744        );
745
746        assert_eq!(result, Err(DuplicateSeedName(seed_name)));
747    }
748}