1use anyhow::Result;
2use nuts_storable::{ItemType, Value};
3use std::collections::HashMap;
4
5use crate::storage::{ChainStorage, StorageConfig, TraceStorage};
6use crate::{Progress, Settings};
7
8#[derive(Clone, Debug)]
10pub enum HashMapValue {
11 F64(Vec<f64>),
12 F32(Vec<f32>),
13 Bool(Vec<bool>),
14 I64(Vec<i64>),
15 U64(Vec<u64>),
16 String(Vec<String>),
17}
18
19impl HashMapValue {
20 fn new(item_type: ItemType) -> Self {
22 match item_type {
23 ItemType::F64 => HashMapValue::F64(Vec::new()),
24 ItemType::F32 => HashMapValue::F32(Vec::new()),
25 ItemType::Bool => HashMapValue::Bool(Vec::new()),
26 ItemType::I64 => HashMapValue::I64(Vec::new()),
27 ItemType::U64 => HashMapValue::U64(Vec::new()),
28 ItemType::String => HashMapValue::String(Vec::new()),
29 }
30 }
31
32 fn push(&mut self, value: Value) {
34 match (self, value) {
35 (HashMapValue::F64(vec), Value::ScalarF64(v)) => vec.push(v),
37 (HashMapValue::F32(vec), Value::ScalarF32(v)) => vec.push(v),
38 (HashMapValue::U64(vec), Value::ScalarU64(v)) => vec.push(v),
39 (HashMapValue::Bool(vec), Value::ScalarBool(v)) => vec.push(v),
40 (HashMapValue::I64(vec), Value::ScalarI64(v)) => vec.push(v),
41
42 (HashMapValue::F64(vec), Value::F64(v)) => vec.extend(v),
43 (HashMapValue::F32(vec), Value::F32(v)) => vec.extend(v),
44 (HashMapValue::U64(vec), Value::U64(v)) => vec.extend(v),
45 (HashMapValue::Bool(vec), Value::Bool(v)) => vec.extend(v),
46 (HashMapValue::I64(vec), Value::I64(v)) => vec.extend(v),
47
48 _ => panic!("Mismatched item type"),
49 }
50 }
51}
52
53#[derive(Clone)]
55pub struct HashMapTraceStorage {
56 draw_types: Vec<(String, ItemType)>,
57 param_types: Vec<(String, ItemType)>,
58}
59
60#[derive(Clone)]
62pub struct HashMapChainStorage {
63 warmup_stats: HashMap<String, HashMapValue>,
64 sample_stats: HashMap<String, HashMapValue>,
65 warmup_draws: HashMap<String, HashMapValue>,
66 sample_draws: HashMap<String, HashMapValue>,
67 last_sample_was_warmup: bool,
68}
69
70#[derive(Debug, Clone)]
72pub struct HashMapResult {
73 pub stats: HashMap<String, HashMapValue>,
75 pub draws: HashMap<String, HashMapValue>,
77}
78
79impl HashMapChainStorage {
80 fn new(param_types: &Vec<(String, ItemType)>, draw_types: &Vec<(String, ItemType)>) -> Self {
82 let warmup_stats = param_types
83 .iter()
84 .cloned()
85 .map(|(name, item_type)| (name, HashMapValue::new(item_type)))
86 .collect();
87
88 let sample_stats = param_types
89 .iter()
90 .cloned()
91 .map(|(name, item_type)| (name, HashMapValue::new(item_type)))
92 .collect();
93
94 let warmup_draws = draw_types
95 .iter()
96 .cloned()
97 .map(|(name, item_type)| (name, HashMapValue::new(item_type)))
98 .collect();
99
100 let sample_draws = draw_types
101 .iter()
102 .cloned()
103 .map(|(name, item_type)| (name, HashMapValue::new(item_type)))
104 .collect();
105
106 Self {
107 warmup_stats,
108 sample_stats,
109 warmup_draws,
110 sample_draws,
111 last_sample_was_warmup: true,
112 }
113 }
114
115 fn push_param(&mut self, name: &str, value: Value, is_warmup: bool) -> Result<()> {
117 if ["draw", "chain"].contains(&name) {
118 return Ok(());
119 }
120
121 let target_map = if is_warmup {
122 &mut self.warmup_stats
123 } else {
124 &mut self.sample_stats
125 };
126
127 if let Some(hash_value) = target_map.get_mut(name) {
128 hash_value.push(value);
129 } else {
130 panic!("Unknown param name: {}", name);
131 }
132 Ok(())
133 }
134
135 fn push_draw(&mut self, name: &str, value: Value, is_warmup: bool) -> Result<()> {
137 if ["draw", "chain"].contains(&name) {
138 return Ok(());
139 }
140
141 let target_map = if is_warmup {
142 &mut self.warmup_draws
143 } else {
144 &mut self.sample_draws
145 };
146
147 if let Some(hash_value) = target_map.get_mut(name) {
148 hash_value.push(value);
149 } else {
150 panic!("Unknown posterior variable name: {}", name);
151 }
152 Ok(())
153 }
154}
155
156impl ChainStorage for HashMapChainStorage {
157 type Finalized = HashMapResult;
158
159 fn record_sample(
160 &mut self,
161 _settings: &impl Settings,
162 stats: Vec<(&str, Option<Value>)>,
163 draws: Vec<(&str, Option<Value>)>,
164 info: &Progress,
165 ) -> Result<()> {
166 let is_first_draw = self.last_sample_was_warmup && !info.tuning;
167 if is_first_draw {
168 self.last_sample_was_warmup = false;
169 }
170
171 for (name, value) in stats {
172 if let Some(value) = value {
173 self.push_param(name, value, info.tuning)?;
174 }
175 }
176 for (name, value) in draws {
177 if let Some(value) = value {
178 self.push_draw(name, value, info.tuning)?;
179 } else {
180 panic!("Missing draw value for {}", name);
181 }
182 }
183 Ok(())
184 }
185
186 fn finalize(self) -> Result<Self::Finalized> {
188 let mut combined_stats = HashMap::new();
190 let mut combined_draws = HashMap::new();
191
192 for (key, warmup_values) in self.warmup_stats {
194 let sample_values = &self.sample_stats[&key];
195 let mut combined = warmup_values.clone();
196
197 match (&mut combined, sample_values) {
198 (HashMapValue::F64(combined_vec), HashMapValue::F64(sample_vec)) => {
199 combined_vec.extend(sample_vec.iter().cloned());
200 }
201 (HashMapValue::F32(combined_vec), HashMapValue::F32(sample_vec)) => {
202 combined_vec.extend(sample_vec.iter().cloned());
203 }
204 (HashMapValue::Bool(combined_vec), HashMapValue::Bool(sample_vec)) => {
205 combined_vec.extend(sample_vec.iter().cloned());
206 }
207 (HashMapValue::I64(combined_vec), HashMapValue::I64(sample_vec)) => {
208 combined_vec.extend(sample_vec.iter().cloned());
209 }
210 (HashMapValue::U64(combined_vec), HashMapValue::U64(sample_vec)) => {
211 combined_vec.extend(sample_vec.iter().cloned());
212 }
213 _ => panic!("Type mismatch when combining stats for {}", key),
214 }
215
216 combined_stats.insert(key, combined);
217 }
218
219 for (key, warmup_values) in self.warmup_draws {
221 let sample_values = &self.sample_draws[&key];
222 let mut combined = warmup_values.clone();
223
224 match (&mut combined, sample_values) {
225 (HashMapValue::F64(combined_vec), HashMapValue::F64(sample_vec)) => {
226 combined_vec.extend(sample_vec.iter().cloned());
227 }
228 (HashMapValue::F32(combined_vec), HashMapValue::F32(sample_vec)) => {
229 combined_vec.extend(sample_vec.iter().cloned());
230 }
231 (HashMapValue::Bool(combined_vec), HashMapValue::Bool(sample_vec)) => {
232 combined_vec.extend(sample_vec.iter().cloned());
233 }
234 (HashMapValue::I64(combined_vec), HashMapValue::I64(sample_vec)) => {
235 combined_vec.extend(sample_vec.iter().cloned());
236 }
237 (HashMapValue::U64(combined_vec), HashMapValue::U64(sample_vec)) => {
238 combined_vec.extend(sample_vec.iter().cloned());
239 }
240 _ => panic!("Type mismatch when combining draws for {}", key),
241 }
242
243 combined_draws.insert(key, combined);
244 }
245
246 Ok(HashMapResult {
247 stats: combined_stats,
248 draws: combined_draws,
249 })
250 }
251
252 fn flush(&self) -> Result<()> {
254 Ok(())
255 }
256
257 fn inspect(&self) -> Result<Option<Self::Finalized>> {
258 self.clone().finalize().map(Some)
259 }
260}
261
262pub struct HashMapConfig {}
263
264impl Default for HashMapConfig {
265 fn default() -> Self {
266 Self::new()
267 }
268}
269
270impl HashMapConfig {
271 pub fn new() -> Self {
272 Self {}
273 }
274}
275
276impl StorageConfig for HashMapConfig {
277 type Storage = HashMapTraceStorage;
278
279 fn new_trace<M: crate::Math>(
280 self,
281 settings: &impl Settings,
282 math: &M,
283 ) -> Result<Self::Storage> {
284 Ok(HashMapTraceStorage {
285 param_types: settings.stat_types(math),
286 draw_types: settings.data_types(math),
287 })
288 }
289}
290
291impl TraceStorage for HashMapTraceStorage {
292 type ChainStorage = HashMapChainStorage;
293
294 type Finalized = Vec<HashMapResult>;
295
296 fn initialize_trace_for_chain(&self, _chain_id: u64) -> Result<Self::ChainStorage> {
297 Ok(HashMapChainStorage::new(
298 &self.param_types,
299 &self.draw_types,
300 ))
301 }
302
303 fn finalize(
304 self,
305 traces: Vec<Result<<Self::ChainStorage as ChainStorage>::Finalized>>,
306 ) -> Result<(Option<anyhow::Error>, Self::Finalized)> {
307 let mut results = Vec::new();
308 let mut first_error = None;
309
310 for trace in traces {
311 match trace {
312 Ok(result) => results.push(result),
313 Err(err) => {
314 if first_error.is_none() {
315 first_error = Some(err);
316 }
317 }
318 }
319 }
320
321 Ok((first_error, results))
322 }
323
324 fn inspect(
325 &self,
326 traces: Vec<Result<Option<<Self::ChainStorage as ChainStorage>::Finalized>>>,
327 ) -> Result<(Option<anyhow::Error>, Self::Finalized)> {
328 self.clone()
329 .finalize(traces.into_iter().map(|r| r.map(|o| o.unwrap())).collect())
330 }
331}