nuts_rs/
sampler_stats.rs

1use std::collections::HashMap;
2
3use nuts_storable::{HasDims, Storable, Value};
4
5use crate::Math;
6
7#[derive(Clone)]
8pub struct StatsDims {
9    n_dim: u64,
10    coord: Option<Value>,
11}
12
13impl HasDims for StatsDims {
14    fn dim_sizes(&self) -> std::collections::HashMap<String, u64> {
15        std::collections::HashMap::from([("unconstrained_parameter".to_string(), self.n_dim)])
16    }
17
18    fn coords(&self) -> HashMap<String, Value> {
19        if let Some(coord) = &self.coord {
20            return HashMap::from([("unconstrained_parameter".to_string(), coord.clone())]);
21        }
22        HashMap::new()
23    }
24}
25
26impl<M: Math> From<&M> for StatsDims {
27    fn from(math: &M) -> Self {
28        StatsDims {
29            n_dim: math.dim() as u64,
30            coord: math.vector_coord(),
31        }
32    }
33}
34
35pub trait SamplerStats<M: Math> {
36    type Stats: Storable<StatsDims>;
37    type StatsOptions: Copy + Send + Sync;
38
39    fn extract_stats(&self, math: &mut M, opt: Self::StatsOptions) -> Self::Stats;
40}