Skip to main content

nuts_storable/
lib.rs

1//! Minimal serialisation abstractions that let the sampler emit typed values without depending on any specific serialisation format.
2
3use std::collections::HashMap;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub enum DateTimeUnit {
7    Seconds,
8    Milliseconds,
9    Microseconds,
10    Nanoseconds,
11}
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum ItemType {
15    U64,
16    I64,
17    F64,
18    F32,
19    Bool,
20    String,
21    DateTime64(DateTimeUnit),
22    TimeDelta64(DateTimeUnit),
23}
24
25#[derive(Debug, Clone, PartialEq)]
26pub enum Value {
27    U64(Vec<u64>),
28    I64(Vec<i64>),
29    F64(Vec<f64>),
30    F32(Vec<f32>),
31    Bool(Vec<bool>),
32    ScalarString(String),
33    DateTime64(DateTimeUnit, Vec<i64>),
34    TimeDelta64(DateTimeUnit, Vec<i64>),
35    ScalarU64(u64),
36    ScalarI64(i64),
37    ScalarF64(f64),
38    ScalarF32(f32),
39    ScalarBool(bool),
40    Strings(Vec<String>),
41}
42
43impl From<Vec<u64>> for Value {
44    fn from(value: Vec<u64>) -> Self {
45        Value::U64(value)
46    }
47}
48impl From<Vec<i64>> for Value {
49    fn from(value: Vec<i64>) -> Self {
50        Value::I64(value)
51    }
52}
53impl From<Vec<f64>> for Value {
54    fn from(value: Vec<f64>) -> Self {
55        Value::F64(value)
56    }
57}
58impl From<Vec<f32>> for Value {
59    fn from(value: Vec<f32>) -> Self {
60        Value::F32(value)
61    }
62}
63impl From<Vec<bool>> for Value {
64    fn from(value: Vec<bool>) -> Self {
65        Value::Bool(value)
66    }
67}
68impl From<u64> for Value {
69    fn from(value: u64) -> Self {
70        Value::ScalarU64(value)
71    }
72}
73impl From<i64> for Value {
74    fn from(value: i64) -> Self {
75        Value::ScalarI64(value)
76    }
77}
78impl From<f64> for Value {
79    fn from(value: f64) -> Self {
80        Value::ScalarF64(value)
81    }
82}
83impl From<f32> for Value {
84    fn from(value: f32) -> Self {
85        Value::ScalarF32(value)
86    }
87}
88impl From<bool> for Value {
89    fn from(value: bool) -> Self {
90        Value::ScalarBool(value)
91    }
92}
93
94pub trait HasDims {
95    fn dim_sizes(&self) -> HashMap<String, u64>;
96    fn coords(&self) -> HashMap<String, Value> {
97        HashMap::new()
98    }
99}
100
101/// Trait for types whose fields can be progressively written to a trace backend.
102///
103/// Each field in a `Storable` struct has a *primary dimension* — the dimension
104/// along which one entry is appended per event. For most fields this is the draw
105/// dimension: one value is recorded per MCMC draw. Fields annotated with
106/// `#[storable(event = "name")]` use a different primary dimension, meaning they
107/// only receive a value when that particular event occurs.
108///
109/// For example, divergence statistics use `event = "divergence"`: fields like
110/// `divergence_draw` or `divergence_message` only have a value on draws where a
111/// divergence actually occurred. Storage backends collect these values into a
112/// separate array whose second axis is named after the event (e.g.
113/// `"divergence"`) rather than `"draw"`, and whose length equals the number of
114/// events that occurred rather than the total number of draws.
115///
116/// The struct itself is responsible for ensuring that all fields sharing the
117/// same event dimension produce a value on exactly the same set of draws — the
118/// storage layer does not enforce this.
119pub trait Storable<P: HasDims + ?Sized>: Send + Sync {
120    fn names(parent: &P) -> Vec<&str>;
121    fn item_type(parent: &P, item: &str) -> ItemType;
122    fn dims<'a>(parent: &'a P, item: &str) -> Vec<&'a str>;
123
124    /// Return the name of the primary dimension for the given field, or `None`
125    /// if the field uses the default draw dimension.
126    fn event_dim(_parent: &P, _item: &str) -> Option<&'static str> {
127        None
128    }
129
130    fn get_all<'a>(&'a mut self, parent: &'a P) -> Vec<(&'a str, Option<Value>)>;
131}
132
133impl<P: HasDims> Storable<P> for Vec<f64> {
134    fn names(_parent: &P) -> Vec<&str> {
135        vec!["value"]
136    }
137
138    fn item_type(_parent: &P, _item: &str) -> ItemType {
139        ItemType::F64
140    }
141
142    fn dims<'a>(_parent: &'a P, _item: &str) -> Vec<&'a str> {
143        vec!["dim"]
144    }
145
146    fn event_dim(_parent: &P, _item: &str) -> Option<&'static str> {
147        None
148    }
149
150    fn get_all<'a>(&'a mut self, _parent: &'a P) -> Vec<(&'a str, Option<Value>)> {
151        vec![("value", Some(Value::F64(self.clone())))]
152    }
153}
154
155impl<P: HasDims> Storable<P> for () {
156    fn names(_parent: &P) -> Vec<&str> {
157        vec![]
158    }
159
160    fn item_type(_parent: &P, _item: &str) -> ItemType {
161        panic!("No items in unit type")
162    }
163
164    fn dims<'a>(_parent: &'a P, _item: &str) -> Vec<&'a str> {
165        panic!("No items in unit type")
166    }
167
168    fn event_dim(_parent: &P, _item: &str) -> Option<&'static str> {
169        None
170    }
171
172    fn get_all(&mut self, _parent: &P) -> Vec<(&str, Option<Value>)> {
173        vec![]
174    }
175}