1use anyhow::Result;
2use nuts_storable::{ItemType, Value};
3use std::collections::HashMap;
4
5use crate::storage::{ChainStorage, StorageConfig, TraceStorage};
6use crate::{Progress, Settings};
7
8#[derive(Clone, Debug)]
10pub enum HashMapValue {
11 F64(Vec<f64>),
12 F32(Vec<f32>),
13 Bool(Vec<bool>),
14 I64(Vec<i64>),
15 U64(Vec<u64>),
16 String(Vec<String>),
17}
18
19impl HashMapValue {
20 fn new(item_type: ItemType) -> Self {
22 match item_type {
23 ItemType::F64 => HashMapValue::F64(Vec::new()),
24 ItemType::F32 => HashMapValue::F32(Vec::new()),
25 ItemType::Bool => HashMapValue::Bool(Vec::new()),
26 ItemType::I64 => HashMapValue::I64(Vec::new()),
27 ItemType::U64 => HashMapValue::U64(Vec::new()),
28 ItemType::String => HashMapValue::String(Vec::new()),
29 ItemType::DateTime64(_) | ItemType::TimeDelta64(_) => HashMapValue::I64(Vec::new()),
30 }
31 }
32
33 fn push(&mut self, value: Value) {
35 match (self, value) {
36 (HashMapValue::F64(vec), Value::ScalarF64(v)) => vec.push(v),
38 (HashMapValue::F32(vec), Value::ScalarF32(v)) => vec.push(v),
39 (HashMapValue::U64(vec), Value::ScalarU64(v)) => vec.push(v),
40 (HashMapValue::Bool(vec), Value::ScalarBool(v)) => vec.push(v),
41 (HashMapValue::I64(vec), Value::ScalarI64(v)) => vec.push(v),
42
43 (HashMapValue::F64(vec), Value::F64(v)) => vec.extend(v),
44 (HashMapValue::F32(vec), Value::F32(v)) => vec.extend(v),
45 (HashMapValue::U64(vec), Value::U64(v)) => vec.extend(v),
46 (HashMapValue::Bool(vec), Value::Bool(v)) => vec.extend(v),
47 (HashMapValue::I64(vec), Value::I64(v)) => vec.extend(v),
48
49 (HashMapValue::String(vec), Value::Strings(v)) => vec.extend(v),
50 (HashMapValue::String(vec), Value::ScalarString(v)) => vec.push(v),
51 (HashMapValue::I64(vec), Value::DateTime64(_, v)) => vec.extend(v),
52 (HashMapValue::I64(vec), Value::TimeDelta64(_, v)) => vec.extend(v),
53
54 _ => panic!("Mismatched item type"),
55 }
56 }
57}
58
59#[derive(Clone)]
61pub struct HashMapTraceStorage {
62 draw_types: Vec<(String, ItemType)>,
63 param_types: Vec<(String, ItemType)>,
64}
65
66#[derive(Clone)]
68pub struct HashMapChainStorage {
69 warmup_stats: HashMap<String, HashMapValue>,
70 sample_stats: HashMap<String, HashMapValue>,
71 warmup_draws: HashMap<String, HashMapValue>,
72 sample_draws: HashMap<String, HashMapValue>,
73 last_sample_was_warmup: bool,
74}
75
76#[derive(Debug, Clone)]
78pub struct HashMapResult {
79 pub stats: HashMap<String, HashMapValue>,
81 pub draws: HashMap<String, HashMapValue>,
83}
84
85impl HashMapChainStorage {
86 fn new(param_types: &Vec<(String, ItemType)>, draw_types: &Vec<(String, ItemType)>) -> Self {
88 let warmup_stats = param_types
89 .iter()
90 .cloned()
91 .map(|(name, item_type)| (name, HashMapValue::new(item_type)))
92 .collect();
93
94 let sample_stats = param_types
95 .iter()
96 .cloned()
97 .map(|(name, item_type)| (name, HashMapValue::new(item_type)))
98 .collect();
99
100 let warmup_draws = draw_types
101 .iter()
102 .cloned()
103 .map(|(name, item_type)| (name, HashMapValue::new(item_type)))
104 .collect();
105
106 let sample_draws = draw_types
107 .iter()
108 .cloned()
109 .map(|(name, item_type)| (name, HashMapValue::new(item_type)))
110 .collect();
111
112 Self {
113 warmup_stats,
114 sample_stats,
115 warmup_draws,
116 sample_draws,
117 last_sample_was_warmup: true,
118 }
119 }
120
121 fn push_param(&mut self, name: &str, value: Value, is_warmup: bool) -> Result<()> {
123 if ["draw", "chain"].contains(&name) {
124 return Ok(());
125 }
126
127 let target_map = if is_warmup {
128 &mut self.warmup_stats
129 } else {
130 &mut self.sample_stats
131 };
132
133 if let Some(hash_value) = target_map.get_mut(name) {
134 hash_value.push(value);
135 } else {
136 panic!("Unknown param name: {}", name);
137 }
138 Ok(())
139 }
140
141 fn push_draw(&mut self, name: &str, value: Value, is_warmup: bool) -> Result<()> {
143 if ["draw", "chain"].contains(&name) {
144 return Ok(());
145 }
146
147 let target_map = if is_warmup {
148 &mut self.warmup_draws
149 } else {
150 &mut self.sample_draws
151 };
152
153 if let Some(hash_value) = target_map.get_mut(name) {
154 hash_value.push(value);
155 } else {
156 panic!("Unknown posterior variable name: {}", name);
157 }
158 Ok(())
159 }
160}
161
162impl ChainStorage for HashMapChainStorage {
163 type Finalized = HashMapResult;
164
165 fn record_sample(
166 &mut self,
167 _settings: &impl Settings,
168 stats: Vec<(&str, Option<Value>)>,
169 draws: Vec<(&str, Option<Value>)>,
170 info: &Progress,
171 ) -> Result<()> {
172 let is_first_draw = self.last_sample_was_warmup && !info.tuning;
173 if is_first_draw {
174 self.last_sample_was_warmup = false;
175 }
176
177 for (name, value) in stats {
178 if let Some(value) = value {
179 self.push_param(name, value, info.tuning)?;
180 }
181 }
182 for (name, value) in draws {
183 if let Some(value) = value {
184 self.push_draw(name, value, info.tuning)?;
185 } else {
186 panic!("Missing draw value for {}", name);
187 }
188 }
189 Ok(())
190 }
191
192 fn finalize(self) -> Result<Self::Finalized> {
194 let mut combined_stats = HashMap::new();
196 let mut combined_draws = HashMap::new();
197
198 for (key, warmup_values) in self.warmup_stats {
200 let sample_values = &self.sample_stats[&key];
201 let mut combined = warmup_values.clone();
202
203 match (&mut combined, sample_values) {
204 (HashMapValue::F64(combined_vec), HashMapValue::F64(sample_vec)) => {
205 combined_vec.extend(sample_vec.iter().cloned());
206 }
207 (HashMapValue::F32(combined_vec), HashMapValue::F32(sample_vec)) => {
208 combined_vec.extend(sample_vec.iter().cloned());
209 }
210 (HashMapValue::Bool(combined_vec), HashMapValue::Bool(sample_vec)) => {
211 combined_vec.extend(sample_vec.iter().cloned());
212 }
213 (HashMapValue::I64(combined_vec), HashMapValue::I64(sample_vec)) => {
214 combined_vec.extend(sample_vec.iter().cloned());
215 }
216 (HashMapValue::U64(combined_vec), HashMapValue::U64(sample_vec)) => {
217 combined_vec.extend(sample_vec.iter().cloned());
218 }
219 _ => panic!("Type mismatch when combining stats for {}", key),
220 }
221
222 combined_stats.insert(key, combined);
223 }
224
225 for (key, warmup_values) in self.warmup_draws {
227 let sample_values = &self.sample_draws[&key];
228 let mut combined = warmup_values.clone();
229
230 match (&mut combined, sample_values) {
231 (HashMapValue::F64(combined_vec), HashMapValue::F64(sample_vec)) => {
232 combined_vec.extend(sample_vec.iter().cloned());
233 }
234 (HashMapValue::F32(combined_vec), HashMapValue::F32(sample_vec)) => {
235 combined_vec.extend(sample_vec.iter().cloned());
236 }
237 (HashMapValue::Bool(combined_vec), HashMapValue::Bool(sample_vec)) => {
238 combined_vec.extend(sample_vec.iter().cloned());
239 }
240 (HashMapValue::I64(combined_vec), HashMapValue::I64(sample_vec)) => {
241 combined_vec.extend(sample_vec.iter().cloned());
242 }
243 (HashMapValue::U64(combined_vec), HashMapValue::U64(sample_vec)) => {
244 combined_vec.extend(sample_vec.iter().cloned());
245 }
246 _ => panic!("Type mismatch when combining draws for {}", key),
247 }
248
249 combined_draws.insert(key, combined);
250 }
251
252 Ok(HashMapResult {
253 stats: combined_stats,
254 draws: combined_draws,
255 })
256 }
257
258 fn flush(&self) -> Result<()> {
260 Ok(())
261 }
262
263 fn inspect(&self) -> Result<Option<Self::Finalized>> {
264 self.clone().finalize().map(Some)
265 }
266}
267
268pub struct HashMapConfig {}
269
270impl Default for HashMapConfig {
271 fn default() -> Self {
272 Self::new()
273 }
274}
275
276impl HashMapConfig {
277 pub fn new() -> Self {
278 Self {}
279 }
280}
281
282impl StorageConfig for HashMapConfig {
283 type Storage = HashMapTraceStorage;
284
285 fn new_trace<M: crate::Math>(
286 self,
287 settings: &impl Settings,
288 math: &M,
289 ) -> Result<Self::Storage> {
290 Ok(HashMapTraceStorage {
291 param_types: settings.stat_types(math),
292 draw_types: settings.data_types(math),
293 })
294 }
295}
296
297impl TraceStorage for HashMapTraceStorage {
298 type ChainStorage = HashMapChainStorage;
299
300 type Finalized = Vec<HashMapResult>;
301
302 fn initialize_trace_for_chain(&self, _chain_id: u64) -> Result<Self::ChainStorage> {
303 Ok(HashMapChainStorage::new(
304 &self.param_types,
305 &self.draw_types,
306 ))
307 }
308
309 fn finalize(
310 self,
311 traces: Vec<Result<<Self::ChainStorage as ChainStorage>::Finalized>>,
312 ) -> Result<(Option<anyhow::Error>, Self::Finalized)> {
313 let mut results = Vec::new();
314 let mut first_error = None;
315
316 for trace in traces {
317 match trace {
318 Ok(result) => results.push(result),
319 Err(err) => {
320 if first_error.is_none() {
321 first_error = Some(err);
322 }
323 }
324 }
325 }
326
327 Ok((first_error, results))
328 }
329
330 fn inspect(
331 &self,
332 traces: Vec<Result<Option<<Self::ChainStorage as ChainStorage>::Finalized>>>,
333 ) -> Result<(Option<anyhow::Error>, Self::Finalized)> {
334 self.clone()
335 .finalize(traces.into_iter().map(|r| r.map(|o| o.unwrap())).collect())
336 }
337}