brainharmony 0.1.0

Brain-Harmony multimodal brain foundation model — inference in Rust with Burn ML
Documentation
//! # brainharmony-rs — Brain-Harmony multimodal brain foundation model inference in Rust
//!
//! Pure-Rust inference for the [Brain-Harmony](https://github.com/eugenehp/Brain-Harmony)
//! multimodal brain foundation model, built on [Burn 0.20](https://burn.dev).
//!
//! Brain-Harmony unifies morphology (T1 MRI) and function (fMRI) into 1D tokens
//! using a Vision Transformer with:
//! - **Brain gradient + geometric harmonics positioning** for spatial embeddings
//! - **Flexible patch embedding** via Conv2d with dynamic patch size
//! - **JEPA architecture** (encoder + predictor with momentum target)
//!
//! ## Three entry points
//!
//! | Type | Loads | Use case |
//! |---|---|---|
//! | [`BrainHarmonyEncoder`] | encoder only | produce latent embeddings |
//! | [`ClassificationHead`] | classification layer | downstream classification |
//! | [`MLPHead`] | 3-layer MLP head | stage 2 finetuning |
//!
//! ## Quick start — encode brain signal
//!
//! ```rust,ignore
//! use brainharmony::{BrainHarmonyEncoder, ModelConfig, DataConfig};
//!
//! let (enc, ms) = BrainHarmonyEncoder::<B>::from_weights(
//!     "model.safetensors",
//!     "gradient_mapping_400.csv",
//!     "schaefer400_roi_eigenmodes.csv",
//!     &ModelConfig::default(),
//!     &DataConfig::default(),
//!     &device,
//! )?;
//! let result = enc.encode_safetensors("data/signal.safetensors")?;
//! result.save_safetensors("embeddings.safetensors")?;
//! ```
//!
//! ## Backends
//!
//! | Feature | Backend | Notes |
//! |---|---|---|
//! | `ndarray` (default) | CPU (NdArray + Rayon) | Add `blas-accelerate` on macOS |
//! | `wgpu` | GPU (Metal / Vulkan) | `--no-default-features --features wgpu` |
//! | `wgpu-f16` | GPU (half precision) | `--no-default-features --features wgpu-f16` |

// -- Thread configuration ---------------------------------------------------------

/// Configure the global Rayon thread pool.
///
/// Call this **once**, before any model operations.
/// Returns the actual number of threads in the pool.
pub fn init_threads(n: Option<usize>) -> usize {
    let mut builder = rayon::ThreadPoolBuilder::new();
    if let Some(count) = n {
        if count > 0 {
            builder = builder.num_threads(count);
        }
    }
    let _ = builder.build_global();
    rayon::current_num_threads()
}

// -- Internal modules -------------------------------------------------------------

pub mod classification;
pub mod config;
pub mod csv_export;
pub mod data;
pub mod error;
pub mod hf_download;
pub mod inference;
pub mod masks;
pub mod model;
pub mod predictor_api;
pub mod prelude;
pub mod weights;

// -- Flat re-exports --------------------------------------------------------------

// Configs
pub use config::{DataConfig, ModelConfig, YamlConfig};

// Encoder-only inference
pub use inference::{BrainHarmonyEncoder, EmbeddingResult};

// Encoder + Predictor (JEPA evaluation)
pub use predictor_api::BrainHarmonyPredictor;

// Classification heads
pub use classification::{ClassificationHead, MLPHead, predict_classes};

// Data types
pub use data::{GradientData, GeohData, SignalInput};

// Masking
pub use masks::{MaskConfig, full_context_mask, jepa_masks};

// Errors
pub use error::{BrainHarmonyError, Result};

// Model internals (advanced usage)
pub use model::encoder::apply_masks;

// Weights
pub use weights::{WeightFilter, WeightMap};

// CSV export
pub use csv_export::save_embeddings_csv;

// HuggingFace download
pub use hf_download::{resolve as resolve_weights, ResolvedWeights, DEFAULT_REPO};