use std::io;
use crate::{
zarr_get, zarr_get_range_from_end, zarr_get_range_from_offset, zarr_has,
zarr_get_status, zarr_get_range_from_end_status, zarr_get_range_from_offset_status, zarr_has_status,
};
use crate::zarr_types::ZarrPeekResult;
use futures::{stream, StreamExt, TryStreamExt};
use zarrs::storage::{
byte_range::{ByteRange, ByteRangeIterator},
AsyncMaybeBytesIterator, AsyncReadableStorageTraits, Bytes, MaybeBytes, StorageError, StoreKey,
};
use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock};
use quick_cache::sync::Cache;
static ZARR_STORE_CACHES: OnceLock<Mutex<HashMap<String, Arc<Cache<String, Bytes>>>>> =
OnceLock::new();
const ZARR_CACHE_ENABLED: bool = false;
fn get_or_init_store_cache(name: &str) -> Arc<Cache<String, Bytes>> {
let map_mutex = ZARR_STORE_CACHES.get_or_init(|| Mutex::new(HashMap::new()));
let mut map = map_mutex.lock().unwrap();
if let Some(cache) = map.get(name) {
cache.clone()
} else {
let new_cache = Arc::new(Cache::new(10000)); map.insert(name.to_string(), new_cache.clone());
new_cache
}
}
fn normalize_key(key: &str, byte_range: Option<ByteRange>) -> String {
match byte_range {
Some(ByteRange::FromStart(start, Some(len))) => {
format!("{}:{}:{}", key, start, start + len - 1)
}
Some(ByteRange::Suffix(suffix_length)) => {
format!("{}:-{}", key, suffix_length)
}
None => key.to_string(),
_ => panic!("Unsupported ByteRange variant"),
}
}
fn make_storage_error() -> StorageError {
return StorageError::IOError(Arc::new(io::Error::new(io::ErrorKind::TimedOut, "too slow")));
}
fn is_storage_error_timed_out(err: &StorageError) -> bool {
match err {
StorageError::IOError(io_err) => io_err.kind() == io::ErrorKind::TimedOut,
_ => false,
}
}
fn is_codec_error_timed_out(err: &zarrs::array::CodecError) -> bool {
match err {
zarrs::array::CodecError::StorageError(se) => is_storage_error_timed_out(se),
zarrs::array::CodecError::IOError(io_err) => io_err.kind() == io::ErrorKind::TimedOut,
_ => false,
}
}
pub fn is_timed_out_zarrs_error(err: &zarrs::array::ArrayError) -> bool {
match err {
zarrs::array::ArrayError::StorageError(se) => is_storage_error_timed_out(se),
zarrs::array::ArrayError::CodecError(ce) => is_codec_error_timed_out(ce),
_ => false,
}
}
pub struct AsyncZarritaStore {
store_name: String,
wait_for_store_gets: bool,
}
impl AsyncZarritaStore {
#[must_use]
pub fn new(store_name: String, wait_for_store_gets: bool) -> Self {
Self {
store_name,
wait_for_store_gets,
}
}
async fn fetch_byte_range(
&self,
key: &str,
byte_range: ByteRange,
) -> Result<Bytes, StorageError> {
match byte_range {
ByteRange::FromStart(start, Some(len)) => {
if !self.wait_for_store_gets {
let promise_status = zarr_get_range_from_offset_status(
&self.store_name,
key,
start as u32,
len as u32,
);
if promise_status == ZarrPeekResult::Pending {
return Err(make_storage_error());
}
}
Ok(zarr_get_range_from_offset(
&self.store_name,
key,
start as u32,
len as u32,
)
.await)
}
ByteRange::Suffix(suffix_length) => {
if !self.wait_for_store_gets {
let promise_status = zarr_get_range_from_end_status(
&self.store_name,
key,
suffix_length as u32,
);
if promise_status == ZarrPeekResult::Pending {
return Err(make_storage_error());
}
}
Ok(zarr_get_range_from_end(
&self.store_name,
key,
suffix_length as u32,
)
.await)
}
_ => panic!("Unsupported ByteRange variant"),
}
}
pub async fn has(&self, key: &StoreKey) -> Result<bool, StorageError> {
if !self.wait_for_store_gets {
let promise_status = zarr_has_status(&self.store_name, key.as_str());
if promise_status == ZarrPeekResult::Pending {
return Err(make_storage_error());
}
}
let has = zarr_has(&self.store_name, key.as_str()).await;
Ok(has)
}
}
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl AsyncReadableStorageTraits for AsyncZarritaStore {
async fn get(&self, key: &StoreKey) -> Result<MaybeBytes, StorageError> {
let key_str = normalize_key(key.as_str(), None);
if !ZARR_CACHE_ENABLED {
let bytes = zarr_get(&self.store_name, key.as_str()).await;
return Ok(Some(bytes));
}
let cache = get_or_init_store_cache(&self.store_name);
if let Some(cached) = cache.get(&key_str.to_string()) {
return Ok(Some(cached.clone()));
}
if !self.has(key).await.expect("store.has failed") {
return Ok(None);
}
if !self.wait_for_store_gets {
let promise_status = zarr_get_status(&self.store_name, key.as_str());
if promise_status == ZarrPeekResult::Pending {
return Err(make_storage_error());
}
}
let bytes = zarr_get(&self.store_name, key.as_str()).await;
cache.insert(key_str.to_string(), bytes.clone());
Ok(Some(bytes))
}
async fn get_partial_many<'a>(
&'a self,
key: &StoreKey,
byte_ranges: ByteRangeIterator<'a>,
) -> Result<AsyncMaybeBytesIterator<'a>, StorageError> {
let mut results = Vec::new();
let cache = ZARR_CACHE_ENABLED.then(|| get_or_init_store_cache(&self.store_name));
for byte_range in byte_ranges {
let key_str = normalize_key(key.as_str(), Some(byte_range));
if let Some(cached) = cache.as_ref().and_then(|c| c.get(&key_str)) {
results.push(Ok(cached.clone()));
continue;
}
let bytes_result = self.fetch_byte_range(key.as_str(), byte_range).await;
if let (Some(cache), Ok(bytes)) = (&cache, &bytes_result) {
cache.insert(key_str, bytes.clone());
}
results.push(bytes_result);
}
Ok(Some(Box::pin(stream::iter(results))))
}
async fn size_key(&self, key: &StoreKey) -> Result<Option<u64>, StorageError> {
Ok(None) }
fn supports_get_partial(&self) -> bool {
true
}
}