use std::sync::{
Arc,
atomic::{AtomicU64, Ordering},
};
use async_trait::async_trait;
use bytes::{Bytes, BytesMut};
use crate::{
storage::{StorageError, StorageProvider},
superfile::{LazyByteSource, LazyByteSourceError},
};
const MAX_SHORT_READ_RETRIES: u32 = 4;
#[derive(Debug)]
pub struct StorageRangeSource {
storage: Arc<dyn StorageProvider>,
uri: String,
size: AtomicU64,
}
impl StorageRangeSource {
pub async fn new(
storage: Arc<dyn StorageProvider>,
uri: impl Into<String>,
) -> Result<Self, StorageError> {
let uri: String = uri.into();
let meta = storage.head(&uri).await?;
Ok(Self {
storage,
uri,
size: AtomicU64::new(meta.size),
})
}
pub fn with_known_size(
storage: Arc<dyn StorageProvider>,
uri: impl Into<String>,
size: u64,
) -> Self {
Self {
storage,
uri: uri.into(),
size: AtomicU64::new(size),
}
}
pub fn with_unknown_size(storage: Arc<dyn StorageProvider>, uri: impl Into<String>) -> Self {
Self {
storage,
uri: uri.into(),
size: AtomicU64::new(0),
}
}
pub fn uri(&self) -> &str {
&self.uri
}
}
#[async_trait]
impl LazyByteSource for StorageRangeSource {
fn size(&self) -> u64 {
self.size.load(Ordering::Acquire)
}
async fn range(&self, start: u64, len: u64) -> Result<Bytes, LazyByteSourceError> {
let known = self.size.load(Ordering::Acquire);
if known > 0 && start.saturating_add(len) > known {
return Err(LazyByteSourceError::OutOfBounds {
start,
len,
size: known,
});
}
if len == 0 {
return Ok(Bytes::new());
}
let want = len as usize;
let end = start + len;
let mut cursor = start;
let mut filled = 0usize;
let mut parts: Vec<Bytes> = Vec::new();
let mut stalls = 0u32;
while filled < want {
let chunk = self.storage.get_range(&self.uri, cursor..end).await?;
if chunk.is_empty() {
return Err(LazyByteSourceError::ShortRead {
start,
requested: len,
got: filled as u64,
});
}
let take = chunk.len().min(want - filled);
filled += take;
cursor += take as u64;
parts.push(chunk.slice(0..take));
if filled < want {
stalls += 1;
if stalls > MAX_SHORT_READ_RETRIES {
return Err(LazyByteSourceError::ShortRead {
start,
requested: len,
got: filled as u64,
});
}
}
}
if parts.len() == 1 {
return Ok(parts.pop().expect("len checked == 1"));
}
let mut out = BytesMut::with_capacity(want);
for p in &parts {
out.extend_from_slice(p);
}
Ok(out.freeze())
}
async fn tail(&self, len: u64) -> Result<(Bytes, u64), LazyByteSourceError> {
let (bytes, total) = self.storage.tail(&self.uri, len).await?;
self.size.store(total, Ordering::Release);
Ok((bytes, total))
}
}
#[cfg(test)]
mod tests {
use std::{error::Error, ops::Range, sync::atomic::AtomicUsize, time::SystemTime};
use object_store::MultipartUpload;
use super::*;
use crate::storage::ObjectMeta;
#[derive(Debug)]
struct ChunkedStorage {
blob: Bytes,
chunk_cap: usize,
obj_size: usize,
calls: AtomicUsize,
}
impl ChunkedStorage {
fn new(blob: Bytes, chunk_cap: usize, obj_size: usize) -> Self {
Self {
blob,
chunk_cap,
obj_size,
calls: AtomicUsize::new(0),
}
}
fn call_count(&self) -> usize {
self.calls.load(Ordering::Acquire)
}
}
fn permanent(uri: &str, msg: &'static str) -> StorageError {
let boxed: Box<dyn Error + Send + Sync> = msg.into();
StorageError::Permanent {
uri: uri.into(),
source: boxed,
}
}
#[async_trait]
impl StorageProvider for ChunkedStorage {
async fn head(&self, _uri: &str) -> Result<ObjectMeta, StorageError> {
Ok(ObjectMeta {
size: self.obj_size as u64,
etag: None,
last_modified: SystemTime::UNIX_EPOCH,
})
}
async fn get(&self, uri: &str) -> Result<(Bytes, ObjectMeta), StorageError> {
Err(permanent(uri, "get unimplemented"))
}
async fn get_range(&self, _uri: &str, range: Range<u64>) -> Result<Bytes, StorageError> {
self.calls.fetch_add(1, Ordering::AcqRel);
let start = range.start as usize;
let req = (range.end - range.start) as usize;
let available = self.obj_size.saturating_sub(start);
let take = req.min(self.chunk_cap).min(available);
Ok(self.blob.slice(start..start + take))
}
async fn put_atomic(
&self,
uri: &str,
_bytes: Bytes,
) -> Result<Option<String>, StorageError> {
Err(permanent(uri, "put_atomic unimplemented"))
}
async fn put_if_match(
&self,
uri: &str,
_bytes: Bytes,
_expected_etag: Option<&str>,
) -> Result<Option<String>, StorageError> {
Err(permanent(uri, "put_if_match unimplemented"))
}
async fn put_multipart(&self, uri: &str) -> Result<Box<dyn MultipartUpload>, StorageError> {
Err(permanent(uri, "put_multipart unimplemented"))
}
async fn delete(&self, _uri: &str) -> Result<(), StorageError> {
Ok(())
}
}
#[tokio::test]
async fn range_completes_a_short_read_by_refetching_the_tail() {
let blob = Bytes::from((0u8..=255).cycle().take(4096).collect::<Vec<u8>>());
let storage = Arc::new(ChunkedStorage::new(blob.clone(), 1000, blob.len()));
let src = StorageRangeSource::with_known_size(
Arc::clone(&storage) as Arc<dyn StorageProvider>,
"seg.sf.parquet",
blob.len() as u64,
);
let got = src.range(0, blob.len() as u64).await.expect("range");
assert_eq!(got.len(), blob.len());
assert_eq!(got.as_ref(), blob.as_ref());
assert!(
storage.call_count() >= 5,
"expected multiple GETs to complete the short read, got {}",
storage.call_count()
);
}
#[tokio::test]
async fn range_completes_short_read_for_interior_range() {
let blob = Bytes::from((0u8..=255).cycle().take(4096).collect::<Vec<u8>>());
let storage = Arc::new(ChunkedStorage::new(blob.clone(), 700, blob.len()));
let src = StorageRangeSource::with_known_size(
Arc::clone(&storage) as Arc<dyn StorageProvider>,
"seg.sf.parquet",
blob.len() as u64,
);
let (start, len) = (1024u64, 2048u64);
let got = src.range(start, len).await.expect("range");
assert_eq!(got.as_ref(), &blob[start as usize..(start + len) as usize]);
}
#[tokio::test]
async fn range_surfaces_short_read_when_object_is_truncated() {
let blob = Bytes::from(vec![7u8; 2048]);
let storage = Arc::new(ChunkedStorage::new(blob, 4096, 2048));
let src = StorageRangeSource::with_known_size(
Arc::clone(&storage) as Arc<dyn StorageProvider>,
"seg.sf.parquet",
4096,
);
let err = src
.range(0, 4096)
.await
.expect_err("must reject a permanently short read");
match err {
LazyByteSourceError::ShortRead {
start,
requested,
got,
} => {
assert_eq!(start, 0);
assert_eq!(requested, 4096);
assert_eq!(got, 2048);
}
other => panic!("expected ShortRead, got {other:?}"),
}
}
#[tokio::test]
async fn range_zero_length_is_empty_without_io() {
let storage = Arc::new(ChunkedStorage::new(Bytes::from(vec![0u8; 16]), 16, 16));
let src = StorageRangeSource::with_known_size(
Arc::clone(&storage) as Arc<dyn StorageProvider>,
"seg.sf.parquet",
16,
);
let got = src.range(8, 0).await.expect("zero-length range");
assert!(got.is_empty());
assert_eq!(storage.call_count(), 0);
}
#[tokio::test]
async fn new_heads_and_caches_size() {
let blob = Bytes::from(vec![3u8; 512]);
let storage = Arc::new(ChunkedStorage::new(blob.clone(), 512, blob.len()));
let src = StorageRangeSource::new(
Arc::clone(&storage) as Arc<dyn StorageProvider>,
"seg.sf.parquet",
)
.await
.expect("new heads ok");
assert_eq!(src.size(), blob.len() as u64);
assert_eq!(src.uri(), "seg.sf.parquet");
}
#[tokio::test]
async fn unknown_size_tail_discovers_and_patches_size() {
let blob = Bytes::from((0u8..=255).cycle().take(1024).collect::<Vec<u8>>());
let storage = Arc::new(ChunkedStorage::new(blob.clone(), 4096, blob.len()));
let src = StorageRangeSource::with_unknown_size(
Arc::clone(&storage) as Arc<dyn StorageProvider>,
"seg.sf.parquet",
);
assert_eq!(src.size(), 0);
let (tail_bytes, total) = src.tail(64).await.expect("tail");
assert_eq!(total, blob.len() as u64);
assert_eq!(tail_bytes.as_ref(), &blob[blob.len() - 64..]);
assert_eq!(src.size(), blob.len() as u64);
let err = src
.range(blob.len() as u64, 8)
.await
.expect_err("past-end range must be OutOfBounds");
match err {
LazyByteSourceError::OutOfBounds { start, len, size } => {
assert_eq!(start, blob.len() as u64);
assert_eq!(len, 8);
assert_eq!(size, blob.len() as u64);
}
other => panic!("expected OutOfBounds, got {other:?}"),
}
}
#[tokio::test]
async fn range_out_of_bounds_when_size_known() {
let storage = Arc::new(ChunkedStorage::new(Bytes::from(vec![0u8; 100]), 100, 100));
let src = StorageRangeSource::with_known_size(
Arc::clone(&storage) as Arc<dyn StorageProvider>,
"seg.sf.parquet",
100,
);
let err = src.range(90, 20).await.expect_err("90+20 > 100");
assert!(
matches!(err, LazyByteSourceError::OutOfBounds { .. }),
"expected OutOfBounds, got {err:?}"
);
assert_eq!(storage.call_count(), 0);
}
#[tokio::test]
async fn debug_renders_struct_name_and_uri() {
let storage = Arc::new(ChunkedStorage::new(Bytes::from(vec![0u8; 8]), 8, 8));
let src = StorageRangeSource::with_known_size(
Arc::clone(&storage) as Arc<dyn StorageProvider>,
"seg-debug.sf.parquet",
8,
);
let dbg = format!("{src:?}");
assert!(dbg.contains("StorageRangeSource"), "got {dbg}");
assert!(dbg.contains("seg-debug.sf.parquet"), "got {dbg}");
}
}