#[derive(Debug, Clone)]
pub struct MaskConfig {
pub n_rois: usize,
pub n_time_patches: usize,
pub enc_mask_scale: (f64, f64),
pub pred_mask_r_scale: (f64, f64),
pub pred_mask_t_scale: (f64, f64),
pub min_keep: usize,
pub seed: Option<u64>,
}
impl Default for MaskConfig {
fn default() -> Self {
Self {
n_rois: 450,
n_time_patches: 10,
enc_mask_scale: (0.84, 1.0),
pred_mask_r_scale: (0.45, 0.6),
pred_mask_t_scale: (0.0, 0.4),
min_keep: 4,
seed: None,
}
}
}
pub fn mask_config_for(n_rois: usize, n_time_patches: usize) -> MaskConfig {
let seed = std::env::var("BRAINJEPA_MASK_SEED")
.ok()
.and_then(|s| s.parse().ok())
.or(Some(42));
MaskConfig {
n_rois,
n_time_patches,
seed,
..MaskConfig::default()
}
}
pub fn full_context_mask(n_rois: usize, n_time_patches: usize) -> Vec<i64> {
let n = n_rois * n_time_patches;
(0..n as i64).collect()
}
pub fn random_block_mask(
n_rois: usize,
n_time_patches: usize,
roi_frac: f64,
time_frac: f64,
min_keep: usize,
) -> Vec<i64> {
let n_r = ((n_rois as f64 * roi_frac).round() as usize).max(1);
let n_t = ((n_time_patches as f64 * time_frac).round() as usize).max(1);
let r_start = fastrand::usize(..=(n_rois.saturating_sub(n_r)));
let t_start = fastrand::usize(..=(n_time_patches.saturating_sub(n_t)));
let mut indices = Vec::with_capacity(n_r * n_t);
for r in r_start..(r_start + n_r) {
for t in t_start..(t_start + n_t) {
indices.push((r * n_time_patches + t) as i64);
}
}
while indices.len() < min_keep {
let idx = fastrand::usize(..(n_rois * n_time_patches)) as i64;
if !indices.contains(&idx) {
indices.push(idx);
}
}
indices.sort();
indices
}
pub fn jepa_masks(cfg: &MaskConfig) -> (Vec<i64>, Vec<Vec<i64>>) {
if let Some(seed) = cfg.seed {
fastrand::seed(seed);
}
let n_r = cfg.n_rois;
let n_t = cfg.n_time_patches;
let n = n_r * n_t;
let enc_roi_frac = uniform(cfg.enc_mask_scale.0, cfg.enc_mask_scale.1);
let enc_time_frac = uniform(cfg.enc_mask_scale.0, cfg.enc_mask_scale.1);
let enc_mask = random_block_mask(n_r, n_t, enc_roi_frac, enc_time_frac, cfg.min_keep);
let enc_set: std::collections::HashSet<i64> = enc_mask.iter().copied().collect();
let complement: Vec<i64> = (0..n as i64).filter(|i| !enc_set.contains(i)).collect();
let mut pred_masks = Vec::with_capacity(3);
for _ in 0..3 {
let frac_r = uniform(cfg.pred_mask_r_scale.0, cfg.pred_mask_r_scale.1);
let frac_t = uniform(cfg.pred_mask_t_scale.0, cfg.pred_mask_t_scale.1);
let target_count = ((n as f64 * frac_r * frac_t).round() as usize)
.max(cfg.min_keep)
.min(complement.len());
let mut sampled = complement.clone();
shuffle(&mut sampled);
sampled.truncate(target_count);
sampled.sort();
pred_masks.push(sampled);
}
(enc_mask, pred_masks)
}
fn uniform(lo: f64, hi: f64) -> f64 {
lo + fastrand::f64() * (hi - lo)
}
fn shuffle(v: &mut [i64]) {
for i in (1..v.len()).rev() {
let j = fastrand::usize(..=i);
v.swap(i, j);
}
}