use std::path::Path;
use std::time::Instant;
use burn::prelude::*;
use crate::config::{DataConfig, ModelConfig};
use crate::data::{self, GradientData, GeohData, SignalInput};
use crate::error::BrainHarmonyError;
use crate::model::encoder::FlexVisionTransformer;
use crate::weights::{load_encoder_weights, WeightMap};
pub struct EmbeddingResult {
pub embeddings: Vec<f32>,
pub shape: Vec<usize>,
pub n_rois: usize,
pub n_time_patches: usize,
pub ms_encode: f64,
}
impl EmbeddingResult {
pub fn n_patches(&self) -> usize {
self.n_rois * self.n_time_patches
}
pub fn embed_dim(&self) -> usize {
self.shape.get(1).copied().unwrap_or(0)
}
pub fn save_safetensors(&self, path: &str) -> anyhow::Result<()> {
use safetensors::{Dtype, View};
use std::borrow::Cow;
struct RawTensor {
data: Vec<u8>,
shape: Vec<usize>,
}
impl View for RawTensor {
fn dtype(&self) -> Dtype { Dtype::F32 }
fn shape(&self) -> &[usize] { &self.shape }
fn data(&self) -> Cow<'_, [u8]> { Cow::Borrowed(&self.data) }
fn data_len(&self) -> usize { self.data.len() }
}
let bytes: Vec<u8> = self.embeddings.iter().flat_map(|f| f.to_le_bytes()).collect();
let tensor = RawTensor {
data: bytes,
shape: self.shape.clone(),
};
let pairs: Vec<(&str, RawTensor)> = vec![("embeddings", tensor)];
let out = safetensors::serialize(pairs, None)?;
std::fs::write(path, out)?;
Ok(())
}
}
pub struct BrainHarmonyEncoder<B: Backend> {
encoder: FlexVisionTransformer<B>,
gradient: Tensor<B, 2>,
geoh: Tensor<B, 2>,
pub model_cfg: ModelConfig,
pub data_cfg: DataConfig,
device: B::Device,
}
impl<B: Backend> BrainHarmonyEncoder<B> {
pub fn from_weights(
weights_path: &str,
gradient_csv_path: &str,
geoh_csv_path: &str,
model_cfg: &ModelConfig,
data_cfg: &DataConfig,
device: &B::Device,
) -> anyhow::Result<(Self, f64)> {
if !Path::new(weights_path).exists() {
return Err(BrainHarmonyError::FileNotFound {
kind: "weights",
path: weights_path.into(),
}
.into());
}
let grad_data = GradientData::from_csv(gradient_csv_path)?;
let geoh_data = GeohData::from_csv(geoh_csv_path)?;
let expected_rois = data_cfg.n_cortical_rois;
if grad_data.n_rois != expected_rois {
return Err(BrainHarmonyError::GradientRoiMismatch {
expected: expected_rois,
got: grad_data.n_rois,
}
.into());
}
if geoh_data.n_rois != expected_rois {
return Err(BrainHarmonyError::GeohRoiMismatch {
expected: expected_rois,
got: geoh_data.n_rois,
}
.into());
}
let gradient = grad_data.to_tensor::<B>(device);
let geoh = geoh_data.to_tensor::<B>(device);
let mut encoder = FlexVisionTransformer::new(
data_cfg.signal_size,
model_cfg.patch_size,
1, model_cfg.embed_dim,
model_cfg.depth,
model_cfg.num_heads,
model_cfg.mlp_ratio,
true, model_cfg.norm_eps,
model_cfg.grad_dim,
model_cfg.geoh_dim,
model_cfg.pred_emb_dim,
&model_cfg.pos_mode,
model_cfg.use_cls_token,
false, device,
)?;
let t = Instant::now();
let mut wm = WeightMap::from_file(weights_path)?;
let prefix = if wm.has("target_encoder.blocks.0.norm1.weight") {
"target_encoder"
} else {
"encoder"
};
load_encoder_weights(model_cfg, &mut wm, &mut encoder, prefix, device)?;
let ms = t.elapsed().as_secs_f64() * 1000.0;
println!("Loaded encoder weights ({} remaining keys)", wm.remaining());
Ok((
Self {
encoder,
gradient,
geoh,
model_cfg: model_cfg.clone(),
data_cfg: data_cfg.clone(),
device: device.clone(),
},
ms,
))
}
pub fn describe(&self) -> String {
format!(
"Brain-Harmony encoder embed_dim={} depth={} heads={} patch={}",
self.model_cfg.embed_dim,
self.model_cfg.depth,
self.model_cfg.num_heads,
self.model_cfg.patch_size,
)
}
pub fn encode_safetensors(&self, path: &str) -> anyhow::Result<EmbeddingResult> {
let input = data::load_signal_safetensors::<B>(path, &self.device)?;
self.encode_input(input)
}
pub fn encode_csv(&self, csv_path: &str) -> anyhow::Result<EmbeddingResult> {
let input = data::load_signal_csv::<B>(csv_path, &self.device)?;
self.encode_input(input)
}
pub fn encode_tensor(&self, data: Tensor<B, 4>) -> anyhow::Result<EmbeddingResult> {
let [_b, _c, n_rois, signal_length] = data.dims();
let input = SignalInput { data, n_rois, signal_length };
self.encode_input(input)
}
pub fn encode_safetensors_batch(
&self,
paths: &[impl AsRef<str>],
) -> anyhow::Result<Vec<EmbeddingResult>> {
paths
.iter()
.map(|p| {
let input = data::load_signal_safetensors::<B>(p.as_ref(), &self.device)?;
self.encode_input(input)
})
.collect()
}
pub fn encode_csv_batch(
&self,
paths: &[impl AsRef<str>],
) -> anyhow::Result<Vec<EmbeddingResult>> {
paths
.iter()
.map(|p| {
let input = data::load_signal_csv::<B>(p.as_ref(), &self.device)?;
self.encode_input(input)
})
.collect()
}
pub fn device(&self) -> &B::Device {
&self.device
}
fn encode_input(&self, input: SignalInput<B>) -> anyhow::Result<EmbeddingResult> {
let x = data::standardize(input.data);
let n_time_patches = input.signal_length / self.model_cfg.patch_size;
let t = Instant::now();
let enc_out = self.encoder.forward(
x,
Some(&self.gradient),
Some(&self.geoh),
None, None, None, );
let ms_encode = t.elapsed().as_secs_f64() * 1000.0;
let [_b, n_patches, embed_dim] = enc_out.dims();
let embeddings = tensor_data_to_f32(enc_out.squeeze::<2>().into_data())
.map_err(|e| BrainHarmonyError::TensorConversion { reason: e })?;
Ok(EmbeddingResult {
embeddings,
shape: vec![n_patches, embed_dim],
n_rois: input.n_rois,
n_time_patches,
ms_encode,
})
}
}
fn tensor_data_to_f32(data: burn::tensor::TensorData) -> Result<Vec<f32>, String> {
if let Ok(v) = data.to_vec::<f32>() {
return Ok(v);
}
let converted = data.clone().convert::<f32>();
if let Ok(v) = converted.to_vec::<f32>() {
return Ok(v);
}
let bytes = &data.bytes;
if bytes.len() % 2 == 0 {
let values: Vec<f32> = bytes
.chunks_exact(2)
.map(|c| half::f16::from_le_bytes([c[0], c[1]]).to_f32())
.collect();
return Ok(values);
}
Err(format!("cannot convert tensor data ({} bytes) to f32", bytes.len()))
}