Skip to main content

nuts_rs/
sampler_stats.rs

1//! Define how per-draw diagnostics are extracted and labelled for downstream storage.
2
3use std::collections::HashMap;
4
5use nuts_storable::{HasDims, Storable, Value};
6
7use crate::Math;
8
9#[derive(Clone)]
10pub struct StatsDims {
11    n_dim: u64,
12    coord: Option<Value>,
13}
14
15impl HasDims for StatsDims {
16    fn dim_sizes(&self) -> std::collections::HashMap<String, u64> {
17        std::collections::HashMap::from([("unconstrained_parameter".to_string(), self.n_dim)])
18    }
19
20    fn coords(&self) -> HashMap<String, Value> {
21        if let Some(coord) = &self.coord {
22            return HashMap::from([("unconstrained_parameter".to_string(), coord.clone())]);
23        }
24        HashMap::new()
25    }
26}
27
28impl<M: Math> From<&M> for StatsDims {
29    fn from(math: &M) -> Self {
30        StatsDims {
31            n_dim: math.dim() as u64,
32            coord: math.vector_coord(),
33        }
34    }
35}
36
37pub trait SamplerStats<M: Math> {
38    type Stats: Storable<StatsDims>;
39    type StatsOptions: Copy + Send + Sync;
40
41    fn extract_stats(&self, math: &mut M, opt: Self::StatsOptions) -> Self::Stats;
42}