use std::mem::replace;
use std::sync::Arc;
use std::{collections::HashMap, num::NonZero};
use anyhow::{Context, Result};
use nuts_storable::{ItemType, Value};
use zarrs::array::codec::{BloscCodec, BloscCodecConfiguration, BloscCodecConfigurationV1};
use zarrs::array::{Array, ArrayBuilder, DataType, FillValue, data_type};
use zarrs::metadata_ext::data_type::NumpyTimeUnit;
#[derive(Clone, Debug)]
pub enum SampleBufferValue {
F64(Vec<f64>),
F32(Vec<f32>),
Bool(Vec<bool>),
I64(Vec<i64>),
U64(Vec<u64>),
String(Vec<String>),
}
impl SampleBufferValue {
pub fn len(&self) -> usize {
match self {
SampleBufferValue::F64(vec) => vec.len(),
SampleBufferValue::F32(vec) => vec.len(),
SampleBufferValue::Bool(vec) => vec.len(),
SampleBufferValue::I64(vec) => vec.len(),
SampleBufferValue::U64(vec) => vec.len(),
SampleBufferValue::String(vec) => vec.len(),
}
}
}
pub struct SampleBuffer {
pub items: SampleBufferValue,
pub len: usize,
pub full_at: usize,
pub current_chunk: usize,
}
#[derive(Debug)]
pub struct Chunk {
pub chunk_idx: usize,
pub len: usize,
pub full_at: usize,
pub values: SampleBufferValue,
}
impl Chunk {
pub fn is_full(&self) -> bool {
self.full_at == self.len
}
}
impl SampleBuffer {
pub fn new(item_type: ItemType, chunk_size: u64) -> Self {
let chunk_size = chunk_size.try_into().expect("Chunk size too large");
let inner = match item_type {
ItemType::F64 => SampleBufferValue::F64(Vec::with_capacity(chunk_size)),
ItemType::F32 => SampleBufferValue::F32(Vec::with_capacity(chunk_size)),
ItemType::U64 => SampleBufferValue::U64(Vec::with_capacity(chunk_size)),
ItemType::Bool => SampleBufferValue::Bool(Vec::with_capacity(chunk_size)),
ItemType::I64 => SampleBufferValue::I64(Vec::with_capacity(chunk_size)),
ItemType::String => SampleBufferValue::String(Vec::with_capacity(chunk_size)),
ItemType::DateTime64(_) => panic!("DateTime64 type not supported in SampleBuffer"),
ItemType::TimeDelta64(_) => panic!("TimeDelta64 type not supported in SampleBuffer"),
};
Self {
items: inner,
len: 0,
full_at: chunk_size,
current_chunk: 0,
}
}
pub fn reset(&mut self) -> Option<Chunk> {
if self.len == 0 {
self.current_chunk = 0;
return None;
}
let out = self.finish_chunk();
self.current_chunk = 0;
Some(out)
}
pub fn finish_chunk(&mut self) -> Chunk {
let values = match &mut self.items {
SampleBufferValue::F64(vec) => {
SampleBufferValue::F64(replace(vec, Vec::with_capacity(vec.len())))
}
SampleBufferValue::F32(vec) => {
SampleBufferValue::F32(replace(vec, Vec::with_capacity(vec.len())))
}
SampleBufferValue::U64(vec) => {
SampleBufferValue::U64(replace(vec, Vec::with_capacity(vec.len())))
}
SampleBufferValue::Bool(vec) => {
SampleBufferValue::Bool(replace(vec, Vec::with_capacity(vec.len())))
}
SampleBufferValue::I64(vec) => {
SampleBufferValue::I64(replace(vec, Vec::with_capacity(vec.len())))
}
SampleBufferValue::String(vec) => {
SampleBufferValue::String(replace(vec, Vec::with_capacity(vec.len())))
}
};
let output = Chunk {
chunk_idx: self.current_chunk,
len: self.len,
values,
full_at: self.full_at,
};
self.current_chunk += 1;
self.len = 0;
output
}
pub fn copy_as_chunk(&self) -> Option<Chunk> {
if self.len == 0 {
return None;
}
let values = match &self.items {
SampleBufferValue::F64(vec) => SampleBufferValue::F64(vec.clone()),
SampleBufferValue::F32(vec) => SampleBufferValue::F32(vec.clone()),
SampleBufferValue::U64(vec) => SampleBufferValue::U64(vec.clone()),
SampleBufferValue::Bool(vec) => SampleBufferValue::Bool(vec.clone()),
SampleBufferValue::I64(vec) => SampleBufferValue::I64(vec.clone()),
SampleBufferValue::String(vec) => SampleBufferValue::String(vec.clone()),
};
Some(Chunk {
chunk_idx: self.current_chunk,
len: self.len,
values,
full_at: self.full_at,
})
}
pub fn push(&mut self, item: Value) -> Option<Chunk> {
assert!(self.len < self.full_at);
match (&mut self.items, item) {
(SampleBufferValue::F64(vec), Value::ScalarF64(v)) => vec.push(v),
(SampleBufferValue::F32(vec), Value::ScalarF32(v)) => vec.push(v),
(SampleBufferValue::U64(vec), Value::ScalarU64(v)) => vec.push(v),
(SampleBufferValue::Bool(vec), Value::ScalarBool(v)) => vec.push(v),
(SampleBufferValue::I64(vec), Value::ScalarI64(v)) => vec.push(v),
(SampleBufferValue::F64(vec), Value::F64(v)) => vec.extend(v),
(SampleBufferValue::F32(vec), Value::F32(v)) => vec.extend(v),
(SampleBufferValue::U64(vec), Value::U64(v)) => vec.extend(v),
(SampleBufferValue::Bool(vec), Value::Bool(v)) => vec.extend(v),
(SampleBufferValue::I64(vec), Value::I64(v)) => vec.extend(v),
(SampleBufferValue::String(vec), Value::ScalarString(s)) => vec.push(s),
_ => panic!("Mismatched item type"),
}
self.len += 1;
if self.len == self.full_at {
Some(self.finish_chunk())
} else {
None
}
}
pub fn total_pushed(&self) -> u64 {
self.current_chunk as u64 * self.full_at as u64 + self.len as u64
}
}
pub fn value_to_zarr_coord_params(coord: &Value) -> (DataType, usize, FillValue) {
match coord {
Value::F64(v) => (data_type::float64(), v.len(), FillValue::from(f64::NAN)),
Value::F32(v) => (data_type::float32(), v.len(), FillValue::from(f32::NAN)),
Value::U64(v) => (data_type::uint64(), v.len(), FillValue::from(0u64)),
Value::I64(v) => (data_type::int64(), v.len(), FillValue::from(0i64)),
Value::Bool(v) => (data_type::bool(), v.len(), FillValue::from(false)),
Value::Strings(v) => (data_type::string(), v.len(), FillValue::from("")),
Value::DateTime64(unit, v) => {
let unit = match unit {
nuts_storable::DateTimeUnit::Seconds => NumpyTimeUnit::Second,
nuts_storable::DateTimeUnit::Milliseconds => NumpyTimeUnit::Millisecond,
nuts_storable::DateTimeUnit::Microseconds => NumpyTimeUnit::Microsecond,
nuts_storable::DateTimeUnit::Nanoseconds => NumpyTimeUnit::Nanosecond,
};
let scale_factor = NonZero::new(1).unwrap();
(
data_type::numpy_datetime64(unit, scale_factor),
v.len(),
FillValue::from(i64::MIN),
)
}
Value::TimeDelta64(unit, v) => {
let unit = match unit {
nuts_storable::DateTimeUnit::Seconds => NumpyTimeUnit::Second,
nuts_storable::DateTimeUnit::Milliseconds => NumpyTimeUnit::Millisecond,
nuts_storable::DateTimeUnit::Microseconds => NumpyTimeUnit::Microsecond,
nuts_storable::DateTimeUnit::Nanoseconds => NumpyTimeUnit::Nanosecond,
};
let scale_factor = NonZero::new(1).unwrap();
(
data_type::numpy_timedelta64(unit, scale_factor),
v.len(),
FillValue::from(i64::MIN),
)
}
_ => panic!("Unsupported coordinate type"),
}
}
pub fn create_arrays<TStorage: ?Sized>(
store: Arc<TStorage>,
group_path: &str,
item_types: &Vec<(String, ItemType)>,
item_dims: &Vec<(String, String, Vec<String>)>,
n_chains: u64,
n_draws: u64,
dim_sizes: &HashMap<String, u64>,
draw_chunk_size: u64,
) -> Result<HashMap<String, Array<TStorage>>> {
let mut arrays = HashMap::new();
for ((name1, item_type), (name2, primary_dim, extra_dims)) in
item_types.iter().zip(item_dims.iter())
{
assert!(name1 == name2);
let name = name1;
if ["draw", "chain"].contains(&name.as_str()) {
continue;
}
let dims = std::iter::once("chain".to_string())
.chain(std::iter::once(primary_dim.clone()))
.chain(extra_dims.iter().cloned());
let extra_shape: Result<Vec<u64>> = extra_dims
.iter()
.map(|dim| {
dim_sizes
.get(dim)
.ok_or_else(|| {
anyhow::anyhow!("Unknown dimension size for dimension {}", dim)
.context(format!("Could not write {}/{}", group_path, name))
})
.copied()
})
.collect();
let extra_shape = extra_shape?;
let shape: Vec<u64> = std::iter::once(n_chains)
.chain(std::iter::once(n_draws))
.chain(extra_shape.clone())
.collect();
let zarr_type = match item_type {
ItemType::F64 => data_type::float64(),
ItemType::F32 => data_type::float32(),
ItemType::U64 => data_type::uint64(),
ItemType::I64 => data_type::int64(),
ItemType::Bool => data_type::bool(),
ItemType::String => data_type::string(),
ItemType::DateTime64(unit) => {
let unit = match unit {
nuts_storable::DateTimeUnit::Seconds => NumpyTimeUnit::Second,
nuts_storable::DateTimeUnit::Milliseconds => NumpyTimeUnit::Millisecond,
nuts_storable::DateTimeUnit::Microseconds => NumpyTimeUnit::Microsecond,
nuts_storable::DateTimeUnit::Nanoseconds => NumpyTimeUnit::Nanosecond,
};
let scale_factor = NonZero::new(1).unwrap();
data_type::numpy_datetime64(unit, scale_factor)
}
ItemType::TimeDelta64(unit) => {
let unit = match unit {
nuts_storable::DateTimeUnit::Seconds => NumpyTimeUnit::Second,
nuts_storable::DateTimeUnit::Milliseconds => NumpyTimeUnit::Millisecond,
nuts_storable::DateTimeUnit::Microseconds => NumpyTimeUnit::Microsecond,
nuts_storable::DateTimeUnit::Nanoseconds => NumpyTimeUnit::Nanosecond,
};
let scale_factor = NonZero::new(1).unwrap();
data_type::numpy_timedelta64(unit, scale_factor)
}
};
let fill_value = match item_type {
ItemType::F64 => FillValue::from(f64::NAN),
ItemType::F32 => FillValue::from(f32::NAN),
ItemType::U64 => FillValue::from(0u64),
ItemType::I64 => FillValue::from(0i64),
ItemType::Bool => FillValue::from(false),
ItemType::String => FillValue::from(""),
ItemType::DateTime64(_) => FillValue::new_optional_null(),
ItemType::TimeDelta64(_) => FillValue::new_optional_null(),
};
let grid: Vec<u64> = std::iter::once(1)
.chain(std::iter::once(draw_chunk_size))
.chain(extra_shape)
.map(|size| size.max(1))
.collect();
let codec = {
if let Some(typesize) = zarr_type.fixed_size() {
let config = BloscCodecConfiguration::V1(BloscCodecConfigurationV1 {
cname: zarrs::array::codec::BloscCompressor::Zstd,
clevel: 3u8.try_into().unwrap(),
shuffle: zarrs::array::codec::BloscShuffleMode::Shuffle,
blocksize: 0,
typesize: Some(typesize),
});
BloscCodec::new_with_configuration(&config)?
} else {
let config = BloscCodecConfiguration::V1(BloscCodecConfigurationV1 {
cname: zarrs::array::codec::BloscCompressor::Zstd,
clevel: 3u8.try_into().unwrap(),
shuffle: zarrs::array::codec::BloscShuffleMode::NoShuffle,
blocksize: 0,
typesize: None,
});
BloscCodec::new_with_configuration(&config)
.context("Failed to create Blosc codec")?
}
};
let array = ArrayBuilder::new(shape, grid, zarr_type, fill_value)
.bytes_to_bytes_codecs(vec![Arc::new(codec)])
.dimension_names(Some(dims))
.build(store.clone(), &format!("{}/{}", group_path, name))
.context("Failed to build Zarr array")?;
arrays.insert(name.to_string(), array);
}
Ok(arrays)
}