Skip to main content

pg_ephemeral/
definition.rs

1use std::os::fd::FromRawFd;
2
3use crate::Container;
4use crate::seed::{
5    Command, CommandCacheConfig, DuplicateSeedName, LoadError, LoadedSeed, LoadedSeeds, Seed,
6    SeedName,
7};
8
9#[derive(Debug, thiserror::Error)]
10pub enum SeedApplyError {
11    #[error("Failed to apply command seed")]
12    Command(#[from] cmd_proc::CommandError),
13    #[error("Failed to apply SQL seed")]
14    Sql(#[from] sqlx::Error),
15}
16
17#[derive(Clone, Debug, PartialEq)]
18pub enum SslConfig {
19    Generated {
20        hostname: pg_client::config::HostName,
21    },
22    // UserProvided { ca_cert: PathBuf, server_cert: PathBuf, server_key: PathBuf },
23}
24
25#[derive(Clone, Debug, PartialEq)]
26pub struct Definition {
27    pub instance_name: crate::InstanceName,
28    pub application_name: Option<pg_client::config::ApplicationName>,
29    pub backend: ociman::Backend,
30    pub database: pg_client::Database,
31    pub seeds: indexmap::IndexMap<SeedName, Seed>,
32    pub ssl_config: Option<SslConfig>,
33    pub superuser: pg_client::User,
34    pub image: crate::image::Image,
35    pub cross_container_access: bool,
36    pub wait_available_timeout: std::time::Duration,
37    pub remove: bool,
38}
39
40impl Definition {
41    #[must_use]
42    pub fn new(
43        backend: ociman::backend::Backend,
44        image: crate::image::Image,
45        instance_name: crate::InstanceName,
46    ) -> Self {
47        Self {
48            instance_name,
49            backend,
50            application_name: None,
51            seeds: indexmap::IndexMap::new(),
52            ssl_config: None,
53            superuser: pg_client::User::POSTGRES,
54            database: pg_client::Database::POSTGRES,
55            image,
56            cross_container_access: false,
57            wait_available_timeout: std::time::Duration::from_secs(10),
58            remove: true,
59        }
60    }
61
62    #[must_use]
63    pub fn remove(self, remove: bool) -> Self {
64        Self { remove, ..self }
65    }
66
67    #[must_use]
68    pub fn image(self, image: crate::image::Image) -> Self {
69        Self { image, ..self }
70    }
71
72    pub fn add_seed(self, name: SeedName, seed: Seed) -> Result<Self, DuplicateSeedName> {
73        let mut seeds = self.seeds.clone();
74
75        if seeds.contains_key(&name) {
76            return Err(DuplicateSeedName(name));
77        }
78
79        seeds.insert(name, seed);
80        Ok(Self { seeds, ..self })
81    }
82
83    pub fn apply_file(
84        self,
85        name: SeedName,
86        path: std::path::PathBuf,
87    ) -> Result<Self, DuplicateSeedName> {
88        self.add_seed(name, Seed::SqlFile { path })
89    }
90
91    pub async fn load_seeds(
92        &self,
93        instance_name: &crate::InstanceName,
94    ) -> Result<LoadedSeeds<'_>, LoadError> {
95        LoadedSeeds::load(
96            &self.image,
97            self.ssl_config.as_ref(),
98            &self.seeds,
99            &self.backend,
100            instance_name,
101        )
102        .await
103    }
104
105    pub async fn print_cache_status(
106        &self,
107        instance_name: &crate::InstanceName,
108        json: bool,
109    ) -> Result<(), crate::container::Error> {
110        let loaded_seeds = self.load_seeds(instance_name).await?;
111        if json {
112            loaded_seeds.print_json(instance_name);
113        } else {
114            loaded_seeds.print(instance_name);
115        }
116        Ok(())
117    }
118
119    #[must_use]
120    pub fn superuser(self, user: pg_client::User) -> Self {
121        Self {
122            superuser: user,
123            ..self
124        }
125    }
126
127    pub fn apply_file_from_git_revision(
128        self,
129        name: SeedName,
130        path: std::path::PathBuf,
131        git_revision: impl Into<String>,
132    ) -> Result<Self, DuplicateSeedName> {
133        self.add_seed(
134            name,
135            Seed::SqlFileGitRevision {
136                git_revision: git_revision.into(),
137                path,
138            },
139        )
140    }
141
142    pub fn apply_command(
143        self,
144        name: SeedName,
145        command: Command,
146        cache: CommandCacheConfig,
147    ) -> Result<Self, DuplicateSeedName> {
148        self.add_seed(name, Seed::Command { command, cache })
149    }
150
151    pub fn apply_script(
152        self,
153        name: SeedName,
154        script: impl Into<String>,
155    ) -> Result<Self, DuplicateSeedName> {
156        self.add_seed(
157            name,
158            Seed::Script {
159                script: script.into(),
160            },
161        )
162    }
163
164    pub fn apply_container_script(
165        self,
166        name: SeedName,
167        script: impl Into<String>,
168    ) -> Result<Self, DuplicateSeedName> {
169        self.add_seed(
170            name,
171            Seed::ContainerScript {
172                script: script.into(),
173            },
174        )
175    }
176
177    pub fn apply_csv_file(
178        self,
179        name: SeedName,
180        path: std::path::PathBuf,
181        table: pg_client::QualifiedTable,
182    ) -> Result<Self, DuplicateSeedName> {
183        self.add_seed(
184            name,
185            Seed::CsvFile {
186                path,
187                table,
188                delimiter: ',',
189            },
190        )
191    }
192
193    #[must_use]
194    pub fn ssl_config(self, ssl_config: SslConfig) -> Self {
195        Self {
196            ssl_config: Some(ssl_config),
197            ..self
198        }
199    }
200
201    #[must_use]
202    pub fn cross_container_access(self, enabled: bool) -> Self {
203        Self {
204            cross_container_access: enabled,
205            ..self
206        }
207    }
208
209    #[must_use]
210    pub fn wait_available_timeout(self, timeout: std::time::Duration) -> Self {
211        Self {
212            wait_available_timeout: timeout,
213            ..self
214        }
215    }
216
217    #[must_use]
218    pub fn to_ociman_definition(&self) -> ociman::Definition {
219        ociman::Definition::new(self.backend.clone(), (&self.image).into())
220    }
221
222    pub async fn with_container<T>(
223        &self,
224        mut action: impl AsyncFnMut(&Container) -> T,
225    ) -> Result<T, crate::container::Error> {
226        let (last_cache_hit, uncached_seeds) = self.populate_cache(&self.instance_name).await?;
227
228        let boot_definition = match &last_cache_hit {
229            Some(reference) => self
230                .clone()
231                .image(crate::image::Image::Explicit(reference.clone())),
232            None => self.clone(),
233        };
234
235        let mut db_container = Container::run_definition(&boot_definition).await;
236
237        if last_cache_hit.is_some() {
238            db_container
239                .set_superuser_password(
240                    db_container
241                        .client_config
242                        .session
243                        .password
244                        .as_ref()
245                        .unwrap(),
246                )
247                .await?;
248        }
249
250        db_container.wait_available().await?;
251
252        for seed in &uncached_seeds {
253            self.apply_loaded_seed(&db_container, seed).await?;
254        }
255
256        let result = action(&db_container).await;
257
258        db_container.stop().await;
259
260        Ok(result)
261    }
262
263    /// Populate cache images for seeds.
264    ///
265    /// Returns a tuple of:
266    /// - The last cache hit reference (if any), which can be used to boot from
267    /// - The loaded seeds that could not be cached because the cache chain was broken
268    pub async fn populate_cache(
269        &self,
270        instance_name: &crate::InstanceName,
271    ) -> Result<(Option<ociman::Reference>, Vec<LoadedSeed>), crate::container::Error> {
272        let loaded_seeds = self.load_seeds(instance_name).await?;
273
274        let mut previous_cache_reference: Option<&ociman::Reference> = None;
275        let mut seeds_iter = loaded_seeds.iter_seeds().peekable();
276
277        while let Some(seed) = seeds_iter.next() {
278            let Some(cache_reference) = seed.cache_status().reference() else {
279                // Uncacheable seed - cache chain is broken, return remaining seeds
280                let mut remaining = vec![seed.clone()];
281                remaining.extend(seeds_iter.cloned());
282                return Ok((previous_cache_reference.cloned(), remaining));
283            };
284
285            if seed.cache_status().is_hit() {
286                previous_cache_reference = Some(cache_reference);
287                continue;
288            }
289
290            let caching_image = previous_cache_reference
291                .map(|reference| crate::image::Image::Explicit(reference.clone()))
292                .unwrap_or_else(|| self.image.clone());
293
294            if let LoadedSeed::ContainerScript { script, .. } = seed {
295                log::info!("Applying container-script seed: {}", seed.name());
296
297                let base_image: ociman::image::Reference = (&caching_image).into();
298                let build_dir = create_container_script_build_dir(&base_image, script);
299
300                ociman::image::BuildDefinition::from_directory(
301                    &self.backend,
302                    cache_reference.clone(),
303                    &build_dir,
304                )
305                .build()
306                .await;
307
308                std::fs::remove_dir_all(&build_dir)
309                    .expect("failed to clean up container-script build directory");
310            } else {
311                let caching_definition = self.clone().remove(false).image(caching_image);
312
313                let mut container = Container::run_definition(&caching_definition).await;
314
315                if previous_cache_reference.is_some() {
316                    container
317                        .set_superuser_password(
318                            container.client_config.session.password.as_ref().unwrap(),
319                        )
320                        .await?;
321                }
322
323                container.wait_available().await?;
324
325                self.apply_loaded_seed(&container, seed).await?;
326                container.stop_commit_remove(cache_reference).await;
327            }
328
329            log::info!("Committed cache image: {cache_reference}");
330
331            previous_cache_reference = Some(cache_reference);
332        }
333
334        Ok((previous_cache_reference.cloned(), Vec::new()))
335    }
336
337    pub async fn run_integration_server(
338        &self,
339        result_fd: std::os::fd::RawFd,
340        control_fd: std::os::fd::RawFd,
341    ) -> Result<(), crate::container::Error> {
342        self.with_container(async |container| {
343            // SAFETY: The parent process guarantees these are valid, exclusively-owned FDs
344            // inherited via the process spawn protocol.
345            let result_owned = unsafe { std::os::fd::OwnedFd::from_raw_fd(result_fd) };
346            let control_owned = unsafe { std::os::fd::OwnedFd::from_raw_fd(control_fd) };
347
348            let mut result_file = std::fs::File::from(result_owned);
349            let json = serde_json::to_string(&container.client_config).unwrap();
350
351            use std::io::Write;
352            writeln!(result_file, "{json}").expect("Failed to write config to result pipe");
353            drop(result_file);
354
355            log::info!("Integration server is running, waiting for EOF on control pipe");
356
357            let control_fd = tokio::io::unix::AsyncFd::new(control_owned)
358                .expect("Failed to register control pipe with tokio");
359
360            let _ = control_fd.readable().await.unwrap();
361
362            log::info!("Integration server received EOF on control pipe, exiting");
363        })
364        .await
365    }
366
367    async fn apply_loaded_seed(
368        &self,
369        db_container: &Container,
370        loaded_seed: &LoadedSeed,
371    ) -> Result<(), SeedApplyError> {
372        log::info!("Applying seed: {}", loaded_seed.name());
373        match loaded_seed {
374            LoadedSeed::SqlFile { content, .. } => db_container.apply_sql(content).await?,
375            LoadedSeed::SqlFileGitRevision { content, .. } => {
376                db_container.apply_sql(content).await?
377            }
378            LoadedSeed::Command { command, .. } => {
379                self.execute_command(db_container, command).await?
380            }
381            LoadedSeed::Script { script, .. } => self.execute_script(db_container, script).await?,
382            LoadedSeed::ContainerScript { script, .. } => {
383                db_container.exec_container_script(script).await?
384            }
385            LoadedSeed::CsvFile {
386                table,
387                delimiter,
388                content,
389                ..
390            } => db_container.apply_csv(table, *delimiter, content).await?,
391        }
392
393        Ok(())
394    }
395
396    async fn execute_command(
397        &self,
398        db_container: &Container,
399        command: &Command,
400    ) -> Result<(), cmd_proc::CommandError> {
401        cmd_proc::Command::new(&command.command)
402            .arguments(&command.arguments)
403            .envs(db_container.pg_env())
404            .env(&crate::ENV_DATABASE_URL, db_container.database_url())
405            .status()
406            .await
407    }
408
409    async fn execute_script(
410        &self,
411        db_container: &Container,
412        script: &str,
413    ) -> Result<(), cmd_proc::CommandError> {
414        cmd_proc::Command::new("sh")
415            .arguments(["-e", "-c"])
416            .argument(script)
417            .envs(db_container.pg_env())
418            .env(&crate::ENV_DATABASE_URL, db_container.database_url())
419            .status()
420            .await
421    }
422
423    pub async fn schema_dump(
424        &self,
425        client_config: &pg_client::Config,
426        pg_schema_dump: &pg_client::PgSchemaDump,
427    ) -> String {
428        let (effective_config, mounts) = apply_ociman_mounts(client_config);
429
430        let bytes = self
431            .to_ociman_definition()
432            .entrypoint("pg_dump")
433            .arguments(pg_schema_dump.arguments())
434            .environment_variables(effective_config.to_pg_env())
435            .mounts(mounts)
436            .run_capture_only_stdout()
437            .await;
438
439        crate::convert_schema(&bytes)
440    }
441}
442
443#[must_use]
444pub fn apply_ociman_mounts(
445    client_config: &pg_client::Config,
446) -> (pg_client::Config, Vec<ociman::Mount>) {
447    let owned_client_config = client_config.clone();
448
449    match client_config.ssl_root_cert {
450        Some(ref ssl_root_cert) => match ssl_root_cert {
451            pg_client::config::SslRootCert::File(file) => {
452                let host =
453                    std::fs::canonicalize(file).expect("could not canonicalize ssl root path");
454
455                let mut container_path = std::path::PathBuf::new();
456
457                container_path.push("/pg_ephemeral");
458                container_path.push(file.file_name().unwrap());
459
460                let mounts = vec![ociman::Mount::from(format!(
461                    "type=bind,ro,source={},target={}",
462                    host.to_str().unwrap(),
463                    container_path.to_str().unwrap()
464                ))];
465
466                (
467                    pg_client::Config {
468                        ssl_root_cert: Some(container_path.into()),
469                        ..owned_client_config
470                    },
471                    mounts,
472                )
473            }
474            pg_client::config::SslRootCert::System => (owned_client_config, vec![]),
475        },
476        None => (owned_client_config, vec![]),
477    }
478}
479
480fn create_container_script_build_dir(
481    base_image: &ociman::image::Reference,
482    script: &str,
483) -> std::path::PathBuf {
484    use rand::RngExt;
485
486    let suffix: String = rand::rng()
487        .sample_iter(rand::distr::Alphanumeric)
488        .take(16)
489        .map(char::from)
490        .collect();
491
492    let dir = std::env::temp_dir().join(format!("pg-ephemeral-build-{suffix}"));
493    std::fs::create_dir(&dir).expect("failed to create container-script build directory");
494
495    std::fs::write(dir.join("script.sh"), script).expect("failed to write container-script");
496
497    std::fs::write(
498        dir.join("Dockerfile"),
499        format!("FROM {base_image}\nCOPY script.sh /tmp/pg-ephemeral-script.sh\nRUN sh -e /tmp/pg-ephemeral-script.sh && rm /tmp/pg-ephemeral-script.sh\n"),
500    )
501    .expect("failed to write Dockerfile");
502
503    dir
504}
505
506#[cfg(test)]
507mod test {
508    use super::*;
509
510    fn test_backend() -> ociman::Backend {
511        ociman::Backend::Podman {
512            version: semver::Version::new(4, 0, 0),
513        }
514    }
515
516    fn test_instance_name() -> crate::InstanceName {
517        "test".parse().unwrap()
518    }
519
520    #[test]
521    fn test_add_seed_rejects_duplicate() {
522        let definition = Definition::new(
523            test_backend(),
524            crate::Image::default(),
525            test_instance_name(),
526        );
527        let seed_name: SeedName = "test-seed".parse().unwrap();
528
529        let definition = definition
530            .add_seed(
531                seed_name.clone(),
532                Seed::SqlFile {
533                    path: "file1.sql".into(),
534                },
535            )
536            .unwrap();
537
538        let result = definition.add_seed(
539            seed_name.clone(),
540            Seed::SqlFile {
541                path: "file2.sql".into(),
542            },
543        );
544
545        assert_eq!(result, Err(DuplicateSeedName(seed_name)));
546    }
547
548    #[test]
549    fn test_add_seed_allows_different_names() {
550        let definition = Definition::new(
551            test_backend(),
552            crate::Image::default(),
553            test_instance_name(),
554        );
555
556        let definition = definition
557            .add_seed(
558                "seed1".parse().unwrap(),
559                Seed::SqlFile {
560                    path: "file1.sql".into(),
561                },
562            )
563            .unwrap();
564
565        let result = definition.add_seed(
566            "seed2".parse().unwrap(),
567            Seed::SqlFile {
568                path: "file2.sql".into(),
569            },
570        );
571
572        assert!(result.is_ok());
573    }
574
575    #[test]
576    fn test_apply_file_rejects_duplicate() {
577        let definition = Definition::new(
578            test_backend(),
579            crate::Image::default(),
580            test_instance_name(),
581        );
582        let seed_name: SeedName = "test-seed".parse().unwrap();
583
584        let definition = definition
585            .apply_file(seed_name.clone(), "file1.sql".into())
586            .unwrap();
587
588        let result = definition.apply_file(seed_name.clone(), "file2.sql".into());
589
590        assert_eq!(result, Err(DuplicateSeedName(seed_name)));
591    }
592
593    #[test]
594    fn test_apply_command_adds_seed() {
595        let definition = Definition::new(
596            test_backend(),
597            crate::Image::default(),
598            test_instance_name(),
599        );
600
601        let result = definition.apply_command(
602            "test-command".parse().unwrap(),
603            Command::new("echo", vec!["test"]),
604            CommandCacheConfig::CommandHash,
605        );
606
607        assert!(result.is_ok());
608        let definition = result.unwrap();
609        assert_eq!(definition.seeds.len(), 1);
610    }
611
612    #[test]
613    fn test_apply_command_rejects_duplicate() {
614        let definition = Definition::new(
615            test_backend(),
616            crate::Image::default(),
617            test_instance_name(),
618        );
619        let seed_name: SeedName = "test-command".parse().unwrap();
620
621        let definition = definition
622            .apply_command(
623                seed_name.clone(),
624                Command::new("echo", vec!["test1"]),
625                CommandCacheConfig::CommandHash,
626            )
627            .unwrap();
628
629        let result = definition.apply_command(
630            seed_name.clone(),
631            Command::new("echo", vec!["test2"]),
632            CommandCacheConfig::CommandHash,
633        );
634
635        assert_eq!(result, Err(DuplicateSeedName(seed_name)));
636    }
637
638    #[test]
639    fn test_apply_script_adds_seed() {
640        let definition = Definition::new(
641            test_backend(),
642            crate::Image::default(),
643            test_instance_name(),
644        );
645
646        let result = definition.apply_script("test-script".parse().unwrap(), "echo test");
647
648        assert!(result.is_ok());
649        let definition = result.unwrap();
650        assert_eq!(definition.seeds.len(), 1);
651    }
652
653    #[test]
654    fn test_apply_script_rejects_duplicate() {
655        let definition = Definition::new(
656            test_backend(),
657            crate::Image::default(),
658            test_instance_name(),
659        );
660        let seed_name: SeedName = "test-script".parse().unwrap();
661
662        let definition = definition
663            .apply_script(seed_name.clone(), "echo test1")
664            .unwrap();
665
666        let result = definition.apply_script(seed_name.clone(), "echo test2");
667
668        assert_eq!(result, Err(DuplicateSeedName(seed_name)));
669    }
670
671    #[test]
672    fn test_apply_container_script_adds_seed() {
673        let definition = Definition::new(
674            test_backend(),
675            crate::Image::default(),
676            test_instance_name(),
677        );
678
679        let result = definition.apply_container_script(
680            "install-ext".parse().unwrap(),
681            "apt-get update && apt-get install -y postgresql-17-cron",
682        );
683
684        assert!(result.is_ok());
685        let definition = result.unwrap();
686        assert_eq!(definition.seeds.len(), 1);
687    }
688
689    #[test]
690    fn test_apply_container_script_rejects_duplicate() {
691        let definition = Definition::new(
692            test_backend(),
693            crate::Image::default(),
694            test_instance_name(),
695        );
696        let seed_name: SeedName = "install-ext".parse().unwrap();
697
698        let definition = definition
699            .apply_container_script(seed_name.clone(), "apt-get update")
700            .unwrap();
701
702        let result = definition.apply_container_script(seed_name.clone(), "apt-get update");
703
704        assert_eq!(result, Err(DuplicateSeedName(seed_name)));
705    }
706
707    fn test_qualified_table() -> pg_client::QualifiedTable {
708        pg_client::QualifiedTable {
709            schema: pg_client::identifier::Schema::PUBLIC,
710            table: "users".parse().unwrap(),
711        }
712    }
713
714    #[test]
715    fn test_apply_csv_file_adds_seed() {
716        let definition = Definition::new(
717            test_backend(),
718            crate::Image::default(),
719            test_instance_name(),
720        );
721
722        let result = definition.apply_csv_file(
723            "import-users".parse().unwrap(),
724            "fixtures/users.csv".into(),
725            test_qualified_table(),
726        );
727
728        assert!(result.is_ok());
729        let definition = result.unwrap();
730        assert_eq!(definition.seeds.len(), 1);
731    }
732
733    #[test]
734    fn test_apply_csv_file_rejects_duplicate() {
735        let definition = Definition::new(
736            test_backend(),
737            crate::Image::default(),
738            test_instance_name(),
739        );
740        let seed_name: SeedName = "import-users".parse().unwrap();
741
742        let definition = definition
743            .apply_csv_file(
744                seed_name.clone(),
745                "fixtures/users.csv".into(),
746                test_qualified_table(),
747            )
748            .unwrap();
749
750        let result = definition.apply_csv_file(
751            seed_name.clone(),
752            "fixtures/other.csv".into(),
753            test_qualified_table(),
754        );
755
756        assert_eq!(result, Err(DuplicateSeedName(seed_name)));
757    }
758}