1use std::collections::HashMap;
8use std::fs::File;
9use std::io::{BufWriter, Write};
10use std::path::{Path, PathBuf};
11
12use anyhow::{Context, Result};
13use nuts_storable::{ItemType, Value};
14
15use crate::storage::{ChainStorage, StorageConfig, TraceStorage};
16use crate::{Math, Progress, Settings};
17
18pub struct CsvConfig {
30 output_dir: PathBuf,
32 precision: usize,
34 store_warmup: bool,
36}
37
38impl CsvConfig {
39 pub fn new<P: AsRef<Path>>(output_dir: P) -> Self {
52 Self {
53 output_dir: output_dir.as_ref().to_path_buf(),
54 precision: 6,
55 store_warmup: true,
56 }
57 }
58
59 pub fn with_precision(mut self, precision: usize) -> Self {
63 self.precision = precision;
64 self
65 }
66
67 pub fn store_warmup(mut self, store: bool) -> Self {
72 self.store_warmup = store;
73 self
74 }
75}
76
77pub struct CsvTraceStorage {
79 output_dir: PathBuf,
80 precision: usize,
81 store_warmup: bool,
82 parameter_names: Vec<String>,
83 column_mapping: Vec<(String, usize)>, }
85
86pub struct CsvChainStorage {
88 writer: BufWriter<File>,
89 precision: usize,
90 store_warmup: bool,
91 parameter_names: Vec<String>,
92 column_mapping: Vec<(String, usize)>, is_first_sample: bool,
94 headers_written: bool,
95}
96
97impl CsvChainStorage {
98 fn new(
100 output_dir: &Path,
101 chain_id: u64,
102 precision: usize,
103 store_warmup: bool,
104 parameter_names: Vec<String>,
105 column_mapping: Vec<(String, usize)>,
106 ) -> Result<Self> {
107 std::fs::create_dir_all(output_dir)
108 .with_context(|| format!("Failed to create output directory: {:?}", output_dir))?;
109
110 let file_path = output_dir.join(format!("chain_{}.csv", chain_id));
111 let file = File::create(&file_path)
112 .with_context(|| format!("Failed to create CSV file: {:?}", file_path))?;
113 let writer = BufWriter::new(file);
114
115 Ok(Self {
116 writer,
117 precision,
118 store_warmup,
119 parameter_names,
120 column_mapping,
121 is_first_sample: true,
122 headers_written: false,
123 })
124 }
125
126 fn write_header(&mut self) -> Result<()> {
128 if self.headers_written {
129 return Ok(());
130 }
131
132 let mut headers = vec![
134 "lp__".to_string(),
135 "accept_stat__".to_string(),
136 "stepsize__".to_string(),
137 "treedepth__".to_string(),
138 "n_leapfrog__".to_string(),
139 "divergent__".to_string(),
140 "energy__".to_string(),
141 ];
142
143 for param_name in &self.parameter_names {
145 headers.push(param_name.clone());
146 }
147
148 writeln!(self.writer, "{}", headers.join(","))?;
150 self.headers_written = true;
151 Ok(())
152 }
153
154 fn format_value(&self, value: &Value) -> String {
156 match value {
157 Value::ScalarF64(v) => {
158 if v.is_nan() {
159 "NA".to_string()
160 } else if v.is_infinite() {
161 if *v > 0.0 { "Inf" } else { "-Inf" }.to_string()
162 } else {
163 format!("{:.prec$}", v, prec = self.precision)
164 }
165 }
166 Value::ScalarF32(v) => {
167 if v.is_nan() {
168 "NA".to_string()
169 } else if v.is_infinite() {
170 if *v > 0.0 { "Inf" } else { "-Inf" }.to_string()
171 } else {
172 format!("{:.prec$}", v, prec = self.precision)
173 }
174 }
175 Value::ScalarU64(v) => v.to_string(),
176 Value::ScalarI64(v) => v.to_string(),
177 Value::ScalarBool(v) => if *v { "1" } else { "0" }.to_string(),
178 Value::F64(vec) => {
179 if vec.is_empty() {
182 "NA".to_string()
183 } else {
184 self.format_value(&Value::ScalarF64(vec[0]))
185 }
186 }
187 Value::F32(vec) => {
188 if vec.is_empty() {
189 "NA".to_string()
190 } else {
191 self.format_value(&Value::ScalarF32(vec[0]))
192 }
193 }
194 Value::U64(vec) => {
195 if vec.is_empty() {
196 "NA".to_string()
197 } else {
198 vec[0].to_string()
199 }
200 }
201 Value::I64(vec) => {
202 if vec.is_empty() {
203 "NA".to_string()
204 } else {
205 vec[0].to_string()
206 }
207 }
208 Value::Bool(vec) => {
209 if vec.is_empty() {
210 "NA".to_string()
211 } else {
212 if vec[0] { "1" } else { "0" }.to_string()
213 }
214 }
215 Value::ScalarString(v) => v.clone(),
216 Value::Strings(vec) => {
217 if vec.is_empty() {
218 "NA".to_string()
219 } else {
220 vec[0].clone()
221 }
222 }
223 }
224 }
225
226 fn write_sample_row(
228 &mut self,
229 stats: &Vec<(&str, Option<Value>)>,
230 draws: &Vec<(&str, Option<Value>)>,
231 _info: &Progress,
232 ) -> Result<()> {
233 let mut row_values = Vec::new();
234
235 let stats_map: HashMap<&str, &Option<Value>> = stats.iter().map(|(k, v)| (*k, v)).collect();
237 let draws_map: HashMap<&str, &Option<Value>> = draws.iter().map(|(k, v)| (*k, v)).collect();
238
239 let get_stat_value = |name: &str| -> String {
241 stats_map
242 .get(name)
243 .and_then(|opt| opt.as_ref())
244 .map(|v| self.format_value(v))
245 .unwrap_or_else(|| "NA".to_string())
246 };
247
248 row_values.push(get_stat_value("logp"));
249 row_values.push(get_stat_value("mean_tree_accept"));
250 row_values.push(get_stat_value("step_size"));
251 row_values.push(get_stat_value("depth"));
252 row_values.push(get_stat_value("n_steps"));
253 let divergent_val = stats_map
254 .get("diverging")
255 .and_then(|opt| opt.as_ref())
256 .map(|v| match v {
257 Value::ScalarBool(true) => "1".to_string(),
258 Value::ScalarBool(false) => "0".to_string(),
259 _ => "0".to_string(),
260 })
261 .unwrap_or_else(|| "0".to_string());
262 row_values.push(divergent_val);
263
264 row_values.push(get_stat_value("energy"));
265
266 for (_param_name, (data_name, index)) in
268 self.parameter_names.iter().zip(&self.column_mapping)
269 {
270 if let Some(Some(data_value)) = draws_map.get(data_name.as_str()) {
271 let formatted_value = match data_value {
272 Value::F64(vec) => {
273 if *index < vec.len() {
274 self.format_value(&Value::ScalarF64(vec[*index]))
275 } else {
276 "NA".to_string()
277 }
278 }
279 Value::F32(vec) => {
280 if *index < vec.len() {
281 self.format_value(&Value::ScalarF32(vec[*index]))
282 } else {
283 "NA".to_string()
284 }
285 }
286 Value::I64(vec) => {
287 if *index < vec.len() {
288 self.format_value(&Value::ScalarI64(vec[*index]))
289 } else {
290 "NA".to_string()
291 }
292 }
293 Value::U64(vec) => {
294 if *index < vec.len() {
295 self.format_value(&Value::ScalarU64(vec[*index]))
296 } else {
297 "NA".to_string()
298 }
299 }
300 scalar_val if *index == 0 => self.format_value(scalar_val),
302 _ => "NA".to_string(),
303 };
304 row_values.push(formatted_value);
305 } else {
306 row_values.push("NA".to_string());
307 }
308 }
309
310 writeln!(self.writer, "{}", row_values.join(","))?;
312 Ok(())
313 }
314}
315
316impl ChainStorage for CsvChainStorage {
317 type Finalized = ();
318
319 fn record_sample(
320 &mut self,
321 _settings: &impl Settings,
322 stats: Vec<(&str, Option<Value>)>,
323 draws: Vec<(&str, Option<Value>)>,
324 info: &Progress,
325 ) -> Result<()> {
326 if info.tuning && !self.store_warmup {
328 return Ok(());
329 }
330
331 if self.is_first_sample {
333 self.write_header()?;
334 self.is_first_sample = false;
335 }
336
337 self.write_sample_row(&stats, &draws, info)?;
338 Ok(())
339 }
340
341 fn finalize(mut self) -> Result<Self::Finalized> {
342 self.writer.flush().context("Failed to flush CSV file")?;
343 Ok(())
344 }
345
346 fn flush(&self) -> Result<()> {
347 Ok(())
350 }
351
352 fn inspect(&self) -> Result<Option<Self::Finalized>> {
353 self.flush()?;
355 Ok(None)
356 }
357}
358
359impl StorageConfig for CsvConfig {
360 type Storage = CsvTraceStorage;
361
362 fn new_trace<M: Math>(self, settings: &impl Settings, math: &M) -> Result<Self::Storage> {
363 let (parameter_names, column_mapping) =
365 generate_parameter_names_and_mapping(settings, math)?;
366
367 Ok(CsvTraceStorage {
368 output_dir: self.output_dir,
369 precision: self.precision,
370 store_warmup: self.store_warmup,
371 parameter_names,
372 column_mapping,
373 })
374 }
375}
376
377fn generate_parameter_names_and_mapping<M: Math>(
379 settings: &impl Settings,
380 math: &M,
381) -> Result<(Vec<String>, Vec<(String, usize)>)> {
382 let data_dims = settings.data_dims_all(math);
383 let coords = math.coords();
384 let mut parameter_names = Vec::new();
385 let mut column_mapping = Vec::new();
386
387 for (var_name, var_dims) in data_dims {
388 let data_type = settings.data_type(math, &var_name);
389
390 if matches!(
392 data_type,
393 ItemType::F64 | ItemType::F32 | ItemType::I64 | ItemType::U64
394 ) {
395 let (column_names, indices) = generate_column_names_and_indices_for_variable(
396 &var_name, &var_dims, &coords, math,
397 )?;
398
399 for (name, index) in column_names.into_iter().zip(indices) {
400 parameter_names.push(name);
401 column_mapping.push((var_name.clone(), index));
402 }
403 }
404 }
405
406 if parameter_names.is_empty() {
408 let dim_sizes = math.dim_sizes();
409 let param_count = dim_sizes.get("expanded_parameter").unwrap_or(&0);
410 for i in 0..*param_count {
411 parameter_names.push(format!("param_{}", i + 1));
412 let data_names = settings.data_names(math);
414 let mut found_field = false;
415 for data_name in &data_names {
416 let data_type = settings.data_type(math, data_name);
417 if matches!(
418 data_type,
419 ItemType::F64 | ItemType::F32 | ItemType::I64 | ItemType::U64
420 ) {
421 column_mapping.push((data_name.clone(), i as usize));
422 found_field = true;
423 break;
424 }
425 }
426 if !found_field {
427 column_mapping.push(("unknown".to_string(), i as usize));
428 }
429 }
430 }
431
432 Ok((parameter_names, column_mapping))
433}
434
435fn generate_column_names_and_indices_for_variable<M: Math>(
437 var_name: &str,
438 var_dims: &[String],
439 coords: &HashMap<String, Value>,
440 math: &M,
441) -> Result<(Vec<String>, Vec<usize>)> {
442 let dim_sizes = math.dim_sizes();
443
444 if var_dims.is_empty() {
445 return Ok((vec![var_name.to_string()], vec![0]));
447 }
448
449 let has_meaningful_coords = var_dims.iter().all(|dim_name| {
451 coords.get(dim_name).is_some_and(
452 |coord_value| matches!(coord_value, Value::Strings(labels) if !labels.is_empty()),
453 )
454 });
455
456 let mut dim_coords: Vec<Vec<String>> = Vec::new();
458 let mut dim_sizes_vec: Vec<usize> = Vec::new();
459
460 for dim_name in var_dims {
461 let size = *dim_sizes.get(dim_name).unwrap_or(&1) as usize;
462 dim_sizes_vec.push(size);
463
464 if has_meaningful_coords {
465 if let Some(coord_value) = coords.get(dim_name) {
467 match coord_value {
468 Value::Strings(labels) => {
469 dim_coords.push(labels.clone());
470 }
471 _ => {
472 dim_coords.push((1..=size).map(|i| i.to_string()).collect());
474 }
475 }
476 } else {
477 dim_coords.push((1..=size).map(|i| i.to_string()).collect());
479 }
480 } else {
481 dim_coords.push((1..=size).map(|i| i.to_string()).collect());
483 }
484 }
485
486 let (coord_names, indices) =
488 cartesian_product_with_indices_column_major(&dim_coords, &dim_sizes_vec);
489
490 let column_names: Vec<String> = coord_names
492 .into_iter()
493 .map(|coord| format!("{}.{}", var_name, coord))
494 .collect();
495
496 Ok((column_names, indices))
497}
498
499fn cartesian_product_with_indices_column_major(
505 coord_sets: &[Vec<String>],
506 dim_sizes: &[usize],
507) -> (Vec<String>, Vec<usize>) {
508 if coord_sets.is_empty() {
509 return (vec![], vec![]);
510 }
511
512 if coord_sets.len() == 1 {
513 let indices: Vec<usize> = (0..coord_sets[0].len()).collect();
514 return (coord_sets[0].clone(), indices);
515 }
516
517 let mut names = vec![];
518 let mut indices = vec![];
519
520 cartesian_product_recursive_with_indices(
522 coord_sets,
523 dim_sizes,
524 0,
525 &mut String::new(),
526 &mut vec![],
527 &mut names,
528 &mut indices,
529 );
530
531 (names, indices)
532}
533
534fn cartesian_product_recursive_with_indices(
535 coord_sets: &[Vec<String>],
536 dim_sizes: &[usize],
537 dim_idx: usize,
538 current_name: &mut String,
539 current_indices: &mut Vec<usize>,
540 result_names: &mut Vec<String>,
541 result_indices: &mut Vec<usize>,
542) {
543 if dim_idx == coord_sets.len() {
544 result_names.push(current_name.clone());
545 let mut linear_index = 0;
547 for (i, &idx) in current_indices.iter().enumerate() {
548 let mut stride = 1;
549 for &size in &dim_sizes[i + 1..] {
550 stride *= size;
551 }
552 linear_index += idx * stride;
553 }
554 result_indices.push(linear_index);
555 return;
556 }
557
558 let is_first_dim = dim_idx == 0;
559
560 for (coord_idx, coord) in coord_sets[dim_idx].iter().enumerate() {
561 let mut new_name = current_name.clone();
562 if !is_first_dim {
563 new_name.push('.');
564 }
565 new_name.push_str(coord);
566
567 current_indices.push(coord_idx);
568 cartesian_product_recursive_with_indices(
569 coord_sets,
570 dim_sizes,
571 dim_idx + 1,
572 &mut new_name,
573 current_indices,
574 result_names,
575 result_indices,
576 );
577 current_indices.pop();
578 }
579}
580
581impl TraceStorage for CsvTraceStorage {
582 type ChainStorage = CsvChainStorage;
583 type Finalized = ();
584
585 fn initialize_trace_for_chain(&self, chain_id: u64) -> Result<Self::ChainStorage> {
586 CsvChainStorage::new(
587 &self.output_dir,
588 chain_id,
589 self.precision,
590 self.store_warmup,
591 self.parameter_names.clone(),
592 self.column_mapping.clone(),
593 )
594 }
595
596 fn finalize(
597 self,
598 traces: Vec<Result<<Self::ChainStorage as ChainStorage>::Finalized>>,
599 ) -> Result<(Option<anyhow::Error>, Self::Finalized)> {
600 for trace_result in traces {
602 if let Err(err) = trace_result {
603 return Ok((Some(err), ()));
604 }
605 }
606 Ok((None, ()))
607 }
608
609 fn inspect(
610 &self,
611 traces: Vec<Result<Option<<Self::ChainStorage as ChainStorage>::Finalized>>>,
612 ) -> Result<(Option<anyhow::Error>, Self::Finalized)> {
613 for trace_result in traces {
615 if let Err(err) = trace_result {
616 return Ok((Some(err), ()));
617 }
618 }
619 Ok((None, ()))
620 }
621}
622
623#[cfg(test)]
624mod tests {
625 use super::*;
626 use crate::{
627 CpuLogpFunc, CpuMath, CpuMathError, DiagGradNutsSettings, LogpError, Model, Sampler,
628 };
629 use anyhow::Result;
630 use nuts_derive::Storable;
631 use nuts_storable::{HasDims, Value};
632 use rand::Rng;
633 use std::collections::HashMap;
634 use std::fs;
635 use std::path::Path;
636 use thiserror::Error;
637
638 #[allow(dead_code)]
639 #[derive(Debug, Error)]
640 enum TestLogpError {
641 #[error("Test error")]
642 Test,
643 }
644
645 impl LogpError for TestLogpError {
646 fn is_recoverable(&self) -> bool {
647 false
648 }
649 }
650
651 #[derive(Clone)]
653 struct MultiDimTestLogp {
654 dim_a: usize,
655 dim_b: usize,
656 }
657
658 impl HasDims for MultiDimTestLogp {
659 fn dim_sizes(&self) -> HashMap<String, u64> {
660 HashMap::from([
661 ("a".to_string(), self.dim_a as u64),
662 ("b".to_string(), self.dim_b as u64),
663 ])
664 }
665
666 fn coords(&self) -> HashMap<String, Value> {
667 HashMap::from([
668 (
669 "a".to_string(),
670 Value::Strings(vec!["x".to_string(), "y".to_string()]),
671 ),
672 (
673 "b".to_string(),
674 Value::Strings(vec!["alpha".to_string(), "beta".to_string()]),
675 ),
676 ])
677 }
678 }
679
680 #[derive(Storable)]
681 struct MultiDimExpandedDraw {
682 #[storable(dims("a", "b"))]
683 param_matrix: Vec<f64>,
684 scalar_value: f64,
685 }
686
687 impl CpuLogpFunc for MultiDimTestLogp {
688 type LogpError = TestLogpError;
689 type FlowParameters = ();
690 type ExpandedVector = MultiDimExpandedDraw;
691
692 fn dim(&self) -> usize {
693 self.dim_a * self.dim_b
694 }
695
696 fn logp(&mut self, x: &[f64], grad: &mut [f64]) -> Result<f64, Self::LogpError> {
697 let mut logp = 0.0;
698 for (i, &xi) in x.iter().enumerate() {
699 logp -= 0.5 * xi * xi;
700 grad[i] = -xi;
701 }
702 Ok(logp)
703 }
704
705 fn expand_vector<R: Rng + ?Sized>(
706 &mut self,
707 _rng: &mut R,
708 array: &[f64],
709 ) -> Result<Self::ExpandedVector, CpuMathError> {
710 Ok(MultiDimExpandedDraw {
711 param_matrix: array.to_vec(),
712 scalar_value: array.iter().sum(),
713 })
714 }
715
716 fn vector_coord(&self) -> Option<Value> {
717 Some(Value::Strings(
718 (0..self.dim()).map(|i| format!("theta{}", i + 1)).collect(),
719 ))
720 }
721 }
722
723 struct MultiDimTestModel {
724 math: CpuMath<MultiDimTestLogp>,
725 }
726
727 impl Model for MultiDimTestModel {
728 type Math<'model>
729 = CpuMath<MultiDimTestLogp>
730 where
731 Self: 'model;
732
733 fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
734 Ok(self.math.clone())
735 }
736
737 fn init_position<R: Rng + ?Sized>(&self, rng: &mut R, position: &mut [f64]) -> Result<()> {
738 for p in position.iter_mut() {
739 *p = rng.random_range(-1.0..1.0);
740 }
741 Ok(())
742 }
743 }
744
745 #[derive(Clone)]
747 struct SimpleTestLogp {
748 dim: usize,
749 }
750
751 impl HasDims for SimpleTestLogp {
752 fn dim_sizes(&self) -> HashMap<String, u64> {
753 HashMap::from([("simple_param".to_string(), self.dim as u64)])
754 }
755 }
757
758 #[derive(Storable)]
759 struct SimpleExpandedDraw {
760 #[storable(dims("simple_param"))]
761 values: Vec<f64>,
762 }
763
764 impl CpuLogpFunc for SimpleTestLogp {
765 type LogpError = TestLogpError;
766 type FlowParameters = ();
767 type ExpandedVector = SimpleExpandedDraw;
768
769 fn dim(&self) -> usize {
770 self.dim
771 }
772
773 fn logp(&mut self, x: &[f64], grad: &mut [f64]) -> Result<f64, Self::LogpError> {
774 let mut logp = 0.0;
775 for (i, &xi) in x.iter().enumerate() {
776 logp -= 0.5 * xi * xi;
777 grad[i] = -xi;
778 }
779 Ok(logp)
780 }
781
782 fn expand_vector<R: Rng + ?Sized>(
783 &mut self,
784 _rng: &mut R,
785 array: &[f64],
786 ) -> Result<Self::ExpandedVector, CpuMathError> {
787 Ok(SimpleExpandedDraw {
788 values: array.to_vec(),
789 })
790 }
791
792 fn vector_coord(&self) -> Option<Value> {
793 Some(Value::Strings(vec![
794 "param1".to_string(),
795 "param2".to_string(),
796 "param3".to_string(),
797 ]))
798 }
799 }
800
801 struct SimpleTestModel {
802 math: CpuMath<SimpleTestLogp>,
803 }
804
805 impl Model for SimpleTestModel {
806 type Math<'model>
807 = CpuMath<SimpleTestLogp>
808 where
809 Self: 'model;
810
811 fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
812 Ok(self.math.clone())
813 }
814
815 fn init_position<R: Rng + ?Sized>(&self, rng: &mut R, position: &mut [f64]) -> Result<()> {
816 for p in position.iter_mut() {
817 *p = rng.random_range(-1.0..1.0);
818 }
819 Ok(())
820 }
821 }
822
823 fn read_csv_header(path: &Path) -> Result<String> {
824 let content = fs::read_to_string(path)?;
825 content
826 .lines()
827 .next()
828 .map(|s| s.to_string())
829 .ok_or_else(|| anyhow::anyhow!("Empty CSV file"))
830 }
831
832 #[test]
833 fn test_multidim_coordinate_naming() -> Result<()> {
834 let temp_dir = tempfile::tempdir()?;
835 let output_path = temp_dir.path().join("multidim_test");
836
837 let model = MultiDimTestModel {
839 math: CpuMath::new(MultiDimTestLogp { dim_a: 2, dim_b: 2 }),
840 };
841
842 let mut settings = DiagGradNutsSettings::default();
843 settings.num_chains = 1;
844 settings.num_tune = 10;
845 settings.num_draws = 20;
846 settings.seed = 42;
847
848 let csv_config = CsvConfig::new(&output_path)
849 .with_precision(6)
850 .store_warmup(false);
851
852 let mut sampler = Some(Sampler::new(model, settings, csv_config, 1, None)?);
853
854 while let Some(sampler_) = sampler.take() {
856 match sampler_.wait_timeout(std::time::Duration::from_millis(100)) {
857 crate::SamplerWaitResult::Trace(_) => break,
858 crate::SamplerWaitResult::Timeout(s) => sampler = Some(s),
859 crate::SamplerWaitResult::Err(err, _) => return Err(err),
860 }
861 }
862
863 let csv_file = output_path.join("chain_0.csv");
865 assert!(csv_file.exists());
866
867 let header = read_csv_header(&csv_file)?;
869
870 assert!(header.contains("param_matrix.x.alpha"));
872 assert!(header.contains("param_matrix.x.beta"));
873 assert!(header.contains("param_matrix.y.alpha"));
874 assert!(header.contains("param_matrix.y.beta"));
875 assert!(header.contains("scalar_value"));
876
877 let columns: Vec<&str> = header.split(',').collect();
879 let param_columns: Vec<&str> = columns
880 .iter()
881 .filter(|col| col.starts_with("param_matrix."))
882 .cloned()
883 .collect();
884
885 assert_eq!(
886 param_columns,
887 vec![
888 "param_matrix.x.alpha",
889 "param_matrix.x.beta",
890 "param_matrix.y.alpha",
891 "param_matrix.y.beta"
892 ]
893 );
894
895 Ok(())
896 }
897
898 #[test]
899 fn test_fallback_coordinate_naming() -> Result<()> {
900 let temp_dir = tempfile::tempdir()?;
901 let output_path = temp_dir.path().join("simple_test");
902
903 let model = SimpleTestModel {
905 math: CpuMath::new(SimpleTestLogp { dim: 3 }),
906 };
907
908 let mut settings = DiagGradNutsSettings::default();
909 settings.num_chains = 1;
910 settings.num_tune = 5;
911 settings.num_draws = 10;
912 settings.seed = 123;
913
914 let csv_config = CsvConfig::new(&output_path)
915 .with_precision(6)
916 .store_warmup(false);
917
918 let mut sampler = Some(Sampler::new(model, settings, csv_config, 1, None)?);
919
920 while let Some(sampler_) = sampler.take() {
922 match sampler_.wait_timeout(std::time::Duration::from_millis(100)) {
923 crate::SamplerWaitResult::Trace(_) => break,
924 crate::SamplerWaitResult::Timeout(s) => sampler = Some(s),
925 crate::SamplerWaitResult::Err(err, _) => return Err(err),
926 }
927 }
928
929 let csv_file = output_path.join("chain_0.csv");
931 assert!(csv_file.exists());
932
933 let header = read_csv_header(&csv_file)?;
935
936 assert!(header.contains("values.1"));
938 assert!(header.contains("values.2"));
939 assert!(header.contains("values.3"));
940
941 Ok(())
942 }
943
944 #[test]
945 fn test_cartesian_product_generation() {
946 let coord_sets = vec![
947 vec!["x".to_string(), "y".to_string()],
948 vec!["alpha".to_string(), "beta".to_string()],
949 ];
950 let dim_sizes = vec![2, 2];
951
952 let (names, indices) = cartesian_product_with_indices_column_major(&coord_sets, &dim_sizes);
953
954 assert_eq!(names, vec!["x.alpha", "x.beta", "y.alpha", "y.beta"]);
955
956 assert_eq!(indices, vec![0, 1, 2, 3]);
957 }
958
959 #[test]
960 fn test_single_dimension_coordinates() {
961 let coord_sets = vec![vec!["param1".to_string(), "param2".to_string()]];
962 let dim_sizes = vec![2];
963
964 let (names, indices) = cartesian_product_with_indices_column_major(&coord_sets, &dim_sizes);
965
966 assert_eq!(names, vec!["param1", "param2"]);
967 assert_eq!(indices, vec![0, 1]);
968 }
969
970 #[test]
971 fn test_three_dimension_cartesian_product() {
972 let coord_sets = vec![
973 vec!["a".to_string(), "b".to_string()],
974 vec!["1".to_string()],
975 vec!["i".to_string(), "j".to_string()],
976 ];
977 let dim_sizes = vec![2, 1, 2];
978
979 let (names, indices) = cartesian_product_with_indices_column_major(&coord_sets, &dim_sizes);
980
981 assert_eq!(names, vec!["a.1.i", "a.1.j", "b.1.i", "b.1.j"]);
982
983 assert_eq!(indices, vec![0, 1, 2, 3]);
984 }
985}