use serde::{Deserialize, Serialize};
use crate::error::{ConfigError, MaeError};
use crate::virtual_aug::Xorshift64;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MaePretrainConfig {
pub mask_ratio: f64,
pub patch_time: usize,
pub patch_subc: usize,
pub seed: u64,
}
impl Default for MaePretrainConfig {
fn default() -> Self {
MaePretrainConfig {
mask_ratio: 0.80,
patch_time: 30,
patch_subc: 3,
seed: 42,
}
}
}
impl MaePretrainConfig {
pub fn validate(&self) -> Result<(), ConfigError> {
if !self.mask_ratio.is_finite() || self.mask_ratio <= 0.0 || self.mask_ratio >= 1.0 {
return Err(ConfigError::invalid_value(
"mask_ratio",
format!("must be in (0.0, 1.0), got {}", self.mask_ratio),
));
}
if self.patch_time == 0 {
return Err(ConfigError::invalid_value("patch_time", "must be >= 1"));
}
if self.patch_subc == 0 {
return Err(ConfigError::invalid_value("patch_subc", "must be >= 1"));
}
Ok(())
}
pub fn validate_for_window(&self, time: usize, subc: usize) -> Result<(), MaeError> {
check_axis("time", time, self.patch_time)?;
check_axis("subcarrier", subc, self.patch_subc)?;
Ok(())
}
#[must_use]
pub fn cropped_window_shape(&self, time: usize, subc: usize) -> (usize, usize) {
(
(time / self.patch_time) * self.patch_time,
(subc / self.patch_subc) * self.patch_subc,
)
}
pub fn num_patches(&self, time: usize, subc: usize) -> Result<usize, MaeError> {
self.validate_for_window(time, subc)?;
Ok((time / self.patch_time) * (subc / self.patch_subc))
}
#[must_use]
pub fn num_masked(&self, n_patches: usize) -> usize {
((self.mask_ratio * n_patches as f64).round() as usize).min(n_patches)
}
pub fn mask_window(
&self,
window: &[f32],
time: usize,
subc: usize,
) -> Result<(PatchGrid, MaskIndices), MaeError> {
let grid = patchify(window, time, subc, self)?;
let mask = random_mask(grid.n_patches(), self.mask_ratio, self.seed)?;
Ok((grid, mask))
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct PatchGrid {
pub patch_time: usize,
pub patch_subc: usize,
pub n_patches_time: usize,
pub n_patches_subc: usize,
pub patches: Vec<Vec<f32>>,
}
impl PatchGrid {
#[must_use]
pub fn n_patches(&self) -> usize {
self.n_patches_time * self.n_patches_subc
}
#[must_use]
pub fn patch_len(&self) -> usize {
self.patch_time * self.patch_subc
}
#[must_use]
pub fn window_shape(&self) -> (usize, usize) {
(
self.n_patches_time * self.patch_time,
self.n_patches_subc * self.patch_subc,
)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MaskIndices {
pub masked: Vec<usize>,
pub visible: Vec<usize>,
}
pub fn patchify(
window: &[f32],
time: usize,
subc: usize,
cfg: &MaePretrainConfig,
) -> Result<PatchGrid, MaeError> {
let expected = time * subc;
if window.len() != expected {
return Err(MaeError::WindowShapeMismatch {
time,
subc,
expected,
actual: window.len(),
});
}
cfg.validate_for_window(time, subc)?;
if let Some(idx) = window.iter().position(|v| !v.is_finite()) {
return Err(MaeError::NonFiniteValue {
row: idx / subc,
col: idx % subc,
value: window[idx],
});
}
let n_patches_time = time / cfg.patch_time;
let n_patches_subc = subc / cfg.patch_subc;
let mut patches = Vec::with_capacity(n_patches_time * n_patches_subc);
for pt in 0..n_patches_time {
for ps in 0..n_patches_subc {
let mut patch = Vec::with_capacity(cfg.patch_time * cfg.patch_subc);
for lt in 0..cfg.patch_time {
let t = pt * cfg.patch_time + lt;
let row_start = t * subc + ps * cfg.patch_subc;
patch.extend_from_slice(&window[row_start..row_start + cfg.patch_subc]);
}
patches.push(patch);
}
}
Ok(PatchGrid {
patch_time: cfg.patch_time,
patch_subc: cfg.patch_subc,
n_patches_time,
n_patches_subc,
patches,
})
}
#[must_use]
pub fn unpatchify(grid: &PatchGrid) -> Vec<f32> {
unpatchify_select(grid, None, 0.0)
}
#[must_use]
pub fn unpatchify_visible(grid: &PatchGrid, visible: &[usize], fill: f32) -> Vec<f32> {
unpatchify_select(grid, Some(visible), fill)
}
fn unpatchify_select(grid: &PatchGrid, keep: Option<&[usize]>, fill: f32) -> Vec<f32> {
let (time, subc) = grid.window_shape();
let mut window = vec![fill; time * subc];
for (p, patch) in grid.patches.iter().enumerate() {
if let Some(keep) = keep {
if !keep.contains(&p) {
continue;
}
}
let pt = p / grid.n_patches_subc;
let ps = p % grid.n_patches_subc;
for lt in 0..grid.patch_time {
let t = pt * grid.patch_time + lt;
let row_start = t * subc + ps * grid.patch_subc;
let local_start = lt * grid.patch_subc;
window[row_start..row_start + grid.patch_subc]
.copy_from_slice(&patch[local_start..local_start + grid.patch_subc]);
}
}
window
}
pub fn random_mask(n_patches: usize, mask_ratio: f64, seed: u64) -> Result<MaskIndices, MaeError> {
if !mask_ratio.is_finite() || mask_ratio <= 0.0 || mask_ratio >= 1.0 {
return Err(MaeError::InvalidMaskRatio { ratio: mask_ratio });
}
let n_masked = ((mask_ratio * n_patches as f64).round() as usize).min(n_patches);
let mut order: Vec<usize> = (0..n_patches).collect();
let mut rng = Xorshift64::new(seed);
for i in (1..n_patches).rev() {
let j = (rng.next_u64() % (i as u64 + 1)) as usize;
order.swap(i, j);
}
let mut masked: Vec<usize> = order[..n_masked].to_vec();
let mut visible: Vec<usize> = order[n_masked..].to_vec();
masked.sort_unstable();
visible.sort_unstable();
Ok(MaskIndices { masked, visible })
}
fn check_axis(axis: &'static str, window: usize, patch: usize) -> Result<(), MaeError> {
if patch > window {
return Err(MaeError::PatchExceedsWindow {
axis,
patch,
window,
});
}
let remainder = window % patch;
if remainder != 0 {
return Err(MaeError::NotDivisible {
axis,
window,
patch,
remainder,
crop: window - remainder,
});
}
Ok(())
}