1use std::collections::HashMap;
4use std::fs::File;
5use std::io::{BufWriter, Write};
6use std::path::{Path, PathBuf};
7
8use anyhow::{Context, Result};
9use nuts_storable::{ItemType, Value};
10
11use crate::storage::{ChainStorage, StorageConfig, TraceStorage};
12use crate::{Math, Progress, Settings};
13
14pub struct CsvConfig {
26 output_dir: PathBuf,
28 precision: usize,
30 store_warmup: bool,
32}
33
34impl CsvConfig {
35 pub fn new<P: AsRef<Path>>(output_dir: P) -> Self {
48 Self {
49 output_dir: output_dir.as_ref().to_path_buf(),
50 precision: 6,
51 store_warmup: true,
52 }
53 }
54
55 pub fn with_precision(mut self, precision: usize) -> Self {
59 self.precision = precision;
60 self
61 }
62
63 pub fn store_warmup(mut self, store: bool) -> Self {
68 self.store_warmup = store;
69 self
70 }
71}
72
73pub struct CsvTraceStorage {
75 output_dir: PathBuf,
76 precision: usize,
77 store_warmup: bool,
78 parameter_names: Vec<String>,
79 column_mapping: Vec<(String, usize)>, }
81
82pub struct CsvChainStorage {
84 writer: BufWriter<File>,
85 precision: usize,
86 store_warmup: bool,
87 parameter_names: Vec<String>,
88 column_mapping: Vec<(String, usize)>, is_first_sample: bool,
90 headers_written: bool,
91}
92
93impl CsvChainStorage {
94 fn new(
96 output_dir: &Path,
97 chain_id: u64,
98 precision: usize,
99 store_warmup: bool,
100 parameter_names: Vec<String>,
101 column_mapping: Vec<(String, usize)>,
102 ) -> Result<Self> {
103 std::fs::create_dir_all(output_dir)
104 .with_context(|| format!("Failed to create output directory: {:?}", output_dir))?;
105
106 let file_path = output_dir.join(format!("chain_{}.csv", chain_id));
107 let file = File::create(&file_path)
108 .with_context(|| format!("Failed to create CSV file: {:?}", file_path))?;
109 let writer = BufWriter::new(file);
110
111 Ok(Self {
112 writer,
113 precision,
114 store_warmup,
115 parameter_names,
116 column_mapping,
117 is_first_sample: true,
118 headers_written: false,
119 })
120 }
121
122 fn write_header(&mut self) -> Result<()> {
124 if self.headers_written {
125 return Ok(());
126 }
127
128 let mut headers = vec![
130 "lp__".to_string(),
131 "accept_stat__".to_string(),
132 "stepsize__".to_string(),
133 "treedepth__".to_string(),
134 "n_leapfrog__".to_string(),
135 "divergent__".to_string(),
136 "energy__".to_string(),
137 ];
138
139 for param_name in &self.parameter_names {
141 headers.push(param_name.clone());
142 }
143
144 writeln!(self.writer, "{}", headers.join(","))?;
146 self.headers_written = true;
147 Ok(())
148 }
149
150 fn format_value(&self, value: &Value) -> String {
152 match value {
153 Value::ScalarF64(v) => {
154 if v.is_nan() {
155 "NA".to_string()
156 } else if v.is_infinite() {
157 if *v > 0.0 { "Inf" } else { "-Inf" }.to_string()
158 } else {
159 format!("{:.prec$}", v, prec = self.precision)
160 }
161 }
162 Value::ScalarF32(v) => {
163 if v.is_nan() {
164 "NA".to_string()
165 } else if v.is_infinite() {
166 if *v > 0.0 { "Inf" } else { "-Inf" }.to_string()
167 } else {
168 format!("{:.prec$}", v, prec = self.precision)
169 }
170 }
171 Value::ScalarU64(v) => v.to_string(),
172 Value::ScalarI64(v) => v.to_string(),
173 Value::ScalarBool(v) => if *v { "1" } else { "0" }.to_string(),
174 Value::F64(vec) => {
175 if vec.is_empty() {
178 "NA".to_string()
179 } else {
180 self.format_value(&Value::ScalarF64(vec[0]))
181 }
182 }
183 Value::F32(vec) => {
184 if vec.is_empty() {
185 "NA".to_string()
186 } else {
187 self.format_value(&Value::ScalarF32(vec[0]))
188 }
189 }
190 Value::U64(vec) => {
191 if vec.is_empty() {
192 "NA".to_string()
193 } else {
194 vec[0].to_string()
195 }
196 }
197 Value::I64(vec) => {
198 if vec.is_empty() {
199 "NA".to_string()
200 } else {
201 vec[0].to_string()
202 }
203 }
204 Value::Bool(vec) => {
205 if vec.is_empty() {
206 "NA".to_string()
207 } else {
208 if vec[0] { "1" } else { "0" }.to_string()
209 }
210 }
211 Value::ScalarString(v) => v.clone(),
212 Value::Strings(vec) => {
213 if vec.is_empty() {
214 "NA".to_string()
215 } else {
216 vec[0].clone()
217 }
218 }
219 Value::DateTime64(_, _) => panic!("DateTime64 not supported in CSV output"),
220 Value::TimeDelta64(_, _) => panic!("TimeDelta64 not supported in CSV output"),
221 }
222 }
223
224 fn write_sample_row(
226 &mut self,
227 stats: &Vec<(&str, Option<Value>)>,
228 draws: &Vec<(&str, Option<Value>)>,
229 _info: &Progress,
230 ) -> Result<()> {
231 let mut row_values = Vec::new();
232
233 let stats_map: HashMap<&str, &Option<Value>> = stats.iter().map(|(k, v)| (*k, v)).collect();
235 let draws_map: HashMap<&str, &Option<Value>> = draws.iter().map(|(k, v)| (*k, v)).collect();
236
237 let get_stat_value = |name: &str| -> String {
239 stats_map
240 .get(name)
241 .and_then(|opt| opt.as_ref())
242 .map(|v| self.format_value(v))
243 .unwrap_or_else(|| "NA".to_string())
244 };
245
246 row_values.push(get_stat_value("logp"));
247 row_values.push(get_stat_value("mean_tree_accept"));
248 row_values.push(get_stat_value("step_size"));
249 row_values.push(get_stat_value("depth"));
250 row_values.push(get_stat_value("n_steps"));
251 let divergent_val = stats_map
252 .get("diverging")
253 .and_then(|opt| opt.as_ref())
254 .map(|v| match v {
255 Value::ScalarBool(true) => "1".to_string(),
256 Value::ScalarBool(false) => "0".to_string(),
257 _ => "0".to_string(),
258 })
259 .unwrap_or_else(|| "0".to_string());
260 row_values.push(divergent_val);
261
262 row_values.push(get_stat_value("energy"));
263
264 for (_param_name, (data_name, index)) in
266 self.parameter_names.iter().zip(&self.column_mapping)
267 {
268 if let Some(Some(data_value)) = draws_map.get(data_name.as_str()) {
269 let formatted_value = match data_value {
270 Value::F64(vec) => {
271 if *index < vec.len() {
272 self.format_value(&Value::ScalarF64(vec[*index]))
273 } else {
274 "NA".to_string()
275 }
276 }
277 Value::F32(vec) => {
278 if *index < vec.len() {
279 self.format_value(&Value::ScalarF32(vec[*index]))
280 } else {
281 "NA".to_string()
282 }
283 }
284 Value::I64(vec) => {
285 if *index < vec.len() {
286 self.format_value(&Value::ScalarI64(vec[*index]))
287 } else {
288 "NA".to_string()
289 }
290 }
291 Value::U64(vec) => {
292 if *index < vec.len() {
293 self.format_value(&Value::ScalarU64(vec[*index]))
294 } else {
295 "NA".to_string()
296 }
297 }
298 scalar_val if *index == 0 => self.format_value(scalar_val),
300 _ => "NA".to_string(),
301 };
302 row_values.push(formatted_value);
303 } else {
304 row_values.push("NA".to_string());
305 }
306 }
307
308 writeln!(self.writer, "{}", row_values.join(","))?;
310 Ok(())
311 }
312}
313
314impl ChainStorage for CsvChainStorage {
315 type Finalized = ();
316
317 fn record_sample(
318 &mut self,
319 _settings: &impl Settings,
320 stats: Vec<(&str, Option<Value>)>,
321 draws: Vec<(&str, Option<Value>)>,
322 info: &Progress,
323 ) -> Result<()> {
324 if info.tuning && !self.store_warmup {
326 return Ok(());
327 }
328
329 if self.is_first_sample {
331 self.write_header()?;
332 self.is_first_sample = false;
333 }
334
335 self.write_sample_row(&stats, &draws, info)?;
336 Ok(())
337 }
338
339 fn finalize(mut self) -> Result<Self::Finalized> {
340 self.writer.flush().context("Failed to flush CSV file")?;
341 Ok(())
342 }
343
344 fn flush(&self) -> Result<()> {
345 Ok(())
348 }
349
350 fn inspect(&self) -> Result<Option<Self::Finalized>> {
351 self.flush()?;
353 Ok(None)
354 }
355}
356
357impl StorageConfig for CsvConfig {
358 type Storage = CsvTraceStorage;
359
360 fn new_trace<M: Math>(self, settings: &impl Settings, math: &M) -> Result<Self::Storage> {
361 let (parameter_names, column_mapping) =
363 generate_parameter_names_and_mapping(settings, math)?;
364
365 Ok(CsvTraceStorage {
366 output_dir: self.output_dir,
367 precision: self.precision,
368 store_warmup: self.store_warmup,
369 parameter_names,
370 column_mapping,
371 })
372 }
373}
374
375fn generate_parameter_names_and_mapping<M: Math>(
377 settings: &impl Settings,
378 math: &M,
379) -> Result<(Vec<String>, Vec<(String, usize)>)> {
380 let data_dims = settings.data_dims_all(math);
381 let coords = math.coords();
382 let mut parameter_names = Vec::new();
383 let mut column_mapping = Vec::new();
384
385 for (var_name, var_dims) in data_dims {
386 let data_type = settings.data_type(math, &var_name);
387
388 if matches!(
390 data_type,
391 ItemType::F64 | ItemType::F32 | ItemType::I64 | ItemType::U64
392 ) {
393 let (column_names, indices) = generate_column_names_and_indices_for_variable(
394 &var_name, &var_dims, &coords, math,
395 )?;
396
397 for (name, index) in column_names.into_iter().zip(indices) {
398 parameter_names.push(name);
399 column_mapping.push((var_name.clone(), index));
400 }
401 }
402 }
403
404 if parameter_names.is_empty() {
406 let dim_sizes = math.dim_sizes();
407 let param_count = dim_sizes.get("expanded_parameter").unwrap_or(&0);
408 for i in 0..*param_count {
409 parameter_names.push(format!("param_{}", i + 1));
410 let data_names = settings.data_names(math);
412 let mut found_field = false;
413 for data_name in &data_names {
414 let data_type = settings.data_type(math, data_name);
415 if matches!(
416 data_type,
417 ItemType::F64 | ItemType::F32 | ItemType::I64 | ItemType::U64
418 ) {
419 column_mapping.push((data_name.clone(), i as usize));
420 found_field = true;
421 break;
422 }
423 }
424 if !found_field {
425 column_mapping.push(("unknown".to_string(), i as usize));
426 }
427 }
428 }
429
430 Ok((parameter_names, column_mapping))
431}
432
433fn generate_column_names_and_indices_for_variable<M: Math>(
435 var_name: &str,
436 var_dims: &[String],
437 coords: &HashMap<String, Value>,
438 math: &M,
439) -> Result<(Vec<String>, Vec<usize>)> {
440 let dim_sizes = math.dim_sizes();
441
442 if var_dims.is_empty() {
443 return Ok((vec![var_name.to_string()], vec![0]));
445 }
446
447 let has_meaningful_coords = var_dims.iter().all(|dim_name| {
449 coords.get(dim_name).is_some_and(
450 |coord_value| matches!(coord_value, Value::Strings(labels) if !labels.is_empty()),
451 )
452 });
453
454 let mut dim_coords: Vec<Vec<String>> = Vec::new();
456 let mut dim_sizes_vec: Vec<usize> = Vec::new();
457
458 for dim_name in var_dims {
459 let size = *dim_sizes.get(dim_name).unwrap_or(&1) as usize;
460 dim_sizes_vec.push(size);
461
462 if has_meaningful_coords {
463 if let Some(coord_value) = coords.get(dim_name) {
465 match coord_value {
466 Value::Strings(labels) => {
467 dim_coords.push(labels.clone());
468 }
469 _ => {
470 dim_coords.push((1..=size).map(|i| i.to_string()).collect());
472 }
473 }
474 } else {
475 dim_coords.push((1..=size).map(|i| i.to_string()).collect());
477 }
478 } else {
479 dim_coords.push((1..=size).map(|i| i.to_string()).collect());
481 }
482 }
483
484 let (coord_names, indices) =
486 cartesian_product_with_indices_column_major(&dim_coords, &dim_sizes_vec);
487
488 let column_names: Vec<String> = coord_names
490 .into_iter()
491 .map(|coord| format!("{}.{}", var_name, coord))
492 .collect();
493
494 Ok((column_names, indices))
495}
496
497fn cartesian_product_with_indices_column_major(
503 coord_sets: &[Vec<String>],
504 dim_sizes: &[usize],
505) -> (Vec<String>, Vec<usize>) {
506 if coord_sets.is_empty() {
507 return (vec![], vec![]);
508 }
509
510 if coord_sets.len() == 1 {
511 let indices: Vec<usize> = (0..coord_sets[0].len()).collect();
512 return (coord_sets[0].clone(), indices);
513 }
514
515 let mut names = vec![];
516 let mut indices = vec![];
517
518 cartesian_product_recursive_with_indices(
520 coord_sets,
521 dim_sizes,
522 0,
523 &mut String::new(),
524 &mut vec![],
525 &mut names,
526 &mut indices,
527 );
528
529 (names, indices)
530}
531
532fn cartesian_product_recursive_with_indices(
533 coord_sets: &[Vec<String>],
534 dim_sizes: &[usize],
535 dim_idx: usize,
536 current_name: &mut String,
537 current_indices: &mut Vec<usize>,
538 result_names: &mut Vec<String>,
539 result_indices: &mut Vec<usize>,
540) {
541 if dim_idx == coord_sets.len() {
542 result_names.push(current_name.clone());
543 let mut linear_index = 0;
545 for (i, &idx) in current_indices.iter().enumerate() {
546 let mut stride = 1;
547 for &size in &dim_sizes[i + 1..] {
548 stride *= size;
549 }
550 linear_index += idx * stride;
551 }
552 result_indices.push(linear_index);
553 return;
554 }
555
556 let is_first_dim = dim_idx == 0;
557
558 for (coord_idx, coord) in coord_sets[dim_idx].iter().enumerate() {
559 let mut new_name = current_name.clone();
560 if !is_first_dim {
561 new_name.push('.');
562 }
563 new_name.push_str(coord);
564
565 current_indices.push(coord_idx);
566 cartesian_product_recursive_with_indices(
567 coord_sets,
568 dim_sizes,
569 dim_idx + 1,
570 &mut new_name,
571 current_indices,
572 result_names,
573 result_indices,
574 );
575 current_indices.pop();
576 }
577}
578
579impl TraceStorage for CsvTraceStorage {
580 type ChainStorage = CsvChainStorage;
581 type Finalized = ();
582
583 fn initialize_trace_for_chain(&self, chain_id: u64) -> Result<Self::ChainStorage> {
584 CsvChainStorage::new(
585 &self.output_dir,
586 chain_id,
587 self.precision,
588 self.store_warmup,
589 self.parameter_names.clone(),
590 self.column_mapping.clone(),
591 )
592 }
593
594 fn finalize(
595 self,
596 traces: Vec<Result<<Self::ChainStorage as ChainStorage>::Finalized>>,
597 ) -> Result<(Option<anyhow::Error>, Self::Finalized)> {
598 for trace_result in traces {
600 if let Err(err) = trace_result {
601 return Ok((Some(err), ()));
602 }
603 }
604 Ok((None, ()))
605 }
606
607 fn inspect(
608 &self,
609 traces: Vec<Result<Option<<Self::ChainStorage as ChainStorage>::Finalized>>>,
610 ) -> Result<(Option<anyhow::Error>, Self::Finalized)> {
611 for trace_result in traces {
613 if let Err(err) = trace_result {
614 return Ok((Some(err), ()));
615 }
616 }
617 Ok((None, ()))
618 }
619}
620
621#[cfg(test)]
622mod tests {
623 use super::*;
624 use crate::{CpuLogpFunc, CpuMath, CpuMathError, DiagNutsSettings, LogpError, Model, Sampler};
625 use anyhow::Result;
626 use nuts_derive::Storable;
627 use nuts_storable::{HasDims, Value};
628 use rand::{Rng, RngExt};
629 use std::collections::HashMap;
630 use std::fs;
631 use std::path::Path;
632 use thiserror::Error;
633
634 #[allow(dead_code)]
635 #[derive(Debug, Error)]
636 enum TestLogpError {
637 #[error("Test error")]
638 Test,
639 }
640
641 impl LogpError for TestLogpError {
642 fn is_recoverable(&self) -> bool {
643 false
644 }
645 }
646
647 #[derive(Clone)]
649 struct MultiDimTestLogp {
650 dim_a: usize,
651 dim_b: usize,
652 }
653
654 impl HasDims for MultiDimTestLogp {
655 fn dim_sizes(&self) -> HashMap<String, u64> {
656 HashMap::from([
657 ("a".to_string(), self.dim_a as u64),
658 ("b".to_string(), self.dim_b as u64),
659 ])
660 }
661
662 fn coords(&self) -> HashMap<String, Value> {
663 HashMap::from([
664 (
665 "a".to_string(),
666 Value::Strings(vec!["x".to_string(), "y".to_string()]),
667 ),
668 (
669 "b".to_string(),
670 Value::Strings(vec!["alpha".to_string(), "beta".to_string()]),
671 ),
672 ])
673 }
674 }
675
676 #[derive(Storable)]
677 struct MultiDimExpandedDraw {
678 #[storable(dims("a", "b"))]
679 param_matrix: Vec<f64>,
680 scalar_value: f64,
681 }
682
683 impl CpuLogpFunc for MultiDimTestLogp {
684 type LogpError = TestLogpError;
685 type FlowParameters = ();
686 type ExpandedVector = MultiDimExpandedDraw;
687
688 fn dim(&self) -> usize {
689 self.dim_a * self.dim_b
690 }
691
692 fn logp(&mut self, x: &[f64], grad: &mut [f64]) -> Result<f64, Self::LogpError> {
693 let mut logp = 0.0;
694 for (i, &xi) in x.iter().enumerate() {
695 logp -= 0.5 * xi * xi;
696 grad[i] = -xi;
697 }
698 Ok(logp)
699 }
700
701 fn expand_vector<R: Rng + ?Sized>(
702 &mut self,
703 _rng: &mut R,
704 array: &[f64],
705 ) -> Result<Self::ExpandedVector, CpuMathError> {
706 Ok(MultiDimExpandedDraw {
707 param_matrix: array.to_vec(),
708 scalar_value: array.iter().sum(),
709 })
710 }
711
712 fn vector_coord(&self) -> Option<Value> {
713 Some(Value::Strings(
714 (0..self.dim()).map(|i| format!("theta{}", i + 1)).collect(),
715 ))
716 }
717 }
718
719 struct MultiDimTestModel {
720 math: CpuMath<MultiDimTestLogp>,
721 }
722
723 impl Model for MultiDimTestModel {
724 type Math<'model>
725 = CpuMath<MultiDimTestLogp>
726 where
727 Self: 'model;
728
729 fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
730 Ok(self.math.clone())
731 }
732
733 fn init_position<R: Rng + ?Sized>(&self, rng: &mut R, position: &mut [f64]) -> Result<()> {
734 for p in position.iter_mut() {
735 *p = rng.random_range(-1.0..1.0);
736 }
737 Ok(())
738 }
739 }
740
741 #[derive(Clone)]
743 struct SimpleTestLogp {
744 dim: usize,
745 }
746
747 impl HasDims for SimpleTestLogp {
748 fn dim_sizes(&self) -> HashMap<String, u64> {
749 HashMap::from([("simple_param".to_string(), self.dim as u64)])
750 }
751 }
753
754 #[derive(Storable)]
755 struct SimpleExpandedDraw {
756 #[storable(dims("simple_param"))]
757 values: Vec<f64>,
758 }
759
760 impl CpuLogpFunc for SimpleTestLogp {
761 type LogpError = TestLogpError;
762 type FlowParameters = ();
763 type ExpandedVector = SimpleExpandedDraw;
764
765 fn dim(&self) -> usize {
766 self.dim
767 }
768
769 fn logp(&mut self, x: &[f64], grad: &mut [f64]) -> Result<f64, Self::LogpError> {
770 let mut logp = 0.0;
771 for (i, &xi) in x.iter().enumerate() {
772 logp -= 0.5 * xi * xi;
773 grad[i] = -xi;
774 }
775 Ok(logp)
776 }
777
778 fn expand_vector<R: Rng + ?Sized>(
779 &mut self,
780 _rng: &mut R,
781 array: &[f64],
782 ) -> Result<Self::ExpandedVector, CpuMathError> {
783 Ok(SimpleExpandedDraw {
784 values: array.to_vec(),
785 })
786 }
787
788 fn vector_coord(&self) -> Option<Value> {
789 Some(Value::Strings(vec![
790 "param1".to_string(),
791 "param2".to_string(),
792 "param3".to_string(),
793 ]))
794 }
795 }
796
797 struct SimpleTestModel {
798 math: CpuMath<SimpleTestLogp>,
799 }
800
801 impl Model for SimpleTestModel {
802 type Math<'model>
803 = CpuMath<SimpleTestLogp>
804 where
805 Self: 'model;
806
807 fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
808 Ok(self.math.clone())
809 }
810
811 fn init_position<R: Rng + ?Sized>(&self, rng: &mut R, position: &mut [f64]) -> Result<()> {
812 for p in position.iter_mut() {
813 *p = rng.random_range(-1.0..1.0);
814 }
815 Ok(())
816 }
817 }
818
819 fn read_csv_header(path: &Path) -> Result<String> {
820 let content = fs::read_to_string(path)?;
821 content
822 .lines()
823 .next()
824 .map(|s| s.to_string())
825 .ok_or_else(|| anyhow::anyhow!("Empty CSV file"))
826 }
827
828 #[test]
829 fn test_multidim_coordinate_naming() -> Result<()> {
830 let temp_dir = tempfile::tempdir()?;
831 let output_path = temp_dir.path().join("multidim_test");
832
833 let model = MultiDimTestModel {
835 math: CpuMath::new(MultiDimTestLogp { dim_a: 2, dim_b: 2 }),
836 };
837
838 let mut settings = DiagNutsSettings::default();
839 settings.num_chains = 1;
840 settings.num_tune = 10;
841 settings.num_draws = 20;
842 settings.seed = 42;
843
844 let csv_config = CsvConfig::new(&output_path)
845 .with_precision(6)
846 .store_warmup(false);
847
848 let mut sampler = Some(Sampler::new(model, settings, csv_config, 1, None)?);
849
850 while let Some(sampler_) = sampler.take() {
852 match sampler_.wait_timeout(std::time::Duration::from_millis(100)) {
853 crate::SamplerWaitResult::Trace(_) => break,
854 crate::SamplerWaitResult::Timeout(s) => sampler = Some(s),
855 crate::SamplerWaitResult::Err(err, _) => return Err(err),
856 }
857 }
858
859 let csv_file = output_path.join("chain_0.csv");
861 assert!(csv_file.exists());
862
863 let header = read_csv_header(&csv_file)?;
865
866 assert!(header.contains("param_matrix.x.alpha"));
868 assert!(header.contains("param_matrix.x.beta"));
869 assert!(header.contains("param_matrix.y.alpha"));
870 assert!(header.contains("param_matrix.y.beta"));
871 assert!(header.contains("scalar_value"));
872
873 let columns: Vec<&str> = header.split(',').collect();
875 let param_columns: Vec<&str> = columns
876 .iter()
877 .filter(|col| col.starts_with("param_matrix."))
878 .cloned()
879 .collect();
880
881 assert_eq!(
882 param_columns,
883 vec![
884 "param_matrix.x.alpha",
885 "param_matrix.x.beta",
886 "param_matrix.y.alpha",
887 "param_matrix.y.beta"
888 ]
889 );
890
891 Ok(())
892 }
893
894 #[test]
895 fn test_fallback_coordinate_naming() -> Result<()> {
896 let temp_dir = tempfile::tempdir()?;
897 let output_path = temp_dir.path().join("simple_test");
898
899 let model = SimpleTestModel {
901 math: CpuMath::new(SimpleTestLogp { dim: 3 }),
902 };
903
904 let mut settings = DiagNutsSettings::default();
905 settings.num_chains = 1;
906 settings.num_tune = 5;
907 settings.num_draws = 10;
908 settings.seed = 123;
909
910 let csv_config = CsvConfig::new(&output_path)
911 .with_precision(6)
912 .store_warmup(false);
913
914 let mut sampler = Some(Sampler::new(model, settings, csv_config, 1, None)?);
915
916 while let Some(sampler_) = sampler.take() {
918 match sampler_.wait_timeout(std::time::Duration::from_millis(100)) {
919 crate::SamplerWaitResult::Trace(_) => break,
920 crate::SamplerWaitResult::Timeout(s) => sampler = Some(s),
921 crate::SamplerWaitResult::Err(err, _) => return Err(err),
922 }
923 }
924
925 let csv_file = output_path.join("chain_0.csv");
927 assert!(csv_file.exists());
928
929 let header = read_csv_header(&csv_file)?;
931
932 assert!(header.contains("values.1"));
934 assert!(header.contains("values.2"));
935 assert!(header.contains("values.3"));
936
937 Ok(())
938 }
939
940 #[test]
941 fn test_cartesian_product_generation() {
942 let coord_sets = vec![
943 vec!["x".to_string(), "y".to_string()],
944 vec!["alpha".to_string(), "beta".to_string()],
945 ];
946 let dim_sizes = vec![2, 2];
947
948 let (names, indices) = cartesian_product_with_indices_column_major(&coord_sets, &dim_sizes);
949
950 assert_eq!(names, vec!["x.alpha", "x.beta", "y.alpha", "y.beta"]);
951
952 assert_eq!(indices, vec![0, 1, 2, 3]);
953 }
954
955 #[test]
956 fn test_single_dimension_coordinates() {
957 let coord_sets = vec![vec!["param1".to_string(), "param2".to_string()]];
958 let dim_sizes = vec![2];
959
960 let (names, indices) = cartesian_product_with_indices_column_major(&coord_sets, &dim_sizes);
961
962 assert_eq!(names, vec!["param1", "param2"]);
963 assert_eq!(indices, vec![0, 1]);
964 }
965
966 #[test]
967 fn test_three_dimension_cartesian_product() {
968 let coord_sets = vec![
969 vec!["a".to_string(), "b".to_string()],
970 vec!["1".to_string()],
971 vec!["i".to_string(), "j".to_string()],
972 ];
973 let dim_sizes = vec![2, 1, 2];
974
975 let (names, indices) = cartesian_product_with_indices_column_major(&coord_sets, &dim_sizes);
976
977 assert_eq!(names, vec!["a.1.i", "a.1.j", "b.1.i", "b.1.j"]);
978
979 assert_eq!(indices, vec![0, 1, 2, 3]);
980 }
981}