wifi-densepose-train 0.3.2

Training pipeline for WiFi-DensePose pose estimation
Documentation
//! Masked-autoencoder (MAE) pretraining recipe for the ADR-150 RF foundation
//! encoder — ADR-152 §2.3 (amends ADR-150 §2.3).
//!
//! Implements the *measured* tokenization recipe from the UNSW MAE pretraining
//! study (arXiv [2511.18792](https://arxiv.org/abs/2511.18792), Nov 2025), the
//! largest heterogeneous CSI pretraining run to date (1,320,892 samples, 14
//! public datasets, 4 devices, 2.4/5/6 GHz, 20–160 MHz):
//!
//! - **80% masking ratio** over the patch grid.
//! - **Small (30, 3) patches** — 30 time steps × 3 subcarriers — measured
//!   **+4.7%** over (40, 5) patches by preserving fine temporal dynamics.
//! - Encoder capacity stays **ViT-Small-class (~15M params)**: ViT-Base adds
//!   only +0.4–0.9% over ViT-Small in-study, corroborating ADR-150's own
//!   finding that capacity hurts cross-subject transfer.
//! - Unseen-domain performance scales **log-linearly with pretraining data,
//!   unsaturated at 1.3M samples** — data aggregation outranks architecture
//!   work (ADR-152 §2.3).
//!
//! This module provides the GPU-free half of the recipe: configuration,
//! patchification, and deterministic random masking. The (future, ADR-150)
//! encoder consumes [`PatchGrid`] + [`MaskIndices`] to compute the masked
//! reconstruction loss (`L_masked_csi` in ADR-150 §2.3's loss stack).
//!
//! ## Axis convention
//!
//! A CSI window is `time × subcarriers`, row-major (`index = t * subc + sc`),
//! matching the crate's `[T, …, n_sc]` dataset layout (time first, subcarriers
//! last) and the UNSW "(30 time steps, 3 subcarriers)" patch framing. Patches
//! are indexed row-major over the patch grid (`p = pt * n_patches_subc + ps`),
//! and values within a patch are row-major time-major
//! (`local = lt * patch_subc + lsc`).
//!
//! ## Divisibility policy: error, never truncate
//!
//! Window dimensions **must** be exact multiples of the patch dimensions.
//! Non-divisible shapes return [`MaeError::NotDivisible`] instead of silently
//! truncating trailing samples (this crate never silently drops data). The
//! error names the largest divisible crop; use
//! [`MaePretrainConfig::cropped_window_shape`] to compute it and crop
//! explicitly before calling [`patchify`].
//!
//! ## Example
//!
//! ```rust
//! use wifi_densepose_train::mae::MaePretrainConfig;
//!
//! let cfg = MaePretrainConfig::default(); // 0.80 masking, (30, 3) patches
//! cfg.validate().expect("default recipe is valid");
//!
//! // 90 frames × 54 subcarriers → a 3 × 18 grid of (30, 3) patches.
//! let window = vec![0.25_f32; 90 * 54];
//! let (grid, mask) = cfg.mask_window(&window, 90, 54).unwrap();
//! assert_eq!(grid.n_patches(), 54);
//! assert_eq!(mask.masked.len(), 43); // round(0.80 * 54)
//! assert_eq!(mask.visible.len(), 11);
//! ```

use serde::{Deserialize, Serialize};

use crate::error::{ConfigError, MaeError};
use crate::virtual_aug::Xorshift64;

// ---------------------------------------------------------------------------
// MaePretrainConfig
// ---------------------------------------------------------------------------

/// Hyper-parameters for masked-CSI pretraining (ADR-152 §2.3).
///
/// Defaults are the measured-optimal UNSW recipe (arXiv 2511.18792); change
/// them only with benchmark evidence. Serializable so the recipe is recorded
/// in checkpoint metadata alongside [`crate::config::TrainingConfig`].
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MaePretrainConfig {
    /// Fraction of patches hidden from the encoder, in `(0, 1)`.
    ///
    /// Default: **0.80** (UNSW measured optimum).
    pub mask_ratio: f64,

    /// Patch extent along the time axis, in frames. Default: **30**.
    pub patch_time: usize,

    /// Patch extent along the subcarrier axis. Default: **3**.
    pub patch_subc: usize,

    /// Base seed for the deterministic mask sampler. Default: **42**.
    ///
    /// For per-sample masks derive a child seed (e.g.
    /// `seed ^ sample_idx as u64`) and pass it to [`random_mask`]; reusing one
    /// seed yields the identical mask for every sample.
    pub seed: u64,
}

impl Default for MaePretrainConfig {
    fn default() -> Self {
        MaePretrainConfig {
            mask_ratio: 0.80,
            patch_time: 30,
            patch_subc: 3,
            seed: 42,
        }
    }
}

impl MaePretrainConfig {
    /// Validate the shape-independent fields.
    ///
    /// # Validated invariants
    ///
    /// - `mask_ratio` must be strictly inside `(0, 1)` and finite.
    /// - `patch_time` and `patch_subc` must be at least 1.
    pub fn validate(&self) -> Result<(), ConfigError> {
        if !self.mask_ratio.is_finite() || self.mask_ratio <= 0.0 || self.mask_ratio >= 1.0 {
            return Err(ConfigError::invalid_value(
                "mask_ratio",
                format!("must be in (0.0, 1.0), got {}", self.mask_ratio),
            ));
        }
        if self.patch_time == 0 {
            return Err(ConfigError::invalid_value("patch_time", "must be >= 1"));
        }
        if self.patch_subc == 0 {
            return Err(ConfigError::invalid_value("patch_subc", "must be >= 1"));
        }
        Ok(())
    }

    /// Check this recipe against a concrete `time × subc` window shape.
    ///
    /// Errors if a patch dimension exceeds the window or if either axis is
    /// not an exact multiple of the patch extent (divisibility policy above).
    pub fn validate_for_window(&self, time: usize, subc: usize) -> Result<(), MaeError> {
        check_axis("time", time, self.patch_time)?;
        check_axis("subcarrier", subc, self.patch_subc)?;
        Ok(())
    }

    /// Largest `(time, subc)` crop of the given window that is exactly
    /// divisible by the patch dimensions. Either component may be 0 when the
    /// window is smaller than one patch.
    #[must_use]
    pub fn cropped_window_shape(&self, time: usize, subc: usize) -> (usize, usize) {
        (
            (time / self.patch_time) * self.patch_time,
            (subc / self.patch_subc) * self.patch_subc,
        )
    }

    /// Number of patches a `time × subc` window yields under this recipe.
    pub fn num_patches(&self, time: usize, subc: usize) -> Result<usize, MaeError> {
        self.validate_for_window(time, subc)?;
        Ok((time / self.patch_time) * (subc / self.patch_subc))
    }

    /// Exact number of masked patches for a grid of `n_patches`:
    /// `round(mask_ratio * n_patches)`, clamped to `[0, n_patches]`.
    #[must_use]
    pub fn num_masked(&self, n_patches: usize) -> usize {
        ((self.mask_ratio * n_patches as f64).round() as usize).min(n_patches)
    }

    /// Patchify `window` and draw the deterministic random mask in one step,
    /// using `self.seed`. See [`patchify`] and [`random_mask`].
    ///
    /// # Errors
    ///
    /// Everything [`patchify`] rejects, plus [`MaeError::InvalidMaskRatio`]
    /// if `self.mask_ratio` is not finite or outside `(0, 1)` (the
    /// [`Self::validate`] rule) — a NaN ratio must never silently mask zero
    /// patches.
    pub fn mask_window(
        &self,
        window: &[f32],
        time: usize,
        subc: usize,
    ) -> Result<(PatchGrid, MaskIndices), MaeError> {
        let grid = patchify(window, time, subc, self)?;
        let mask = random_mask(grid.n_patches(), self.mask_ratio, self.seed)?;
        Ok((grid, mask))
    }
}

// ---------------------------------------------------------------------------
// PatchGrid / MaskIndices
// ---------------------------------------------------------------------------

/// A CSI window decomposed into non-overlapping `patch_time × patch_subc`
/// patches (see the module-level axis convention).
#[derive(Debug, Clone, PartialEq)]
pub struct PatchGrid {
    /// Patch extent along the time axis.
    pub patch_time: usize,
    /// Patch extent along the subcarrier axis.
    pub patch_subc: usize,
    /// Number of patch rows (`time / patch_time`).
    pub n_patches_time: usize,
    /// Number of patch columns (`subc / patch_subc`).
    pub n_patches_subc: usize,
    /// Flattened patches, row-major over the grid; each inner `Vec` is one
    /// patch of length `patch_time * patch_subc`, row-major time-major.
    pub patches: Vec<Vec<f32>>,
}

impl PatchGrid {
    /// Total number of patches in the grid.
    #[must_use]
    pub fn n_patches(&self) -> usize {
        self.n_patches_time * self.n_patches_subc
    }

    /// Number of scalar values per patch.
    #[must_use]
    pub fn patch_len(&self) -> usize {
        self.patch_time * self.patch_subc
    }

    /// Window shape `(time, subc)` this grid reconstructs to.
    #[must_use]
    pub fn window_shape(&self) -> (usize, usize) {
        (
            self.n_patches_time * self.patch_time,
            self.n_patches_subc * self.patch_subc,
        )
    }
}

/// Sorted, disjoint patch-index sets produced by [`random_mask`]. Together
/// they cover `0..n_patches` exactly.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MaskIndices {
    /// Indices of patches hidden from the encoder (`round(ratio * n)` of them).
    pub masked: Vec<usize>,
    /// Indices of patches the encoder sees.
    pub visible: Vec<usize>,
}

// ---------------------------------------------------------------------------
// patchify / unpatchify
// ---------------------------------------------------------------------------

/// Decompose a row-major `time × subc` CSI window into the patch grid defined
/// by `cfg`.
///
/// # Errors
///
/// - [`MaeError::WindowShapeMismatch`] if `window.len() != time * subc`.
/// - [`MaeError::PatchExceedsWindow`] / [`MaeError::NotDivisible`] per the
///   module-level divisibility policy.
/// - [`MaeError::NonFiniteValue`] on the first NaN/±inf encountered —
///   corrupted CSI must be cleaned upstream, never masked over (cf. the
///   WiFlow-STD NaN-poisoning incident, ADR-152 §2.2).
pub fn patchify(
    window: &[f32],
    time: usize,
    subc: usize,
    cfg: &MaePretrainConfig,
) -> Result<PatchGrid, MaeError> {
    let expected = time * subc;
    if window.len() != expected {
        return Err(MaeError::WindowShapeMismatch {
            time,
            subc,
            expected,
            actual: window.len(),
        });
    }
    cfg.validate_for_window(time, subc)?;
    if let Some(idx) = window.iter().position(|v| !v.is_finite()) {
        return Err(MaeError::NonFiniteValue {
            row: idx / subc,
            col: idx % subc,
            value: window[idx],
        });
    }

    let n_patches_time = time / cfg.patch_time;
    let n_patches_subc = subc / cfg.patch_subc;
    let mut patches = Vec::with_capacity(n_patches_time * n_patches_subc);
    for pt in 0..n_patches_time {
        for ps in 0..n_patches_subc {
            let mut patch = Vec::with_capacity(cfg.patch_time * cfg.patch_subc);
            for lt in 0..cfg.patch_time {
                let t = pt * cfg.patch_time + lt;
                let row_start = t * subc + ps * cfg.patch_subc;
                patch.extend_from_slice(&window[row_start..row_start + cfg.patch_subc]);
            }
            patches.push(patch);
        }
    }

    Ok(PatchGrid {
        patch_time: cfg.patch_time,
        patch_subc: cfg.patch_subc,
        n_patches_time,
        n_patches_subc,
        patches,
    })
}

/// Reassemble the full row-major `time × subc` window from a [`PatchGrid`].
/// Exact inverse of [`patchify`].
#[must_use]
pub fn unpatchify(grid: &PatchGrid) -> Vec<f32> {
    unpatchify_select(grid, None, 0.0)
}

/// Reassemble the window keeping only the patches listed in `visible`;
/// every other patch's region is filled with `fill` (the standard MAE
/// "visible tokens + mask token" view of the input).
#[must_use]
pub fn unpatchify_visible(grid: &PatchGrid, visible: &[usize], fill: f32) -> Vec<f32> {
    unpatchify_select(grid, Some(visible), fill)
}

fn unpatchify_select(grid: &PatchGrid, keep: Option<&[usize]>, fill: f32) -> Vec<f32> {
    let (time, subc) = grid.window_shape();
    let mut window = vec![fill; time * subc];
    for (p, patch) in grid.patches.iter().enumerate() {
        if let Some(keep) = keep {
            if !keep.contains(&p) {
                continue;
            }
        }
        let pt = p / grid.n_patches_subc;
        let ps = p % grid.n_patches_subc;
        for lt in 0..grid.patch_time {
            let t = pt * grid.patch_time + lt;
            let row_start = t * subc + ps * grid.patch_subc;
            let local_start = lt * grid.patch_subc;
            window[row_start..row_start + grid.patch_subc]
                .copy_from_slice(&patch[local_start..local_start + grid.patch_subc]);
        }
    }
    window
}

// ---------------------------------------------------------------------------
// random_mask
// ---------------------------------------------------------------------------

/// Draw a deterministic random mask over `n_patches` patches.
///
/// Exactly `round(mask_ratio * n_patches)` patches (clamped to
/// `[0, n_patches]`) are masked, chosen by a seeded Fisher–Yates shuffle
/// ([`Xorshift64`]), so the same `(n_patches, mask_ratio, seed)` triple always
/// yields the same mask. Both index lists are sorted ascending, disjoint, and
/// together cover `0..n_patches`.
///
/// # Errors
///
/// [`MaeError::InvalidMaskRatio`] if `mask_ratio` is not finite or outside
/// the open interval `(0, 1)` — the same rule as
/// [`MaePretrainConfig::validate`]. Erroring (never clamping) keeps the
/// module's error-not-silent policy: a NaN ratio would otherwise silently
/// mask zero patches and a ratio ≥ 1 would mask everything.
pub fn random_mask(n_patches: usize, mask_ratio: f64, seed: u64) -> Result<MaskIndices, MaeError> {
    if !mask_ratio.is_finite() || mask_ratio <= 0.0 || mask_ratio >= 1.0 {
        return Err(MaeError::InvalidMaskRatio { ratio: mask_ratio });
    }
    let n_masked = ((mask_ratio * n_patches as f64).round() as usize).min(n_patches);
    let mut order: Vec<usize> = (0..n_patches).collect();
    let mut rng = Xorshift64::new(seed);
    for i in (1..n_patches).rev() {
        let j = (rng.next_u64() % (i as u64 + 1)) as usize;
        order.swap(i, j);
    }
    let mut masked: Vec<usize> = order[..n_masked].to_vec();
    let mut visible: Vec<usize> = order[n_masked..].to_vec();
    masked.sort_unstable();
    visible.sort_unstable();
    Ok(MaskIndices { masked, visible })
}

// ---------------------------------------------------------------------------
// helpers
// ---------------------------------------------------------------------------

fn check_axis(axis: &'static str, window: usize, patch: usize) -> Result<(), MaeError> {
    if patch > window {
        return Err(MaeError::PatchExceedsWindow {
            axis,
            patch,
            window,
        });
    }
    let remainder = window % patch;
    if remainder != 0 {
        return Err(MaeError::NotDivisible {
            axis,
            window,
            patch,
            remainder,
            crop: window - remainder,
        });
    }
    Ok(())
}