use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use ndarray::{ArrayD, IxDyn, SliceInfoElem};
use tokio::sync::RwLock;
use tracing::{error, trace};
use zarrs::array::{Array, DataType, ElementOwned};
use zarrs::array_subset::ArraySubset;
use zarrs::storage::{AsyncReadableListableStorage, AsyncReadableListableStorageTraits};
use crate::error::{Error, Result};
pub(crate) trait AsyncLazyElement: ElementOwned + Clone + Send + Sync + 'static {
const NAN_VALUE: Self;
fn from_f64(v: f64) -> Self;
}
impl AsyncLazyElement for f32 {
const NAN_VALUE: f32 = f32::NAN;
#[inline]
fn from_f64(v: f64) -> Self {
v as f32
}
}
impl AsyncLazyElement for f64 {
const NAN_VALUE: f64 = f64::NAN;
#[inline]
fn from_f64(v: f64) -> Self {
v
}
}
pub(crate) struct AsyncLazySubset<T: AsyncLazyElement> {
source: AsyncReadableListableStorage,
subset: ArraySubset,
data: Arc<RwLock<HashMap<String, ArrayD<T>>>>,
}
impl<T: AsyncLazyElement> fmt::Display for AsyncLazySubset<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"AsyncLazySubset<{}> {{ subset: {:?} }}",
std::any::type_name::<T>(),
self.subset,
)
}
}
impl<T: AsyncLazyElement> fmt::Debug for AsyncLazySubset<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let cached_vars = match self.data.try_read() {
Ok(cache) => {
let mut keys: Vec<&str> = cache.keys().map(|k| k.as_str()).collect();
keys.sort();
format!("{:?}", keys)
}
Err(_) => "<locked>".to_string(),
};
f.debug_struct("AsyncLazySubset")
.field("subset", &self.subset)
.field("element_type", &std::any::type_name::<T>())
.field("cached_variables", &cached_vars)
.finish()
}
}
impl<T: AsyncLazyElement> AsyncLazySubset<T> {
#[allow(dead_code)]
pub(crate) fn new(source: AsyncReadableListableStorage, subset: ArraySubset) -> Self {
trace!("Creating AsyncLazySubset for subset: {:?}", subset);
Self {
source,
subset,
data: Arc::new(RwLock::new(HashMap::new())),
}
}
#[allow(dead_code)]
pub(crate) fn subset(&self) -> &ArraySubset {
&self.subset
}
#[allow(dead_code)]
pub(crate) async fn get(&self, varname: &str) -> Result<ArrayD<T>> {
trace!("AsyncLazySubset::get(\"{}\")", varname);
{
let cache = self.data.read().await;
if let Some(cached) = cache.get(varname) {
trace!("Cache hit for variable \"{}\"", varname);
return Ok(cached.clone());
}
}
let mut cache = self.data.write().await;
if let Some(cached) = cache.get(varname) {
trace!(
"Cache hit for variable \"{}\" (populated while waiting for write lock)",
varname
);
return Ok(cached.clone());
}
trace!(
"Loading variable \"{}\" for subset {:?}",
varname, self.subset
);
let data = self.load_variable(varname).await?;
cache.insert(varname.to_string(), data.clone());
Ok(data)
}
async fn retrieve_subset<S: AsyncReadableListableStorageTraits + 'static + ?Sized>(
&self,
array: &Array<S>,
varname: &str,
subset: &ArraySubset,
) -> Result<ArrayD<T>> {
trace!(
"Retrieving in-bounds subset {:?} for \"{}\" (on-disk type: {:?})",
subset,
varname,
array.data_type()
);
macro_rules! read_and_convert {
($native:ty) => {{
let native_data = array
.async_retrieve_array_subset_ndarray::<$native>(subset)
.await
.map_err(|err| {
Error::IO(std::io::Error::other(format!(
"Failed to retrieve subset for variable '{}': {}",
varname, err
)))
})?;
Ok(native_data.mapv(|v| T::from_f64(v as f64)))
}};
}
match array.data_type() {
DataType::Float32 => read_and_convert!(f32),
DataType::Float64 => read_and_convert!(f64),
DataType::Int16 => read_and_convert!(i16),
DataType::Int32 => read_and_convert!(i32),
DataType::Int64 => read_and_convert!(i64),
DataType::UInt8 => read_and_convert!(u8),
DataType::UInt16 => read_and_convert!(u16),
DataType::UInt32 => read_and_convert!(u32),
other => Err(Error::UnsupportedDataType(
format!("{:?}", other),
varname.to_string(),
)),
}
}
async fn load_variable(&self, varname: &str) -> Result<ArrayD<T>> {
let array = Array::async_open(self.source.clone(), &format!("/{varname}"))
.await
.inspect_err(|err| error!("Failed to open variable '{}': {}", varname, err))?;
let source_shape = array.shape().to_vec();
let ndim = source_shape.len();
let subset_start = self.subset.start();
let subset_shape: Vec<u64> = self.subset.shape().to_vec();
let mut clipped_start: Vec<u64> = Vec::with_capacity(ndim);
let mut clipped_shape: Vec<u64> = Vec::with_capacity(ndim);
let mut fully_outside = false;
for dim in 0..ndim {
let req_start = subset_start[dim];
let req_end = req_start + subset_shape[dim]; let src_end = source_shape[dim];
if req_start >= src_end || req_end == 0 {
fully_outside = true;
break;
}
let eff_start = req_start; let eff_end = req_end.min(src_end);
clipped_start.push(eff_start);
clipped_shape.push(eff_end - eff_start);
}
if fully_outside {
let out_shape: Vec<usize> = subset_shape.iter().map(|&d| d as usize).collect();
trace!(
"Subset for \"{}\" is fully outside source bounds; returning all-NaN",
varname
);
return Ok(ArrayD::<T>::from_elem(IxDyn(&out_shape), T::NAN_VALUE));
}
let needs_padding = clipped_shape != subset_shape;
if !needs_padding {
trace!("No padding needed for \"{}\"", varname);
return self.retrieve_subset(&array, varname, &self.subset).await;
}
trace!(
"Padding needed for \"{}\": clipped {:?} vs requested {:?}",
varname, clipped_shape, subset_shape
);
let clipped_subset =
ArraySubset::new_with_start_shape(clipped_start, clipped_shape.clone()).map_err(
|err| {
Error::IO(std::io::Error::other(format!(
"Failed to build clipped subset for '{}': {}",
varname, err
)))
},
)?;
let loaded = self
.retrieve_subset(&array, varname, &clipped_subset)
.await?;
let out_shape: Vec<usize> = subset_shape.iter().map(|&d| d as usize).collect();
let mut output = ArrayD::<T>::from_elem(IxDyn(&out_shape), T::NAN_VALUE);
let slice_info: Vec<SliceInfoElem> = clipped_shape
.iter()
.map(|&len| SliceInfoElem::Slice {
start: 0,
end: Some(len as isize),
step: 1,
})
.collect();
output
.slice_mut(slice_info.as_slice())
.assign(&loaded.view());
Ok(output)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::routing::features::samples::{FeaturesTestBuilder, LayerConfig};
#[tokio::test]
async fn get_returns_correct_shape_and_values() {
let (_tmp, storage) = FeaturesTestBuilder::new()
.dimensions(4, 4)
.chunks(2, 2)
.layer(LayerConfig::constant("A", 3.0))
.build()
.unwrap();
let subset = ArraySubset::new_with_start_shape(vec![0, 0], vec![4, 4]).unwrap();
let lazy = AsyncLazySubset::<f32>::new(Arc::clone(&storage), subset);
let data = lazy.get("A").await.unwrap();
assert_eq!(data.shape(), &[4, 4]);
assert!(data.iter().all(|&v| v == 3.0));
}
#[tokio::test]
async fn get_same_variable_twice_is_identical() {
let (_tmp, storage) = FeaturesTestBuilder::new()
.dimensions(4, 4)
.chunks(2, 2)
.layer(LayerConfig::sequential("A"))
.build()
.unwrap();
let subset = ArraySubset::new_with_start_shape(vec![0, 0], vec![4, 4]).unwrap();
let lazy = AsyncLazySubset::<f32>::new(Arc::clone(&storage), subset);
let first = lazy.get("A").await.unwrap();
let second = lazy.get("A").await.unwrap();
assert_eq!(first, second);
}
#[tokio::test]
async fn get_missing_variable_returns_error() {
let (_tmp, storage) = FeaturesTestBuilder::new()
.dimensions(4, 4)
.chunks(2, 2)
.layer(LayerConfig::ones("A"))
.build()
.unwrap();
let subset = ArraySubset::new_with_start_shape(vec![0, 0], vec![4, 4]).unwrap();
let lazy = AsyncLazySubset::<f32>::new(Arc::clone(&storage), subset);
let result = lazy.get("NONEXISTENT").await;
assert!(result.is_err(), "Expected an error for missing variable");
}
#[tokio::test]
async fn get_pads_out_of_bounds_with_nan_f32() {
let (_tmp, storage) = FeaturesTestBuilder::new()
.dimensions(4, 4)
.chunks(2, 2)
.layer(LayerConfig::constant("A", 7.0))
.build()
.unwrap();
let subset = ArraySubset::new_with_start_shape(vec![0, 0], vec![6, 6]).unwrap();
let lazy = AsyncLazySubset::<f32>::new(Arc::clone(&storage), subset);
let data = lazy.get("A").await.unwrap();
assert_eq!(data.shape(), &[6, 6]);
for i in 0..4 {
for j in 0..4 {
assert_eq!(data[[i, j]], 7.0, "in-bounds cell [{i},{j}]");
}
}
for i in 4..6 {
for j in 0..6 {
assert!(
data[[i, j]].is_nan(),
"out-of-bounds cell [{i},{j}] should be NaN, got {}",
data[[i, j]]
);
}
}
for i in 0..4 {
for j in 4..6 {
assert!(
data[[i, j]].is_nan(),
"out-of-bounds cell [{i},{j}] should be NaN, got {}",
data[[i, j]]
);
}
}
}
#[tokio::test]
async fn get_f64_works() {
use super::super::samples::FeatureDataType;
let (_tmp, storage) = FeaturesTestBuilder::new()
.dimensions(4, 4)
.chunks(2, 2)
.layer(LayerConfig::constant("elev", 1.5).with_dtype(FeatureDataType::Float64))
.build()
.unwrap();
let subset = ArraySubset::new_with_start_shape(vec![0, 0], vec![4, 4]).unwrap();
let lazy = AsyncLazySubset::<f64>::new(Arc::clone(&storage), subset);
let data = lazy.get("elev").await.unwrap();
assert_eq!(data.shape(), &[4, 4]);
assert!(data.iter().all(|&v| (v - 1.5).abs() < 1e-6));
}
#[tokio::test]
async fn fully_outside_returns_all_nan() {
let (_tmp, storage) = FeaturesTestBuilder::new()
.dimensions(4, 4)
.chunks(2, 2)
.layer(LayerConfig::ones("A"))
.build()
.unwrap();
let subset = ArraySubset::new_with_start_shape(vec![10, 10], vec![3, 3]).unwrap();
let lazy = AsyncLazySubset::<f32>::new(Arc::clone(&storage), subset);
let data = lazy.get("A").await.unwrap();
assert_eq!(data.shape(), &[3, 3]);
assert!(data.iter().all(|v| v.is_nan()));
}
#[tokio::test]
async fn concurrent_get_different_variables() {
let (_tmp, storage) = FeaturesTestBuilder::new()
.dimensions(4, 4)
.chunks(2, 2)
.layer(LayerConfig::constant("A", 1.0))
.layer(LayerConfig::constant("B", 2.0))
.build()
.unwrap();
let subset = ArraySubset::new_with_start_shape(vec![0, 0], vec![4, 4]).unwrap();
let lazy = AsyncLazySubset::<f32>::new(Arc::clone(&storage), subset);
let (ra, rb) = tokio::join!(lazy.get("A"), lazy.get("B"));
assert!(ra.unwrap().iter().all(|&v| v == 1.0));
assert!(rb.unwrap().iter().all(|&v| v == 2.0));
}
#[tokio::test(flavor = "multi_thread")]
async fn parallel_loads_match() {
let (_tmp, storage) = FeaturesTestBuilder::new()
.dimensions(4, 4)
.chunks(2, 2)
.layer(LayerConfig::sequential("A"))
.build()
.unwrap();
let subset = ArraySubset::new_with_start_shape(vec![0, 0], vec![4, 4]).unwrap();
let reference = {
let lazy = AsyncLazySubset::<f32>::new(Arc::clone(&storage), subset.clone());
lazy.get("A").await.unwrap()
};
let handles: Vec<_> = (0..4)
.map(|_| {
let src = Arc::clone(&storage);
let sub = subset.clone();
tokio::spawn(async move {
let lazy = AsyncLazySubset::<f32>::new(src, sub);
lazy.get("A").await
})
})
.collect();
for h in handles {
let data = h.await.expect("task panicked").unwrap();
assert_eq!(data, reference);
}
}
#[tokio::test(flavor = "multi_thread")]
async fn shared_instance_across_tasks() {
let (_tmp, storage) = FeaturesTestBuilder::new()
.dimensions(4, 4)
.chunks(2, 2)
.layer(LayerConfig::sequential("A"))
.build()
.unwrap();
let subset = ArraySubset::new_with_start_shape(vec![0, 0], vec![4, 4]).unwrap();
let lazy = Arc::new(AsyncLazySubset::<f32>::new(Arc::clone(&storage), subset));
let handles: Vec<_> = (0..8)
.map(|_| {
let lazy = Arc::clone(&lazy);
tokio::spawn(async move { lazy.get("A").await })
})
.collect();
let mut results = vec![];
for h in handles {
results.push(h.await.expect("task panicked").unwrap());
}
let first = &results[0];
for result in &results[1..] {
assert_eq!(result, first);
}
}
#[tokio::test]
async fn get_converts_mixed_dtypes_to_common_type() {
use super::super::samples::FeatureDataType;
let (_tmp, storage) = FeaturesTestBuilder::new()
.dimensions(2, 3)
.chunks(2, 3)
.layer(LayerConfig::sequential("from_f32"))
.layer(LayerConfig::sequential("from_f64").with_dtype(FeatureDataType::Float64))
.layer(LayerConfig::sequential("from_i16").with_dtype(FeatureDataType::Int16))
.layer(LayerConfig::sequential("from_i32").with_dtype(FeatureDataType::Int32))
.layer(LayerConfig::sequential("from_u8").with_dtype(FeatureDataType::UInt8))
.layer(LayerConfig::sequential("from_u32").with_dtype(FeatureDataType::UInt32))
.build()
.unwrap();
let subset = ArraySubset::new_with_start_shape(vec![0, 0], vec![2, 3]).unwrap();
let lazy = AsyncLazySubset::<f32>::new(Arc::clone(&storage), subset);
let expected: Vec<f32> = (1..=6).map(|x| x as f32).collect();
for varname in [
"from_f32", "from_f64", "from_i16", "from_i32", "from_u8", "from_u32",
] {
let data = lazy.get(varname).await.unwrap();
assert_eq!(data.shape(), &[2, 3], "wrong shape for {varname}");
let flat: Vec<f32> = data.iter().copied().collect();
assert_eq!(flat, expected, "wrong values for {varname}");
}
}
}