nuts_rs/storage/
hashmap.rs

1use anyhow::Result;
2use nuts_storable::{ItemType, Value};
3use std::collections::HashMap;
4
5use crate::storage::{ChainStorage, StorageConfig, TraceStorage};
6use crate::{Progress, Settings};
7
8/// Container for different types of sample values in HashMaps
9#[derive(Clone, Debug)]
10pub enum HashMapValue {
11    F64(Vec<f64>),
12    F32(Vec<f32>),
13    Bool(Vec<bool>),
14    I64(Vec<i64>),
15    U64(Vec<u64>),
16    String(Vec<String>),
17}
18
19impl HashMapValue {
20    /// Create a new empty HashMapValue of the specified type
21    fn new(item_type: ItemType) -> Self {
22        match item_type {
23            ItemType::F64 => HashMapValue::F64(Vec::new()),
24            ItemType::F32 => HashMapValue::F32(Vec::new()),
25            ItemType::Bool => HashMapValue::Bool(Vec::new()),
26            ItemType::I64 => HashMapValue::I64(Vec::new()),
27            ItemType::U64 => HashMapValue::U64(Vec::new()),
28            ItemType::String => HashMapValue::String(Vec::new()),
29            ItemType::DateTime64(_) | ItemType::TimeDelta64(_) => HashMapValue::I64(Vec::new()),
30        }
31    }
32
33    /// Push a value to the internal vector
34    fn push(&mut self, value: Value) {
35        match (self, value) {
36            // Scalar values - store as single element vectors for array types
37            (HashMapValue::F64(vec), Value::ScalarF64(v)) => vec.push(v),
38            (HashMapValue::F32(vec), Value::ScalarF32(v)) => vec.push(v),
39            (HashMapValue::U64(vec), Value::ScalarU64(v)) => vec.push(v),
40            (HashMapValue::Bool(vec), Value::ScalarBool(v)) => vec.push(v),
41            (HashMapValue::I64(vec), Value::ScalarI64(v)) => vec.push(v),
42
43            (HashMapValue::F64(vec), Value::F64(v)) => vec.extend(v),
44            (HashMapValue::F32(vec), Value::F32(v)) => vec.extend(v),
45            (HashMapValue::U64(vec), Value::U64(v)) => vec.extend(v),
46            (HashMapValue::Bool(vec), Value::Bool(v)) => vec.extend(v),
47            (HashMapValue::I64(vec), Value::I64(v)) => vec.extend(v),
48
49            (HashMapValue::String(vec), Value::Strings(v)) => vec.extend(v),
50            (HashMapValue::String(vec), Value::ScalarString(v)) => vec.push(v),
51            (HashMapValue::I64(vec), Value::DateTime64(_, v)) => vec.extend(v),
52            (HashMapValue::I64(vec), Value::TimeDelta64(_, v)) => vec.extend(v),
53
54            _ => panic!("Mismatched item type"),
55        }
56    }
57}
58
59/// Main storage for HashMap MCMC traces
60#[derive(Clone)]
61pub struct HashMapTraceStorage {
62    draw_types: Vec<(String, ItemType)>,
63    param_types: Vec<(String, ItemType)>,
64}
65
66/// Per-chain storage for HashMap MCMC traces
67#[derive(Clone)]
68pub struct HashMapChainStorage {
69    warmup_stats: HashMap<String, HashMapValue>,
70    sample_stats: HashMap<String, HashMapValue>,
71    warmup_draws: HashMap<String, HashMapValue>,
72    sample_draws: HashMap<String, HashMapValue>,
73    last_sample_was_warmup: bool,
74}
75
76/// Final result containing the collected samples
77#[derive(Debug, Clone)]
78pub struct HashMapResult {
79    /// HashMap containing sampler stats including warmup samples
80    pub stats: HashMap<String, HashMapValue>,
81    /// HashMap containing draws including warmup samples
82    pub draws: HashMap<String, HashMapValue>,
83}
84
85impl HashMapChainStorage {
86    /// Create a new chain storage with HashMaps for parameters and samples
87    fn new(param_types: &Vec<(String, ItemType)>, draw_types: &Vec<(String, ItemType)>) -> Self {
88        let warmup_stats = param_types
89            .iter()
90            .cloned()
91            .map(|(name, item_type)| (name, HashMapValue::new(item_type)))
92            .collect();
93
94        let sample_stats = param_types
95            .iter()
96            .cloned()
97            .map(|(name, item_type)| (name, HashMapValue::new(item_type)))
98            .collect();
99
100        let warmup_draws = draw_types
101            .iter()
102            .cloned()
103            .map(|(name, item_type)| (name, HashMapValue::new(item_type)))
104            .collect();
105
106        let sample_draws = draw_types
107            .iter()
108            .cloned()
109            .map(|(name, item_type)| (name, HashMapValue::new(item_type)))
110            .collect();
111
112        Self {
113            warmup_stats,
114            sample_stats,
115            warmup_draws,
116            sample_draws,
117            last_sample_was_warmup: true,
118        }
119    }
120
121    /// Store a parameter value
122    fn push_param(&mut self, name: &str, value: Value, is_warmup: bool) -> Result<()> {
123        if ["draw", "chain"].contains(&name) {
124            return Ok(());
125        }
126
127        let target_map = if is_warmup {
128            &mut self.warmup_stats
129        } else {
130            &mut self.sample_stats
131        };
132
133        if let Some(hash_value) = target_map.get_mut(name) {
134            hash_value.push(value);
135        } else {
136            panic!("Unknown param name: {}", name);
137        }
138        Ok(())
139    }
140
141    /// Store a draw value
142    fn push_draw(&mut self, name: &str, value: Value, is_warmup: bool) -> Result<()> {
143        if ["draw", "chain"].contains(&name) {
144            return Ok(());
145        }
146
147        let target_map = if is_warmup {
148            &mut self.warmup_draws
149        } else {
150            &mut self.sample_draws
151        };
152
153        if let Some(hash_value) = target_map.get_mut(name) {
154            hash_value.push(value);
155        } else {
156            panic!("Unknown posterior variable name: {}", name);
157        }
158        Ok(())
159    }
160}
161
162impl ChainStorage for HashMapChainStorage {
163    type Finalized = HashMapResult;
164
165    fn record_sample(
166        &mut self,
167        _settings: &impl Settings,
168        stats: Vec<(&str, Option<Value>)>,
169        draws: Vec<(&str, Option<Value>)>,
170        info: &Progress,
171    ) -> Result<()> {
172        let is_first_draw = self.last_sample_was_warmup && !info.tuning;
173        if is_first_draw {
174            self.last_sample_was_warmup = false;
175        }
176
177        for (name, value) in stats {
178            if let Some(value) = value {
179                self.push_param(name, value, info.tuning)?;
180            }
181        }
182        for (name, value) in draws {
183            if let Some(value) = value {
184                self.push_draw(name, value, info.tuning)?;
185            } else {
186                panic!("Missing draw value for {}", name);
187            }
188        }
189        Ok(())
190    }
191
192    /// Finalize storage and return the collected samples
193    fn finalize(self) -> Result<Self::Finalized> {
194        // Combine warmup and sample data
195        let mut combined_stats = HashMap::new();
196        let mut combined_draws = HashMap::new();
197
198        // Combine stats
199        for (key, warmup_values) in self.warmup_stats {
200            let sample_values = &self.sample_stats[&key];
201            let mut combined = warmup_values.clone();
202
203            match (&mut combined, sample_values) {
204                (HashMapValue::F64(combined_vec), HashMapValue::F64(sample_vec)) => {
205                    combined_vec.extend(sample_vec.iter().cloned());
206                }
207                (HashMapValue::F32(combined_vec), HashMapValue::F32(sample_vec)) => {
208                    combined_vec.extend(sample_vec.iter().cloned());
209                }
210                (HashMapValue::Bool(combined_vec), HashMapValue::Bool(sample_vec)) => {
211                    combined_vec.extend(sample_vec.iter().cloned());
212                }
213                (HashMapValue::I64(combined_vec), HashMapValue::I64(sample_vec)) => {
214                    combined_vec.extend(sample_vec.iter().cloned());
215                }
216                (HashMapValue::U64(combined_vec), HashMapValue::U64(sample_vec)) => {
217                    combined_vec.extend(sample_vec.iter().cloned());
218                }
219                _ => panic!("Type mismatch when combining stats for {}", key),
220            }
221
222            combined_stats.insert(key, combined);
223        }
224
225        // Combine draws
226        for (key, warmup_values) in self.warmup_draws {
227            let sample_values = &self.sample_draws[&key];
228            let mut combined = warmup_values.clone();
229
230            match (&mut combined, sample_values) {
231                (HashMapValue::F64(combined_vec), HashMapValue::F64(sample_vec)) => {
232                    combined_vec.extend(sample_vec.iter().cloned());
233                }
234                (HashMapValue::F32(combined_vec), HashMapValue::F32(sample_vec)) => {
235                    combined_vec.extend(sample_vec.iter().cloned());
236                }
237                (HashMapValue::Bool(combined_vec), HashMapValue::Bool(sample_vec)) => {
238                    combined_vec.extend(sample_vec.iter().cloned());
239                }
240                (HashMapValue::I64(combined_vec), HashMapValue::I64(sample_vec)) => {
241                    combined_vec.extend(sample_vec.iter().cloned());
242                }
243                (HashMapValue::U64(combined_vec), HashMapValue::U64(sample_vec)) => {
244                    combined_vec.extend(sample_vec.iter().cloned());
245                }
246                _ => panic!("Type mismatch when combining draws for {}", key),
247            }
248
249            combined_draws.insert(key, combined);
250        }
251
252        Ok(HashMapResult {
253            stats: combined_stats,
254            draws: combined_draws,
255        })
256    }
257
258    /// Flush - no-op for HashMap storage since everything is in memory
259    fn flush(&self) -> Result<()> {
260        Ok(())
261    }
262
263    fn inspect(&self) -> Result<Option<Self::Finalized>> {
264        self.clone().finalize().map(Some)
265    }
266}
267
268pub struct HashMapConfig {}
269
270impl Default for HashMapConfig {
271    fn default() -> Self {
272        Self::new()
273    }
274}
275
276impl HashMapConfig {
277    pub fn new() -> Self {
278        Self {}
279    }
280}
281
282impl StorageConfig for HashMapConfig {
283    type Storage = HashMapTraceStorage;
284
285    fn new_trace<M: crate::Math>(
286        self,
287        settings: &impl Settings,
288        math: &M,
289    ) -> Result<Self::Storage> {
290        Ok(HashMapTraceStorage {
291            param_types: settings.stat_types(math),
292            draw_types: settings.data_types(math),
293        })
294    }
295}
296
297impl TraceStorage for HashMapTraceStorage {
298    type ChainStorage = HashMapChainStorage;
299
300    type Finalized = Vec<HashMapResult>;
301
302    fn initialize_trace_for_chain(&self, _chain_id: u64) -> Result<Self::ChainStorage> {
303        Ok(HashMapChainStorage::new(
304            &self.param_types,
305            &self.draw_types,
306        ))
307    }
308
309    fn finalize(
310        self,
311        traces: Vec<Result<<Self::ChainStorage as ChainStorage>::Finalized>>,
312    ) -> Result<(Option<anyhow::Error>, Self::Finalized)> {
313        let mut results = Vec::new();
314        let mut first_error = None;
315
316        for trace in traces {
317            match trace {
318                Ok(result) => results.push(result),
319                Err(err) => {
320                    if first_error.is_none() {
321                        first_error = Some(err);
322                    }
323                }
324            }
325        }
326
327        Ok((first_error, results))
328    }
329
330    fn inspect(
331        &self,
332        traces: Vec<Result<Option<<Self::ChainStorage as ChainStorage>::Finalized>>>,
333    ) -> Result<(Option<anyhow::Error>, Self::Finalized)> {
334        self.clone()
335            .finalize(traces.into_iter().map(|r| r.map(|o| o.unwrap())).collect())
336    }
337}