Skip to main content

nuts_rs/storage/
hashmap.rs

1//! In-memory storage backend that accumulates draws and statistics into plain Rust `HashMap`s.
2
3use anyhow::Result;
4use nuts_storable::{ItemType, Value};
5use std::collections::HashMap;
6
7use crate::storage::{ChainStorage, StorageConfig, TraceStorage};
8use crate::{Progress, Settings};
9
10/// Container for different types of sample values in HashMaps
11#[derive(Clone, Debug)]
12pub enum HashMapValue {
13    F64(Vec<f64>),
14    F32(Vec<f32>),
15    Bool(Vec<bool>),
16    I64(Vec<i64>),
17    U64(Vec<u64>),
18    String(Vec<String>),
19}
20
21impl HashMapValue {
22    /// Create a new empty HashMapValue of the specified type
23    fn new(item_type: ItemType) -> Self {
24        match item_type {
25            ItemType::F64 => HashMapValue::F64(Vec::new()),
26            ItemType::F32 => HashMapValue::F32(Vec::new()),
27            ItemType::Bool => HashMapValue::Bool(Vec::new()),
28            ItemType::I64 => HashMapValue::I64(Vec::new()),
29            ItemType::U64 => HashMapValue::U64(Vec::new()),
30            ItemType::String => HashMapValue::String(Vec::new()),
31            ItemType::DateTime64(_) | ItemType::TimeDelta64(_) => HashMapValue::I64(Vec::new()),
32        }
33    }
34
35    /// Push a value to the internal vector
36    fn push(&mut self, value: Value) {
37        match (self, value) {
38            // Scalar values - store as single element vectors for array types
39            (HashMapValue::F64(vec), Value::ScalarF64(v)) => vec.push(v),
40            (HashMapValue::F32(vec), Value::ScalarF32(v)) => vec.push(v),
41            (HashMapValue::U64(vec), Value::ScalarU64(v)) => vec.push(v),
42            (HashMapValue::Bool(vec), Value::ScalarBool(v)) => vec.push(v),
43            (HashMapValue::I64(vec), Value::ScalarI64(v)) => vec.push(v),
44
45            (HashMapValue::F64(vec), Value::F64(v)) => vec.extend(v),
46            (HashMapValue::F32(vec), Value::F32(v)) => vec.extend(v),
47            (HashMapValue::U64(vec), Value::U64(v)) => vec.extend(v),
48            (HashMapValue::Bool(vec), Value::Bool(v)) => vec.extend(v),
49            (HashMapValue::I64(vec), Value::I64(v)) => vec.extend(v),
50
51            (HashMapValue::String(vec), Value::Strings(v)) => vec.extend(v),
52            (HashMapValue::String(vec), Value::ScalarString(v)) => vec.push(v),
53            (HashMapValue::I64(vec), Value::DateTime64(_, v)) => vec.extend(v),
54            (HashMapValue::I64(vec), Value::TimeDelta64(_, v)) => vec.extend(v),
55
56            _ => panic!("Mismatched item type"),
57        }
58    }
59}
60
61/// Main storage for HashMap MCMC traces
62#[derive(Clone)]
63pub struct HashMapTraceStorage {
64    draw_types: Vec<(String, ItemType)>,
65    param_types: Vec<(String, ItemType)>,
66}
67
68/// Per-chain storage for HashMap MCMC traces
69#[derive(Clone)]
70pub struct HashMapChainStorage {
71    warmup_stats: HashMap<String, HashMapValue>,
72    sample_stats: HashMap<String, HashMapValue>,
73    warmup_draws: HashMap<String, HashMapValue>,
74    sample_draws: HashMap<String, HashMapValue>,
75    last_sample_was_warmup: bool,
76}
77
78/// Final result containing the collected samples
79#[derive(Debug, Clone)]
80pub struct HashMapResult {
81    /// HashMap containing sampler stats including warmup samples
82    pub stats: HashMap<String, HashMapValue>,
83    /// HashMap containing draws including warmup samples
84    pub draws: HashMap<String, HashMapValue>,
85}
86
87impl HashMapChainStorage {
88    /// Create a new chain storage with HashMaps for parameters and samples
89    fn new(param_types: &[(String, ItemType)], draw_types: &[(String, ItemType)]) -> Self {
90        let warmup_stats = param_types
91            .iter()
92            .cloned()
93            .map(|(name, item_type)| (name, HashMapValue::new(item_type)))
94            .collect();
95
96        let sample_stats = param_types
97            .iter()
98            .cloned()
99            .map(|(name, item_type)| (name, HashMapValue::new(item_type)))
100            .collect();
101
102        let warmup_draws = draw_types
103            .iter()
104            .cloned()
105            .map(|(name, item_type)| (name, HashMapValue::new(item_type)))
106            .collect();
107
108        let sample_draws = draw_types
109            .iter()
110            .cloned()
111            .map(|(name, item_type)| (name, HashMapValue::new(item_type)))
112            .collect();
113
114        Self {
115            warmup_stats,
116            sample_stats,
117            warmup_draws,
118            sample_draws,
119            last_sample_was_warmup: true,
120        }
121    }
122
123    /// Store a parameter value
124    fn push_param(&mut self, name: &str, value: Value, is_warmup: bool) -> Result<()> {
125        if ["draw", "chain"].contains(&name) {
126            return Ok(());
127        }
128
129        let target_map = if is_warmup {
130            &mut self.warmup_stats
131        } else {
132            &mut self.sample_stats
133        };
134
135        if let Some(hash_value) = target_map.get_mut(name) {
136            hash_value.push(value);
137        } else {
138            panic!("Unknown param name: {}", name);
139        }
140        Ok(())
141    }
142
143    /// Store a draw value
144    fn push_draw(&mut self, name: &str, value: Value, is_warmup: bool) -> Result<()> {
145        if ["draw", "chain"].contains(&name) {
146            return Ok(());
147        }
148
149        let target_map = if is_warmup {
150            &mut self.warmup_draws
151        } else {
152            &mut self.sample_draws
153        };
154
155        if let Some(hash_value) = target_map.get_mut(name) {
156            hash_value.push(value);
157        } else {
158            panic!("Unknown posterior variable name: {}", name);
159        }
160        Ok(())
161    }
162}
163
164impl ChainStorage for HashMapChainStorage {
165    type Finalized = HashMapResult;
166
167    fn record_sample(
168        &mut self,
169        _settings: &impl Settings,
170        stats: Vec<(&str, Option<Value>)>,
171        draws: Vec<(&str, Option<Value>)>,
172        info: &Progress,
173    ) -> Result<()> {
174        let is_first_draw = self.last_sample_was_warmup && !info.tuning;
175        if is_first_draw {
176            self.last_sample_was_warmup = false;
177        }
178
179        for (name, value) in stats {
180            if let Some(value) = value {
181                self.push_param(name, value, info.tuning)?;
182            }
183        }
184        for (name, value) in draws {
185            if let Some(value) = value {
186                self.push_draw(name, value, info.tuning)?;
187            } else {
188                panic!("Missing draw value for {}", name);
189            }
190        }
191        Ok(())
192    }
193
194    /// Finalize storage and return the collected samples
195    fn finalize(self) -> Result<Self::Finalized> {
196        // Combine warmup and sample data
197        let mut combined_stats = HashMap::new();
198        let mut combined_draws = HashMap::new();
199
200        // Combine stats
201        for (key, warmup_values) in self.warmup_stats {
202            let sample_values = &self.sample_stats[&key];
203            let mut combined = warmup_values.clone();
204
205            match (&mut combined, sample_values) {
206                (HashMapValue::F64(combined_vec), HashMapValue::F64(sample_vec)) => {
207                    combined_vec.extend(sample_vec.iter().cloned());
208                }
209                (HashMapValue::F32(combined_vec), HashMapValue::F32(sample_vec)) => {
210                    combined_vec.extend(sample_vec.iter().cloned());
211                }
212                (HashMapValue::Bool(combined_vec), HashMapValue::Bool(sample_vec)) => {
213                    combined_vec.extend(sample_vec.iter().cloned());
214                }
215                (HashMapValue::I64(combined_vec), HashMapValue::I64(sample_vec)) => {
216                    combined_vec.extend(sample_vec.iter().cloned());
217                }
218                (HashMapValue::U64(combined_vec), HashMapValue::U64(sample_vec)) => {
219                    combined_vec.extend(sample_vec.iter().cloned());
220                }
221                _ => panic!("Type mismatch when combining stats for {}", key),
222            }
223
224            combined_stats.insert(key, combined);
225        }
226
227        // Combine draws
228        for (key, warmup_values) in self.warmup_draws {
229            let sample_values = &self.sample_draws[&key];
230            let mut combined = warmup_values.clone();
231
232            match (&mut combined, sample_values) {
233                (HashMapValue::F64(combined_vec), HashMapValue::F64(sample_vec)) => {
234                    combined_vec.extend(sample_vec.iter().cloned());
235                }
236                (HashMapValue::F32(combined_vec), HashMapValue::F32(sample_vec)) => {
237                    combined_vec.extend(sample_vec.iter().cloned());
238                }
239                (HashMapValue::Bool(combined_vec), HashMapValue::Bool(sample_vec)) => {
240                    combined_vec.extend(sample_vec.iter().cloned());
241                }
242                (HashMapValue::I64(combined_vec), HashMapValue::I64(sample_vec)) => {
243                    combined_vec.extend(sample_vec.iter().cloned());
244                }
245                (HashMapValue::U64(combined_vec), HashMapValue::U64(sample_vec)) => {
246                    combined_vec.extend(sample_vec.iter().cloned());
247                }
248                _ => panic!("Type mismatch when combining draws for {}", key),
249            }
250
251            combined_draws.insert(key, combined);
252        }
253
254        Ok(HashMapResult {
255            stats: combined_stats,
256            draws: combined_draws,
257        })
258    }
259
260    /// Flush - no-op for HashMap storage since everything is in memory
261    fn flush(&self) -> Result<()> {
262        Ok(())
263    }
264
265    fn inspect(&self) -> Result<Option<Self::Finalized>> {
266        self.clone().finalize().map(Some)
267    }
268}
269
270pub struct HashMapConfig {}
271
272impl Default for HashMapConfig {
273    fn default() -> Self {
274        Self::new()
275    }
276}
277
278impl HashMapConfig {
279    pub fn new() -> Self {
280        Self {}
281    }
282}
283
284impl StorageConfig for HashMapConfig {
285    type Storage = HashMapTraceStorage;
286
287    fn new_trace<M: crate::Math>(
288        self,
289        settings: &impl Settings,
290        math: &M,
291    ) -> Result<Self::Storage> {
292        Ok(HashMapTraceStorage {
293            param_types: settings.stat_types(math),
294            draw_types: settings.data_types(math),
295        })
296    }
297}
298
299impl TraceStorage for HashMapTraceStorage {
300    type ChainStorage = HashMapChainStorage;
301
302    type Finalized = Vec<HashMapResult>;
303
304    fn initialize_trace_for_chain(&self, _chain_id: u64) -> Result<Self::ChainStorage> {
305        Ok(HashMapChainStorage::new(
306            &self.param_types,
307            &self.draw_types,
308        ))
309    }
310
311    fn finalize(
312        self,
313        traces: Vec<Result<<Self::ChainStorage as ChainStorage>::Finalized>>,
314    ) -> Result<(Option<anyhow::Error>, Self::Finalized)> {
315        let mut results = Vec::new();
316        let mut first_error = None;
317
318        for trace in traces {
319            match trace {
320                Ok(result) => results.push(result),
321                Err(err) => {
322                    if first_error.is_none() {
323                        first_error = Some(err);
324                    }
325                }
326            }
327        }
328
329        Ok((first_error, results))
330    }
331
332    fn inspect(
333        &self,
334        traces: Vec<Result<Option<<Self::ChainStorage as ChainStorage>::Finalized>>>,
335    ) -> Result<(Option<anyhow::Error>, Self::Finalized)> {
336        self.clone()
337            .finalize(traces.into_iter().map(|r| r.map(|o| o.unwrap())).collect())
338    }
339}