Skip to main content

pg_ephemeral/
seed.rs

1use git_proc::Build;
2
3type CacheKey = [u8; 32];
4
5#[derive(Clone, Debug, PartialEq)]
6pub enum CacheStatus {
7    Hit { reference: ociman::Reference },
8    Miss { reference: ociman::Reference },
9    Uncacheable,
10}
11
12impl CacheStatus {
13    async fn from_cache_key(
14        cache_key: Option<CacheKey>,
15        backend: &ociman::Backend,
16        instance_name: &crate::InstanceName,
17    ) -> Self {
18        match cache_key {
19            Some(key) => {
20                let reference = format!("pg-ephemeral/{}:{}", instance_name, hex::encode(key))
21                    .parse()
22                    .unwrap();
23                if backend.is_image_present(&reference).await {
24                    Self::Hit { reference }
25                } else {
26                    Self::Miss { reference }
27                }
28            }
29            None => Self::Uncacheable,
30        }
31    }
32
33    #[must_use]
34    pub fn reference(&self) -> Option<&ociman::Reference> {
35        match self {
36            Self::Hit { reference } | Self::Miss { reference } => Some(reference),
37            Self::Uncacheable => None,
38        }
39    }
40
41    #[must_use]
42    pub fn is_hit(&self) -> bool {
43        matches!(self, Self::Hit { .. })
44    }
45
46    #[must_use]
47    pub fn status_str(&self) -> &'static str {
48        match self {
49            Self::Hit { .. } => "hit",
50            Self::Miss { .. } => "miss",
51            Self::Uncacheable => "uncacheable",
52        }
53    }
54}
55
56#[derive(Clone, Debug, Hash, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
57#[serde(try_from = "String")]
58pub struct SeedName(String);
59
60impl SeedName {
61    #[must_use]
62    pub fn as_str(&self) -> &str {
63        &self.0
64    }
65}
66
67impl std::fmt::Display for SeedName {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        write!(f, "{}", self.0)
70    }
71}
72
73#[derive(Debug, PartialEq, Eq, thiserror::Error)]
74#[error("Seed name cannot be empty")]
75pub struct SeedNameError;
76
77#[derive(Debug, PartialEq, Eq, thiserror::Error)]
78#[error("Duplicate seed name: {0}")]
79pub struct DuplicateSeedName(pub SeedName);
80
81impl std::str::FromStr for SeedName {
82    type Err = SeedNameError;
83
84    fn from_str(value: &str) -> Result<Self, Self::Err> {
85        if value.is_empty() {
86            Err(SeedNameError)
87        } else {
88            Ok(Self(value.to_string()))
89        }
90    }
91}
92
93impl TryFrom<String> for SeedName {
94    type Error = SeedNameError;
95
96    fn try_from(value: String) -> Result<Self, Self::Error> {
97        if value.is_empty() {
98            Err(SeedNameError)
99        } else {
100            Ok(Self(value))
101        }
102    }
103}
104
105impl TryFrom<&str> for SeedName {
106    type Error = SeedNameError;
107
108    fn try_from(value: &str) -> Result<Self, Self::Error> {
109        value.parse()
110    }
111}
112
113#[derive(Clone, Debug, PartialEq)]
114pub struct Command {
115    pub command: String,
116    pub arguments: Vec<String>,
117}
118
119impl Command {
120    pub fn new(
121        command: impl Into<String>,
122        arguments: impl IntoIterator<Item = impl Into<String>>,
123    ) -> Self {
124        Self {
125            command: command.into(),
126            arguments: arguments.into_iter().map(|a| a.into()).collect(),
127        }
128    }
129}
130
131#[derive(Clone, Debug, serde::Deserialize, PartialEq)]
132#[serde(tag = "type", rename_all = "kebab-case")]
133pub enum CommandCacheConfig {
134    /// Disable caching, breaks the cache chain
135    None,
136    /// Hash the command and arguments
137    CommandHash,
138    /// Run a command to get cache key input
139    KeyCommand {
140        command: String,
141        #[serde(default)]
142        arguments: Vec<String>,
143    },
144    /// Run a script to get cache key input
145    KeyScript { script: String },
146}
147
148#[derive(Clone, Debug, PartialEq)]
149pub enum Seed {
150    SqlFile {
151        path: std::path::PathBuf,
152    },
153    SqlFileGitRevision {
154        git_revision: String,
155        path: std::path::PathBuf,
156    },
157    Command {
158        command: Command,
159        cache: CommandCacheConfig,
160    },
161    Script {
162        script: String,
163    },
164    ContainerScript {
165        script: String,
166    },
167    CsvFile {
168        path: std::path::PathBuf,
169        table: pg_client::QualifiedTable,
170        delimiter: char,
171    },
172}
173
174impl Seed {
175    async fn load(
176        &self,
177        name: SeedName,
178        hash_chain: &mut HashChain,
179        backend: &ociman::Backend,
180        instance_name: &crate::InstanceName,
181    ) -> Result<LoadedSeed, LoadError> {
182        match self {
183            Seed::SqlFile { path } => {
184                let content =
185                    std::fs::read_to_string(path).map_err(|source| LoadError::FileRead {
186                        name: name.clone(),
187                        path: path.clone(),
188                        source,
189                    })?;
190
191                hash_chain.update(&content);
192
193                Ok(LoadedSeed::SqlFile {
194                    cache_status: CacheStatus::from_cache_key(
195                        hash_chain.cache_key(),
196                        backend,
197                        instance_name,
198                    )
199                    .await,
200                    name,
201                    path: path.clone(),
202                    content,
203                })
204            }
205            Seed::SqlFileGitRevision { path, git_revision } => {
206                let output =
207                    git_proc::show::new(&format!("{git_revision}:{}", path.to_str().unwrap()))
208                        .build()
209                        .stdout_capture()
210                        .stderr_capture()
211                        .accept_nonzero_exit()
212                        .run()
213                        .await
214                        .map_err(|error| LoadError::GitRevision {
215                            name: name.clone(),
216                            path: path.clone(),
217                            git_revision: git_revision.clone(),
218                            message: error.to_string(),
219                        })?;
220
221                if output.status.success() {
222                    let content = String::from_utf8(output.stdout).map_err(|error| {
223                        LoadError::GitRevision {
224                            name: name.clone(),
225                            path: path.clone(),
226                            git_revision: git_revision.clone(),
227                            message: error.to_string(),
228                        }
229                    })?;
230
231                    hash_chain.update(&content);
232
233                    Ok(LoadedSeed::SqlFileGitRevision {
234                        cache_status: CacheStatus::from_cache_key(
235                            hash_chain.cache_key(),
236                            backend,
237                            instance_name,
238                        )
239                        .await,
240                        name,
241                        path: path.clone(),
242                        git_revision: git_revision.clone(),
243                        content,
244                    })
245                } else {
246                    let message = String::from_utf8(output.stderr).map_err(|error| {
247                        LoadError::GitRevision {
248                            name: name.clone(),
249                            path: path.clone(),
250                            git_revision: git_revision.clone(),
251                            message: error.to_string(),
252                        }
253                    })?;
254                    Err(LoadError::GitRevision {
255                        name,
256                        path: path.clone(),
257                        git_revision: git_revision.clone(),
258                        message,
259                    })
260                }
261            }
262            Seed::Command { command, cache } => {
263                let cache_key_output = match cache {
264                    CommandCacheConfig::None => {
265                        hash_chain.stop();
266                        None
267                    }
268                    CommandCacheConfig::CommandHash => {
269                        hash_chain.update(&command.command);
270                        for argument in &command.arguments {
271                            hash_chain.update(argument);
272                        }
273                        None
274                    }
275                    CommandCacheConfig::KeyCommand {
276                        command: key_command,
277                        arguments: key_arguments,
278                    } => {
279                        let output = cmd_proc::Command::new(key_command)
280                            .arguments(key_arguments)
281                            .stdout_capture()
282                            .stderr_capture()
283                            .accept_nonzero_exit()
284                            .run()
285                            .await
286                            .map_err(|error| LoadError::KeyCommand {
287                                name: name.clone(),
288                                command: key_command.clone(),
289                                message: error.to_string(),
290                            })?;
291
292                        if output.status.success() {
293                            hash_chain.update(&output.stdout);
294                            Some(output.stdout)
295                        } else {
296                            let message = String::from_utf8(output.stderr).map_err(|error| {
297                                LoadError::KeyCommand {
298                                    name: name.clone(),
299                                    command: key_command.clone(),
300                                    message: error.to_string(),
301                                }
302                            })?;
303                            return Err(LoadError::KeyCommand {
304                                name,
305                                command: key_command.clone(),
306                                message,
307                            });
308                        }
309                    }
310                    CommandCacheConfig::KeyScript { script: key_script } => {
311                        let output = cmd_proc::Command::new("sh")
312                            .arguments(["-e", "-c"])
313                            .argument(key_script)
314                            .stdout_capture()
315                            .stderr_capture()
316                            .accept_nonzero_exit()
317                            .run()
318                            .await
319                            .map_err(|error| LoadError::KeyScript {
320                                name: name.clone(),
321                                message: error.to_string(),
322                            })?;
323
324                        if output.status.success() {
325                            hash_chain.update(&output.stdout);
326                            Some(output.stdout)
327                        } else {
328                            let message = String::from_utf8(output.stderr).map_err(|error| {
329                                LoadError::KeyScript {
330                                    name: name.clone(),
331                                    message: error.to_string(),
332                                }
333                            })?;
334                            return Err(LoadError::KeyScript { name, message });
335                        }
336                    }
337                };
338
339                Ok(LoadedSeed::Command {
340                    cache_status: CacheStatus::from_cache_key(
341                        hash_chain.cache_key(),
342                        backend,
343                        instance_name,
344                    )
345                    .await,
346                    cache_key_output,
347                    name,
348                    command: command.clone(),
349                })
350            }
351            Seed::Script { script } => {
352                hash_chain.update(script);
353
354                Ok(LoadedSeed::Script {
355                    cache_status: CacheStatus::from_cache_key(
356                        hash_chain.cache_key(),
357                        backend,
358                        instance_name,
359                    )
360                    .await,
361                    name,
362                    script: script.clone(),
363                })
364            }
365            Seed::ContainerScript { script } => {
366                hash_chain.update(script);
367
368                Ok(LoadedSeed::ContainerScript {
369                    cache_status: CacheStatus::from_cache_key(
370                        hash_chain.cache_key(),
371                        backend,
372                        instance_name,
373                    )
374                    .await,
375                    name,
376                    script: script.clone(),
377                })
378            }
379            Seed::CsvFile {
380                path,
381                table,
382                delimiter,
383            } => {
384                let content =
385                    std::fs::read_to_string(path).map_err(|source| LoadError::FileRead {
386                        name: name.clone(),
387                        path: path.clone(),
388                        source,
389                    })?;
390
391                hash_chain.update(table.schema.as_ref());
392                hash_chain.update(table.table.as_ref());
393                hash_chain.update(&content);
394
395                Ok(LoadedSeed::CsvFile {
396                    cache_status: CacheStatus::from_cache_key(
397                        hash_chain.cache_key(),
398                        backend,
399                        instance_name,
400                    )
401                    .await,
402                    name,
403                    path: path.clone(),
404                    table: table.clone(),
405                    delimiter: *delimiter,
406                    content,
407                })
408            }
409        }
410    }
411}
412
413#[derive(Debug, thiserror::Error)]
414pub enum LoadError {
415    #[error("Failed to load seed {name}: could not read file {path}: {source}")]
416    FileRead {
417        name: SeedName,
418        path: std::path::PathBuf,
419        source: std::io::Error,
420    },
421    #[error(
422        "Failed to load seed {name}: could not read {path} at git revision {git_revision}: {message}"
423    )]
424    GitRevision {
425        name: SeedName,
426        path: std::path::PathBuf,
427        git_revision: String,
428        message: String,
429    },
430    #[error("Failed to load seed {name}: cache key command {command} failed: {message}")]
431    KeyCommand {
432        name: SeedName,
433        command: String,
434        message: String,
435    },
436    #[error("Failed to load seed {name}: cache key script failed: {message}")]
437    KeyScript { name: SeedName, message: String },
438}
439
440#[derive(Clone, Debug, PartialEq)]
441pub enum LoadedSeed {
442    SqlFile {
443        cache_status: CacheStatus,
444        name: SeedName,
445        path: std::path::PathBuf,
446        content: String,
447    },
448    SqlFileGitRevision {
449        cache_status: CacheStatus,
450        name: SeedName,
451        path: std::path::PathBuf,
452        git_revision: String,
453        content: String,
454    },
455    Command {
456        cache_status: CacheStatus,
457        cache_key_output: Option<Vec<u8>>,
458        name: SeedName,
459        command: Command,
460    },
461    Script {
462        cache_status: CacheStatus,
463        name: SeedName,
464        script: String,
465    },
466    ContainerScript {
467        cache_status: CacheStatus,
468        name: SeedName,
469        script: String,
470    },
471    CsvFile {
472        cache_status: CacheStatus,
473        name: SeedName,
474        path: std::path::PathBuf,
475        table: pg_client::QualifiedTable,
476        delimiter: char,
477        content: String,
478    },
479}
480
481impl LoadedSeed {
482    #[must_use]
483    pub fn cache_status(&self) -> &CacheStatus {
484        match self {
485            Self::SqlFile { cache_status, .. }
486            | Self::SqlFileGitRevision { cache_status, .. }
487            | Self::Command { cache_status, .. }
488            | Self::Script { cache_status, .. }
489            | Self::ContainerScript { cache_status, .. }
490            | Self::CsvFile { cache_status, .. } => cache_status,
491        }
492    }
493
494    #[must_use]
495    pub fn name(&self) -> &SeedName {
496        match self {
497            Self::SqlFile { name, .. }
498            | Self::SqlFileGitRevision { name, .. }
499            | Self::Command { name, .. }
500            | Self::Script { name, .. }
501            | Self::ContainerScript { name, .. }
502            | Self::CsvFile { name, .. } => name,
503        }
504    }
505
506    fn variant_name(&self) -> &'static str {
507        match self {
508            Self::SqlFile { .. } => "sql-file",
509            Self::SqlFileGitRevision { .. } => "sql-file-git-revision",
510            Self::Command { .. } => "command",
511            Self::Script { .. } => "script",
512            Self::ContainerScript { .. } => "container-script",
513            Self::CsvFile { .. } => "csv-file",
514        }
515    }
516}
517
518struct HashChain {
519    hasher: Option<sha2::Sha256>,
520}
521
522impl HashChain {
523    fn new() -> Self {
524        use sha2::Digest;
525
526        Self {
527            hasher: Some(sha2::Sha256::new()),
528        }
529    }
530
531    fn update(&mut self, bytes: impl AsRef<[u8]>) {
532        use sha2::Digest;
533
534        if let Some(ref mut hasher) = self.hasher {
535            hasher.update(bytes)
536        }
537    }
538
539    fn cache_key(&self) -> Option<CacheKey> {
540        use sha2::Digest;
541
542        self.hasher
543            .as_ref()
544            .map(|hasher| hasher.clone().finalize().into())
545    }
546
547    fn stop(&mut self) {
548        self.hasher = None
549    }
550}
551
552#[derive(Debug, PartialEq)]
553pub struct LoadedSeeds<'a> {
554    image: &'a crate::image::Image,
555    seeds: Vec<LoadedSeed>,
556}
557
558impl<'a> LoadedSeeds<'a> {
559    pub async fn load(
560        image: &'a crate::image::Image,
561        ssl_config: Option<&crate::definition::SslConfig>,
562        seeds: &indexmap::IndexMap<SeedName, Seed>,
563        backend: &ociman::Backend,
564        instance_name: &crate::InstanceName,
565    ) -> Result<Self, LoadError> {
566        let mut hash_chain = HashChain::new();
567        let mut loaded_seeds = Vec::new();
568
569        hash_chain.update(crate::VERSION_STR);
570        hash_chain.update(image.to_string());
571
572        match ssl_config {
573            Some(crate::definition::SslConfig::Generated { hostname }) => {
574                hash_chain.update("ssl:generated:");
575                hash_chain.update(hostname.as_str());
576            }
577            None => {
578                hash_chain.update("ssl:none");
579            }
580        }
581
582        for (name, seed) in seeds {
583            let loaded_seed = seed
584                .load(name.clone(), &mut hash_chain, backend, instance_name)
585                .await?;
586            loaded_seeds.push(loaded_seed);
587        }
588
589        Ok(Self {
590            image,
591            seeds: loaded_seeds,
592        })
593    }
594
595    pub fn iter_seeds(&self) -> impl Iterator<Item = &LoadedSeed> {
596        self.seeds.iter()
597    }
598
599    pub fn print(&self, instance_name: &crate::InstanceName) {
600        println!("Instance: {instance_name}");
601        println!("Image:    {}", self.image);
602        println!("Version:  {}", crate::VERSION_STR);
603        println!();
604
605        let mut table = comfy_table::Table::new();
606
607        table
608            .load_preset(comfy_table::presets::NOTHING)
609            .set_header(["Seed", "Type", "Status"]);
610
611        for seed in &self.seeds {
612            table.add_row([
613                seed.name().as_str(),
614                seed.variant_name(),
615                seed.cache_status().status_str(),
616            ]);
617        }
618
619        println!("{table}");
620    }
621
622    pub fn print_json(&self, instance_name: &crate::InstanceName) {
623        #[derive(serde::Serialize)]
624        struct Output<'a> {
625            instance: &'a str,
626            image: String,
627            version: &'a str,
628            seeds: Vec<SeedOutput<'a>>,
629        }
630
631        #[derive(serde::Serialize)]
632        struct SeedOutput<'a> {
633            name: &'a str,
634            r#type: &'a str,
635            status: &'a str,
636            #[serde(skip_serializing_if = "Option::is_none")]
637            reference: Option<String>,
638        }
639
640        let output = Output {
641            instance: &instance_name.to_string(),
642            image: self.image.to_string(),
643            version: crate::VERSION_STR,
644            seeds: self
645                .seeds
646                .iter()
647                .map(|seed| SeedOutput {
648                    name: seed.name().as_str(),
649                    r#type: seed.variant_name(),
650                    status: seed.cache_status().status_str(),
651                    reference: seed.cache_status().reference().map(|r| r.to_string()),
652                })
653                .collect(),
654        };
655
656        println!("{}", serde_json::to_string_pretty(&output).unwrap());
657    }
658}
659
660#[cfg(test)]
661mod test {
662    use super::*;
663
664    #[test]
665    fn test_seed_name_rejects_empty_string() {
666        assert_eq!("".parse::<SeedName>(), Err(SeedNameError));
667        assert_eq!(SeedName::try_from(""), Err(SeedNameError));
668        assert_eq!(SeedName::try_from(String::new()), Err(SeedNameError));
669    }
670
671    #[test]
672    fn test_seed_name_accepts_non_empty_string() {
673        assert_eq!(
674            "valid-name".parse::<SeedName>(),
675            Ok(SeedName("valid-name".to_string()))
676        );
677        assert_eq!(
678            SeedName::try_from("valid-name"),
679            Ok(SeedName("valid-name".to_string()))
680        );
681        assert_eq!(
682            SeedName::try_from("valid-name".to_string()),
683            Ok(SeedName("valid-name".to_string()))
684        );
685    }
686
687    #[test]
688    fn test_seed_name_display() {
689        let name: SeedName = "test-seed".parse().unwrap();
690        assert_eq!(name.to_string(), "test-seed");
691        assert_eq!(name.as_str(), "test-seed");
692    }
693
694    #[test]
695    fn test_cache_status_uncacheable() {
696        let loaded_seed = LoadedSeed::Command {
697            cache_status: CacheStatus::Uncacheable,
698            cache_key_output: None,
699            name: "run-migrations".parse().unwrap(),
700            command: Command::new("migrate", ["up"]),
701        };
702
703        assert!(loaded_seed.cache_status().reference().is_none());
704        assert!(!loaded_seed.cache_status().is_hit());
705    }
706
707    #[test]
708    fn test_cache_status_miss() {
709        let reference: ociman::Reference =
710            "pg-ephemeral/main:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
711                .parse()
712                .unwrap();
713
714        let loaded_seed = LoadedSeed::SqlFile {
715            cache_status: CacheStatus::Miss {
716                reference: reference.clone(),
717            },
718            name: "schema".parse().unwrap(),
719            path: "schema.sql".into(),
720            content: "CREATE TABLE test();".to_string(),
721        };
722
723        assert_eq!(loaded_seed.cache_status().reference(), Some(&reference));
724        assert!(!loaded_seed.cache_status().is_hit());
725    }
726
727    #[test]
728    fn test_cache_status_hit() {
729        let reference: ociman::Reference =
730            "pg-ephemeral/main:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
731                .parse()
732                .unwrap();
733
734        let loaded_seed = LoadedSeed::SqlFile {
735            cache_status: CacheStatus::Hit {
736                reference: reference.clone(),
737            },
738            name: "schema".parse().unwrap(),
739            path: "schema.sql".into(),
740            content: "CREATE TABLE test();".to_string(),
741        };
742
743        assert_eq!(loaded_seed.cache_status().reference(), Some(&reference));
744        assert!(loaded_seed.cache_status().is_hit());
745    }
746}