1use anyhow::Result;
4use nuts_storable::{ItemType, Value};
5use std::collections::HashMap;
6
7use crate::storage::{ChainStorage, StorageConfig, TraceStorage};
8use crate::{Progress, Settings};
9
10#[derive(Clone, Debug)]
12pub enum HashMapValue {
13 F64(Vec<f64>),
14 F32(Vec<f32>),
15 Bool(Vec<bool>),
16 I64(Vec<i64>),
17 U64(Vec<u64>),
18 String(Vec<String>),
19}
20
21impl HashMapValue {
22 fn new(item_type: ItemType) -> Self {
24 match item_type {
25 ItemType::F64 => HashMapValue::F64(Vec::new()),
26 ItemType::F32 => HashMapValue::F32(Vec::new()),
27 ItemType::Bool => HashMapValue::Bool(Vec::new()),
28 ItemType::I64 => HashMapValue::I64(Vec::new()),
29 ItemType::U64 => HashMapValue::U64(Vec::new()),
30 ItemType::String => HashMapValue::String(Vec::new()),
31 ItemType::DateTime64(_) | ItemType::TimeDelta64(_) => HashMapValue::I64(Vec::new()),
32 }
33 }
34
35 fn push(&mut self, value: Value) {
37 match (self, value) {
38 (HashMapValue::F64(vec), Value::ScalarF64(v)) => vec.push(v),
40 (HashMapValue::F32(vec), Value::ScalarF32(v)) => vec.push(v),
41 (HashMapValue::U64(vec), Value::ScalarU64(v)) => vec.push(v),
42 (HashMapValue::Bool(vec), Value::ScalarBool(v)) => vec.push(v),
43 (HashMapValue::I64(vec), Value::ScalarI64(v)) => vec.push(v),
44
45 (HashMapValue::F64(vec), Value::F64(v)) => vec.extend(v),
46 (HashMapValue::F32(vec), Value::F32(v)) => vec.extend(v),
47 (HashMapValue::U64(vec), Value::U64(v)) => vec.extend(v),
48 (HashMapValue::Bool(vec), Value::Bool(v)) => vec.extend(v),
49 (HashMapValue::I64(vec), Value::I64(v)) => vec.extend(v),
50
51 (HashMapValue::String(vec), Value::Strings(v)) => vec.extend(v),
52 (HashMapValue::String(vec), Value::ScalarString(v)) => vec.push(v),
53 (HashMapValue::I64(vec), Value::DateTime64(_, v)) => vec.extend(v),
54 (HashMapValue::I64(vec), Value::TimeDelta64(_, v)) => vec.extend(v),
55
56 _ => panic!("Mismatched item type"),
57 }
58 }
59}
60
61#[derive(Clone)]
63pub struct HashMapTraceStorage {
64 draw_types: Vec<(String, ItemType)>,
65 param_types: Vec<(String, ItemType)>,
66}
67
68#[derive(Clone)]
70pub struct HashMapChainStorage {
71 warmup_stats: HashMap<String, HashMapValue>,
72 sample_stats: HashMap<String, HashMapValue>,
73 warmup_draws: HashMap<String, HashMapValue>,
74 sample_draws: HashMap<String, HashMapValue>,
75 last_sample_was_warmup: bool,
76}
77
78#[derive(Debug, Clone)]
80pub struct HashMapResult {
81 pub stats: HashMap<String, HashMapValue>,
83 pub draws: HashMap<String, HashMapValue>,
85}
86
87impl HashMapChainStorage {
88 fn new(param_types: &[(String, ItemType)], draw_types: &[(String, ItemType)]) -> Self {
90 let warmup_stats = param_types
91 .iter()
92 .cloned()
93 .map(|(name, item_type)| (name, HashMapValue::new(item_type)))
94 .collect();
95
96 let sample_stats = param_types
97 .iter()
98 .cloned()
99 .map(|(name, item_type)| (name, HashMapValue::new(item_type)))
100 .collect();
101
102 let warmup_draws = draw_types
103 .iter()
104 .cloned()
105 .map(|(name, item_type)| (name, HashMapValue::new(item_type)))
106 .collect();
107
108 let sample_draws = draw_types
109 .iter()
110 .cloned()
111 .map(|(name, item_type)| (name, HashMapValue::new(item_type)))
112 .collect();
113
114 Self {
115 warmup_stats,
116 sample_stats,
117 warmup_draws,
118 sample_draws,
119 last_sample_was_warmup: true,
120 }
121 }
122
123 fn push_param(&mut self, name: &str, value: Value, is_warmup: bool) -> Result<()> {
125 if ["draw", "chain"].contains(&name) {
126 return Ok(());
127 }
128
129 let target_map = if is_warmup {
130 &mut self.warmup_stats
131 } else {
132 &mut self.sample_stats
133 };
134
135 if let Some(hash_value) = target_map.get_mut(name) {
136 hash_value.push(value);
137 } else {
138 panic!("Unknown param name: {}", name);
139 }
140 Ok(())
141 }
142
143 fn push_draw(&mut self, name: &str, value: Value, is_warmup: bool) -> Result<()> {
145 if ["draw", "chain"].contains(&name) {
146 return Ok(());
147 }
148
149 let target_map = if is_warmup {
150 &mut self.warmup_draws
151 } else {
152 &mut self.sample_draws
153 };
154
155 if let Some(hash_value) = target_map.get_mut(name) {
156 hash_value.push(value);
157 } else {
158 panic!("Unknown posterior variable name: {}", name);
159 }
160 Ok(())
161 }
162}
163
164impl ChainStorage for HashMapChainStorage {
165 type Finalized = HashMapResult;
166
167 fn record_sample(
168 &mut self,
169 _settings: &impl Settings,
170 stats: Vec<(&str, Option<Value>)>,
171 draws: Vec<(&str, Option<Value>)>,
172 info: &Progress,
173 ) -> Result<()> {
174 let is_first_draw = self.last_sample_was_warmup && !info.tuning;
175 if is_first_draw {
176 self.last_sample_was_warmup = false;
177 }
178
179 for (name, value) in stats {
180 if let Some(value) = value {
181 self.push_param(name, value, info.tuning)?;
182 }
183 }
184 for (name, value) in draws {
185 if let Some(value) = value {
186 self.push_draw(name, value, info.tuning)?;
187 } else {
188 panic!("Missing draw value for {}", name);
189 }
190 }
191 Ok(())
192 }
193
194 fn finalize(self) -> Result<Self::Finalized> {
196 let mut combined_stats = HashMap::new();
198 let mut combined_draws = HashMap::new();
199
200 for (key, warmup_values) in self.warmup_stats {
202 let sample_values = &self.sample_stats[&key];
203 let mut combined = warmup_values.clone();
204
205 match (&mut combined, sample_values) {
206 (HashMapValue::F64(combined_vec), HashMapValue::F64(sample_vec)) => {
207 combined_vec.extend(sample_vec.iter().cloned());
208 }
209 (HashMapValue::F32(combined_vec), HashMapValue::F32(sample_vec)) => {
210 combined_vec.extend(sample_vec.iter().cloned());
211 }
212 (HashMapValue::Bool(combined_vec), HashMapValue::Bool(sample_vec)) => {
213 combined_vec.extend(sample_vec.iter().cloned());
214 }
215 (HashMapValue::I64(combined_vec), HashMapValue::I64(sample_vec)) => {
216 combined_vec.extend(sample_vec.iter().cloned());
217 }
218 (HashMapValue::U64(combined_vec), HashMapValue::U64(sample_vec)) => {
219 combined_vec.extend(sample_vec.iter().cloned());
220 }
221 _ => panic!("Type mismatch when combining stats for {}", key),
222 }
223
224 combined_stats.insert(key, combined);
225 }
226
227 for (key, warmup_values) in self.warmup_draws {
229 let sample_values = &self.sample_draws[&key];
230 let mut combined = warmup_values.clone();
231
232 match (&mut combined, sample_values) {
233 (HashMapValue::F64(combined_vec), HashMapValue::F64(sample_vec)) => {
234 combined_vec.extend(sample_vec.iter().cloned());
235 }
236 (HashMapValue::F32(combined_vec), HashMapValue::F32(sample_vec)) => {
237 combined_vec.extend(sample_vec.iter().cloned());
238 }
239 (HashMapValue::Bool(combined_vec), HashMapValue::Bool(sample_vec)) => {
240 combined_vec.extend(sample_vec.iter().cloned());
241 }
242 (HashMapValue::I64(combined_vec), HashMapValue::I64(sample_vec)) => {
243 combined_vec.extend(sample_vec.iter().cloned());
244 }
245 (HashMapValue::U64(combined_vec), HashMapValue::U64(sample_vec)) => {
246 combined_vec.extend(sample_vec.iter().cloned());
247 }
248 _ => panic!("Type mismatch when combining draws for {}", key),
249 }
250
251 combined_draws.insert(key, combined);
252 }
253
254 Ok(HashMapResult {
255 stats: combined_stats,
256 draws: combined_draws,
257 })
258 }
259
260 fn flush(&self) -> Result<()> {
262 Ok(())
263 }
264
265 fn inspect(&self) -> Result<Option<Self::Finalized>> {
266 self.clone().finalize().map(Some)
267 }
268}
269
270pub struct HashMapConfig {}
271
272impl Default for HashMapConfig {
273 fn default() -> Self {
274 Self::new()
275 }
276}
277
278impl HashMapConfig {
279 pub fn new() -> Self {
280 Self {}
281 }
282}
283
284impl StorageConfig for HashMapConfig {
285 type Storage = HashMapTraceStorage;
286
287 fn new_trace<M: crate::Math>(
288 self,
289 settings: &impl Settings,
290 math: &M,
291 ) -> Result<Self::Storage> {
292 Ok(HashMapTraceStorage {
293 param_types: settings.stat_types(math),
294 draw_types: settings.data_types(math),
295 })
296 }
297}
298
299impl TraceStorage for HashMapTraceStorage {
300 type ChainStorage = HashMapChainStorage;
301
302 type Finalized = Vec<HashMapResult>;
303
304 fn initialize_trace_for_chain(&self, _chain_id: u64) -> Result<Self::ChainStorage> {
305 Ok(HashMapChainStorage::new(
306 &self.param_types,
307 &self.draw_types,
308 ))
309 }
310
311 fn finalize(
312 self,
313 traces: Vec<Result<<Self::ChainStorage as ChainStorage>::Finalized>>,
314 ) -> Result<(Option<anyhow::Error>, Self::Finalized)> {
315 let mut results = Vec::new();
316 let mut first_error = None;
317
318 for trace in traces {
319 match trace {
320 Ok(result) => results.push(result),
321 Err(err) => {
322 if first_error.is_none() {
323 first_error = Some(err);
324 }
325 }
326 }
327 }
328
329 Ok((first_error, results))
330 }
331
332 fn inspect(
333 &self,
334 traces: Vec<Result<Option<<Self::ChainStorage as ChainStorage>::Finalized>>>,
335 ) -> Result<(Option<anyhow::Error>, Self::Finalized)> {
336 self.clone()
337 .finalize(traces.into_iter().map(|r| r.map(|o| o.unwrap())).collect())
338 }
339}