Skip to main content

pg_ephemeral/
seed.rs

1use git_proc::Build;
2
3type CacheKey = [u8; 32];
4
5#[derive(Clone, Debug, PartialEq)]
6pub enum CacheStatus {
7    Hit { reference: ociman::Reference },
8    Miss { reference: ociman::Reference },
9    Uncacheable,
10}
11
12impl CacheStatus {
13    async fn from_cache_key(
14        cache_key: Option<CacheKey>,
15        backend: &ociman::Backend,
16        instance_name: &crate::InstanceName,
17    ) -> Self {
18        match cache_key {
19            Some(key) => {
20                let reference = format!("pg-ephemeral/{}:{}", instance_name, hex::encode(key))
21                    .parse()
22                    .unwrap();
23                if backend.is_image_present(&reference).await {
24                    Self::Hit { reference }
25                } else {
26                    Self::Miss { reference }
27                }
28            }
29            None => Self::Uncacheable,
30        }
31    }
32
33    #[must_use]
34    pub fn reference(&self) -> Option<&ociman::Reference> {
35        match self {
36            Self::Hit { reference } | Self::Miss { reference } => Some(reference),
37            Self::Uncacheable => None,
38        }
39    }
40
41    #[must_use]
42    pub fn is_hit(&self) -> bool {
43        matches!(self, Self::Hit { .. })
44    }
45
46    #[must_use]
47    pub fn status_str(&self) -> &'static str {
48        match self {
49            Self::Hit { .. } => "hit",
50            Self::Miss { .. } => "miss",
51            Self::Uncacheable => "uncacheable",
52        }
53    }
54}
55
56#[derive(Clone, Debug, Hash, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
57#[serde(try_from = "String")]
58pub struct SeedName(String);
59
60impl SeedName {
61    #[must_use]
62    pub fn as_str(&self) -> &str {
63        &self.0
64    }
65}
66
67impl std::fmt::Display for SeedName {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        write!(f, "{}", self.0)
70    }
71}
72
73#[derive(Debug, PartialEq, Eq, thiserror::Error)]
74#[error("Seed name cannot be empty")]
75pub struct SeedNameError;
76
77#[derive(Debug, PartialEq, Eq, thiserror::Error)]
78#[error("Duplicate seed name: {0}")]
79pub struct DuplicateSeedName(pub SeedName);
80
81impl std::str::FromStr for SeedName {
82    type Err = SeedNameError;
83
84    fn from_str(value: &str) -> Result<Self, Self::Err> {
85        if value.is_empty() {
86            Err(SeedNameError)
87        } else {
88            Ok(Self(value.to_string()))
89        }
90    }
91}
92
93impl TryFrom<String> for SeedName {
94    type Error = SeedNameError;
95
96    fn try_from(value: String) -> Result<Self, Self::Error> {
97        if value.is_empty() {
98            Err(SeedNameError)
99        } else {
100            Ok(Self(value))
101        }
102    }
103}
104
105impl TryFrom<&str> for SeedName {
106    type Error = SeedNameError;
107
108    fn try_from(value: &str) -> Result<Self, Self::Error> {
109        value.parse()
110    }
111}
112
113#[derive(Clone, Debug, PartialEq)]
114pub struct Command {
115    pub command: String,
116    pub arguments: Vec<String>,
117}
118
119impl Command {
120    pub fn new(
121        command: impl Into<String>,
122        arguments: impl IntoIterator<Item = impl Into<String>>,
123    ) -> Self {
124        Self {
125            command: command.into(),
126            arguments: arguments.into_iter().map(|a| a.into()).collect(),
127        }
128    }
129}
130
131#[derive(Clone, Debug, serde::Deserialize, PartialEq)]
132#[serde(tag = "type", rename_all = "kebab-case")]
133pub enum CommandCacheConfig {
134    /// Disable caching, breaks the cache chain
135    None,
136    /// Hash the command and arguments
137    CommandHash,
138    /// Run a command to get cache key input
139    KeyCommand {
140        command: String,
141        #[serde(default)]
142        arguments: Vec<String>,
143    },
144    /// Run a script to get cache key input
145    KeyScript { script: String },
146}
147
148#[derive(Clone, Debug, PartialEq)]
149pub enum Seed {
150    SqlFile {
151        path: std::path::PathBuf,
152    },
153    SqlFileGitRevision {
154        git_revision: String,
155        path: std::path::PathBuf,
156    },
157    Command {
158        command: Command,
159        cache: CommandCacheConfig,
160    },
161    Script {
162        script: String,
163    },
164    ContainerScript {
165        script: String,
166    },
167    CsvFile {
168        path: std::path::PathBuf,
169        table: pg_client::QualifiedTable,
170    },
171}
172
173impl Seed {
174    async fn load(
175        &self,
176        name: SeedName,
177        hash_chain: &mut HashChain,
178        backend: &ociman::Backend,
179        instance_name: &crate::InstanceName,
180    ) -> Result<LoadedSeed, LoadError> {
181        match self {
182            Seed::SqlFile { path } => {
183                let content =
184                    std::fs::read_to_string(path).map_err(|source| LoadError::FileRead {
185                        name: name.clone(),
186                        path: path.clone(),
187                        source,
188                    })?;
189
190                hash_chain.update(&content);
191
192                Ok(LoadedSeed::SqlFile {
193                    cache_status: CacheStatus::from_cache_key(
194                        hash_chain.cache_key(),
195                        backend,
196                        instance_name,
197                    )
198                    .await,
199                    name,
200                    path: path.clone(),
201                    content,
202                })
203            }
204            Seed::SqlFileGitRevision { path, git_revision } => {
205                let output =
206                    git_proc::show::new(&format!("{git_revision}:{}", path.to_str().unwrap()))
207                        .build()
208                        .stdout_capture()
209                        .stderr_capture()
210                        .accept_nonzero_exit()
211                        .run()
212                        .await
213                        .map_err(|error| LoadError::GitRevision {
214                            name: name.clone(),
215                            path: path.clone(),
216                            git_revision: git_revision.clone(),
217                            message: error.to_string(),
218                        })?;
219
220                if output.status.success() {
221                    let content = String::from_utf8(output.stdout).map_err(|error| {
222                        LoadError::GitRevision {
223                            name: name.clone(),
224                            path: path.clone(),
225                            git_revision: git_revision.clone(),
226                            message: error.to_string(),
227                        }
228                    })?;
229
230                    hash_chain.update(&content);
231
232                    Ok(LoadedSeed::SqlFileGitRevision {
233                        cache_status: CacheStatus::from_cache_key(
234                            hash_chain.cache_key(),
235                            backend,
236                            instance_name,
237                        )
238                        .await,
239                        name,
240                        path: path.clone(),
241                        git_revision: git_revision.clone(),
242                        content,
243                    })
244                } else {
245                    let message = String::from_utf8(output.stderr).map_err(|error| {
246                        LoadError::GitRevision {
247                            name: name.clone(),
248                            path: path.clone(),
249                            git_revision: git_revision.clone(),
250                            message: error.to_string(),
251                        }
252                    })?;
253                    Err(LoadError::GitRevision {
254                        name,
255                        path: path.clone(),
256                        git_revision: git_revision.clone(),
257                        message,
258                    })
259                }
260            }
261            Seed::Command { command, cache } => {
262                let cache_key_output = match cache {
263                    CommandCacheConfig::None => {
264                        hash_chain.stop();
265                        None
266                    }
267                    CommandCacheConfig::CommandHash => {
268                        hash_chain.update(&command.command);
269                        for argument in &command.arguments {
270                            hash_chain.update(argument);
271                        }
272                        None
273                    }
274                    CommandCacheConfig::KeyCommand {
275                        command: key_command,
276                        arguments: key_arguments,
277                    } => {
278                        let output = cmd_proc::Command::new(key_command)
279                            .arguments(key_arguments)
280                            .stdout_capture()
281                            .stderr_capture()
282                            .accept_nonzero_exit()
283                            .run()
284                            .await
285                            .map_err(|error| LoadError::KeyCommand {
286                                name: name.clone(),
287                                command: key_command.clone(),
288                                message: error.to_string(),
289                            })?;
290
291                        if output.status.success() {
292                            hash_chain.update(&output.stdout);
293                            Some(output.stdout)
294                        } else {
295                            let message = String::from_utf8(output.stderr).map_err(|error| {
296                                LoadError::KeyCommand {
297                                    name: name.clone(),
298                                    command: key_command.clone(),
299                                    message: error.to_string(),
300                                }
301                            })?;
302                            return Err(LoadError::KeyCommand {
303                                name,
304                                command: key_command.clone(),
305                                message,
306                            });
307                        }
308                    }
309                    CommandCacheConfig::KeyScript { script: key_script } => {
310                        let output = cmd_proc::Command::new("sh")
311                            .arguments(["-e", "-c"])
312                            .argument(key_script)
313                            .stdout_capture()
314                            .stderr_capture()
315                            .accept_nonzero_exit()
316                            .run()
317                            .await
318                            .map_err(|error| LoadError::KeyScript {
319                                name: name.clone(),
320                                message: error.to_string(),
321                            })?;
322
323                        if output.status.success() {
324                            hash_chain.update(&output.stdout);
325                            Some(output.stdout)
326                        } else {
327                            let message = String::from_utf8(output.stderr).map_err(|error| {
328                                LoadError::KeyScript {
329                                    name: name.clone(),
330                                    message: error.to_string(),
331                                }
332                            })?;
333                            return Err(LoadError::KeyScript { name, message });
334                        }
335                    }
336                };
337
338                Ok(LoadedSeed::Command {
339                    cache_status: CacheStatus::from_cache_key(
340                        hash_chain.cache_key(),
341                        backend,
342                        instance_name,
343                    )
344                    .await,
345                    cache_key_output,
346                    name,
347                    command: command.clone(),
348                })
349            }
350            Seed::Script { script } => {
351                hash_chain.update(script);
352
353                Ok(LoadedSeed::Script {
354                    cache_status: CacheStatus::from_cache_key(
355                        hash_chain.cache_key(),
356                        backend,
357                        instance_name,
358                    )
359                    .await,
360                    name,
361                    script: script.clone(),
362                })
363            }
364            Seed::ContainerScript { script } => {
365                hash_chain.update(script);
366
367                Ok(LoadedSeed::ContainerScript {
368                    cache_status: CacheStatus::from_cache_key(
369                        hash_chain.cache_key(),
370                        backend,
371                        instance_name,
372                    )
373                    .await,
374                    name,
375                    script: script.clone(),
376                })
377            }
378            Seed::CsvFile { path, table } => {
379                let content =
380                    std::fs::read_to_string(path).map_err(|source| LoadError::FileRead {
381                        name: name.clone(),
382                        path: path.clone(),
383                        source,
384                    })?;
385
386                hash_chain.update(table.schema.as_ref());
387                hash_chain.update(table.table.as_ref());
388                hash_chain.update(&content);
389
390                Ok(LoadedSeed::CsvFile {
391                    cache_status: CacheStatus::from_cache_key(
392                        hash_chain.cache_key(),
393                        backend,
394                        instance_name,
395                    )
396                    .await,
397                    name,
398                    path: path.clone(),
399                    table: table.clone(),
400                    content,
401                })
402            }
403        }
404    }
405}
406
407#[derive(Debug, thiserror::Error)]
408pub enum LoadError {
409    #[error("Failed to load seed {name}: could not read file {path}: {source}")]
410    FileRead {
411        name: SeedName,
412        path: std::path::PathBuf,
413        source: std::io::Error,
414    },
415    #[error(
416        "Failed to load seed {name}: could not read {path} at git revision {git_revision}: {message}"
417    )]
418    GitRevision {
419        name: SeedName,
420        path: std::path::PathBuf,
421        git_revision: String,
422        message: String,
423    },
424    #[error("Failed to load seed {name}: cache key command {command} failed: {message}")]
425    KeyCommand {
426        name: SeedName,
427        command: String,
428        message: String,
429    },
430    #[error("Failed to load seed {name}: cache key script failed: {message}")]
431    KeyScript { name: SeedName, message: String },
432}
433
434#[derive(Clone, Debug, PartialEq)]
435pub enum LoadedSeed {
436    SqlFile {
437        cache_status: CacheStatus,
438        name: SeedName,
439        path: std::path::PathBuf,
440        content: String,
441    },
442    SqlFileGitRevision {
443        cache_status: CacheStatus,
444        name: SeedName,
445        path: std::path::PathBuf,
446        git_revision: String,
447        content: String,
448    },
449    Command {
450        cache_status: CacheStatus,
451        cache_key_output: Option<Vec<u8>>,
452        name: SeedName,
453        command: Command,
454    },
455    Script {
456        cache_status: CacheStatus,
457        name: SeedName,
458        script: String,
459    },
460    ContainerScript {
461        cache_status: CacheStatus,
462        name: SeedName,
463        script: String,
464    },
465    CsvFile {
466        cache_status: CacheStatus,
467        name: SeedName,
468        path: std::path::PathBuf,
469        table: pg_client::QualifiedTable,
470        content: String,
471    },
472}
473
474impl LoadedSeed {
475    #[must_use]
476    pub fn cache_status(&self) -> &CacheStatus {
477        match self {
478            Self::SqlFile { cache_status, .. }
479            | Self::SqlFileGitRevision { cache_status, .. }
480            | Self::Command { cache_status, .. }
481            | Self::Script { cache_status, .. }
482            | Self::ContainerScript { cache_status, .. }
483            | Self::CsvFile { cache_status, .. } => cache_status,
484        }
485    }
486
487    #[must_use]
488    pub fn name(&self) -> &SeedName {
489        match self {
490            Self::SqlFile { name, .. }
491            | Self::SqlFileGitRevision { name, .. }
492            | Self::Command { name, .. }
493            | Self::Script { name, .. }
494            | Self::ContainerScript { name, .. }
495            | Self::CsvFile { name, .. } => name,
496        }
497    }
498
499    fn variant_name(&self) -> &'static str {
500        match self {
501            Self::SqlFile { .. } => "sql-file",
502            Self::SqlFileGitRevision { .. } => "sql-file-git-revision",
503            Self::Command { .. } => "command",
504            Self::Script { .. } => "script",
505            Self::ContainerScript { .. } => "container-script",
506            Self::CsvFile { .. } => "csv-file",
507        }
508    }
509}
510
511struct HashChain {
512    hasher: Option<sha2::Sha256>,
513}
514
515impl HashChain {
516    fn new() -> Self {
517        use sha2::Digest;
518
519        Self {
520            hasher: Some(sha2::Sha256::new()),
521        }
522    }
523
524    fn update(&mut self, bytes: impl AsRef<[u8]>) {
525        use sha2::Digest;
526
527        if let Some(ref mut hasher) = self.hasher {
528            hasher.update(bytes)
529        }
530    }
531
532    fn cache_key(&self) -> Option<CacheKey> {
533        use sha2::Digest;
534
535        self.hasher
536            .as_ref()
537            .map(|hasher| hasher.clone().finalize().into())
538    }
539
540    fn stop(&mut self) {
541        self.hasher = None
542    }
543}
544
545#[derive(Debug, PartialEq)]
546pub struct LoadedSeeds<'a> {
547    image: &'a crate::image::Image,
548    seeds: Vec<LoadedSeed>,
549}
550
551impl<'a> LoadedSeeds<'a> {
552    pub async fn load(
553        image: &'a crate::image::Image,
554        ssl_config: Option<&crate::definition::SslConfig>,
555        seeds: &indexmap::IndexMap<SeedName, Seed>,
556        backend: &ociman::Backend,
557        instance_name: &crate::InstanceName,
558    ) -> Result<Self, LoadError> {
559        let mut hash_chain = HashChain::new();
560        let mut loaded_seeds = Vec::new();
561
562        hash_chain.update(crate::VERSION_STR);
563        hash_chain.update(image.to_string());
564
565        match ssl_config {
566            Some(crate::definition::SslConfig::Generated { hostname }) => {
567                hash_chain.update("ssl:generated:");
568                hash_chain.update(hostname.as_str());
569            }
570            None => {
571                hash_chain.update("ssl:none");
572            }
573        }
574
575        for (name, seed) in seeds {
576            let loaded_seed = seed
577                .load(name.clone(), &mut hash_chain, backend, instance_name)
578                .await?;
579            loaded_seeds.push(loaded_seed);
580        }
581
582        Ok(Self {
583            image,
584            seeds: loaded_seeds,
585        })
586    }
587
588    pub fn iter_seeds(&self) -> impl Iterator<Item = &LoadedSeed> {
589        self.seeds.iter()
590    }
591
592    pub fn print(&self, instance_name: &crate::InstanceName) {
593        println!("Instance: {instance_name}");
594        println!("Image:    {}", self.image);
595        println!("Version:  {}", crate::VERSION_STR);
596        println!();
597
598        let mut table = comfy_table::Table::new();
599
600        table
601            .load_preset(comfy_table::presets::NOTHING)
602            .set_header(["Seed", "Type", "Status"]);
603
604        for seed in &self.seeds {
605            table.add_row([
606                seed.name().as_str(),
607                seed.variant_name(),
608                seed.cache_status().status_str(),
609            ]);
610        }
611
612        println!("{table}");
613    }
614
615    pub fn print_json(&self, instance_name: &crate::InstanceName) {
616        #[derive(serde::Serialize)]
617        struct Output<'a> {
618            instance: &'a str,
619            image: String,
620            version: &'a str,
621            seeds: Vec<SeedOutput<'a>>,
622        }
623
624        #[derive(serde::Serialize)]
625        struct SeedOutput<'a> {
626            name: &'a str,
627            r#type: &'a str,
628            status: &'a str,
629            #[serde(skip_serializing_if = "Option::is_none")]
630            reference: Option<String>,
631        }
632
633        let output = Output {
634            instance: &instance_name.to_string(),
635            image: self.image.to_string(),
636            version: crate::VERSION_STR,
637            seeds: self
638                .seeds
639                .iter()
640                .map(|seed| SeedOutput {
641                    name: seed.name().as_str(),
642                    r#type: seed.variant_name(),
643                    status: seed.cache_status().status_str(),
644                    reference: seed.cache_status().reference().map(|r| r.to_string()),
645                })
646                .collect(),
647        };
648
649        println!("{}", serde_json::to_string_pretty(&output).unwrap());
650    }
651}
652
653#[cfg(test)]
654mod test {
655    use super::*;
656
657    #[test]
658    fn test_seed_name_rejects_empty_string() {
659        assert_eq!("".parse::<SeedName>(), Err(SeedNameError));
660        assert_eq!(SeedName::try_from(""), Err(SeedNameError));
661        assert_eq!(SeedName::try_from(String::new()), Err(SeedNameError));
662    }
663
664    #[test]
665    fn test_seed_name_accepts_non_empty_string() {
666        assert_eq!(
667            "valid-name".parse::<SeedName>(),
668            Ok(SeedName("valid-name".to_string()))
669        );
670        assert_eq!(
671            SeedName::try_from("valid-name"),
672            Ok(SeedName("valid-name".to_string()))
673        );
674        assert_eq!(
675            SeedName::try_from("valid-name".to_string()),
676            Ok(SeedName("valid-name".to_string()))
677        );
678    }
679
680    #[test]
681    fn test_seed_name_display() {
682        let name: SeedName = "test-seed".parse().unwrap();
683        assert_eq!(name.to_string(), "test-seed");
684        assert_eq!(name.as_str(), "test-seed");
685    }
686
687    #[test]
688    fn test_cache_status_uncacheable() {
689        let loaded_seed = LoadedSeed::Command {
690            cache_status: CacheStatus::Uncacheable,
691            cache_key_output: None,
692            name: "run-migrations".parse().unwrap(),
693            command: Command::new("migrate", ["up"]),
694        };
695
696        assert!(loaded_seed.cache_status().reference().is_none());
697        assert!(!loaded_seed.cache_status().is_hit());
698    }
699
700    #[test]
701    fn test_cache_status_miss() {
702        let reference: ociman::Reference =
703            "pg-ephemeral/main:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
704                .parse()
705                .unwrap();
706
707        let loaded_seed = LoadedSeed::SqlFile {
708            cache_status: CacheStatus::Miss {
709                reference: reference.clone(),
710            },
711            name: "schema".parse().unwrap(),
712            path: "schema.sql".into(),
713            content: "CREATE TABLE test();".to_string(),
714        };
715
716        assert_eq!(loaded_seed.cache_status().reference(), Some(&reference));
717        assert!(!loaded_seed.cache_status().is_hit());
718    }
719
720    #[test]
721    fn test_cache_status_hit() {
722        let reference: ociman::Reference =
723            "pg-ephemeral/main:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
724                .parse()
725                .unwrap();
726
727        let loaded_seed = LoadedSeed::SqlFile {
728            cache_status: CacheStatus::Hit {
729                reference: reference.clone(),
730            },
731            name: "schema".parse().unwrap(),
732            path: "schema.sql".into(),
733            content: "CREATE TABLE test();".to_string(),
734        };
735
736        assert_eq!(loaded_seed.cache_status().reference(), Some(&reference));
737        assert!(loaded_seed.cache_status().is_hit());
738    }
739}