Skip to main content

pg_ephemeral/
seed.rs

1use git_proc::Build;
2
3type CacheKey = [u8; 32];
4
5#[derive(Clone, Debug, PartialEq)]
6pub enum CacheStatus {
7    Hit { reference: ociman::Reference },
8    Miss { reference: ociman::Reference },
9    Uncacheable,
10}
11
12impl CacheStatus {
13    async fn from_cache_key(
14        cache_key: Option<CacheKey>,
15        backend: &ociman::Backend,
16        instance_name: &crate::InstanceName,
17    ) -> Self {
18        match cache_key {
19            Some(key) => {
20                let reference = format!("pg-ephemeral/{}:{}", instance_name, hex::encode(key))
21                    .parse()
22                    .unwrap();
23                if backend.is_image_present(&reference).await {
24                    Self::Hit { reference }
25                } else {
26                    Self::Miss { reference }
27                }
28            }
29            None => Self::Uncacheable,
30        }
31    }
32
33    #[must_use]
34    pub fn reference(&self) -> Option<&ociman::Reference> {
35        match self {
36            Self::Hit { reference } | Self::Miss { reference } => Some(reference),
37            Self::Uncacheable => None,
38        }
39    }
40
41    #[must_use]
42    pub fn is_hit(&self) -> bool {
43        matches!(self, Self::Hit { .. })
44    }
45
46    #[must_use]
47    pub fn status_str(&self) -> &'static str {
48        match self {
49            Self::Hit { .. } => "hit",
50            Self::Miss { .. } => "miss",
51            Self::Uncacheable => "uncacheable",
52        }
53    }
54}
55
56/// Maximum length of a seed name in bytes.
57pub const SEED_NAME_MAX_LENGTH: usize = 63;
58
59/// Error parsing a seed name.
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum SeedNameError {
62    /// Seed name cannot be empty.
63    Empty,
64    /// Seed name exceeds maximum length.
65    TooLong,
66    /// Seed name contains an invalid character.
67    InvalidCharacter,
68    /// Seed name starts with a dash.
69    StartsWithDash,
70    /// Seed name ends with a dash.
71    EndsWithDash,
72}
73
74impl SeedNameError {
75    #[must_use]
76    const fn message(&self) -> &'static str {
77        match self {
78            Self::Empty => "seed name cannot be empty",
79            Self::TooLong => "seed name exceeds maximum length of 63 bytes",
80            Self::InvalidCharacter => {
81                "seed name must contain only lowercase ASCII alphanumeric characters or dashes"
82            }
83            Self::StartsWithDash => "seed name cannot start with a dash",
84            Self::EndsWithDash => "seed name cannot end with a dash",
85        }
86    }
87}
88
89impl std::fmt::Display for SeedNameError {
90    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91        write!(formatter, "{}", self.message())
92    }
93}
94
95impl std::error::Error for SeedNameError {}
96
97const fn validate_seed_name(input: &str) -> Option<SeedNameError> {
98    let bytes = input.as_bytes();
99
100    if bytes.is_empty() {
101        return Some(SeedNameError::Empty);
102    }
103
104    if bytes.len() > SEED_NAME_MAX_LENGTH {
105        return Some(SeedNameError::TooLong);
106    }
107
108    if bytes[0] == b'-' {
109        return Some(SeedNameError::StartsWithDash);
110    }
111
112    if bytes[bytes.len() - 1] == b'-' {
113        return Some(SeedNameError::EndsWithDash);
114    }
115
116    let mut index = 0;
117
118    while index < bytes.len() {
119        let byte = bytes[index];
120        if !(byte.is_ascii_lowercase() || byte.is_ascii_digit() || byte == b'-') {
121            return Some(SeedNameError::InvalidCharacter);
122        }
123        index += 1;
124    }
125
126    None
127}
128
129#[derive(Clone, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
130#[serde(try_from = "String")]
131pub struct SeedName(std::borrow::Cow<'static, str>);
132
133impl SeedName {
134    /// Creates a new seed name from a static string.
135    ///
136    /// # Panics
137    ///
138    /// Panics if the input is empty, exceeds [`SEED_NAME_MAX_LENGTH`],
139    /// contains non-lowercase-alphanumeric/dash characters,
140    /// or starts/ends with a dash.
141    #[must_use]
142    pub const fn from_static_or_panic(input: &'static str) -> Self {
143        match validate_seed_name(input) {
144            Some(error) => panic!("{}", error.message()),
145            None => Self(std::borrow::Cow::Borrowed(input)),
146        }
147    }
148
149    /// Returns the seed name as a string slice.
150    #[must_use]
151    pub fn as_str(&self) -> &str {
152        &self.0
153    }
154}
155
156impl std::fmt::Display for SeedName {
157    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158        write!(f, "{}", self.0)
159    }
160}
161
162impl AsRef<str> for SeedName {
163    fn as_ref(&self) -> &str {
164        &self.0
165    }
166}
167
168#[derive(Debug, PartialEq, Eq, thiserror::Error)]
169#[error("Duplicate seed name: {0}")]
170pub struct DuplicateSeedName(pub SeedName);
171
172impl std::str::FromStr for SeedName {
173    type Err = SeedNameError;
174
175    fn from_str(value: &str) -> Result<Self, Self::Err> {
176        match validate_seed_name(value) {
177            Some(error) => Err(error),
178            None => Ok(Self(std::borrow::Cow::Owned(value.to_owned()))),
179        }
180    }
181}
182
183impl TryFrom<String> for SeedName {
184    type Error = SeedNameError;
185
186    fn try_from(value: String) -> Result<Self, Self::Error> {
187        match validate_seed_name(&value) {
188            Some(error) => Err(error),
189            None => Ok(Self(std::borrow::Cow::Owned(value))),
190        }
191    }
192}
193
194#[derive(Clone, Debug, PartialEq)]
195pub struct Command {
196    pub command: String,
197    pub arguments: Vec<String>,
198}
199
200impl Command {
201    pub fn new(
202        command: impl Into<String>,
203        arguments: impl IntoIterator<Item = impl Into<String>>,
204    ) -> Self {
205        Self {
206            command: command.into(),
207            arguments: arguments.into_iter().map(|a| a.into()).collect(),
208        }
209    }
210}
211
212#[derive(Clone, Debug, serde::Deserialize, PartialEq)]
213#[serde(tag = "type", rename_all = "kebab-case")]
214pub enum CommandCacheConfig {
215    /// Disable caching, breaks the cache chain
216    None,
217    /// Hash the command and arguments
218    CommandHash,
219    /// Run a command to get cache key input
220    KeyCommand {
221        command: String,
222        #[serde(default)]
223        arguments: Vec<String>,
224    },
225    /// Run a script to get cache key input
226    KeyScript { script: String },
227}
228
229#[derive(Clone, Debug, PartialEq)]
230pub enum Seed {
231    SqlFile {
232        path: std::path::PathBuf,
233    },
234    SqlFileGitRevision {
235        git_revision: String,
236        path: std::path::PathBuf,
237    },
238    Command {
239        command: Command,
240        cache: CommandCacheConfig,
241    },
242    Script {
243        script: String,
244    },
245    ContainerScript {
246        script: String,
247    },
248}
249
250impl Seed {
251    async fn load(
252        &self,
253        name: SeedName,
254        hash_chain: &mut HashChain,
255        backend: &ociman::Backend,
256        instance_name: &crate::InstanceName,
257    ) -> Result<LoadedSeed, LoadError> {
258        match self {
259            Seed::SqlFile { path } => {
260                let content =
261                    std::fs::read_to_string(path).map_err(|source| LoadError::FileRead {
262                        name: name.clone(),
263                        path: path.clone(),
264                        source,
265                    })?;
266
267                hash_chain.update(&content);
268
269                Ok(LoadedSeed::SqlFile {
270                    cache_status: CacheStatus::from_cache_key(
271                        hash_chain.cache_key(),
272                        backend,
273                        instance_name,
274                    )
275                    .await,
276                    name,
277                    path: path.clone(),
278                    content,
279                })
280            }
281            Seed::SqlFileGitRevision { path, git_revision } => {
282                let output =
283                    git_proc::show::new(&format!("{git_revision}:{}", path.to_str().unwrap()))
284                        .build()
285                        .stdout_capture()
286                        .stderr_capture()
287                        .accept_nonzero_exit()
288                        .run()
289                        .await
290                        .map_err(|error| LoadError::GitRevision {
291                            name: name.clone(),
292                            path: path.clone(),
293                            git_revision: git_revision.clone(),
294                            message: error.to_string(),
295                        })?;
296
297                if output.status.success() {
298                    let content = String::from_utf8(output.stdout).map_err(|error| {
299                        LoadError::GitRevision {
300                            name: name.clone(),
301                            path: path.clone(),
302                            git_revision: git_revision.clone(),
303                            message: error.to_string(),
304                        }
305                    })?;
306
307                    hash_chain.update(&content);
308
309                    Ok(LoadedSeed::SqlFileGitRevision {
310                        cache_status: CacheStatus::from_cache_key(
311                            hash_chain.cache_key(),
312                            backend,
313                            instance_name,
314                        )
315                        .await,
316                        name,
317                        path: path.clone(),
318                        git_revision: git_revision.clone(),
319                        content,
320                    })
321                } else {
322                    let message = String::from_utf8(output.stderr).map_err(|error| {
323                        LoadError::GitRevision {
324                            name: name.clone(),
325                            path: path.clone(),
326                            git_revision: git_revision.clone(),
327                            message: error.to_string(),
328                        }
329                    })?;
330                    Err(LoadError::GitRevision {
331                        name,
332                        path: path.clone(),
333                        git_revision: git_revision.clone(),
334                        message,
335                    })
336                }
337            }
338            Seed::Command { command, cache } => {
339                let cache_key_output = match cache {
340                    CommandCacheConfig::None => {
341                        hash_chain.stop();
342                        None
343                    }
344                    CommandCacheConfig::CommandHash => {
345                        hash_chain.update(&command.command);
346                        for argument in &command.arguments {
347                            hash_chain.update(argument);
348                        }
349                        None
350                    }
351                    CommandCacheConfig::KeyCommand {
352                        command: key_command,
353                        arguments: key_arguments,
354                    } => {
355                        let output = cmd_proc::Command::new(key_command)
356                            .arguments(key_arguments)
357                            .stdout_capture()
358                            .stderr_capture()
359                            .accept_nonzero_exit()
360                            .run()
361                            .await
362                            .map_err(|error| LoadError::KeyCommand {
363                                name: name.clone(),
364                                command: key_command.clone(),
365                                message: error.to_string(),
366                            })?;
367
368                        if output.status.success() {
369                            hash_chain.update(&output.stdout);
370                            Some(output.stdout)
371                        } else {
372                            let message = String::from_utf8(output.stderr).map_err(|error| {
373                                LoadError::KeyCommand {
374                                    name: name.clone(),
375                                    command: key_command.clone(),
376                                    message: error.to_string(),
377                                }
378                            })?;
379                            return Err(LoadError::KeyCommand {
380                                name,
381                                command: key_command.clone(),
382                                message,
383                            });
384                        }
385                    }
386                    CommandCacheConfig::KeyScript { script: key_script } => {
387                        let output = cmd_proc::Command::new("sh")
388                            .arguments(["-e", "-c"])
389                            .argument(key_script)
390                            .stdout_capture()
391                            .stderr_capture()
392                            .accept_nonzero_exit()
393                            .run()
394                            .await
395                            .map_err(|error| LoadError::KeyScript {
396                                name: name.clone(),
397                                message: error.to_string(),
398                            })?;
399
400                        if output.status.success() {
401                            hash_chain.update(&output.stdout);
402                            Some(output.stdout)
403                        } else {
404                            let message = String::from_utf8(output.stderr).map_err(|error| {
405                                LoadError::KeyScript {
406                                    name: name.clone(),
407                                    message: error.to_string(),
408                                }
409                            })?;
410                            return Err(LoadError::KeyScript { name, message });
411                        }
412                    }
413                };
414
415                Ok(LoadedSeed::Command {
416                    cache_status: CacheStatus::from_cache_key(
417                        hash_chain.cache_key(),
418                        backend,
419                        instance_name,
420                    )
421                    .await,
422                    cache_key_output,
423                    name,
424                    command: command.clone(),
425                })
426            }
427            Seed::Script { script } => {
428                hash_chain.update(script);
429
430                Ok(LoadedSeed::Script {
431                    cache_status: CacheStatus::from_cache_key(
432                        hash_chain.cache_key(),
433                        backend,
434                        instance_name,
435                    )
436                    .await,
437                    name,
438                    script: script.clone(),
439                })
440            }
441            Seed::ContainerScript { script } => {
442                hash_chain.update(script);
443
444                Ok(LoadedSeed::ContainerScript {
445                    cache_status: CacheStatus::from_cache_key(
446                        hash_chain.cache_key(),
447                        backend,
448                        instance_name,
449                    )
450                    .await,
451                    name,
452                    script: script.clone(),
453                })
454            }
455        }
456    }
457}
458
459#[derive(Debug, thiserror::Error)]
460pub enum LoadError {
461    #[error("Failed to load seed {name}: could not read file {path}: {source}")]
462    FileRead {
463        name: SeedName,
464        path: std::path::PathBuf,
465        source: std::io::Error,
466    },
467    #[error(
468        "Failed to load seed {name}: could not read {path} at git revision {git_revision}: {message}"
469    )]
470    GitRevision {
471        name: SeedName,
472        path: std::path::PathBuf,
473        git_revision: String,
474        message: String,
475    },
476    #[error("Failed to load seed {name}: cache key command {command} failed: {message}")]
477    KeyCommand {
478        name: SeedName,
479        command: String,
480        message: String,
481    },
482    #[error("Failed to load seed {name}: cache key script failed: {message}")]
483    KeyScript { name: SeedName, message: String },
484}
485
486#[derive(Clone, Debug, PartialEq)]
487pub enum LoadedSeed {
488    SqlFile {
489        cache_status: CacheStatus,
490        name: SeedName,
491        path: std::path::PathBuf,
492        content: String,
493    },
494    SqlFileGitRevision {
495        cache_status: CacheStatus,
496        name: SeedName,
497        path: std::path::PathBuf,
498        git_revision: String,
499        content: String,
500    },
501    Command {
502        cache_status: CacheStatus,
503        cache_key_output: Option<Vec<u8>>,
504        name: SeedName,
505        command: Command,
506    },
507    Script {
508        cache_status: CacheStatus,
509        name: SeedName,
510        script: String,
511    },
512    ContainerScript {
513        cache_status: CacheStatus,
514        name: SeedName,
515        script: String,
516    },
517}
518
519impl LoadedSeed {
520    #[must_use]
521    pub fn cache_status(&self) -> &CacheStatus {
522        match self {
523            Self::SqlFile { cache_status, .. }
524            | Self::SqlFileGitRevision { cache_status, .. }
525            | Self::Command { cache_status, .. }
526            | Self::Script { cache_status, .. }
527            | Self::ContainerScript { cache_status, .. } => cache_status,
528        }
529    }
530
531    #[must_use]
532    pub fn name(&self) -> &SeedName {
533        match self {
534            Self::SqlFile { name, .. }
535            | Self::SqlFileGitRevision { name, .. }
536            | Self::Command { name, .. }
537            | Self::Script { name, .. }
538            | Self::ContainerScript { name, .. } => name,
539        }
540    }
541
542    fn variant_name(&self) -> &'static str {
543        match self {
544            Self::SqlFile { .. } => "sql-file",
545            Self::SqlFileGitRevision { .. } => "sql-file-git-revision",
546            Self::Command { .. } => "command",
547            Self::Script { .. } => "script",
548            Self::ContainerScript { .. } => "container-script",
549        }
550    }
551}
552
553struct HashChain {
554    hasher: Option<sha2::Sha256>,
555}
556
557impl HashChain {
558    fn new() -> Self {
559        use sha2::Digest;
560
561        Self {
562            hasher: Some(sha2::Sha256::new()),
563        }
564    }
565
566    fn update(&mut self, bytes: impl AsRef<[u8]>) {
567        use sha2::Digest;
568
569        if let Some(ref mut hasher) = self.hasher {
570            hasher.update(bytes)
571        }
572    }
573
574    fn cache_key(&self) -> Option<CacheKey> {
575        use sha2::Digest;
576
577        self.hasher
578            .as_ref()
579            .map(|hasher| hasher.clone().finalize().into())
580    }
581
582    fn stop(&mut self) {
583        self.hasher = None
584    }
585}
586
587#[derive(Debug, PartialEq)]
588pub struct LoadedSeeds<'a> {
589    image: &'a crate::image::Image,
590    seeds: Vec<LoadedSeed>,
591}
592
593impl<'a> LoadedSeeds<'a> {
594    pub async fn load(
595        image: &'a crate::image::Image,
596        ssl_config: Option<&crate::definition::SslConfig>,
597        seeds: &indexmap::IndexMap<SeedName, Seed>,
598        backend: &ociman::Backend,
599        instance_name: &crate::InstanceName,
600    ) -> Result<Self, LoadError> {
601        let mut hash_chain = HashChain::new();
602        let mut loaded_seeds = Vec::new();
603
604        hash_chain.update(crate::VERSION_STR);
605        hash_chain.update(image.to_string());
606
607        match ssl_config {
608            Some(crate::definition::SslConfig::Generated { hostname }) => {
609                hash_chain.update("ssl:generated:");
610                hash_chain.update(hostname.as_str());
611            }
612            None => {
613                hash_chain.update("ssl:none");
614            }
615        }
616
617        for (name, seed) in seeds {
618            let loaded_seed = seed
619                .load(name.clone(), &mut hash_chain, backend, instance_name)
620                .await?;
621            loaded_seeds.push(loaded_seed);
622        }
623
624        Ok(Self {
625            image,
626            seeds: loaded_seeds,
627        })
628    }
629
630    pub fn iter_seeds(&self) -> impl Iterator<Item = &LoadedSeed> {
631        self.seeds.iter()
632    }
633
634    pub fn print(&self, instance_name: &crate::InstanceName) {
635        println!("Instance: {instance_name}");
636        println!("Image:    {}", self.image);
637        println!("Version:  {}", crate::VERSION_STR);
638        println!();
639
640        let mut table = comfy_table::Table::new();
641
642        table
643            .load_preset(comfy_table::presets::NOTHING)
644            .set_header(["Seed", "Type", "Status"]);
645
646        for seed in &self.seeds {
647            table.add_row([
648                seed.name().as_str(),
649                seed.variant_name(),
650                seed.cache_status().status_str(),
651            ]);
652        }
653
654        println!("{table}");
655    }
656
657    pub fn print_json(&self, instance_name: &crate::InstanceName) {
658        #[derive(serde::Serialize)]
659        struct Output<'a> {
660            instance: &'a str,
661            image: String,
662            version: &'a str,
663            seeds: Vec<SeedOutput<'a>>,
664        }
665
666        #[derive(serde::Serialize)]
667        struct SeedOutput<'a> {
668            name: &'a str,
669            r#type: &'a str,
670            status: &'a str,
671            #[serde(skip_serializing_if = "Option::is_none")]
672            reference: Option<String>,
673        }
674
675        let output = Output {
676            instance: instance_name.as_ref(),
677            image: self.image.to_string(),
678            version: crate::VERSION_STR,
679            seeds: self
680                .seeds
681                .iter()
682                .map(|seed| SeedOutput {
683                    name: seed.name().as_str(),
684                    r#type: seed.variant_name(),
685                    status: seed.cache_status().status_str(),
686                    reference: seed.cache_status().reference().map(|r| r.to_string()),
687                })
688                .collect(),
689        };
690
691        println!("{}", serde_json::to_string_pretty(&output).unwrap());
692    }
693}
694
695#[cfg(test)]
696mod test {
697    use super::*;
698
699    #[test]
700    fn parse_valid_simple() {
701        let name: SeedName = "schema".parse().unwrap();
702        assert_eq!(name.to_string(), "schema");
703        assert_eq!(name.as_str(), "schema");
704    }
705
706    #[test]
707    fn parse_valid_with_dash() {
708        let name: SeedName = "create-users-table".parse().unwrap();
709        assert_eq!(name.to_string(), "create-users-table");
710    }
711
712    #[test]
713    fn parse_valid_single_char() {
714        let name: SeedName = "a".parse().unwrap();
715        assert_eq!(name.to_string(), "a");
716    }
717
718    #[test]
719    fn parse_valid_numeric() {
720        let name: SeedName = "123".parse().unwrap();
721        assert_eq!(name.to_string(), "123");
722    }
723
724    #[test]
725    fn parse_valid_max_length() {
726        let input = "a".repeat(SEED_NAME_MAX_LENGTH);
727        let name: SeedName = input.parse().unwrap();
728        assert_eq!(name.to_string(), input);
729    }
730
731    #[test]
732    fn parse_empty_fails() {
733        assert_eq!("".parse::<SeedName>(), Err(SeedNameError::Empty));
734        assert_eq!(SeedName::try_from(String::new()), Err(SeedNameError::Empty));
735    }
736
737    #[test]
738    fn parse_too_long_fails() {
739        let input = "a".repeat(SEED_NAME_MAX_LENGTH + 1);
740        assert_eq!(input.parse::<SeedName>(), Err(SeedNameError::TooLong));
741    }
742
743    #[test]
744    fn parse_starts_with_dash_fails() {
745        assert_eq!(
746            "-schema".parse::<SeedName>(),
747            Err(SeedNameError::StartsWithDash)
748        );
749    }
750
751    #[test]
752    fn parse_ends_with_dash_fails() {
753        assert_eq!(
754            "schema-".parse::<SeedName>(),
755            Err(SeedNameError::EndsWithDash)
756        );
757    }
758
759    #[test]
760    fn parse_uppercase_fails() {
761        assert_eq!(
762            "Schema".parse::<SeedName>(),
763            Err(SeedNameError::InvalidCharacter)
764        );
765    }
766
767    #[test]
768    fn parse_underscore_fails() {
769        assert_eq!(
770            "create_table".parse::<SeedName>(),
771            Err(SeedNameError::InvalidCharacter)
772        );
773    }
774
775    #[test]
776    fn parse_space_fails() {
777        assert_eq!(
778            "my seed".parse::<SeedName>(),
779            Err(SeedNameError::InvalidCharacter)
780        );
781    }
782
783    #[test]
784    fn try_from_string_valid() {
785        assert_eq!(
786            SeedName::try_from("valid-name".to_string()),
787            Ok(SeedName::from_static_or_panic("valid-name"))
788        );
789    }
790
791    #[test]
792    fn from_static_or_panic_works() {
793        const NAME: SeedName = SeedName::from_static_or_panic("my-seed");
794        assert_eq!(NAME.as_str(), "my-seed");
795    }
796
797    #[test]
798    fn test_cache_status_uncacheable() {
799        let loaded_seed = LoadedSeed::Command {
800            cache_status: CacheStatus::Uncacheable,
801            cache_key_output: None,
802            name: "run-migrations".parse().unwrap(),
803            command: Command::new("migrate", ["up"]),
804        };
805
806        assert!(loaded_seed.cache_status().reference().is_none());
807        assert!(!loaded_seed.cache_status().is_hit());
808    }
809
810    #[test]
811    fn test_cache_status_miss() {
812        let reference: ociman::Reference =
813            "pg-ephemeral/main:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
814                .parse()
815                .unwrap();
816
817        let loaded_seed = LoadedSeed::SqlFile {
818            cache_status: CacheStatus::Miss {
819                reference: reference.clone(),
820            },
821            name: "schema".parse().unwrap(),
822            path: "schema.sql".into(),
823            content: "CREATE TABLE test();".to_string(),
824        };
825
826        assert_eq!(loaded_seed.cache_status().reference(), Some(&reference));
827        assert!(!loaded_seed.cache_status().is_hit());
828    }
829
830    #[test]
831    fn test_cache_status_hit() {
832        let reference: ociman::Reference =
833            "pg-ephemeral/main:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
834                .parse()
835                .unwrap();
836
837        let loaded_seed = LoadedSeed::SqlFile {
838            cache_status: CacheStatus::Hit {
839                reference: reference.clone(),
840            },
841            name: "schema".parse().unwrap(),
842            path: "schema.sql".into(),
843            content: "CREATE TABLE test();".to_string(),
844        };
845
846        assert_eq!(loaded_seed.cache_status().reference(), Some(&reference));
847        assert!(loaded_seed.cache_status().is_hit());
848    }
849}