1use git_proc::Build;
2
3type CacheKey = [u8; 32];
4
5#[derive(Clone, Debug, PartialEq)]
6pub enum CacheStatus {
7 Hit { reference: ociman::Reference },
8 Miss { reference: ociman::Reference },
9 Uncacheable,
10}
11
12impl CacheStatus {
13 async fn from_cache_key(
14 cache_key: Option<CacheKey>,
15 backend: &ociman::Backend,
16 instance_name: &crate::InstanceName,
17 ) -> Self {
18 match cache_key {
19 Some(key) => {
20 let reference = format!("pg-ephemeral/{}:{}", instance_name, hex::encode(key))
21 .parse()
22 .unwrap();
23 if backend.is_image_present(&reference).await {
24 Self::Hit { reference }
25 } else {
26 Self::Miss { reference }
27 }
28 }
29 None => Self::Uncacheable,
30 }
31 }
32
33 #[must_use]
34 pub fn reference(&self) -> Option<&ociman::Reference> {
35 match self {
36 Self::Hit { reference } | Self::Miss { reference } => Some(reference),
37 Self::Uncacheable => None,
38 }
39 }
40
41 #[must_use]
42 pub fn is_hit(&self) -> bool {
43 matches!(self, Self::Hit { .. })
44 }
45
46 #[must_use]
47 pub fn status_str(&self) -> &'static str {
48 match self {
49 Self::Hit { .. } => "hit",
50 Self::Miss { .. } => "miss",
51 Self::Uncacheable => "uncacheable",
52 }
53 }
54}
55
56#[derive(Clone, Debug, Hash, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
57#[serde(try_from = "String")]
58pub struct SeedName(String);
59
60impl SeedName {
61 #[must_use]
62 pub fn as_str(&self) -> &str {
63 &self.0
64 }
65}
66
67impl std::fmt::Display for SeedName {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 write!(f, "{}", self.0)
70 }
71}
72
73#[derive(Debug, PartialEq, Eq, thiserror::Error)]
74#[error("Seed name cannot be empty")]
75pub struct SeedNameError;
76
77#[derive(Debug, PartialEq, Eq, thiserror::Error)]
78#[error("Duplicate seed name: {0}")]
79pub struct DuplicateSeedName(pub SeedName);
80
81impl std::str::FromStr for SeedName {
82 type Err = SeedNameError;
83
84 fn from_str(value: &str) -> Result<Self, Self::Err> {
85 if value.is_empty() {
86 Err(SeedNameError)
87 } else {
88 Ok(Self(value.to_string()))
89 }
90 }
91}
92
93impl TryFrom<String> for SeedName {
94 type Error = SeedNameError;
95
96 fn try_from(value: String) -> Result<Self, Self::Error> {
97 if value.is_empty() {
98 Err(SeedNameError)
99 } else {
100 Ok(Self(value))
101 }
102 }
103}
104
105impl TryFrom<&str> for SeedName {
106 type Error = SeedNameError;
107
108 fn try_from(value: &str) -> Result<Self, Self::Error> {
109 value.parse()
110 }
111}
112
113#[derive(Clone, Debug, PartialEq)]
114pub struct Command {
115 pub command: String,
116 pub arguments: Vec<String>,
117}
118
119impl Command {
120 pub fn new(
121 command: impl Into<String>,
122 arguments: impl IntoIterator<Item = impl Into<String>>,
123 ) -> Self {
124 Self {
125 command: command.into(),
126 arguments: arguments.into_iter().map(|a| a.into()).collect(),
127 }
128 }
129}
130
131#[derive(Clone, Debug, serde::Deserialize, PartialEq)]
132#[serde(tag = "type", rename_all = "kebab-case")]
133pub enum CommandCacheConfig {
134 None,
136 CommandHash,
138 KeyCommand {
140 command: String,
141 #[serde(default)]
142 arguments: Vec<String>,
143 },
144 KeyScript { script: String },
146}
147
148#[derive(Clone, Debug, PartialEq)]
149pub enum Seed {
150 SqlFile {
151 path: std::path::PathBuf,
152 },
153 SqlFileGitRevision {
154 git_revision: String,
155 path: std::path::PathBuf,
156 },
157 Command {
158 command: Command,
159 cache: CommandCacheConfig,
160 },
161 Script {
162 script: String,
163 },
164 ContainerScript {
165 script: String,
166 },
167 CsvFile {
168 path: std::path::PathBuf,
169 table: pg_client::QualifiedTable,
170 delimiter: char,
171 },
172}
173
174impl Seed {
175 async fn load(
176 &self,
177 name: SeedName,
178 hash_chain: &mut HashChain,
179 backend: &ociman::Backend,
180 instance_name: &crate::InstanceName,
181 ) -> Result<LoadedSeed, LoadError> {
182 match self {
183 Seed::SqlFile { path } => {
184 let content =
185 std::fs::read_to_string(path).map_err(|source| LoadError::FileRead {
186 name: name.clone(),
187 path: path.clone(),
188 source,
189 })?;
190
191 hash_chain.update(&content);
192
193 Ok(LoadedSeed::SqlFile {
194 cache_status: CacheStatus::from_cache_key(
195 hash_chain.cache_key(),
196 backend,
197 instance_name,
198 )
199 .await,
200 name,
201 path: path.clone(),
202 content,
203 })
204 }
205 Seed::SqlFileGitRevision { path, git_revision } => {
206 let output =
207 git_proc::show::new(&format!("{git_revision}:{}", path.to_str().unwrap()))
208 .build()
209 .stdout_capture()
210 .stderr_capture()
211 .accept_nonzero_exit()
212 .run()
213 .await
214 .map_err(|error| LoadError::GitRevision {
215 name: name.clone(),
216 path: path.clone(),
217 git_revision: git_revision.clone(),
218 message: error.to_string(),
219 })?;
220
221 if output.status.success() {
222 let content = String::from_utf8(output.stdout).map_err(|error| {
223 LoadError::GitRevision {
224 name: name.clone(),
225 path: path.clone(),
226 git_revision: git_revision.clone(),
227 message: error.to_string(),
228 }
229 })?;
230
231 hash_chain.update(&content);
232
233 Ok(LoadedSeed::SqlFileGitRevision {
234 cache_status: CacheStatus::from_cache_key(
235 hash_chain.cache_key(),
236 backend,
237 instance_name,
238 )
239 .await,
240 name,
241 path: path.clone(),
242 git_revision: git_revision.clone(),
243 content,
244 })
245 } else {
246 let message = String::from_utf8(output.stderr).map_err(|error| {
247 LoadError::GitRevision {
248 name: name.clone(),
249 path: path.clone(),
250 git_revision: git_revision.clone(),
251 message: error.to_string(),
252 }
253 })?;
254 Err(LoadError::GitRevision {
255 name,
256 path: path.clone(),
257 git_revision: git_revision.clone(),
258 message,
259 })
260 }
261 }
262 Seed::Command { command, cache } => {
263 let cache_key_output = match cache {
264 CommandCacheConfig::None => {
265 hash_chain.stop();
266 None
267 }
268 CommandCacheConfig::CommandHash => {
269 hash_chain.update(&command.command);
270 for argument in &command.arguments {
271 hash_chain.update(argument);
272 }
273 None
274 }
275 CommandCacheConfig::KeyCommand {
276 command: key_command,
277 arguments: key_arguments,
278 } => {
279 let output = cmd_proc::Command::new(key_command)
280 .arguments(key_arguments)
281 .stdout_capture()
282 .stderr_capture()
283 .accept_nonzero_exit()
284 .run()
285 .await
286 .map_err(|error| LoadError::KeyCommand {
287 name: name.clone(),
288 command: key_command.clone(),
289 message: error.to_string(),
290 })?;
291
292 if output.status.success() {
293 hash_chain.update(&output.stdout);
294 Some(output.stdout)
295 } else {
296 let message = String::from_utf8(output.stderr).map_err(|error| {
297 LoadError::KeyCommand {
298 name: name.clone(),
299 command: key_command.clone(),
300 message: error.to_string(),
301 }
302 })?;
303 return Err(LoadError::KeyCommand {
304 name,
305 command: key_command.clone(),
306 message,
307 });
308 }
309 }
310 CommandCacheConfig::KeyScript { script: key_script } => {
311 let output = cmd_proc::Command::new("sh")
312 .arguments(["-e", "-c"])
313 .argument(key_script)
314 .stdout_capture()
315 .stderr_capture()
316 .accept_nonzero_exit()
317 .run()
318 .await
319 .map_err(|error| LoadError::KeyScript {
320 name: name.clone(),
321 message: error.to_string(),
322 })?;
323
324 if output.status.success() {
325 hash_chain.update(&output.stdout);
326 Some(output.stdout)
327 } else {
328 let message = String::from_utf8(output.stderr).map_err(|error| {
329 LoadError::KeyScript {
330 name: name.clone(),
331 message: error.to_string(),
332 }
333 })?;
334 return Err(LoadError::KeyScript { name, message });
335 }
336 }
337 };
338
339 Ok(LoadedSeed::Command {
340 cache_status: CacheStatus::from_cache_key(
341 hash_chain.cache_key(),
342 backend,
343 instance_name,
344 )
345 .await,
346 cache_key_output,
347 name,
348 command: command.clone(),
349 })
350 }
351 Seed::Script { script } => {
352 hash_chain.update(script);
353
354 Ok(LoadedSeed::Script {
355 cache_status: CacheStatus::from_cache_key(
356 hash_chain.cache_key(),
357 backend,
358 instance_name,
359 )
360 .await,
361 name,
362 script: script.clone(),
363 })
364 }
365 Seed::ContainerScript { script } => {
366 hash_chain.update(script);
367
368 Ok(LoadedSeed::ContainerScript {
369 cache_status: CacheStatus::from_cache_key(
370 hash_chain.cache_key(),
371 backend,
372 instance_name,
373 )
374 .await,
375 name,
376 script: script.clone(),
377 })
378 }
379 Seed::CsvFile {
380 path,
381 table,
382 delimiter,
383 } => {
384 let content =
385 std::fs::read_to_string(path).map_err(|source| LoadError::FileRead {
386 name: name.clone(),
387 path: path.clone(),
388 source,
389 })?;
390
391 hash_chain.update(table.schema.as_ref());
392 hash_chain.update(table.table.as_ref());
393 hash_chain.update(&content);
394
395 Ok(LoadedSeed::CsvFile {
396 cache_status: CacheStatus::from_cache_key(
397 hash_chain.cache_key(),
398 backend,
399 instance_name,
400 )
401 .await,
402 name,
403 path: path.clone(),
404 table: table.clone(),
405 delimiter: *delimiter,
406 content,
407 })
408 }
409 }
410 }
411}
412
413#[derive(Debug, thiserror::Error)]
414pub enum LoadError {
415 #[error("Failed to load seed {name}: could not read file {path}: {source}")]
416 FileRead {
417 name: SeedName,
418 path: std::path::PathBuf,
419 source: std::io::Error,
420 },
421 #[error(
422 "Failed to load seed {name}: could not read {path} at git revision {git_revision}: {message}"
423 )]
424 GitRevision {
425 name: SeedName,
426 path: std::path::PathBuf,
427 git_revision: String,
428 message: String,
429 },
430 #[error("Failed to load seed {name}: cache key command {command} failed: {message}")]
431 KeyCommand {
432 name: SeedName,
433 command: String,
434 message: String,
435 },
436 #[error("Failed to load seed {name}: cache key script failed: {message}")]
437 KeyScript { name: SeedName, message: String },
438}
439
440#[derive(Clone, Debug, PartialEq)]
441pub enum LoadedSeed {
442 SqlFile {
443 cache_status: CacheStatus,
444 name: SeedName,
445 path: std::path::PathBuf,
446 content: String,
447 },
448 SqlFileGitRevision {
449 cache_status: CacheStatus,
450 name: SeedName,
451 path: std::path::PathBuf,
452 git_revision: String,
453 content: String,
454 },
455 Command {
456 cache_status: CacheStatus,
457 cache_key_output: Option<Vec<u8>>,
458 name: SeedName,
459 command: Command,
460 },
461 Script {
462 cache_status: CacheStatus,
463 name: SeedName,
464 script: String,
465 },
466 ContainerScript {
467 cache_status: CacheStatus,
468 name: SeedName,
469 script: String,
470 },
471 CsvFile {
472 cache_status: CacheStatus,
473 name: SeedName,
474 path: std::path::PathBuf,
475 table: pg_client::QualifiedTable,
476 delimiter: char,
477 content: String,
478 },
479}
480
481impl LoadedSeed {
482 #[must_use]
483 pub fn cache_status(&self) -> &CacheStatus {
484 match self {
485 Self::SqlFile { cache_status, .. }
486 | Self::SqlFileGitRevision { cache_status, .. }
487 | Self::Command { cache_status, .. }
488 | Self::Script { cache_status, .. }
489 | Self::ContainerScript { cache_status, .. }
490 | Self::CsvFile { cache_status, .. } => cache_status,
491 }
492 }
493
494 #[must_use]
495 pub fn name(&self) -> &SeedName {
496 match self {
497 Self::SqlFile { name, .. }
498 | Self::SqlFileGitRevision { name, .. }
499 | Self::Command { name, .. }
500 | Self::Script { name, .. }
501 | Self::ContainerScript { name, .. }
502 | Self::CsvFile { name, .. } => name,
503 }
504 }
505
506 fn variant_name(&self) -> &'static str {
507 match self {
508 Self::SqlFile { .. } => "sql-file",
509 Self::SqlFileGitRevision { .. } => "sql-file-git-revision",
510 Self::Command { .. } => "command",
511 Self::Script { .. } => "script",
512 Self::ContainerScript { .. } => "container-script",
513 Self::CsvFile { .. } => "csv-file",
514 }
515 }
516}
517
518struct HashChain {
519 hasher: Option<sha2::Sha256>,
520}
521
522impl HashChain {
523 fn new() -> Self {
524 use sha2::Digest;
525
526 Self {
527 hasher: Some(sha2::Sha256::new()),
528 }
529 }
530
531 fn update(&mut self, bytes: impl AsRef<[u8]>) {
532 use sha2::Digest;
533
534 if let Some(ref mut hasher) = self.hasher {
535 hasher.update(bytes)
536 }
537 }
538
539 fn cache_key(&self) -> Option<CacheKey> {
540 use sha2::Digest;
541
542 self.hasher
543 .as_ref()
544 .map(|hasher| hasher.clone().finalize().into())
545 }
546
547 fn stop(&mut self) {
548 self.hasher = None
549 }
550}
551
552#[derive(Debug, PartialEq)]
553pub struct LoadedSeeds<'a> {
554 image: &'a crate::image::Image,
555 seeds: Vec<LoadedSeed>,
556}
557
558impl<'a> LoadedSeeds<'a> {
559 pub async fn load(
560 image: &'a crate::image::Image,
561 ssl_config: Option<&crate::definition::SslConfig>,
562 seeds: &indexmap::IndexMap<SeedName, Seed>,
563 backend: &ociman::Backend,
564 instance_name: &crate::InstanceName,
565 ) -> Result<Self, LoadError> {
566 let mut hash_chain = HashChain::new();
567 let mut loaded_seeds = Vec::new();
568
569 hash_chain.update(crate::VERSION_STR);
570 hash_chain.update(image.to_string());
571
572 match ssl_config {
573 Some(crate::definition::SslConfig::Generated { hostname }) => {
574 hash_chain.update("ssl:generated:");
575 hash_chain.update(hostname.as_str());
576 }
577 None => {
578 hash_chain.update("ssl:none");
579 }
580 }
581
582 for (name, seed) in seeds {
583 let loaded_seed = seed
584 .load(name.clone(), &mut hash_chain, backend, instance_name)
585 .await?;
586 loaded_seeds.push(loaded_seed);
587 }
588
589 Ok(Self {
590 image,
591 seeds: loaded_seeds,
592 })
593 }
594
595 pub fn iter_seeds(&self) -> impl Iterator<Item = &LoadedSeed> {
596 self.seeds.iter()
597 }
598
599 pub fn print(&self, instance_name: &crate::InstanceName) {
600 println!("Instance: {instance_name}");
601 println!("Image: {}", self.image);
602 println!("Version: {}", crate::VERSION_STR);
603 println!();
604
605 let mut table = comfy_table::Table::new();
606
607 table
608 .load_preset(comfy_table::presets::NOTHING)
609 .set_header(["Seed", "Type", "Status"]);
610
611 for seed in &self.seeds {
612 table.add_row([
613 seed.name().as_str(),
614 seed.variant_name(),
615 seed.cache_status().status_str(),
616 ]);
617 }
618
619 println!("{table}");
620 }
621
622 pub fn print_json(&self, instance_name: &crate::InstanceName) {
623 #[derive(serde::Serialize)]
624 struct Output<'a> {
625 instance: &'a str,
626 image: String,
627 version: &'a str,
628 seeds: Vec<SeedOutput<'a>>,
629 }
630
631 #[derive(serde::Serialize)]
632 struct SeedOutput<'a> {
633 name: &'a str,
634 r#type: &'a str,
635 status: &'a str,
636 #[serde(skip_serializing_if = "Option::is_none")]
637 reference: Option<String>,
638 }
639
640 let output = Output {
641 instance: &instance_name.to_string(),
642 image: self.image.to_string(),
643 version: crate::VERSION_STR,
644 seeds: self
645 .seeds
646 .iter()
647 .map(|seed| SeedOutput {
648 name: seed.name().as_str(),
649 r#type: seed.variant_name(),
650 status: seed.cache_status().status_str(),
651 reference: seed.cache_status().reference().map(|r| r.to_string()),
652 })
653 .collect(),
654 };
655
656 println!("{}", serde_json::to_string_pretty(&output).unwrap());
657 }
658}
659
660#[cfg(test)]
661mod test {
662 use super::*;
663
664 #[test]
665 fn test_seed_name_rejects_empty_string() {
666 assert_eq!("".parse::<SeedName>(), Err(SeedNameError));
667 assert_eq!(SeedName::try_from(""), Err(SeedNameError));
668 assert_eq!(SeedName::try_from(String::new()), Err(SeedNameError));
669 }
670
671 #[test]
672 fn test_seed_name_accepts_non_empty_string() {
673 assert_eq!(
674 "valid-name".parse::<SeedName>(),
675 Ok(SeedName("valid-name".to_string()))
676 );
677 assert_eq!(
678 SeedName::try_from("valid-name"),
679 Ok(SeedName("valid-name".to_string()))
680 );
681 assert_eq!(
682 SeedName::try_from("valid-name".to_string()),
683 Ok(SeedName("valid-name".to_string()))
684 );
685 }
686
687 #[test]
688 fn test_seed_name_display() {
689 let name: SeedName = "test-seed".parse().unwrap();
690 assert_eq!(name.to_string(), "test-seed");
691 assert_eq!(name.as_str(), "test-seed");
692 }
693
694 #[test]
695 fn test_cache_status_uncacheable() {
696 let loaded_seed = LoadedSeed::Command {
697 cache_status: CacheStatus::Uncacheable,
698 cache_key_output: None,
699 name: "run-migrations".parse().unwrap(),
700 command: Command::new("migrate", ["up"]),
701 };
702
703 assert!(loaded_seed.cache_status().reference().is_none());
704 assert!(!loaded_seed.cache_status().is_hit());
705 }
706
707 #[test]
708 fn test_cache_status_miss() {
709 let reference: ociman::Reference =
710 "pg-ephemeral/main:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
711 .parse()
712 .unwrap();
713
714 let loaded_seed = LoadedSeed::SqlFile {
715 cache_status: CacheStatus::Miss {
716 reference: reference.clone(),
717 },
718 name: "schema".parse().unwrap(),
719 path: "schema.sql".into(),
720 content: "CREATE TABLE test();".to_string(),
721 };
722
723 assert_eq!(loaded_seed.cache_status().reference(), Some(&reference));
724 assert!(!loaded_seed.cache_status().is_hit());
725 }
726
727 #[test]
728 fn test_cache_status_hit() {
729 let reference: ociman::Reference =
730 "pg-ephemeral/main:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
731 .parse()
732 .unwrap();
733
734 let loaded_seed = LoadedSeed::SqlFile {
735 cache_status: CacheStatus::Hit {
736 reference: reference.clone(),
737 },
738 name: "schema".parse().unwrap(),
739 path: "schema.sql".into(),
740 content: "CREATE TABLE test();".to_string(),
741 };
742
743 assert_eq!(loaded_seed.cache_status().reference(), Some(&reference));
744 assert!(loaded_seed.cache_status().is_hit());
745 }
746}