use std::sync::Arc;
use ndarray::{Array2, Array3};
#[cfg(test)]
use object_store::local::LocalFileSystem;
use rand::RngExt;
use tempfile::TempDir;
use zarrs::array::{ArrayBuilder, DataType, FillValue};
use zarrs::array_subset::ArraySubset;
use zarrs::filesystem::FilesystemStore;
use zarrs::group::GroupBuilder;
#[cfg(test)]
use zarrs::storage::AsyncReadableListableStorage;
use zarrs::storage::ReadableWritableListableStorage;
#[cfg(test)]
use zarrs_object_store::AsyncObjectStore;
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub(crate) enum FillStrategy {
Constant(f32),
Sequential,
Random(f32, f32),
Custom(fn(u64, u64) -> f32),
Values(Vec<f32>),
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub(crate) struct LayerConfig {
name: String,
fill: FillStrategy,
}
#[allow(dead_code)]
impl LayerConfig {
pub(crate) fn new(name: impl Into<String>, fill: FillStrategy) -> Self {
Self {
name: name.into(),
fill,
}
}
pub(crate) fn constant(name: impl Into<String>, value: f32) -> Self {
Self::new(name, FillStrategy::Constant(value))
}
pub(crate) fn sequential(name: impl Into<String>) -> Self {
Self::new(name, FillStrategy::Sequential)
}
pub(crate) fn random(name: impl Into<String>, min: f32, max: f32) -> Self {
Self::new(name, FillStrategy::Random(min, max))
}
pub(crate) fn custom(name: impl Into<String>, fill_fn: fn(u64, u64) -> f32) -> Self {
Self::new(name, FillStrategy::Custom(fill_fn))
}
pub(crate) fn ones(name: impl Into<String>) -> Self {
Self::constant(name, 1.0)
}
pub(crate) fn zeros(name: impl Into<String>) -> Self {
Self::constant(name, 0.0)
}
}
pub(crate) struct ZarrTestBuilder {
ni: u64,
nj: u64,
ci: u64,
cj: u64,
layers: Vec<LayerConfig>,
dtype: DataType,
fill_value: FillValue,
}
impl Default for ZarrTestBuilder {
fn default() -> Self {
Self::new()
}
}
#[allow(dead_code)]
impl ZarrTestBuilder {
pub(crate) fn new() -> Self {
Self {
ni: 8,
nj: 8,
ci: 4,
cj: 4,
layers: Vec::new(),
dtype: DataType::Float32,
fill_value: FillValue::from(zarrs::array::ZARR_NAN_F32),
}
}
pub(crate) fn dimensions(mut self, ni: u64, nj: u64) -> Self {
self.ni = ni;
self.nj = nj;
self
}
pub(crate) fn chunks(mut self, ci: u64, cj: u64) -> Self {
self.ci = ci;
self.cj = cj;
self
}
pub(crate) fn shape(mut self, ni: u64, nj: u64, ci: u64, cj: u64) -> Self {
self.ni = ni;
self.nj = nj;
self.ci = ci;
self.cj = cj;
self
}
pub(crate) fn layer(mut self, layer: LayerConfig) -> Self {
self.layers.push(layer);
self
}
pub(crate) fn layers(mut self, layers: Vec<LayerConfig>) -> Self {
self.layers.extend(layers);
self
}
pub(crate) fn data_type(mut self, dtype: DataType) -> Self {
self.dtype = dtype;
self
}
pub(crate) fn build(self) -> Result<TempDir, Box<dyn std::error::Error>> {
let tmp_path = TempDir::new()?;
let store: ReadableWritableListableStorage =
Arc::new(FilesystemStore::new(tmp_path.path()).unwrap());
GroupBuilder::new()
.build(store.clone(), "/")?
.store_metadata()?;
for layer_config in &self.layers {
self.create_layer(&store, layer_config)?;
}
Ok(tmp_path)
}
fn create_layer(
&self,
store: &ReadableWritableListableStorage,
config: &LayerConfig,
) -> Result<(), Box<dyn std::error::Error>> {
let array = ArrayBuilder::new(
vec![self.ni, self.nj],
vec![self.ci, self.cj],
self.dtype.clone(),
self.fill_value.clone(),
)
.dimension_names(["y", "x"].into())
.build(store.clone(), &format!("/{}", config.name))?;
array.store_metadata()?;
let data = self.generate_data(&config.fill)?;
let subset =
ArraySubset::new_with_ranges(&[0..(self.ni / self.ci), 0..(self.nj / self.cj)]);
array.store_chunks_ndarray(&subset, data)?;
Ok(())
}
fn generate_data(
&self,
fill: &FillStrategy,
) -> Result<Array2<f32>, Box<dyn std::error::Error>> {
let size = (self.ni * self.nj) as usize;
let values = match fill {
FillStrategy::Constant(val) => vec![*val; size],
FillStrategy::Sequential => (1..=size).map(|x| x as f32).collect(),
FillStrategy::Random(min, max) => {
let mut rng = rand::rng();
(0..size).map(|_| rng.random_range(*min..=*max)).collect()
}
FillStrategy::Custom(func) => {
let mut values = Vec::with_capacity(size);
for i in 0..self.ni {
for j in 0..self.nj {
values.push(func(i, j));
}
}
values
}
FillStrategy::Values(vals) => {
if vals.len() != size {
return Err(format!(
"Values vector length {} doesn't match array size {}",
vals.len(),
size
)
.into());
}
vals.clone()
}
};
let data = Array2::from_shape_vec((self.ni as usize, self.nj as usize), values)?;
Ok(data)
}
}
pub(crate) fn uniform_cost_zarr(ni: u64, nj: u64, ci: u64, cj: u64) -> TempDir {
ZarrTestBuilder::new()
.shape(ni, nj, ci, cj)
.layer(LayerConfig::ones("cost"))
.build()
.expect("Failed to create uniform cost zarr")
}
pub(crate) fn three_layer_ones(ni: u64, nj: u64, ci: u64, cj: u64) -> TempDir {
ZarrTestBuilder::new()
.shape(ni, nj, ci, cj)
.layer(LayerConfig::ones("A"))
.layer(LayerConfig::ones("B"))
.layer(LayerConfig::ones("C"))
.build()
.expect("Failed to create three-layer zarr")
}
#[allow(dead_code)]
pub(crate) fn multi_variable_random(
ni: u64,
nj: u64,
ci: u64,
cj: u64,
layers: &[&str],
) -> TempDir {
let mut builder = ZarrTestBuilder::new().shape(ni, nj, ci, cj);
for &layer_name in layers {
builder = builder.layer(LayerConfig::random(layer_name, 0.0, 1.0));
}
builder
.build()
.expect("Failed to create multi-variable zarr")
}
#[allow(dead_code)]
pub(crate) fn sequential_layers(ni: u64, nj: u64, ci: u64, cj: u64, layers: &[&str]) -> TempDir {
let mut builder = ZarrTestBuilder::new().shape(ni, nj, ci, cj);
for &layer_name in layers {
builder = builder.layer(LayerConfig::sequential(layer_name));
}
builder.build().expect("Failed to create sequential zarr")
}
pub(crate) fn preset_small() -> ZarrTestBuilder {
ZarrTestBuilder::new().dimensions(4, 4).chunks(2, 2)
}
#[allow(dead_code)]
pub(crate) fn preset_medium() -> ZarrTestBuilder {
ZarrTestBuilder::new().dimensions(16, 16).chunks(4, 4)
}
#[allow(dead_code)]
pub(crate) fn preset_large() -> ZarrTestBuilder {
ZarrTestBuilder::new().dimensions(128, 128).chunks(32, 32)
}
pub(crate) fn preset_cost_surface() -> ZarrTestBuilder {
ZarrTestBuilder::new()
.layer(LayerConfig::sequential("A"))
.layer(LayerConfig::constant("B", 2.0))
.layer(LayerConfig::ones("C"))
}
pub(crate) fn multi_variable_zarr() -> TempDir {
let ni = 8;
let nj = 8;
let ci = 4;
let cj = 4;
let tmp_path = TempDir::new().unwrap();
let store: ReadableWritableListableStorage = std::sync::Arc::new(
zarrs::filesystem::FilesystemStore::new(tmp_path.path())
.expect("could not open filesystem store"),
);
zarrs::group::GroupBuilder::new()
.build(store.clone(), "/")
.unwrap()
.store_metadata()
.unwrap();
for array_path in ["/A", "/B", "/C", "/cost"] {
let array = zarrs::array::ArrayBuilder::new(
vec![1, ni, nj], vec![1, ci, cj], zarrs::array::DataType::Float32,
zarrs::array::FillValue::from(zarrs::array::ZARR_NAN_F32),
)
.dimension_names(["band", "y", "x"].into())
.build(store.clone(), array_path)
.unwrap();
array.store_metadata().unwrap();
let mut rng = rand::rng();
let mut a = vec![];
for _x in 0..(ni * nj) {
a.push(rng.random_range(0.0..=1.0));
}
let data: Array3<f32> =
ndarray::Array::from_shape_vec((1, ni.try_into().unwrap(), nj.try_into().unwrap()), a)
.unwrap();
array
.store_chunks_ndarray(
&zarrs::array_subset::ArraySubset::new_with_ranges(&[
0..1,
0..(ni / ci),
0..(nj / cj),
]),
data,
)
.unwrap();
}
tmp_path
}
pub(crate) fn constant_value_cost_zarr(cost_value: f32) -> TempDir {
let (ni, nj) = (8, 8);
let (ci, cj) = (4, 4);
let tmp_path = TempDir::new().unwrap();
let store: zarrs::storage::ReadableWritableListableStorage = std::sync::Arc::new(
zarrs::filesystem::FilesystemStore::new(tmp_path.path())
.expect("could not open filesystem store"),
);
zarrs::group::GroupBuilder::new()
.build(store.clone(), "/")
.unwrap()
.store_metadata()
.unwrap();
let array = zarrs::array::ArrayBuilder::new(
vec![1, ni, nj], vec![1, ci, cj], zarrs::array::DataType::Float32,
zarrs::array::FillValue::from(zarrs::array::ZARR_NAN_F32),
)
.dimension_names(["band", "y", "x"].into())
.build(store.clone(), "/cost")
.unwrap();
array.store_metadata().unwrap();
let (uni, unj): (usize, usize) = (ni.try_into().unwrap(), nj.try_into().unwrap());
let data: Array3<f32> =
ndarray::Array::from_shape_vec((1, uni, unj), vec![cost_value; uni * unj]).unwrap();
array
.store_chunks_ndarray(
&zarrs::array_subset::ArraySubset::new_with_ranges(&[0..1, 0..(ni / ci), 0..(nj / cj)]),
data,
)
.unwrap();
tmp_path
}
pub(crate) fn cost_as_index_zarr((ni, nj): (u64, u64), (ci, cj): (u64, u64)) -> TempDir {
let tmp_path = TempDir::new().unwrap();
let store: zarrs::storage::ReadableWritableListableStorage = std::sync::Arc::new(
zarrs::filesystem::FilesystemStore::new(tmp_path.path())
.expect("could not open filesystem store"),
);
zarrs::group::GroupBuilder::new()
.build(store.clone(), "/")
.unwrap()
.store_metadata()
.unwrap();
let array = zarrs::array::ArrayBuilder::new(
vec![1, ni, nj], vec![1, ci, cj], zarrs::array::DataType::Float32,
zarrs::array::FillValue::from(zarrs::array::ZARR_NAN_F32),
)
.dimension_names(["band", "y", "x"].into())
.build(store.clone(), "/cost")
.unwrap();
array.store_metadata().unwrap();
let a: Vec<f32> = (0..ni * nj).map(|x| x as f32).collect();
let data: Array3<f32> =
ndarray::Array::from_shape_vec((1, ni.try_into().unwrap(), nj.try_into().unwrap()), a)
.unwrap();
array
.store_chunks_ndarray(
&zarrs::array_subset::ArraySubset::new_with_ranges(&[0..1, 0..(ni / ci), 0..(nj / cj)]),
data,
)
.unwrap();
tmp_path
}
pub(crate) fn specific_layers_zarr(
(ni, nj): (u64, u64),
(ci, cj): (u64, u64),
friction_layer_weight: f32,
invariant_layer_cost: f32,
) -> TempDir {
let tmp_path = TempDir::new().unwrap();
let store: zarrs::storage::ReadableWritableListableStorage = std::sync::Arc::new(
zarrs::filesystem::FilesystemStore::new(tmp_path.path())
.expect("could not open filesystem store"),
);
zarrs::group::GroupBuilder::new()
.build(store.clone(), "/")
.unwrap()
.store_metadata()
.unwrap();
let a_vals: Vec<f32> = (1..=(ni * nj)).map(|x| x as f32).collect();
let a_data: Array3<f32> =
ndarray::Array::from_shape_vec((1, ni.try_into().unwrap(), nj.try_into().unwrap()), a_vals)
.unwrap();
let b_vals: Vec<f32> = vec![friction_layer_weight; ni as usize * nj as usize];
let b_data: Array3<f32> =
ndarray::Array::from_shape_vec((1, ni.try_into().unwrap(), nj.try_into().unwrap()), b_vals)
.unwrap();
let c_vals: Vec<f32> = vec![invariant_layer_cost; ni as usize * nj as usize];
let c_data: Array3<f32> =
ndarray::Array::from_shape_vec((1, ni.try_into().unwrap(), nj.try_into().unwrap()), c_vals)
.unwrap();
for (path, data) in [("/A", a_data), ("/B", b_data), ("/C", c_data)] {
let array = zarrs::array::ArrayBuilder::new(
vec![1, ni, nj], vec![1, ci, cj], zarrs::array::DataType::Float32,
zarrs::array::FillValue::from(zarrs::array::ZARR_NAN_F32),
)
.dimension_names(["band", "y", "x"].into())
.build(store.clone(), path)
.unwrap();
array.store_metadata().unwrap();
array
.store_chunks_ndarray(
&zarrs::array_subset::ArraySubset::new_with_ranges(&[0..1, 0..1, 0..1]),
data,
)
.unwrap();
}
tmp_path
}
#[cfg(test)]
pub(crate) fn async_storage_for(path: &std::path::Path) -> AsyncReadableListableStorage {
let store =
LocalFileSystem::new_with_prefix(path).expect("could not open local filesystem store");
std::sync::Arc::new(AsyncObjectStore::new(store))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_builder_basic() {
let sample = ZarrTestBuilder::new()
.dimensions(4, 4)
.chunks(2, 2)
.layer(LayerConfig::ones("test"))
.build()
.unwrap();
assert!(sample.path().exists());
}
#[test]
fn test_builder_multiple_layers() {
let sample = ZarrTestBuilder::new()
.dimensions(8, 8)
.chunks(4, 4)
.layer(LayerConfig::ones("A"))
.layer(LayerConfig::sequential("B"))
.layer(LayerConfig::constant("C", 5.0))
.build()
.unwrap();
let path = sample.path();
assert!(path.exists());
let store = Arc::new(FilesystemStore::new(path).unwrap());
for layer_name in ["A", "B", "C"] {
let array = zarrs::array::Array::open(store.clone(), &format!("/{}", layer_name));
assert!(array.is_ok(), "Layer {} should exist", layer_name);
}
}
#[test]
fn test_custom_fill() {
let sample = ZarrTestBuilder::new()
.dimensions(4, 4)
.chunks(2, 2)
.layer(LayerConfig::custom("custom", |i, j| (i * 10 + j) as f32))
.build()
.unwrap();
assert!(sample.path().exists());
}
#[test]
fn test_uniform_cost_helper() {
let sample = uniform_cost_zarr(4, 4, 2, 2);
let path = sample.path();
assert!(path.exists());
let store = Arc::new(FilesystemStore::new(path).unwrap());
let array = zarrs::array::Array::open(store, "/cost");
assert!(array.is_ok());
}
#[test]
fn test_three_layer_ones_helper() {
let sample = three_layer_ones(4, 4, 2, 2);
let path = sample.path();
assert!(path.exists());
let store = Arc::new(FilesystemStore::new(path).unwrap());
for layer in ["A", "B", "C"] {
let array = zarrs::array::Array::open(store.clone(), &format!("/{}", layer));
assert!(array.is_ok(), "Layer {} should exist", layer);
}
}
#[test]
fn test_preset_small() {
let sample = preset_small()
.layer(LayerConfig::ones("test"))
.build()
.unwrap();
assert!(sample.path().exists());
}
#[test]
fn test_preset_cost_surface() {
let sample = preset_cost_surface()
.dimensions(8, 8)
.chunks(4, 4)
.build()
.unwrap();
let path = sample.path();
assert!(path.exists());
let store = Arc::new(FilesystemStore::new(path).unwrap());
for layer in ["A", "B", "C"] {
let array = zarrs::array::Array::open(store.clone(), &format!("/{}", layer));
assert!(array.is_ok(), "Layer {} should exist", layer);
}
}
}