nuts_rs/storage/
csv.rs

1//! CSV storage backend for nuts-rs that outputs CmdStan-compatible CSV files
2//!
3//! This module provides a CSV storage backend that saves MCMC samples and
4//! statistics in a format compatible with CmdStan, allowing existing Stan
5//! analysis tools and libraries to read nuts-rs results.
6
7use std::collections::HashMap;
8use std::fs::File;
9use std::io::{BufWriter, Write};
10use std::path::{Path, PathBuf};
11
12use anyhow::{Context, Result};
13use nuts_storable::{ItemType, Value};
14
15use crate::storage::{ChainStorage, StorageConfig, TraceStorage};
16use crate::{Math, Progress, Settings};
17
18/// Configuration for CSV-based MCMC storage.
19///
20/// This storage backend creates Stan-compatible CSV files with one file per chain.
21/// Files are named `chain_{id}.csv` where `{id}` is the chain number starting from 0.
22///
23/// The CSV format matches CmdStan output:
24/// - Header row with column names
25/// - Warmup samples with negative sample_id
26/// - Post-warmup samples with positive sample_id
27/// - Standard Stan statistics (lp__, stepsize, treedepth, etc.)
28/// - Parameter columns
29pub struct CsvConfig {
30    /// Directory where CSV files will be written
31    output_dir: PathBuf,
32    /// Number of decimal places for floating point values
33    precision: usize,
34    /// Whether to store warmup samples (default: true)
35    store_warmup: bool,
36}
37
38impl CsvConfig {
39    /// Create a new CSV configuration.
40    ///
41    /// # Arguments
42    ///
43    /// * `output_dir` - Directory where CSV files will be written
44    ///
45    /// # Example
46    ///
47    /// ```rust
48    /// use nuts_rs::CsvConfig;
49    /// let config = CsvConfig::new("mcmc_output");
50    /// ```
51    pub fn new<P: AsRef<Path>>(output_dir: P) -> Self {
52        Self {
53            output_dir: output_dir.as_ref().to_path_buf(),
54            precision: 6,
55            store_warmup: true,
56        }
57    }
58
59    /// Set the precision (number of decimal places) for floating point values.
60    ///
61    /// Default is 6 decimal places.
62    pub fn with_precision(mut self, precision: usize) -> Self {
63        self.precision = precision;
64        self
65    }
66
67    /// Configure whether to store warmup samples.
68    ///
69    /// When true (default), warmup samples are included with negative sample IDs.
70    /// When false, only post-warmup samples are stored.
71    pub fn store_warmup(mut self, store: bool) -> Self {
72        self.store_warmup = store;
73        self
74    }
75}
76
77/// Main CSV storage managing multiple chains
78pub struct CsvTraceStorage {
79    output_dir: PathBuf,
80    precision: usize,
81    store_warmup: bool,
82    parameter_names: Vec<String>,
83    column_mapping: Vec<(String, usize)>, // (data_name, index_in_data)
84}
85
86/// Per-chain CSV storage
87pub struct CsvChainStorage {
88    writer: BufWriter<File>,
89    precision: usize,
90    store_warmup: bool,
91    parameter_names: Vec<String>,
92    column_mapping: Vec<(String, usize)>, // (data_name, index_in_data)
93    is_first_sample: bool,
94    headers_written: bool,
95}
96
97impl CsvChainStorage {
98    /// Create a new CSV chain storage
99    fn new(
100        output_dir: &Path,
101        chain_id: u64,
102        precision: usize,
103        store_warmup: bool,
104        parameter_names: Vec<String>,
105        column_mapping: Vec<(String, usize)>,
106    ) -> Result<Self> {
107        std::fs::create_dir_all(output_dir)
108            .with_context(|| format!("Failed to create output directory: {:?}", output_dir))?;
109
110        let file_path = output_dir.join(format!("chain_{}.csv", chain_id));
111        let file = File::create(&file_path)
112            .with_context(|| format!("Failed to create CSV file: {:?}", file_path))?;
113        let writer = BufWriter::new(file);
114
115        Ok(Self {
116            writer,
117            precision,
118            store_warmup,
119            parameter_names,
120            column_mapping,
121            is_first_sample: true,
122            headers_written: false,
123        })
124    }
125
126    /// Write the CSV header row
127    fn write_header(&mut self) -> Result<()> {
128        if self.headers_written {
129            return Ok(());
130        }
131
132        // Standard CmdStan header format - only the core columns
133        let mut headers = vec![
134            "lp__".to_string(),
135            "accept_stat__".to_string(),
136            "stepsize__".to_string(),
137            "treedepth__".to_string(),
138            "n_leapfrog__".to_string(),
139            "divergent__".to_string(),
140            "energy__".to_string(),
141        ];
142
143        // Add parameter columns from the expanded parameter vector
144        for param_name in &self.parameter_names {
145            headers.push(param_name.clone());
146        }
147
148        // Write header row
149        writeln!(self.writer, "{}", headers.join(","))?;
150        self.headers_written = true;
151        Ok(())
152    }
153
154    /// Format a value for CSV output
155    fn format_value(&self, value: &Value) -> String {
156        match value {
157            Value::ScalarF64(v) => {
158                if v.is_nan() {
159                    "NA".to_string()
160                } else if v.is_infinite() {
161                    if *v > 0.0 { "Inf" } else { "-Inf" }.to_string()
162                } else {
163                    format!("{:.prec$}", v, prec = self.precision)
164                }
165            }
166            Value::ScalarF32(v) => {
167                if v.is_nan() {
168                    "NA".to_string()
169                } else if v.is_infinite() {
170                    if *v > 0.0 { "Inf" } else { "-Inf" }.to_string()
171                } else {
172                    format!("{:.prec$}", v, prec = self.precision)
173                }
174            }
175            Value::ScalarU64(v) => v.to_string(),
176            Value::ScalarI64(v) => v.to_string(),
177            Value::ScalarBool(v) => if *v { "1" } else { "0" }.to_string(),
178            Value::F64(vec) => {
179                // For vector values, we'll just use the first element for now
180                // A more sophisticated implementation would handle multi-dimensional parameters
181                if vec.is_empty() {
182                    "NA".to_string()
183                } else {
184                    self.format_value(&Value::ScalarF64(vec[0]))
185                }
186            }
187            Value::F32(vec) => {
188                if vec.is_empty() {
189                    "NA".to_string()
190                } else {
191                    self.format_value(&Value::ScalarF32(vec[0]))
192                }
193            }
194            Value::U64(vec) => {
195                if vec.is_empty() {
196                    "NA".to_string()
197                } else {
198                    vec[0].to_string()
199                }
200            }
201            Value::I64(vec) => {
202                if vec.is_empty() {
203                    "NA".to_string()
204                } else {
205                    vec[0].to_string()
206                }
207            }
208            Value::Bool(vec) => {
209                if vec.is_empty() {
210                    "NA".to_string()
211                } else {
212                    if vec[0] { "1" } else { "0" }.to_string()
213                }
214            }
215            Value::ScalarString(v) => v.clone(),
216            Value::Strings(vec) => {
217                if vec.is_empty() {
218                    "NA".to_string()
219                } else {
220                    vec[0].clone()
221                }
222            }
223            Value::DateTime64(_, _) => panic!("DateTime64 not supported in CSV output"),
224            Value::TimeDelta64(_, _) => panic!("TimeDelta64 not supported in CSV output"),
225        }
226    }
227
228    /// Write a single sample row to the CSV file
229    fn write_sample_row(
230        &mut self,
231        stats: &Vec<(&str, Option<Value>)>,
232        draws: &Vec<(&str, Option<Value>)>,
233        _info: &Progress,
234    ) -> Result<()> {
235        let mut row_values = Vec::new();
236
237        // Create lookup maps for quick access
238        let stats_map: HashMap<&str, &Option<Value>> = stats.iter().map(|(k, v)| (*k, v)).collect();
239        let draws_map: HashMap<&str, &Option<Value>> = draws.iter().map(|(k, v)| (*k, v)).collect();
240
241        // Helper function to get stat value
242        let get_stat_value = |name: &str| -> String {
243            stats_map
244                .get(name)
245                .and_then(|opt| opt.as_ref())
246                .map(|v| self.format_value(v))
247                .unwrap_or_else(|| "NA".to_string())
248        };
249
250        row_values.push(get_stat_value("logp"));
251        row_values.push(get_stat_value("mean_tree_accept"));
252        row_values.push(get_stat_value("step_size"));
253        row_values.push(get_stat_value("depth"));
254        row_values.push(get_stat_value("n_steps"));
255        let divergent_val = stats_map
256            .get("diverging")
257            .and_then(|opt| opt.as_ref())
258            .map(|v| match v {
259                Value::ScalarBool(true) => "1".to_string(),
260                Value::ScalarBool(false) => "0".to_string(),
261                _ => "0".to_string(),
262            })
263            .unwrap_or_else(|| "0".to_string());
264        row_values.push(divergent_val);
265
266        row_values.push(get_stat_value("energy"));
267
268        // Add parameter values using the column mapping
269        for (_param_name, (data_name, index)) in
270            self.parameter_names.iter().zip(&self.column_mapping)
271        {
272            if let Some(Some(data_value)) = draws_map.get(data_name.as_str()) {
273                let formatted_value = match data_value {
274                    Value::F64(vec) => {
275                        if *index < vec.len() {
276                            self.format_value(&Value::ScalarF64(vec[*index]))
277                        } else {
278                            "NA".to_string()
279                        }
280                    }
281                    Value::F32(vec) => {
282                        if *index < vec.len() {
283                            self.format_value(&Value::ScalarF32(vec[*index]))
284                        } else {
285                            "NA".to_string()
286                        }
287                    }
288                    Value::I64(vec) => {
289                        if *index < vec.len() {
290                            self.format_value(&Value::ScalarI64(vec[*index]))
291                        } else {
292                            "NA".to_string()
293                        }
294                    }
295                    Value::U64(vec) => {
296                        if *index < vec.len() {
297                            self.format_value(&Value::ScalarU64(vec[*index]))
298                        } else {
299                            "NA".to_string()
300                        }
301                    }
302                    // Handle scalar values (index should be 0)
303                    scalar_val if *index == 0 => self.format_value(scalar_val),
304                    _ => "NA".to_string(),
305                };
306                row_values.push(formatted_value);
307            } else {
308                row_values.push("NA".to_string());
309            }
310        }
311
312        // Write the row
313        writeln!(self.writer, "{}", row_values.join(","))?;
314        Ok(())
315    }
316}
317
318impl ChainStorage for CsvChainStorage {
319    type Finalized = ();
320
321    fn record_sample(
322        &mut self,
323        _settings: &impl Settings,
324        stats: Vec<(&str, Option<Value>)>,
325        draws: Vec<(&str, Option<Value>)>,
326        info: &Progress,
327    ) -> Result<()> {
328        // Skip warmup samples if not storing them
329        if info.tuning && !self.store_warmup {
330            return Ok(());
331        }
332
333        // Write header on first sample
334        if self.is_first_sample {
335            self.write_header()?;
336            self.is_first_sample = false;
337        }
338
339        self.write_sample_row(&stats, &draws, info)?;
340        Ok(())
341    }
342
343    fn finalize(mut self) -> Result<Self::Finalized> {
344        self.writer.flush().context("Failed to flush CSV file")?;
345        Ok(())
346    }
347
348    fn flush(&self) -> Result<()> {
349        // BufWriter doesn't provide a way to flush without mutable reference
350        // In practice, the buffer will be flushed when the file is closed
351        Ok(())
352    }
353
354    fn inspect(&self) -> Result<Option<Self::Finalized>> {
355        // For CSV storage, inspection does not produce a finalized result
356        self.flush()?;
357        Ok(None)
358    }
359}
360
361impl StorageConfig for CsvConfig {
362    type Storage = CsvTraceStorage;
363
364    fn new_trace<M: Math>(self, settings: &impl Settings, math: &M) -> Result<Self::Storage> {
365        // Generate parameter names and column mapping using coordinates
366        let (parameter_names, column_mapping) =
367            generate_parameter_names_and_mapping(settings, math)?;
368
369        Ok(CsvTraceStorage {
370            output_dir: self.output_dir,
371            precision: self.precision,
372            store_warmup: self.store_warmup,
373            parameter_names,
374            column_mapping,
375        })
376    }
377}
378
379/// Generate parameter column names and mapping using coordinates or Stan-compliant indexing
380fn generate_parameter_names_and_mapping<M: Math>(
381    settings: &impl Settings,
382    math: &M,
383) -> Result<(Vec<String>, Vec<(String, usize)>)> {
384    let data_dims = settings.data_dims_all(math);
385    let coords = math.coords();
386    let mut parameter_names = Vec::new();
387    let mut column_mapping = Vec::new();
388
389    for (var_name, var_dims) in data_dims {
390        let data_type = settings.data_type(math, &var_name);
391
392        // Only process vector types that could contain parameter values
393        if matches!(
394            data_type,
395            ItemType::F64 | ItemType::F32 | ItemType::I64 | ItemType::U64
396        ) {
397            let (column_names, indices) = generate_column_names_and_indices_for_variable(
398                &var_name, &var_dims, &coords, math,
399            )?;
400
401            for (name, index) in column_names.into_iter().zip(indices) {
402                parameter_names.push(name);
403                column_mapping.push((var_name.clone(), index));
404            }
405        }
406    }
407
408    // If no parameter names were generated, fall back to simple numbering
409    if parameter_names.is_empty() {
410        let dim_sizes = math.dim_sizes();
411        let param_count = dim_sizes.get("expanded_parameter").unwrap_or(&0);
412        for i in 0..*param_count {
413            parameter_names.push(format!("param_{}", i + 1));
414            // Try to find a data field that contains the parameters
415            let data_names = settings.data_names(math);
416            let mut found_field = false;
417            for data_name in &data_names {
418                let data_type = settings.data_type(math, data_name);
419                if matches!(
420                    data_type,
421                    ItemType::F64 | ItemType::F32 | ItemType::I64 | ItemType::U64
422                ) {
423                    column_mapping.push((data_name.clone(), i as usize));
424                    found_field = true;
425                    break;
426                }
427            }
428            if !found_field {
429                column_mapping.push(("unknown".to_string(), i as usize));
430            }
431        }
432    }
433
434    Ok((parameter_names, column_mapping))
435}
436
437/// Generate column names and indices for a single variable using its dimensions and coordinates
438fn generate_column_names_and_indices_for_variable<M: Math>(
439    var_name: &str,
440    var_dims: &[String],
441    coords: &HashMap<String, Value>,
442    math: &M,
443) -> Result<(Vec<String>, Vec<usize>)> {
444    let dim_sizes = math.dim_sizes();
445
446    if var_dims.is_empty() {
447        // Scalar variable
448        return Ok((vec![var_name.to_string()], vec![0]));
449    }
450
451    // Check if we have meaningful coordinate names for all dimensions
452    let has_meaningful_coords = var_dims.iter().all(|dim_name| {
453        coords.get(dim_name).is_some_and(
454            |coord_value| matches!(coord_value, Value::Strings(labels) if !labels.is_empty()),
455        )
456    });
457
458    // Get coordinate labels for each dimension
459    let mut dim_coords: Vec<Vec<String>> = Vec::new();
460    let mut dim_sizes_vec: Vec<usize> = Vec::new();
461
462    for dim_name in var_dims {
463        let size = *dim_sizes.get(dim_name).unwrap_or(&1) as usize;
464        dim_sizes_vec.push(size);
465
466        if has_meaningful_coords {
467            // Use coordinate names if available and meaningful
468            if let Some(coord_value) = coords.get(dim_name) {
469                match coord_value {
470                    Value::Strings(labels) => {
471                        dim_coords.push(labels.clone());
472                    }
473                    _ => {
474                        // Fallback to 1-based indexing (Stan format)
475                        dim_coords.push((1..=size).map(|i| i.to_string()).collect());
476                    }
477                }
478            } else {
479                // Fallback to 1-based indexing (Stan format)
480                dim_coords.push((1..=size).map(|i| i.to_string()).collect());
481            }
482        } else {
483            // Use Stan-compliant 1-based indexing
484            dim_coords.push((1..=size).map(|i| i.to_string()).collect());
485        }
486    }
487
488    // Generate Cartesian product using column-major order (Stan format)
489    let (coord_names, indices) =
490        cartesian_product_with_indices_column_major(&dim_coords, &dim_sizes_vec);
491
492    // Prepend variable name to each coordinate combination
493    let column_names: Vec<String> = coord_names
494        .into_iter()
495        .map(|coord| format!("{}.{}", var_name, coord))
496        .collect();
497
498    Ok((column_names, indices))
499}
500
501/// Compute the Cartesian product with column-major ordering (Stan format)
502///
503/// Stan uses what they call "column-major" ordering, but it's actually the same as
504/// row-major order: the first index changes slowest, last index changes fastest.
505/// For example, a 2x3 array produces: [1,1], [1,2], [1,3], [2,1], [2,2], [2,3]
506fn cartesian_product_with_indices_column_major(
507    coord_sets: &[Vec<String>],
508    dim_sizes: &[usize],
509) -> (Vec<String>, Vec<usize>) {
510    if coord_sets.is_empty() {
511        return (vec![], vec![]);
512    }
513
514    if coord_sets.len() == 1 {
515        let indices: Vec<usize> = (0..coord_sets[0].len()).collect();
516        return (coord_sets[0].clone(), indices);
517    }
518
519    let mut names = vec![];
520    let mut indices = vec![];
521
522    // Stan's "column-major" is actually row-major order
523    cartesian_product_recursive_with_indices(
524        coord_sets,
525        dim_sizes,
526        0,
527        &mut String::new(),
528        &mut vec![],
529        &mut names,
530        &mut indices,
531    );
532
533    (names, indices)
534}
535
536fn cartesian_product_recursive_with_indices(
537    coord_sets: &[Vec<String>],
538    dim_sizes: &[usize],
539    dim_idx: usize,
540    current_name: &mut String,
541    current_indices: &mut Vec<usize>,
542    result_names: &mut Vec<String>,
543    result_indices: &mut Vec<usize>,
544) {
545    if dim_idx == coord_sets.len() {
546        result_names.push(current_name.clone());
547        // Compute linear index from multi-dimensional indices
548        let mut linear_index = 0;
549        for (i, &idx) in current_indices.iter().enumerate() {
550            let mut stride = 1;
551            for &size in &dim_sizes[i + 1..] {
552                stride *= size;
553            }
554            linear_index += idx * stride;
555        }
556        result_indices.push(linear_index);
557        return;
558    }
559
560    let is_first_dim = dim_idx == 0;
561
562    for (coord_idx, coord) in coord_sets[dim_idx].iter().enumerate() {
563        let mut new_name = current_name.clone();
564        if !is_first_dim {
565            new_name.push('.');
566        }
567        new_name.push_str(coord);
568
569        current_indices.push(coord_idx);
570        cartesian_product_recursive_with_indices(
571            coord_sets,
572            dim_sizes,
573            dim_idx + 1,
574            &mut new_name,
575            current_indices,
576            result_names,
577            result_indices,
578        );
579        current_indices.pop();
580    }
581}
582
583impl TraceStorage for CsvTraceStorage {
584    type ChainStorage = CsvChainStorage;
585    type Finalized = ();
586
587    fn initialize_trace_for_chain(&self, chain_id: u64) -> Result<Self::ChainStorage> {
588        CsvChainStorage::new(
589            &self.output_dir,
590            chain_id,
591            self.precision,
592            self.store_warmup,
593            self.parameter_names.clone(),
594            self.column_mapping.clone(),
595        )
596    }
597
598    fn finalize(
599        self,
600        traces: Vec<Result<<Self::ChainStorage as ChainStorage>::Finalized>>,
601    ) -> Result<(Option<anyhow::Error>, Self::Finalized)> {
602        // Check for any errors in the chain finalizations
603        for trace_result in traces {
604            if let Err(err) = trace_result {
605                return Ok((Some(err), ()));
606            }
607        }
608        Ok((None, ()))
609    }
610
611    fn inspect(
612        &self,
613        traces: Vec<Result<Option<<Self::ChainStorage as ChainStorage>::Finalized>>>,
614    ) -> Result<(Option<anyhow::Error>, Self::Finalized)> {
615        // Check for any errors in the chain inspections
616        for trace_result in traces {
617            if let Err(err) = trace_result {
618                return Ok((Some(err), ()));
619            }
620        }
621        Ok((None, ()))
622    }
623}
624
625#[cfg(test)]
626mod tests {
627    use super::*;
628    use crate::{
629        CpuLogpFunc, CpuMath, CpuMathError, DiagGradNutsSettings, LogpError, Model, Sampler,
630    };
631    use anyhow::Result;
632    use nuts_derive::Storable;
633    use nuts_storable::{HasDims, Value};
634    use rand::Rng;
635    use std::collections::HashMap;
636    use std::fs;
637    use std::path::Path;
638    use thiserror::Error;
639
640    #[allow(dead_code)]
641    #[derive(Debug, Error)]
642    enum TestLogpError {
643        #[error("Test error")]
644        Test,
645    }
646
647    impl LogpError for TestLogpError {
648        fn is_recoverable(&self) -> bool {
649            false
650        }
651    }
652
653    /// Test model with multi-dimensional coordinates
654    #[derive(Clone)]
655    struct MultiDimTestLogp {
656        dim_a: usize,
657        dim_b: usize,
658    }
659
660    impl HasDims for MultiDimTestLogp {
661        fn dim_sizes(&self) -> HashMap<String, u64> {
662            HashMap::from([
663                ("a".to_string(), self.dim_a as u64),
664                ("b".to_string(), self.dim_b as u64),
665            ])
666        }
667
668        fn coords(&self) -> HashMap<String, Value> {
669            HashMap::from([
670                (
671                    "a".to_string(),
672                    Value::Strings(vec!["x".to_string(), "y".to_string()]),
673                ),
674                (
675                    "b".to_string(),
676                    Value::Strings(vec!["alpha".to_string(), "beta".to_string()]),
677                ),
678            ])
679        }
680    }
681
682    #[derive(Storable)]
683    struct MultiDimExpandedDraw {
684        #[storable(dims("a", "b"))]
685        param_matrix: Vec<f64>,
686        scalar_value: f64,
687    }
688
689    impl CpuLogpFunc for MultiDimTestLogp {
690        type LogpError = TestLogpError;
691        type FlowParameters = ();
692        type ExpandedVector = MultiDimExpandedDraw;
693
694        fn dim(&self) -> usize {
695            self.dim_a * self.dim_b
696        }
697
698        fn logp(&mut self, x: &[f64], grad: &mut [f64]) -> Result<f64, Self::LogpError> {
699            let mut logp = 0.0;
700            for (i, &xi) in x.iter().enumerate() {
701                logp -= 0.5 * xi * xi;
702                grad[i] = -xi;
703            }
704            Ok(logp)
705        }
706
707        fn expand_vector<R: Rng + ?Sized>(
708            &mut self,
709            _rng: &mut R,
710            array: &[f64],
711        ) -> Result<Self::ExpandedVector, CpuMathError> {
712            Ok(MultiDimExpandedDraw {
713                param_matrix: array.to_vec(),
714                scalar_value: array.iter().sum(),
715            })
716        }
717
718        fn vector_coord(&self) -> Option<Value> {
719            Some(Value::Strings(
720                (0..self.dim()).map(|i| format!("theta{}", i + 1)).collect(),
721            ))
722        }
723    }
724
725    struct MultiDimTestModel {
726        math: CpuMath<MultiDimTestLogp>,
727    }
728
729    impl Model for MultiDimTestModel {
730        type Math<'model>
731            = CpuMath<MultiDimTestLogp>
732        where
733            Self: 'model;
734
735        fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
736            Ok(self.math.clone())
737        }
738
739        fn init_position<R: Rng + ?Sized>(&self, rng: &mut R, position: &mut [f64]) -> Result<()> {
740            for p in position.iter_mut() {
741                *p = rng.random_range(-1.0..1.0);
742            }
743            Ok(())
744        }
745    }
746
747    /// Test model without coordinates (fallback behavior)
748    #[derive(Clone)]
749    struct SimpleTestLogp {
750        dim: usize,
751    }
752
753    impl HasDims for SimpleTestLogp {
754        fn dim_sizes(&self) -> HashMap<String, u64> {
755            HashMap::from([("simple_param".to_string(), self.dim as u64)])
756        }
757        // No coords() method - should use fallback
758    }
759
760    #[derive(Storable)]
761    struct SimpleExpandedDraw {
762        #[storable(dims("simple_param"))]
763        values: Vec<f64>,
764    }
765
766    impl CpuLogpFunc for SimpleTestLogp {
767        type LogpError = TestLogpError;
768        type FlowParameters = ();
769        type ExpandedVector = SimpleExpandedDraw;
770
771        fn dim(&self) -> usize {
772            self.dim
773        }
774
775        fn logp(&mut self, x: &[f64], grad: &mut [f64]) -> Result<f64, Self::LogpError> {
776            let mut logp = 0.0;
777            for (i, &xi) in x.iter().enumerate() {
778                logp -= 0.5 * xi * xi;
779                grad[i] = -xi;
780            }
781            Ok(logp)
782        }
783
784        fn expand_vector<R: Rng + ?Sized>(
785            &mut self,
786            _rng: &mut R,
787            array: &[f64],
788        ) -> Result<Self::ExpandedVector, CpuMathError> {
789            Ok(SimpleExpandedDraw {
790                values: array.to_vec(),
791            })
792        }
793
794        fn vector_coord(&self) -> Option<Value> {
795            Some(Value::Strings(vec![
796                "param1".to_string(),
797                "param2".to_string(),
798                "param3".to_string(),
799            ]))
800        }
801    }
802
803    struct SimpleTestModel {
804        math: CpuMath<SimpleTestLogp>,
805    }
806
807    impl Model for SimpleTestModel {
808        type Math<'model>
809            = CpuMath<SimpleTestLogp>
810        where
811            Self: 'model;
812
813        fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
814            Ok(self.math.clone())
815        }
816
817        fn init_position<R: Rng + ?Sized>(&self, rng: &mut R, position: &mut [f64]) -> Result<()> {
818            for p in position.iter_mut() {
819                *p = rng.random_range(-1.0..1.0);
820            }
821            Ok(())
822        }
823    }
824
825    fn read_csv_header(path: &Path) -> Result<String> {
826        let content = fs::read_to_string(path)?;
827        content
828            .lines()
829            .next()
830            .map(|s| s.to_string())
831            .ok_or_else(|| anyhow::anyhow!("Empty CSV file"))
832    }
833
834    #[test]
835    fn test_multidim_coordinate_naming() -> Result<()> {
836        let temp_dir = tempfile::tempdir()?;
837        let output_path = temp_dir.path().join("multidim_test");
838
839        // Create model with 2x2 parameter matrix
840        let model = MultiDimTestModel {
841            math: CpuMath::new(MultiDimTestLogp { dim_a: 2, dim_b: 2 }),
842        };
843
844        let mut settings = DiagGradNutsSettings::default();
845        settings.num_chains = 1;
846        settings.num_tune = 10;
847        settings.num_draws = 20;
848        settings.seed = 42;
849
850        let csv_config = CsvConfig::new(&output_path)
851            .with_precision(6)
852            .store_warmup(false);
853
854        let mut sampler = Some(Sampler::new(model, settings, csv_config, 1, None)?);
855
856        // Wait for sampling to complete
857        while let Some(sampler_) = sampler.take() {
858            match sampler_.wait_timeout(std::time::Duration::from_millis(100)) {
859                crate::SamplerWaitResult::Trace(_) => break,
860                crate::SamplerWaitResult::Timeout(s) => sampler = Some(s),
861                crate::SamplerWaitResult::Err(err, _) => return Err(err),
862            }
863        }
864
865        // Check that CSV file was created
866        let csv_file = output_path.join("chain_0.csv");
867        assert!(csv_file.exists());
868
869        // Check header contains expected coordinate names
870        let header = read_csv_header(&csv_file)?;
871
872        // Should contain Cartesian product: x.alpha, x.beta, y.alpha, y.beta
873        assert!(header.contains("param_matrix.x.alpha"));
874        assert!(header.contains("param_matrix.x.beta"));
875        assert!(header.contains("param_matrix.y.alpha"));
876        assert!(header.contains("param_matrix.y.beta"));
877        assert!(header.contains("scalar_value"));
878
879        // Verify column order (Cartesian product should be in correct order)
880        let columns: Vec<&str> = header.split(',').collect();
881        let param_columns: Vec<&str> = columns
882            .iter()
883            .filter(|col| col.starts_with("param_matrix."))
884            .cloned()
885            .collect();
886
887        assert_eq!(
888            param_columns,
889            vec![
890                "param_matrix.x.alpha",
891                "param_matrix.x.beta",
892                "param_matrix.y.alpha",
893                "param_matrix.y.beta"
894            ]
895        );
896
897        Ok(())
898    }
899
900    #[test]
901    fn test_fallback_coordinate_naming() -> Result<()> {
902        let temp_dir = tempfile::tempdir()?;
903        let output_path = temp_dir.path().join("simple_test");
904
905        // Create model with 3 parameters but no coordinate specification
906        let model = SimpleTestModel {
907            math: CpuMath::new(SimpleTestLogp { dim: 3 }),
908        };
909
910        let mut settings = DiagGradNutsSettings::default();
911        settings.num_chains = 1;
912        settings.num_tune = 5;
913        settings.num_draws = 10;
914        settings.seed = 123;
915
916        let csv_config = CsvConfig::new(&output_path)
917            .with_precision(6)
918            .store_warmup(false);
919
920        let mut sampler = Some(Sampler::new(model, settings, csv_config, 1, None)?);
921
922        // Wait for sampling to complete
923        while let Some(sampler_) = sampler.take() {
924            match sampler_.wait_timeout(std::time::Duration::from_millis(100)) {
925                crate::SamplerWaitResult::Trace(_) => break,
926                crate::SamplerWaitResult::Timeout(s) => sampler = Some(s),
927                crate::SamplerWaitResult::Err(err, _) => return Err(err),
928            }
929        }
930
931        // Check that CSV file was created
932        let csv_file = output_path.join("chain_0.csv");
933        assert!(csv_file.exists());
934
935        // Check header uses fallback numeric naming
936        let header = read_csv_header(&csv_file)?;
937
938        // Should fall back to 1-based indices since no coordinates provided
939        assert!(header.contains("values.1"));
940        assert!(header.contains("values.2"));
941        assert!(header.contains("values.3"));
942
943        Ok(())
944    }
945
946    #[test]
947    fn test_cartesian_product_generation() {
948        let coord_sets = vec![
949            vec!["x".to_string(), "y".to_string()],
950            vec!["alpha".to_string(), "beta".to_string()],
951        ];
952        let dim_sizes = vec![2, 2];
953
954        let (names, indices) = cartesian_product_with_indices_column_major(&coord_sets, &dim_sizes);
955
956        assert_eq!(names, vec!["x.alpha", "x.beta", "y.alpha", "y.beta"]);
957
958        assert_eq!(indices, vec![0, 1, 2, 3]);
959    }
960
961    #[test]
962    fn test_single_dimension_coordinates() {
963        let coord_sets = vec![vec!["param1".to_string(), "param2".to_string()]];
964        let dim_sizes = vec![2];
965
966        let (names, indices) = cartesian_product_with_indices_column_major(&coord_sets, &dim_sizes);
967
968        assert_eq!(names, vec!["param1", "param2"]);
969        assert_eq!(indices, vec![0, 1]);
970    }
971
972    #[test]
973    fn test_three_dimension_cartesian_product() {
974        let coord_sets = vec![
975            vec!["a".to_string(), "b".to_string()],
976            vec!["1".to_string()],
977            vec!["i".to_string(), "j".to_string()],
978        ];
979        let dim_sizes = vec![2, 1, 2];
980
981        let (names, indices) = cartesian_product_with_indices_column_major(&coord_sets, &dim_sizes);
982
983        assert_eq!(names, vec!["a.1.i", "a.1.j", "b.1.i", "b.1.j"]);
984
985        assert_eq!(indices, vec![0, 1, 2, 3]);
986    }
987}