use std::fmt;
use std::fs::File;
use std::io::{BufWriter, Write};
use crate::physics::{PhysicalData, PhysicalQuantity};
use crate::solver::SimulationResult;
use super::Exporter;
#[derive(Debug)]
pub enum CsvError {
Io(std::io::Error),
EmptyResult,
SpeciesCountMismatch { expected: usize, got: usize },
}
impl fmt::Display for CsvError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CsvError::Io(e) => write!(f, "CSV I/O error: {e}"),
CsvError::EmptyResult => {
write!(
f,
"CSV export failed: SimulationResult contains no time points"
)
}
CsvError::SpeciesCountMismatch { expected, got } => write!(
f,
"CSV export failed: expected {expected} species names, got {got}"
),
}
}
}
impl std::error::Error for CsvError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
CsvError::Io(e) => Some(e),
_ => None,
}
}
}
impl From<std::io::Error> for CsvError {
fn from(e: std::io::Error) -> Self {
CsvError::Io(e)
}
}
#[derive(Debug, Clone)]
pub struct CsvConfig {
pub separator: char,
pub precision: usize,
}
impl Default for CsvConfig {
fn default() -> Self {
Self {
separator: ';', precision: 6, }
}
}
#[derive(Debug, Clone, Default)]
pub struct CsvExporter {
pub config: CsvConfig,
}
impl CsvExporter {
pub fn new(config: CsvConfig) -> Self {
Self { config }
}
}
impl Exporter for CsvExporter {
type Error = CsvError;
fn export_single(
&self,
result: &SimulationResult,
n_points: Option<usize>,
path: &str,
) -> Result<(), CsvError> {
if result.time_points.is_empty() {
return Err(CsvError::EmptyResult);
}
let indices = compute_sample_indices(result.time_points.len(), n_points);
let file = File::create(path)?;
let mut writer = BufWriter::new(file);
let sep = self.config.separator;
let prec = self.config.precision;
writeln!(writer, "time (s){sep}c_outlet (mol/L)")?;
for idx in indices {
let t = result.time_points[idx];
let c = extract_scalar_concentration(&result.state_trajectory[idx])?;
writeln!(writer, "{t:.prec$e}{sep}{c:.prec$e}")?;
}
writer.flush()?;
Ok(())
}
fn export_multi(
&self,
result: &SimulationResult,
n_points: Option<usize>,
species_names: &[&str],
path: &str,
) -> Result<(), CsvError> {
if result.time_points.is_empty() {
return Err(CsvError::EmptyResult);
}
let n_species = extract_vector_concentrations(&result.state_trajectory[0])?.len();
if species_names.len() != n_species {
return Err(CsvError::SpeciesCountMismatch {
expected: n_species,
got: species_names.len(),
});
}
let indices = compute_sample_indices(result.time_points.len(), n_points);
let file = File::create(path)?;
let mut writer = BufWriter::new(file);
let sep = self.config.separator;
let prec = self.config.precision;
write!(writer, "time (s){sep}c_total (mol/L)")?;
for name in species_names {
write!(writer, "{sep}{name} (mol/L)")?;
}
writeln!(writer)?;
for idx in indices {
let t = result.time_points[idx];
let concentrations = extract_vector_concentrations(&result.state_trajectory[idx])?;
let c_total: f64 = concentrations.iter().sum();
write!(writer, "{t:.prec$e}{sep}{c_total:.prec$e}")?;
for c in &concentrations {
write!(writer, "{sep}{c:.prec$e}")?;
}
writeln!(writer)?;
}
writer.flush()?;
Ok(())
}
}
fn compute_sample_indices(total: usize, n_points: Option<usize>) -> Vec<usize> {
match n_points {
None => (0..total).collect(),
Some(n) if n == 0 || n >= total => {
(0..total).collect()
}
Some(1) => {
vec![0]
}
Some(n) => {
let mut indices = Vec::with_capacity(n);
for i in 0..n {
let idx = (i * (total - 1)) / (n - 1);
indices.push(idx);
}
if let Some(last) = indices.last_mut() {
*last = total - 1;
}
indices
}
}
}
fn extract_scalar_concentration(state: &crate::physics::PhysicalState) -> Result<f64, CsvError> {
let data = state
.get(PhysicalQuantity::Concentration)
.ok_or(CsvError::EmptyResult)?;
let value = match data {
PhysicalData::Scalar(c) => *c,
PhysicalData::Vector(v) => {
*v.iter().next_back().unwrap_or(&0.0)
}
_ => 0.0,
};
Ok(value)
}
fn extract_vector_concentrations(
state: &crate::physics::PhysicalState,
) -> Result<Vec<f64>, CsvError> {
let data = state
.get(PhysicalQuantity::Concentration)
.ok_or(CsvError::EmptyResult)?;
let values = match data {
PhysicalData::Vector(v) => v.iter().copied().collect(),
PhysicalData::Matrix(m) => {
let last_row = m.nrows().checked_sub(1).ok_or(CsvError::EmptyResult)?;
(0..m.ncols()).map(|s| m[(last_row, s)]).collect()
}
PhysicalData::Scalar(c) => vec![*c],
_ => vec![],
};
Ok(values)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::physics::{PhysicalData, PhysicalQuantity, PhysicalState};
use crate::solver::SimulationResult;
fn make_single_result(n_steps: usize) -> SimulationResult {
let time_points: Vec<f64> = (0..n_steps).map(|i| i as f64 * 0.1).collect();
let state_trajectory: Vec<PhysicalState> = (0..n_steps)
.map(|i| {
let c = i as f64 * 0.001;
PhysicalState::new(PhysicalQuantity::Concentration, PhysicalData::Scalar(c))
})
.collect();
let final_state = state_trajectory.last().unwrap().clone();
SimulationResult::new(time_points, state_trajectory, final_state)
}
fn make_multi_result(n_steps: usize, n_species: usize) -> SimulationResult {
let time_points: Vec<f64> = (0..n_steps).map(|i| i as f64 * 0.1).collect();
let state_trajectory: Vec<PhysicalState> = (0..n_steps)
.map(|i| {
use nalgebra::DVector;
let concs: Vec<f64> = (0..n_species)
.map(|s| i as f64 * 0.001 + s as f64 * 0.0001)
.collect();
PhysicalState::new(
PhysicalQuantity::Concentration,
PhysicalData::Vector(DVector::from_vec(concs)),
)
})
.collect();
let final_state = state_trajectory.last().unwrap().clone();
SimulationResult::new(time_points, state_trajectory, final_state)
}
#[test]
fn test_sample_indices_none_returns_all() {
let indices = compute_sample_indices(10, None);
assert_eq!(indices, (0..10).collect::<Vec<_>>());
}
#[test]
fn test_sample_indices_n_equals_total() {
let indices = compute_sample_indices(5, Some(5));
assert_eq!(indices.len(), 5);
assert_eq!(indices[0], 0);
assert_eq!(*indices.last().unwrap(), 4);
}
#[test]
fn test_sample_indices_n_greater_than_total() {
let indices = compute_sample_indices(5, Some(100));
assert_eq!(indices.len(), 5);
}
#[test]
fn test_sample_indices_n_one() {
let indices = compute_sample_indices(100, Some(1));
assert_eq!(indices, vec![0]);
}
#[test]
fn test_sample_indices_uniform_and_last_included() {
let indices = compute_sample_indices(100, Some(5));
assert_eq!(indices.len(), 5);
assert_eq!(indices[0], 0);
assert_eq!(*indices.last().unwrap(), 99); }
#[test]
fn test_sample_indices_stride_correctness() {
let indices = compute_sample_indices(10, Some(3));
assert_eq!(indices.len(), 3);
assert_eq!(indices[0], 0);
assert_eq!(indices[1], 4); assert_eq!(indices[2], 9); }
#[test]
fn test_export_single_creates_file() {
let result = make_single_result(10);
let exporter = CsvExporter::default();
let path = "/tmp/test_single_creates.csv";
exporter.export_single(&result, None, path).unwrap();
assert!(std::path::Path::new(path).exists());
std::fs::remove_file(path).ok();
}
#[test]
fn test_export_single_header() {
let result = make_single_result(5);
let exporter = CsvExporter::default();
let path = "/tmp/test_single_header.csv";
exporter.export_single(&result, None, path).unwrap();
let content = std::fs::read_to_string(path).unwrap();
let first_line = content.lines().next().unwrap();
assert_eq!(first_line, "time (s);c_outlet (mol/L)");
std::fs::remove_file(path).ok();
}
#[test]
fn test_export_single_row_count_all_points() {
let n_steps = 20;
let result = make_single_result(n_steps);
let exporter = CsvExporter::default();
let path = "/tmp/test_single_rows_all.csv";
exporter.export_single(&result, None, path).unwrap();
let content = std::fs::read_to_string(path).unwrap();
let line_count = content.lines().count();
assert_eq!(line_count, n_steps + 1);
std::fs::remove_file(path).ok();
}
#[test]
fn test_export_single_row_count_subsampled() {
let result = make_single_result(100);
let exporter = CsvExporter::default();
let path = "/tmp/test_single_rows_sub.csv";
exporter.export_single(&result, Some(5), path).unwrap();
let content = std::fs::read_to_string(path).unwrap();
let line_count = content.lines().count();
assert_eq!(line_count, 6);
std::fs::remove_file(path).ok();
}
#[test]
fn test_export_single_values_correctness() {
let result = make_single_result(3);
let exporter = CsvExporter::default();
let path = "/tmp/test_single_values.csv";
exporter.export_single(&result, None, path).unwrap();
let content = std::fs::read_to_string(path).unwrap();
let lines: Vec<&str> = content.lines().collect();
assert!(lines[1].starts_with("0.000000e0;0.000000e0"));
assert!(lines[2].contains("1.000000e-1")); assert!(lines[2].contains("1.000000e-3"));
std::fs::remove_file(path).ok();
}
#[test]
fn test_export_single_empty_result_error() {
let result = SimulationResult::new(vec![], vec![], PhysicalState::empty());
let exporter = CsvExporter::default();
let err = exporter
.export_single(&result, None, "/tmp/unused.csv")
.unwrap_err();
assert!(matches!(err, CsvError::EmptyResult));
}
#[test]
fn test_export_single_invalid_path_error() {
let result = make_single_result(5);
let exporter = CsvExporter::default();
let err = exporter
.export_single(&result, None, "/nonexistent_dir/file.csv")
.unwrap_err();
assert!(matches!(err, CsvError::Io(_)));
}
#[test]
fn test_export_multi_header() {
let result = make_multi_result(5, 2);
let exporter = CsvExporter::default();
let path = "/tmp/test_multi_header.csv";
exporter
.export_multi(&result, None, &["Ascorbic", "Erythorbic"], path)
.unwrap();
let content = std::fs::read_to_string(path).unwrap();
let first_line = content.lines().next().unwrap();
assert_eq!(
first_line,
"time (s);c_total (mol/L);Ascorbic (mol/L);Erythorbic (mol/L)"
);
std::fs::remove_file(path).ok();
}
#[test]
fn test_export_multi_envelope_equals_sum() {
let n_species = 3;
let result = make_multi_result(10, n_species);
let exporter = CsvExporter::default();
let path = "/tmp/test_multi_envelope.csv";
let names = vec!["A", "B", "C"];
exporter.export_multi(&result, None, &names, path).unwrap();
let content = std::fs::read_to_string(path).unwrap();
for line in content.lines().skip(1) {
let cols: Vec<f64> = line
.split(';')
.skip(1) .map(|s| s.trim().parse::<f64>().unwrap())
.collect();
let c_total = cols[0];
let sum: f64 = cols[1..].iter().sum();
assert!(
(c_total - sum).abs() < 1e-12,
"c_total {c_total} != sum {sum}"
);
}
std::fs::remove_file(path).ok();
}
#[test]
fn test_export_multi_species_count_mismatch() {
let result = make_multi_result(5, 2);
let exporter = CsvExporter::default();
let err = exporter
.export_multi(&result, None, &["A", "B", "C"], "/tmp/unused.csv")
.unwrap_err();
assert!(matches!(
err,
CsvError::SpeciesCountMismatch {
expected: 2,
got: 3
}
));
}
#[test]
fn test_export_multi_subsampled() {
let result = make_multi_result(50, 2);
let exporter = CsvExporter::default();
let path = "/tmp/test_multi_sub.csv";
exporter
.export_multi(&result, Some(10), &["A", "B"], path)
.unwrap();
let content = std::fs::read_to_string(path).unwrap();
assert_eq!(content.lines().count(), 11);
std::fs::remove_file(path).ok();
}
#[test]
fn test_custom_separator() {
let result = make_single_result(3);
let config = CsvConfig {
separator: ',',
precision: 6,
};
let exporter = CsvExporter::new(config);
let path = "/tmp/test_custom_sep.csv";
exporter.export_single(&result, None, path).unwrap();
let content = std::fs::read_to_string(path).unwrap();
let first_line = content.lines().next().unwrap();
assert_eq!(first_line, "time (s),c_outlet (mol/L)");
std::fs::remove_file(path).ok();
}
#[test]
fn test_csvconfig_default() {
let config = CsvConfig::default();
assert_eq!(config.separator, ';');
assert_eq!(config.precision, 6);
}
#[test]
fn test_csv_error_display_empty() {
let err = CsvError::EmptyResult;
let msg = err.to_string();
assert!(msg.contains("no time points"));
}
#[test]
fn test_csv_error_display_mismatch() {
let err = CsvError::SpeciesCountMismatch {
expected: 2,
got: 5,
};
let msg = err.to_string();
assert!(msg.contains("expected 2") && msg.contains("got 5"));
}
#[test]
fn test_csv_error_source_io() {
use std::error::Error;
let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
let err = CsvError::Io(io_err);
assert!(err.source().is_some());
}
#[test]
fn test_csv_error_source_empty() {
use std::error::Error;
let err = CsvError::EmptyResult;
assert!(err.source().is_none());
}
}