use anyhow::Result;
use nuts_storable::{ItemType, Value};
use std::collections::HashMap;
use crate::storage::{ChainStorage, StorageConfig, TraceStorage};
use crate::{Progress, Settings};
#[derive(Clone, Debug)]
pub enum HashMapValue {
F64(Vec<f64>),
F32(Vec<f32>),
Bool(Vec<bool>),
I64(Vec<i64>),
U64(Vec<u64>),
String(Vec<String>),
}
impl HashMapValue {
fn new(item_type: ItemType) -> Self {
match item_type {
ItemType::F64 => HashMapValue::F64(Vec::new()),
ItemType::F32 => HashMapValue::F32(Vec::new()),
ItemType::Bool => HashMapValue::Bool(Vec::new()),
ItemType::I64 => HashMapValue::I64(Vec::new()),
ItemType::U64 => HashMapValue::U64(Vec::new()),
ItemType::String => HashMapValue::String(Vec::new()),
ItemType::DateTime64(_) | ItemType::TimeDelta64(_) => HashMapValue::I64(Vec::new()),
}
}
fn push(&mut self, value: Value) {
match (self, value) {
(HashMapValue::F64(vec), Value::ScalarF64(v)) => vec.push(v),
(HashMapValue::F32(vec), Value::ScalarF32(v)) => vec.push(v),
(HashMapValue::U64(vec), Value::ScalarU64(v)) => vec.push(v),
(HashMapValue::Bool(vec), Value::ScalarBool(v)) => vec.push(v),
(HashMapValue::I64(vec), Value::ScalarI64(v)) => vec.push(v),
(HashMapValue::F64(vec), Value::F64(v)) => vec.extend(v),
(HashMapValue::F32(vec), Value::F32(v)) => vec.extend(v),
(HashMapValue::U64(vec), Value::U64(v)) => vec.extend(v),
(HashMapValue::Bool(vec), Value::Bool(v)) => vec.extend(v),
(HashMapValue::I64(vec), Value::I64(v)) => vec.extend(v),
(HashMapValue::String(vec), Value::Strings(v)) => vec.extend(v),
(HashMapValue::String(vec), Value::ScalarString(v)) => vec.push(v),
(HashMapValue::I64(vec), Value::DateTime64(_, v)) => vec.extend(v),
(HashMapValue::I64(vec), Value::TimeDelta64(_, v)) => vec.extend(v),
_ => panic!("Mismatched item type"),
}
}
}
#[derive(Clone)]
pub struct HashMapTraceStorage {
draw_types: Vec<(String, ItemType)>,
param_types: Vec<(String, ItemType)>,
}
#[derive(Clone)]
pub struct HashMapChainStorage {
warmup_stats: HashMap<String, HashMapValue>,
sample_stats: HashMap<String, HashMapValue>,
warmup_draws: HashMap<String, HashMapValue>,
sample_draws: HashMap<String, HashMapValue>,
last_sample_was_warmup: bool,
}
#[derive(Debug, Clone)]
pub struct HashMapResult {
pub stats: HashMap<String, HashMapValue>,
pub draws: HashMap<String, HashMapValue>,
}
impl HashMapChainStorage {
fn new(param_types: &[(String, ItemType)], draw_types: &[(String, ItemType)]) -> Self {
let warmup_stats = param_types
.iter()
.cloned()
.map(|(name, item_type)| (name, HashMapValue::new(item_type)))
.collect();
let sample_stats = param_types
.iter()
.cloned()
.map(|(name, item_type)| (name, HashMapValue::new(item_type)))
.collect();
let warmup_draws = draw_types
.iter()
.cloned()
.map(|(name, item_type)| (name, HashMapValue::new(item_type)))
.collect();
let sample_draws = draw_types
.iter()
.cloned()
.map(|(name, item_type)| (name, HashMapValue::new(item_type)))
.collect();
Self {
warmup_stats,
sample_stats,
warmup_draws,
sample_draws,
last_sample_was_warmup: true,
}
}
fn push_param(&mut self, name: &str, value: Value, is_warmup: bool) -> Result<()> {
if ["draw", "chain"].contains(&name) {
return Ok(());
}
let target_map = if is_warmup {
&mut self.warmup_stats
} else {
&mut self.sample_stats
};
if let Some(hash_value) = target_map.get_mut(name) {
hash_value.push(value);
} else {
panic!("Unknown param name: {}", name);
}
Ok(())
}
fn push_draw(&mut self, name: &str, value: Value, is_warmup: bool) -> Result<()> {
if ["draw", "chain"].contains(&name) {
return Ok(());
}
let target_map = if is_warmup {
&mut self.warmup_draws
} else {
&mut self.sample_draws
};
if let Some(hash_value) = target_map.get_mut(name) {
hash_value.push(value);
} else {
panic!("Unknown posterior variable name: {}", name);
}
Ok(())
}
}
impl ChainStorage for HashMapChainStorage {
type Finalized = HashMapResult;
fn record_sample(
&mut self,
_settings: &impl Settings,
stats: Vec<(&str, Option<Value>)>,
draws: Vec<(&str, Option<Value>)>,
info: &Progress,
) -> Result<()> {
let is_first_draw = self.last_sample_was_warmup && !info.tuning;
if is_first_draw {
self.last_sample_was_warmup = false;
}
for (name, value) in stats {
if let Some(value) = value {
self.push_param(name, value, info.tuning)?;
}
}
for (name, value) in draws {
if let Some(value) = value {
self.push_draw(name, value, info.tuning)?;
} else {
panic!("Missing draw value for {}", name);
}
}
Ok(())
}
fn finalize(self) -> Result<Self::Finalized> {
let mut combined_stats = HashMap::new();
let mut combined_draws = HashMap::new();
for (key, warmup_values) in self.warmup_stats {
let sample_values = &self.sample_stats[&key];
let mut combined = warmup_values.clone();
match (&mut combined, sample_values) {
(HashMapValue::F64(combined_vec), HashMapValue::F64(sample_vec)) => {
combined_vec.extend(sample_vec.iter().cloned());
}
(HashMapValue::F32(combined_vec), HashMapValue::F32(sample_vec)) => {
combined_vec.extend(sample_vec.iter().cloned());
}
(HashMapValue::Bool(combined_vec), HashMapValue::Bool(sample_vec)) => {
combined_vec.extend(sample_vec.iter().cloned());
}
(HashMapValue::I64(combined_vec), HashMapValue::I64(sample_vec)) => {
combined_vec.extend(sample_vec.iter().cloned());
}
(HashMapValue::U64(combined_vec), HashMapValue::U64(sample_vec)) => {
combined_vec.extend(sample_vec.iter().cloned());
}
_ => panic!("Type mismatch when combining stats for {}", key),
}
combined_stats.insert(key, combined);
}
for (key, warmup_values) in self.warmup_draws {
let sample_values = &self.sample_draws[&key];
let mut combined = warmup_values.clone();
match (&mut combined, sample_values) {
(HashMapValue::F64(combined_vec), HashMapValue::F64(sample_vec)) => {
combined_vec.extend(sample_vec.iter().cloned());
}
(HashMapValue::F32(combined_vec), HashMapValue::F32(sample_vec)) => {
combined_vec.extend(sample_vec.iter().cloned());
}
(HashMapValue::Bool(combined_vec), HashMapValue::Bool(sample_vec)) => {
combined_vec.extend(sample_vec.iter().cloned());
}
(HashMapValue::I64(combined_vec), HashMapValue::I64(sample_vec)) => {
combined_vec.extend(sample_vec.iter().cloned());
}
(HashMapValue::U64(combined_vec), HashMapValue::U64(sample_vec)) => {
combined_vec.extend(sample_vec.iter().cloned());
}
_ => panic!("Type mismatch when combining draws for {}", key),
}
combined_draws.insert(key, combined);
}
Ok(HashMapResult {
stats: combined_stats,
draws: combined_draws,
})
}
fn flush(&self) -> Result<()> {
Ok(())
}
fn inspect(&self) -> Result<Option<Self::Finalized>> {
self.clone().finalize().map(Some)
}
}
pub struct HashMapConfig {}
impl Default for HashMapConfig {
fn default() -> Self {
Self::new()
}
}
impl HashMapConfig {
pub fn new() -> Self {
Self {}
}
}
impl StorageConfig for HashMapConfig {
type Storage = HashMapTraceStorage;
fn new_trace<M: crate::Math>(
self,
settings: &impl Settings,
math: &M,
) -> Result<Self::Storage> {
Ok(HashMapTraceStorage {
param_types: settings.stat_types(math),
draw_types: settings.data_types(math),
})
}
}
impl TraceStorage for HashMapTraceStorage {
type ChainStorage = HashMapChainStorage;
type Finalized = Vec<HashMapResult>;
fn initialize_trace_for_chain(&self, _chain_id: u64) -> Result<Self::ChainStorage> {
Ok(HashMapChainStorage::new(
&self.param_types,
&self.draw_types,
))
}
fn finalize(
self,
traces: Vec<Result<<Self::ChainStorage as ChainStorage>::Finalized>>,
) -> Result<(Option<anyhow::Error>, Self::Finalized)> {
let mut results = Vec::new();
let mut first_error = None;
for trace in traces {
match trace {
Ok(result) => results.push(result),
Err(err) => {
if first_error.is_none() {
first_error = Some(err);
}
}
}
}
Ok((first_error, results))
}
fn inspect(
&self,
traces: Vec<Result<Option<<Self::ChainStorage as ChainStorage>::Finalized>>>,
) -> Result<(Option<anyhow::Error>, Self::Finalized)> {
self.clone()
.finalize(traces.into_iter().map(|r| r.map(|o| o.unwrap())).collect())
}
}