1use std::collections::HashMap;
8use std::fs::File;
9use std::io::{BufWriter, Write};
10use std::path::{Path, PathBuf};
11
12use anyhow::{Context, Result};
13use nuts_storable::{ItemType, Value};
14
15use crate::storage::{ChainStorage, StorageConfig, TraceStorage};
16use crate::{Math, Progress, Settings};
17
18pub struct CsvConfig {
30 output_dir: PathBuf,
32 precision: usize,
34 store_warmup: bool,
36}
37
38impl CsvConfig {
39 pub fn new<P: AsRef<Path>>(output_dir: P) -> Self {
52 Self {
53 output_dir: output_dir.as_ref().to_path_buf(),
54 precision: 6,
55 store_warmup: true,
56 }
57 }
58
59 pub fn with_precision(mut self, precision: usize) -> Self {
63 self.precision = precision;
64 self
65 }
66
67 pub fn store_warmup(mut self, store: bool) -> Self {
72 self.store_warmup = store;
73 self
74 }
75}
76
77pub struct CsvTraceStorage {
79 output_dir: PathBuf,
80 precision: usize,
81 store_warmup: bool,
82 parameter_names: Vec<String>,
83 column_mapping: Vec<(String, usize)>, }
85
86pub struct CsvChainStorage {
88 writer: BufWriter<File>,
89 precision: usize,
90 store_warmup: bool,
91 parameter_names: Vec<String>,
92 column_mapping: Vec<(String, usize)>, is_first_sample: bool,
94 headers_written: bool,
95}
96
97impl CsvChainStorage {
98 fn new(
100 output_dir: &Path,
101 chain_id: u64,
102 precision: usize,
103 store_warmup: bool,
104 parameter_names: Vec<String>,
105 column_mapping: Vec<(String, usize)>,
106 ) -> Result<Self> {
107 std::fs::create_dir_all(output_dir)
108 .with_context(|| format!("Failed to create output directory: {:?}", output_dir))?;
109
110 let file_path = output_dir.join(format!("chain_{}.csv", chain_id));
111 let file = File::create(&file_path)
112 .with_context(|| format!("Failed to create CSV file: {:?}", file_path))?;
113 let writer = BufWriter::new(file);
114
115 Ok(Self {
116 writer,
117 precision,
118 store_warmup,
119 parameter_names,
120 column_mapping,
121 is_first_sample: true,
122 headers_written: false,
123 })
124 }
125
126 fn write_header(&mut self) -> Result<()> {
128 if self.headers_written {
129 return Ok(());
130 }
131
132 let mut headers = vec![
134 "lp__".to_string(),
135 "accept_stat__".to_string(),
136 "stepsize__".to_string(),
137 "treedepth__".to_string(),
138 "n_leapfrog__".to_string(),
139 "divergent__".to_string(),
140 "energy__".to_string(),
141 ];
142
143 for param_name in &self.parameter_names {
145 headers.push(param_name.clone());
146 }
147
148 writeln!(self.writer, "{}", headers.join(","))?;
150 self.headers_written = true;
151 Ok(())
152 }
153
154 fn format_value(&self, value: &Value) -> String {
156 match value {
157 Value::ScalarF64(v) => {
158 if v.is_nan() {
159 "NA".to_string()
160 } else if v.is_infinite() {
161 if *v > 0.0 { "Inf" } else { "-Inf" }.to_string()
162 } else {
163 format!("{:.prec$}", v, prec = self.precision)
164 }
165 }
166 Value::ScalarF32(v) => {
167 if v.is_nan() {
168 "NA".to_string()
169 } else if v.is_infinite() {
170 if *v > 0.0 { "Inf" } else { "-Inf" }.to_string()
171 } else {
172 format!("{:.prec$}", v, prec = self.precision)
173 }
174 }
175 Value::ScalarU64(v) => v.to_string(),
176 Value::ScalarI64(v) => v.to_string(),
177 Value::ScalarBool(v) => if *v { "1" } else { "0" }.to_string(),
178 Value::F64(vec) => {
179 if vec.is_empty() {
182 "NA".to_string()
183 } else {
184 self.format_value(&Value::ScalarF64(vec[0]))
185 }
186 }
187 Value::F32(vec) => {
188 if vec.is_empty() {
189 "NA".to_string()
190 } else {
191 self.format_value(&Value::ScalarF32(vec[0]))
192 }
193 }
194 Value::U64(vec) => {
195 if vec.is_empty() {
196 "NA".to_string()
197 } else {
198 vec[0].to_string()
199 }
200 }
201 Value::I64(vec) => {
202 if vec.is_empty() {
203 "NA".to_string()
204 } else {
205 vec[0].to_string()
206 }
207 }
208 Value::Bool(vec) => {
209 if vec.is_empty() {
210 "NA".to_string()
211 } else {
212 if vec[0] { "1" } else { "0" }.to_string()
213 }
214 }
215 Value::ScalarString(v) => v.clone(),
216 Value::Strings(vec) => {
217 if vec.is_empty() {
218 "NA".to_string()
219 } else {
220 vec[0].clone()
221 }
222 }
223 Value::DateTime64(_, _) => panic!("DateTime64 not supported in CSV output"),
224 Value::TimeDelta64(_, _) => panic!("TimeDelta64 not supported in CSV output"),
225 }
226 }
227
228 fn write_sample_row(
230 &mut self,
231 stats: &Vec<(&str, Option<Value>)>,
232 draws: &Vec<(&str, Option<Value>)>,
233 _info: &Progress,
234 ) -> Result<()> {
235 let mut row_values = Vec::new();
236
237 let stats_map: HashMap<&str, &Option<Value>> = stats.iter().map(|(k, v)| (*k, v)).collect();
239 let draws_map: HashMap<&str, &Option<Value>> = draws.iter().map(|(k, v)| (*k, v)).collect();
240
241 let get_stat_value = |name: &str| -> String {
243 stats_map
244 .get(name)
245 .and_then(|opt| opt.as_ref())
246 .map(|v| self.format_value(v))
247 .unwrap_or_else(|| "NA".to_string())
248 };
249
250 row_values.push(get_stat_value("logp"));
251 row_values.push(get_stat_value("mean_tree_accept"));
252 row_values.push(get_stat_value("step_size"));
253 row_values.push(get_stat_value("depth"));
254 row_values.push(get_stat_value("n_steps"));
255 let divergent_val = stats_map
256 .get("diverging")
257 .and_then(|opt| opt.as_ref())
258 .map(|v| match v {
259 Value::ScalarBool(true) => "1".to_string(),
260 Value::ScalarBool(false) => "0".to_string(),
261 _ => "0".to_string(),
262 })
263 .unwrap_or_else(|| "0".to_string());
264 row_values.push(divergent_val);
265
266 row_values.push(get_stat_value("energy"));
267
268 for (_param_name, (data_name, index)) in
270 self.parameter_names.iter().zip(&self.column_mapping)
271 {
272 if let Some(Some(data_value)) = draws_map.get(data_name.as_str()) {
273 let formatted_value = match data_value {
274 Value::F64(vec) => {
275 if *index < vec.len() {
276 self.format_value(&Value::ScalarF64(vec[*index]))
277 } else {
278 "NA".to_string()
279 }
280 }
281 Value::F32(vec) => {
282 if *index < vec.len() {
283 self.format_value(&Value::ScalarF32(vec[*index]))
284 } else {
285 "NA".to_string()
286 }
287 }
288 Value::I64(vec) => {
289 if *index < vec.len() {
290 self.format_value(&Value::ScalarI64(vec[*index]))
291 } else {
292 "NA".to_string()
293 }
294 }
295 Value::U64(vec) => {
296 if *index < vec.len() {
297 self.format_value(&Value::ScalarU64(vec[*index]))
298 } else {
299 "NA".to_string()
300 }
301 }
302 scalar_val if *index == 0 => self.format_value(scalar_val),
304 _ => "NA".to_string(),
305 };
306 row_values.push(formatted_value);
307 } else {
308 row_values.push("NA".to_string());
309 }
310 }
311
312 writeln!(self.writer, "{}", row_values.join(","))?;
314 Ok(())
315 }
316}
317
318impl ChainStorage for CsvChainStorage {
319 type Finalized = ();
320
321 fn record_sample(
322 &mut self,
323 _settings: &impl Settings,
324 stats: Vec<(&str, Option<Value>)>,
325 draws: Vec<(&str, Option<Value>)>,
326 info: &Progress,
327 ) -> Result<()> {
328 if info.tuning && !self.store_warmup {
330 return Ok(());
331 }
332
333 if self.is_first_sample {
335 self.write_header()?;
336 self.is_first_sample = false;
337 }
338
339 self.write_sample_row(&stats, &draws, info)?;
340 Ok(())
341 }
342
343 fn finalize(mut self) -> Result<Self::Finalized> {
344 self.writer.flush().context("Failed to flush CSV file")?;
345 Ok(())
346 }
347
348 fn flush(&self) -> Result<()> {
349 Ok(())
352 }
353
354 fn inspect(&self) -> Result<Option<Self::Finalized>> {
355 self.flush()?;
357 Ok(None)
358 }
359}
360
361impl StorageConfig for CsvConfig {
362 type Storage = CsvTraceStorage;
363
364 fn new_trace<M: Math>(self, settings: &impl Settings, math: &M) -> Result<Self::Storage> {
365 let (parameter_names, column_mapping) =
367 generate_parameter_names_and_mapping(settings, math)?;
368
369 Ok(CsvTraceStorage {
370 output_dir: self.output_dir,
371 precision: self.precision,
372 store_warmup: self.store_warmup,
373 parameter_names,
374 column_mapping,
375 })
376 }
377}
378
379fn generate_parameter_names_and_mapping<M: Math>(
381 settings: &impl Settings,
382 math: &M,
383) -> Result<(Vec<String>, Vec<(String, usize)>)> {
384 let data_dims = settings.data_dims_all(math);
385 let coords = math.coords();
386 let mut parameter_names = Vec::new();
387 let mut column_mapping = Vec::new();
388
389 for (var_name, var_dims) in data_dims {
390 let data_type = settings.data_type(math, &var_name);
391
392 if matches!(
394 data_type,
395 ItemType::F64 | ItemType::F32 | ItemType::I64 | ItemType::U64
396 ) {
397 let (column_names, indices) = generate_column_names_and_indices_for_variable(
398 &var_name, &var_dims, &coords, math,
399 )?;
400
401 for (name, index) in column_names.into_iter().zip(indices) {
402 parameter_names.push(name);
403 column_mapping.push((var_name.clone(), index));
404 }
405 }
406 }
407
408 if parameter_names.is_empty() {
410 let dim_sizes = math.dim_sizes();
411 let param_count = dim_sizes.get("expanded_parameter").unwrap_or(&0);
412 for i in 0..*param_count {
413 parameter_names.push(format!("param_{}", i + 1));
414 let data_names = settings.data_names(math);
416 let mut found_field = false;
417 for data_name in &data_names {
418 let data_type = settings.data_type(math, data_name);
419 if matches!(
420 data_type,
421 ItemType::F64 | ItemType::F32 | ItemType::I64 | ItemType::U64
422 ) {
423 column_mapping.push((data_name.clone(), i as usize));
424 found_field = true;
425 break;
426 }
427 }
428 if !found_field {
429 column_mapping.push(("unknown".to_string(), i as usize));
430 }
431 }
432 }
433
434 Ok((parameter_names, column_mapping))
435}
436
437fn generate_column_names_and_indices_for_variable<M: Math>(
439 var_name: &str,
440 var_dims: &[String],
441 coords: &HashMap<String, Value>,
442 math: &M,
443) -> Result<(Vec<String>, Vec<usize>)> {
444 let dim_sizes = math.dim_sizes();
445
446 if var_dims.is_empty() {
447 return Ok((vec![var_name.to_string()], vec![0]));
449 }
450
451 let has_meaningful_coords = var_dims.iter().all(|dim_name| {
453 coords.get(dim_name).is_some_and(
454 |coord_value| matches!(coord_value, Value::Strings(labels) if !labels.is_empty()),
455 )
456 });
457
458 let mut dim_coords: Vec<Vec<String>> = Vec::new();
460 let mut dim_sizes_vec: Vec<usize> = Vec::new();
461
462 for dim_name in var_dims {
463 let size = *dim_sizes.get(dim_name).unwrap_or(&1) as usize;
464 dim_sizes_vec.push(size);
465
466 if has_meaningful_coords {
467 if let Some(coord_value) = coords.get(dim_name) {
469 match coord_value {
470 Value::Strings(labels) => {
471 dim_coords.push(labels.clone());
472 }
473 _ => {
474 dim_coords.push((1..=size).map(|i| i.to_string()).collect());
476 }
477 }
478 } else {
479 dim_coords.push((1..=size).map(|i| i.to_string()).collect());
481 }
482 } else {
483 dim_coords.push((1..=size).map(|i| i.to_string()).collect());
485 }
486 }
487
488 let (coord_names, indices) =
490 cartesian_product_with_indices_column_major(&dim_coords, &dim_sizes_vec);
491
492 let column_names: Vec<String> = coord_names
494 .into_iter()
495 .map(|coord| format!("{}.{}", var_name, coord))
496 .collect();
497
498 Ok((column_names, indices))
499}
500
501fn cartesian_product_with_indices_column_major(
507 coord_sets: &[Vec<String>],
508 dim_sizes: &[usize],
509) -> (Vec<String>, Vec<usize>) {
510 if coord_sets.is_empty() {
511 return (vec![], vec![]);
512 }
513
514 if coord_sets.len() == 1 {
515 let indices: Vec<usize> = (0..coord_sets[0].len()).collect();
516 return (coord_sets[0].clone(), indices);
517 }
518
519 let mut names = vec![];
520 let mut indices = vec![];
521
522 cartesian_product_recursive_with_indices(
524 coord_sets,
525 dim_sizes,
526 0,
527 &mut String::new(),
528 &mut vec![],
529 &mut names,
530 &mut indices,
531 );
532
533 (names, indices)
534}
535
536fn cartesian_product_recursive_with_indices(
537 coord_sets: &[Vec<String>],
538 dim_sizes: &[usize],
539 dim_idx: usize,
540 current_name: &mut String,
541 current_indices: &mut Vec<usize>,
542 result_names: &mut Vec<String>,
543 result_indices: &mut Vec<usize>,
544) {
545 if dim_idx == coord_sets.len() {
546 result_names.push(current_name.clone());
547 let mut linear_index = 0;
549 for (i, &idx) in current_indices.iter().enumerate() {
550 let mut stride = 1;
551 for &size in &dim_sizes[i + 1..] {
552 stride *= size;
553 }
554 linear_index += idx * stride;
555 }
556 result_indices.push(linear_index);
557 return;
558 }
559
560 let is_first_dim = dim_idx == 0;
561
562 for (coord_idx, coord) in coord_sets[dim_idx].iter().enumerate() {
563 let mut new_name = current_name.clone();
564 if !is_first_dim {
565 new_name.push('.');
566 }
567 new_name.push_str(coord);
568
569 current_indices.push(coord_idx);
570 cartesian_product_recursive_with_indices(
571 coord_sets,
572 dim_sizes,
573 dim_idx + 1,
574 &mut new_name,
575 current_indices,
576 result_names,
577 result_indices,
578 );
579 current_indices.pop();
580 }
581}
582
583impl TraceStorage for CsvTraceStorage {
584 type ChainStorage = CsvChainStorage;
585 type Finalized = ();
586
587 fn initialize_trace_for_chain(&self, chain_id: u64) -> Result<Self::ChainStorage> {
588 CsvChainStorage::new(
589 &self.output_dir,
590 chain_id,
591 self.precision,
592 self.store_warmup,
593 self.parameter_names.clone(),
594 self.column_mapping.clone(),
595 )
596 }
597
598 fn finalize(
599 self,
600 traces: Vec<Result<<Self::ChainStorage as ChainStorage>::Finalized>>,
601 ) -> Result<(Option<anyhow::Error>, Self::Finalized)> {
602 for trace_result in traces {
604 if let Err(err) = trace_result {
605 return Ok((Some(err), ()));
606 }
607 }
608 Ok((None, ()))
609 }
610
611 fn inspect(
612 &self,
613 traces: Vec<Result<Option<<Self::ChainStorage as ChainStorage>::Finalized>>>,
614 ) -> Result<(Option<anyhow::Error>, Self::Finalized)> {
615 for trace_result in traces {
617 if let Err(err) = trace_result {
618 return Ok((Some(err), ()));
619 }
620 }
621 Ok((None, ()))
622 }
623}
624
625#[cfg(test)]
626mod tests {
627 use super::*;
628 use crate::{
629 CpuLogpFunc, CpuMath, CpuMathError, DiagGradNutsSettings, LogpError, Model, Sampler,
630 };
631 use anyhow::Result;
632 use nuts_derive::Storable;
633 use nuts_storable::{HasDims, Value};
634 use rand::Rng;
635 use std::collections::HashMap;
636 use std::fs;
637 use std::path::Path;
638 use thiserror::Error;
639
640 #[allow(dead_code)]
641 #[derive(Debug, Error)]
642 enum TestLogpError {
643 #[error("Test error")]
644 Test,
645 }
646
647 impl LogpError for TestLogpError {
648 fn is_recoverable(&self) -> bool {
649 false
650 }
651 }
652
653 #[derive(Clone)]
655 struct MultiDimTestLogp {
656 dim_a: usize,
657 dim_b: usize,
658 }
659
660 impl HasDims for MultiDimTestLogp {
661 fn dim_sizes(&self) -> HashMap<String, u64> {
662 HashMap::from([
663 ("a".to_string(), self.dim_a as u64),
664 ("b".to_string(), self.dim_b as u64),
665 ])
666 }
667
668 fn coords(&self) -> HashMap<String, Value> {
669 HashMap::from([
670 (
671 "a".to_string(),
672 Value::Strings(vec!["x".to_string(), "y".to_string()]),
673 ),
674 (
675 "b".to_string(),
676 Value::Strings(vec!["alpha".to_string(), "beta".to_string()]),
677 ),
678 ])
679 }
680 }
681
682 #[derive(Storable)]
683 struct MultiDimExpandedDraw {
684 #[storable(dims("a", "b"))]
685 param_matrix: Vec<f64>,
686 scalar_value: f64,
687 }
688
689 impl CpuLogpFunc for MultiDimTestLogp {
690 type LogpError = TestLogpError;
691 type FlowParameters = ();
692 type ExpandedVector = MultiDimExpandedDraw;
693
694 fn dim(&self) -> usize {
695 self.dim_a * self.dim_b
696 }
697
698 fn logp(&mut self, x: &[f64], grad: &mut [f64]) -> Result<f64, Self::LogpError> {
699 let mut logp = 0.0;
700 for (i, &xi) in x.iter().enumerate() {
701 logp -= 0.5 * xi * xi;
702 grad[i] = -xi;
703 }
704 Ok(logp)
705 }
706
707 fn expand_vector<R: Rng + ?Sized>(
708 &mut self,
709 _rng: &mut R,
710 array: &[f64],
711 ) -> Result<Self::ExpandedVector, CpuMathError> {
712 Ok(MultiDimExpandedDraw {
713 param_matrix: array.to_vec(),
714 scalar_value: array.iter().sum(),
715 })
716 }
717
718 fn vector_coord(&self) -> Option<Value> {
719 Some(Value::Strings(
720 (0..self.dim()).map(|i| format!("theta{}", i + 1)).collect(),
721 ))
722 }
723 }
724
725 struct MultiDimTestModel {
726 math: CpuMath<MultiDimTestLogp>,
727 }
728
729 impl Model for MultiDimTestModel {
730 type Math<'model>
731 = CpuMath<MultiDimTestLogp>
732 where
733 Self: 'model;
734
735 fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
736 Ok(self.math.clone())
737 }
738
739 fn init_position<R: Rng + ?Sized>(&self, rng: &mut R, position: &mut [f64]) -> Result<()> {
740 for p in position.iter_mut() {
741 *p = rng.random_range(-1.0..1.0);
742 }
743 Ok(())
744 }
745 }
746
747 #[derive(Clone)]
749 struct SimpleTestLogp {
750 dim: usize,
751 }
752
753 impl HasDims for SimpleTestLogp {
754 fn dim_sizes(&self) -> HashMap<String, u64> {
755 HashMap::from([("simple_param".to_string(), self.dim as u64)])
756 }
757 }
759
760 #[derive(Storable)]
761 struct SimpleExpandedDraw {
762 #[storable(dims("simple_param"))]
763 values: Vec<f64>,
764 }
765
766 impl CpuLogpFunc for SimpleTestLogp {
767 type LogpError = TestLogpError;
768 type FlowParameters = ();
769 type ExpandedVector = SimpleExpandedDraw;
770
771 fn dim(&self) -> usize {
772 self.dim
773 }
774
775 fn logp(&mut self, x: &[f64], grad: &mut [f64]) -> Result<f64, Self::LogpError> {
776 let mut logp = 0.0;
777 for (i, &xi) in x.iter().enumerate() {
778 logp -= 0.5 * xi * xi;
779 grad[i] = -xi;
780 }
781 Ok(logp)
782 }
783
784 fn expand_vector<R: Rng + ?Sized>(
785 &mut self,
786 _rng: &mut R,
787 array: &[f64],
788 ) -> Result<Self::ExpandedVector, CpuMathError> {
789 Ok(SimpleExpandedDraw {
790 values: array.to_vec(),
791 })
792 }
793
794 fn vector_coord(&self) -> Option<Value> {
795 Some(Value::Strings(vec![
796 "param1".to_string(),
797 "param2".to_string(),
798 "param3".to_string(),
799 ]))
800 }
801 }
802
803 struct SimpleTestModel {
804 math: CpuMath<SimpleTestLogp>,
805 }
806
807 impl Model for SimpleTestModel {
808 type Math<'model>
809 = CpuMath<SimpleTestLogp>
810 where
811 Self: 'model;
812
813 fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
814 Ok(self.math.clone())
815 }
816
817 fn init_position<R: Rng + ?Sized>(&self, rng: &mut R, position: &mut [f64]) -> Result<()> {
818 for p in position.iter_mut() {
819 *p = rng.random_range(-1.0..1.0);
820 }
821 Ok(())
822 }
823 }
824
825 fn read_csv_header(path: &Path) -> Result<String> {
826 let content = fs::read_to_string(path)?;
827 content
828 .lines()
829 .next()
830 .map(|s| s.to_string())
831 .ok_or_else(|| anyhow::anyhow!("Empty CSV file"))
832 }
833
834 #[test]
835 fn test_multidim_coordinate_naming() -> Result<()> {
836 let temp_dir = tempfile::tempdir()?;
837 let output_path = temp_dir.path().join("multidim_test");
838
839 let model = MultiDimTestModel {
841 math: CpuMath::new(MultiDimTestLogp { dim_a: 2, dim_b: 2 }),
842 };
843
844 let mut settings = DiagGradNutsSettings::default();
845 settings.num_chains = 1;
846 settings.num_tune = 10;
847 settings.num_draws = 20;
848 settings.seed = 42;
849
850 let csv_config = CsvConfig::new(&output_path)
851 .with_precision(6)
852 .store_warmup(false);
853
854 let mut sampler = Some(Sampler::new(model, settings, csv_config, 1, None)?);
855
856 while let Some(sampler_) = sampler.take() {
858 match sampler_.wait_timeout(std::time::Duration::from_millis(100)) {
859 crate::SamplerWaitResult::Trace(_) => break,
860 crate::SamplerWaitResult::Timeout(s) => sampler = Some(s),
861 crate::SamplerWaitResult::Err(err, _) => return Err(err),
862 }
863 }
864
865 let csv_file = output_path.join("chain_0.csv");
867 assert!(csv_file.exists());
868
869 let header = read_csv_header(&csv_file)?;
871
872 assert!(header.contains("param_matrix.x.alpha"));
874 assert!(header.contains("param_matrix.x.beta"));
875 assert!(header.contains("param_matrix.y.alpha"));
876 assert!(header.contains("param_matrix.y.beta"));
877 assert!(header.contains("scalar_value"));
878
879 let columns: Vec<&str> = header.split(',').collect();
881 let param_columns: Vec<&str> = columns
882 .iter()
883 .filter(|col| col.starts_with("param_matrix."))
884 .cloned()
885 .collect();
886
887 assert_eq!(
888 param_columns,
889 vec![
890 "param_matrix.x.alpha",
891 "param_matrix.x.beta",
892 "param_matrix.y.alpha",
893 "param_matrix.y.beta"
894 ]
895 );
896
897 Ok(())
898 }
899
900 #[test]
901 fn test_fallback_coordinate_naming() -> Result<()> {
902 let temp_dir = tempfile::tempdir()?;
903 let output_path = temp_dir.path().join("simple_test");
904
905 let model = SimpleTestModel {
907 math: CpuMath::new(SimpleTestLogp { dim: 3 }),
908 };
909
910 let mut settings = DiagGradNutsSettings::default();
911 settings.num_chains = 1;
912 settings.num_tune = 5;
913 settings.num_draws = 10;
914 settings.seed = 123;
915
916 let csv_config = CsvConfig::new(&output_path)
917 .with_precision(6)
918 .store_warmup(false);
919
920 let mut sampler = Some(Sampler::new(model, settings, csv_config, 1, None)?);
921
922 while let Some(sampler_) = sampler.take() {
924 match sampler_.wait_timeout(std::time::Duration::from_millis(100)) {
925 crate::SamplerWaitResult::Trace(_) => break,
926 crate::SamplerWaitResult::Timeout(s) => sampler = Some(s),
927 crate::SamplerWaitResult::Err(err, _) => return Err(err),
928 }
929 }
930
931 let csv_file = output_path.join("chain_0.csv");
933 assert!(csv_file.exists());
934
935 let header = read_csv_header(&csv_file)?;
937
938 assert!(header.contains("values.1"));
940 assert!(header.contains("values.2"));
941 assert!(header.contains("values.3"));
942
943 Ok(())
944 }
945
946 #[test]
947 fn test_cartesian_product_generation() {
948 let coord_sets = vec![
949 vec!["x".to_string(), "y".to_string()],
950 vec!["alpha".to_string(), "beta".to_string()],
951 ];
952 let dim_sizes = vec![2, 2];
953
954 let (names, indices) = cartesian_product_with_indices_column_major(&coord_sets, &dim_sizes);
955
956 assert_eq!(names, vec!["x.alpha", "x.beta", "y.alpha", "y.beta"]);
957
958 assert_eq!(indices, vec![0, 1, 2, 3]);
959 }
960
961 #[test]
962 fn test_single_dimension_coordinates() {
963 let coord_sets = vec![vec!["param1".to_string(), "param2".to_string()]];
964 let dim_sizes = vec![2];
965
966 let (names, indices) = cartesian_product_with_indices_column_major(&coord_sets, &dim_sizes);
967
968 assert_eq!(names, vec!["param1", "param2"]);
969 assert_eq!(indices, vec![0, 1]);
970 }
971
972 #[test]
973 fn test_three_dimension_cartesian_product() {
974 let coord_sets = vec![
975 vec!["a".to_string(), "b".to_string()],
976 vec!["1".to_string()],
977 vec!["i".to_string(), "j".to_string()],
978 ];
979 let dim_sizes = vec![2, 1, 2];
980
981 let (names, indices) = cartesian_product_with_indices_column_major(&coord_sets, &dim_sizes);
982
983 assert_eq!(names, vec!["a.1.i", "a.1.j", "b.1.i", "b.1.j"]);
984
985 assert_eq!(indices, vec![0, 1, 2, 3]);
986 }
987}