brainharmony 0.1.0

Brain-Harmony multimodal brain foundation model — inference in Rust with Burn ML
Documentation
use brainharmony::EmbeddingResult;
use std::io::Read;

fn fake_embedding_result(n_rois: usize, n_time: usize, embed_dim: usize) -> EmbeddingResult {
    let n_patches = n_rois * n_time;
    let embeddings: Vec<f32> = (0..n_patches * embed_dim)
        .map(|i| i as f32 * 0.1)
        .collect();
    EmbeddingResult {
        embeddings,
        shape: vec![n_patches, embed_dim],
        n_rois,
        n_time_patches: n_time,
        ms_encode: 42.0,
    }
}

fn read_file(path: &str) -> String {
    let mut s = String::new();
    std::fs::File::open(path)
        .unwrap()
        .read_to_string(&mut s)
        .unwrap();
    s
}

#[test]
fn csv_export_header_and_row_count() {
    let result = fake_embedding_result(3, 2, 4);
    let dir = tempfile::tempdir().unwrap();
    let path = dir.path().join("out.csv");
    let path_str = path.to_str().unwrap();

    brainharmony::save_embeddings_csv(&result, path_str).unwrap();

    let content = read_file(path_str);
    let lines: Vec<&str> = content.lines().collect();

    assert_eq!(lines[0], "dim_0,dim_1,dim_2,dim_3");
    assert_eq!(lines.len(), 7);
}

#[test]
fn csv_export_values_correct() {
    let result = fake_embedding_result(2, 1, 3);
    let dir = tempfile::tempdir().unwrap();
    let path = dir.path().join("vals.csv");
    let path_str = path.to_str().unwrap();

    brainharmony::save_embeddings_csv(&result, path_str).unwrap();

    let content = read_file(path_str);
    let lines: Vec<&str> = content.lines().collect();

    let row1: Vec<f32> = lines[1].split(',').map(|s| s.parse().unwrap()).collect();
    assert_eq!(row1.len(), 3);
    assert!((row1[0] - 0.0).abs() < 1e-6);
    assert!((row1[1] - 0.1).abs() < 1e-6);
    assert!((row1[2] - 0.2).abs() < 1e-6);
}

#[test]
fn csv_with_metadata_has_roi_and_time_columns() {
    let result = fake_embedding_result(3, 2, 4);
    let dir = tempfile::tempdir().unwrap();
    let path = dir.path().join("meta.csv");
    let path_str = path.to_str().unwrap();

    brainharmony::csv_export::save_embeddings_csv_with_metadata(&result, path_str, 3, 2).unwrap();

    let content = read_file(path_str);
    let lines: Vec<&str> = content.lines().collect();

    assert_eq!(lines[0], "roi_idx,time_idx,dim_0,dim_1,dim_2,dim_3");
    assert_eq!(lines.len(), 7);
}

#[test]
fn csv_with_metadata_roi_time_values() {
    let result = fake_embedding_result(2, 3, 2);
    let dir = tempfile::tempdir().unwrap();
    let path = dir.path().join("meta2.csv");
    let path_str = path.to_str().unwrap();

    brainharmony::csv_export::save_embeddings_csv_with_metadata(&result, path_str, 2, 3).unwrap();

    let content = read_file(path_str);
    let lines: Vec<&str> = content.lines().collect();

    let expected_pairs = vec![
        (0, 0), (0, 1), (0, 2),
        (1, 0), (1, 1), (1, 2),
    ];
    for (i, &(roi, time)) in expected_pairs.iter().enumerate() {
        let cols: Vec<&str> = lines[i + 1].split(',').collect();
        assert_eq!(cols[0].parse::<usize>().unwrap(), roi);
        assert_eq!(cols[1].parse::<usize>().unwrap(), time);
    }
}

#[test]
fn csv_with_metadata_column_count() {
    let embed_dim = 5;
    let result = fake_embedding_result(2, 2, embed_dim);
    let dir = tempfile::tempdir().unwrap();
    let path = dir.path().join("cols.csv");
    let path_str = path.to_str().unwrap();

    brainharmony::csv_export::save_embeddings_csv_with_metadata(&result, path_str, 2, 2).unwrap();

    let content = read_file(path_str);
    let lines: Vec<&str> = content.lines().collect();

    for line in &lines {
        let ncols = line.split(',').count();
        assert_eq!(ncols, embed_dim + 2, "wrong column count in: {line}");
    }
}