1use crate::hash::DEFAULT_S3_MULTIPART_PART_SIZE;
6use crate::s3_check_cache::S3CheckCache;
7use std::any::Any;
8use std::path::PathBuf;
9use std::sync::Arc;
10
11use async_trait::async_trait;
12
13#[derive(Debug, PartialEq)]
15pub enum CopyResult {
16 ServerSideCopy,
18 NotSupported,
20}
21
22#[async_trait]
31pub trait AsyncDataCache: Send + Sync {
32 fn object_key(&self, hash: &str, algorithm: &str) -> String;
33 fn as_any(&self) -> &dyn Any;
34 async fn object_exists(&self, hash: &str, algorithm: &str) -> std::io::Result<bool>;
35 async fn put_object(
36 &self,
37 hash: &str,
38 algorithm: &str,
39 data: Vec<u8>,
40 ) -> std::io::Result<String>;
41 async fn get_object(&self, hash: &str, algorithm: &str) -> std::io::Result<Vec<u8>>;
42
43 async fn copy_from(
45 &self,
46 _source: &dyn AsyncDataCache,
47 _hash: &str,
48 _algorithm: &str,
49 ) -> std::io::Result<CopyResult> {
50 Ok(CopyResult::NotSupported)
51 }
52
53 fn multipart_part_size(&self) -> usize {
59 crate::hash::DEFAULT_S3_MULTIPART_PART_SIZE
60 }
61
62 fn as_multipart(&self) -> Option<&dyn MultipartDataCache> {
65 None
66 }
67
68 fn as_range_read(&self) -> Option<&dyn RangeReadDataCache> {
71 None
72 }
73
74 async fn copy_object_to_file(
77 &self,
78 hash: &str,
79 algorithm: &str,
80 dest: &std::path::Path,
81 ) -> std::io::Result<u64> {
82 let data = self.get_object(hash, algorithm).await?;
83 let len = data.len() as u64;
84 tokio::fs::write(dest, &data).await?;
85 Ok(len)
86 }
87
88 async fn write_object_to_file_at_offset(
91 &self,
92 hash: &str,
93 algorithm: &str,
94 dest: &std::path::Path,
95 offset: u64,
96 ) -> std::io::Result<u64> {
97 let data = self.get_object(hash, algorithm).await?;
98 let len = data.len() as u64;
99 let dest = dest.to_path_buf();
100 tokio::task::spawn_blocking(move || {
101 use std::io::{Seek, SeekFrom, Write};
102 let mut f = std::fs::OpenOptions::new().write(true).open(&dest)?;
103 f.seek(SeekFrom::Start(offset))?;
104 f.write_all(&data)?;
105 Ok::<_, std::io::Error>(len)
106 })
107 .await
108 .map_err(std::io::Error::other)?
109 }
110}
111
112#[async_trait]
114pub trait MultipartDataCache: AsyncDataCache {
115 async fn create_multipart_upload(&self, hash: &str, algorithm: &str)
116 -> std::io::Result<String>;
117 async fn upload_part(
118 &self,
119 hash: &str,
120 algorithm: &str,
121 upload_id: &str,
122 part_number: i32,
123 data: Vec<u8>,
124 ) -> std::io::Result<String>;
125 async fn complete_multipart_upload(
126 &self,
127 hash: &str,
128 algorithm: &str,
129 upload_id: &str,
130 parts: Vec<(i32, String)>,
131 ) -> std::io::Result<()>;
132 async fn abort_multipart_upload(
133 &self,
134 hash: &str,
135 algorithm: &str,
136 upload_id: &str,
137 ) -> std::io::Result<()>;
138}
139
140#[async_trait]
142pub trait RangeReadDataCache: AsyncDataCache {
143 async fn get_object_range(
144 &self,
145 hash: &str,
146 algorithm: &str,
147 start: u64,
148 end: u64,
149 ) -> std::io::Result<Vec<u8>>;
150
151 async fn stream_range_to_file_at_offset(
154 &self,
155 hash: &str,
156 algorithm: &str,
157 range_start: u64,
158 range_end: u64,
159 dest: &std::path::Path,
160 file_offset: u64,
161 ) -> std::io::Result<u64> {
162 let data = self
163 .get_object_range(hash, algorithm, range_start, range_end)
164 .await?;
165 let len = data.len() as u64;
166 let dest = dest.to_path_buf();
167 tokio::task::spawn_blocking(move || {
168 use std::io::{Seek, SeekFrom, Write};
169 let mut f = std::fs::OpenOptions::new().write(true).open(&dest)?;
170 f.seek(SeekFrom::Start(file_offset))?;
171 f.write_all(&data)?;
172 Ok::<_, std::io::Error>(len)
173 })
174 .await
175 .map_err(std::io::Error::other)?
176 }
177}
178
179pub struct FileSystemDataCache {
181 pub root_path: PathBuf,
182}
183
184impl FileSystemDataCache {
185 pub fn new(root_path: impl Into<PathBuf>) -> crate::Result<Self> {
186 let root_path = root_path.into();
187 if !root_path.is_absolute() {
188 return Err(crate::SnapshotError::Validation(
189 "root_path must be absolute".into(),
190 ));
191 }
192 std::fs::create_dir_all(&root_path)?;
193 Ok(Self { root_path })
194 }
195
196 fn object_path(&self, hash: &str, algorithm: &str) -> PathBuf {
197 self.root_path.join(format!("{hash}.{algorithm}"))
198 }
199}
200
201#[async_trait]
202impl AsyncDataCache for FileSystemDataCache {
203 fn object_key(&self, hash: &str, algorithm: &str) -> String {
204 self.object_path(hash, algorithm)
205 .to_string_lossy()
206 .into_owned()
207 }
208
209 fn as_any(&self) -> &dyn Any {
210 self
211 }
212
213 async fn object_exists(&self, hash: &str, algorithm: &str) -> std::io::Result<bool> {
214 Ok(self.object_path(hash, algorithm).exists())
215 }
216
217 async fn put_object(
218 &self,
219 hash: &str,
220 algorithm: &str,
221 data: Vec<u8>,
222 ) -> std::io::Result<String> {
223 let path = self.object_path(hash, algorithm);
224 tokio::fs::write(&path, &data).await?;
225 Ok(path.to_string_lossy().into_owned())
226 }
227
228 async fn get_object(&self, hash: &str, algorithm: &str) -> std::io::Result<Vec<u8>> {
229 tokio::fs::read(self.object_path(hash, algorithm)).await
230 }
231
232 async fn copy_object_to_file(
233 &self,
234 hash: &str,
235 algorithm: &str,
236 dest: &std::path::Path,
237 ) -> std::io::Result<u64> {
238 let src = self.object_path(hash, algorithm);
239 let dest = dest.to_path_buf();
240 tokio::task::spawn_blocking(move || std::fs::copy(&src, &dest))
241 .await
242 .map_err(std::io::Error::other)?
243 }
244
245 async fn write_object_to_file_at_offset(
246 &self,
247 hash: &str,
248 algorithm: &str,
249 dest: &std::path::Path,
250 offset: u64,
251 ) -> std::io::Result<u64> {
252 let src = self.object_path(hash, algorithm);
253 let dest = dest.to_path_buf();
254 tokio::task::spawn_blocking(move || {
255 use std::io::{Seek, SeekFrom};
256 let mut src_file = std::fs::File::open(&src)?;
257 let src_len = src_file.metadata()?.len();
258 let mut dest_file = std::fs::OpenOptions::new().write(true).open(&dest)?;
259 dest_file.seek(SeekFrom::Start(offset))?;
260 std::io::copy(&mut src_file, &mut dest_file)?;
261 Ok::<_, std::io::Error>(src_len)
262 })
263 .await
264 .map_err(std::io::Error::other)?
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271 use tempfile::TempDir;
272
273 #[test]
274 fn new_creates_directory() {
275 let tmp = TempDir::new().unwrap();
276 let cache_dir = tmp.path().join("cache");
277 assert!(!cache_dir.exists());
278 let _cache = FileSystemDataCache::new(&cache_dir).unwrap();
279 assert!(cache_dir.exists());
280 }
281
282 #[test]
283 fn rejects_relative_path() {
284 let result = FileSystemDataCache::new("relative/path");
285 assert!(result.is_err());
286 }
287
288 #[tokio::test]
289 async fn put_and_get_object() {
290 let tmp = TempDir::new().unwrap();
291 let cache = FileSystemDataCache::new(tmp.path().join("cache")).unwrap();
292 cache
293 .put_object("abc123", "xxh128", b"hello".to_vec())
294 .await
295 .unwrap();
296 let data = cache.get_object("abc123", "xxh128").await.unwrap();
297 assert_eq!(data, b"hello");
298 }
299
300 #[tokio::test]
301 async fn object_exists_check() {
302 let tmp = TempDir::new().unwrap();
303 let cache = FileSystemDataCache::new(tmp.path().join("cache")).unwrap();
304 assert!(!cache.object_exists("abc123", "xxh128").await.unwrap());
305 cache
306 .put_object("abc123", "xxh128", b"data".to_vec())
307 .await
308 .unwrap();
309 assert!(cache.object_exists("abc123", "xxh128").await.unwrap());
310 }
311
312 #[test]
313 fn object_key_format() {
314 let tmp = TempDir::new().unwrap();
315 let cache = FileSystemDataCache::new(tmp.path().join("cache")).unwrap();
316 let key = AsyncDataCache::object_key(&cache, "abc123", "xxh128");
317 assert!(key.ends_with("abc123.xxh128"));
318 }
319
320 #[tokio::test]
321 async fn get_nonexistent_returns_error() {
322 let tmp = TempDir::new().unwrap();
323 let cache = FileSystemDataCache::new(tmp.path().join("cache")).unwrap();
324 assert!(cache.get_object("missing", "xxh128").await.is_err());
325 }
326
327 #[tokio::test]
328 async fn copy_from_default_returns_not_supported() {
329 let tmp = TempDir::new().unwrap();
330 let src = FileSystemDataCache::new(tmp.path().join("src")).unwrap();
331 let dst = FileSystemDataCache::new(tmp.path().join("dst")).unwrap();
332 let result = dst.copy_from(&src, "abc", "xxh128").await.unwrap();
333 assert_eq!(result, CopyResult::NotSupported);
334 }
335
336 #[test]
337 fn format_copy_source_regular_bucket() {
338 let result = super::format_copy_source("my-bucket", "Data/abc123.xxh128");
339 assert_eq!(result, "my-bucket/Data/abc123.xxh128");
340 }
341
342 #[test]
343 fn format_copy_source_access_point_arn() {
344 let arn = "arn:aws:s3:us-west-2:123456789012:accesspoint/my-access-point";
345 let result = super::format_copy_source(arn, "Data/abc123.xxh128");
346 assert_eq!(
347 result,
348 "arn:aws:s3:us-west-2:123456789012:accesspoint/my-access-point/object/Data/abc123.xxh128"
349 );
350 }
351
352 #[test]
353 fn format_copy_source_outpost_arn() {
354 let arn = "arn:aws:s3-outposts:us-west-2:123456789012:outpost/my-outpost";
355 let result = super::format_copy_source(arn, "Data/abc123.xxh128");
356 assert_eq!(
357 result,
358 "arn:aws:s3-outposts:us-west-2:123456789012:outpost/my-outpost/object/Data/abc123.xxh128"
359 );
360 }
361}
362
363use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
364use tracing::warn;
365
366fn rand_u64() -> u64 {
367 use std::collections::hash_map::RandomState;
368 use std::hash::{BuildHasher, Hasher};
369 let s = RandomState::new();
370 let mut h = s.build_hasher();
371 h.write_u64(
372 std::time::SystemTime::now()
373 .duration_since(std::time::UNIX_EPOCH)
374 .unwrap_or_default()
375 .as_nanos() as u64,
376 );
377 h.finish()
378}
379
380pub struct CacheValidationState {
381 hit_count: AtomicU64,
382 invalidated: AtomicBool,
383}
384
385impl Default for CacheValidationState {
386 fn default() -> Self {
387 Self::new()
388 }
389}
390
391impl CacheValidationState {
392 pub fn new() -> Self {
393 Self {
394 hit_count: AtomicU64::new(0),
395 invalidated: AtomicBool::new(false),
396 }
397 }
398
399 fn should_verify(&self) -> bool {
400 let count = self.hit_count.fetch_add(1, Ordering::Relaxed) + 1;
401 if count <= 100 {
402 return true;
403 }
404 rand_u64().is_multiple_of(100)
405 }
406
407 fn invalidate(&self) {
408 self.invalidated.store(true, Ordering::Relaxed);
409 }
410
411 pub fn is_invalidated(&self) -> bool {
412 self.invalidated.load(Ordering::Relaxed)
413 }
414}
415
416pub struct S3DataCache {
418 bucket: String,
419 key_prefix: String,
420 client: aws_sdk_s3::Client,
421 multipart_part_size: usize,
422 s3_check_cache: Option<Arc<S3CheckCache>>,
423 force_s3_check: bool,
424 expected_bucket_owner: Option<String>,
425 cache_validation: CacheValidationState,
426}
427
428impl S3DataCache {
429 pub fn new(bucket: String, key_prefix: String, client: aws_sdk_s3::Client) -> Self {
430 Self {
431 bucket,
432 key_prefix,
433 client,
434 multipart_part_size: DEFAULT_S3_MULTIPART_PART_SIZE,
435 s3_check_cache: None,
436 force_s3_check: false,
437 expected_bucket_owner: None,
438 cache_validation: CacheValidationState::new(),
439 }
440 }
441
442 pub fn with_multipart_part_size(mut self, size: usize) -> Self {
444 self.multipart_part_size = size;
445 self
446 }
447
448 pub fn with_s3_check_cache(mut self, cache: Option<Arc<S3CheckCache>>) -> Self {
450 self.s3_check_cache = cache;
451 self
452 }
453
454 pub fn with_force_s3_check(mut self, force: bool) -> Self {
456 self.force_s3_check = force;
457 self
458 }
459
460 pub fn with_expected_bucket_owner(mut self, owner: Option<String>) -> Self {
462 self.expected_bucket_owner = owner;
463 self
464 }
465
466 pub fn client(&self) -> &aws_sdk_s3::Client {
468 &self.client
469 }
470
471 pub fn expected_bucket_owner(&self) -> Option<&str> {
473 self.expected_bucket_owner.as_deref()
474 }
475
476 pub fn cache_key(&self, hash: &str, algorithm: &str) -> String {
478 format!("{}/{}", self.bucket, self.object_key(hash, algorithm))
479 }
480
481 pub fn is_cache_validation_invalidated(&self) -> bool {
484 self.cache_validation.is_invalidated()
485 }
486
487 pub async fn new_with_auto_account_id(
489 bucket: String,
490 key_prefix: String,
491 s3_client: aws_sdk_s3::Client,
492 sts_client: aws_sdk_sts::Client,
493 ) -> crate::Result<Self> {
494 let resp =
495 sts_client.get_caller_identity().send().await.map_err(|e| {
496 crate::SnapshotError::S3(format!("STS GetCallerIdentity failed: {e}"))
497 })?;
498 let account = resp
499 .account()
500 .ok_or_else(|| crate::SnapshotError::S3("STS response missing Account".into()))?
501 .to_string();
502 Ok(Self::new(bucket, key_prefix, s3_client).with_expected_bucket_owner(Some(account)))
503 }
504
505 pub fn check_cache_exists(&self, hash: &str, algorithm: &str) -> bool {
507 if self.force_s3_check {
508 return false;
509 }
510 if let Some(ref cache) = self.s3_check_cache {
511 cache.get_entry(&self.cache_key(hash, algorithm)).is_some()
512 } else {
513 false
514 }
515 }
516
517 pub fn record_in_check_cache(&self, hash: &str, algorithm: &str) {
519 if let Some(ref cache) = self.s3_check_cache {
520 let _ = cache.put_entry(&self.cache_key(hash, algorithm));
521 }
522 }
523
524 fn object_key(&self, hash: &str, algorithm: &str) -> String {
525 format!("{}/{hash}.{algorithm}", self.key_prefix)
526 }
527}
528
529fn format_copy_source(bucket: &str, key: &str) -> String {
534 if bucket.starts_with("arn:") {
535 format!("{}/object/{}", bucket, key)
536 } else {
537 format!("{}/{}", bucket, key)
538 }
539}
540
541#[async_trait]
542impl AsyncDataCache for S3DataCache {
543 fn object_key(&self, hash: &str, algorithm: &str) -> String {
544 format!("{}/{hash}.{algorithm}", self.key_prefix)
545 }
546
547 fn as_any(&self) -> &dyn Any {
548 self
549 }
550
551 fn multipart_part_size(&self) -> usize {
552 self.multipart_part_size
553 }
554
555 fn as_multipart(&self) -> Option<&dyn MultipartDataCache> {
556 Some(self)
557 }
558
559 fn as_range_read(&self) -> Option<&dyn RangeReadDataCache> {
560 Some(self)
561 }
562
563 async fn copy_from(
564 &self,
565 source: &dyn AsyncDataCache,
566 hash: &str,
567 algorithm: &str,
568 ) -> std::io::Result<CopyResult> {
569 let Some(src_s3) = source.as_any().downcast_ref::<S3DataCache>() else {
570 return Ok(CopyResult::NotSupported);
571 };
572 let src_key = src_s3.object_key(hash, algorithm);
573 let dst_key = self.object_key(hash, algorithm);
574 let copy_source = format_copy_source(&src_s3.bucket, &src_key);
575 self.client
576 .copy_object()
577 .bucket(&self.bucket)
578 .key(&dst_key)
579 .copy_source(©_source)
580 .set_expected_bucket_owner(self.expected_bucket_owner.clone())
581 .send()
582 .await
583 .map_err(|e| std::io::Error::other(format!("S3 CopyObject failed: {e}")))?;
584 self.record_in_check_cache(hash, algorithm);
585 Ok(CopyResult::ServerSideCopy)
586 }
587
588 async fn object_exists(&self, hash: &str, algorithm: &str) -> std::io::Result<bool> {
589 if self.check_cache_exists(hash, algorithm) && !self.cache_validation.is_invalidated() {
590 if !self.cache_validation.should_verify() {
591 return Ok(true);
592 }
593 let key = AsyncDataCache::object_key(self, hash, algorithm);
595 return match self
596 .client
597 .head_object()
598 .bucket(&self.bucket)
599 .key(&key)
600 .set_expected_bucket_owner(self.expected_bucket_owner.clone())
601 .send()
602 .await
603 {
604 Ok(_) => Ok(true),
605 Err(e) => {
606 if e.as_service_error().is_some_and(|se| se.is_not_found()) {
607 warn!(key = %key, "S3 check cache stale entry detected, invalidating cache");
608 self.cache_validation.invalidate();
609 Ok(false)
610 } else {
611 Err(std::io::Error::other(format!(
612 "S3 HeadObject failed for {key}: {e}"
613 )))
614 }
615 }
616 };
617 }
618 let key = AsyncDataCache::object_key(self, hash, algorithm);
619 match self
620 .client
621 .head_object()
622 .bucket(&self.bucket)
623 .key(&key)
624 .set_expected_bucket_owner(self.expected_bucket_owner.clone())
625 .send()
626 .await
627 {
628 Ok(_) => {
629 self.record_in_check_cache(hash, algorithm);
630 Ok(true)
631 }
632 Err(e) => {
633 if e.as_service_error().is_some_and(|se| se.is_not_found()) {
634 Ok(false)
635 } else {
636 Err(std::io::Error::other(format!(
637 "S3 HeadObject failed for {key}: {e}"
638 )))
639 }
640 }
641 }
642 }
643
644 async fn put_object(
645 &self,
646 hash: &str,
647 algorithm: &str,
648 data: Vec<u8>,
649 ) -> std::io::Result<String> {
650 let key = AsyncDataCache::object_key(self, hash, algorithm);
651 let body = aws_sdk_s3::primitives::ByteStream::from(data);
652 self.client
653 .put_object()
654 .bucket(&self.bucket)
655 .key(&key)
656 .body(body)
657 .set_expected_bucket_owner(self.expected_bucket_owner.clone())
658 .send()
659 .await
660 .map_err(|e| std::io::Error::other(format!("S3 PutObject failed for {key}: {e}")))?;
661 self.record_in_check_cache(hash, algorithm);
662 Ok(key)
663 }
664
665 async fn get_object(&self, hash: &str, algorithm: &str) -> std::io::Result<Vec<u8>> {
666 let key = AsyncDataCache::object_key(self, hash, algorithm);
667 let resp = self
668 .client
669 .get_object()
670 .bucket(&self.bucket)
671 .key(&key)
672 .set_expected_bucket_owner(self.expected_bucket_owner.clone())
673 .send()
674 .await
675 .map_err(|e| std::io::Error::other(format!("S3 GetObject failed for {key}: {e}")))?;
676 let bytes = resp.body.collect().await.map_err(|e| {
677 std::io::Error::other(format!("S3 GetObject body read failed for {key}: {e}"))
678 })?;
679 Ok(bytes.to_vec())
680 }
681
682 async fn copy_object_to_file(
683 &self,
684 hash: &str,
685 algorithm: &str,
686 dest: &std::path::Path,
687 ) -> std::io::Result<u64> {
688 let key = AsyncDataCache::object_key(self, hash, algorithm);
689 let resp = self
690 .client
691 .get_object()
692 .bucket(&self.bucket)
693 .key(&key)
694 .set_expected_bucket_owner(self.expected_bucket_owner.clone())
695 .send()
696 .await
697 .map_err(|e| std::io::Error::other(format!("S3 GetObject failed for {key}: {e}")))?;
698 let bytes = resp.body.collect().await.map_err(|e| {
699 std::io::Error::other(format!("S3 GetObject body read failed for {key}: {e}"))
700 })?;
701 let data = bytes.to_vec();
702 let len = data.len() as u64;
703 let dest = dest.to_path_buf();
704 tokio::task::spawn_blocking(move || std::fs::write(&dest, &data))
705 .await
706 .map_err(std::io::Error::other)??;
707 Ok(len)
708 }
709
710 async fn write_object_to_file_at_offset(
711 &self,
712 hash: &str,
713 algorithm: &str,
714 dest: &std::path::Path,
715 offset: u64,
716 ) -> std::io::Result<u64> {
717 let key = AsyncDataCache::object_key(self, hash, algorithm);
718 let resp = self
719 .client
720 .get_object()
721 .bucket(&self.bucket)
722 .key(&key)
723 .set_expected_bucket_owner(self.expected_bucket_owner.clone())
724 .send()
725 .await
726 .map_err(|e| std::io::Error::other(format!("S3 GetObject failed for {key}: {e}")))?;
727 let bytes = resp.body.collect().await.map_err(|e| {
728 std::io::Error::other(format!("S3 GetObject body read failed for {key}: {e}"))
729 })?;
730 let data = bytes.to_vec();
731 let len = data.len() as u64;
732 let dest = dest.to_path_buf();
733 tokio::task::spawn_blocking(move || {
734 use std::io::{Seek, SeekFrom, Write};
735 let mut f = std::fs::OpenOptions::new().write(true).open(&dest)?;
736 f.seek(SeekFrom::Start(offset))?;
737 f.write_all(&data)?;
738 Ok::<_, std::io::Error>(len)
739 })
740 .await
741 .map_err(std::io::Error::other)?
742 }
743}
744
745#[async_trait]
746impl MultipartDataCache for S3DataCache {
747 async fn create_multipart_upload(
748 &self,
749 hash: &str,
750 algorithm: &str,
751 ) -> std::io::Result<String> {
752 let key = AsyncDataCache::object_key(self, hash, algorithm);
753 let resp = self
754 .client
755 .create_multipart_upload()
756 .bucket(&self.bucket)
757 .key(&key)
758 .set_expected_bucket_owner(self.expected_bucket_owner.clone())
759 .send()
760 .await
761 .map_err(|e| {
762 std::io::Error::other(format!("S3 CreateMultipartUpload failed for {key}: {e}"))
763 })?;
764 resp.upload_id()
765 .map(|s| s.to_string())
766 .ok_or_else(|| std::io::Error::other("missing upload_id"))
767 }
768
769 async fn upload_part(
770 &self,
771 hash: &str,
772 algorithm: &str,
773 upload_id: &str,
774 part_number: i32,
775 data: Vec<u8>,
776 ) -> std::io::Result<String> {
777 let key = AsyncDataCache::object_key(self, hash, algorithm);
778 let body = aws_sdk_s3::primitives::ByteStream::from(data);
779 let resp = self
780 .client
781 .upload_part()
782 .bucket(&self.bucket)
783 .key(&key)
784 .upload_id(upload_id)
785 .part_number(part_number)
786 .body(body)
787 .set_expected_bucket_owner(self.expected_bucket_owner.clone())
788 .send()
789 .await
790 .map_err(|e| {
791 std::io::Error::other(format!(
792 "S3 UploadPart failed for {key} part {part_number}: {e}"
793 ))
794 })?;
795 resp.e_tag()
796 .map(|s| s.to_string())
797 .ok_or_else(|| std::io::Error::other("missing ETag"))
798 }
799
800 async fn complete_multipart_upload(
801 &self,
802 hash: &str,
803 algorithm: &str,
804 upload_id: &str,
805 parts: Vec<(i32, String)>,
806 ) -> std::io::Result<()> {
807 let key = AsyncDataCache::object_key(self, hash, algorithm);
808 let completed_parts: Vec<_> = parts
809 .into_iter()
810 .map(|(num, etag)| {
811 aws_sdk_s3::types::CompletedPart::builder()
812 .part_number(num)
813 .e_tag(etag)
814 .build()
815 })
816 .collect();
817 let upload = aws_sdk_s3::types::CompletedMultipartUpload::builder()
818 .set_parts(Some(completed_parts))
819 .build();
820 self.client
821 .complete_multipart_upload()
822 .bucket(&self.bucket)
823 .key(&key)
824 .upload_id(upload_id)
825 .multipart_upload(upload)
826 .set_expected_bucket_owner(self.expected_bucket_owner.clone())
827 .send()
828 .await
829 .map_err(|e| {
830 std::io::Error::other(format!("S3 CompleteMultipartUpload failed for {key}: {e}"))
831 })?;
832 self.record_in_check_cache(hash, algorithm);
833 Ok(())
834 }
835
836 async fn abort_multipart_upload(
837 &self,
838 hash: &str,
839 algorithm: &str,
840 upload_id: &str,
841 ) -> std::io::Result<()> {
842 let key = AsyncDataCache::object_key(self, hash, algorithm);
843 self.client
844 .abort_multipart_upload()
845 .bucket(&self.bucket)
846 .key(&key)
847 .upload_id(upload_id)
848 .set_expected_bucket_owner(self.expected_bucket_owner.clone())
849 .send()
850 .await
851 .map_err(|e| {
852 std::io::Error::other(format!("S3 AbortMultipartUpload failed for {key}: {e}"))
853 })?;
854 Ok(())
855 }
856}
857
858#[async_trait]
859impl RangeReadDataCache for S3DataCache {
860 async fn get_object_range(
861 &self,
862 hash: &str,
863 algorithm: &str,
864 start: u64,
865 end: u64,
866 ) -> std::io::Result<Vec<u8>> {
867 let key = AsyncDataCache::object_key(self, hash, algorithm);
868 let range = format!("bytes={start}-{end}");
869 let resp = self
870 .client
871 .get_object()
872 .bucket(&self.bucket)
873 .key(&key)
874 .range(&range)
875 .set_expected_bucket_owner(self.expected_bucket_owner.clone())
876 .send()
877 .await
878 .map_err(|e| {
879 std::io::Error::other(format!("S3 GetObject range failed for {key}: {e}"))
880 })?;
881 let bytes = resp.body.collect().await.map_err(|e| {
882 std::io::Error::other(format!(
883 "S3 GetObject range body read failed for {key}: {e}"
884 ))
885 })?;
886 Ok(bytes.to_vec())
887 }
888
889 async fn stream_range_to_file_at_offset(
890 &self,
891 hash: &str,
892 algorithm: &str,
893 range_start: u64,
894 range_end: u64,
895 dest: &std::path::Path,
896 file_offset: u64,
897 ) -> std::io::Result<u64> {
898 let key = AsyncDataCache::object_key(self, hash, algorithm);
899 let range = format!("bytes={range_start}-{range_end}");
900 let resp = self
901 .client
902 .get_object()
903 .bucket(&self.bucket)
904 .key(&key)
905 .range(&range)
906 .set_expected_bucket_owner(self.expected_bucket_owner.clone())
907 .send()
908 .await
909 .map_err(|e| {
910 std::io::Error::other(format!("S3 GetObject range failed for {key}: {e}"))
911 })?;
912 let bytes = resp.body.collect().await.map_err(|e| {
913 std::io::Error::other(format!(
914 "S3 GetObject range body read failed for {key}: {e}"
915 ))
916 })?;
917 let data = bytes.to_vec();
918 let len = data.len() as u64;
919 let dest = dest.to_path_buf();
920 tokio::task::spawn_blocking(move || {
921 use std::io::{Seek, SeekFrom, Write};
922 let mut f = std::fs::OpenOptions::new().write(true).open(&dest)?;
923 f.seek(SeekFrom::Start(file_offset))?;
924 f.write_all(&data)?;
925 Ok::<_, std::io::Error>(len)
926 })
927 .await
928 .map_err(std::io::Error::other)?
929 }
930}
931
932#[cfg(test)]
933mod cache_validation_tests {
934 use super::*;
935
936 #[test]
937 fn cache_validation_first_100_always_verify() {
938 let state = CacheValidationState::new();
939 for _ in 0..100 {
940 assert!(state.should_verify());
941 }
942 }
943
944 #[test]
945 fn cache_validation_after_100_probabilistic() {
946 let state = CacheValidationState::new();
947 for _ in 0..100 {
949 state.should_verify();
950 }
951 let count = (0..1000).filter(|_| state.should_verify()).count();
953 assert!(
954 count < 1000,
955 "expected some false results after 100, but all returned true"
956 );
957 }
958
959 #[test]
960 fn cache_validation_invalidate() {
961 let state = CacheValidationState::new();
962 assert!(!state.is_invalidated());
963 state.invalidate();
964 assert!(state.is_invalidated());
965 }
966
967 #[test]
968 fn cache_validation_thread_safety() {
969 use std::sync::Arc;
970 let state = Arc::new(CacheValidationState::new());
971 let mut handles = vec![];
972 for _ in 0..8 {
973 let s = Arc::clone(&state);
974 handles.push(std::thread::spawn(move || {
975 for _ in 0..200 {
976 s.should_verify();
977 }
978 s.invalidate();
979 s.is_invalidated();
980 }));
981 }
982 for h in handles {
983 h.join().unwrap();
984 }
985 assert!(state.is_invalidated());
987 }
988}