use burn::prelude::*;
use ndarray::{Array3, Axis};
use std::error::Error;
use std::fs::File;
use csv::Writer;
pub fn save_csv<T: std::fmt::Display>(
data: &Array3<T>,
filename: &str,
) -> Result<(), Box<dyn Error>> {
let mut wtr = Writer::from_writer(File::create(filename)?);
let n_dims = data.shape()[2];
let mut header: Vec<String> = vec!["chain".to_string(), "observation".to_string()];
header.extend((0..n_dims).map(|i| format!("dim_{}", i)));
wtr.write_record(&header)?;
for (chain_idx, chain) in data.axis_iter(Axis(0)).enumerate() {
for (obs_idx, obs) in chain.axis_iter(Axis(0)).enumerate() {
let mut row = vec![chain_idx.to_string(), obs_idx.to_string()];
row.extend(obs.iter().map(|v| v.to_string()));
wtr.write_record(&row)?;
}
}
wtr.flush()?;
Ok(())
}
pub fn save_csv_tensor<B>(
tensor: burn::tensor::Tensor<B, 3>,
filename: &str,
) -> Result<(), Box<dyn Error>>
where
B: Backend,
{
use csv::Writer;
use std::fs::File;
let shape = tensor.dims(); let data = tensor.to_data();
let (num_chains, num_obs, num_dims) = (shape[0], shape[1], shape[2]);
let flat: Vec<f32> = data
.to_vec()
.map_err(|e| format!("Converting data to Vec failed.\nData: {data:?}.\nError: {e:?}"))?;
let mut wtr = Writer::from_writer(File::create(filename)?);
let mut header = vec!["chain".to_string(), "observation".to_string()];
header.extend((0..num_dims).map(|i| format!("dim_{}", i)));
wtr.write_record(&header)?;
for chain_idx in 0..num_chains {
for obs_idx in 0..num_obs {
let mut row = vec![chain_idx.to_string(), obs_idx.to_string()];
let offset = chain_idx * num_obs * num_dims + obs_idx * num_dims;
let row_slice = &flat[offset..offset + num_dims];
row.extend(row_slice.iter().map(|v| v.to_string()));
wtr.write_record(&row)?;
}
}
wtr.flush()?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::{ndarray::NdArrayDevice, NdArray};
use csv::Reader;
use ndarray::arr3;
use std::fs;
use tempfile::NamedTempFile;
#[test]
fn test_save_csv_empty_data() {
let data = arr3::<f32, 0, 0>(&[]);
let file = NamedTempFile::new().expect("Could not create temp file");
let filename = file.path().to_str().unwrap();
let result = save_csv(&data, filename);
assert!(
result.is_ok(),
"Saving empty data to CSV failed: {:?}",
result
);
let contents = fs::read_to_string(filename).unwrap();
assert_eq!(contents.trim(), "chain,observation");
}
#[test]
fn test_save_csv_single_chain_single_obs() {
let data = arr3(&[[[42.0]]]); let file = NamedTempFile::new().expect("Could not create temp file");
let filename = file.path().to_str().unwrap();
let result = save_csv(&data, filename);
assert!(
result.is_ok(),
"Saving single chain with single obs to CSV failed: {:?}",
result
);
let contents = fs::read_to_string(filename).unwrap();
let expected = "chain,observation,dim_0\n0,0,42";
assert_eq!(contents.trim(), expected);
}
#[test]
fn test_save_csv_multi_chain() {
let data = arr3(&[[[1, 2], [3, 4]], [[10, 20], [30, 40]]]);
let file = NamedTempFile::new().expect("Could not create temp file");
let filename = file.path().to_str().unwrap();
let result = save_csv(&data, filename);
assert!(result.is_ok());
let contents = fs::read_to_string(filename).unwrap();
let expected = "\
chain,observation,dim_0,dim_1
0,0,1,2
0,1,3,4
1,0,10,20
1,1,30,40";
assert_eq!(contents.trim(), expected);
}
#[test]
fn test_save_csv_tensor_data() -> Result<(), Box<dyn std::error::Error>> {
let tensor = Tensor::<NdArray, 3, burn::tensor::Float>::from_floats(
[[[1.0, 2.0], [3.0, 4.0]], [[1.1, 2.1], [3.1, 4.1]]],
&NdArrayDevice::Cpu,
);
let file = NamedTempFile::new()?;
let filename = file.path().to_str().unwrap();
save_csv_tensor(tensor, filename)?;
let contents = fs::read_to_string(filename)?;
let mut rdr = Reader::from_reader(contents.as_bytes());
let headers = rdr.headers()?;
assert_eq!(&headers[0], "chain");
assert_eq!(&headers[1], "observation");
assert_eq!(&headers[2], "dim_0");
assert_eq!(&headers[3], "dim_1");
let records: Vec<_> = rdr.records().collect::<Result<_, _>>()?;
assert_eq!(records.len(), 4);
let expected = [
vec!["0", "0", "1", "2"],
vec!["0", "1", "3", "4"],
vec!["1", "0", "1.1", "2.1"],
vec!["1", "1", "3.1", "4.1"],
];
for (record, exp) in records.iter().zip(expected.iter()) {
for (field, &exp_field) in record.iter().zip(exp.iter()) {
assert!(
field.contains(exp_field),
"Expected field '{}' to contain '{}'",
field,
exp_field
);
}
}
Ok(())
}
}