use std::collections::HashMap;
use std::iter::once;
use std::sync::Arc;
use anyhow::{Context, Result};
use nuts_storable::{ItemType, Value};
use zarrs::array::{ArrayBuilder, ArraySubset};
use zarrs::group::GroupBuilder;
use zarrs::storage::{ReadableWritableListableStorage, ReadableWritableListableStorageTraits};
use super::common::{Chunk, SampleBuffer, SampleBufferValue, value_to_zarr_coord_params};
use super::create_arrays;
use crate::storage::{ChainStorage, StorageConfig, TraceStorage};
use crate::{Math, Progress, Settings};
pub type Array = zarrs::array::Array<dyn ReadableWritableListableStorageTraits>;
struct ArrayCollection {
pub warmup_param_arrays: HashMap<String, Array>,
pub sample_param_arrays: HashMap<String, Array>,
pub warmup_draw_arrays: HashMap<String, Array>,
pub sample_draw_arrays: HashMap<String, Array>,
}
pub fn store_coords(
store: ReadableWritableListableStorage,
group: String,
coords: &HashMap<String, Value>,
) -> Result<()> {
for (name, coord) in coords {
let (data_type, len, fill_value) = value_to_zarr_coord_params(coord);
let name: &String = name;
let coord_array =
ArrayBuilder::new(vec![len as u64], vec![len as u64], data_type, fill_value)
.dimension_names(Some(vec![name.to_string()]))
.build(store.clone(), &format!("{}/{}", group, name))?;
if len > 0 {
let subset = vec![0];
match coord {
&Value::F64(ref v) => coord_array.store_chunk(&subset, v)?,
&Value::F32(ref v) => coord_array.store_chunk(&subset, v)?,
&Value::U64(ref v) => coord_array.store_chunk(&subset, v)?,
&Value::I64(ref v) => coord_array.store_chunk(&subset, v)?,
&Value::Bool(ref v) => coord_array.store_chunk(&subset, v)?,
&Value::Strings(ref v) => coord_array.store_chunk(&subset, v)?,
&Value::DateTime64(_, ref data) => coord_array.store_chunk(&subset, data)?,
&Value::TimeDelta64(_, ref data) => coord_array.store_chunk(&subset, data)?,
_ => unreachable!(),
}
}
coord_array.store_metadata()?;
}
Ok(())
}
pub struct ZarrTraceStorage {
arrays: Arc<ArrayCollection>,
draw_chunk_size: u64,
param_types: Vec<(String, ItemType)>,
draw_types: Vec<(String, ItemType)>,
event_dim_of_stat: HashMap<String, String>,
}
pub struct ZarrChainStorage {
draw_buffers: HashMap<String, SampleBuffer>,
stats_buffers: HashMap<String, SampleBuffer>,
arrays: Arc<ArrayCollection>,
chain: u64,
last_sample_was_warmup: bool,
event_dim_of_stat: HashMap<String, String>,
warmup_event_counts: HashMap<String, u64>,
}
fn store_zarr_chunk(array: &Array, data: Chunk, chain_chunk_index: u64) -> Result<()> {
let rank = array.chunk_grid().dimensionality();
assert!(rank >= 2);
let chunk_vec: Vec<_> = once(chain_chunk_index as u64)
.chain(once(data.chunk_idx as u64))
.chain(once(0).cycle().take(rank - 2))
.collect();
let chunk = &chunk_vec[..];
if data.values.len() == 0 {
return Ok(());
}
if let SampleBufferValue::String(v) = data.values {
let start = vec![
chain_chunk_index,
data.chunk_idx as u64 * data.full_at as u64,
];
let shape = vec![1u64, data.len as u64];
let subset = ArraySubset::new_with_start_shape(start, shape)
.context("Failed to build string chunk subset")?;
return array
.store_array_subset(&subset, &v)
.context(format!("Failed to store string chunk for {}", array.path()));
}
let result = if data.is_full() {
match data.values {
SampleBufferValue::F64(v) => array.store_chunk(&chunk, &v),
SampleBufferValue::F32(v) => array.store_chunk(&chunk, &v),
SampleBufferValue::U64(v) => array.store_chunk(&chunk, &v),
SampleBufferValue::I64(v) => array.store_chunk(&chunk, &v),
SampleBufferValue::Bool(v) => array.store_chunk(&chunk, &v),
SampleBufferValue::String(_) => unreachable!(),
}
} else {
let mut shape: Vec<_> = array.shape().iter().cloned().collect();
assert!(shape.len() >= 2);
shape[0] = 1;
shape[1] = data.len as u64;
let chunk_subset = ArraySubset::new_with_shape(shape);
match data.values {
SampleBufferValue::F64(v) => {
assert!(v.len() == chunk_subset.num_elements_usize());
array.store_chunk_subset(&chunk, &chunk_subset, &v)
}
SampleBufferValue::F32(v) => {
assert!(v.len() == chunk_subset.num_elements_usize());
array.store_chunk_subset(&chunk, &chunk_subset, &v)
}
SampleBufferValue::U64(v) => {
assert!(v.len() == chunk_subset.num_elements_usize());
array.store_chunk_subset(&chunk, &chunk_subset, &v)
}
SampleBufferValue::I64(v) => {
assert!(v.len() == chunk_subset.num_elements_usize());
array.store_chunk_subset(&chunk, &chunk_subset, &v)
}
SampleBufferValue::Bool(v) => {
assert!(v.len() == chunk_subset.num_elements_usize());
array.store_chunk_subset(&chunk, &chunk_subset, &v)
}
SampleBufferValue::String(_) => unreachable!(),
}
};
result.context(format!(
"Failed to store chunk for variable {} at chunk {} with length {}",
array.path(),
data.chunk_idx,
data.len
))?;
Ok(())
}
impl ZarrChainStorage {
fn new(
arrays: Arc<ArrayCollection>,
param_types: &Vec<(String, ItemType)>,
draw_types: &Vec<(String, ItemType)>,
buffer_size: u64,
chain: u64,
event_dim_of_stat: HashMap<String, String>,
) -> Self {
let draw_buffers = draw_types
.iter()
.map(|(name, item_type)| (name.clone(), SampleBuffer::new(*item_type, buffer_size)))
.collect();
let stats_buffers = param_types
.iter()
.map(|(name, item_type)| (name.clone(), SampleBuffer::new(*item_type, buffer_size)))
.collect();
Self {
draw_buffers,
stats_buffers,
arrays,
chain,
last_sample_was_warmup: true,
event_dim_of_stat,
warmup_event_counts: HashMap::new(),
}
}
fn push_param(&mut self, name: &str, value: Value, is_warmup: bool) -> Result<()> {
if ["draw", "chain"].contains(&name) {
return Ok(());
}
let Some(buffer) = self.stats_buffers.get_mut(name) else {
panic!("Unknown param name: {}", name);
};
if let Some(chunk) = buffer.push(value) {
let array = if is_warmup {
&self.arrays.warmup_param_arrays[name]
} else {
&self.arrays.sample_param_arrays[name]
};
store_zarr_chunk(array, chunk, self.chain)?;
}
Ok(())
}
fn push_draw(&mut self, name: &str, value: Value, is_warmup: bool) -> Result<()> {
if ["draw", "chain"].contains(&name) {
return Ok(());
}
let Some(buffer) = self.draw_buffers.get_mut(name) else {
panic!("Unknown posterior variable name: {}", name);
};
if let Some(chunk) = buffer.push(value) {
let array = if is_warmup {
&self.arrays.warmup_draw_arrays[name]
} else {
&self.arrays.sample_draw_arrays[name]
};
store_zarr_chunk(array, chunk, self.chain)?;
}
Ok(())
}
}
impl ChainStorage for ZarrChainStorage {
type Finalized = HashMap<String, (u64, u64)>;
fn record_sample(
&mut self,
_settings: &impl Settings,
stats: Vec<(&str, Option<Value>)>,
draws: Vec<(&str, Option<Value>)>,
info: &Progress,
) -> Result<()> {
let is_first_draw = self.last_sample_was_warmup && !info.tuning;
if is_first_draw {
{
let mut seen = std::collections::HashSet::new();
for (field, dim) in &self.event_dim_of_stat {
if seen.insert(dim.as_str()) {
if let Some(buf) = self.stats_buffers.get(field.as_str()) {
self.warmup_event_counts
.insert(dim.clone(), buf.total_pushed());
}
}
}
}
for (key, buffer) in self.draw_buffers.iter_mut() {
if let Some(chunk) = buffer.reset() {
store_zarr_chunk(&self.arrays.warmup_draw_arrays[key], chunk, self.chain)?;
}
}
for (key, buffer) in self.stats_buffers.iter_mut() {
if let Some(chunk) = buffer.reset() {
store_zarr_chunk(&self.arrays.warmup_param_arrays[key], chunk, self.chain)?;
}
}
self.last_sample_was_warmup = false;
}
for (name, value) in stats {
if let Some(value) = value {
self.push_param(name, value, info.tuning)?;
}
}
for (name, value) in draws {
if let Some(value) = value {
self.push_draw(name, value, info.tuning)?;
} else {
panic!("Missing draw value for {}", name);
}
}
Ok(())
}
fn finalize(self) -> Result<Self::Finalized> {
let mut seen = std::collections::HashSet::new();
let mut sample_counts: HashMap<String, u64> = HashMap::new();
for (field, dim) in &self.event_dim_of_stat {
if seen.insert(dim.as_str()) {
if let Some(buf) = self.stats_buffers.get(field.as_str()) {
sample_counts.insert(dim.clone(), buf.total_pushed());
}
}
}
for (key, mut buffer) in self.draw_buffers.into_iter() {
if let Some(chunk) = buffer.reset() {
let array = if self.last_sample_was_warmup {
&self.arrays.warmup_draw_arrays[&key]
} else {
&self.arrays.sample_draw_arrays[&key]
};
store_zarr_chunk(array, chunk, self.chain)?;
}
}
for (key, mut buffer) in self.stats_buffers.into_iter() {
if let Some(chunk) = buffer.reset() {
let array = if self.last_sample_was_warmup {
&self.arrays.warmup_param_arrays[&key]
} else {
&self.arrays.sample_param_arrays[&key]
};
store_zarr_chunk(array, chunk, self.chain)?;
}
}
let counts = self
.event_dim_of_stat
.values()
.collect::<std::collections::HashSet<_>>()
.into_iter()
.map(|dim| {
let w = self
.warmup_event_counts
.get(dim.as_str())
.copied()
.unwrap_or(0);
let s = sample_counts.get(dim.as_str()).copied().unwrap_or(0);
(dim.clone(), (w, s))
})
.collect();
Ok(counts)
}
fn inspect(&self) -> Result<Option<Self::Finalized>> {
let mut seen = std::collections::HashSet::new();
let mut counts = HashMap::new();
for (field, dim) in &self.event_dim_of_stat {
if seen.insert(dim.as_str()) {
let s = self
.stats_buffers
.get(field.as_str())
.map(|b| b.total_pushed())
.unwrap_or(0);
let w = self
.warmup_event_counts
.get(dim.as_str())
.copied()
.unwrap_or(0);
counts.insert(dim.clone(), (w, s));
}
}
Ok(Some(counts))
}
fn flush(&self) -> Result<()> {
for (key, buffer) in &self.draw_buffers {
if let Some(temp_chunk) = buffer.copy_as_chunk() {
let array = if self.last_sample_was_warmup {
&self.arrays.warmup_draw_arrays[key]
} else {
&self.arrays.sample_draw_arrays[key]
};
store_zarr_chunk(array, temp_chunk, self.chain)?;
}
}
for (key, buffer) in &self.stats_buffers {
if let Some(temp_chunk) = buffer.copy_as_chunk() {
let array = if self.last_sample_was_warmup {
&self.arrays.warmup_param_arrays[key]
} else {
&self.arrays.sample_param_arrays[key]
};
store_zarr_chunk(array, temp_chunk, self.chain)?;
}
}
Ok(())
}
}
pub struct ZarrConfig {
store: ReadableWritableListableStorage,
group_path: Option<String>,
draw_chunk_size: u64,
store_warmup: bool,
}
impl ZarrConfig {
pub fn new(store: ReadableWritableListableStorage) -> Self {
Self {
store,
group_path: None,
draw_chunk_size: 100,
store_warmup: true,
}
}
pub fn with_chunk_size(mut self, chunk_size: u64) -> Self {
self.draw_chunk_size = chunk_size;
self
}
pub fn with_group_path<S: Into<String>>(mut self, path: S) -> Self {
self.group_path = Some(path.into());
self
}
pub fn store_warmup(mut self, store: bool) -> Self {
self.store_warmup = store;
self
}
}
impl StorageConfig for ZarrConfig {
type Storage = ZarrTraceStorage;
fn new_trace<M: Math>(self, settings: &impl Settings, math: &M) -> Result<Self::Storage> {
let n_chains = settings.num_chains() as u64;
let n_tune = settings.hint_num_tune() as u64;
let n_draws = settings.hint_num_draws() as u64;
let param_types = settings.stat_types(math);
let draw_types = settings.data_types(math);
let draw_dim_sizes = math.dim_sizes();
let stat_dim_sizes = settings.stat_dim_sizes(math);
let stat_event_dims_vec = settings.stat_event_dims(math);
let param_dims: Vec<(String, String, Vec<String>)> = settings
.stat_dims_all(math)
.into_iter()
.zip(stat_event_dims_vec.iter())
.map(|((name, extra), (_, ev))| {
(name, ev.as_deref().unwrap_or("draw").to_string(), extra)
})
.collect();
let draw_dims: Vec<(String, String, Vec<String>)> = settings
.data_dims_all(math)
.into_iter()
.map(|(name, extra)| (name, "draw".to_string(), extra))
.collect();
let event_dim_of_stat: HashMap<String, String> = stat_event_dims_vec
.iter()
.filter_map(|(name, opt)| opt.as_ref().map(|d| (name.clone(), d.clone())))
.collect();
let mut group_path = self.group_path.unwrap_or_else(|| "".to_string());
if !group_path.ends_with('/') {
group_path.push('/');
}
let store = self.store;
let draw_chunk_size = self.draw_chunk_size;
let mut root = GroupBuilder::new().build(store.clone(), &group_path)?;
let attrs = root.attributes_mut();
attrs.insert(
"sampler".to_string(),
serde_json::Value::String(env!("CARGO_PKG_NAME").to_string()),
);
attrs.insert(
"sampler_version".to_string(),
serde_json::Value::String(env!("CARGO_PKG_VERSION").to_string()),
);
attrs.insert(
"sampler_kind".to_string(),
serde_json::Value::String(settings.sampler_name().to_string()),
);
attrs.insert(
"adaptation_kind".to_string(),
serde_json::Value::String(settings.adaptation_name().to_string()),
);
attrs.insert(
"sampler_settings".to_string(),
serde_json::to_value(settings).context("Could not serialize sampler settings")?,
);
root.store_metadata()?;
GroupBuilder::new()
.build(store.clone(), &format!("{}warmup_posterior", group_path))?
.store_metadata()?;
GroupBuilder::new()
.build(store.clone(), &format!("{}warmup_sample_stats", group_path))?
.store_metadata()?;
GroupBuilder::new()
.build(store.clone(), &format!("{}posterior", group_path))?
.store_metadata()?;
GroupBuilder::new()
.build(store.clone(), &format!("{}sample_stats", group_path))?
.store_metadata()?;
let warmup_param_arrays = create_arrays(
store.clone(),
&format!("{}warmup_sample_stats", group_path),
¶m_types,
¶m_dims,
n_chains,
n_tune,
&stat_dim_sizes,
self.draw_chunk_size,
)?;
for array in warmup_param_arrays.values() {
array.store_metadata()?;
}
let sample_param_arrays = create_arrays(
store.clone(),
&format!("{}sample_stats", group_path),
¶m_types,
¶m_dims,
n_chains,
n_draws,
&stat_dim_sizes,
self.draw_chunk_size,
)?;
for array in sample_param_arrays.values() {
array.store_metadata()?;
}
let warmup_draw_arrays = create_arrays(
store.clone(),
&format!("{}warmup_posterior", group_path),
&draw_types,
&draw_dims,
n_chains,
n_tune,
&draw_dim_sizes,
self.draw_chunk_size,
)?;
for array in warmup_draw_arrays.values() {
array.store_metadata()?;
}
let sample_draw_arrays = create_arrays(
store.clone(),
&format!("{}posterior", group_path),
&draw_types,
&draw_dims,
n_chains,
n_draws,
&draw_dim_sizes,
self.draw_chunk_size,
)?;
for array in sample_draw_arrays.values() {
array.store_metadata()?;
}
let trace_storage = ArrayCollection {
warmup_param_arrays,
sample_param_arrays,
warmup_draw_arrays,
sample_draw_arrays,
};
let draw_coords = math.coords();
let stat_coords = settings.stat_coords(math);
store_coords(
store.clone(),
format!("{}posterior", &group_path),
&draw_coords,
)?;
store_coords(
store.clone(),
format!("{}warmup_posterior", &group_path),
&draw_coords,
)?;
store_coords(
store.clone(),
format!("{}sample_stats", &group_path),
&stat_coords,
)?;
store_coords(
store.clone(),
format!("{}warmup_sample_stats", &group_path),
&stat_coords,
)?;
Ok(ZarrTraceStorage {
arrays: Arc::new(trace_storage),
param_types,
draw_types,
draw_chunk_size,
event_dim_of_stat,
})
}
}
impl TraceStorage for ZarrTraceStorage {
type ChainStorage = ZarrChainStorage;
type Finalized = ();
fn initialize_trace_for_chain(&self, chain_id: u64) -> Result<Self::ChainStorage> {
Ok(ZarrChainStorage::new(
self.arrays.clone(),
&self.param_types,
&self.draw_types,
self.draw_chunk_size,
chain_id as _,
self.event_dim_of_stat.clone(),
))
}
fn finalize(
self,
traces: Vec<Result<<Self::ChainStorage as ChainStorage>::Finalized>>,
) -> Result<(Option<anyhow::Error>, Self::Finalized)> {
let mut warmup_counts: HashMap<String, Vec<u64>> = HashMap::new();
let mut sample_counts: HashMap<String, Vec<u64>> = HashMap::new();
for trace in traces {
match trace {
Err(e) => return Ok((Some(e), ())),
Ok(c) => {
for (dim, (w, s)) in c {
warmup_counts.entry(dim.clone()).or_default().push(w);
sample_counts.entry(dim.clone()).or_default().push(s);
}
}
}
}
let max_sample: HashMap<String, u64> = sample_counts
.iter()
.map(|(dim, counts)| (dim.clone(), *counts.iter().max().unwrap_or(&0)))
.collect();
let max_warmup: HashMap<String, u64> = warmup_counts
.iter()
.map(|(dim, counts)| (dim.clone(), *counts.iter().max().unwrap_or(&0)))
.collect();
let mut arrays = Arc::try_unwrap(self.arrays).unwrap_or_else(|_| {
panic!("ArrayCollection still has multiple references at finalize")
});
for (field_name, event_dim) in &self.event_dim_of_stat {
if let Some(&max) = max_sample.get(event_dim) {
if let Some(array) = arrays.sample_param_arrays.get_mut(field_name) {
let mut shape = array.shape().to_vec();
shape[1] = max;
array
.set_shape(shape)
.context("Failed to resize event array")?
.store_metadata()
.context("Failed to store resized array metadata")?;
}
}
if let Some(&max) = max_warmup.get(event_dim) {
if let Some(array) = arrays.warmup_param_arrays.get_mut(field_name) {
let mut shape = array.shape().to_vec();
shape[1] = max;
array
.set_shape(shape)
.context("Failed to resize warmup event array")?
.store_metadata()
.context("Failed to store resized warmup array metadata")?;
}
}
}
Ok((None, ()))
}
fn inspect(
&self,
traces: Vec<Result<Option<<Self::ChainStorage as ChainStorage>::Finalized>>>,
) -> Result<(Option<anyhow::Error>, Self::Finalized)> {
for trace in traces {
if let Err(err) = trace {
return Ok((Some(err), ()));
};
}
Ok((None, ()))
}
}