use burn::prelude::*;
#[derive(Debug, Clone)]
pub struct MaskConfig {
pub n_rois: usize,
pub n_time_patches: usize,
pub enc_mask_scale: (f64, f64),
pub pred_mask_r_scale: (f64, f64),
pub pred_mask_t_scale: (f64, f64),
pub min_keep: usize,
pub seed: Option<u64>,
}
impl Default for MaskConfig {
fn default() -> Self {
Self {
n_rois: 400,
n_time_patches: 18,
enc_mask_scale: (0.84, 1.0),
pred_mask_r_scale: (0.45, 0.6),
pred_mask_t_scale: (0.0, 0.4),
min_keep: 4,
seed: None,
}
}
}
pub fn full_context_mask<B: Backend>(
n_rois: usize,
n_time_patches: usize,
device: &B::Device,
) -> Tensor<B, 2, Int> {
let n = n_rois * n_time_patches;
let indices: Vec<i64> = (0..n as i64).collect();
Tensor::<B, 1, Int>::from_data(TensorData::new(indices, vec![n]), device)
.unsqueeze_dim::<2>(0)
}
pub fn random_block_mask<B: Backend>(
n_rois: usize,
n_time_patches: usize,
roi_frac: f64,
time_frac: f64,
min_keep: usize,
device: &B::Device,
) -> Tensor<B, 2, Int> {
let n_r = ((n_rois as f64 * roi_frac).round() as usize).max(1);
let n_t = ((n_time_patches as f64 * time_frac).round() as usize).max(1);
let r_start = fastrand::usize(..=(n_rois.saturating_sub(n_r)));
let t_start = fastrand::usize(..=(n_time_patches.saturating_sub(n_t)));
let mut indices = Vec::with_capacity(n_r * n_t);
for r in r_start..(r_start + n_r) {
for t in t_start..(t_start + n_t) {
indices.push((r * n_time_patches + t) as i64);
}
}
while indices.len() < min_keep {
let idx = fastrand::usize(..(n_rois * n_time_patches)) as i64;
if !indices.contains(&idx) {
indices.push(idx);
}
}
indices.sort();
let k = indices.len();
Tensor::<B, 1, Int>::from_data(TensorData::new(indices, vec![k]), device)
.unsqueeze_dim::<2>(0)
}
pub fn jepa_masks<B: Backend>(
cfg: &MaskConfig,
device: &B::Device,
) -> (Tensor<B, 2, Int>, Vec<Tensor<B, 2, Int>>) {
if let Some(seed) = cfg.seed {
fastrand::seed(seed);
}
let n_r = cfg.n_rois;
let n_t = cfg.n_time_patches;
let n = n_r * n_t;
let enc_roi_frac = uniform(cfg.enc_mask_scale.0, cfg.enc_mask_scale.1);
let enc_time_frac = uniform(cfg.enc_mask_scale.0, cfg.enc_mask_scale.1);
let enc_mask = random_block_mask::<B>(n_r, n_t, enc_roi_frac, enc_time_frac, cfg.min_keep, device);
let enc_data = enc_mask.clone().squeeze::<1>().into_data();
let enc_indices: Vec<i64> = enc_data.to_vec::<i64>().unwrap_or_default();
let enc_set: std::collections::HashSet<i64> = enc_indices.into_iter().collect();
let complement: Vec<i64> = (0..n as i64).filter(|i| !enc_set.contains(i)).collect();
let mut pred_masks = Vec::with_capacity(3);
for _ in 0..3 {
let frac_r = uniform(cfg.pred_mask_r_scale.0, cfg.pred_mask_r_scale.1);
let frac_t = uniform(cfg.pred_mask_t_scale.0, cfg.pred_mask_t_scale.1);
let target_count = ((n as f64 * frac_r * frac_t).round() as usize)
.max(cfg.min_keep)
.min(complement.len());
let mut sampled = complement.clone();
fastrand_shuffle(&mut sampled);
sampled.truncate(target_count);
sampled.sort();
let k = sampled.len();
let mask = Tensor::<B, 1, Int>::from_data(
TensorData::new(sampled, vec![k]),
device,
)
.unsqueeze_dim::<2>(0);
pred_masks.push(mask);
}
(enc_mask, pred_masks)
}
fn uniform(lo: f64, hi: f64) -> f64 {
lo + fastrand::f64() * (hi - lo)
}
fn fastrand_shuffle(v: &mut [i64]) {
for i in (1..v.len()).rev() {
let j = fastrand::usize(..=i);
v.swap(i, j);
}
}