brainharmony 0.1.0

Brain-Harmony multimodal brain foundation model — inference in Rust with Burn ML
Documentation
/// Brain-Harmony encoder inference — produce brain embeddings.
///
/// The encoder maps brain signals to latent representations using the
/// FlexVisionTransformer with brain gradient + geometric harmonics
/// positional embeddings.
use std::path::Path;
use std::time::Instant;

use burn::prelude::*;

use crate::config::{DataConfig, ModelConfig};
use crate::data::{self, GradientData, GeohData, SignalInput};
use crate::error::BrainHarmonyError;
use crate::model::encoder::FlexVisionTransformer;
use crate::weights::{load_encoder_weights, WeightMap};

// -- Output types -----------------------------------------------------------------

/// Encoder embedding output.
pub struct EmbeddingResult {
    /// Latent embeddings: row-major f32, shape [n_patches, embed_dim]
    pub embeddings: Vec<f32>,
    /// Shape: [n_patches, embed_dim]
    pub shape: Vec<usize>,
    /// Number of ROI patches
    pub n_rois: usize,
    /// Number of temporal patches
    pub n_time_patches: usize,
    /// Encoding time in milliseconds
    pub ms_encode: f64,
}

impl EmbeddingResult {
    /// Total number of output patches (n_rois * n_time_patches).
    pub fn n_patches(&self) -> usize {
        self.n_rois * self.n_time_patches
    }

    /// Output embedding dimension.
    pub fn embed_dim(&self) -> usize {
        self.shape.get(1).copied().unwrap_or(0)
    }

    /// Save embeddings to a safetensors file.
    pub fn save_safetensors(&self, path: &str) -> anyhow::Result<()> {
        use safetensors::{Dtype, View};
        use std::borrow::Cow;

        struct RawTensor {
            data: Vec<u8>,
            shape: Vec<usize>,
        }
        impl View for RawTensor {
            fn dtype(&self) -> Dtype { Dtype::F32 }
            fn shape(&self) -> &[usize] { &self.shape }
            fn data(&self) -> Cow<'_, [u8]> { Cow::Borrowed(&self.data) }
            fn data_len(&self) -> usize { self.data.len() }
        }

        let bytes: Vec<u8> = self.embeddings.iter().flat_map(|f| f.to_le_bytes()).collect();
        let tensor = RawTensor {
            data: bytes,
            shape: self.shape.clone(),
        };

        let pairs: Vec<(&str, RawTensor)> = vec![("embeddings", tensor)];
        let out = safetensors::serialize(pairs, None)?;
        std::fs::write(path, out)?;
        Ok(())
    }
}

// -- BrainHarmonyEncoder ----------------------------------------------------------

/// Brain-Harmony encoder for producing latent embeddings from brain signals.
pub struct BrainHarmonyEncoder<B: Backend> {
    encoder: FlexVisionTransformer<B>,
    gradient: Tensor<B, 2>,
    geoh: Tensor<B, 2>,
    pub model_cfg: ModelConfig,
    pub data_cfg: DataConfig,
    device: B::Device,
}

impl<B: Backend> BrainHarmonyEncoder<B> {
    /// Load encoder from safetensors weights, gradient CSV, and geometric harmonics CSV.
    ///
    /// Returns `(encoder, weight_load_ms)`.
    pub fn from_weights(
        weights_path: &str,
        gradient_csv_path: &str,
        geoh_csv_path: &str,
        model_cfg: &ModelConfig,
        data_cfg: &DataConfig,
        device: &B::Device,
    ) -> anyhow::Result<(Self, f64)> {
        if !Path::new(weights_path).exists() {
            return Err(BrainHarmonyError::FileNotFound {
                kind: "weights",
                path: weights_path.into(),
            }
            .into());
        }

        let grad_data = GradientData::from_csv(gradient_csv_path)?;
        let geoh_data = GeohData::from_csv(geoh_csv_path)?;

        let expected_rois = data_cfg.n_cortical_rois;
        if grad_data.n_rois != expected_rois {
            return Err(BrainHarmonyError::GradientRoiMismatch {
                expected: expected_rois,
                got: grad_data.n_rois,
            }
            .into());
        }
        if geoh_data.n_rois != expected_rois {
            return Err(BrainHarmonyError::GeohRoiMismatch {
                expected: expected_rois,
                got: geoh_data.n_rois,
            }
            .into());
        }

        let gradient = grad_data.to_tensor::<B>(device);
        let geoh = geoh_data.to_tensor::<B>(device);

        let mut encoder = FlexVisionTransformer::new(
            data_cfg.signal_size,
            model_cfg.patch_size,
            1, // in_chans
            model_cfg.embed_dim,
            model_cfg.depth,
            model_cfg.num_heads,
            model_cfg.mlp_ratio,
            true, // qkv_bias
            model_cfg.norm_eps,
            model_cfg.grad_dim,
            model_cfg.geoh_dim,
            model_cfg.pred_emb_dim,
            &model_cfg.pos_mode,
            model_cfg.use_cls_token,
            false, // use_decoder for encoder-only mode
            device,
        )?;

        let t = Instant::now();
        let mut wm = WeightMap::from_file(weights_path)?;
        let prefix = if wm.has("target_encoder.blocks.0.norm1.weight") {
            "target_encoder"
        } else {
            "encoder"
        };
        load_encoder_weights(model_cfg, &mut wm, &mut encoder, prefix, device)?;
        let ms = t.elapsed().as_secs_f64() * 1000.0;

        println!("Loaded encoder weights ({} remaining keys)", wm.remaining());

        Ok((
            Self {
                encoder,
                gradient,
                geoh,
                model_cfg: model_cfg.clone(),
                data_cfg: data_cfg.clone(),
                device: device.clone(),
            },
            ms,
        ))
    }

    /// One-line description of the loaded encoder.
    pub fn describe(&self) -> String {
        format!(
            "Brain-Harmony encoder  embed_dim={}  depth={}  heads={}  patch={}",
            self.model_cfg.embed_dim,
            self.model_cfg.depth,
            self.model_cfg.num_heads,
            self.model_cfg.patch_size,
        )
    }

    /// Encode signal from a safetensors file.
    pub fn encode_safetensors(&self, path: &str) -> anyhow::Result<EmbeddingResult> {
        let input = data::load_signal_safetensors::<B>(path, &self.device)?;
        self.encode_input(input)
    }

    /// Encode signal from a CSV file.
    pub fn encode_csv(&self, csv_path: &str) -> anyhow::Result<EmbeddingResult> {
        let input = data::load_signal_csv::<B>(csv_path, &self.device)?;
        self.encode_input(input)
    }

    /// Encode a raw tensor input.
    pub fn encode_tensor(&self, data: Tensor<B, 4>) -> anyhow::Result<EmbeddingResult> {
        let [_b, _c, n_rois, signal_length] = data.dims();
        let input = SignalInput { data, n_rois, signal_length };
        self.encode_input(input)
    }

    /// Encode multiple safetensors files.
    pub fn encode_safetensors_batch(
        &self,
        paths: &[impl AsRef<str>],
    ) -> anyhow::Result<Vec<EmbeddingResult>> {
        paths
            .iter()
            .map(|p| {
                let input = data::load_signal_safetensors::<B>(p.as_ref(), &self.device)?;
                self.encode_input(input)
            })
            .collect()
    }

    /// Encode multiple CSV files.
    pub fn encode_csv_batch(
        &self,
        paths: &[impl AsRef<str>],
    ) -> anyhow::Result<Vec<EmbeddingResult>> {
        paths
            .iter()
            .map(|p| {
                let input = data::load_signal_csv::<B>(p.as_ref(), &self.device)?;
                self.encode_input(input)
            })
            .collect()
    }

    /// Reference to the Burn device this encoder was loaded on.
    pub fn device(&self) -> &B::Device {
        &self.device
    }

    fn encode_input(&self, input: SignalInput<B>) -> anyhow::Result<EmbeddingResult> {
        let x = data::standardize(input.data);

        let n_time_patches = input.signal_length / self.model_cfg.patch_size;

        let t = Instant::now();
        let enc_out = self.encoder.forward(
            x,
            Some(&self.gradient),
            Some(&self.geoh),
            None, // use default patch size
            None, // no masks
            None, // no attention mask
        );
        let ms_encode = t.elapsed().as_secs_f64() * 1000.0;

        let [_b, n_patches, embed_dim] = enc_out.dims();

        let embeddings = tensor_data_to_f32(enc_out.squeeze::<2>().into_data())
            .map_err(|e| BrainHarmonyError::TensorConversion { reason: e })?;

        Ok(EmbeddingResult {
            embeddings,
            shape: vec![n_patches, embed_dim],
            n_rois: input.n_rois,
            n_time_patches,
            ms_encode,
        })
    }
}

/// Convert TensorData bytes to Vec<f32>, handling both f32 and f16 element types.
fn tensor_data_to_f32(data: burn::tensor::TensorData) -> Result<Vec<f32>, String> {
    if let Ok(v) = data.to_vec::<f32>() {
        return Ok(v);
    }
    let converted = data.clone().convert::<f32>();
    if let Ok(v) = converted.to_vec::<f32>() {
        return Ok(v);
    }
    let bytes = &data.bytes;
    if bytes.len() % 2 == 0 {
        let values: Vec<f32> = bytes
            .chunks_exact(2)
            .map(|c| half::f16::from_le_bytes([c[0], c[1]]).to_f32())
            .collect();
        return Ok(values);
    }
    Err(format!("cannot convert tensor data ({} bytes) to f32", bytes.len()))
}