nuts_rs/storage/
csv.rs

1//! CSV storage backend for nuts-rs that outputs CmdStan-compatible CSV files
2//!
3//! This module provides a CSV storage backend that saves MCMC samples and
4//! statistics in a format compatible with CmdStan, allowing existing Stan
5//! analysis tools and libraries to read nuts-rs results.
6
7use std::collections::HashMap;
8use std::fs::File;
9use std::io::{BufWriter, Write};
10use std::path::{Path, PathBuf};
11
12use anyhow::{Context, Result};
13use nuts_storable::{ItemType, Value};
14
15use crate::storage::{ChainStorage, StorageConfig, TraceStorage};
16use crate::{Math, Progress, Settings};
17
18/// Configuration for CSV-based MCMC storage.
19///
20/// This storage backend creates Stan-compatible CSV files with one file per chain.
21/// Files are named `chain_{id}.csv` where `{id}` is the chain number starting from 0.
22///
23/// The CSV format matches CmdStan output:
24/// - Header row with column names
25/// - Warmup samples with negative sample_id
26/// - Post-warmup samples with positive sample_id
27/// - Standard Stan statistics (lp__, stepsize, treedepth, etc.)
28/// - Parameter columns
29pub struct CsvConfig {
30    /// Directory where CSV files will be written
31    output_dir: PathBuf,
32    /// Number of decimal places for floating point values
33    precision: usize,
34    /// Whether to store warmup samples (default: true)
35    store_warmup: bool,
36}
37
38impl CsvConfig {
39    /// Create a new CSV configuration.
40    ///
41    /// # Arguments
42    ///
43    /// * `output_dir` - Directory where CSV files will be written
44    ///
45    /// # Example
46    ///
47    /// ```rust
48    /// use nuts_rs::CsvConfig;
49    /// let config = CsvConfig::new("mcmc_output");
50    /// ```
51    pub fn new<P: AsRef<Path>>(output_dir: P) -> Self {
52        Self {
53            output_dir: output_dir.as_ref().to_path_buf(),
54            precision: 6,
55            store_warmup: true,
56        }
57    }
58
59    /// Set the precision (number of decimal places) for floating point values.
60    ///
61    /// Default is 6 decimal places.
62    pub fn with_precision(mut self, precision: usize) -> Self {
63        self.precision = precision;
64        self
65    }
66
67    /// Configure whether to store warmup samples.
68    ///
69    /// When true (default), warmup samples are included with negative sample IDs.
70    /// When false, only post-warmup samples are stored.
71    pub fn store_warmup(mut self, store: bool) -> Self {
72        self.store_warmup = store;
73        self
74    }
75}
76
77/// Main CSV storage managing multiple chains
78pub struct CsvTraceStorage {
79    output_dir: PathBuf,
80    precision: usize,
81    store_warmup: bool,
82    parameter_names: Vec<String>,
83    column_mapping: Vec<(String, usize)>, // (data_name, index_in_data)
84}
85
86/// Per-chain CSV storage
87pub struct CsvChainStorage {
88    writer: BufWriter<File>,
89    precision: usize,
90    store_warmup: bool,
91    parameter_names: Vec<String>,
92    column_mapping: Vec<(String, usize)>, // (data_name, index_in_data)
93    is_first_sample: bool,
94    headers_written: bool,
95}
96
97impl CsvChainStorage {
98    /// Create a new CSV chain storage
99    fn new(
100        output_dir: &Path,
101        chain_id: u64,
102        precision: usize,
103        store_warmup: bool,
104        parameter_names: Vec<String>,
105        column_mapping: Vec<(String, usize)>,
106    ) -> Result<Self> {
107        std::fs::create_dir_all(output_dir)
108            .with_context(|| format!("Failed to create output directory: {:?}", output_dir))?;
109
110        let file_path = output_dir.join(format!("chain_{}.csv", chain_id));
111        let file = File::create(&file_path)
112            .with_context(|| format!("Failed to create CSV file: {:?}", file_path))?;
113        let writer = BufWriter::new(file);
114
115        Ok(Self {
116            writer,
117            precision,
118            store_warmup,
119            parameter_names,
120            column_mapping,
121            is_first_sample: true,
122            headers_written: false,
123        })
124    }
125
126    /// Write the CSV header row
127    fn write_header(&mut self) -> Result<()> {
128        if self.headers_written {
129            return Ok(());
130        }
131
132        // Standard CmdStan header format - only the core columns
133        let mut headers = vec![
134            "lp__".to_string(),
135            "accept_stat__".to_string(),
136            "stepsize__".to_string(),
137            "treedepth__".to_string(),
138            "n_leapfrog__".to_string(),
139            "divergent__".to_string(),
140            "energy__".to_string(),
141        ];
142
143        // Add parameter columns from the expanded parameter vector
144        for param_name in &self.parameter_names {
145            headers.push(param_name.clone());
146        }
147
148        // Write header row
149        writeln!(self.writer, "{}", headers.join(","))?;
150        self.headers_written = true;
151        Ok(())
152    }
153
154    /// Format a value for CSV output
155    fn format_value(&self, value: &Value) -> String {
156        match value {
157            Value::ScalarF64(v) => {
158                if v.is_nan() {
159                    "NA".to_string()
160                } else if v.is_infinite() {
161                    if *v > 0.0 { "Inf" } else { "-Inf" }.to_string()
162                } else {
163                    format!("{:.prec$}", v, prec = self.precision)
164                }
165            }
166            Value::ScalarF32(v) => {
167                if v.is_nan() {
168                    "NA".to_string()
169                } else if v.is_infinite() {
170                    if *v > 0.0 { "Inf" } else { "-Inf" }.to_string()
171                } else {
172                    format!("{:.prec$}", v, prec = self.precision)
173                }
174            }
175            Value::ScalarU64(v) => v.to_string(),
176            Value::ScalarI64(v) => v.to_string(),
177            Value::ScalarBool(v) => if *v { "1" } else { "0" }.to_string(),
178            Value::F64(vec) => {
179                // For vector values, we'll just use the first element for now
180                // A more sophisticated implementation would handle multi-dimensional parameters
181                if vec.is_empty() {
182                    "NA".to_string()
183                } else {
184                    self.format_value(&Value::ScalarF64(vec[0]))
185                }
186            }
187            Value::F32(vec) => {
188                if vec.is_empty() {
189                    "NA".to_string()
190                } else {
191                    self.format_value(&Value::ScalarF32(vec[0]))
192                }
193            }
194            Value::U64(vec) => {
195                if vec.is_empty() {
196                    "NA".to_string()
197                } else {
198                    vec[0].to_string()
199                }
200            }
201            Value::I64(vec) => {
202                if vec.is_empty() {
203                    "NA".to_string()
204                } else {
205                    vec[0].to_string()
206                }
207            }
208            Value::Bool(vec) => {
209                if vec.is_empty() {
210                    "NA".to_string()
211                } else {
212                    if vec[0] { "1" } else { "0" }.to_string()
213                }
214            }
215            Value::ScalarString(v) => v.clone(),
216            Value::Strings(vec) => {
217                if vec.is_empty() {
218                    "NA".to_string()
219                } else {
220                    vec[0].clone()
221                }
222            }
223        }
224    }
225
226    /// Write a single sample row to the CSV file
227    fn write_sample_row(
228        &mut self,
229        stats: &Vec<(&str, Option<Value>)>,
230        draws: &Vec<(&str, Option<Value>)>,
231        _info: &Progress,
232    ) -> Result<()> {
233        let mut row_values = Vec::new();
234
235        // Create lookup maps for quick access
236        let stats_map: HashMap<&str, &Option<Value>> = stats.iter().map(|(k, v)| (*k, v)).collect();
237        let draws_map: HashMap<&str, &Option<Value>> = draws.iter().map(|(k, v)| (*k, v)).collect();
238
239        // Helper function to get stat value
240        let get_stat_value = |name: &str| -> String {
241            stats_map
242                .get(name)
243                .and_then(|opt| opt.as_ref())
244                .map(|v| self.format_value(v))
245                .unwrap_or_else(|| "NA".to_string())
246        };
247
248        row_values.push(get_stat_value("logp"));
249        row_values.push(get_stat_value("mean_tree_accept"));
250        row_values.push(get_stat_value("step_size"));
251        row_values.push(get_stat_value("depth"));
252        row_values.push(get_stat_value("n_steps"));
253        let divergent_val = stats_map
254            .get("diverging")
255            .and_then(|opt| opt.as_ref())
256            .map(|v| match v {
257                Value::ScalarBool(true) => "1".to_string(),
258                Value::ScalarBool(false) => "0".to_string(),
259                _ => "0".to_string(),
260            })
261            .unwrap_or_else(|| "0".to_string());
262        row_values.push(divergent_val);
263
264        row_values.push(get_stat_value("energy"));
265
266        // Add parameter values using the column mapping
267        for (_param_name, (data_name, index)) in
268            self.parameter_names.iter().zip(&self.column_mapping)
269        {
270            if let Some(Some(data_value)) = draws_map.get(data_name.as_str()) {
271                let formatted_value = match data_value {
272                    Value::F64(vec) => {
273                        if *index < vec.len() {
274                            self.format_value(&Value::ScalarF64(vec[*index]))
275                        } else {
276                            "NA".to_string()
277                        }
278                    }
279                    Value::F32(vec) => {
280                        if *index < vec.len() {
281                            self.format_value(&Value::ScalarF32(vec[*index]))
282                        } else {
283                            "NA".to_string()
284                        }
285                    }
286                    Value::I64(vec) => {
287                        if *index < vec.len() {
288                            self.format_value(&Value::ScalarI64(vec[*index]))
289                        } else {
290                            "NA".to_string()
291                        }
292                    }
293                    Value::U64(vec) => {
294                        if *index < vec.len() {
295                            self.format_value(&Value::ScalarU64(vec[*index]))
296                        } else {
297                            "NA".to_string()
298                        }
299                    }
300                    // Handle scalar values (index should be 0)
301                    scalar_val if *index == 0 => self.format_value(scalar_val),
302                    _ => "NA".to_string(),
303                };
304                row_values.push(formatted_value);
305            } else {
306                row_values.push("NA".to_string());
307            }
308        }
309
310        // Write the row
311        writeln!(self.writer, "{}", row_values.join(","))?;
312        Ok(())
313    }
314}
315
316impl ChainStorage for CsvChainStorage {
317    type Finalized = ();
318
319    fn record_sample(
320        &mut self,
321        _settings: &impl Settings,
322        stats: Vec<(&str, Option<Value>)>,
323        draws: Vec<(&str, Option<Value>)>,
324        info: &Progress,
325    ) -> Result<()> {
326        // Skip warmup samples if not storing them
327        if info.tuning && !self.store_warmup {
328            return Ok(());
329        }
330
331        // Write header on first sample
332        if self.is_first_sample {
333            self.write_header()?;
334            self.is_first_sample = false;
335        }
336
337        self.write_sample_row(&stats, &draws, info)?;
338        Ok(())
339    }
340
341    fn finalize(mut self) -> Result<Self::Finalized> {
342        self.writer.flush().context("Failed to flush CSV file")?;
343        Ok(())
344    }
345
346    fn flush(&self) -> Result<()> {
347        // BufWriter doesn't provide a way to flush without mutable reference
348        // In practice, the buffer will be flushed when the file is closed
349        Ok(())
350    }
351
352    fn inspect(&self) -> Result<Option<Self::Finalized>> {
353        // For CSV storage, inspection does not produce a finalized result
354        self.flush()?;
355        Ok(None)
356    }
357}
358
359impl StorageConfig for CsvConfig {
360    type Storage = CsvTraceStorage;
361
362    fn new_trace<M: Math>(self, settings: &impl Settings, math: &M) -> Result<Self::Storage> {
363        // Generate parameter names and column mapping using coordinates
364        let (parameter_names, column_mapping) =
365            generate_parameter_names_and_mapping(settings, math)?;
366
367        Ok(CsvTraceStorage {
368            output_dir: self.output_dir,
369            precision: self.precision,
370            store_warmup: self.store_warmup,
371            parameter_names,
372            column_mapping,
373        })
374    }
375}
376
377/// Generate parameter column names and mapping using coordinates or Stan-compliant indexing
378fn generate_parameter_names_and_mapping<M: Math>(
379    settings: &impl Settings,
380    math: &M,
381) -> Result<(Vec<String>, Vec<(String, usize)>)> {
382    let data_dims = settings.data_dims_all(math);
383    let coords = math.coords();
384    let mut parameter_names = Vec::new();
385    let mut column_mapping = Vec::new();
386
387    for (var_name, var_dims) in data_dims {
388        let data_type = settings.data_type(math, &var_name);
389
390        // Only process vector types that could contain parameter values
391        if matches!(
392            data_type,
393            ItemType::F64 | ItemType::F32 | ItemType::I64 | ItemType::U64
394        ) {
395            let (column_names, indices) = generate_column_names_and_indices_for_variable(
396                &var_name, &var_dims, &coords, math,
397            )?;
398
399            for (name, index) in column_names.into_iter().zip(indices) {
400                parameter_names.push(name);
401                column_mapping.push((var_name.clone(), index));
402            }
403        }
404    }
405
406    // If no parameter names were generated, fall back to simple numbering
407    if parameter_names.is_empty() {
408        let dim_sizes = math.dim_sizes();
409        let param_count = dim_sizes.get("expanded_parameter").unwrap_or(&0);
410        for i in 0..*param_count {
411            parameter_names.push(format!("param_{}", i + 1));
412            // Try to find a data field that contains the parameters
413            let data_names = settings.data_names(math);
414            let mut found_field = false;
415            for data_name in &data_names {
416                let data_type = settings.data_type(math, data_name);
417                if matches!(
418                    data_type,
419                    ItemType::F64 | ItemType::F32 | ItemType::I64 | ItemType::U64
420                ) {
421                    column_mapping.push((data_name.clone(), i as usize));
422                    found_field = true;
423                    break;
424                }
425            }
426            if !found_field {
427                column_mapping.push(("unknown".to_string(), i as usize));
428            }
429        }
430    }
431
432    Ok((parameter_names, column_mapping))
433}
434
435/// Generate column names and indices for a single variable using its dimensions and coordinates
436fn generate_column_names_and_indices_for_variable<M: Math>(
437    var_name: &str,
438    var_dims: &[String],
439    coords: &HashMap<String, Value>,
440    math: &M,
441) -> Result<(Vec<String>, Vec<usize>)> {
442    let dim_sizes = math.dim_sizes();
443
444    if var_dims.is_empty() {
445        // Scalar variable
446        return Ok((vec![var_name.to_string()], vec![0]));
447    }
448
449    // Check if we have meaningful coordinate names for all dimensions
450    let has_meaningful_coords = var_dims.iter().all(|dim_name| {
451        coords.get(dim_name).is_some_and(
452            |coord_value| matches!(coord_value, Value::Strings(labels) if !labels.is_empty()),
453        )
454    });
455
456    // Get coordinate labels for each dimension
457    let mut dim_coords: Vec<Vec<String>> = Vec::new();
458    let mut dim_sizes_vec: Vec<usize> = Vec::new();
459
460    for dim_name in var_dims {
461        let size = *dim_sizes.get(dim_name).unwrap_or(&1) as usize;
462        dim_sizes_vec.push(size);
463
464        if has_meaningful_coords {
465            // Use coordinate names if available and meaningful
466            if let Some(coord_value) = coords.get(dim_name) {
467                match coord_value {
468                    Value::Strings(labels) => {
469                        dim_coords.push(labels.clone());
470                    }
471                    _ => {
472                        // Fallback to 1-based indexing (Stan format)
473                        dim_coords.push((1..=size).map(|i| i.to_string()).collect());
474                    }
475                }
476            } else {
477                // Fallback to 1-based indexing (Stan format)
478                dim_coords.push((1..=size).map(|i| i.to_string()).collect());
479            }
480        } else {
481            // Use Stan-compliant 1-based indexing
482            dim_coords.push((1..=size).map(|i| i.to_string()).collect());
483        }
484    }
485
486    // Generate Cartesian product using column-major order (Stan format)
487    let (coord_names, indices) =
488        cartesian_product_with_indices_column_major(&dim_coords, &dim_sizes_vec);
489
490    // Prepend variable name to each coordinate combination
491    let column_names: Vec<String> = coord_names
492        .into_iter()
493        .map(|coord| format!("{}.{}", var_name, coord))
494        .collect();
495
496    Ok((column_names, indices))
497}
498
499/// Compute the Cartesian product with column-major ordering (Stan format)
500///
501/// Stan uses what they call "column-major" ordering, but it's actually the same as
502/// row-major order: the first index changes slowest, last index changes fastest.
503/// For example, a 2x3 array produces: [1,1], [1,2], [1,3], [2,1], [2,2], [2,3]
504fn cartesian_product_with_indices_column_major(
505    coord_sets: &[Vec<String>],
506    dim_sizes: &[usize],
507) -> (Vec<String>, Vec<usize>) {
508    if coord_sets.is_empty() {
509        return (vec![], vec![]);
510    }
511
512    if coord_sets.len() == 1 {
513        let indices: Vec<usize> = (0..coord_sets[0].len()).collect();
514        return (coord_sets[0].clone(), indices);
515    }
516
517    let mut names = vec![];
518    let mut indices = vec![];
519
520    // Stan's "column-major" is actually row-major order
521    cartesian_product_recursive_with_indices(
522        coord_sets,
523        dim_sizes,
524        0,
525        &mut String::new(),
526        &mut vec![],
527        &mut names,
528        &mut indices,
529    );
530
531    (names, indices)
532}
533
534fn cartesian_product_recursive_with_indices(
535    coord_sets: &[Vec<String>],
536    dim_sizes: &[usize],
537    dim_idx: usize,
538    current_name: &mut String,
539    current_indices: &mut Vec<usize>,
540    result_names: &mut Vec<String>,
541    result_indices: &mut Vec<usize>,
542) {
543    if dim_idx == coord_sets.len() {
544        result_names.push(current_name.clone());
545        // Compute linear index from multi-dimensional indices
546        let mut linear_index = 0;
547        for (i, &idx) in current_indices.iter().enumerate() {
548            let mut stride = 1;
549            for &size in &dim_sizes[i + 1..] {
550                stride *= size;
551            }
552            linear_index += idx * stride;
553        }
554        result_indices.push(linear_index);
555        return;
556    }
557
558    let is_first_dim = dim_idx == 0;
559
560    for (coord_idx, coord) in coord_sets[dim_idx].iter().enumerate() {
561        let mut new_name = current_name.clone();
562        if !is_first_dim {
563            new_name.push('.');
564        }
565        new_name.push_str(coord);
566
567        current_indices.push(coord_idx);
568        cartesian_product_recursive_with_indices(
569            coord_sets,
570            dim_sizes,
571            dim_idx + 1,
572            &mut new_name,
573            current_indices,
574            result_names,
575            result_indices,
576        );
577        current_indices.pop();
578    }
579}
580
581impl TraceStorage for CsvTraceStorage {
582    type ChainStorage = CsvChainStorage;
583    type Finalized = ();
584
585    fn initialize_trace_for_chain(&self, chain_id: u64) -> Result<Self::ChainStorage> {
586        CsvChainStorage::new(
587            &self.output_dir,
588            chain_id,
589            self.precision,
590            self.store_warmup,
591            self.parameter_names.clone(),
592            self.column_mapping.clone(),
593        )
594    }
595
596    fn finalize(
597        self,
598        traces: Vec<Result<<Self::ChainStorage as ChainStorage>::Finalized>>,
599    ) -> Result<(Option<anyhow::Error>, Self::Finalized)> {
600        // Check for any errors in the chain finalizations
601        for trace_result in traces {
602            if let Err(err) = trace_result {
603                return Ok((Some(err), ()));
604            }
605        }
606        Ok((None, ()))
607    }
608
609    fn inspect(
610        &self,
611        traces: Vec<Result<Option<<Self::ChainStorage as ChainStorage>::Finalized>>>,
612    ) -> Result<(Option<anyhow::Error>, Self::Finalized)> {
613        // Check for any errors in the chain inspections
614        for trace_result in traces {
615            if let Err(err) = trace_result {
616                return Ok((Some(err), ()));
617            }
618        }
619        Ok((None, ()))
620    }
621}
622
623#[cfg(test)]
624mod tests {
625    use super::*;
626    use crate::{
627        CpuLogpFunc, CpuMath, CpuMathError, DiagGradNutsSettings, LogpError, Model, Sampler,
628    };
629    use anyhow::Result;
630    use nuts_derive::Storable;
631    use nuts_storable::{HasDims, Value};
632    use rand::Rng;
633    use std::collections::HashMap;
634    use std::fs;
635    use std::path::Path;
636    use thiserror::Error;
637
638    #[allow(dead_code)]
639    #[derive(Debug, Error)]
640    enum TestLogpError {
641        #[error("Test error")]
642        Test,
643    }
644
645    impl LogpError for TestLogpError {
646        fn is_recoverable(&self) -> bool {
647            false
648        }
649    }
650
651    /// Test model with multi-dimensional coordinates
652    #[derive(Clone)]
653    struct MultiDimTestLogp {
654        dim_a: usize,
655        dim_b: usize,
656    }
657
658    impl HasDims for MultiDimTestLogp {
659        fn dim_sizes(&self) -> HashMap<String, u64> {
660            HashMap::from([
661                ("a".to_string(), self.dim_a as u64),
662                ("b".to_string(), self.dim_b as u64),
663            ])
664        }
665
666        fn coords(&self) -> HashMap<String, Value> {
667            HashMap::from([
668                (
669                    "a".to_string(),
670                    Value::Strings(vec!["x".to_string(), "y".to_string()]),
671                ),
672                (
673                    "b".to_string(),
674                    Value::Strings(vec!["alpha".to_string(), "beta".to_string()]),
675                ),
676            ])
677        }
678    }
679
680    #[derive(Storable)]
681    struct MultiDimExpandedDraw {
682        #[storable(dims("a", "b"))]
683        param_matrix: Vec<f64>,
684        scalar_value: f64,
685    }
686
687    impl CpuLogpFunc for MultiDimTestLogp {
688        type LogpError = TestLogpError;
689        type FlowParameters = ();
690        type ExpandedVector = MultiDimExpandedDraw;
691
692        fn dim(&self) -> usize {
693            self.dim_a * self.dim_b
694        }
695
696        fn logp(&mut self, x: &[f64], grad: &mut [f64]) -> Result<f64, Self::LogpError> {
697            let mut logp = 0.0;
698            for (i, &xi) in x.iter().enumerate() {
699                logp -= 0.5 * xi * xi;
700                grad[i] = -xi;
701            }
702            Ok(logp)
703        }
704
705        fn expand_vector<R: Rng + ?Sized>(
706            &mut self,
707            _rng: &mut R,
708            array: &[f64],
709        ) -> Result<Self::ExpandedVector, CpuMathError> {
710            Ok(MultiDimExpandedDraw {
711                param_matrix: array.to_vec(),
712                scalar_value: array.iter().sum(),
713            })
714        }
715
716        fn vector_coord(&self) -> Option<Value> {
717            Some(Value::Strings(
718                (0..self.dim()).map(|i| format!("theta{}", i + 1)).collect(),
719            ))
720        }
721    }
722
723    struct MultiDimTestModel {
724        math: CpuMath<MultiDimTestLogp>,
725    }
726
727    impl Model for MultiDimTestModel {
728        type Math<'model>
729            = CpuMath<MultiDimTestLogp>
730        where
731            Self: 'model;
732
733        fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
734            Ok(self.math.clone())
735        }
736
737        fn init_position<R: Rng + ?Sized>(&self, rng: &mut R, position: &mut [f64]) -> Result<()> {
738            for p in position.iter_mut() {
739                *p = rng.random_range(-1.0..1.0);
740            }
741            Ok(())
742        }
743    }
744
745    /// Test model without coordinates (fallback behavior)
746    #[derive(Clone)]
747    struct SimpleTestLogp {
748        dim: usize,
749    }
750
751    impl HasDims for SimpleTestLogp {
752        fn dim_sizes(&self) -> HashMap<String, u64> {
753            HashMap::from([("simple_param".to_string(), self.dim as u64)])
754        }
755        // No coords() method - should use fallback
756    }
757
758    #[derive(Storable)]
759    struct SimpleExpandedDraw {
760        #[storable(dims("simple_param"))]
761        values: Vec<f64>,
762    }
763
764    impl CpuLogpFunc for SimpleTestLogp {
765        type LogpError = TestLogpError;
766        type FlowParameters = ();
767        type ExpandedVector = SimpleExpandedDraw;
768
769        fn dim(&self) -> usize {
770            self.dim
771        }
772
773        fn logp(&mut self, x: &[f64], grad: &mut [f64]) -> Result<f64, Self::LogpError> {
774            let mut logp = 0.0;
775            for (i, &xi) in x.iter().enumerate() {
776                logp -= 0.5 * xi * xi;
777                grad[i] = -xi;
778            }
779            Ok(logp)
780        }
781
782        fn expand_vector<R: Rng + ?Sized>(
783            &mut self,
784            _rng: &mut R,
785            array: &[f64],
786        ) -> Result<Self::ExpandedVector, CpuMathError> {
787            Ok(SimpleExpandedDraw {
788                values: array.to_vec(),
789            })
790        }
791
792        fn vector_coord(&self) -> Option<Value> {
793            Some(Value::Strings(vec![
794                "param1".to_string(),
795                "param2".to_string(),
796                "param3".to_string(),
797            ]))
798        }
799    }
800
801    struct SimpleTestModel {
802        math: CpuMath<SimpleTestLogp>,
803    }
804
805    impl Model for SimpleTestModel {
806        type Math<'model>
807            = CpuMath<SimpleTestLogp>
808        where
809            Self: 'model;
810
811        fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
812            Ok(self.math.clone())
813        }
814
815        fn init_position<R: Rng + ?Sized>(&self, rng: &mut R, position: &mut [f64]) -> Result<()> {
816            for p in position.iter_mut() {
817                *p = rng.random_range(-1.0..1.0);
818            }
819            Ok(())
820        }
821    }
822
823    fn read_csv_header(path: &Path) -> Result<String> {
824        let content = fs::read_to_string(path)?;
825        content
826            .lines()
827            .next()
828            .map(|s| s.to_string())
829            .ok_or_else(|| anyhow::anyhow!("Empty CSV file"))
830    }
831
832    #[test]
833    fn test_multidim_coordinate_naming() -> Result<()> {
834        let temp_dir = tempfile::tempdir()?;
835        let output_path = temp_dir.path().join("multidim_test");
836
837        // Create model with 2x2 parameter matrix
838        let model = MultiDimTestModel {
839            math: CpuMath::new(MultiDimTestLogp { dim_a: 2, dim_b: 2 }),
840        };
841
842        let mut settings = DiagGradNutsSettings::default();
843        settings.num_chains = 1;
844        settings.num_tune = 10;
845        settings.num_draws = 20;
846        settings.seed = 42;
847
848        let csv_config = CsvConfig::new(&output_path)
849            .with_precision(6)
850            .store_warmup(false);
851
852        let mut sampler = Some(Sampler::new(model, settings, csv_config, 1, None)?);
853
854        // Wait for sampling to complete
855        while let Some(sampler_) = sampler.take() {
856            match sampler_.wait_timeout(std::time::Duration::from_millis(100)) {
857                crate::SamplerWaitResult::Trace(_) => break,
858                crate::SamplerWaitResult::Timeout(s) => sampler = Some(s),
859                crate::SamplerWaitResult::Err(err, _) => return Err(err),
860            }
861        }
862
863        // Check that CSV file was created
864        let csv_file = output_path.join("chain_0.csv");
865        assert!(csv_file.exists());
866
867        // Check header contains expected coordinate names
868        let header = read_csv_header(&csv_file)?;
869
870        // Should contain Cartesian product: x.alpha, x.beta, y.alpha, y.beta
871        assert!(header.contains("param_matrix.x.alpha"));
872        assert!(header.contains("param_matrix.x.beta"));
873        assert!(header.contains("param_matrix.y.alpha"));
874        assert!(header.contains("param_matrix.y.beta"));
875        assert!(header.contains("scalar_value"));
876
877        // Verify column order (Cartesian product should be in correct order)
878        let columns: Vec<&str> = header.split(',').collect();
879        let param_columns: Vec<&str> = columns
880            .iter()
881            .filter(|col| col.starts_with("param_matrix."))
882            .cloned()
883            .collect();
884
885        assert_eq!(
886            param_columns,
887            vec![
888                "param_matrix.x.alpha",
889                "param_matrix.x.beta",
890                "param_matrix.y.alpha",
891                "param_matrix.y.beta"
892            ]
893        );
894
895        Ok(())
896    }
897
898    #[test]
899    fn test_fallback_coordinate_naming() -> Result<()> {
900        let temp_dir = tempfile::tempdir()?;
901        let output_path = temp_dir.path().join("simple_test");
902
903        // Create model with 3 parameters but no coordinate specification
904        let model = SimpleTestModel {
905            math: CpuMath::new(SimpleTestLogp { dim: 3 }),
906        };
907
908        let mut settings = DiagGradNutsSettings::default();
909        settings.num_chains = 1;
910        settings.num_tune = 5;
911        settings.num_draws = 10;
912        settings.seed = 123;
913
914        let csv_config = CsvConfig::new(&output_path)
915            .with_precision(6)
916            .store_warmup(false);
917
918        let mut sampler = Some(Sampler::new(model, settings, csv_config, 1, None)?);
919
920        // Wait for sampling to complete
921        while let Some(sampler_) = sampler.take() {
922            match sampler_.wait_timeout(std::time::Duration::from_millis(100)) {
923                crate::SamplerWaitResult::Trace(_) => break,
924                crate::SamplerWaitResult::Timeout(s) => sampler = Some(s),
925                crate::SamplerWaitResult::Err(err, _) => return Err(err),
926            }
927        }
928
929        // Check that CSV file was created
930        let csv_file = output_path.join("chain_0.csv");
931        assert!(csv_file.exists());
932
933        // Check header uses fallback numeric naming
934        let header = read_csv_header(&csv_file)?;
935
936        // Should fall back to 1-based indices since no coordinates provided
937        assert!(header.contains("values.1"));
938        assert!(header.contains("values.2"));
939        assert!(header.contains("values.3"));
940
941        Ok(())
942    }
943
944    #[test]
945    fn test_cartesian_product_generation() {
946        let coord_sets = vec![
947            vec!["x".to_string(), "y".to_string()],
948            vec!["alpha".to_string(), "beta".to_string()],
949        ];
950        let dim_sizes = vec![2, 2];
951
952        let (names, indices) = cartesian_product_with_indices_column_major(&coord_sets, &dim_sizes);
953
954        assert_eq!(names, vec!["x.alpha", "x.beta", "y.alpha", "y.beta"]);
955
956        assert_eq!(indices, vec![0, 1, 2, 3]);
957    }
958
959    #[test]
960    fn test_single_dimension_coordinates() {
961        let coord_sets = vec![vec!["param1".to_string(), "param2".to_string()]];
962        let dim_sizes = vec![2];
963
964        let (names, indices) = cartesian_product_with_indices_column_major(&coord_sets, &dim_sizes);
965
966        assert_eq!(names, vec!["param1", "param2"]);
967        assert_eq!(indices, vec![0, 1]);
968    }
969
970    #[test]
971    fn test_three_dimension_cartesian_product() {
972        let coord_sets = vec![
973            vec!["a".to_string(), "b".to_string()],
974            vec!["1".to_string()],
975            vec!["i".to_string(), "j".to_string()],
976        ];
977        let dim_sizes = vec![2, 1, 2];
978
979        let (names, indices) = cartesian_product_with_indices_column_major(&coord_sets, &dim_sizes);
980
981        assert_eq!(names, vec!["a.1.i", "a.1.j", "b.1.i", "b.1.j"]);
982
983        assert_eq!(indices, vec![0, 1, 2, 3]);
984    }
985}