1use git_proc::Build;
2
3type CacheKey = [u8; 32];
4
5#[derive(Clone, Debug, PartialEq)]
6pub enum CacheStatus {
7 Hit { reference: ociman::Reference },
8 Miss { reference: ociman::Reference },
9 Uncacheable,
10}
11
12impl CacheStatus {
13 async fn from_cache_key(
14 cache_key: Option<CacheKey>,
15 backend: &ociman::Backend,
16 instance_name: &crate::InstanceName,
17 ) -> Self {
18 match cache_key {
19 Some(key) => {
20 let reference = format!("pg-ephemeral/{}:{}", instance_name, hex::encode(key))
21 .parse()
22 .unwrap();
23 if backend.is_image_present(&reference).await {
24 Self::Hit { reference }
25 } else {
26 Self::Miss { reference }
27 }
28 }
29 None => Self::Uncacheable,
30 }
31 }
32
33 #[must_use]
34 pub fn reference(&self) -> Option<&ociman::Reference> {
35 match self {
36 Self::Hit { reference } | Self::Miss { reference } => Some(reference),
37 Self::Uncacheable => None,
38 }
39 }
40
41 #[must_use]
42 pub fn is_hit(&self) -> bool {
43 matches!(self, Self::Hit { .. })
44 }
45
46 #[must_use]
47 pub fn status_str(&self) -> &'static str {
48 match self {
49 Self::Hit { .. } => "hit",
50 Self::Miss { .. } => "miss",
51 Self::Uncacheable => "uncacheable",
52 }
53 }
54}
55
56pub const SEED_NAME_MAX_LENGTH: usize = 63;
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum SeedNameError {
62 Empty,
64 TooLong,
66 InvalidCharacter,
68 StartsWithDash,
70 EndsWithDash,
72}
73
74impl SeedNameError {
75 #[must_use]
76 const fn message(&self) -> &'static str {
77 match self {
78 Self::Empty => "seed name cannot be empty",
79 Self::TooLong => "seed name exceeds maximum length of 63 bytes",
80 Self::InvalidCharacter => {
81 "seed name must contain only lowercase ASCII alphanumeric characters or dashes"
82 }
83 Self::StartsWithDash => "seed name cannot start with a dash",
84 Self::EndsWithDash => "seed name cannot end with a dash",
85 }
86 }
87}
88
89impl std::fmt::Display for SeedNameError {
90 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91 write!(formatter, "{}", self.message())
92 }
93}
94
95impl std::error::Error for SeedNameError {}
96
97const fn validate_seed_name(input: &str) -> Option<SeedNameError> {
98 let bytes = input.as_bytes();
99
100 if bytes.is_empty() {
101 return Some(SeedNameError::Empty);
102 }
103
104 if bytes.len() > SEED_NAME_MAX_LENGTH {
105 return Some(SeedNameError::TooLong);
106 }
107
108 if bytes[0] == b'-' {
109 return Some(SeedNameError::StartsWithDash);
110 }
111
112 if bytes[bytes.len() - 1] == b'-' {
113 return Some(SeedNameError::EndsWithDash);
114 }
115
116 let mut index = 0;
117
118 while index < bytes.len() {
119 let byte = bytes[index];
120 if !(byte.is_ascii_lowercase() || byte.is_ascii_digit() || byte == b'-') {
121 return Some(SeedNameError::InvalidCharacter);
122 }
123 index += 1;
124 }
125
126 None
127}
128
129#[derive(Clone, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
130#[serde(try_from = "String")]
131pub struct SeedName(std::borrow::Cow<'static, str>);
132
133impl SeedName {
134 #[must_use]
142 pub const fn from_static_or_panic(input: &'static str) -> Self {
143 match validate_seed_name(input) {
144 Some(error) => panic!("{}", error.message()),
145 None => Self(std::borrow::Cow::Borrowed(input)),
146 }
147 }
148
149 #[must_use]
151 pub fn as_str(&self) -> &str {
152 &self.0
153 }
154}
155
156impl std::fmt::Display for SeedName {
157 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158 write!(f, "{}", self.0)
159 }
160}
161
162impl AsRef<str> for SeedName {
163 fn as_ref(&self) -> &str {
164 &self.0
165 }
166}
167
168#[derive(Debug, PartialEq, Eq, thiserror::Error)]
169#[error("Duplicate seed name: {0}")]
170pub struct DuplicateSeedName(pub SeedName);
171
172impl std::str::FromStr for SeedName {
173 type Err = SeedNameError;
174
175 fn from_str(value: &str) -> Result<Self, Self::Err> {
176 match validate_seed_name(value) {
177 Some(error) => Err(error),
178 None => Ok(Self(std::borrow::Cow::Owned(value.to_owned()))),
179 }
180 }
181}
182
183impl TryFrom<String> for SeedName {
184 type Error = SeedNameError;
185
186 fn try_from(value: String) -> Result<Self, Self::Error> {
187 match validate_seed_name(&value) {
188 Some(error) => Err(error),
189 None => Ok(Self(std::borrow::Cow::Owned(value))),
190 }
191 }
192}
193
194#[derive(Clone, Debug, PartialEq)]
195pub struct Command {
196 pub command: String,
197 pub arguments: Vec<String>,
198}
199
200impl Command {
201 pub fn new(
202 command: impl Into<String>,
203 arguments: impl IntoIterator<Item = impl Into<String>>,
204 ) -> Self {
205 Self {
206 command: command.into(),
207 arguments: arguments.into_iter().map(|a| a.into()).collect(),
208 }
209 }
210}
211
212#[derive(Clone, Debug, serde::Deserialize, PartialEq)]
213#[serde(tag = "type", rename_all = "kebab-case")]
214pub enum CommandCacheConfig {
215 None,
217 CommandHash,
219 KeyCommand {
221 command: String,
222 #[serde(default)]
223 arguments: Vec<String>,
224 },
225 KeyScript { script: String },
227}
228
229#[derive(Clone, Debug, PartialEq)]
230pub enum Seed {
231 SqlFile {
232 path: std::path::PathBuf,
233 },
234 SqlFileGitRevision {
235 git_revision: String,
236 path: std::path::PathBuf,
237 },
238 Command {
239 command: Command,
240 cache: CommandCacheConfig,
241 },
242 Script {
243 script: String,
244 },
245 ContainerScript {
246 script: String,
247 },
248}
249
250impl Seed {
251 async fn load(
252 &self,
253 name: SeedName,
254 hash_chain: &mut HashChain,
255 backend: &ociman::Backend,
256 instance_name: &crate::InstanceName,
257 ) -> Result<LoadedSeed, LoadError> {
258 match self {
259 Seed::SqlFile { path } => {
260 let content =
261 std::fs::read_to_string(path).map_err(|source| LoadError::FileRead {
262 name: name.clone(),
263 path: path.clone(),
264 source,
265 })?;
266
267 hash_chain.update(&content);
268
269 Ok(LoadedSeed::SqlFile {
270 cache_status: CacheStatus::from_cache_key(
271 hash_chain.cache_key(),
272 backend,
273 instance_name,
274 )
275 .await,
276 name,
277 path: path.clone(),
278 content,
279 })
280 }
281 Seed::SqlFileGitRevision { path, git_revision } => {
282 let output =
283 git_proc::show::new(&format!("{git_revision}:{}", path.to_str().unwrap()))
284 .build()
285 .stdout_capture()
286 .stderr_capture()
287 .accept_nonzero_exit()
288 .run()
289 .await
290 .map_err(|error| LoadError::GitRevision {
291 name: name.clone(),
292 path: path.clone(),
293 git_revision: git_revision.clone(),
294 message: error.to_string(),
295 })?;
296
297 if output.status.success() {
298 let content = String::from_utf8(output.stdout).map_err(|error| {
299 LoadError::GitRevision {
300 name: name.clone(),
301 path: path.clone(),
302 git_revision: git_revision.clone(),
303 message: error.to_string(),
304 }
305 })?;
306
307 hash_chain.update(&content);
308
309 Ok(LoadedSeed::SqlFileGitRevision {
310 cache_status: CacheStatus::from_cache_key(
311 hash_chain.cache_key(),
312 backend,
313 instance_name,
314 )
315 .await,
316 name,
317 path: path.clone(),
318 git_revision: git_revision.clone(),
319 content,
320 })
321 } else {
322 let message = String::from_utf8(output.stderr).map_err(|error| {
323 LoadError::GitRevision {
324 name: name.clone(),
325 path: path.clone(),
326 git_revision: git_revision.clone(),
327 message: error.to_string(),
328 }
329 })?;
330 Err(LoadError::GitRevision {
331 name,
332 path: path.clone(),
333 git_revision: git_revision.clone(),
334 message,
335 })
336 }
337 }
338 Seed::Command { command, cache } => {
339 let cache_key_output = match cache {
340 CommandCacheConfig::None => {
341 hash_chain.stop();
342 None
343 }
344 CommandCacheConfig::CommandHash => {
345 hash_chain.update(&command.command);
346 for argument in &command.arguments {
347 hash_chain.update(argument);
348 }
349 None
350 }
351 CommandCacheConfig::KeyCommand {
352 command: key_command,
353 arguments: key_arguments,
354 } => {
355 let output = cmd_proc::Command::new(key_command)
356 .arguments(key_arguments)
357 .stdout_capture()
358 .stderr_capture()
359 .accept_nonzero_exit()
360 .run()
361 .await
362 .map_err(|error| LoadError::KeyCommand {
363 name: name.clone(),
364 command: key_command.clone(),
365 message: error.to_string(),
366 })?;
367
368 if output.status.success() {
369 hash_chain.update(&output.stdout);
370 Some(output.stdout)
371 } else {
372 let message = String::from_utf8(output.stderr).map_err(|error| {
373 LoadError::KeyCommand {
374 name: name.clone(),
375 command: key_command.clone(),
376 message: error.to_string(),
377 }
378 })?;
379 return Err(LoadError::KeyCommand {
380 name,
381 command: key_command.clone(),
382 message,
383 });
384 }
385 }
386 CommandCacheConfig::KeyScript { script: key_script } => {
387 let output = cmd_proc::Command::new("sh")
388 .arguments(["-e", "-c"])
389 .argument(key_script)
390 .stdout_capture()
391 .stderr_capture()
392 .accept_nonzero_exit()
393 .run()
394 .await
395 .map_err(|error| LoadError::KeyScript {
396 name: name.clone(),
397 message: error.to_string(),
398 })?;
399
400 if output.status.success() {
401 hash_chain.update(&output.stdout);
402 Some(output.stdout)
403 } else {
404 let message = String::from_utf8(output.stderr).map_err(|error| {
405 LoadError::KeyScript {
406 name: name.clone(),
407 message: error.to_string(),
408 }
409 })?;
410 return Err(LoadError::KeyScript { name, message });
411 }
412 }
413 };
414
415 Ok(LoadedSeed::Command {
416 cache_status: CacheStatus::from_cache_key(
417 hash_chain.cache_key(),
418 backend,
419 instance_name,
420 )
421 .await,
422 cache_key_output,
423 name,
424 command: command.clone(),
425 })
426 }
427 Seed::Script { script } => {
428 hash_chain.update(script);
429
430 Ok(LoadedSeed::Script {
431 cache_status: CacheStatus::from_cache_key(
432 hash_chain.cache_key(),
433 backend,
434 instance_name,
435 )
436 .await,
437 name,
438 script: script.clone(),
439 })
440 }
441 Seed::ContainerScript { script } => {
442 hash_chain.update(script);
443
444 Ok(LoadedSeed::ContainerScript {
445 cache_status: CacheStatus::from_cache_key(
446 hash_chain.cache_key(),
447 backend,
448 instance_name,
449 )
450 .await,
451 name,
452 script: script.clone(),
453 })
454 }
455 }
456 }
457}
458
459#[derive(Debug, thiserror::Error)]
460pub enum LoadError {
461 #[error("Failed to load seed {name}: could not read file {path}: {source}")]
462 FileRead {
463 name: SeedName,
464 path: std::path::PathBuf,
465 source: std::io::Error,
466 },
467 #[error(
468 "Failed to load seed {name}: could not read {path} at git revision {git_revision}: {message}"
469 )]
470 GitRevision {
471 name: SeedName,
472 path: std::path::PathBuf,
473 git_revision: String,
474 message: String,
475 },
476 #[error("Failed to load seed {name}: cache key command {command} failed: {message}")]
477 KeyCommand {
478 name: SeedName,
479 command: String,
480 message: String,
481 },
482 #[error("Failed to load seed {name}: cache key script failed: {message}")]
483 KeyScript { name: SeedName, message: String },
484}
485
486#[derive(Clone, Debug, PartialEq)]
487pub enum LoadedSeed {
488 SqlFile {
489 cache_status: CacheStatus,
490 name: SeedName,
491 path: std::path::PathBuf,
492 content: String,
493 },
494 SqlFileGitRevision {
495 cache_status: CacheStatus,
496 name: SeedName,
497 path: std::path::PathBuf,
498 git_revision: String,
499 content: String,
500 },
501 Command {
502 cache_status: CacheStatus,
503 cache_key_output: Option<Vec<u8>>,
504 name: SeedName,
505 command: Command,
506 },
507 Script {
508 cache_status: CacheStatus,
509 name: SeedName,
510 script: String,
511 },
512 ContainerScript {
513 cache_status: CacheStatus,
514 name: SeedName,
515 script: String,
516 },
517}
518
519impl LoadedSeed {
520 #[must_use]
521 pub fn cache_status(&self) -> &CacheStatus {
522 match self {
523 Self::SqlFile { cache_status, .. }
524 | Self::SqlFileGitRevision { cache_status, .. }
525 | Self::Command { cache_status, .. }
526 | Self::Script { cache_status, .. }
527 | Self::ContainerScript { cache_status, .. } => cache_status,
528 }
529 }
530
531 #[must_use]
532 pub fn name(&self) -> &SeedName {
533 match self {
534 Self::SqlFile { name, .. }
535 | Self::SqlFileGitRevision { name, .. }
536 | Self::Command { name, .. }
537 | Self::Script { name, .. }
538 | Self::ContainerScript { name, .. } => name,
539 }
540 }
541
542 fn variant_name(&self) -> &'static str {
543 match self {
544 Self::SqlFile { .. } => "sql-file",
545 Self::SqlFileGitRevision { .. } => "sql-file-git-revision",
546 Self::Command { .. } => "command",
547 Self::Script { .. } => "script",
548 Self::ContainerScript { .. } => "container-script",
549 }
550 }
551}
552
553struct HashChain {
554 hasher: Option<sha2::Sha256>,
555}
556
557impl HashChain {
558 fn new() -> Self {
559 use sha2::Digest;
560
561 Self {
562 hasher: Some(sha2::Sha256::new()),
563 }
564 }
565
566 fn update(&mut self, bytes: impl AsRef<[u8]>) {
567 use sha2::Digest;
568
569 if let Some(ref mut hasher) = self.hasher {
570 hasher.update(bytes)
571 }
572 }
573
574 fn cache_key(&self) -> Option<CacheKey> {
575 use sha2::Digest;
576
577 self.hasher
578 .as_ref()
579 .map(|hasher| hasher.clone().finalize().into())
580 }
581
582 fn stop(&mut self) {
583 self.hasher = None
584 }
585}
586
587#[derive(Debug, PartialEq)]
588pub struct LoadedSeeds<'a> {
589 image: &'a crate::image::Image,
590 seeds: Vec<LoadedSeed>,
591}
592
593impl<'a> LoadedSeeds<'a> {
594 pub async fn load(
595 image: &'a crate::image::Image,
596 ssl_config: Option<&crate::definition::SslConfig>,
597 seeds: &indexmap::IndexMap<SeedName, Seed>,
598 backend: &ociman::Backend,
599 instance_name: &crate::InstanceName,
600 ) -> Result<Self, LoadError> {
601 let mut hash_chain = HashChain::new();
602 let mut loaded_seeds = Vec::new();
603
604 hash_chain.update(crate::VERSION_STR);
605 hash_chain.update(image.to_string());
606
607 match ssl_config {
608 Some(crate::definition::SslConfig::Generated { hostname }) => {
609 hash_chain.update("ssl:generated:");
610 hash_chain.update(hostname.as_str());
611 }
612 None => {
613 hash_chain.update("ssl:none");
614 }
615 }
616
617 for (name, seed) in seeds {
618 let loaded_seed = seed
619 .load(name.clone(), &mut hash_chain, backend, instance_name)
620 .await?;
621 loaded_seeds.push(loaded_seed);
622 }
623
624 Ok(Self {
625 image,
626 seeds: loaded_seeds,
627 })
628 }
629
630 pub fn iter_seeds(&self) -> impl Iterator<Item = &LoadedSeed> {
631 self.seeds.iter()
632 }
633
634 pub fn print(&self, instance_name: &crate::InstanceName) {
635 println!("Instance: {instance_name}");
636 println!("Image: {}", self.image);
637 println!("Version: {}", crate::VERSION_STR);
638 println!();
639
640 let mut table = comfy_table::Table::new();
641
642 table
643 .load_preset(comfy_table::presets::NOTHING)
644 .set_header(["Seed", "Type", "Status"]);
645
646 for seed in &self.seeds {
647 table.add_row([
648 seed.name().as_str(),
649 seed.variant_name(),
650 seed.cache_status().status_str(),
651 ]);
652 }
653
654 println!("{table}");
655 }
656
657 pub fn print_json(&self, instance_name: &crate::InstanceName) {
658 #[derive(serde::Serialize)]
659 struct Output<'a> {
660 instance: &'a str,
661 image: String,
662 version: &'a str,
663 seeds: Vec<SeedOutput<'a>>,
664 }
665
666 #[derive(serde::Serialize)]
667 struct SeedOutput<'a> {
668 name: &'a str,
669 r#type: &'a str,
670 status: &'a str,
671 #[serde(skip_serializing_if = "Option::is_none")]
672 reference: Option<String>,
673 }
674
675 let output = Output {
676 instance: instance_name.as_ref(),
677 image: self.image.to_string(),
678 version: crate::VERSION_STR,
679 seeds: self
680 .seeds
681 .iter()
682 .map(|seed| SeedOutput {
683 name: seed.name().as_str(),
684 r#type: seed.variant_name(),
685 status: seed.cache_status().status_str(),
686 reference: seed.cache_status().reference().map(|r| r.to_string()),
687 })
688 .collect(),
689 };
690
691 println!("{}", serde_json::to_string_pretty(&output).unwrap());
692 }
693}
694
695#[cfg(test)]
696mod test {
697 use super::*;
698
699 #[test]
700 fn parse_valid_simple() {
701 let name: SeedName = "schema".parse().unwrap();
702 assert_eq!(name.to_string(), "schema");
703 assert_eq!(name.as_str(), "schema");
704 }
705
706 #[test]
707 fn parse_valid_with_dash() {
708 let name: SeedName = "create-users-table".parse().unwrap();
709 assert_eq!(name.to_string(), "create-users-table");
710 }
711
712 #[test]
713 fn parse_valid_single_char() {
714 let name: SeedName = "a".parse().unwrap();
715 assert_eq!(name.to_string(), "a");
716 }
717
718 #[test]
719 fn parse_valid_numeric() {
720 let name: SeedName = "123".parse().unwrap();
721 assert_eq!(name.to_string(), "123");
722 }
723
724 #[test]
725 fn parse_valid_max_length() {
726 let input = "a".repeat(SEED_NAME_MAX_LENGTH);
727 let name: SeedName = input.parse().unwrap();
728 assert_eq!(name.to_string(), input);
729 }
730
731 #[test]
732 fn parse_empty_fails() {
733 assert_eq!("".parse::<SeedName>(), Err(SeedNameError::Empty));
734 assert_eq!(SeedName::try_from(String::new()), Err(SeedNameError::Empty));
735 }
736
737 #[test]
738 fn parse_too_long_fails() {
739 let input = "a".repeat(SEED_NAME_MAX_LENGTH + 1);
740 assert_eq!(input.parse::<SeedName>(), Err(SeedNameError::TooLong));
741 }
742
743 #[test]
744 fn parse_starts_with_dash_fails() {
745 assert_eq!(
746 "-schema".parse::<SeedName>(),
747 Err(SeedNameError::StartsWithDash)
748 );
749 }
750
751 #[test]
752 fn parse_ends_with_dash_fails() {
753 assert_eq!(
754 "schema-".parse::<SeedName>(),
755 Err(SeedNameError::EndsWithDash)
756 );
757 }
758
759 #[test]
760 fn parse_uppercase_fails() {
761 assert_eq!(
762 "Schema".parse::<SeedName>(),
763 Err(SeedNameError::InvalidCharacter)
764 );
765 }
766
767 #[test]
768 fn parse_underscore_fails() {
769 assert_eq!(
770 "create_table".parse::<SeedName>(),
771 Err(SeedNameError::InvalidCharacter)
772 );
773 }
774
775 #[test]
776 fn parse_space_fails() {
777 assert_eq!(
778 "my seed".parse::<SeedName>(),
779 Err(SeedNameError::InvalidCharacter)
780 );
781 }
782
783 #[test]
784 fn try_from_string_valid() {
785 assert_eq!(
786 SeedName::try_from("valid-name".to_string()),
787 Ok(SeedName::from_static_or_panic("valid-name"))
788 );
789 }
790
791 #[test]
792 fn from_static_or_panic_works() {
793 const NAME: SeedName = SeedName::from_static_or_panic("my-seed");
794 assert_eq!(NAME.as_str(), "my-seed");
795 }
796
797 #[test]
798 fn test_cache_status_uncacheable() {
799 let loaded_seed = LoadedSeed::Command {
800 cache_status: CacheStatus::Uncacheable,
801 cache_key_output: None,
802 name: "run-migrations".parse().unwrap(),
803 command: Command::new("migrate", ["up"]),
804 };
805
806 assert!(loaded_seed.cache_status().reference().is_none());
807 assert!(!loaded_seed.cache_status().is_hit());
808 }
809
810 #[test]
811 fn test_cache_status_miss() {
812 let reference: ociman::Reference =
813 "pg-ephemeral/main:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
814 .parse()
815 .unwrap();
816
817 let loaded_seed = LoadedSeed::SqlFile {
818 cache_status: CacheStatus::Miss {
819 reference: reference.clone(),
820 },
821 name: "schema".parse().unwrap(),
822 path: "schema.sql".into(),
823 content: "CREATE TABLE test();".to_string(),
824 };
825
826 assert_eq!(loaded_seed.cache_status().reference(), Some(&reference));
827 assert!(!loaded_seed.cache_status().is_hit());
828 }
829
830 #[test]
831 fn test_cache_status_hit() {
832 let reference: ociman::Reference =
833 "pg-ephemeral/main:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
834 .parse()
835 .unwrap();
836
837 let loaded_seed = LoadedSeed::SqlFile {
838 cache_status: CacheStatus::Hit {
839 reference: reference.clone(),
840 },
841 name: "schema".parse().unwrap(),
842 path: "schema.sql".into(),
843 content: "CREATE TABLE test();".to_string(),
844 };
845
846 assert_eq!(loaded_seed.cache_status().reference(), Some(&reference));
847 assert!(loaded_seed.cache_status().is_hit());
848 }
849}