use std::time::Instant;
use burn::prelude::*;
use crate::config::{DataConfig, ModelConfig};
use crate::data::{GradientData, GeohData};
use crate::error::BrainHarmonyError;
use crate::model::decoder::VisionTransformerPredictor;
use crate::model::encoder::FlexVisionTransformer;
use crate::weights::{load_encoder_weights, load_predictor_weights, WeightMap};
pub struct BrainHarmonyPredictor<B: Backend> {
pub encoder: FlexVisionTransformer<B>,
pub predictor: VisionTransformerPredictor<B>,
pub gradient: Tensor<B, 2>,
pub geoh: Tensor<B, 2>,
pub model_cfg: ModelConfig,
pub data_cfg: DataConfig,
device: B::Device,
}
impl<B: Backend> BrainHarmonyPredictor<B> {
pub fn from_weights(
weights_path: &str,
gradient_csv_path: &str,
geoh_csv_path: &str,
model_cfg: &ModelConfig,
data_cfg: &DataConfig,
device: &B::Device,
) -> anyhow::Result<(Self, f64)> {
if !std::path::Path::new(weights_path).exists() {
return Err(BrainHarmonyError::FileNotFound {
kind: "weights",
path: weights_path.into(),
}
.into());
}
let grad_data = GradientData::from_csv(gradient_csv_path)?;
let geoh_data = GeohData::from_csv(geoh_csv_path)?;
let expected_rois = data_cfg.n_cortical_rois;
if grad_data.n_rois != expected_rois {
return Err(BrainHarmonyError::GradientRoiMismatch {
expected: expected_rois,
got: grad_data.n_rois,
}
.into());
}
if geoh_data.n_rois != expected_rois {
return Err(BrainHarmonyError::GeohRoiMismatch {
expected: expected_rois,
got: geoh_data.n_rois,
}
.into());
}
let gradient = grad_data.to_tensor::<B>(device);
let geoh = geoh_data.to_tensor::<B>(device);
let mut encoder = FlexVisionTransformer::new(
data_cfg.signal_size,
model_cfg.patch_size,
1,
model_cfg.embed_dim,
model_cfg.depth,
model_cfg.num_heads,
model_cfg.mlp_ratio,
true,
model_cfg.norm_eps,
model_cfg.grad_dim,
model_cfg.geoh_dim,
model_cfg.pred_emb_dim,
&model_cfg.pos_mode,
model_cfg.use_cls_token,
true, device,
)?;
let num_patches_2d = encoder.patch_embed.num_patches_2d;
let mut predictor = VisionTransformerPredictor::new(
num_patches_2d,
model_cfg.embed_dim,
model_cfg.pred_emb_dim,
model_cfg.pred_depth,
model_cfg.num_heads,
model_cfg.mlp_ratio,
true,
model_cfg.norm_eps,
model_cfg.grad_dim,
model_cfg.geoh_dim,
&model_cfg.pos_mode,
model_cfg.use_cls_token,
device,
)?;
let t = Instant::now();
let mut wm = WeightMap::from_file(weights_path)?;
let enc_prefix = if wm.has("target_encoder.blocks.0.norm1.weight") {
"target_encoder"
} else {
"encoder"
};
load_encoder_weights(model_cfg, &mut wm, &mut encoder, enc_prefix, device)?;
load_predictor_weights(model_cfg, &mut wm, &mut predictor, "predictor", device)?;
let ms = t.elapsed().as_secs_f64() * 1000.0;
println!(
"Loaded encoder ({enc_prefix}) + predictor weights ({} remaining keys)",
wm.remaining()
);
Ok((
Self {
encoder,
predictor,
gradient,
geoh,
model_cfg: model_cfg.clone(),
data_cfg: data_cfg.clone(),
device: device.clone(),
},
ms,
))
}
pub fn predict(
&self,
x: Tensor<B, 4>,
enc_masks: &[Tensor<B, 2, Int>],
pred_masks: &[Tensor<B, 2, Int>],
) -> (Tensor<B, 3>, Tensor<B, 3>) {
let enc_out = self.encoder.forward(
x,
Some(&self.gradient),
Some(&self.geoh),
None,
Some(enc_masks),
None,
);
let pred_out = self.predictor.forward(
enc_out.clone(),
Some(&self.gradient),
Some(&self.geoh),
enc_masks,
pred_masks,
);
(enc_out, pred_out)
}
pub fn encode(&self, x: Tensor<B, 4>) -> Tensor<B, 3> {
self.encoder.forward(
x,
Some(&self.gradient),
Some(&self.geoh),
None,
None,
None,
)
}
pub fn describe(&self) -> String {
format!(
"Brain-Harmony encoder: {}-dim x {} layers predictor: {}-dim x {} layers",
self.model_cfg.embed_dim,
self.model_cfg.depth,
self.model_cfg.pred_emb_dim,
self.model_cfg.pred_depth,
)
}
pub fn device(&self) -> &B::Device {
&self.device
}
}