use brainharmony::EmbeddingResult;
use std::io::Read;
fn fake_embedding_result(n_rois: usize, n_time: usize, embed_dim: usize) -> EmbeddingResult {
let n_patches = n_rois * n_time;
let embeddings: Vec<f32> = (0..n_patches * embed_dim)
.map(|i| i as f32 * 0.1)
.collect();
EmbeddingResult {
embeddings,
shape: vec![n_patches, embed_dim],
n_rois,
n_time_patches: n_time,
ms_encode: 42.0,
}
}
fn read_file(path: &str) -> String {
let mut s = String::new();
std::fs::File::open(path)
.unwrap()
.read_to_string(&mut s)
.unwrap();
s
}
#[test]
fn csv_export_header_and_row_count() {
let result = fake_embedding_result(3, 2, 4);
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("out.csv");
let path_str = path.to_str().unwrap();
brainharmony::save_embeddings_csv(&result, path_str).unwrap();
let content = read_file(path_str);
let lines: Vec<&str> = content.lines().collect();
assert_eq!(lines[0], "dim_0,dim_1,dim_2,dim_3");
assert_eq!(lines.len(), 7);
}
#[test]
fn csv_export_values_correct() {
let result = fake_embedding_result(2, 1, 3);
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("vals.csv");
let path_str = path.to_str().unwrap();
brainharmony::save_embeddings_csv(&result, path_str).unwrap();
let content = read_file(path_str);
let lines: Vec<&str> = content.lines().collect();
let row1: Vec<f32> = lines[1].split(',').map(|s| s.parse().unwrap()).collect();
assert_eq!(row1.len(), 3);
assert!((row1[0] - 0.0).abs() < 1e-6);
assert!((row1[1] - 0.1).abs() < 1e-6);
assert!((row1[2] - 0.2).abs() < 1e-6);
}
#[test]
fn csv_with_metadata_has_roi_and_time_columns() {
let result = fake_embedding_result(3, 2, 4);
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("meta.csv");
let path_str = path.to_str().unwrap();
brainharmony::csv_export::save_embeddings_csv_with_metadata(&result, path_str, 3, 2).unwrap();
let content = read_file(path_str);
let lines: Vec<&str> = content.lines().collect();
assert_eq!(lines[0], "roi_idx,time_idx,dim_0,dim_1,dim_2,dim_3");
assert_eq!(lines.len(), 7);
}
#[test]
fn csv_with_metadata_roi_time_values() {
let result = fake_embedding_result(2, 3, 2);
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("meta2.csv");
let path_str = path.to_str().unwrap();
brainharmony::csv_export::save_embeddings_csv_with_metadata(&result, path_str, 2, 3).unwrap();
let content = read_file(path_str);
let lines: Vec<&str> = content.lines().collect();
let expected_pairs = vec![
(0, 0), (0, 1), (0, 2),
(1, 0), (1, 1), (1, 2),
];
for (i, &(roi, time)) in expected_pairs.iter().enumerate() {
let cols: Vec<&str> = lines[i + 1].split(',').collect();
assert_eq!(cols[0].parse::<usize>().unwrap(), roi);
assert_eq!(cols[1].parse::<usize>().unwrap(), time);
}
}
#[test]
fn csv_with_metadata_column_count() {
let embed_dim = 5;
let result = fake_embedding_result(2, 2, embed_dim);
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("cols.csv");
let path_str = path.to_str().unwrap();
brainharmony::csv_export::save_embeddings_csv_with_metadata(&result, path_str, 2, 2).unwrap();
let content = read_file(path_str);
let lines: Vec<&str> = content.lines().collect();
for line in &lines {
let ncols = line.split(',').count();
assert_eq!(ncols, embed_dim + 2, "wrong column count in: {line}");
}
}