brainharmony 0.1.0

Brain-Harmony multimodal brain foundation model — inference in Rust with Burn ML
Documentation
/// Brain data loading and preprocessing utilities.
///
/// Handles loading brain embeddings, gradient mappings, and geometric
/// harmonics from CSV/safetensors.
use std::path::Path;

use burn::prelude::*;

use crate::error::BrainHarmonyError;

/// Brain gradient coordinates loaded from CSV.
#[derive(Debug)]
pub struct GradientData {
    /// Gradient values: [n_rois, grad_dim] as flat Vec
    pub values: Vec<f32>,
    pub n_rois: usize,
    pub grad_dim: usize,
}

impl GradientData {
    /// Load gradient mapping from a CSV file.
    ///
    /// Expected format: each row is an ROI, columns are gradient axes.
    pub fn from_csv(path: &str) -> crate::error::Result<Self> {
        load_csv_data(path, "gradient CSV")
    }

    /// Convert to a burn tensor: [n_rois, grad_dim]
    pub fn to_tensor<B: Backend>(&self, device: &B::Device) -> Tensor<B, 2> {
        Tensor::<B, 2>::from_data(
            TensorData::new(self.values.clone(), vec![self.n_rois, self.grad_dim]),
            device,
        )
    }
}

/// Geometric harmonics data loaded from CSV.
#[derive(Debug)]
pub struct GeohData {
    /// Geometric harmonics values: [n_rois, geoh_dim] as flat Vec
    pub values: Vec<f32>,
    pub n_rois: usize,
    pub geoh_dim: usize,
}

impl GeohData {
    /// Load geometric harmonics from a CSV file.
    ///
    /// Expected format: first row is header (skipped), first column is index (skipped).
    /// Remaining: each row is an ROI, columns are eigenmode values.
    pub fn from_csv(path: &str) -> crate::error::Result<Self> {
        let p = Path::new(path);
        if !p.exists() {
            return Err(BrainHarmonyError::FileNotFound {
                kind: "geometric harmonics CSV",
                path: p.to_path_buf(),
            });
        }

        let content = std::fs::read_to_string(p)?;
        let mut values = Vec::new();
        let mut n_rois = 0usize;
        let mut geoh_dim = 0usize;

        for (line_no, line) in content.lines().enumerate() {
            let line = line.trim();
            if line.is_empty() || line.starts_with('#') {
                continue;
            }
            // Skip header row (first non-comment line)
            if line_no == 0 {
                continue;
            }
            // Skip first column (index), parse remaining as floats
            let parts: Vec<f32> = line
                .split(',')
                .skip(1) // skip index column
                .filter_map(|s| s.trim().parse::<f32>().ok())
                .collect();
            if parts.is_empty() {
                continue;
            }
            if geoh_dim == 0 {
                geoh_dim = parts.len();
            } else if parts.len() != geoh_dim {
                return Err(BrainHarmonyError::InconsistentCsvRow {
                    path: p.to_path_buf(),
                    row: line_no + 1,
                    expected: geoh_dim,
                    got: parts.len(),
                });
            }
            values.extend_from_slice(&parts);
            n_rois += 1;
        }

        if n_rois == 0 {
            return Err(BrainHarmonyError::EmptyCsv {
                path: p.to_path_buf(),
            });
        }

        Ok(Self {
            values,
            n_rois,
            geoh_dim,
        })
    }

    /// Convert to a burn tensor: [n_rois, geoh_dim]
    pub fn to_tensor<B: Backend>(&self, device: &B::Device) -> Tensor<B, 2> {
        Tensor::<B, 2>::from_data(
            TensorData::new(self.values.clone(), vec![self.n_rois, self.geoh_dim]),
            device,
        )
    }
}

/// Preprocessed signal input ready for the model.
#[derive(Debug)]
pub struct SignalInput<B: Backend> {
    /// Signal data: [1, 1, n_rois, signal_length]
    pub data: Tensor<B, 4>,
    pub n_rois: usize,
    pub signal_length: usize,
}

/// Load fMRI signal from a safetensors file.
///
/// Expected key: "signal" or "fmri" with shape [B, 1, n_rois, n_time] or [n_rois, n_time].
pub fn load_signal_safetensors<B: Backend>(
    path: &str,
    device: &B::Device,
) -> anyhow::Result<SignalInput<B>> {
    let p = Path::new(path);
    if !p.exists() {
        return Err(BrainHarmonyError::FileNotFound {
            kind: "signal input",
            path: p.to_path_buf(),
        }
        .into());
    }

    let bytes = std::fs::read(p)?;
    let st = safetensors::SafeTensors::deserialize(&bytes)?;

    // Try "signal" key first, then "fmri"
    let key = if st.tensor("signal").is_ok() {
        "signal"
    } else {
        "fmri"
    };

    let view = st
        .tensor(key)
        .map_err(|e| anyhow::anyhow!("missing '{key}' key: {e}"))?;
    let shape = view.shape().to_vec();
    let data_bytes = view.data();

    let f32s: Vec<f32> = match view.dtype() {
        safetensors::Dtype::F32 => data_bytes
            .chunks_exact(4)
            .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
            .collect(),
        safetensors::Dtype::BF16 => data_bytes
            .chunks_exact(2)
            .map(|b| half::bf16::from_le_bytes([b[0], b[1]]).to_f32())
            .collect(),
        safetensors::Dtype::F16 => data_bytes
            .chunks_exact(2)
            .map(|b| half::f16::from_le_bytes([b[0], b[1]]).to_f32())
            .collect(),
        other => anyhow::bail!("unsupported dtype {:?}", other),
    };

    let (n_rois, signal_length, tensor) = match shape.len() {
        2 => {
            let t = Tensor::<B, 2>::from_data(
                TensorData::new(f32s, shape.clone()),
                device,
            );
            (shape[0], shape[1], t.unsqueeze_dim::<3>(0).unsqueeze_dim::<4>(0))
        }
        3 => {
            let t = Tensor::<B, 3>::from_data(
                TensorData::new(f32s, shape.clone()),
                device,
            );
            (shape[1], shape[2], t.unsqueeze_dim::<4>(1))
        }
        4 => {
            let t = Tensor::<B, 4>::from_data(
                TensorData::new(f32s, shape.clone()),
                device,
            );
            (shape[2], shape[3], t)
        }
        _ => anyhow::bail!("unexpected signal tensor rank: {}", shape.len()),
    };

    Ok(SignalInput {
        data: tensor,
        n_rois,
        signal_length,
    })
}

/// Load signal from a raw CSV (rows = ROIs, columns = time points).
pub fn load_signal_csv<B: Backend>(
    path: &str,
    device: &B::Device,
) -> crate::error::Result<SignalInput<B>> {
    let p = Path::new(path);
    if !p.exists() {
        return Err(BrainHarmonyError::FileNotFound {
            kind: "signal CSV",
            path: p.to_path_buf(),
        });
    }

    let content = std::fs::read_to_string(p)?;
    let mut values = Vec::new();
    let mut n_rois = 0usize;
    let mut n_time = 0usize;

    for (line_no, line) in content.lines().enumerate() {
        let line = line.trim();
        if line.is_empty() || line.starts_with('#') {
            continue;
        }
        let parts: Vec<f32> = line
            .split(',')
            .filter_map(|s| s.trim().parse::<f32>().ok())
            .collect();
        if parts.is_empty() {
            continue;
        }
        if n_time == 0 {
            n_time = parts.len();
        } else if parts.len() != n_time {
            return Err(BrainHarmonyError::InconsistentCsvRow {
                path: p.to_path_buf(),
                row: line_no + 1,
                expected: n_time,
                got: parts.len(),
            });
        }
        values.extend_from_slice(&parts);
        n_rois += 1;
    }

    if n_rois == 0 {
        return Err(BrainHarmonyError::EmptyCsv {
            path: p.to_path_buf(),
        });
    }

    let t = Tensor::<B, 2>::from_data(
        TensorData::new(values, vec![n_rois, n_time]),
        device,
    )
    .unsqueeze_dim::<3>(0)
    .unsqueeze_dim::<4>(0);

    Ok(SignalInput {
        data: t,
        n_rois,
        signal_length: n_time,
    })
}

/// Standardize data per sample: (x - mean) / std.
pub fn standardize<B: Backend>(x: Tensor<B, 4>) -> Tensor<B, 4> {
    let [b, c, h, w] = x.dims();
    let n = (b * c * h * w) as f32;
    let sum: f32 = x.clone().sum().into_scalar().elem();
    let mean = sum / n;
    let centered = x.sub_scalar(mean);
    let var_sum: f32 = centered.clone().powf_scalar(2.0f32).sum().into_scalar().elem();
    let std = (var_sum / n).sqrt() + 1e-8;
    centered.div_scalar(std)
}

/// Internal: load a CSV with rows=ROIs, cols=features.
fn load_csv_data(path: &str, kind: &'static str) -> crate::error::Result<GradientData> {
    let p = Path::new(path);
    if !p.exists() {
        return Err(BrainHarmonyError::FileNotFound {
            kind,
            path: p.to_path_buf(),
        });
    }

    let content = std::fs::read_to_string(p)?;
    let mut values = Vec::new();
    let mut n_rois = 0usize;
    let mut grad_dim = 0usize;

    for (line_no, line) in content.lines().enumerate() {
        let line = line.trim();
        if line.is_empty() || line.starts_with('#') {
            continue;
        }
        let parts: Vec<f32> = line
            .split(',')
            .filter_map(|s| s.trim().parse::<f32>().ok())
            .collect();
        if parts.is_empty() {
            continue;
        }
        if grad_dim == 0 {
            grad_dim = parts.len();
        } else if parts.len() != grad_dim {
            return Err(BrainHarmonyError::InconsistentCsvRow {
                path: p.to_path_buf(),
                row: line_no + 1,
                expected: grad_dim,
                got: parts.len(),
            });
        }
        values.extend_from_slice(&parts);
        n_rois += 1;
    }

    if n_rois == 0 {
        return Err(BrainHarmonyError::EmptyCsv {
            path: p.to_path_buf(),
        });
    }

    Ok(GradientData {
        values,
        n_rois,
        grad_dim,
    })
}