burn_autogaze 0.21.6

AutoGaze inference, fixation traces, and crisp mask visualization for Burn
Documentation
use crate::FixationPoint;
use crate::config::AutoGazeConfig;

#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct AutoGazeScaleTokenLayout {
    pub token_count: usize,
    pub grid: usize,
}

#[derive(Clone, Debug)]
pub(super) struct GenerationCoverageTracker {
    scale_layouts: Vec<AutoGazeScaleTokenLayout>,
    grid: usize,
    covered: Vec<bool>,
    pub(super) covered_count: usize,
    stop_cells: usize,
}

impl GenerationCoverageTracker {
    pub(super) fn new(scale_layouts: &[AutoGazeScaleTokenLayout], stop_ratio: f64) -> Option<Self> {
        if !stop_ratio.is_finite() || stop_ratio <= 0.0 {
            return None;
        }

        let scale_layouts = normalize_scale_layouts(scale_layouts.to_vec(), 1);
        let grid = coverage_grid_for_layouts(&scale_layouts);
        let cells = grid.checked_mul(grid)?;
        let stop_cells =
            ((stop_ratio.clamp(0.0, 1.0) * cells as f64).ceil() as usize).clamp(1, cells);
        Some(Self {
            scale_layouts,
            grid,
            covered: vec![false; cells],
            covered_count: 0,
            stop_cells,
        })
    }

    pub(super) fn observe_token(&mut self, token: i64) -> bool {
        if token < 0 || self.covered_count >= self.stop_cells {
            return self.covered_count >= self.stop_cells;
        }

        let Some((scale_idx, local)) = scale_token_index(token as usize, &self.scale_layouts)
        else {
            return false;
        };
        let source_grid = self.scale_layouts[scale_idx].grid.max(1);
        let row = local / source_grid;
        let col = local % source_grid;
        let y0 = row.saturating_mul(self.grid) / source_grid;
        let x0 = col.saturating_mul(self.grid) / source_grid;
        let y1 = (row + 1).saturating_mul(self.grid).div_ceil(source_grid);
        let x1 = (col + 1).saturating_mul(self.grid).div_ceil(source_grid);

        for y in y0.min(self.grid)..y1.min(self.grid) {
            let row_offset = y * self.grid;
            for x in x0.min(self.grid)..x1.min(self.grid) {
                let idx = row_offset + x;
                if !self.covered[idx] {
                    self.covered[idx] = true;
                    self.covered_count += 1;
                }
            }
        }
        self.covered_count >= self.stop_cells
    }
}

pub(super) fn generation_coverage_trackers(
    batch: usize,
    stop_ratio: Option<f64>,
    scale_layouts: &[AutoGazeScaleTokenLayout],
) -> Option<Vec<GenerationCoverageTracker>> {
    let tracker = GenerationCoverageTracker::new(scale_layouts, stop_ratio?)?;
    Some(vec![tracker; batch])
}

pub(super) fn observe_generation_coverage(
    trackers: &mut Option<Vec<GenerationCoverageTracker>>,
    batch_idx: usize,
    token: i64,
) -> bool {
    trackers
        .as_mut()
        .and_then(|trackers| trackers.get_mut(batch_idx))
        .map(|tracker| tracker.observe_token(token))
        .unwrap_or(false)
}

pub(super) fn effective_generation_max_tokens(
    configured_max_tokens: usize,
    coverage_stop_ratio: Option<f64>,
    scale_layouts: &[AutoGazeScaleTokenLayout],
    num_multi_token_pred: usize,
) -> usize {
    let configured_max_tokens = configured_max_tokens.max(1);
    let Some(stop_ratio) = coverage_stop_ratio else {
        return configured_max_tokens;
    };
    if !stop_ratio.is_finite() || stop_ratio <= 0.0 || stop_ratio >= 1.0 {
        return configured_max_tokens;
    }

    let Some(finest_grid) = scale_layouts
        .iter()
        .filter(|layout| layout.token_count > 0)
        .map(|layout| layout.grid.max(1))
        .max()
    else {
        return configured_max_tokens;
    };
    let finest_cells = finest_grid.saturating_mul(finest_grid).max(1);
    let required_tokens = (stop_ratio.clamp(0.0, 1.0) * finest_cells as f64).ceil() as usize;
    let chunk = num_multi_token_pred.max(1);
    let chunk_aligned = required_tokens.max(1).div_ceil(chunk).saturating_mul(chunk);
    configured_max_tokens.min(chunk_aligned.max(1))
}

pub(super) fn normalize_scale_layouts(
    mut layouts: Vec<AutoGazeScaleTokenLayout>,
    fallback_token_count: usize,
) -> Vec<AutoGazeScaleTokenLayout> {
    layouts.retain(|layout| layout.token_count > 0);
    if layouts.is_empty() {
        let token_count = fallback_token_count.max(1);
        return vec![AutoGazeScaleTokenLayout {
            token_count,
            grid: square_grid(token_count),
        }];
    }

    for layout in &mut layouts {
        layout.grid = layout.grid.max(1);
    }
    layouts
}

fn coverage_grid_for_layouts(layouts: &[AutoGazeScaleTokenLayout]) -> usize {
    const MAX_COVERAGE_GRID: usize = 256;
    let max_grid = layouts.iter().map(|layout| layout.grid).max().unwrap_or(1);
    let mut grid = 1usize;
    for layout in layouts {
        let Some(next) = bounded_lcm(grid, layout.grid.max(1), MAX_COVERAGE_GRID) else {
            return max_grid.clamp(1, MAX_COVERAGE_GRID);
        };
        grid = next;
    }
    grid.max(max_grid.max(1)).min(MAX_COVERAGE_GRID)
}

fn bounded_lcm(left: usize, right: usize, max_value: usize) -> Option<usize> {
    let gcd = gcd(left.max(1), right.max(1));
    left.checked_div(gcd)?
        .checked_mul(right.max(1))
        .filter(|value| *value <= max_value)
}

fn gcd(mut left: usize, mut right: usize) -> usize {
    while right != 0 {
        let remainder = left % right;
        left = right;
        right = remainder;
    }
    left.max(1)
}

pub(super) fn scale_token_index(
    token: usize,
    scale_layouts: &[AutoGazeScaleTokenLayout],
) -> Option<(usize, usize)> {
    let mut offset = 0usize;
    for (scale_idx, layout) in scale_layouts.iter().enumerate() {
        if token < offset + layout.token_count {
            return Some((scale_idx, token - offset));
        }
        offset += layout.token_count;
    }
    None
}

pub(super) fn token_to_fixation_point(
    token: usize,
    scale_layouts: &[AutoGazeScaleTokenLayout],
    confidence: f32,
) -> Option<FixationPoint> {
    let (scale_idx, local) = scale_token_index(token, scale_layouts)?;
    let grid = scale_layouts[scale_idx].grid.max(1);
    let row = local / grid;
    let col = local % grid;
    let x = (col as f32 + 0.5) / grid as f32;
    let y = (row as f32 + 0.5) / grid as f32;
    let cell = (1.0 / grid as f32).clamp(1.0e-6, 1.0);
    Some(FixationPoint::with_grid_extent(
        x, y, cell, cell, confidence, grid,
    ))
}

pub fn scale_token_layouts(config: &AutoGazeConfig) -> Vec<AutoGazeScaleTokenLayout> {
    let scales = config.scale_values();
    if scales.is_empty() {
        let token_count = config.num_vision_tokens_each_frame.max(1);
        return vec![AutoGazeScaleTokenLayout {
            token_count,
            grid: square_grid(token_count),
        }];
    }

    let patch_size = config
        .gaze_model_config
        .vision_model_config
        .kernel_size
        .max(1);
    let direct_layouts = scales
        .iter()
        .map(|scale| {
            let grid = (scale / patch_size).max(1);
            AutoGazeScaleTokenLayout {
                token_count: grid * grid,
                grid,
            }
        })
        .collect::<Vec<_>>();
    let direct_tokens = direct_layouts
        .iter()
        .map(|layout| layout.token_count)
        .sum::<usize>();
    if direct_tokens == config.num_vision_tokens_each_frame {
        return direct_layouts;
    }

    let sum_sq: usize = scales.iter().map(|scale| scale * scale).sum();
    let mut counts = Vec::with_capacity(scales.len());
    let mut assigned = 0usize;
    for (index, scale) in scales.iter().copied().enumerate() {
        if index + 1 == scales.len() {
            counts.push(config.num_vision_tokens_each_frame.saturating_sub(assigned));
        } else {
            let count = ((scale * scale) as f64 / sum_sq.max(1) as f64
                * config.num_vision_tokens_each_frame as f64)
                .floor() as usize;
            counts.push(count);
            assigned += count;
        }
    }
    counts
        .into_iter()
        .map(|token_count| AutoGazeScaleTokenLayout {
            token_count,
            grid: square_grid(token_count),
        })
        .collect()
}

pub(super) fn square_grid(token_count: usize) -> usize {
    (token_count.max(1) as f64).sqrt().round().max(1.0) as usize
}