Skip to main content

nuts_rs/storage/
csv.rs

1//! Store samples and statistics as CmdStan-compatible CSV files.
2
3use std::collections::HashMap;
4use std::fs::File;
5use std::io::{BufWriter, Write};
6use std::path::{Path, PathBuf};
7
8use anyhow::{Context, Result};
9use nuts_storable::{ItemType, Value};
10
11use crate::storage::{ChainStorage, StorageConfig, TraceStorage};
12use crate::{Math, Progress, Settings};
13
14/// Configuration for CSV-based MCMC storage.
15///
16/// This storage backend creates Stan-compatible CSV files with one file per chain.
17/// Files are named `chain_{id}.csv` where `{id}` is the chain number starting from 0.
18///
19/// The CSV format matches CmdStan output:
20/// - Header row with column names
21/// - Warmup samples with negative sample_id
22/// - Post-warmup samples with positive sample_id
23/// - Standard Stan statistics (lp__, stepsize, treedepth, etc.)
24/// - Parameter columns
25pub struct CsvConfig {
26    /// Directory where CSV files will be written
27    output_dir: PathBuf,
28    /// Number of decimal places for floating point values
29    precision: usize,
30    /// Whether to store warmup samples (default: true)
31    store_warmup: bool,
32}
33
34impl CsvConfig {
35    /// Create a new CSV configuration.
36    ///
37    /// # Arguments
38    ///
39    /// * `output_dir` - Directory where CSV files will be written
40    ///
41    /// # Example
42    ///
43    /// ```rust
44    /// use nuts_rs::CsvConfig;
45    /// let config = CsvConfig::new("mcmc_output");
46    /// ```
47    pub fn new<P: AsRef<Path>>(output_dir: P) -> Self {
48        Self {
49            output_dir: output_dir.as_ref().to_path_buf(),
50            precision: 6,
51            store_warmup: true,
52        }
53    }
54
55    /// Set the precision (number of decimal places) for floating point values.
56    ///
57    /// Default is 6 decimal places.
58    pub fn with_precision(mut self, precision: usize) -> Self {
59        self.precision = precision;
60        self
61    }
62
63    /// Configure whether to store warmup samples.
64    ///
65    /// When true (default), warmup samples are included with negative sample IDs.
66    /// When false, only post-warmup samples are stored.
67    pub fn store_warmup(mut self, store: bool) -> Self {
68        self.store_warmup = store;
69        self
70    }
71}
72
73/// Main CSV storage managing multiple chains
74pub struct CsvTraceStorage {
75    output_dir: PathBuf,
76    precision: usize,
77    store_warmup: bool,
78    parameter_names: Vec<String>,
79    column_mapping: Vec<(String, usize)>, // (data_name, index_in_data)
80}
81
82/// Per-chain CSV storage
83pub struct CsvChainStorage {
84    writer: BufWriter<File>,
85    precision: usize,
86    store_warmup: bool,
87    parameter_names: Vec<String>,
88    column_mapping: Vec<(String, usize)>, // (data_name, index_in_data)
89    is_first_sample: bool,
90    headers_written: bool,
91}
92
93impl CsvChainStorage {
94    /// Create a new CSV chain storage
95    fn new(
96        output_dir: &Path,
97        chain_id: u64,
98        precision: usize,
99        store_warmup: bool,
100        parameter_names: Vec<String>,
101        column_mapping: Vec<(String, usize)>,
102    ) -> Result<Self> {
103        std::fs::create_dir_all(output_dir)
104            .with_context(|| format!("Failed to create output directory: {:?}", output_dir))?;
105
106        let file_path = output_dir.join(format!("chain_{}.csv", chain_id));
107        let file = File::create(&file_path)
108            .with_context(|| format!("Failed to create CSV file: {:?}", file_path))?;
109        let writer = BufWriter::new(file);
110
111        Ok(Self {
112            writer,
113            precision,
114            store_warmup,
115            parameter_names,
116            column_mapping,
117            is_first_sample: true,
118            headers_written: false,
119        })
120    }
121
122    /// Write the CSV header row
123    fn write_header(&mut self) -> Result<()> {
124        if self.headers_written {
125            return Ok(());
126        }
127
128        // Standard CmdStan header format - only the core columns
129        let mut headers = vec![
130            "lp__".to_string(),
131            "accept_stat__".to_string(),
132            "stepsize__".to_string(),
133            "treedepth__".to_string(),
134            "n_leapfrog__".to_string(),
135            "divergent__".to_string(),
136            "energy__".to_string(),
137        ];
138
139        // Add parameter columns from the expanded parameter vector
140        for param_name in &self.parameter_names {
141            headers.push(param_name.clone());
142        }
143
144        // Write header row
145        writeln!(self.writer, "{}", headers.join(","))?;
146        self.headers_written = true;
147        Ok(())
148    }
149
150    /// Format a value for CSV output
151    fn format_value(&self, value: &Value) -> String {
152        match value {
153            Value::ScalarF64(v) => {
154                if v.is_nan() {
155                    "NA".to_string()
156                } else if v.is_infinite() {
157                    if *v > 0.0 { "Inf" } else { "-Inf" }.to_string()
158                } else {
159                    format!("{:.prec$}", v, prec = self.precision)
160                }
161            }
162            Value::ScalarF32(v) => {
163                if v.is_nan() {
164                    "NA".to_string()
165                } else if v.is_infinite() {
166                    if *v > 0.0 { "Inf" } else { "-Inf" }.to_string()
167                } else {
168                    format!("{:.prec$}", v, prec = self.precision)
169                }
170            }
171            Value::ScalarU64(v) => v.to_string(),
172            Value::ScalarI64(v) => v.to_string(),
173            Value::ScalarBool(v) => if *v { "1" } else { "0" }.to_string(),
174            Value::F64(vec) => {
175                // For vector values, we'll just use the first element for now
176                // A more sophisticated implementation would handle multi-dimensional parameters
177                if vec.is_empty() {
178                    "NA".to_string()
179                } else {
180                    self.format_value(&Value::ScalarF64(vec[0]))
181                }
182            }
183            Value::F32(vec) => {
184                if vec.is_empty() {
185                    "NA".to_string()
186                } else {
187                    self.format_value(&Value::ScalarF32(vec[0]))
188                }
189            }
190            Value::U64(vec) => {
191                if vec.is_empty() {
192                    "NA".to_string()
193                } else {
194                    vec[0].to_string()
195                }
196            }
197            Value::I64(vec) => {
198                if vec.is_empty() {
199                    "NA".to_string()
200                } else {
201                    vec[0].to_string()
202                }
203            }
204            Value::Bool(vec) => {
205                if vec.is_empty() {
206                    "NA".to_string()
207                } else {
208                    if vec[0] { "1" } else { "0" }.to_string()
209                }
210            }
211            Value::ScalarString(v) => v.clone(),
212            Value::Strings(vec) => {
213                if vec.is_empty() {
214                    "NA".to_string()
215                } else {
216                    vec[0].clone()
217                }
218            }
219            Value::DateTime64(_, _) => panic!("DateTime64 not supported in CSV output"),
220            Value::TimeDelta64(_, _) => panic!("TimeDelta64 not supported in CSV output"),
221        }
222    }
223
224    /// Write a single sample row to the CSV file
225    fn write_sample_row(
226        &mut self,
227        stats: &Vec<(&str, Option<Value>)>,
228        draws: &Vec<(&str, Option<Value>)>,
229        _info: &Progress,
230    ) -> Result<()> {
231        let mut row_values = Vec::new();
232
233        // Create lookup maps for quick access
234        let stats_map: HashMap<&str, &Option<Value>> = stats.iter().map(|(k, v)| (*k, v)).collect();
235        let draws_map: HashMap<&str, &Option<Value>> = draws.iter().map(|(k, v)| (*k, v)).collect();
236
237        // Helper function to get stat value
238        let get_stat_value = |name: &str| -> String {
239            stats_map
240                .get(name)
241                .and_then(|opt| opt.as_ref())
242                .map(|v| self.format_value(v))
243                .unwrap_or_else(|| "NA".to_string())
244        };
245
246        row_values.push(get_stat_value("logp"));
247        row_values.push(get_stat_value("mean_tree_accept"));
248        row_values.push(get_stat_value("step_size"));
249        row_values.push(get_stat_value("depth"));
250        row_values.push(get_stat_value("n_steps"));
251        let divergent_val = stats_map
252            .get("diverging")
253            .and_then(|opt| opt.as_ref())
254            .map(|v| match v {
255                Value::ScalarBool(true) => "1".to_string(),
256                Value::ScalarBool(false) => "0".to_string(),
257                _ => "0".to_string(),
258            })
259            .unwrap_or_else(|| "0".to_string());
260        row_values.push(divergent_val);
261
262        row_values.push(get_stat_value("energy"));
263
264        // Add parameter values using the column mapping
265        for (_param_name, (data_name, index)) in
266            self.parameter_names.iter().zip(&self.column_mapping)
267        {
268            if let Some(Some(data_value)) = draws_map.get(data_name.as_str()) {
269                let formatted_value = match data_value {
270                    Value::F64(vec) => {
271                        if *index < vec.len() {
272                            self.format_value(&Value::ScalarF64(vec[*index]))
273                        } else {
274                            "NA".to_string()
275                        }
276                    }
277                    Value::F32(vec) => {
278                        if *index < vec.len() {
279                            self.format_value(&Value::ScalarF32(vec[*index]))
280                        } else {
281                            "NA".to_string()
282                        }
283                    }
284                    Value::I64(vec) => {
285                        if *index < vec.len() {
286                            self.format_value(&Value::ScalarI64(vec[*index]))
287                        } else {
288                            "NA".to_string()
289                        }
290                    }
291                    Value::U64(vec) => {
292                        if *index < vec.len() {
293                            self.format_value(&Value::ScalarU64(vec[*index]))
294                        } else {
295                            "NA".to_string()
296                        }
297                    }
298                    // Handle scalar values (index should be 0)
299                    scalar_val if *index == 0 => self.format_value(scalar_val),
300                    _ => "NA".to_string(),
301                };
302                row_values.push(formatted_value);
303            } else {
304                row_values.push("NA".to_string());
305            }
306        }
307
308        // Write the row
309        writeln!(self.writer, "{}", row_values.join(","))?;
310        Ok(())
311    }
312}
313
314impl ChainStorage for CsvChainStorage {
315    type Finalized = ();
316
317    fn record_sample(
318        &mut self,
319        _settings: &impl Settings,
320        stats: Vec<(&str, Option<Value>)>,
321        draws: Vec<(&str, Option<Value>)>,
322        info: &Progress,
323    ) -> Result<()> {
324        // Skip warmup samples if not storing them
325        if info.tuning && !self.store_warmup {
326            return Ok(());
327        }
328
329        // Write header on first sample
330        if self.is_first_sample {
331            self.write_header()?;
332            self.is_first_sample = false;
333        }
334
335        self.write_sample_row(&stats, &draws, info)?;
336        Ok(())
337    }
338
339    fn finalize(mut self) -> Result<Self::Finalized> {
340        self.writer.flush().context("Failed to flush CSV file")?;
341        Ok(())
342    }
343
344    fn flush(&self) -> Result<()> {
345        // BufWriter doesn't provide a way to flush without mutable reference
346        // In practice, the buffer will be flushed when the file is closed
347        Ok(())
348    }
349
350    fn inspect(&self) -> Result<Option<Self::Finalized>> {
351        // For CSV storage, inspection does not produce a finalized result
352        self.flush()?;
353        Ok(None)
354    }
355}
356
357impl StorageConfig for CsvConfig {
358    type Storage = CsvTraceStorage;
359
360    fn new_trace<M: Math>(self, settings: &impl Settings, math: &M) -> Result<Self::Storage> {
361        // Generate parameter names and column mapping using coordinates
362        let (parameter_names, column_mapping) =
363            generate_parameter_names_and_mapping(settings, math)?;
364
365        Ok(CsvTraceStorage {
366            output_dir: self.output_dir,
367            precision: self.precision,
368            store_warmup: self.store_warmup,
369            parameter_names,
370            column_mapping,
371        })
372    }
373}
374
375/// Generate parameter column names and mapping using coordinates or Stan-compliant indexing
376fn generate_parameter_names_and_mapping<M: Math>(
377    settings: &impl Settings,
378    math: &M,
379) -> Result<(Vec<String>, Vec<(String, usize)>)> {
380    let data_dims = settings.data_dims_all(math);
381    let coords = math.coords();
382    let mut parameter_names = Vec::new();
383    let mut column_mapping = Vec::new();
384
385    for (var_name, var_dims) in data_dims {
386        let data_type = settings.data_type(math, &var_name);
387
388        // Only process vector types that could contain parameter values
389        if matches!(
390            data_type,
391            ItemType::F64 | ItemType::F32 | ItemType::I64 | ItemType::U64
392        ) {
393            let (column_names, indices) = generate_column_names_and_indices_for_variable(
394                &var_name, &var_dims, &coords, math,
395            )?;
396
397            for (name, index) in column_names.into_iter().zip(indices) {
398                parameter_names.push(name);
399                column_mapping.push((var_name.clone(), index));
400            }
401        }
402    }
403
404    // If no parameter names were generated, fall back to simple numbering
405    if parameter_names.is_empty() {
406        let dim_sizes = math.dim_sizes();
407        let param_count = dim_sizes.get("expanded_parameter").unwrap_or(&0);
408        for i in 0..*param_count {
409            parameter_names.push(format!("param_{}", i + 1));
410            // Try to find a data field that contains the parameters
411            let data_names = settings.data_names(math);
412            let mut found_field = false;
413            for data_name in &data_names {
414                let data_type = settings.data_type(math, data_name);
415                if matches!(
416                    data_type,
417                    ItemType::F64 | ItemType::F32 | ItemType::I64 | ItemType::U64
418                ) {
419                    column_mapping.push((data_name.clone(), i as usize));
420                    found_field = true;
421                    break;
422                }
423            }
424            if !found_field {
425                column_mapping.push(("unknown".to_string(), i as usize));
426            }
427        }
428    }
429
430    Ok((parameter_names, column_mapping))
431}
432
433/// Generate column names and indices for a single variable using its dimensions and coordinates
434fn generate_column_names_and_indices_for_variable<M: Math>(
435    var_name: &str,
436    var_dims: &[String],
437    coords: &HashMap<String, Value>,
438    math: &M,
439) -> Result<(Vec<String>, Vec<usize>)> {
440    let dim_sizes = math.dim_sizes();
441
442    if var_dims.is_empty() {
443        // Scalar variable
444        return Ok((vec![var_name.to_string()], vec![0]));
445    }
446
447    // Check if we have meaningful coordinate names for all dimensions
448    let has_meaningful_coords = var_dims.iter().all(|dim_name| {
449        coords.get(dim_name).is_some_and(
450            |coord_value| matches!(coord_value, Value::Strings(labels) if !labels.is_empty()),
451        )
452    });
453
454    // Get coordinate labels for each dimension
455    let mut dim_coords: Vec<Vec<String>> = Vec::new();
456    let mut dim_sizes_vec: Vec<usize> = Vec::new();
457
458    for dim_name in var_dims {
459        let size = *dim_sizes.get(dim_name).unwrap_or(&1) as usize;
460        dim_sizes_vec.push(size);
461
462        if has_meaningful_coords {
463            // Use coordinate names if available and meaningful
464            if let Some(coord_value) = coords.get(dim_name) {
465                match coord_value {
466                    Value::Strings(labels) => {
467                        dim_coords.push(labels.clone());
468                    }
469                    _ => {
470                        // Fallback to 1-based indexing (Stan format)
471                        dim_coords.push((1..=size).map(|i| i.to_string()).collect());
472                    }
473                }
474            } else {
475                // Fallback to 1-based indexing (Stan format)
476                dim_coords.push((1..=size).map(|i| i.to_string()).collect());
477            }
478        } else {
479            // Use Stan-compliant 1-based indexing
480            dim_coords.push((1..=size).map(|i| i.to_string()).collect());
481        }
482    }
483
484    // Generate Cartesian product using column-major order (Stan format)
485    let (coord_names, indices) =
486        cartesian_product_with_indices_column_major(&dim_coords, &dim_sizes_vec);
487
488    // Prepend variable name to each coordinate combination
489    let column_names: Vec<String> = coord_names
490        .into_iter()
491        .map(|coord| format!("{}.{}", var_name, coord))
492        .collect();
493
494    Ok((column_names, indices))
495}
496
497/// Compute the Cartesian product with column-major ordering (Stan format)
498///
499/// Stan uses what they call "column-major" ordering, but it's actually the same as
500/// row-major order: the first index changes slowest, last index changes fastest.
501/// For example, a 2x3 array produces: [1,1], [1,2], [1,3], [2,1], [2,2], [2,3]
502fn cartesian_product_with_indices_column_major(
503    coord_sets: &[Vec<String>],
504    dim_sizes: &[usize],
505) -> (Vec<String>, Vec<usize>) {
506    if coord_sets.is_empty() {
507        return (vec![], vec![]);
508    }
509
510    if coord_sets.len() == 1 {
511        let indices: Vec<usize> = (0..coord_sets[0].len()).collect();
512        return (coord_sets[0].clone(), indices);
513    }
514
515    let mut names = vec![];
516    let mut indices = vec![];
517
518    // Stan's "column-major" is actually row-major order
519    cartesian_product_recursive_with_indices(
520        coord_sets,
521        dim_sizes,
522        0,
523        &mut String::new(),
524        &mut vec![],
525        &mut names,
526        &mut indices,
527    );
528
529    (names, indices)
530}
531
532fn cartesian_product_recursive_with_indices(
533    coord_sets: &[Vec<String>],
534    dim_sizes: &[usize],
535    dim_idx: usize,
536    current_name: &mut String,
537    current_indices: &mut Vec<usize>,
538    result_names: &mut Vec<String>,
539    result_indices: &mut Vec<usize>,
540) {
541    if dim_idx == coord_sets.len() {
542        result_names.push(current_name.clone());
543        // Compute linear index from multi-dimensional indices
544        let mut linear_index = 0;
545        for (i, &idx) in current_indices.iter().enumerate() {
546            let mut stride = 1;
547            for &size in &dim_sizes[i + 1..] {
548                stride *= size;
549            }
550            linear_index += idx * stride;
551        }
552        result_indices.push(linear_index);
553        return;
554    }
555
556    let is_first_dim = dim_idx == 0;
557
558    for (coord_idx, coord) in coord_sets[dim_idx].iter().enumerate() {
559        let mut new_name = current_name.clone();
560        if !is_first_dim {
561            new_name.push('.');
562        }
563        new_name.push_str(coord);
564
565        current_indices.push(coord_idx);
566        cartesian_product_recursive_with_indices(
567            coord_sets,
568            dim_sizes,
569            dim_idx + 1,
570            &mut new_name,
571            current_indices,
572            result_names,
573            result_indices,
574        );
575        current_indices.pop();
576    }
577}
578
579impl TraceStorage for CsvTraceStorage {
580    type ChainStorage = CsvChainStorage;
581    type Finalized = ();
582
583    fn initialize_trace_for_chain(&self, chain_id: u64) -> Result<Self::ChainStorage> {
584        CsvChainStorage::new(
585            &self.output_dir,
586            chain_id,
587            self.precision,
588            self.store_warmup,
589            self.parameter_names.clone(),
590            self.column_mapping.clone(),
591        )
592    }
593
594    fn finalize(
595        self,
596        traces: Vec<Result<<Self::ChainStorage as ChainStorage>::Finalized>>,
597    ) -> Result<(Option<anyhow::Error>, Self::Finalized)> {
598        // Check for any errors in the chain finalizations
599        for trace_result in traces {
600            if let Err(err) = trace_result {
601                return Ok((Some(err), ()));
602            }
603        }
604        Ok((None, ()))
605    }
606
607    fn inspect(
608        &self,
609        traces: Vec<Result<Option<<Self::ChainStorage as ChainStorage>::Finalized>>>,
610    ) -> Result<(Option<anyhow::Error>, Self::Finalized)> {
611        // Check for any errors in the chain inspections
612        for trace_result in traces {
613            if let Err(err) = trace_result {
614                return Ok((Some(err), ()));
615            }
616        }
617        Ok((None, ()))
618    }
619}
620
621#[cfg(test)]
622mod tests {
623    use super::*;
624    use crate::{CpuLogpFunc, CpuMath, CpuMathError, DiagNutsSettings, LogpError, Model, Sampler};
625    use anyhow::Result;
626    use nuts_derive::Storable;
627    use nuts_storable::{HasDims, Value};
628    use rand::{Rng, RngExt};
629    use std::collections::HashMap;
630    use std::fs;
631    use std::path::Path;
632    use thiserror::Error;
633
634    #[allow(dead_code)]
635    #[derive(Debug, Error)]
636    enum TestLogpError {
637        #[error("Test error")]
638        Test,
639    }
640
641    impl LogpError for TestLogpError {
642        fn is_recoverable(&self) -> bool {
643            false
644        }
645    }
646
647    /// Test model with multi-dimensional coordinates
648    #[derive(Clone)]
649    struct MultiDimTestLogp {
650        dim_a: usize,
651        dim_b: usize,
652    }
653
654    impl HasDims for MultiDimTestLogp {
655        fn dim_sizes(&self) -> HashMap<String, u64> {
656            HashMap::from([
657                ("a".to_string(), self.dim_a as u64),
658                ("b".to_string(), self.dim_b as u64),
659            ])
660        }
661
662        fn coords(&self) -> HashMap<String, Value> {
663            HashMap::from([
664                (
665                    "a".to_string(),
666                    Value::Strings(vec!["x".to_string(), "y".to_string()]),
667                ),
668                (
669                    "b".to_string(),
670                    Value::Strings(vec!["alpha".to_string(), "beta".to_string()]),
671                ),
672            ])
673        }
674    }
675
676    #[derive(Storable)]
677    struct MultiDimExpandedDraw {
678        #[storable(dims("a", "b"))]
679        param_matrix: Vec<f64>,
680        scalar_value: f64,
681    }
682
683    impl CpuLogpFunc for MultiDimTestLogp {
684        type LogpError = TestLogpError;
685        type FlowParameters = ();
686        type ExpandedVector = MultiDimExpandedDraw;
687
688        fn dim(&self) -> usize {
689            self.dim_a * self.dim_b
690        }
691
692        fn logp(&mut self, x: &[f64], grad: &mut [f64]) -> Result<f64, Self::LogpError> {
693            let mut logp = 0.0;
694            for (i, &xi) in x.iter().enumerate() {
695                logp -= 0.5 * xi * xi;
696                grad[i] = -xi;
697            }
698            Ok(logp)
699        }
700
701        fn expand_vector<R: Rng + ?Sized>(
702            &mut self,
703            _rng: &mut R,
704            array: &[f64],
705        ) -> Result<Self::ExpandedVector, CpuMathError> {
706            Ok(MultiDimExpandedDraw {
707                param_matrix: array.to_vec(),
708                scalar_value: array.iter().sum(),
709            })
710        }
711
712        fn vector_coord(&self) -> Option<Value> {
713            Some(Value::Strings(
714                (0..self.dim()).map(|i| format!("theta{}", i + 1)).collect(),
715            ))
716        }
717    }
718
719    struct MultiDimTestModel {
720        math: CpuMath<MultiDimTestLogp>,
721    }
722
723    impl Model for MultiDimTestModel {
724        type Math<'model>
725            = CpuMath<MultiDimTestLogp>
726        where
727            Self: 'model;
728
729        fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
730            Ok(self.math.clone())
731        }
732
733        fn init_position<R: Rng + ?Sized>(&self, rng: &mut R, position: &mut [f64]) -> Result<()> {
734            for p in position.iter_mut() {
735                *p = rng.random_range(-1.0..1.0);
736            }
737            Ok(())
738        }
739    }
740
741    /// Test model without coordinates (fallback behavior)
742    #[derive(Clone)]
743    struct SimpleTestLogp {
744        dim: usize,
745    }
746
747    impl HasDims for SimpleTestLogp {
748        fn dim_sizes(&self) -> HashMap<String, u64> {
749            HashMap::from([("simple_param".to_string(), self.dim as u64)])
750        }
751        // No coords() method - should use fallback
752    }
753
754    #[derive(Storable)]
755    struct SimpleExpandedDraw {
756        #[storable(dims("simple_param"))]
757        values: Vec<f64>,
758    }
759
760    impl CpuLogpFunc for SimpleTestLogp {
761        type LogpError = TestLogpError;
762        type FlowParameters = ();
763        type ExpandedVector = SimpleExpandedDraw;
764
765        fn dim(&self) -> usize {
766            self.dim
767        }
768
769        fn logp(&mut self, x: &[f64], grad: &mut [f64]) -> Result<f64, Self::LogpError> {
770            let mut logp = 0.0;
771            for (i, &xi) in x.iter().enumerate() {
772                logp -= 0.5 * xi * xi;
773                grad[i] = -xi;
774            }
775            Ok(logp)
776        }
777
778        fn expand_vector<R: Rng + ?Sized>(
779            &mut self,
780            _rng: &mut R,
781            array: &[f64],
782        ) -> Result<Self::ExpandedVector, CpuMathError> {
783            Ok(SimpleExpandedDraw {
784                values: array.to_vec(),
785            })
786        }
787
788        fn vector_coord(&self) -> Option<Value> {
789            Some(Value::Strings(vec![
790                "param1".to_string(),
791                "param2".to_string(),
792                "param3".to_string(),
793            ]))
794        }
795    }
796
797    struct SimpleTestModel {
798        math: CpuMath<SimpleTestLogp>,
799    }
800
801    impl Model for SimpleTestModel {
802        type Math<'model>
803            = CpuMath<SimpleTestLogp>
804        where
805            Self: 'model;
806
807        fn math<R: Rng + ?Sized>(&self, _rng: &mut R) -> Result<Self::Math<'_>> {
808            Ok(self.math.clone())
809        }
810
811        fn init_position<R: Rng + ?Sized>(&self, rng: &mut R, position: &mut [f64]) -> Result<()> {
812            for p in position.iter_mut() {
813                *p = rng.random_range(-1.0..1.0);
814            }
815            Ok(())
816        }
817    }
818
819    fn read_csv_header(path: &Path) -> Result<String> {
820        let content = fs::read_to_string(path)?;
821        content
822            .lines()
823            .next()
824            .map(|s| s.to_string())
825            .ok_or_else(|| anyhow::anyhow!("Empty CSV file"))
826    }
827
828    #[test]
829    fn test_multidim_coordinate_naming() -> Result<()> {
830        let temp_dir = tempfile::tempdir()?;
831        let output_path = temp_dir.path().join("multidim_test");
832
833        // Create model with 2x2 parameter matrix
834        let model = MultiDimTestModel {
835            math: CpuMath::new(MultiDimTestLogp { dim_a: 2, dim_b: 2 }),
836        };
837
838        let mut settings = DiagNutsSettings::default();
839        settings.num_chains = 1;
840        settings.num_tune = 10;
841        settings.num_draws = 20;
842        settings.seed = 42;
843
844        let csv_config = CsvConfig::new(&output_path)
845            .with_precision(6)
846            .store_warmup(false);
847
848        let mut sampler = Some(Sampler::new(model, settings, csv_config, 1, None)?);
849
850        // Wait for sampling to complete
851        while let Some(sampler_) = sampler.take() {
852            match sampler_.wait_timeout(std::time::Duration::from_millis(100)) {
853                crate::SamplerWaitResult::Trace(_) => break,
854                crate::SamplerWaitResult::Timeout(s) => sampler = Some(s),
855                crate::SamplerWaitResult::Err(err, _) => return Err(err),
856            }
857        }
858
859        // Check that CSV file was created
860        let csv_file = output_path.join("chain_0.csv");
861        assert!(csv_file.exists());
862
863        // Check header contains expected coordinate names
864        let header = read_csv_header(&csv_file)?;
865
866        // Should contain Cartesian product: x.alpha, x.beta, y.alpha, y.beta
867        assert!(header.contains("param_matrix.x.alpha"));
868        assert!(header.contains("param_matrix.x.beta"));
869        assert!(header.contains("param_matrix.y.alpha"));
870        assert!(header.contains("param_matrix.y.beta"));
871        assert!(header.contains("scalar_value"));
872
873        // Verify column order (Cartesian product should be in correct order)
874        let columns: Vec<&str> = header.split(',').collect();
875        let param_columns: Vec<&str> = columns
876            .iter()
877            .filter(|col| col.starts_with("param_matrix."))
878            .cloned()
879            .collect();
880
881        assert_eq!(
882            param_columns,
883            vec![
884                "param_matrix.x.alpha",
885                "param_matrix.x.beta",
886                "param_matrix.y.alpha",
887                "param_matrix.y.beta"
888            ]
889        );
890
891        Ok(())
892    }
893
894    #[test]
895    fn test_fallback_coordinate_naming() -> Result<()> {
896        let temp_dir = tempfile::tempdir()?;
897        let output_path = temp_dir.path().join("simple_test");
898
899        // Create model with 3 parameters but no coordinate specification
900        let model = SimpleTestModel {
901            math: CpuMath::new(SimpleTestLogp { dim: 3 }),
902        };
903
904        let mut settings = DiagNutsSettings::default();
905        settings.num_chains = 1;
906        settings.num_tune = 5;
907        settings.num_draws = 10;
908        settings.seed = 123;
909
910        let csv_config = CsvConfig::new(&output_path)
911            .with_precision(6)
912            .store_warmup(false);
913
914        let mut sampler = Some(Sampler::new(model, settings, csv_config, 1, None)?);
915
916        // Wait for sampling to complete
917        while let Some(sampler_) = sampler.take() {
918            match sampler_.wait_timeout(std::time::Duration::from_millis(100)) {
919                crate::SamplerWaitResult::Trace(_) => break,
920                crate::SamplerWaitResult::Timeout(s) => sampler = Some(s),
921                crate::SamplerWaitResult::Err(err, _) => return Err(err),
922            }
923        }
924
925        // Check that CSV file was created
926        let csv_file = output_path.join("chain_0.csv");
927        assert!(csv_file.exists());
928
929        // Check header uses fallback numeric naming
930        let header = read_csv_header(&csv_file)?;
931
932        // Should fall back to 1-based indices since no coordinates provided
933        assert!(header.contains("values.1"));
934        assert!(header.contains("values.2"));
935        assert!(header.contains("values.3"));
936
937        Ok(())
938    }
939
940    #[test]
941    fn test_cartesian_product_generation() {
942        let coord_sets = vec![
943            vec!["x".to_string(), "y".to_string()],
944            vec!["alpha".to_string(), "beta".to_string()],
945        ];
946        let dim_sizes = vec![2, 2];
947
948        let (names, indices) = cartesian_product_with_indices_column_major(&coord_sets, &dim_sizes);
949
950        assert_eq!(names, vec!["x.alpha", "x.beta", "y.alpha", "y.beta"]);
951
952        assert_eq!(indices, vec![0, 1, 2, 3]);
953    }
954
955    #[test]
956    fn test_single_dimension_coordinates() {
957        let coord_sets = vec![vec!["param1".to_string(), "param2".to_string()]];
958        let dim_sizes = vec![2];
959
960        let (names, indices) = cartesian_product_with_indices_column_major(&coord_sets, &dim_sizes);
961
962        assert_eq!(names, vec!["param1", "param2"]);
963        assert_eq!(indices, vec![0, 1]);
964    }
965
966    #[test]
967    fn test_three_dimension_cartesian_product() {
968        let coord_sets = vec![
969            vec!["a".to_string(), "b".to_string()],
970            vec!["1".to_string()],
971            vec!["i".to_string(), "j".to_string()],
972        ];
973        let dim_sizes = vec![2, 1, 2];
974
975        let (names, indices) = cartesian_product_with_indices_column_major(&coord_sets, &dim_sizes);
976
977        assert_eq!(names, vec!["a.1.i", "a.1.j", "b.1.i", "b.1.j"]);
978
979        assert_eq!(indices, vec![0, 1, 2, 3]);
980    }
981}