use crate::error::OxenError;
use async_tempfile::TempFile;
use async_trait::async_trait;
use aws_config::meta::region::RegionProviderChain;
use aws_sdk_s3::error::SdkError;
use aws_sdk_s3::operation::head_object::HeadObjectError;
use aws_sdk_s3::types::{CompletedMultipartUpload, CompletedPart, Delete, ObjectIdentifier};
use aws_sdk_s3::{Client, config::Region, primitives::ByteStream};
use bytes::Bytes;
use futures::StreamExt;
use log;
use std::path::Path;
use std::sync::Arc;
use tokio::fs::{File, create_dir_all};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::OnceCell;
use tokio_stream::Stream;
use tokio_util::io::StreamReader;
use super::version_store::{LocalFilePath, VersionStore};
use crate::constants::VERSION_FILE_NAME;
use crate::util::hasher;
use crate::view::versions::CleanCorruptedVersionsResult;
use xxhash_rust::xxh3::Xxh3;
const DEFAULT_ONESHOT_SIZE: u64 = 100 * 1024 * 1024;
#[derive(Debug)]
pub struct S3VersionStore {
client: OnceCell<Result<Arc<Client>, OxenError>>,
bucket: String,
prefix: String,
oneshot_size: u64,
}
impl S3VersionStore {
pub fn new(bucket: impl Into<String>, prefix: impl Into<String>) -> Self {
Self {
client: OnceCell::new(),
bucket: bucket.into(),
prefix: prefix.into(),
oneshot_size: DEFAULT_ONESHOT_SIZE,
}
}
pub async fn client(&self) -> Result<Arc<Client>, OxenError> {
let result_ref = self
.client
.get_or_init(|| async {
let region_provider = RegionProviderChain::default_provider().or_else("us-west-1");
let base_config = aws_config::defaults(aws_config::BehaviorVersion::latest())
.region(region_provider)
.load()
.await;
let tmp_client = Client::new(&base_config);
let detected_region = tmp_client
.get_bucket_location()
.bucket(&self.bucket)
.send()
.await
.map_err(|err| {
OxenError::basic_str(format!("Failed to get bucket location: {err:?}"))
})?
.location_constraint()
.map(|loc| loc.as_str().to_string())
.unwrap_or("us-east-1".to_string());
let real_config = aws_config::defaults(aws_config::BehaviorVersion::latest())
.region(Region::new(detected_region))
.load()
.await;
Ok::<Arc<Client>, OxenError>(Arc::new(Client::new(&real_config)))
})
.await;
match result_ref {
Ok(client) => Ok(client.clone()),
Err(e) => Err(OxenError::basic_str(format!("{e:?}"))),
}
}
#[cfg(test)]
pub fn new_with_client(client: Arc<Client>, bucket: String, prefix: String) -> Self {
let cell = OnceCell::new();
cell.set(Ok(client)).expect("cell was just created");
Self {
client: cell,
bucket,
prefix,
oneshot_size: DEFAULT_ONESHOT_SIZE,
}
}
fn version_dir(&self, hash: &str) -> String {
format!("{}/{}", self.prefix, hash)
}
fn generate_key(&self, hash: &str) -> String {
format!("{}/{}", self.version_dir(hash), VERSION_FILE_NAME)
}
fn chunk_key(&self, hash: &str, offset: u64) -> String {
format!("{}/chunks/{}", self.version_dir(hash), offset)
}
fn chunks_prefix(&self, hash: &str) -> String {
format!("{}/chunks/", self.version_dir(hash))
}
async fn list_objects_with_prefix(&self, prefix: &str) -> Result<Vec<String>, OxenError> {
let client = self.client().await?;
let mut keys = Vec::new();
let mut continuation_token: Option<String> = None;
loop {
let mut req = client.list_objects_v2().bucket(&self.bucket).prefix(prefix);
if let Some(token) = &continuation_token {
req = req.continuation_token(token);
}
let resp = req.send().await?;
if let Some(contents) = resp.contents {
for obj in contents {
if let Some(key) = obj.key {
keys.push(key);
}
}
}
if resp.is_truncated.unwrap_or(false) {
continuation_token = resp.next_continuation_token;
} else {
break;
}
}
Ok(keys)
}
async fn delete_objects(&self, keys: Vec<String>) -> Result<(), OxenError> {
let client = self.client().await?;
for batch in keys.chunks(1000) {
let objects: Vec<ObjectIdentifier> = batch
.iter()
.map(|k| ObjectIdentifier::builder().key(k).build())
.collect::<Result<_, _>>()?;
let delete = Delete::builder().set_objects(Some(objects)).build()?;
let resp = client
.delete_objects()
.bucket(&self.bucket)
.delete(delete)
.send()
.await?;
if let Some(errors) = resp.errors
&& !errors.is_empty()
{
let key_failures = errors
.iter()
.map(|e| {
(
e.key.as_deref().unwrap_or("?").into(),
e.message.as_deref().unwrap_or("?").into(),
)
})
.collect::<Vec<_>>();
return Err(OxenError::DeleteFailure(key_failures));
}
}
Ok(())
}
}
#[async_trait]
impl VersionStore for S3VersionStore {
async fn init(&self) -> Result<(), OxenError> {
let client = self.client().await?;
match client.head_bucket().bucket(&self.bucket).send().await {
Ok(result) => {
log::debug!("Successfully got S3 bucket {result:?}");
let test_key = format!("{}/_permission_check", self.prefix);
let body = ByteStream::from("permission-check".as_bytes().to_vec());
match client
.put_object()
.bucket(&self.bucket)
.key(&test_key)
.body(body)
.send()
.await
{
Ok(_) => {
client
.delete_object()
.bucket(&self.bucket)
.key(&test_key)
.send()
.await
.map_err(|err| {
OxenError::basic_str(format!(
"Failed to delete _permission_check: {err}"
))
})?;
Ok(())
}
Err(err) => Err(OxenError::basic_str(format!(
"S3 write permission check failed: {err}",
))),
}
}
Err(err) => Err(OxenError::basic_str(format!(
"Cannot access S3 bucket '{}': {err}",
self.bucket
))),
}
}
async fn store_version_from_reader(
&self,
hash: &str,
reader: Box<dyn tokio::io::AsyncRead + Send + Unpin>,
size: u64,
) -> Result<(), OxenError> {
let client = self.client().await?;
let key = self.generate_key(hash);
const MIN_PART_SIZE: usize = 5 * 1024 * 1024; const MAX_PART_SIZE: usize = 5 * 1024 * 1024 * 1024; const MAX_PARTS: usize = 10_000; const MAX_CONCURRENT_UPLOADS: usize = 16;
let mut reader = tokio::io::BufReader::new(reader);
if size <= self.oneshot_size {
let mut buf = Vec::with_capacity(size as usize);
tokio::io::AsyncReadExt::read_to_end(&mut reader, &mut buf)
.await
.map_err(|e| OxenError::upload(&format!("Failed to read: {e}")))?;
let computed = hasher::hash_buffer(&buf);
if computed != hash {
return Err(OxenError::upload(&format!(
"store_version_from_reader hash mismatch: expected {hash}, computed {computed}"
)));
}
client
.put_object() .bucket(&self.bucket)
.key(&key)
.body(ByteStream::from(buf))
.send()
.await?;
return Ok(());
}
let part_size = ((size as usize).div_ceil(MAX_PARTS)).clamp(MIN_PART_SIZE, MAX_PART_SIZE);
let upload = client
.create_multipart_upload()
.bucket(&self.bucket)
.key(&key)
.send()
.await?;
let upload_id = upload
.upload_id()
.ok_or_else(|| OxenError::upload("S3 multipart upload missing upload_id"))?
.to_string();
let mut uploads = futures::stream::FuturesUnordered::new();
let mut completed_parts = Vec::new();
let mut part_num = 1;
let mut hasher = Xxh3::new();
let result: Result<(), OxenError> = async {
loop {
let mut buf = vec![0u8; part_size];
let n = read_full(&mut reader, &mut buf).await?;
if n == 0 {
break;
}
buf.truncate(n);
hasher.update(&buf);
uploads.push(tokio::spawn(upload_part(
client.clone(),
self.bucket.clone(),
key.clone(),
upload_id.clone(),
part_num,
buf,
)));
part_num += 1;
while uploads.len() >= MAX_CONCURRENT_UPLOADS {
match uploads.next().await {
Some(Ok(result)) => completed_parts.push(result?),
Some(Err(e)) => {
return Err(OxenError::upload(&format!("Upload task panicked: {e}")));
}
None => break, }
}
}
while let Some(join_result) = uploads.next().await {
match join_result {
Ok(result) => completed_parts.push(result?),
Err(e) => return Err(OxenError::upload(&format!("Upload task panicked: {e}"))),
}
}
let computed = format!("{:x}", hasher.digest128());
if computed != hash {
return Err(OxenError::upload(&format!(
"store_version_from_reader hash mismatch: expected {hash}, computed {computed}"
)));
}
Ok(())
}
.await;
match result {
Ok(()) => {
completed_parts.sort_by_key(|p| p.part_number);
let completed = CompletedMultipartUpload::builder()
.set_parts(Some(completed_parts))
.build();
client
.complete_multipart_upload()
.bucket(&self.bucket)
.key(&key)
.upload_id(&upload_id)
.multipart_upload(completed)
.send()
.await?;
Ok(())
}
Err(e) => {
for handle in uploads.iter() {
handle.abort();
}
let _ = client
.abort_multipart_upload()
.bucket(&self.bucket)
.key(&key)
.upload_id(&upload_id)
.send()
.await;
Err(e)
}
}
}
async fn store_version(&self, hash: &str, data: &[u8]) -> Result<(), OxenError> {
let client = self.client().await?;
log::debug!("Storing version to S3");
let key = self.generate_key(hash);
let body = ByteStream::from(data.to_vec());
client
.put_object()
.bucket(&self.bucket)
.key(&key)
.body(body)
.send()
.await
.map_err(|_| OxenError::Basic("failed to store version in S3".into()))?;
Ok(())
}
async fn store_version_derived(
&self,
orig_hash: &str,
derived_filename: &str,
derived_data: &[u8],
) -> Result<(), OxenError> {
let client = self.client().await?;
let key = format!("{}/{}", self.version_dir(orig_hash), derived_filename);
client
.put_object()
.bucket(&self.bucket)
.key(&key)
.body(ByteStream::from(derived_data.to_vec()))
.send()
.await
.map_err(|e| {
OxenError::basic_str(format!("failed to store derived version file in S3: {e}"))
})?;
log::debug!("Saved derived version file {key}");
Ok(())
}
async fn get_version_size(&self, hash: &str) -> Result<u64, OxenError> {
let client = self.client().await?;
let key = self.generate_key(hash);
let resp = client
.head_object()
.bucket(&self.bucket)
.key(&key)
.send()
.await
.map_err(|e| OxenError::basic_str(format!("S3 head_object failed: {e}")))?;
let size = resp
.content_length()
.ok_or_else(|| OxenError::basic_str("S3 object missing content_length"))?
as u64;
Ok(size)
}
async fn get_version(&self, hash: &str) -> Result<Vec<u8>, OxenError> {
let client = self.client().await?;
let key = self.generate_key(hash);
let resp = client
.get_object()
.bucket(&self.bucket)
.key(&key)
.send()
.await
.map_err(|e| OxenError::basic_str(format!("S3 get_object failed: {e}")))?;
let data = resp
.body
.collect()
.await
.map_err(|e| OxenError::basic_str(format!("S3 read body failed: {e}")))?
.into_bytes()
.to_vec();
Ok(data)
}
async fn get_version_derived_size(
&self,
orig_hash: &str,
derived_filename: &str,
) -> Result<u64, OxenError> {
let client = self.client().await?;
let key = format!("{}/{}", self.version_dir(orig_hash), derived_filename);
let resp = client
.head_object()
.bucket(&self.bucket)
.key(&key)
.send()
.await
.map_err(|e| OxenError::basic_str(format!("S3 head_object failed: {e}")))?;
let size = resp
.content_length()
.ok_or_else(|| OxenError::basic_str("S3 object missing content_length"))?
as u64;
Ok(size)
}
async fn get_version_stream(
&self,
hash: &str,
) -> Result<Box<dyn Stream<Item = Result<Bytes, std::io::Error>> + Send + Unpin>, OxenError>
{
let client = self.client().await?;
let key = self.generate_key(hash);
let resp = client
.get_object()
.bucket(&self.bucket)
.key(&key)
.send()
.await
.map_err(|e| OxenError::basic_str(format!("S3 get_object failed: {e}")))?;
let adapter = ByteStreamAdapter { inner: resp.body };
Ok(Box::new(adapter) as Box<_>)
}
async fn get_version_derived_stream(
&self,
orig_hash: &str,
derived_filename: &str,
) -> Result<Box<dyn Stream<Item = Result<Bytes, std::io::Error>> + Send + Unpin>, OxenError>
{
let client = self.client().await?;
let key = format!("{}/{}", self.version_dir(orig_hash), derived_filename);
let resp = client
.get_object()
.bucket(&self.bucket)
.key(key)
.send()
.await
.map_err(|e| OxenError::basic_str(format!("S3 get_object failed: {e}")))?;
let adapter = ByteStreamAdapter { inner: resp.body };
Ok(Box::new(adapter) as Box<_>)
}
async fn derived_version_exists(
&self,
orig_hash: &str,
derived_filename: &str,
) -> Result<bool, OxenError> {
let client = self.client().await?;
let key = format!("{}/{}", self.version_dir(orig_hash), derived_filename);
match client
.head_object()
.bucket(&self.bucket)
.key(key)
.send()
.await
{
Ok(_) => Ok(true),
Err(SdkError::ServiceError(err)) => match err.err() {
HeadObjectError::NotFound(_) => Ok(false),
err => Err(OxenError::basic_str(format!(
"derived_exists failed with S3 head_object error: {err:?}"
))),
},
Err(err) => Err(OxenError::basic_str(format!(
"derived_exists failed with S3 head_object error: {err:?}"
))),
}
}
async fn get_version_path(&self, hash: &str) -> Result<LocalFilePath, OxenError> {
let data = self.get_version(hash).await?;
let mut tmp = TempFile::new()
.await
.map_err(|e| OxenError::basic_str(format!("Failed to create temp file: {e}")))?;
tmp.write_all(&data)
.await
.map_err(|e| OxenError::basic_str(format!("Failed to write temp file: {e}")))?;
Ok(LocalFilePath::Temp(tmp))
}
async fn copy_version_to_path(&self, hash: &str, dest_path: &Path) -> Result<(), OxenError> {
if let Some(parent) = dest_path.parent() {
create_dir_all(parent)
.await
.map_err(|e| OxenError::basic_str(format!("Failed to create parent dirs: {e}")))?;
}
let file = File::create(dest_path)
.await
.map_err(|e| OxenError::basic_str(format!("Failed to create file: {e}")))?;
let mut writer = tokio::io::BufWriter::with_capacity(10 * 1024 * 1024, file);
let mut stream = StreamReader::new(self.get_version_stream(hash).await?);
tokio::io::copy_buf(&mut stream, &mut writer)
.await
.map_err(|e| OxenError::basic_str(format!("Failed to copy S3 stream to file: {e}")))?;
Ok(())
}
async fn store_version_chunk(
&self,
hash: &str,
offset: u64,
data: Bytes,
) -> Result<(), OxenError> {
let client = self.client().await?;
let key = self.chunk_key(hash, offset);
client
.put_object()
.bucket(&self.bucket)
.key(&key)
.body(ByteStream::from(data))
.send()
.await?;
Ok(())
}
async fn get_version_chunk(
&self,
hash: &str,
offset: u64,
size: u64,
) -> Result<Vec<u8>, OxenError> {
if size == 0 {
return Ok(Vec::new());
}
let client = self.client().await?;
let key = self.generate_key(hash);
let end = offset.checked_add(size - 1).ok_or_else(|| {
OxenError::basic_str("get_version_chunk: offset + size overflows u64")
})?;
let range = format!("bytes={offset}-{end}");
let resp = client
.get_object()
.bucket(&self.bucket)
.key(&key)
.range(&range)
.send()
.await?;
let bytes = resp
.body
.collect()
.await
.map_err(|e| {
OxenError::basic_str(format!("get_version_chunk: body collect failed: {e}"))
})?
.into_bytes();
Ok(bytes.to_vec())
}
async fn list_version_chunks(&self, hash: &str) -> Result<Vec<u64>, OxenError> {
let prefix = self.chunks_prefix(hash);
let keys = self.list_objects_with_prefix(&prefix).await?;
let mut offsets = Vec::with_capacity(keys.len());
for key in &keys {
if let Some(offset_str) = key.strip_prefix(&prefix)
&& let Ok(offset) = offset_str.parse::<u64>()
{
offsets.push(offset);
}
}
offsets.sort();
Ok(offsets)
}
async fn version_exists(&self, hash: &str) -> Result<bool, OxenError> {
let client = self.client().await?;
let key = self.generate_key(hash);
match client
.head_object()
.bucket(&self.bucket)
.key(key)
.send()
.await
{
Ok(_) => Ok(true),
Err(SdkError::ServiceError(err)) => match err.err() {
HeadObjectError::NotFound(_) => Ok(false),
err => Err(OxenError::basic_str(format!(
"version_exists failed with S3 head_object error: {err:?}"
))),
},
Err(err) => Err(OxenError::basic_str(format!(
"version_exists failed with S3 head_object error: {err:?}"
))),
}
}
async fn delete_version(&self, hash: &str) -> Result<(), OxenError> {
let prefix = format!("{}/", self.version_dir(hash));
let keys = self.list_objects_with_prefix(&prefix).await?;
self.delete_objects(keys).await
}
async fn list_versions(&self) -> Result<Vec<String>, OxenError> {
let client = self.client().await?;
let base = format!("{}/", self.prefix);
let mut hashes = Vec::new();
let mut continuation_token: Option<String> = None;
loop {
let mut req = client
.list_objects_v2()
.bucket(&self.bucket)
.prefix(&base)
.delimiter("/");
if let Some(token) = &continuation_token {
req = req.continuation_token(token);
}
let resp = req.send().await?;
if let Some(common_prefixes) = resp.common_prefixes {
for cp in common_prefixes {
if let Some(hash) = cp
.prefix
.as_deref()
.and_then(|p| p.strip_prefix(&base))
.and_then(|s| s.strip_suffix('/'))
{
hashes.push(hash.to_string());
}
}
}
if resp.is_truncated.unwrap_or(false) {
continuation_token = resp.next_continuation_token;
} else {
break;
}
}
Ok(hashes)
}
async fn combine_version_chunks(&self, hash: &str) -> Result<(), OxenError> {
let offsets = self.list_version_chunks(hash).await?;
if offsets.is_empty() {
return Ok(());
}
log::debug!("combine_version_chunks found {} chunks", offsets.len());
let client = self.client().await?;
let key = self.generate_key(hash);
let upload = client
.create_multipart_upload()
.bucket(&self.bucket)
.key(&key)
.send()
.await?;
let upload_id = upload
.upload_id()
.ok_or_else(|| OxenError::upload("S3 multipart upload missing upload_id"))?
.to_string();
const MIN_PART_SIZE: usize = 5 * 1024 * 1024; let mut part_buf: Vec<u8> = Vec::new();
let mut part_num: i32 = 1;
let mut completed_parts: Vec<CompletedPart> = Vec::new();
let result: Result<(), OxenError> = async {
for (i, offset) in offsets.iter().enumerate() {
let is_last_chunk = i == offsets.len() - 1;
let resp = client
.get_object()
.bucket(&self.bucket)
.key(self.chunk_key(hash, *offset))
.send()
.await?;
let chunk_bytes = resp
.body
.collect()
.await
.map_err(|e| {
OxenError::basic_str(format!(
"Failed to read chunk body at offset {offset}: {e}"
))
})?
.into_bytes();
part_buf.extend_from_slice(&chunk_bytes);
while part_buf.len() >= MIN_PART_SIZE || (is_last_chunk && !part_buf.is_empty()) {
let drain_len = if part_buf.len() >= MIN_PART_SIZE && !is_last_chunk {
MIN_PART_SIZE
} else if is_last_chunk && part_buf.len() <= MIN_PART_SIZE {
part_buf.len()
} else {
MIN_PART_SIZE
};
let part_data: Vec<u8> = part_buf.drain(..drain_len).collect();
let part = upload_part(
client.clone(),
self.bucket.clone(),
key.clone(),
upload_id.clone(),
part_num,
part_data,
)
.await?;
completed_parts.push(part);
part_num += 1;
if is_last_chunk && part_buf.is_empty() {
break;
}
}
}
Ok(())
}
.await;
if let Err(e) = result {
let _ = client
.abort_multipart_upload()
.bucket(&self.bucket)
.key(&key)
.upload_id(&upload_id)
.send()
.await;
return Err(e);
}
let completed = CompletedMultipartUpload::builder()
.set_parts(Some(completed_parts))
.build();
client
.complete_multipart_upload()
.bucket(&self.bucket)
.key(&key)
.upload_id(&upload_id)
.multipart_upload(completed)
.send()
.await?;
let chunk_keys = self
.list_objects_with_prefix(&self.chunks_prefix(hash))
.await?;
if !chunk_keys.is_empty() {
self.delete_objects(chunk_keys).await?;
}
Ok(())
}
async fn clean_corrupted_versions(
&self,
_dry_run: bool,
) -> Result<CleanCorruptedVersionsResult, OxenError> {
Err(OxenError::basic_str(
"S3VersionStore clean_corrupted_versions not yet implemented",
))
}
fn storage_kind(&self) -> crate::storage::StorageKind {
crate::storage::StorageKind::S3
}
}
async fn upload_part(
client: Arc<Client>,
bucket: String,
key: String,
upload_id: String,
part_num: i32,
data: Vec<u8>,
) -> Result<CompletedPart, OxenError> {
let resp = client
.upload_part()
.bucket(bucket)
.key(key)
.upload_id(upload_id)
.part_number(part_num)
.body(ByteStream::from(data))
.send()
.await?;
let etag = resp
.e_tag()
.map(|s| s.to_string())
.ok_or_else(|| OxenError::upload("S3 upload_part response missing ETag"))?;
Ok(CompletedPart::builder()
.part_number(part_num)
.e_tag(etag)
.build())
}
async fn read_full(
reader: &mut (dyn tokio::io::AsyncRead + Send + Unpin),
buf: &mut [u8],
) -> Result<usize, OxenError> {
let mut offset = 0;
while offset < buf.len() {
let n = reader
.read(&mut buf[offset..])
.await
.map_err(|e| OxenError::upload(&format!("Failed to read from reader: {e}")))?;
if n == 0 {
break;
}
offset += n;
}
Ok(offset)
}
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
pub struct ByteStreamAdapter {
inner: ByteStream,
}
impl Stream for ByteStreamAdapter {
type Item = Result<Bytes, io::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match Pin::new(&mut self.inner).poll_next(cx) {
Poll::Ready(Some(Ok(bytes))) => Poll::Ready(Some(Ok(bytes))),
Poll::Ready(Some(Err(e))) => {
let err = io::Error::other(format!("{e}"));
Poll::Ready(Some(Err(err)))
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::version_store::VersionStore;
use std::net::SocketAddr;
use tokio::net::TcpListener;
async fn setup() -> (
S3VersionStore,
async_tempfile::TempDir,
tokio::task::JoinHandle<()>,
) {
let tmp = async_tempfile::TempDir::new().await.unwrap();
let fs = s3s_fs::FileSystem::new(tmp.dir_path()).unwrap();
let mut builder = s3s::service::S3ServiceBuilder::new(fs);
builder.set_auth(s3s::auth::SimpleAuth::from_single("test", "test"));
let service = builder.build();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server_handle = tokio::spawn(async move {
loop {
let (stream, _) = match listener.accept().await {
Ok(conn) => conn,
Err(_) => break,
};
let service = service.clone();
tokio::spawn(async move {
let stream = hyper_util::rt::TokioIo::new(stream);
let _ = hyper::server::conn::http1::Builder::new()
.serve_connection(stream, service)
.await;
});
}
});
let client = build_test_client(addr);
client
.create_bucket()
.bucket("test-bucket")
.send()
.await
.unwrap();
let store = S3VersionStore::new_with_client(
Arc::new(client),
"test-bucket".to_string(),
"test-namespace/test-repo".to_string(),
);
(store, tmp, server_handle)
}
fn build_test_client(addr: SocketAddr) -> Client {
let config = aws_sdk_s3::Config::builder()
.behavior_version_latest()
.region(aws_sdk_s3::config::Region::new("us-east-1"))
.endpoint_url(format!("http://{addr}"))
.credentials_provider(aws_sdk_s3::config::Credentials::new(
"test", "test", None, None, "test",
))
.force_path_style(true)
.build();
Client::from_conf(config)
}
#[tokio::test]
async fn test_store_and_get_small_version_from_reader() {
let (store, _tmp, _server) = setup().await;
let data = b"hello world";
let hash = hasher::hash_buffer(data);
let cursor = std::io::Cursor::new(data.to_vec());
store
.store_version_from_reader(&hash, Box::new(cursor), data.len() as u64)
.await
.unwrap();
let retrieved = store.get_version(&hash).await.unwrap();
assert_eq!(retrieved, data);
}
#[tokio::test]
async fn test_store_and_get_large_version_from_reader() {
let (store, _tmp, _server) = setup().await;
let data = vec![42u8; 20 * 1024 * 1024];
let hash = hasher::hash_buffer(&data);
let cursor = std::io::Cursor::new(data.clone());
store
.store_version_from_reader(&hash, Box::new(cursor), data.len() as u64)
.await
.unwrap();
let retrieved = store.get_version(&hash).await.unwrap();
assert_eq!(retrieved.len(), data.len());
assert_eq!(retrieved, data);
}
#[tokio::test]
async fn test_store_version_from_reader_exact_part_boundary() {
let (store, _tmp, _server) = setup().await;
let data = vec![7u8; 16 * 1024 * 1024];
let hash = hasher::hash_buffer(&data);
let cursor = std::io::Cursor::new(data.clone());
store
.store_version_from_reader(&hash, Box::new(cursor), data.len() as u64)
.await
.unwrap();
let retrieved = store.get_version(&hash).await.unwrap();
assert_eq!(retrieved, data);
}
#[tokio::test]
async fn test_store_version_from_reader_empty() {
let (store, _tmp, _server) = setup().await;
let hash = hasher::hash_buffer(&[]);
let cursor = std::io::Cursor::new(Vec::new());
store
.store_version_from_reader(&hash, Box::new(cursor), 0)
.await
.unwrap();
let retrieved = store.get_version(&hash).await.unwrap();
assert!(retrieved.is_empty());
}
#[tokio::test]
async fn test_copy_version_to_path_streams_to_dest() {
let (store, _tmp, _server) = setup().await;
let data = b"streamed to destination";
store.store_version("eeedef1234567890", data).await.unwrap();
let dest_dir = async_tempfile::TempDir::new().await.unwrap();
let dest_path = dest_dir.dir_path().join("subdir/output.bin");
store
.copy_version_to_path("eeedef1234567890", &dest_path)
.await
.unwrap();
let contents = tokio::fs::read(&dest_path).await.unwrap();
assert_eq!(contents, data);
}
#[tokio::test]
async fn test_store_version_chunk() {
let (store, _tmp, _server) = setup().await;
let hash = "abcdef1234567890abcdef1234567890";
let chunk_data = Bytes::from_static(b"hello chunk");
store
.store_version_chunk(hash, 0, chunk_data.clone())
.await
.expect("store_version_chunk should succeed");
let client = store.client().await.expect("client should succeed");
let resp = client
.get_object()
.bucket(&store.bucket)
.key(store.chunk_key(hash, 0))
.send()
.await
.expect("chunk object should exist");
let body = resp
.body
.collect()
.await
.expect("body should collect")
.into_bytes();
assert_eq!(body, chunk_data);
}
#[tokio::test]
async fn test_store_version_chunk_multiple_offsets() {
let (store, _tmp, _server) = setup().await;
let hash = "abcdef1234567890abcdef1234567890";
store
.store_version_chunk(hash, 0, Bytes::from_static(b"chunk-0"))
.await
.expect("store chunk at offset 0 should succeed");
store
.store_version_chunk(hash, 1024, Bytes::from_static(b"chunk-1024"))
.await
.expect("store chunk at offset 1024 should succeed");
let client = store.client().await.expect("client should succeed");
let body0 = client
.get_object()
.bucket(&store.bucket)
.key(store.chunk_key(hash, 0))
.send()
.await
.expect("chunk at offset 0 should exist")
.body
.collect()
.await
.expect("body should collect")
.into_bytes();
assert_eq!(&body0[..], b"chunk-0");
let body1024 = client
.get_object()
.bucket(&store.bucket)
.key(store.chunk_key(hash, 1024))
.send()
.await
.expect("chunk at offset 1024 should exist")
.body
.collect()
.await
.expect("body should collect")
.into_bytes();
assert_eq!(&body1024[..], b"chunk-1024");
}
#[tokio::test]
async fn test_delete_version_removes_data_and_chunks() {
let (store, _tmp, _server) = setup().await;
let hash = "abcdef1234567890abcdef1234567890";
store.store_version(hash, b"main data").await.unwrap();
store
.store_version_chunk(hash, 0, Bytes::from_static(b"chunk-0"))
.await
.unwrap();
store
.store_version_chunk(hash, 1024, Bytes::from_static(b"chunk-1024"))
.await
.unwrap();
assert!(store.version_exists(hash).await.unwrap());
store
.delete_version(hash)
.await
.expect("delete_version should succeed");
assert!(!store.version_exists(hash).await.unwrap());
let prefix = format!("{}/", store.version_dir(hash));
let remaining = store.list_objects_with_prefix(&prefix).await.unwrap();
assert!(
remaining.is_empty(),
"expected no objects, found: {:?}",
remaining
);
}
#[tokio::test]
async fn test_delete_version_missing_is_noop() {
let (store, _tmp, _server) = setup().await;
let hash = "deadbeefdeadbeefdeadbeefdeadbeef";
store
.delete_version(hash)
.await
.expect("delete of missing version should succeed");
}
#[tokio::test]
async fn test_list_version_chunks() {
let (store, _tmp, _server) = setup().await;
let hash = "abcdef1234567890abcdef1234567890";
store
.store_version_chunk(hash, 0, Bytes::from_static(b"chunk-0"))
.await
.unwrap();
store
.store_version_chunk(hash, 10240, Bytes::from_static(b"chunk-10240"))
.await
.unwrap();
store
.store_version_chunk(hash, 20480, Bytes::from_static(b"chunk-20480"))
.await
.unwrap();
let offsets = store.list_version_chunks(hash).await.unwrap();
assert_eq!(offsets, vec![0, 10240, 20480]);
}
#[tokio::test]
async fn test_list_version_chunks_empty() {
let (store, _tmp, _server) = setup().await;
let hash = "abcdef1234567890abcdef1234567890";
let offsets = store.list_version_chunks(hash).await.unwrap();
assert!(offsets.is_empty());
}
#[tokio::test]
async fn test_list_version_chunks_isolates_by_hash() {
let (store, _tmp, _server) = setup().await;
let hash_a = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
let hash_b = "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb";
store
.store_version_chunk(hash_a, 0, Bytes::from_static(b"a-chunk"))
.await
.unwrap();
store
.store_version_chunk(hash_b, 0, Bytes::from_static(b"b-chunk-0"))
.await
.unwrap();
store
.store_version_chunk(hash_b, 512, Bytes::from_static(b"b-chunk-512"))
.await
.unwrap();
let offsets_a = store.list_version_chunks(hash_a).await.unwrap();
assert_eq!(offsets_a, vec![0]);
let offsets_b = store.list_version_chunks(hash_b).await.unwrap();
assert_eq!(offsets_b, vec![0, 512]);
}
#[tokio::test]
async fn test_delete_version_isolates_by_hash() {
let (store, _tmp, _server) = setup().await;
let hash_a = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
let hash_b = "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb";
store.store_version(hash_a, b"a data").await.unwrap();
store.store_version(hash_b, b"b data").await.unwrap();
store.delete_version(hash_a).await.unwrap();
assert!(!store.version_exists(hash_a).await.unwrap());
assert!(store.version_exists(hash_b).await.unwrap());
}
#[tokio::test]
async fn test_get_version_chunk_mid_file() {
let (store, _tmp, _server) = setup().await;
let hash = "abcdef1234567890abcdef1234567890";
let data: Vec<u8> = (0..100u8).collect();
store.store_version(hash, &data).await.unwrap();
let chunk = store.get_version_chunk(hash, 10, 20).await.unwrap();
assert_eq!(chunk, data[10..30]);
}
#[tokio::test]
async fn test_get_version_chunk_from_start() {
let (store, _tmp, _server) = setup().await;
let hash = "abcdef1234567890abcdef1234567890";
let data = b"hello world!";
store.store_version(hash, data).await.unwrap();
let chunk = store.get_version_chunk(hash, 0, 5).await.unwrap();
assert_eq!(&chunk[..], b"hello");
}
#[tokio::test]
async fn test_get_version_chunk_zero_size() {
let (store, _tmp, _server) = setup().await;
let hash = "abcdef1234567890abcdef1234567890";
let chunk = store.get_version_chunk(hash, 0, 0).await.unwrap();
assert!(chunk.is_empty());
}
#[tokio::test]
async fn test_get_version_chunk_past_eof_errors() {
let (store, _tmp, _server) = setup().await;
let hash = "abcdef1234567890abcdef1234567890";
store.store_version(hash, b"small").await.unwrap();
let result = store.get_version_chunk(hash, 1000, 10).await;
assert!(
result.is_err(),
"expected error for offset past EOF, got {result:?}"
);
}
#[tokio::test]
async fn test_combine_version_chunks() {
let (store, _tmp, _server) = setup().await;
let combined = b"chunk-0chunk-10240chunk-20480";
let hash = hasher::hash_buffer(combined);
store
.store_version_chunk(&hash, 0, Bytes::from_static(b"chunk-0"))
.await
.expect("store chunk 0");
store
.store_version_chunk(&hash, 10240, Bytes::from_static(b"chunk-10240"))
.await
.expect("store chunk 10240");
store
.store_version_chunk(&hash, 20480, Bytes::from_static(b"chunk-20480"))
.await
.expect("store chunk 20480");
store
.combine_version_chunks(&hash)
.await
.expect("combine_version_chunks should succeed");
let client = store.client().await.expect("client");
let resp = client
.get_object()
.bucket(&store.bucket)
.key(store.generate_key(&hash))
.send()
.await
.expect("VERSION object should exist");
let body = resp.body.collect().await.expect("body").into_bytes();
assert_eq!(&body[..], combined);
let chunk_keys = store
.list_objects_with_prefix(&store.chunks_prefix(&hash))
.await
.expect("list chunks");
assert!(
chunk_keys.is_empty(),
"chunks should be deleted after combine, found: {chunk_keys:?}"
);
}
#[tokio::test]
async fn test_store_version_from_reader_hash_mismatch_oneshot() {
let (store, _tmp, _server) = setup().await;
let data = b"hello world";
let wrong_hash = "deadbeefdeadbeefdeadbeefdeadbeef";
let cursor = std::io::Cursor::new(data.to_vec());
let result = store
.store_version_from_reader(wrong_hash, Box::new(cursor), data.len() as u64)
.await;
assert!(
matches!(result, Err(OxenError::Upload(_))),
"expected Upload error for hash mismatch, got {result:?}"
);
assert!(
!store.version_exists(wrong_hash).await.unwrap(),
"no object should exist at the mismatched hash's key"
);
}
#[tokio::test]
async fn test_store_version_from_reader_hash_mismatch_multipart() {
let (mut store, _tmp, _server) = setup().await;
store.oneshot_size = 512;
let data = vec![42u8; 1024];
let wrong_hash = "deadbeefdeadbeefdeadbeefdeadbeef";
let cursor = std::io::Cursor::new(data.clone());
let result = store
.store_version_from_reader(wrong_hash, Box::new(cursor), data.len() as u64)
.await;
assert!(
matches!(result, Err(OxenError::Upload(_))),
"expected Upload error for hash mismatch, got {result:?}"
);
assert!(
!store.version_exists(wrong_hash).await.unwrap(),
"no object should exist at the mismatched hash's key"
);
}
#[tokio::test]
async fn test_list_versions() {
let (store, _tmp, _server) = setup().await;
store.store_version("cccc", b"c data").await.unwrap();
store.store_version("aaaa", b"a data").await.unwrap();
store.store_version("bbbb", b"b data").await.unwrap();
let versions = store.list_versions().await.unwrap();
assert_eq!(versions, vec!["aaaa", "bbbb", "cccc"]);
}
#[tokio::test]
async fn test_list_versions_empty() {
let (store, _tmp, _server) = setup().await;
let versions = store.list_versions().await.unwrap();
assert!(versions.is_empty());
}
#[tokio::test]
async fn test_list_versions_collapses_chunks_and_derived() {
let (store, _tmp, _server) = setup().await;
let hash = "abcdef1234567890abcdef1234567890";
store.store_version(hash, b"main").await.unwrap();
store
.store_version_chunk(hash, 0, Bytes::from_static(b"chunk-0"))
.await
.unwrap();
store
.store_version_chunk(hash, 1024, Bytes::from_static(b"chunk-1024"))
.await
.unwrap();
store
.store_version_derived(hash, "thumb.jpg", b"thumbnail bytes")
.await
.unwrap();
let versions = store.list_versions().await.unwrap();
assert_eq!(versions, vec![hash]);
}
}