Skip to main content

pg_ephemeral/
seed.rs

1use git_proc::Build;
2
3/// Content-addressed cache hash for a seed.
4///
5/// Wraps the SHA-256 digest produced by [`HashChain`] so the type expresses
6/// "seed hash", not "32 raw bytes". Round-trips through hex strings via
7/// [`std::fmt::Display`] / [`std::str::FromStr`] / serde.
8#[derive(Clone, Debug, Eq, Hash, PartialEq)]
9pub struct SeedHash(sha2::digest::Output<sha2::Sha256>);
10
11impl SeedHash {
12    #[must_use]
13    pub fn as_bytes(&self) -> &[u8] {
14        self.0.as_slice()
15    }
16}
17
18impl From<sha2::digest::Output<sha2::Sha256>> for SeedHash {
19    fn from(digest: sha2::digest::Output<sha2::Sha256>) -> Self {
20        Self(digest)
21    }
22}
23
24impl std::fmt::Display for SeedHash {
25    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        formatter.write_str(&hex::encode(self.0))
27    }
28}
29
30#[derive(Debug, thiserror::Error)]
31pub enum SeedHashError {
32    #[error("expected 64 hex characters, got {0}")]
33    InvalidLength(usize),
34    #[error("invalid hex character")]
35    InvalidHex,
36}
37
38impl std::str::FromStr for SeedHash {
39    type Err = SeedHashError;
40
41    fn from_str(input: &str) -> Result<Self, Self::Err> {
42        let input_len = input.len();
43        let decoded = hex::decode(input).map_err(|_| SeedHashError::InvalidHex)?;
44        sha2::digest::Output::<sha2::Sha256>::try_from(decoded.as_slice())
45            .map(Self)
46            .map_err(|_| SeedHashError::InvalidLength(input_len))
47    }
48}
49
50impl serde::Serialize for SeedHash {
51    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
52        serializer.serialize_str(&hex::encode(self.0))
53    }
54}
55
56impl<'de> serde::Deserialize<'de> for SeedHash {
57    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
58        let raw = String::deserialize(deserializer)?;
59        raw.parse().map_err(serde::de::Error::custom)
60    }
61}
62
63#[cfg(test)]
64mod seed_hash_tests {
65    use super::*;
66
67    /// SHA-256 of the empty input (well-known constant).
68    const EMPTY_DIGEST_HEX: &str =
69        "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
70
71    fn empty_seed_hash() -> SeedHash {
72        use sha2::Digest;
73        sha2::Sha256::new().finalize().into()
74    }
75
76    #[test]
77    fn from_digest_via_into() {
78        // Confirms the From<Output<Sha256>> impl works through .into().
79        let hash = empty_seed_hash();
80        assert_eq!(hash.to_string(), EMPTY_DIGEST_HEX);
81    }
82
83    #[test]
84    fn display_is_lowercase_hex() {
85        let hash = empty_seed_hash();
86        let rendered = hash.to_string();
87        assert_eq!(rendered.len(), 64);
88        assert!(
89            rendered
90                .chars()
91                .all(|c| c.is_ascii_hexdigit() && !c.is_ascii_uppercase())
92        );
93        assert_eq!(rendered, EMPTY_DIGEST_HEX);
94    }
95
96    #[test]
97    fn from_str_round_trip() {
98        let parsed: SeedHash = EMPTY_DIGEST_HEX.parse().unwrap();
99        assert_eq!(parsed.to_string(), EMPTY_DIGEST_HEX);
100        assert_eq!(parsed, empty_seed_hash());
101    }
102
103    #[test]
104    fn json_round_trip() {
105        let hash = empty_seed_hash();
106        let json = serde_json::to_string(&hash).unwrap();
107        assert_eq!(json, format!("\"{EMPTY_DIGEST_HEX}\""));
108        let parsed: SeedHash = serde_json::from_str(&json).unwrap();
109        assert_eq!(parsed, hash);
110    }
111
112    #[test]
113    fn rejects_empty() {
114        assert!(matches!(
115            "".parse::<SeedHash>(),
116            Err(SeedHashError::InvalidLength(0))
117        ));
118    }
119
120    #[test]
121    fn rejects_short_even_length() {
122        assert!(matches!(
123            "abcd".parse::<SeedHash>(),
124            Err(SeedHashError::InvalidLength(4))
125        ));
126    }
127
128    #[test]
129    fn rejects_long_even_length() {
130        let oversized = "a".repeat(66);
131        assert!(matches!(
132            oversized.parse::<SeedHash>(),
133            Err(SeedHashError::InvalidLength(66))
134        ));
135    }
136
137    #[test]
138    fn rejects_odd_length() {
139        // Odd length fails the hex decode before the length check applies.
140        assert!(matches!(
141            "abc".parse::<SeedHash>(),
142            Err(SeedHashError::InvalidHex)
143        ));
144    }
145
146    #[test]
147    fn rejects_non_hex_character() {
148        let mut bad = EMPTY_DIGEST_HEX.to_string();
149        bad.replace_range(0..1, "z");
150        assert!(matches!(
151            bad.parse::<SeedHash>(),
152            Err(SeedHashError::InvalidHex)
153        ));
154    }
155
156    #[test]
157    fn rejects_uppercase_hex_via_serde_round_trip() {
158        // FromStr allows mixed case (hex::decode accepts both); Display always
159        // emits lowercase. Round-trip through Display normalises.
160        let upper = EMPTY_DIGEST_HEX.to_uppercase();
161        let parsed: SeedHash = upper.parse().unwrap();
162        assert_eq!(parsed.to_string(), EMPTY_DIGEST_HEX);
163    }
164}
165
166#[derive(Clone, Debug, PartialEq)]
167pub enum CacheStatus {
168    Hit {
169        hash: SeedHash,
170        reference: ociman::Reference,
171        /// Labels read from the cache image during presence detection. Cached
172        /// here so callers (e.g. `cache status --json`) can decode the
173        /// pg-ephemeral metadata without issuing a second `inspect` round-trip
174        /// per cache layer.
175        labels: ociman::label::ImageLabels,
176    },
177    Miss {
178        hash: SeedHash,
179        reference: ociman::Reference,
180    },
181    Uncacheable,
182}
183
184impl CacheStatus {
185    async fn from_cache_key(
186        cache_key: Option<SeedHash>,
187        backend: &ociman::Backend,
188        instance_name: &crate::InstanceName,
189    ) -> Result<Self, LoadError> {
190        let Some(hash) = cache_key else {
191            return Ok(Self::Uncacheable);
192        };
193        let reference: ociman::Reference = format!("pg-ephemeral/{instance_name}:{hash}")
194            .parse()
195            .unwrap();
196        // Single inspect round-trip determines presence and (on Hit) returns
197        // the labels in one call. NotFound from the underlying inspect is the
198        // documented absence signal, so we map it to Miss instead of an error.
199        match backend.image_labels(&reference).await {
200            Ok(labels) => Ok(Self::Hit {
201                hash,
202                reference,
203                labels,
204            }),
205            Err(ociman::label::ImageError::Inspect {
206                source: ociman::InspectError::NotFound,
207                ..
208            }) => Ok(Self::Miss { hash, reference }),
209            Err(source) => Err(LoadError::InspectCacheImage { reference, source }),
210        }
211    }
212
213    #[must_use]
214    pub fn reference(&self) -> Option<&ociman::Reference> {
215        match self {
216            Self::Hit { reference, .. } | Self::Miss { reference, .. } => Some(reference),
217            Self::Uncacheable => None,
218        }
219    }
220
221    #[must_use]
222    pub fn hash(&self) -> Option<&SeedHash> {
223        match self {
224            Self::Hit { hash, .. } | Self::Miss { hash, .. } => Some(hash),
225            Self::Uncacheable => None,
226        }
227    }
228
229    #[must_use]
230    pub fn is_hit(&self) -> bool {
231        matches!(self, Self::Hit { .. })
232    }
233
234    #[must_use]
235    pub fn status_str(&self) -> &'static str {
236        match self {
237            Self::Hit { .. } => "hit",
238            Self::Miss { .. } => "miss",
239            Self::Uncacheable => "uncacheable",
240        }
241    }
242}
243
244/// Maximum length of a seed name in bytes.
245pub const SEED_NAME_MAX_LENGTH: usize = 63;
246
247/// Error parsing a seed name.
248#[derive(Debug, Clone, Copy, PartialEq, Eq)]
249pub enum SeedNameError {
250    /// Seed name cannot be empty.
251    Empty,
252    /// Seed name exceeds maximum length.
253    TooLong,
254    /// Seed name contains an invalid character.
255    InvalidCharacter,
256    /// Seed name starts with a dash.
257    StartsWithDash,
258    /// Seed name ends with a dash.
259    EndsWithDash,
260}
261
262impl SeedNameError {
263    #[must_use]
264    const fn message(&self) -> &'static str {
265        match self {
266            Self::Empty => "seed name cannot be empty",
267            Self::TooLong => "seed name exceeds maximum length of 63 bytes",
268            Self::InvalidCharacter => {
269                "seed name must contain only lowercase ASCII alphanumeric characters or dashes"
270            }
271            Self::StartsWithDash => "seed name cannot start with a dash",
272            Self::EndsWithDash => "seed name cannot end with a dash",
273        }
274    }
275}
276
277impl std::fmt::Display for SeedNameError {
278    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
279        write!(formatter, "{}", self.message())
280    }
281}
282
283impl std::error::Error for SeedNameError {}
284
285const fn validate_seed_name(input: &str) -> Option<SeedNameError> {
286    let bytes = input.as_bytes();
287
288    if bytes.is_empty() {
289        return Some(SeedNameError::Empty);
290    }
291
292    if bytes.len() > SEED_NAME_MAX_LENGTH {
293        return Some(SeedNameError::TooLong);
294    }
295
296    if bytes[0] == b'-' {
297        return Some(SeedNameError::StartsWithDash);
298    }
299
300    if bytes[bytes.len() - 1] == b'-' {
301        return Some(SeedNameError::EndsWithDash);
302    }
303
304    let mut index = 0;
305
306    while index < bytes.len() {
307        let byte = bytes[index];
308        if !(byte.is_ascii_lowercase() || byte.is_ascii_digit() || byte == b'-') {
309            return Some(SeedNameError::InvalidCharacter);
310        }
311        index += 1;
312    }
313
314    None
315}
316
317#[derive(Clone, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
318#[serde(try_from = "String")]
319pub struct SeedName(std::borrow::Cow<'static, str>);
320
321impl SeedName {
322    /// Creates a new seed name from a static string.
323    ///
324    /// # Panics
325    ///
326    /// Panics if the input is empty, exceeds [`SEED_NAME_MAX_LENGTH`],
327    /// contains non-lowercase-alphanumeric/dash characters,
328    /// or starts/ends with a dash.
329    #[must_use]
330    pub const fn from_static_or_panic(input: &'static str) -> Self {
331        match validate_seed_name(input) {
332            Some(error) => panic!("{}", error.message()),
333            None => Self(std::borrow::Cow::Borrowed(input)),
334        }
335    }
336
337    /// Returns the seed name as a string slice.
338    #[must_use]
339    pub fn as_str(&self) -> &str {
340        &self.0
341    }
342}
343
344impl std::fmt::Display for SeedName {
345    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
346        write!(f, "{}", self.0)
347    }
348}
349
350impl AsRef<str> for SeedName {
351    fn as_ref(&self) -> &str {
352        &self.0
353    }
354}
355
356#[derive(Debug, PartialEq, Eq, thiserror::Error)]
357#[error("Duplicate seed name: {0}")]
358pub struct DuplicateSeedName(pub SeedName);
359
360impl std::str::FromStr for SeedName {
361    type Err = SeedNameError;
362
363    fn from_str(value: &str) -> Result<Self, Self::Err> {
364        match validate_seed_name(value) {
365            Some(error) => Err(error),
366            None => Ok(Self(std::borrow::Cow::Owned(value.to_owned()))),
367        }
368    }
369}
370
371impl TryFrom<String> for SeedName {
372    type Error = SeedNameError;
373
374    fn try_from(value: String) -> Result<Self, Self::Error> {
375        match validate_seed_name(&value) {
376            Some(error) => Err(error),
377            None => Ok(Self(std::borrow::Cow::Owned(value))),
378        }
379    }
380}
381
382#[derive(Clone, Debug, PartialEq)]
383pub struct Command {
384    pub command: String,
385    pub arguments: Vec<String>,
386}
387
388impl Command {
389    pub fn new(
390        command: impl Into<String>,
391        arguments: impl IntoIterator<Item = impl Into<String>>,
392    ) -> Self {
393        Self {
394            command: command.into(),
395            arguments: arguments.into_iter().map(|a| a.into()).collect(),
396        }
397    }
398}
399
400#[derive(Clone, Debug, serde::Deserialize, serde::Serialize, PartialEq)]
401#[serde(tag = "type", rename_all = "kebab-case")]
402pub enum SeedCacheConfig {
403    /// Disable caching, breaks the cache chain
404    None,
405    /// Hash the command and arguments
406    CommandHash,
407    /// Run a command to get cache key input
408    KeyCommand {
409        command: String,
410        #[serde(default)]
411        arguments: Vec<String>,
412    },
413    /// Run a script to get cache key input
414    KeyScript { script: String },
415}
416
417#[derive(Clone, Debug, PartialEq)]
418pub enum Seed {
419    SqlFile {
420        path: std::path::PathBuf,
421    },
422    SqlFileGitRevision {
423        git_revision: String,
424        path: std::path::PathBuf,
425    },
426    SqlStatement {
427        statement: String,
428    },
429    Command {
430        command: Command,
431        cache: SeedCacheConfig,
432    },
433    Script {
434        script: String,
435        cache: SeedCacheConfig,
436    },
437    ContainerScript {
438        script: String,
439    },
440    CsvFile {
441        path: std::path::PathBuf,
442        table: pg_client::QualifiedTable,
443        delimiter: char,
444    },
445}
446
447/// Apply a cache strategy to the hash chain for a cacheable seed.
448///
449/// Always folds the seed's intrinsic `base` chunks (the command + arguments, or the
450/// script body) into the chain first, then layers the cache strategy on top:
451/// - `None`: stops the chain; nothing further cacheable.
452/// - `CommandHash`: `base` alone drives the cache key.
453/// - `KeyCommand` / `KeyScript`: runs the configured command/script and folds its
454///   stdout into the chain, adding an external cache key input alongside `base`.
455///
456/// stderr from the key command/script is inherited so users see it live.
457async fn apply_cache_config(
458    cache: &SeedCacheConfig,
459    hash_chain: &mut HashChain,
460    name: &SeedName,
461    base: &[&[u8]],
462) -> Result<(), LoadError> {
463    match cache {
464        SeedCacheConfig::None => {
465            hash_chain.stop();
466            Ok(())
467        }
468        SeedCacheConfig::CommandHash => {
469            for chunk in base {
470                hash_chain.update(chunk);
471            }
472            Ok(())
473        }
474        SeedCacheConfig::KeyCommand {
475            command: key_command,
476            arguments: key_arguments,
477        } => {
478            for chunk in base {
479                hash_chain.update(chunk);
480            }
481
482            let output = cmd_proc::Command::new(key_command)
483                .arguments(key_arguments)
484                .stdout_capture()
485                .run()
486                .await
487                .map_err(|source| LoadError::KeyCommand {
488                    name: name.clone(),
489                    command: key_command.clone(),
490                    source,
491                })?;
492
493            hash_chain.update(&output.bytes);
494            Ok(())
495        }
496        SeedCacheConfig::KeyScript { script: key_script } => {
497            for chunk in base {
498                hash_chain.update(chunk);
499            }
500
501            let output = cmd_proc::Command::new("sh")
502                .arguments(["-e", "-c"])
503                .argument(key_script)
504                .stdout_capture()
505                .run()
506                .await
507                .map_err(|source| LoadError::KeyScript {
508                    name: name.clone(),
509                    source,
510                })?;
511
512            hash_chain.update(&output.bytes);
513            Ok(())
514        }
515    }
516}
517
518impl Seed {
519    async fn load(
520        &self,
521        name: SeedName,
522        hash_chain: &mut HashChain,
523        backend: &ociman::Backend,
524        instance_name: &crate::InstanceName,
525    ) -> Result<LoadedSeed, LoadError> {
526        match self {
527            Seed::SqlFile { path } => {
528                let content =
529                    std::fs::read_to_string(path).map_err(|source| LoadError::FileRead {
530                        name: name.clone(),
531                        path: path.clone(),
532                        source,
533                    })?;
534
535                hash_chain.update(&content);
536
537                Ok(LoadedSeed::SqlFile {
538                    cache_status: CacheStatus::from_cache_key(
539                        hash_chain.cache_key(),
540                        backend,
541                        instance_name,
542                    )
543                    .await?,
544                    name,
545                    path: path.clone(),
546                    content,
547                })
548            }
549            Seed::SqlFileGitRevision { path, git_revision } => {
550                let output =
551                    git_proc::show::new(&format!("{git_revision}:{}", path.to_str().unwrap()))
552                        .build()
553                        .stdout_capture()
554                        .stderr_capture()
555                        .accept_nonzero_exit()
556                        .run()
557                        .await
558                        .map_err(|error| LoadError::GitRevision {
559                            name: name.clone(),
560                            path: path.clone(),
561                            git_revision: git_revision.clone(),
562                            message: error.to_string(),
563                        })?;
564
565                if output.status.success() {
566                    let content = String::from_utf8(output.stdout).map_err(|error| {
567                        LoadError::GitRevision {
568                            name: name.clone(),
569                            path: path.clone(),
570                            git_revision: git_revision.clone(),
571                            message: error.to_string(),
572                        }
573                    })?;
574
575                    hash_chain.update(&content);
576
577                    Ok(LoadedSeed::SqlFileGitRevision {
578                        cache_status: CacheStatus::from_cache_key(
579                            hash_chain.cache_key(),
580                            backend,
581                            instance_name,
582                        )
583                        .await?,
584                        name,
585                        path: path.clone(),
586                        git_revision: git_revision.clone(),
587                        content,
588                    })
589                } else {
590                    let message = String::from_utf8(output.stderr).map_err(|error| {
591                        LoadError::GitRevision {
592                            name: name.clone(),
593                            path: path.clone(),
594                            git_revision: git_revision.clone(),
595                            message: error.to_string(),
596                        }
597                    })?;
598                    Err(LoadError::GitRevision {
599                        name,
600                        path: path.clone(),
601                        git_revision: git_revision.clone(),
602                        message,
603                    })
604                }
605            }
606            Seed::SqlStatement { statement } => {
607                hash_chain.update(statement);
608
609                Ok(LoadedSeed::SqlStatement {
610                    cache_status: CacheStatus::from_cache_key(
611                        hash_chain.cache_key(),
612                        backend,
613                        instance_name,
614                    )
615                    .await?,
616                    name,
617                    statement: statement.clone(),
618                })
619            }
620            Seed::Command { command, cache } => {
621                let mut base: Vec<&[u8]> = Vec::with_capacity(1 + command.arguments.len());
622                base.push(command.command.as_bytes());
623                for argument in &command.arguments {
624                    base.push(argument.as_bytes());
625                }
626                apply_cache_config(cache, hash_chain, &name, &base).await?;
627
628                Ok(LoadedSeed::Command {
629                    cache_status: CacheStatus::from_cache_key(
630                        hash_chain.cache_key(),
631                        backend,
632                        instance_name,
633                    )
634                    .await?,
635                    name,
636                    command: command.clone(),
637                })
638            }
639            Seed::Script { script, cache } => {
640                apply_cache_config(cache, hash_chain, &name, &[script.as_bytes()]).await?;
641
642                Ok(LoadedSeed::Script {
643                    cache_status: CacheStatus::from_cache_key(
644                        hash_chain.cache_key(),
645                        backend,
646                        instance_name,
647                    )
648                    .await?,
649                    name,
650                    script: script.clone(),
651                })
652            }
653            Seed::ContainerScript { script } => {
654                hash_chain.update(script);
655
656                Ok(LoadedSeed::ContainerScript {
657                    cache_status: CacheStatus::from_cache_key(
658                        hash_chain.cache_key(),
659                        backend,
660                        instance_name,
661                    )
662                    .await?,
663                    name,
664                    script: script.clone(),
665                })
666            }
667            Seed::CsvFile {
668                path,
669                table,
670                delimiter,
671            } => {
672                let content =
673                    std::fs::read_to_string(path).map_err(|source| LoadError::FileRead {
674                        name: name.clone(),
675                        path: path.clone(),
676                        source,
677                    })?;
678
679                hash_chain.update(table.schema.as_ref());
680                hash_chain.update(table.table.as_ref());
681                hash_chain.update(&content);
682
683                Ok(LoadedSeed::CsvFile {
684                    cache_status: CacheStatus::from_cache_key(
685                        hash_chain.cache_key(),
686                        backend,
687                        instance_name,
688                    )
689                    .await?,
690                    name,
691                    path: path.clone(),
692                    table: table.clone(),
693                    delimiter: *delimiter,
694                    content,
695                })
696            }
697        }
698    }
699}
700
701#[derive(Debug, thiserror::Error)]
702pub enum LoadError {
703    #[error("Failed to load seed {name}: could not read file {path}: {source}")]
704    FileRead {
705        name: SeedName,
706        path: std::path::PathBuf,
707        source: std::io::Error,
708    },
709    #[error(
710        "Failed to load seed {name}: could not read {path} at git revision {git_revision}: {message}"
711    )]
712    GitRevision {
713        name: SeedName,
714        path: std::path::PathBuf,
715        git_revision: String,
716        message: String,
717    },
718    #[error("Failed to load seed {name}: cache key command {command} failed")]
719    KeyCommand {
720        name: SeedName,
721        command: String,
722        #[source]
723        source: cmd_proc::CommandError,
724    },
725    #[error("Failed to load seed {name}: cache key script failed")]
726    KeyScript {
727        name: SeedName,
728        #[source]
729        source: cmd_proc::CommandError,
730    },
731    #[error("Failed to inspect cache image {reference}")]
732    InspectCacheImage {
733        reference: ociman::Reference,
734        #[source]
735        source: ociman::label::ImageError,
736    },
737}
738
739#[derive(Clone, Debug, PartialEq)]
740pub enum LoadedSeed {
741    SqlFile {
742        cache_status: CacheStatus,
743        name: SeedName,
744        path: std::path::PathBuf,
745        content: String,
746    },
747    SqlFileGitRevision {
748        cache_status: CacheStatus,
749        name: SeedName,
750        path: std::path::PathBuf,
751        git_revision: String,
752        content: String,
753    },
754    SqlStatement {
755        cache_status: CacheStatus,
756        name: SeedName,
757        statement: String,
758    },
759    Command {
760        cache_status: CacheStatus,
761        name: SeedName,
762        command: Command,
763    },
764    Script {
765        cache_status: CacheStatus,
766        name: SeedName,
767        script: String,
768    },
769    ContainerScript {
770        cache_status: CacheStatus,
771        name: SeedName,
772        script: String,
773    },
774    CsvFile {
775        cache_status: CacheStatus,
776        name: SeedName,
777        path: std::path::PathBuf,
778        table: pg_client::QualifiedTable,
779        delimiter: char,
780        content: String,
781    },
782}
783
784impl LoadedSeed {
785    #[must_use]
786    pub fn cache_status(&self) -> &CacheStatus {
787        match self {
788            Self::SqlFile { cache_status, .. }
789            | Self::SqlFileGitRevision { cache_status, .. }
790            | Self::SqlStatement { cache_status, .. }
791            | Self::Command { cache_status, .. }
792            | Self::Script { cache_status, .. }
793            | Self::ContainerScript { cache_status, .. }
794            | Self::CsvFile { cache_status, .. } => cache_status,
795        }
796    }
797
798    #[must_use]
799    pub fn name(&self) -> &SeedName {
800        match self {
801            Self::SqlFile { name, .. }
802            | Self::SqlFileGitRevision { name, .. }
803            | Self::SqlStatement { name, .. }
804            | Self::Command { name, .. }
805            | Self::Script { name, .. }
806            | Self::ContainerScript { name, .. }
807            | Self::CsvFile { name, .. } => name,
808        }
809    }
810
811    fn variant_name(&self) -> &'static str {
812        match self {
813            Self::SqlFile { .. } => "sql-file",
814            Self::SqlFileGitRevision { .. } => "sql-file-git-revision",
815            Self::SqlStatement { .. } => "sql-statement",
816            Self::Command { .. } => "command",
817            Self::Script { .. } => "script",
818            Self::ContainerScript { .. } => "container-script",
819            Self::CsvFile { .. } => "csv-file",
820        }
821    }
822}
823
824struct HashChain {
825    hasher: Option<sha2::Sha256>,
826}
827
828impl HashChain {
829    fn new() -> Self {
830        use sha2::Digest;
831
832        Self {
833            hasher: Some(sha2::Sha256::new()),
834        }
835    }
836
837    fn update(&mut self, bytes: impl AsRef<[u8]>) {
838        use sha2::Digest;
839
840        if let Some(ref mut hasher) = self.hasher {
841            hasher.update(bytes)
842        }
843    }
844
845    fn cache_key(&self) -> Option<SeedHash> {
846        use sha2::Digest;
847
848        self.hasher
849            .as_ref()
850            .map(|hasher| hasher.clone().finalize().into())
851    }
852
853    fn stop(&mut self) {
854        self.hasher = None
855    }
856}
857
858#[derive(Debug, PartialEq)]
859pub struct LoadedSeeds<'a> {
860    image: &'a crate::image::Image,
861    seeds: Vec<LoadedSeed>,
862}
863
864impl<'a> LoadedSeeds<'a> {
865    pub async fn load(
866        image: &'a crate::image::Image,
867        ssl_config: Option<&crate::definition::SslConfig>,
868        parameters: &pg_client::parameter::Map,
869        seeds: &indexmap::IndexMap<SeedName, Seed>,
870        backend: &ociman::Backend,
871        instance_name: &crate::InstanceName,
872    ) -> Result<Self, LoadError> {
873        let mut hash_chain = HashChain::new();
874        let mut loaded_seeds = Vec::new();
875
876        hash_chain.update(crate::VERSION_STR);
877        hash_chain.update(image.to_string());
878
879        match ssl_config {
880            Some(crate::definition::SslConfig::Generated { hostname }) => {
881                hash_chain.update("ssl:generated:");
882                hash_chain.update(hostname.as_str());
883            }
884            None => {
885                hash_chain.update("ssl:none");
886            }
887        }
888
889        // PG parameters are passed as `-c name=value` flags at container
890        // start, so a parameter change yields a differently-running PG and
891        // must invalidate every cache layer. BTreeMap iteration is sorted,
892        // so the hash is deterministic for a given parameter set.
893        hash_chain.update("parameters:");
894        for (name, value) in parameters {
895            hash_chain.update(name.as_str());
896            hash_chain.update("=");
897            hash_chain.update(value.as_str());
898            hash_chain.update("\0");
899        }
900
901        for (name, seed) in seeds {
902            let loaded_seed = seed
903                .load(name.clone(), &mut hash_chain, backend, instance_name)
904                .await?;
905            loaded_seeds.push(loaded_seed);
906        }
907
908        Ok(Self {
909            image,
910            seeds: loaded_seeds,
911        })
912    }
913
914    pub fn iter_seeds(&self) -> impl Iterator<Item = &LoadedSeed> {
915        self.seeds.iter()
916    }
917
918    pub fn print(&self, instance_name: &crate::InstanceName) {
919        println!("Instance: {instance_name}");
920        println!("Image:    {}", self.image);
921        println!("Version:  {}", crate::VERSION_STR);
922        println!();
923
924        let mut table = comfy_table::Table::new();
925
926        table
927            .load_preset(comfy_table::presets::NOTHING)
928            .set_header(["Seed", "Type", "Status"]);
929
930        for seed in &self.seeds {
931            table.add_row([
932                seed.name().as_str(),
933                seed.variant_name(),
934                seed.cache_status().status_str(),
935            ]);
936        }
937
938        println!("{table}");
939    }
940
941    pub fn print_json(&self, instance_name: &crate::InstanceName) {
942        #[derive(serde::Serialize)]
943        struct Output<'a> {
944            instance: &'a str,
945            base_image: String,
946            version: &'a str,
947            summary: Summary,
948            seeds: Vec<SeedOutput<'a>>,
949        }
950
951        #[derive(serde::Serialize)]
952        struct Summary {
953            total: usize,
954            hits: usize,
955            misses: usize,
956            uncacheable: usize,
957        }
958
959        #[derive(serde::Serialize)]
960        struct SeedOutput<'a> {
961            name: &'a str,
962            r#type: &'a str,
963            status: &'a str,
964            #[serde(skip_serializing_if = "Option::is_none")]
965            cache_image: Option<String>,
966            #[serde(skip_serializing_if = "Option::is_none")]
967            reason: Option<&'static str>,
968            #[serde(skip_serializing_if = "Option::is_none")]
969            broken_by: Option<&'a str>,
970        }
971
972        let mut hits = 0;
973        let mut misses = 0;
974        let mut uncacheable = 0;
975        // Tracks the first seed that broke the cache chain (the only one whose
976        // own `cache = none` setting caused the break). Subsequent uncacheable
977        // seeds were broken by this predecessor, not by their own configuration.
978        let mut chain_breaker: Option<&str> = None;
979
980        let mut seed_outputs = Vec::with_capacity(self.seeds.len());
981
982        for seed in &self.seeds {
983            let status = seed.cache_status();
984            match status {
985                CacheStatus::Hit { .. } => hits += 1,
986                CacheStatus::Miss { .. } => misses += 1,
987                CacheStatus::Uncacheable => uncacheable += 1,
988            }
989
990            let (reason, broken_by) = match status {
991                CacheStatus::Uncacheable => match chain_breaker {
992                    Some(name) => (Some("chain_broken_by_predecessor"), Some(name)),
993                    None => {
994                        chain_breaker = Some(seed.name().as_str());
995                        (Some("cache_strategy_none"), None)
996                    }
997                },
998                CacheStatus::Hit { .. } | CacheStatus::Miss { .. } => (None, None),
999            };
1000
1001            seed_outputs.push(SeedOutput {
1002                name: seed.name().as_str(),
1003                r#type: seed.variant_name(),
1004                status: status.status_str(),
1005                cache_image: status.reference().map(ToString::to_string),
1006                reason,
1007                broken_by,
1008            });
1009        }
1010
1011        let output = Output {
1012            instance: instance_name.as_ref(),
1013            base_image: self.image.to_string(),
1014            version: crate::VERSION_STR,
1015            summary: Summary {
1016                total: self.seeds.len(),
1017                hits,
1018                misses,
1019                uncacheable,
1020            },
1021            seeds: seed_outputs,
1022        };
1023
1024        println!("{}", serde_json::to_string_pretty(&output).unwrap());
1025    }
1026}
1027
1028#[cfg(test)]
1029mod test {
1030    use super::*;
1031
1032    #[test]
1033    fn parse_valid_simple() {
1034        let name: SeedName = "schema".parse().unwrap();
1035        assert_eq!(name.to_string(), "schema");
1036        assert_eq!(name.as_str(), "schema");
1037    }
1038
1039    #[test]
1040    fn parse_valid_with_dash() {
1041        let name: SeedName = "create-users-table".parse().unwrap();
1042        assert_eq!(name.to_string(), "create-users-table");
1043    }
1044
1045    #[test]
1046    fn parse_valid_single_char() {
1047        let name: SeedName = "a".parse().unwrap();
1048        assert_eq!(name.to_string(), "a");
1049    }
1050
1051    #[test]
1052    fn parse_valid_numeric() {
1053        let name: SeedName = "123".parse().unwrap();
1054        assert_eq!(name.to_string(), "123");
1055    }
1056
1057    #[test]
1058    fn parse_valid_max_length() {
1059        let input = "a".repeat(SEED_NAME_MAX_LENGTH);
1060        let name: SeedName = input.parse().unwrap();
1061        assert_eq!(name.to_string(), input);
1062    }
1063
1064    #[test]
1065    fn parse_empty_fails() {
1066        assert_eq!("".parse::<SeedName>(), Err(SeedNameError::Empty));
1067        assert_eq!(SeedName::try_from(String::new()), Err(SeedNameError::Empty));
1068    }
1069
1070    #[test]
1071    fn parse_too_long_fails() {
1072        let input = "a".repeat(SEED_NAME_MAX_LENGTH + 1);
1073        assert_eq!(input.parse::<SeedName>(), Err(SeedNameError::TooLong));
1074    }
1075
1076    #[test]
1077    fn parse_starts_with_dash_fails() {
1078        assert_eq!(
1079            "-schema".parse::<SeedName>(),
1080            Err(SeedNameError::StartsWithDash)
1081        );
1082    }
1083
1084    #[test]
1085    fn parse_ends_with_dash_fails() {
1086        assert_eq!(
1087            "schema-".parse::<SeedName>(),
1088            Err(SeedNameError::EndsWithDash)
1089        );
1090    }
1091
1092    #[test]
1093    fn parse_uppercase_fails() {
1094        assert_eq!(
1095            "Schema".parse::<SeedName>(),
1096            Err(SeedNameError::InvalidCharacter)
1097        );
1098    }
1099
1100    #[test]
1101    fn parse_underscore_fails() {
1102        assert_eq!(
1103            "create_table".parse::<SeedName>(),
1104            Err(SeedNameError::InvalidCharacter)
1105        );
1106    }
1107
1108    #[test]
1109    fn parse_space_fails() {
1110        assert_eq!(
1111            "my seed".parse::<SeedName>(),
1112            Err(SeedNameError::InvalidCharacter)
1113        );
1114    }
1115
1116    #[test]
1117    fn try_from_string_valid() {
1118        assert_eq!(
1119            SeedName::try_from("valid-name".to_string()),
1120            Ok(SeedName::from_static_or_panic("valid-name"))
1121        );
1122    }
1123
1124    #[test]
1125    fn from_static_or_panic_works() {
1126        const NAME: SeedName = SeedName::from_static_or_panic("my-seed");
1127        assert_eq!(NAME.as_str(), "my-seed");
1128    }
1129
1130    #[test]
1131    fn test_cache_status_uncacheable() {
1132        let loaded_seed = LoadedSeed::Command {
1133            cache_status: CacheStatus::Uncacheable,
1134            name: "run-migrations".parse().unwrap(),
1135            command: Command::new("migrate", ["up"]),
1136        };
1137
1138        assert!(loaded_seed.cache_status().reference().is_none());
1139        assert!(!loaded_seed.cache_status().is_hit());
1140    }
1141
1142    #[test]
1143    fn test_cache_status_miss() {
1144        let hash: SeedHash = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
1145            .parse()
1146            .unwrap();
1147        let reference: ociman::Reference =
1148            "pg-ephemeral/main:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
1149                .parse()
1150                .unwrap();
1151
1152        let loaded_seed = LoadedSeed::SqlFile {
1153            cache_status: CacheStatus::Miss {
1154                hash: hash.clone(),
1155                reference: reference.clone(),
1156            },
1157            name: "schema".parse().unwrap(),
1158            path: "schema.sql".into(),
1159            content: "CREATE TABLE test();".to_string(),
1160        };
1161
1162        assert_eq!(loaded_seed.cache_status().reference(), Some(&reference));
1163        assert_eq!(loaded_seed.cache_status().hash(), Some(&hash));
1164        assert!(!loaded_seed.cache_status().is_hit());
1165    }
1166
1167    #[test]
1168    fn test_cache_status_hit() {
1169        let hash: SeedHash = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
1170            .parse()
1171            .unwrap();
1172        let reference: ociman::Reference =
1173            "pg-ephemeral/main:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
1174                .parse()
1175                .unwrap();
1176
1177        let loaded_seed = LoadedSeed::SqlFile {
1178            cache_status: CacheStatus::Hit {
1179                hash: hash.clone(),
1180                reference: reference.clone(),
1181                labels: ociman::label::ImageLabels::default(),
1182            },
1183            name: "schema".parse().unwrap(),
1184            path: "schema.sql".into(),
1185            content: "CREATE TABLE test();".to_string(),
1186        };
1187
1188        assert_eq!(loaded_seed.cache_status().reference(), Some(&reference));
1189        assert_eq!(loaded_seed.cache_status().hash(), Some(&hash));
1190        assert!(loaded_seed.cache_status().is_hit());
1191    }
1192}