use anyhow::{bail, Context, Result};
use ndarray::Array2;
use std::path::Path;
pub fn read_eeg<P: AsRef<Path>>(path: P) -> Result<(Array2<f32>, Vec<String>, f32)> {
let content = std::fs::read_to_string(path.as_ref())
.with_context(|| format!("reading CSV: {}", path.as_ref().display()))?;
let lines: Vec<&str> = content.lines()
.filter(|l| !l.trim().is_empty() && !l.trim_start().starts_with('#'))
.collect();
if lines.is_empty() {
bail!("CSV file is empty");
}
let delim = detect_delimiter(lines[0]);
let first_fields: Vec<&str> = lines[0].split(delim).map(|s| s.trim()).collect();
let has_header = first_fields.iter().any(|f| f.parse::<f64>().is_err());
let (header, data_lines) = if has_header {
let hdr: Vec<String> = first_fields.iter().map(|s| s.to_string()).collect();
(Some(hdr), &lines[1..])
} else {
(None, &lines[..])
};
if data_lines.is_empty() {
bail!("CSV file has no data rows");
}
let n_cols = first_fields.len();
let n_rows = data_lines.len();
let mut raw_data = Vec::with_capacity(n_rows * n_cols);
for (line_idx, line) in data_lines.iter().enumerate() {
let fields: Vec<&str> = line.split(delim).map(|s| s.trim()).collect();
if fields.len() != n_cols {
bail!("Row {} has {} columns, expected {}", line_idx + 1, fields.len(), n_cols);
}
for field in &fields {
let val = field.parse::<f64>()
.with_context(|| format!("parsing value '{}' at row {}", field, line_idx + 1))?;
raw_data.push(val as f32);
}
}
let mut timestamp_col: Option<usize> = None;
'col_loop: for col in 0..n_cols {
if n_rows < 2 { break; }
let mut prev = raw_data[col];
for row in 1..n_rows {
let val = raw_data[row * n_cols + col];
if val <= prev {
continue 'col_loop;
}
prev = val;
}
timestamp_col = Some(col);
break;
}
let sfreq = if let Some(ts_col) = timestamp_col {
if n_rows >= 2 {
let dt = raw_data[n_cols + ts_col] - raw_data[ts_col];
if dt > 0.0 { 1.0 / dt } else { 0.0 }
} else {
0.0
}
} else {
0.0
};
let data_cols: Vec<usize> = (0..n_cols)
.filter(|&c| Some(c) != timestamp_col)
.collect();
let n_ch = data_cols.len();
let mut data = Array2::<f32>::zeros((n_ch, n_rows));
for (ch_out, &col) in data_cols.iter().enumerate() {
for row in 0..n_rows {
data[[ch_out, row]] = raw_data[row * n_cols + col];
}
}
let ch_names: Vec<String> = if let Some(ref hdr) = header {
data_cols.iter().map(|&c| hdr[c].clone()).collect()
} else {
(0..n_ch).map(|i| format!("Ch{}", i)).collect()
};
Ok((data, ch_names, sfreq))
}
fn detect_delimiter(line: &str) -> char {
let tab_count = line.matches('\t').count();
let comma_count = line.matches(',').count();
let semi_count = line.matches(';').count();
if tab_count >= comma_count && tab_count >= semi_count && tab_count > 0 {
'\t'
} else if semi_count > comma_count && semi_count > 0 {
';'
} else {
','
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
#[test]
fn read_csv_with_header_and_timestamp() {
let dir = std::env::temp_dir().join("exg_csv_test");
std::fs::create_dir_all(&dir).unwrap();
let path = dir.join("test.csv");
let mut f = std::fs::File::create(&path).unwrap();
writeln!(f, "time,ch1,ch2").unwrap();
writeln!(f, "0.000,1.0,2.0").unwrap();
writeln!(f, "0.004,3.0,4.0").unwrap();
writeln!(f, "0.008,5.0,6.0").unwrap();
let (data, names, sfreq) = read_eeg(&path).unwrap();
assert_eq!(names, vec!["ch1", "ch2"]);
assert_eq!(data.dim(), (2, 3));
approx::assert_abs_diff_eq!(sfreq, 250.0, epsilon = 1.0);
approx::assert_abs_diff_eq!(data[[0, 0]], 1.0, epsilon = 1e-6);
approx::assert_abs_diff_eq!(data[[1, 2]], 6.0, epsilon = 1e-6);
std::fs::remove_dir_all(&dir).ok();
}
#[test]
fn read_csv_no_header() {
let dir = std::env::temp_dir().join("exg_csv_test2");
std::fs::create_dir_all(&dir).unwrap();
let path = dir.join("test2.csv");
let mut f = std::fs::File::create(&path).unwrap();
writeln!(f, "5.0,2.0,9.0").unwrap();
writeln!(f, "1.0,1.0,3.0").unwrap();
let (data, names, _sfreq) = read_eeg(&path).unwrap();
assert_eq!(names.len(), 3);
assert_eq!(data.dim(), (3, 2));
std::fs::remove_dir_all(&dir).ok();
}
#[test]
fn detect_delimiter_works() {
assert_eq!(detect_delimiter("a,b,c"), ',');
assert_eq!(detect_delimiter("a\tb\tc"), '\t');
assert_eq!(detect_delimiter("a;b;c"), ';');
}
}