nuts_rs/storage/
hashmap.rs

1use anyhow::Result;
2use nuts_storable::{ItemType, Value};
3use std::collections::HashMap;
4
5use crate::storage::{ChainStorage, StorageConfig, TraceStorage};
6use crate::{Progress, Settings};
7
8/// Container for different types of sample values in HashMaps
9#[derive(Clone, Debug)]
10pub enum HashMapValue {
11    F64(Vec<f64>),
12    F32(Vec<f32>),
13    Bool(Vec<bool>),
14    I64(Vec<i64>),
15    U64(Vec<u64>),
16    String(Vec<String>),
17}
18
19impl HashMapValue {
20    /// Create a new empty HashMapValue of the specified type
21    fn new(item_type: ItemType) -> Self {
22        match item_type {
23            ItemType::F64 => HashMapValue::F64(Vec::new()),
24            ItemType::F32 => HashMapValue::F32(Vec::new()),
25            ItemType::Bool => HashMapValue::Bool(Vec::new()),
26            ItemType::I64 => HashMapValue::I64(Vec::new()),
27            ItemType::U64 => HashMapValue::U64(Vec::new()),
28            ItemType::String => HashMapValue::String(Vec::new()),
29        }
30    }
31
32    /// Push a value to the internal vector
33    fn push(&mut self, value: Value) {
34        match (self, value) {
35            // Scalar values - store as single element vectors for array types
36            (HashMapValue::F64(vec), Value::ScalarF64(v)) => vec.push(v),
37            (HashMapValue::F32(vec), Value::ScalarF32(v)) => vec.push(v),
38            (HashMapValue::U64(vec), Value::ScalarU64(v)) => vec.push(v),
39            (HashMapValue::Bool(vec), Value::ScalarBool(v)) => vec.push(v),
40            (HashMapValue::I64(vec), Value::ScalarI64(v)) => vec.push(v),
41
42            (HashMapValue::F64(vec), Value::F64(v)) => vec.extend(v),
43            (HashMapValue::F32(vec), Value::F32(v)) => vec.extend(v),
44            (HashMapValue::U64(vec), Value::U64(v)) => vec.extend(v),
45            (HashMapValue::Bool(vec), Value::Bool(v)) => vec.extend(v),
46            (HashMapValue::I64(vec), Value::I64(v)) => vec.extend(v),
47
48            _ => panic!("Mismatched item type"),
49        }
50    }
51}
52
53/// Main storage for HashMap MCMC traces
54#[derive(Clone)]
55pub struct HashMapTraceStorage {
56    draw_types: Vec<(String, ItemType)>,
57    param_types: Vec<(String, ItemType)>,
58}
59
60/// Per-chain storage for HashMap MCMC traces
61#[derive(Clone)]
62pub struct HashMapChainStorage {
63    warmup_stats: HashMap<String, HashMapValue>,
64    sample_stats: HashMap<String, HashMapValue>,
65    warmup_draws: HashMap<String, HashMapValue>,
66    sample_draws: HashMap<String, HashMapValue>,
67    last_sample_was_warmup: bool,
68}
69
70/// Final result containing the collected samples
71#[derive(Debug, Clone)]
72pub struct HashMapResult {
73    /// HashMap containing sampler stats including warmup samples
74    pub stats: HashMap<String, HashMapValue>,
75    /// HashMap containing draws including warmup samples
76    pub draws: HashMap<String, HashMapValue>,
77}
78
79impl HashMapChainStorage {
80    /// Create a new chain storage with HashMaps for parameters and samples
81    fn new(param_types: &Vec<(String, ItemType)>, draw_types: &Vec<(String, ItemType)>) -> Self {
82        let warmup_stats = param_types
83            .iter()
84            .cloned()
85            .map(|(name, item_type)| (name, HashMapValue::new(item_type)))
86            .collect();
87
88        let sample_stats = param_types
89            .iter()
90            .cloned()
91            .map(|(name, item_type)| (name, HashMapValue::new(item_type)))
92            .collect();
93
94        let warmup_draws = draw_types
95            .iter()
96            .cloned()
97            .map(|(name, item_type)| (name, HashMapValue::new(item_type)))
98            .collect();
99
100        let sample_draws = draw_types
101            .iter()
102            .cloned()
103            .map(|(name, item_type)| (name, HashMapValue::new(item_type)))
104            .collect();
105
106        Self {
107            warmup_stats,
108            sample_stats,
109            warmup_draws,
110            sample_draws,
111            last_sample_was_warmup: true,
112        }
113    }
114
115    /// Store a parameter value
116    fn push_param(&mut self, name: &str, value: Value, is_warmup: bool) -> Result<()> {
117        if ["draw", "chain"].contains(&name) {
118            return Ok(());
119        }
120
121        let target_map = if is_warmup {
122            &mut self.warmup_stats
123        } else {
124            &mut self.sample_stats
125        };
126
127        if let Some(hash_value) = target_map.get_mut(name) {
128            hash_value.push(value);
129        } else {
130            panic!("Unknown param name: {}", name);
131        }
132        Ok(())
133    }
134
135    /// Store a draw value
136    fn push_draw(&mut self, name: &str, value: Value, is_warmup: bool) -> Result<()> {
137        if ["draw", "chain"].contains(&name) {
138            return Ok(());
139        }
140
141        let target_map = if is_warmup {
142            &mut self.warmup_draws
143        } else {
144            &mut self.sample_draws
145        };
146
147        if let Some(hash_value) = target_map.get_mut(name) {
148            hash_value.push(value);
149        } else {
150            panic!("Unknown posterior variable name: {}", name);
151        }
152        Ok(())
153    }
154}
155
156impl ChainStorage for HashMapChainStorage {
157    type Finalized = HashMapResult;
158
159    fn record_sample(
160        &mut self,
161        _settings: &impl Settings,
162        stats: Vec<(&str, Option<Value>)>,
163        draws: Vec<(&str, Option<Value>)>,
164        info: &Progress,
165    ) -> Result<()> {
166        let is_first_draw = self.last_sample_was_warmup && !info.tuning;
167        if is_first_draw {
168            self.last_sample_was_warmup = false;
169        }
170
171        for (name, value) in stats {
172            if let Some(value) = value {
173                self.push_param(name, value, info.tuning)?;
174            }
175        }
176        for (name, value) in draws {
177            if let Some(value) = value {
178                self.push_draw(name, value, info.tuning)?;
179            } else {
180                panic!("Missing draw value for {}", name);
181            }
182        }
183        Ok(())
184    }
185
186    /// Finalize storage and return the collected samples
187    fn finalize(self) -> Result<Self::Finalized> {
188        // Combine warmup and sample data
189        let mut combined_stats = HashMap::new();
190        let mut combined_draws = HashMap::new();
191
192        // Combine stats
193        for (key, warmup_values) in self.warmup_stats {
194            let sample_values = &self.sample_stats[&key];
195            let mut combined = warmup_values.clone();
196
197            match (&mut combined, sample_values) {
198                (HashMapValue::F64(combined_vec), HashMapValue::F64(sample_vec)) => {
199                    combined_vec.extend(sample_vec.iter().cloned());
200                }
201                (HashMapValue::F32(combined_vec), HashMapValue::F32(sample_vec)) => {
202                    combined_vec.extend(sample_vec.iter().cloned());
203                }
204                (HashMapValue::Bool(combined_vec), HashMapValue::Bool(sample_vec)) => {
205                    combined_vec.extend(sample_vec.iter().cloned());
206                }
207                (HashMapValue::I64(combined_vec), HashMapValue::I64(sample_vec)) => {
208                    combined_vec.extend(sample_vec.iter().cloned());
209                }
210                (HashMapValue::U64(combined_vec), HashMapValue::U64(sample_vec)) => {
211                    combined_vec.extend(sample_vec.iter().cloned());
212                }
213                _ => panic!("Type mismatch when combining stats for {}", key),
214            }
215
216            combined_stats.insert(key, combined);
217        }
218
219        // Combine draws
220        for (key, warmup_values) in self.warmup_draws {
221            let sample_values = &self.sample_draws[&key];
222            let mut combined = warmup_values.clone();
223
224            match (&mut combined, sample_values) {
225                (HashMapValue::F64(combined_vec), HashMapValue::F64(sample_vec)) => {
226                    combined_vec.extend(sample_vec.iter().cloned());
227                }
228                (HashMapValue::F32(combined_vec), HashMapValue::F32(sample_vec)) => {
229                    combined_vec.extend(sample_vec.iter().cloned());
230                }
231                (HashMapValue::Bool(combined_vec), HashMapValue::Bool(sample_vec)) => {
232                    combined_vec.extend(sample_vec.iter().cloned());
233                }
234                (HashMapValue::I64(combined_vec), HashMapValue::I64(sample_vec)) => {
235                    combined_vec.extend(sample_vec.iter().cloned());
236                }
237                (HashMapValue::U64(combined_vec), HashMapValue::U64(sample_vec)) => {
238                    combined_vec.extend(sample_vec.iter().cloned());
239                }
240                _ => panic!("Type mismatch when combining draws for {}", key),
241            }
242
243            combined_draws.insert(key, combined);
244        }
245
246        Ok(HashMapResult {
247            stats: combined_stats,
248            draws: combined_draws,
249        })
250    }
251
252    /// Flush - no-op for HashMap storage since everything is in memory
253    fn flush(&self) -> Result<()> {
254        Ok(())
255    }
256
257    fn inspect(&self) -> Result<Option<Self::Finalized>> {
258        self.clone().finalize().map(Some)
259    }
260}
261
262pub struct HashMapConfig {}
263
264impl Default for HashMapConfig {
265    fn default() -> Self {
266        Self::new()
267    }
268}
269
270impl HashMapConfig {
271    pub fn new() -> Self {
272        Self {}
273    }
274}
275
276impl StorageConfig for HashMapConfig {
277    type Storage = HashMapTraceStorage;
278
279    fn new_trace<M: crate::Math>(
280        self,
281        settings: &impl Settings,
282        math: &M,
283    ) -> Result<Self::Storage> {
284        Ok(HashMapTraceStorage {
285            param_types: settings.stat_types(math),
286            draw_types: settings.data_types(math),
287        })
288    }
289}
290
291impl TraceStorage for HashMapTraceStorage {
292    type ChainStorage = HashMapChainStorage;
293
294    type Finalized = Vec<HashMapResult>;
295
296    fn initialize_trace_for_chain(&self, _chain_id: u64) -> Result<Self::ChainStorage> {
297        Ok(HashMapChainStorage::new(
298            &self.param_types,
299            &self.draw_types,
300        ))
301    }
302
303    fn finalize(
304        self,
305        traces: Vec<Result<<Self::ChainStorage as ChainStorage>::Finalized>>,
306    ) -> Result<(Option<anyhow::Error>, Self::Finalized)> {
307        let mut results = Vec::new();
308        let mut first_error = None;
309
310        for trace in traces {
311            match trace {
312                Ok(result) => results.push(result),
313                Err(err) => {
314                    if first_error.is_none() {
315                        first_error = Some(err);
316                    }
317                }
318            }
319        }
320
321        Ok((first_error, results))
322    }
323
324    fn inspect(
325        &self,
326        traces: Vec<Result<Option<<Self::ChainStorage as ChainStorage>::Finalized>>>,
327    ) -> Result<(Option<anyhow::Error>, Self::Finalized)> {
328        self.clone()
329            .finalize(traces.into_iter().map(|r| r.map(|o| o.unwrap())).collect())
330    }
331}