use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::trace;
use zarrs::array::{Array, DataType, ElementOwned};
use zarrs::array_subset::ArraySubset;
use zarrs::storage::AsyncReadableListableStorage;
use zarrs::storage::{ReadableListableStorage, ReadableListableStorageTraits};
use crate::error::{Error, Result};
pub(crate) struct LazySubset<T> {
source: ReadableListableStorage,
subset: ArraySubset,
data: HashMap<
String,
ndarray::ArrayBase<ndarray::OwnedRepr<T>, ndarray::Dim<ndarray::IxDynImpl>>,
>,
}
impl<T> fmt::Display for LazySubset<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "LazySubset {{ subset: {:?}, ... }}", self.subset,)
}
}
impl<T: ElementOwned> LazySubset<T> {
pub(super) fn new(source: ReadableListableStorage, subset: ArraySubset) -> Self {
trace!("Creating LazySubset for subset: {:?}", subset);
LazySubset {
source,
subset,
data: HashMap::new(),
}
}
pub(crate) fn subset(&self) -> &ArraySubset {
&self.subset
}
}
impl LazySubset<f32> {
pub(crate) fn get(
&mut self,
varname: &str,
) -> Result<ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>>> {
trace!("Getting data subset for variable: {}", varname);
let data = match self.data.get(varname) {
Some(v) => {
trace!("Data for variable {} already loaded", varname);
v.clone()
}
None => {
trace!(
"Loading data subset ({:?}) for variable: {}",
self.subset, varname
);
let variable =
Array::open(self.source.clone(), &format!("/{varname}")).map_err(|err| {
Error::IO(std::io::Error::other(format!(
"Failed to open layer '{varname}': {err}"
)))
})?;
let values = self.load_as_f32(&variable, varname)?;
self.data.insert(varname.to_string(), values.clone());
values
}
};
Ok(data)
}
fn load_as_f32<TStorage: ?Sized + ReadableListableStorageTraits + 'static>(
&self,
variable: &Array<TStorage>,
varname: &str,
) -> Result<ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>>> {
let dtype = variable.data_type();
match dtype {
DataType::Float32 => {
self.retrieve_and_convert::<f32, TStorage, _>(variable, varname, |v| v)
}
DataType::Float64 => {
self.retrieve_and_convert::<f64, TStorage, _>(variable, varname, |v| v as f32)
}
DataType::Int8 => {
self.retrieve_and_convert::<i8, TStorage, _>(variable, varname, |v| v as f32)
}
DataType::Int16 => {
self.retrieve_and_convert::<i16, TStorage, _>(variable, varname, |v| v as f32)
}
DataType::Int32 => {
self.retrieve_and_convert::<i32, TStorage, _>(variable, varname, |v| v as f32)
}
DataType::Int64 => {
self.retrieve_and_convert::<i64, TStorage, _>(variable, varname, |v| v as f32)
}
DataType::UInt8 => {
self.retrieve_and_convert::<u8, TStorage, _>(variable, varname, |v| v as f32)
}
DataType::UInt16 => {
self.retrieve_and_convert::<u16, TStorage, _>(variable, varname, |v| v as f32)
}
DataType::UInt32 => {
self.retrieve_and_convert::<u32, TStorage, _>(variable, varname, |v| v as f32)
}
DataType::UInt64 => {
self.retrieve_and_convert::<u64, TStorage, _>(variable, varname, |v| v as f32)
}
other => Err(Error::IO(std::io::Error::other(format!(
"Unsupported data type {:?} for layer '{varname}'",
other
)))),
}
}
fn retrieve_and_convert<T, TStorage, F>(
&self,
variable: &Array<TStorage>,
varname: &str,
converter: F,
) -> Result<ndarray::ArrayBase<ndarray::OwnedRepr<f32>, ndarray::Dim<ndarray::IxDynImpl>>>
where
T: ElementOwned + Clone,
TStorage: ?Sized + ReadableListableStorageTraits + 'static,
F: Fn(T) -> f32 + Copy,
{
let raw = variable
.retrieve_array_subset_ndarray::<T>(&self.subset)
.map_err(|err| {
Error::IO(std::io::Error::other(format!(
"Failed to retrieve array subset for layer '{varname}': {err}"
)))
})?;
Ok(raw.mapv(converter))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dataset::samples;
use std::sync::Arc;
use zarrs::storage::ReadableListableStorage;
#[test]
fn sample() {
let path = samples::multi_variable_random(1, 8, 8, 1, 4, 4, &["A", "B", "C", "cost"]);
let store: ReadableListableStorage =
Arc::new(zarrs::filesystem::FilesystemStore::new(&path).unwrap());
let subset = ArraySubset::new_with_start_shape(vec![0, 0, 0], vec![1, 2, 2]).unwrap();
let mut dataset = LazySubset::<f32>::new(store, subset);
let tmp = dataset.get("A").unwrap();
assert_eq!(tmp.shape(), &[1, 2, 2]);
}
}
#[allow(dead_code)]
trait LazySubsetElement: ElementOwned + Clone + Send + Sync {
fn from_f32(value: f32) -> Self;
fn from_f64(value: f64) -> Self;
}
impl LazySubsetElement for f32 {
fn from_f32(value: f32) -> Self {
value
}
fn from_f64(value: f64) -> Self {
value as f32
}
}
impl LazySubsetElement for f64 {
fn from_f32(value: f32) -> Self {
value as f64
}
fn from_f64(value: f64) -> Self {
value
}
}
#[allow(dead_code)]
struct AsyncLazySubset<T: LazySubsetElement> {
source: AsyncReadableListableStorage,
subset: ArraySubset,
#[allow(clippy::type_complexity)]
data: Arc<
RwLock<
HashMap<
String,
ndarray::ArrayBase<ndarray::OwnedRepr<T>, ndarray::Dim<ndarray::IxDynImpl>>,
>,
>,
>,
}
impl<T: LazySubsetElement> fmt::Display for AsyncLazySubset<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "AsyncLazySubset {{ subset: {:?}, ... }}", self.subset)
}
}
impl<T: LazySubsetElement> fmt::Debug for AsyncLazySubset<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("AsyncLazySubset")
.field("subset", &self.subset)
.field("element_type", &std::any::type_name::<T>())
.finish()
}
}
#[allow(dead_code)]
impl<T: LazySubsetElement> AsyncLazySubset<T> {
fn new(source: AsyncReadableListableStorage, subset: ArraySubset) -> Self {
trace!("Creating AsyncLazySubset for subset: {:?}", subset);
AsyncLazySubset {
source,
subset,
data: Arc::new(RwLock::new(HashMap::new())),
}
}
fn subset(&self) -> &ArraySubset {
&self.subset
}
async fn get(
&self,
varname: &str,
) -> Result<ndarray::ArrayBase<ndarray::OwnedRepr<T>, ndarray::Dim<ndarray::IxDynImpl>>> {
trace!("Getting data subset for variable: {}", varname);
{
let data_read = self.data.read().await;
if let Some(cached) = data_read.get(varname) {
trace!("Data for variable {} already loaded", varname);
return Ok(cached.clone());
}
}
trace!(
"Loading data subset ({:?}) for variable: {}",
self.subset, varname
);
let variable = Array::async_open(self.source.clone(), &format!("/{varname}")).await?;
let values = variable
.async_retrieve_array_subset_ndarray(&self.subset)
.await
.expect("Failed to retrieve array subset");
let mut data_write = self.data.write().await;
data_write.insert(varname.to_string(), values.clone());
Ok(values)
}
}
#[cfg(test)]
mod tests_async {
use super::*;
use crate::dataset::samples::{self, LayerConfig, ZarrTestBuilder};
#[tokio::test]
async fn get_returns_correct_shape_and_values() {
let tmp = ZarrTestBuilder::new()
.dimensions(1, 4, 4)
.chunks(1, 2, 2)
.layer(LayerConfig::constant("A", 3.0))
.build()
.unwrap();
let source = samples::async_storage_for(tmp.path());
let subset = ArraySubset::new_with_start_shape(vec![0, 0, 0], vec![1, 4, 4]).unwrap();
let lazy = AsyncLazySubset::<f32>::new(Arc::clone(&source), subset);
let data = lazy.get("A").await.unwrap();
assert_eq!(data.shape(), &[1, 4, 4]);
assert!(data.iter().all(|&v| v == 3.0));
}
#[tokio::test]
async fn get_same_variable_twice_is_identical() {
let tmp = ZarrTestBuilder::new()
.dimensions(1, 4, 4)
.chunks(1, 2, 2)
.layer(LayerConfig::sequential("A", 1))
.build()
.unwrap();
let source = samples::async_storage_for(tmp.path());
let subset = ArraySubset::new_with_start_shape(vec![0, 0, 0], vec![1, 4, 4]).unwrap();
let lazy = AsyncLazySubset::<f32>::new(Arc::clone(&source), subset);
let first = lazy.get("A").await.unwrap();
let second = lazy.get("A").await.unwrap();
assert_eq!(first, second);
}
#[tokio::test]
async fn two_instances_shared_source_join() {
let tmp = ZarrTestBuilder::new()
.dimensions(1, 4, 4)
.chunks(1, 2, 2)
.layer(LayerConfig::constant("A", 1.0))
.layer(LayerConfig::constant("B", 2.0))
.build()
.unwrap();
let source = samples::async_storage_for(tmp.path());
let subset = ArraySubset::new_with_start_shape(vec![0, 0, 0], vec![1, 4, 4]).unwrap();
let lazy_a = AsyncLazySubset::<f32>::new(Arc::clone(&source), subset.clone());
let lazy_b = AsyncLazySubset::<f32>::new(Arc::clone(&source), subset);
let (result_a, result_b) = tokio::join!(lazy_a.get("A"), lazy_b.get("B"));
assert!(result_a.unwrap().iter().all(|&v| v == 1.0));
assert!(result_b.unwrap().iter().all(|&v| v == 2.0));
}
#[tokio::test]
async fn two_instances_same_layer_join_are_identical() {
let tmp = ZarrTestBuilder::new()
.dimensions(1, 4, 4)
.chunks(1, 2, 2)
.layer(LayerConfig::sequential("A", 1))
.build()
.unwrap();
let source = samples::async_storage_for(tmp.path());
let subset = ArraySubset::new_with_start_shape(vec![0, 0, 0], vec![1, 4, 4]).unwrap();
let lazy_a = AsyncLazySubset::<f32>::new(Arc::clone(&source), subset.clone());
let lazy_b = AsyncLazySubset::<f32>::new(Arc::clone(&source), subset);
let (result_a, result_b) = tokio::join!(lazy_a.get("A"), lazy_b.get("A"));
assert_eq!(result_a.unwrap(), result_b.unwrap());
}
#[tokio::test(flavor = "multi_thread")]
async fn many_instances_parallel_spawn_shared_source() {
let tmp = ZarrTestBuilder::new()
.dimensions(1, 4, 4)
.chunks(1, 2, 2)
.layer(LayerConfig::constant("A", 5.0))
.build()
.unwrap();
let source = samples::async_storage_for(tmp.path());
let subset = ArraySubset::new_with_start_shape(vec![0, 0, 0], vec![1, 4, 4]).unwrap();
let handles: Vec<_> = (0..4)
.map(|_| {
let source = Arc::clone(&source);
let subset = subset.clone();
tokio::spawn(
async move { AsyncLazySubset::<f32>::new(source, subset).get("A").await },
)
})
.collect();
for handle in handles {
let data = handle.await.expect("task panicked").unwrap();
assert_eq!(data.shape(), &[1, 4, 4]);
assert!(data.iter().all(|&v| v == 5.0));
}
}
#[tokio::test(flavor = "multi_thread")]
async fn parallel_results_match_sequential_reference() {
let tmp = ZarrTestBuilder::new()
.dimensions(1, 4, 4)
.chunks(1, 2, 2)
.layer(LayerConfig::sequential("A", 1))
.build()
.unwrap();
let source = samples::async_storage_for(tmp.path());
let subset = ArraySubset::new_with_start_shape(vec![0, 0, 0], vec![1, 4, 4]).unwrap();
let reference = AsyncLazySubset::<f32>::new(Arc::clone(&source), subset.clone())
.get("A")
.await
.unwrap();
let handles: Vec<_> = (0..4)
.map(|_| {
let source = Arc::clone(&source);
let subset = subset.clone();
tokio::spawn(
async move { AsyncLazySubset::<f32>::new(source, subset).get("A").await },
)
})
.collect();
for handle in handles {
let data = handle.await.expect("task panicked").unwrap();
assert_eq!(data, reference);
}
}
}