1use git_proc::Build;
2
3type CacheKey = [u8; 32];
4
5#[derive(Clone, Debug, PartialEq)]
6pub enum CacheStatus {
7 Hit { reference: ociman::Reference },
8 Miss { reference: ociman::Reference },
9 Uncacheable,
10}
11
12impl CacheStatus {
13 async fn from_cache_key(
14 cache_key: Option<CacheKey>,
15 backend: &ociman::Backend,
16 instance_name: &crate::InstanceName,
17 ) -> Self {
18 match cache_key {
19 Some(key) => {
20 let reference = format!("pg-ephemeral/{}:{}", instance_name, hex::encode(key))
21 .parse()
22 .unwrap();
23 if backend.is_image_present(&reference).await {
24 Self::Hit { reference }
25 } else {
26 Self::Miss { reference }
27 }
28 }
29 None => Self::Uncacheable,
30 }
31 }
32
33 #[must_use]
34 pub fn reference(&self) -> Option<&ociman::Reference> {
35 match self {
36 Self::Hit { reference } | Self::Miss { reference } => Some(reference),
37 Self::Uncacheable => None,
38 }
39 }
40
41 #[must_use]
42 pub fn is_hit(&self) -> bool {
43 matches!(self, Self::Hit { .. })
44 }
45
46 #[must_use]
47 pub fn status_str(&self) -> &'static str {
48 match self {
49 Self::Hit { .. } => "hit",
50 Self::Miss { .. } => "miss",
51 Self::Uncacheable => "uncacheable",
52 }
53 }
54}
55
56pub const SEED_NAME_MAX_LENGTH: usize = 63;
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum SeedNameError {
62 Empty,
64 TooLong,
66 InvalidCharacter,
68 StartsWithDash,
70 EndsWithDash,
72}
73
74impl SeedNameError {
75 #[must_use]
76 const fn message(&self) -> &'static str {
77 match self {
78 Self::Empty => "seed name cannot be empty",
79 Self::TooLong => "seed name exceeds maximum length of 63 bytes",
80 Self::InvalidCharacter => {
81 "seed name must contain only lowercase ASCII alphanumeric characters or dashes"
82 }
83 Self::StartsWithDash => "seed name cannot start with a dash",
84 Self::EndsWithDash => "seed name cannot end with a dash",
85 }
86 }
87}
88
89impl std::fmt::Display for SeedNameError {
90 fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
91 write!(formatter, "{}", self.message())
92 }
93}
94
95impl std::error::Error for SeedNameError {}
96
97const fn validate_seed_name(input: &str) -> Option<SeedNameError> {
98 let bytes = input.as_bytes();
99
100 if bytes.is_empty() {
101 return Some(SeedNameError::Empty);
102 }
103
104 if bytes.len() > SEED_NAME_MAX_LENGTH {
105 return Some(SeedNameError::TooLong);
106 }
107
108 if bytes[0] == b'-' {
109 return Some(SeedNameError::StartsWithDash);
110 }
111
112 if bytes[bytes.len() - 1] == b'-' {
113 return Some(SeedNameError::EndsWithDash);
114 }
115
116 let mut index = 0;
117
118 while index < bytes.len() {
119 let byte = bytes[index];
120 if !(byte.is_ascii_lowercase() || byte.is_ascii_digit() || byte == b'-') {
121 return Some(SeedNameError::InvalidCharacter);
122 }
123 index += 1;
124 }
125
126 None
127}
128
129#[derive(Clone, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
130#[serde(try_from = "String")]
131pub struct SeedName(std::borrow::Cow<'static, str>);
132
133impl SeedName {
134 #[must_use]
142 pub const fn from_static_or_panic(input: &'static str) -> Self {
143 match validate_seed_name(input) {
144 Some(error) => panic!("{}", error.message()),
145 None => Self(std::borrow::Cow::Borrowed(input)),
146 }
147 }
148
149 #[must_use]
151 pub fn as_str(&self) -> &str {
152 &self.0
153 }
154}
155
156impl std::fmt::Display for SeedName {
157 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
158 write!(f, "{}", self.0)
159 }
160}
161
162impl AsRef<str> for SeedName {
163 fn as_ref(&self) -> &str {
164 &self.0
165 }
166}
167
168#[derive(Debug, PartialEq, Eq, thiserror::Error)]
169#[error("Duplicate seed name: {0}")]
170pub struct DuplicateSeedName(pub SeedName);
171
172impl std::str::FromStr for SeedName {
173 type Err = SeedNameError;
174
175 fn from_str(value: &str) -> Result<Self, Self::Err> {
176 match validate_seed_name(value) {
177 Some(error) => Err(error),
178 None => Ok(Self(std::borrow::Cow::Owned(value.to_owned()))),
179 }
180 }
181}
182
183impl TryFrom<String> for SeedName {
184 type Error = SeedNameError;
185
186 fn try_from(value: String) -> Result<Self, Self::Error> {
187 match validate_seed_name(&value) {
188 Some(error) => Err(error),
189 None => Ok(Self(std::borrow::Cow::Owned(value))),
190 }
191 }
192}
193
194#[derive(Clone, Debug, PartialEq)]
195pub struct Command {
196 pub command: String,
197 pub arguments: Vec<String>,
198}
199
200impl Command {
201 pub fn new(
202 command: impl Into<String>,
203 arguments: impl IntoIterator<Item = impl Into<String>>,
204 ) -> Self {
205 Self {
206 command: command.into(),
207 arguments: arguments.into_iter().map(|a| a.into()).collect(),
208 }
209 }
210}
211
212#[derive(Clone, Debug, serde::Deserialize, PartialEq)]
213#[serde(tag = "type", rename_all = "kebab-case")]
214pub enum CommandCacheConfig {
215 None,
217 CommandHash,
219 KeyCommand {
221 command: String,
222 #[serde(default)]
223 arguments: Vec<String>,
224 },
225 KeyScript { script: String },
227}
228
229#[derive(Clone, Debug, PartialEq)]
230pub enum Seed {
231 SqlFile {
232 path: std::path::PathBuf,
233 },
234 SqlFileGitRevision {
235 git_revision: String,
236 path: std::path::PathBuf,
237 },
238 Command {
239 command: Command,
240 cache: CommandCacheConfig,
241 },
242 Script {
243 script: String,
244 },
245 ContainerScript {
246 script: String,
247 },
248 CsvFile {
249 path: std::path::PathBuf,
250 table: pg_client::QualifiedTable,
251 delimiter: char,
252 },
253}
254
255impl Seed {
256 async fn load(
257 &self,
258 name: SeedName,
259 hash_chain: &mut HashChain,
260 backend: &ociman::Backend,
261 instance_name: &crate::InstanceName,
262 ) -> Result<LoadedSeed, LoadError> {
263 match self {
264 Seed::SqlFile { path } => {
265 let content =
266 std::fs::read_to_string(path).map_err(|source| LoadError::FileRead {
267 name: name.clone(),
268 path: path.clone(),
269 source,
270 })?;
271
272 hash_chain.update(&content);
273
274 Ok(LoadedSeed::SqlFile {
275 cache_status: CacheStatus::from_cache_key(
276 hash_chain.cache_key(),
277 backend,
278 instance_name,
279 )
280 .await,
281 name,
282 path: path.clone(),
283 content,
284 })
285 }
286 Seed::SqlFileGitRevision { path, git_revision } => {
287 let output =
288 git_proc::show::new(&format!("{git_revision}:{}", path.to_str().unwrap()))
289 .build()
290 .stdout_capture()
291 .stderr_capture()
292 .accept_nonzero_exit()
293 .run()
294 .await
295 .map_err(|error| LoadError::GitRevision {
296 name: name.clone(),
297 path: path.clone(),
298 git_revision: git_revision.clone(),
299 message: error.to_string(),
300 })?;
301
302 if output.status.success() {
303 let content = String::from_utf8(output.stdout).map_err(|error| {
304 LoadError::GitRevision {
305 name: name.clone(),
306 path: path.clone(),
307 git_revision: git_revision.clone(),
308 message: error.to_string(),
309 }
310 })?;
311
312 hash_chain.update(&content);
313
314 Ok(LoadedSeed::SqlFileGitRevision {
315 cache_status: CacheStatus::from_cache_key(
316 hash_chain.cache_key(),
317 backend,
318 instance_name,
319 )
320 .await,
321 name,
322 path: path.clone(),
323 git_revision: git_revision.clone(),
324 content,
325 })
326 } else {
327 let message = String::from_utf8(output.stderr).map_err(|error| {
328 LoadError::GitRevision {
329 name: name.clone(),
330 path: path.clone(),
331 git_revision: git_revision.clone(),
332 message: error.to_string(),
333 }
334 })?;
335 Err(LoadError::GitRevision {
336 name,
337 path: path.clone(),
338 git_revision: git_revision.clone(),
339 message,
340 })
341 }
342 }
343 Seed::Command { command, cache } => {
344 let cache_key_output = match cache {
345 CommandCacheConfig::None => {
346 hash_chain.stop();
347 None
348 }
349 CommandCacheConfig::CommandHash => {
350 hash_chain.update(&command.command);
351 for argument in &command.arguments {
352 hash_chain.update(argument);
353 }
354 None
355 }
356 CommandCacheConfig::KeyCommand {
357 command: key_command,
358 arguments: key_arguments,
359 } => {
360 let output = cmd_proc::Command::new(key_command)
361 .arguments(key_arguments)
362 .stdout_capture()
363 .stderr_capture()
364 .accept_nonzero_exit()
365 .run()
366 .await
367 .map_err(|error| LoadError::KeyCommand {
368 name: name.clone(),
369 command: key_command.clone(),
370 message: error.to_string(),
371 })?;
372
373 if output.status.success() {
374 hash_chain.update(&output.stdout);
375 Some(output.stdout)
376 } else {
377 let message = String::from_utf8(output.stderr).map_err(|error| {
378 LoadError::KeyCommand {
379 name: name.clone(),
380 command: key_command.clone(),
381 message: error.to_string(),
382 }
383 })?;
384 return Err(LoadError::KeyCommand {
385 name,
386 command: key_command.clone(),
387 message,
388 });
389 }
390 }
391 CommandCacheConfig::KeyScript { script: key_script } => {
392 let output = cmd_proc::Command::new("sh")
393 .arguments(["-e", "-c"])
394 .argument(key_script)
395 .stdout_capture()
396 .stderr_capture()
397 .accept_nonzero_exit()
398 .run()
399 .await
400 .map_err(|error| LoadError::KeyScript {
401 name: name.clone(),
402 message: error.to_string(),
403 })?;
404
405 if output.status.success() {
406 hash_chain.update(&output.stdout);
407 Some(output.stdout)
408 } else {
409 let message = String::from_utf8(output.stderr).map_err(|error| {
410 LoadError::KeyScript {
411 name: name.clone(),
412 message: error.to_string(),
413 }
414 })?;
415 return Err(LoadError::KeyScript { name, message });
416 }
417 }
418 };
419
420 Ok(LoadedSeed::Command {
421 cache_status: CacheStatus::from_cache_key(
422 hash_chain.cache_key(),
423 backend,
424 instance_name,
425 )
426 .await,
427 cache_key_output,
428 name,
429 command: command.clone(),
430 })
431 }
432 Seed::Script { script } => {
433 hash_chain.update(script);
434
435 Ok(LoadedSeed::Script {
436 cache_status: CacheStatus::from_cache_key(
437 hash_chain.cache_key(),
438 backend,
439 instance_name,
440 )
441 .await,
442 name,
443 script: script.clone(),
444 })
445 }
446 Seed::ContainerScript { script } => {
447 hash_chain.update(script);
448
449 Ok(LoadedSeed::ContainerScript {
450 cache_status: CacheStatus::from_cache_key(
451 hash_chain.cache_key(),
452 backend,
453 instance_name,
454 )
455 .await,
456 name,
457 script: script.clone(),
458 })
459 }
460 Seed::CsvFile {
461 path,
462 table,
463 delimiter,
464 } => {
465 let content =
466 std::fs::read_to_string(path).map_err(|source| LoadError::FileRead {
467 name: name.clone(),
468 path: path.clone(),
469 source,
470 })?;
471
472 hash_chain.update(table.schema.as_ref());
473 hash_chain.update(table.table.as_ref());
474 hash_chain.update(&content);
475
476 Ok(LoadedSeed::CsvFile {
477 cache_status: CacheStatus::from_cache_key(
478 hash_chain.cache_key(),
479 backend,
480 instance_name,
481 )
482 .await,
483 name,
484 path: path.clone(),
485 table: table.clone(),
486 delimiter: *delimiter,
487 content,
488 })
489 }
490 }
491 }
492}
493
494#[derive(Debug, thiserror::Error)]
495pub enum LoadError {
496 #[error("Failed to load seed {name}: could not read file {path}: {source}")]
497 FileRead {
498 name: SeedName,
499 path: std::path::PathBuf,
500 source: std::io::Error,
501 },
502 #[error(
503 "Failed to load seed {name}: could not read {path} at git revision {git_revision}: {message}"
504 )]
505 GitRevision {
506 name: SeedName,
507 path: std::path::PathBuf,
508 git_revision: String,
509 message: String,
510 },
511 #[error("Failed to load seed {name}: cache key command {command} failed: {message}")]
512 KeyCommand {
513 name: SeedName,
514 command: String,
515 message: String,
516 },
517 #[error("Failed to load seed {name}: cache key script failed: {message}")]
518 KeyScript { name: SeedName, message: String },
519}
520
521#[derive(Clone, Debug, PartialEq)]
522pub enum LoadedSeed {
523 SqlFile {
524 cache_status: CacheStatus,
525 name: SeedName,
526 path: std::path::PathBuf,
527 content: String,
528 },
529 SqlFileGitRevision {
530 cache_status: CacheStatus,
531 name: SeedName,
532 path: std::path::PathBuf,
533 git_revision: String,
534 content: String,
535 },
536 Command {
537 cache_status: CacheStatus,
538 cache_key_output: Option<Vec<u8>>,
539 name: SeedName,
540 command: Command,
541 },
542 Script {
543 cache_status: CacheStatus,
544 name: SeedName,
545 script: String,
546 },
547 ContainerScript {
548 cache_status: CacheStatus,
549 name: SeedName,
550 script: String,
551 },
552 CsvFile {
553 cache_status: CacheStatus,
554 name: SeedName,
555 path: std::path::PathBuf,
556 table: pg_client::QualifiedTable,
557 delimiter: char,
558 content: String,
559 },
560}
561
562impl LoadedSeed {
563 #[must_use]
564 pub fn cache_status(&self) -> &CacheStatus {
565 match self {
566 Self::SqlFile { cache_status, .. }
567 | Self::SqlFileGitRevision { cache_status, .. }
568 | Self::Command { cache_status, .. }
569 | Self::Script { cache_status, .. }
570 | Self::ContainerScript { cache_status, .. }
571 | Self::CsvFile { cache_status, .. } => cache_status,
572 }
573 }
574
575 #[must_use]
576 pub fn name(&self) -> &SeedName {
577 match self {
578 Self::SqlFile { name, .. }
579 | Self::SqlFileGitRevision { name, .. }
580 | Self::Command { name, .. }
581 | Self::Script { name, .. }
582 | Self::ContainerScript { name, .. }
583 | Self::CsvFile { name, .. } => name,
584 }
585 }
586
587 fn variant_name(&self) -> &'static str {
588 match self {
589 Self::SqlFile { .. } => "sql-file",
590 Self::SqlFileGitRevision { .. } => "sql-file-git-revision",
591 Self::Command { .. } => "command",
592 Self::Script { .. } => "script",
593 Self::ContainerScript { .. } => "container-script",
594 Self::CsvFile { .. } => "csv-file",
595 }
596 }
597}
598
599struct HashChain {
600 hasher: Option<sha2::Sha256>,
601}
602
603impl HashChain {
604 fn new() -> Self {
605 use sha2::Digest;
606
607 Self {
608 hasher: Some(sha2::Sha256::new()),
609 }
610 }
611
612 fn update(&mut self, bytes: impl AsRef<[u8]>) {
613 use sha2::Digest;
614
615 if let Some(ref mut hasher) = self.hasher {
616 hasher.update(bytes)
617 }
618 }
619
620 fn cache_key(&self) -> Option<CacheKey> {
621 use sha2::Digest;
622
623 self.hasher
624 .as_ref()
625 .map(|hasher| hasher.clone().finalize().into())
626 }
627
628 fn stop(&mut self) {
629 self.hasher = None
630 }
631}
632
633#[derive(Debug, PartialEq)]
634pub struct LoadedSeeds<'a> {
635 image: &'a crate::image::Image,
636 seeds: Vec<LoadedSeed>,
637}
638
639impl<'a> LoadedSeeds<'a> {
640 pub async fn load(
641 image: &'a crate::image::Image,
642 ssl_config: Option<&crate::definition::SslConfig>,
643 seeds: &indexmap::IndexMap<SeedName, Seed>,
644 backend: &ociman::Backend,
645 instance_name: &crate::InstanceName,
646 ) -> Result<Self, LoadError> {
647 let mut hash_chain = HashChain::new();
648 let mut loaded_seeds = Vec::new();
649
650 hash_chain.update(crate::VERSION_STR);
651 hash_chain.update(image.to_string());
652
653 match ssl_config {
654 Some(crate::definition::SslConfig::Generated { hostname }) => {
655 hash_chain.update("ssl:generated:");
656 hash_chain.update(hostname.as_str());
657 }
658 None => {
659 hash_chain.update("ssl:none");
660 }
661 }
662
663 for (name, seed) in seeds {
664 let loaded_seed = seed
665 .load(name.clone(), &mut hash_chain, backend, instance_name)
666 .await?;
667 loaded_seeds.push(loaded_seed);
668 }
669
670 Ok(Self {
671 image,
672 seeds: loaded_seeds,
673 })
674 }
675
676 pub fn iter_seeds(&self) -> impl Iterator<Item = &LoadedSeed> {
677 self.seeds.iter()
678 }
679
680 pub fn print(&self, instance_name: &crate::InstanceName) {
681 println!("Instance: {instance_name}");
682 println!("Image: {}", self.image);
683 println!("Version: {}", crate::VERSION_STR);
684 println!();
685
686 let mut table = comfy_table::Table::new();
687
688 table
689 .load_preset(comfy_table::presets::NOTHING)
690 .set_header(["Seed", "Type", "Status"]);
691
692 for seed in &self.seeds {
693 table.add_row([
694 seed.name().as_str(),
695 seed.variant_name(),
696 seed.cache_status().status_str(),
697 ]);
698 }
699
700 println!("{table}");
701 }
702
703 pub fn print_json(&self, instance_name: &crate::InstanceName) {
704 #[derive(serde::Serialize)]
705 struct Output<'a> {
706 instance: &'a str,
707 image: String,
708 version: &'a str,
709 seeds: Vec<SeedOutput<'a>>,
710 }
711
712 #[derive(serde::Serialize)]
713 struct SeedOutput<'a> {
714 name: &'a str,
715 r#type: &'a str,
716 status: &'a str,
717 #[serde(skip_serializing_if = "Option::is_none")]
718 reference: Option<String>,
719 }
720
721 let output = Output {
722 instance: instance_name.as_ref(),
723 image: self.image.to_string(),
724 version: crate::VERSION_STR,
725 seeds: self
726 .seeds
727 .iter()
728 .map(|seed| SeedOutput {
729 name: seed.name().as_str(),
730 r#type: seed.variant_name(),
731 status: seed.cache_status().status_str(),
732 reference: seed.cache_status().reference().map(|r| r.to_string()),
733 })
734 .collect(),
735 };
736
737 println!("{}", serde_json::to_string_pretty(&output).unwrap());
738 }
739}
740
741#[cfg(test)]
742mod test {
743 use super::*;
744
745 #[test]
746 fn parse_valid_simple() {
747 let name: SeedName = "schema".parse().unwrap();
748 assert_eq!(name.to_string(), "schema");
749 assert_eq!(name.as_str(), "schema");
750 }
751
752 #[test]
753 fn parse_valid_with_dash() {
754 let name: SeedName = "create-users-table".parse().unwrap();
755 assert_eq!(name.to_string(), "create-users-table");
756 }
757
758 #[test]
759 fn parse_valid_single_char() {
760 let name: SeedName = "a".parse().unwrap();
761 assert_eq!(name.to_string(), "a");
762 }
763
764 #[test]
765 fn parse_valid_numeric() {
766 let name: SeedName = "123".parse().unwrap();
767 assert_eq!(name.to_string(), "123");
768 }
769
770 #[test]
771 fn parse_valid_max_length() {
772 let input = "a".repeat(SEED_NAME_MAX_LENGTH);
773 let name: SeedName = input.parse().unwrap();
774 assert_eq!(name.to_string(), input);
775 }
776
777 #[test]
778 fn parse_empty_fails() {
779 assert_eq!("".parse::<SeedName>(), Err(SeedNameError::Empty));
780 assert_eq!(SeedName::try_from(String::new()), Err(SeedNameError::Empty));
781 }
782
783 #[test]
784 fn parse_too_long_fails() {
785 let input = "a".repeat(SEED_NAME_MAX_LENGTH + 1);
786 assert_eq!(input.parse::<SeedName>(), Err(SeedNameError::TooLong));
787 }
788
789 #[test]
790 fn parse_starts_with_dash_fails() {
791 assert_eq!(
792 "-schema".parse::<SeedName>(),
793 Err(SeedNameError::StartsWithDash)
794 );
795 }
796
797 #[test]
798 fn parse_ends_with_dash_fails() {
799 assert_eq!(
800 "schema-".parse::<SeedName>(),
801 Err(SeedNameError::EndsWithDash)
802 );
803 }
804
805 #[test]
806 fn parse_uppercase_fails() {
807 assert_eq!(
808 "Schema".parse::<SeedName>(),
809 Err(SeedNameError::InvalidCharacter)
810 );
811 }
812
813 #[test]
814 fn parse_underscore_fails() {
815 assert_eq!(
816 "create_table".parse::<SeedName>(),
817 Err(SeedNameError::InvalidCharacter)
818 );
819 }
820
821 #[test]
822 fn parse_space_fails() {
823 assert_eq!(
824 "my seed".parse::<SeedName>(),
825 Err(SeedNameError::InvalidCharacter)
826 );
827 }
828
829 #[test]
830 fn try_from_string_valid() {
831 assert_eq!(
832 SeedName::try_from("valid-name".to_string()),
833 Ok(SeedName::from_static_or_panic("valid-name"))
834 );
835 }
836
837 #[test]
838 fn from_static_or_panic_works() {
839 const NAME: SeedName = SeedName::from_static_or_panic("my-seed");
840 assert_eq!(NAME.as_str(), "my-seed");
841 }
842
843 #[test]
844 fn test_cache_status_uncacheable() {
845 let loaded_seed = LoadedSeed::Command {
846 cache_status: CacheStatus::Uncacheable,
847 cache_key_output: None,
848 name: "run-migrations".parse().unwrap(),
849 command: Command::new("migrate", ["up"]),
850 };
851
852 assert!(loaded_seed.cache_status().reference().is_none());
853 assert!(!loaded_seed.cache_status().is_hit());
854 }
855
856 #[test]
857 fn test_cache_status_miss() {
858 let reference: ociman::Reference =
859 "pg-ephemeral/main:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
860 .parse()
861 .unwrap();
862
863 let loaded_seed = LoadedSeed::SqlFile {
864 cache_status: CacheStatus::Miss {
865 reference: reference.clone(),
866 },
867 name: "schema".parse().unwrap(),
868 path: "schema.sql".into(),
869 content: "CREATE TABLE test();".to_string(),
870 };
871
872 assert_eq!(loaded_seed.cache_status().reference(), Some(&reference));
873 assert!(!loaded_seed.cache_status().is_hit());
874 }
875
876 #[test]
877 fn test_cache_status_hit() {
878 let reference: ociman::Reference =
879 "pg-ephemeral/main:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
880 .parse()
881 .unwrap();
882
883 let loaded_seed = LoadedSeed::SqlFile {
884 cache_status: CacheStatus::Hit {
885 reference: reference.clone(),
886 },
887 name: "schema".parse().unwrap(),
888 path: "schema.sql".into(),
889 content: "CREATE TABLE test();".to_string(),
890 };
891
892 assert_eq!(loaded_seed.cache_status().reference(), Some(&reference));
893 assert!(loaded_seed.cache_status().is_hit());
894 }
895}