1use git_proc::Build;
2
3type CacheKey = [u8; 32];
4
5#[derive(Clone, Debug, PartialEq)]
6pub enum CacheStatus {
7 Hit { reference: ociman::Reference },
8 Miss { reference: ociman::Reference },
9 Uncacheable,
10}
11
12impl CacheStatus {
13 async fn from_cache_key(
14 cache_key: Option<CacheKey>,
15 backend: &ociman::Backend,
16 instance_name: &crate::InstanceName,
17 ) -> Self {
18 match cache_key {
19 Some(key) => {
20 let reference = format!("pg-ephemeral/{}:{}", instance_name, hex::encode(key))
21 .parse()
22 .unwrap();
23 if backend.is_image_present(&reference).await {
24 Self::Hit { reference }
25 } else {
26 Self::Miss { reference }
27 }
28 }
29 None => Self::Uncacheable,
30 }
31 }
32
33 #[must_use]
34 pub fn reference(&self) -> Option<&ociman::Reference> {
35 match self {
36 Self::Hit { reference } | Self::Miss { reference } => Some(reference),
37 Self::Uncacheable => None,
38 }
39 }
40
41 #[must_use]
42 pub fn is_hit(&self) -> bool {
43 matches!(self, Self::Hit { .. })
44 }
45
46 #[must_use]
47 pub fn status_str(&self) -> &'static str {
48 match self {
49 Self::Hit { .. } => "hit",
50 Self::Miss { .. } => "miss",
51 Self::Uncacheable => "uncacheable",
52 }
53 }
54}
55
56#[derive(Clone, Debug, Hash, Eq, PartialEq, serde::Deserialize, serde::Serialize)]
57#[serde(try_from = "String")]
58pub struct SeedName(String);
59
60impl SeedName {
61 #[must_use]
62 pub fn as_str(&self) -> &str {
63 &self.0
64 }
65}
66
67impl std::fmt::Display for SeedName {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 write!(f, "{}", self.0)
70 }
71}
72
73#[derive(Debug, PartialEq, Eq, thiserror::Error)]
74#[error("Seed name cannot be empty")]
75pub struct SeedNameError;
76
77#[derive(Debug, PartialEq, Eq, thiserror::Error)]
78#[error("Duplicate seed name: {0}")]
79pub struct DuplicateSeedName(pub SeedName);
80
81impl std::str::FromStr for SeedName {
82 type Err = SeedNameError;
83
84 fn from_str(value: &str) -> Result<Self, Self::Err> {
85 if value.is_empty() {
86 Err(SeedNameError)
87 } else {
88 Ok(Self(value.to_string()))
89 }
90 }
91}
92
93impl TryFrom<String> for SeedName {
94 type Error = SeedNameError;
95
96 fn try_from(value: String) -> Result<Self, Self::Error> {
97 if value.is_empty() {
98 Err(SeedNameError)
99 } else {
100 Ok(Self(value))
101 }
102 }
103}
104
105impl TryFrom<&str> for SeedName {
106 type Error = SeedNameError;
107
108 fn try_from(value: &str) -> Result<Self, Self::Error> {
109 value.parse()
110 }
111}
112
113#[derive(Clone, Debug, PartialEq)]
114pub struct Command {
115 pub command: String,
116 pub arguments: Vec<String>,
117}
118
119impl Command {
120 pub fn new(
121 command: impl Into<String>,
122 arguments: impl IntoIterator<Item = impl Into<String>>,
123 ) -> Self {
124 Self {
125 command: command.into(),
126 arguments: arguments.into_iter().map(|a| a.into()).collect(),
127 }
128 }
129}
130
131#[derive(Clone, Debug, serde::Deserialize, PartialEq)]
132#[serde(tag = "type", rename_all = "kebab-case")]
133pub enum CommandCacheConfig {
134 None,
136 CommandHash,
138 KeyCommand {
140 command: String,
141 #[serde(default)]
142 arguments: Vec<String>,
143 },
144 KeyScript { script: String },
146}
147
148#[derive(Clone, Debug, PartialEq)]
149pub enum Seed {
150 SqlFile {
151 path: std::path::PathBuf,
152 },
153 SqlFileGitRevision {
154 git_revision: String,
155 path: std::path::PathBuf,
156 },
157 Command {
158 command: Command,
159 cache: CommandCacheConfig,
160 },
161 Script {
162 script: String,
163 },
164 ContainerScript {
165 script: String,
166 },
167 CsvFile {
168 path: std::path::PathBuf,
169 table: pg_client::QualifiedTable,
170 },
171}
172
173impl Seed {
174 async fn load(
175 &self,
176 name: SeedName,
177 hash_chain: &mut HashChain,
178 backend: &ociman::Backend,
179 instance_name: &crate::InstanceName,
180 ) -> Result<LoadedSeed, LoadError> {
181 match self {
182 Seed::SqlFile { path } => {
183 let content =
184 std::fs::read_to_string(path).map_err(|source| LoadError::FileRead {
185 name: name.clone(),
186 path: path.clone(),
187 source,
188 })?;
189
190 hash_chain.update(&content);
191
192 Ok(LoadedSeed::SqlFile {
193 cache_status: CacheStatus::from_cache_key(
194 hash_chain.cache_key(),
195 backend,
196 instance_name,
197 )
198 .await,
199 name,
200 path: path.clone(),
201 content,
202 })
203 }
204 Seed::SqlFileGitRevision { path, git_revision } => {
205 let output =
206 git_proc::show::new(&format!("{git_revision}:{}", path.to_str().unwrap()))
207 .build()
208 .stdout_capture()
209 .stderr_capture()
210 .accept_nonzero_exit()
211 .run()
212 .await
213 .map_err(|error| LoadError::GitRevision {
214 name: name.clone(),
215 path: path.clone(),
216 git_revision: git_revision.clone(),
217 message: error.to_string(),
218 })?;
219
220 if output.status.success() {
221 let content = String::from_utf8(output.stdout).map_err(|error| {
222 LoadError::GitRevision {
223 name: name.clone(),
224 path: path.clone(),
225 git_revision: git_revision.clone(),
226 message: error.to_string(),
227 }
228 })?;
229
230 hash_chain.update(&content);
231
232 Ok(LoadedSeed::SqlFileGitRevision {
233 cache_status: CacheStatus::from_cache_key(
234 hash_chain.cache_key(),
235 backend,
236 instance_name,
237 )
238 .await,
239 name,
240 path: path.clone(),
241 git_revision: git_revision.clone(),
242 content,
243 })
244 } else {
245 let message = String::from_utf8(output.stderr).map_err(|error| {
246 LoadError::GitRevision {
247 name: name.clone(),
248 path: path.clone(),
249 git_revision: git_revision.clone(),
250 message: error.to_string(),
251 }
252 })?;
253 Err(LoadError::GitRevision {
254 name,
255 path: path.clone(),
256 git_revision: git_revision.clone(),
257 message,
258 })
259 }
260 }
261 Seed::Command { command, cache } => {
262 let cache_key_output = match cache {
263 CommandCacheConfig::None => {
264 hash_chain.stop();
265 None
266 }
267 CommandCacheConfig::CommandHash => {
268 hash_chain.update(&command.command);
269 for argument in &command.arguments {
270 hash_chain.update(argument);
271 }
272 None
273 }
274 CommandCacheConfig::KeyCommand {
275 command: key_command,
276 arguments: key_arguments,
277 } => {
278 let output = cmd_proc::Command::new(key_command)
279 .arguments(key_arguments)
280 .stdout_capture()
281 .stderr_capture()
282 .accept_nonzero_exit()
283 .run()
284 .await
285 .map_err(|error| LoadError::KeyCommand {
286 name: name.clone(),
287 command: key_command.clone(),
288 message: error.to_string(),
289 })?;
290
291 if output.status.success() {
292 hash_chain.update(&output.stdout);
293 Some(output.stdout)
294 } else {
295 let message = String::from_utf8(output.stderr).map_err(|error| {
296 LoadError::KeyCommand {
297 name: name.clone(),
298 command: key_command.clone(),
299 message: error.to_string(),
300 }
301 })?;
302 return Err(LoadError::KeyCommand {
303 name,
304 command: key_command.clone(),
305 message,
306 });
307 }
308 }
309 CommandCacheConfig::KeyScript { script: key_script } => {
310 let output = cmd_proc::Command::new("sh")
311 .arguments(["-e", "-c"])
312 .argument(key_script)
313 .stdout_capture()
314 .stderr_capture()
315 .accept_nonzero_exit()
316 .run()
317 .await
318 .map_err(|error| LoadError::KeyScript {
319 name: name.clone(),
320 message: error.to_string(),
321 })?;
322
323 if output.status.success() {
324 hash_chain.update(&output.stdout);
325 Some(output.stdout)
326 } else {
327 let message = String::from_utf8(output.stderr).map_err(|error| {
328 LoadError::KeyScript {
329 name: name.clone(),
330 message: error.to_string(),
331 }
332 })?;
333 return Err(LoadError::KeyScript { name, message });
334 }
335 }
336 };
337
338 Ok(LoadedSeed::Command {
339 cache_status: CacheStatus::from_cache_key(
340 hash_chain.cache_key(),
341 backend,
342 instance_name,
343 )
344 .await,
345 cache_key_output,
346 name,
347 command: command.clone(),
348 })
349 }
350 Seed::Script { script } => {
351 hash_chain.update(script);
352
353 Ok(LoadedSeed::Script {
354 cache_status: CacheStatus::from_cache_key(
355 hash_chain.cache_key(),
356 backend,
357 instance_name,
358 )
359 .await,
360 name,
361 script: script.clone(),
362 })
363 }
364 Seed::ContainerScript { script } => {
365 hash_chain.update(script);
366
367 Ok(LoadedSeed::ContainerScript {
368 cache_status: CacheStatus::from_cache_key(
369 hash_chain.cache_key(),
370 backend,
371 instance_name,
372 )
373 .await,
374 name,
375 script: script.clone(),
376 })
377 }
378 Seed::CsvFile { path, table } => {
379 let content =
380 std::fs::read_to_string(path).map_err(|source| LoadError::FileRead {
381 name: name.clone(),
382 path: path.clone(),
383 source,
384 })?;
385
386 hash_chain.update(table.schema.as_ref());
387 hash_chain.update(table.table.as_ref());
388 hash_chain.update(&content);
389
390 Ok(LoadedSeed::CsvFile {
391 cache_status: CacheStatus::from_cache_key(
392 hash_chain.cache_key(),
393 backend,
394 instance_name,
395 )
396 .await,
397 name,
398 path: path.clone(),
399 table: table.clone(),
400 content,
401 })
402 }
403 }
404 }
405}
406
407#[derive(Debug, thiserror::Error)]
408pub enum LoadError {
409 #[error("Failed to load seed {name}: could not read file {path}: {source}")]
410 FileRead {
411 name: SeedName,
412 path: std::path::PathBuf,
413 source: std::io::Error,
414 },
415 #[error(
416 "Failed to load seed {name}: could not read {path} at git revision {git_revision}: {message}"
417 )]
418 GitRevision {
419 name: SeedName,
420 path: std::path::PathBuf,
421 git_revision: String,
422 message: String,
423 },
424 #[error("Failed to load seed {name}: cache key command {command} failed: {message}")]
425 KeyCommand {
426 name: SeedName,
427 command: String,
428 message: String,
429 },
430 #[error("Failed to load seed {name}: cache key script failed: {message}")]
431 KeyScript { name: SeedName, message: String },
432}
433
434#[derive(Clone, Debug, PartialEq)]
435pub enum LoadedSeed {
436 SqlFile {
437 cache_status: CacheStatus,
438 name: SeedName,
439 path: std::path::PathBuf,
440 content: String,
441 },
442 SqlFileGitRevision {
443 cache_status: CacheStatus,
444 name: SeedName,
445 path: std::path::PathBuf,
446 git_revision: String,
447 content: String,
448 },
449 Command {
450 cache_status: CacheStatus,
451 cache_key_output: Option<Vec<u8>>,
452 name: SeedName,
453 command: Command,
454 },
455 Script {
456 cache_status: CacheStatus,
457 name: SeedName,
458 script: String,
459 },
460 ContainerScript {
461 cache_status: CacheStatus,
462 name: SeedName,
463 script: String,
464 },
465 CsvFile {
466 cache_status: CacheStatus,
467 name: SeedName,
468 path: std::path::PathBuf,
469 table: pg_client::QualifiedTable,
470 content: String,
471 },
472}
473
474impl LoadedSeed {
475 #[must_use]
476 pub fn cache_status(&self) -> &CacheStatus {
477 match self {
478 Self::SqlFile { cache_status, .. }
479 | Self::SqlFileGitRevision { cache_status, .. }
480 | Self::Command { cache_status, .. }
481 | Self::Script { cache_status, .. }
482 | Self::ContainerScript { cache_status, .. }
483 | Self::CsvFile { cache_status, .. } => cache_status,
484 }
485 }
486
487 #[must_use]
488 pub fn name(&self) -> &SeedName {
489 match self {
490 Self::SqlFile { name, .. }
491 | Self::SqlFileGitRevision { name, .. }
492 | Self::Command { name, .. }
493 | Self::Script { name, .. }
494 | Self::ContainerScript { name, .. }
495 | Self::CsvFile { name, .. } => name,
496 }
497 }
498
499 fn variant_name(&self) -> &'static str {
500 match self {
501 Self::SqlFile { .. } => "sql-file",
502 Self::SqlFileGitRevision { .. } => "sql-file-git-revision",
503 Self::Command { .. } => "command",
504 Self::Script { .. } => "script",
505 Self::ContainerScript { .. } => "container-script",
506 Self::CsvFile { .. } => "csv-file",
507 }
508 }
509}
510
511struct HashChain {
512 hasher: Option<sha2::Sha256>,
513}
514
515impl HashChain {
516 fn new() -> Self {
517 use sha2::Digest;
518
519 Self {
520 hasher: Some(sha2::Sha256::new()),
521 }
522 }
523
524 fn update(&mut self, bytes: impl AsRef<[u8]>) {
525 use sha2::Digest;
526
527 if let Some(ref mut hasher) = self.hasher {
528 hasher.update(bytes)
529 }
530 }
531
532 fn cache_key(&self) -> Option<CacheKey> {
533 use sha2::Digest;
534
535 self.hasher
536 .as_ref()
537 .map(|hasher| hasher.clone().finalize().into())
538 }
539
540 fn stop(&mut self) {
541 self.hasher = None
542 }
543}
544
545#[derive(Debug, PartialEq)]
546pub struct LoadedSeeds<'a> {
547 image: &'a crate::image::Image,
548 seeds: Vec<LoadedSeed>,
549}
550
551impl<'a> LoadedSeeds<'a> {
552 pub async fn load(
553 image: &'a crate::image::Image,
554 ssl_config: Option<&crate::definition::SslConfig>,
555 seeds: &indexmap::IndexMap<SeedName, Seed>,
556 backend: &ociman::Backend,
557 instance_name: &crate::InstanceName,
558 ) -> Result<Self, LoadError> {
559 let mut hash_chain = HashChain::new();
560 let mut loaded_seeds = Vec::new();
561
562 hash_chain.update(crate::VERSION_STR);
563 hash_chain.update(image.to_string());
564
565 match ssl_config {
566 Some(crate::definition::SslConfig::Generated { hostname }) => {
567 hash_chain.update("ssl:generated:");
568 hash_chain.update(hostname.as_str());
569 }
570 None => {
571 hash_chain.update("ssl:none");
572 }
573 }
574
575 for (name, seed) in seeds {
576 let loaded_seed = seed
577 .load(name.clone(), &mut hash_chain, backend, instance_name)
578 .await?;
579 loaded_seeds.push(loaded_seed);
580 }
581
582 Ok(Self {
583 image,
584 seeds: loaded_seeds,
585 })
586 }
587
588 pub fn iter_seeds(&self) -> impl Iterator<Item = &LoadedSeed> {
589 self.seeds.iter()
590 }
591
592 pub fn print(&self, instance_name: &crate::InstanceName) {
593 println!("Instance: {instance_name}");
594 println!("Image: {}", self.image);
595 println!("Version: {}", crate::VERSION_STR);
596 println!();
597
598 let mut table = comfy_table::Table::new();
599
600 table
601 .load_preset(comfy_table::presets::NOTHING)
602 .set_header(["Seed", "Type", "Status"]);
603
604 for seed in &self.seeds {
605 table.add_row([
606 seed.name().as_str(),
607 seed.variant_name(),
608 seed.cache_status().status_str(),
609 ]);
610 }
611
612 println!("{table}");
613 }
614
615 pub fn print_json(&self, instance_name: &crate::InstanceName) {
616 #[derive(serde::Serialize)]
617 struct Output<'a> {
618 instance: &'a str,
619 image: String,
620 version: &'a str,
621 seeds: Vec<SeedOutput<'a>>,
622 }
623
624 #[derive(serde::Serialize)]
625 struct SeedOutput<'a> {
626 name: &'a str,
627 r#type: &'a str,
628 status: &'a str,
629 #[serde(skip_serializing_if = "Option::is_none")]
630 reference: Option<String>,
631 }
632
633 let output = Output {
634 instance: &instance_name.to_string(),
635 image: self.image.to_string(),
636 version: crate::VERSION_STR,
637 seeds: self
638 .seeds
639 .iter()
640 .map(|seed| SeedOutput {
641 name: seed.name().as_str(),
642 r#type: seed.variant_name(),
643 status: seed.cache_status().status_str(),
644 reference: seed.cache_status().reference().map(|r| r.to_string()),
645 })
646 .collect(),
647 };
648
649 println!("{}", serde_json::to_string_pretty(&output).unwrap());
650 }
651}
652
653#[cfg(test)]
654mod test {
655 use super::*;
656
657 #[test]
658 fn test_seed_name_rejects_empty_string() {
659 assert_eq!("".parse::<SeedName>(), Err(SeedNameError));
660 assert_eq!(SeedName::try_from(""), Err(SeedNameError));
661 assert_eq!(SeedName::try_from(String::new()), Err(SeedNameError));
662 }
663
664 #[test]
665 fn test_seed_name_accepts_non_empty_string() {
666 assert_eq!(
667 "valid-name".parse::<SeedName>(),
668 Ok(SeedName("valid-name".to_string()))
669 );
670 assert_eq!(
671 SeedName::try_from("valid-name"),
672 Ok(SeedName("valid-name".to_string()))
673 );
674 assert_eq!(
675 SeedName::try_from("valid-name".to_string()),
676 Ok(SeedName("valid-name".to_string()))
677 );
678 }
679
680 #[test]
681 fn test_seed_name_display() {
682 let name: SeedName = "test-seed".parse().unwrap();
683 assert_eq!(name.to_string(), "test-seed");
684 assert_eq!(name.as_str(), "test-seed");
685 }
686
687 #[test]
688 fn test_cache_status_uncacheable() {
689 let loaded_seed = LoadedSeed::Command {
690 cache_status: CacheStatus::Uncacheable,
691 cache_key_output: None,
692 name: "run-migrations".parse().unwrap(),
693 command: Command::new("migrate", ["up"]),
694 };
695
696 assert!(loaded_seed.cache_status().reference().is_none());
697 assert!(!loaded_seed.cache_status().is_hit());
698 }
699
700 #[test]
701 fn test_cache_status_miss() {
702 let reference: ociman::Reference =
703 "pg-ephemeral/main:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
704 .parse()
705 .unwrap();
706
707 let loaded_seed = LoadedSeed::SqlFile {
708 cache_status: CacheStatus::Miss {
709 reference: reference.clone(),
710 },
711 name: "schema".parse().unwrap(),
712 path: "schema.sql".into(),
713 content: "CREATE TABLE test();".to_string(),
714 };
715
716 assert_eq!(loaded_seed.cache_status().reference(), Some(&reference));
717 assert!(!loaded_seed.cache_status().is_hit());
718 }
719
720 #[test]
721 fn test_cache_status_hit() {
722 let reference: ociman::Reference =
723 "pg-ephemeral/main:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
724 .parse()
725 .unwrap();
726
727 let loaded_seed = LoadedSeed::SqlFile {
728 cache_status: CacheStatus::Hit {
729 reference: reference.clone(),
730 },
731 name: "schema".parse().unwrap(),
732 path: "schema.sql".into(),
733 content: "CREATE TABLE test();".to_string(),
734 };
735
736 assert_eq!(loaded_seed.cache_status().reference(), Some(&reference));
737 assert!(loaded_seed.cache_status().is_hit());
738 }
739}