burn_autogaze 0.21.6

AutoGaze inference, fixation traces, and crisp mask visualization for Burn
Documentation
use super::layout::{
    AutoGazeScaleTokenLayout, scale_token_index, scale_token_layouts, token_to_fixation_point,
};
use super::{AutoGazeGenerateOutput, AutoGazeScaleTokenMask, GeneratedFrameFixations};
use crate::config::AutoGazeConfig;
use crate::{FixationPoint, FixationSet, FrameFixationTrace};

pub(super) fn generated_to_traces(
    generated: &AutoGazeGenerateOutput,
    config: &AutoGazeConfig,
    k: usize,
) -> Vec<FrameFixationTrace> {
    let mut traces = Vec::with_capacity(generated.gazing_pos.len());
    for batch_idx in 0..generated.gazing_pos.len() {
        let frames = generated_to_frame_fixations(generated, config, batch_idx)
            .into_iter()
            .map(|frame| FixationSet::with_min_len(frame.points, frame.stop_probability, k))
            .collect();
        traces.push(FrameFixationTrace::new(frames));
    }
    traces
}

pub(crate) fn generated_to_frame_fixations(
    generated: &AutoGazeGenerateOutput,
    config: &AutoGazeConfig,
    batch_idx: usize,
) -> Vec<GeneratedFrameFixations> {
    let scale_layouts = scale_token_layouts(config);
    let mut cursor = 0usize;
    let mut frames = Vec::with_capacity(generated.num_gazing_each_frame.len());
    for (frame_idx, frame_len) in generated.num_gazing_each_frame.iter().copied().enumerate() {
        let frame = generated_frame_fixations_from_layouts(
            generated,
            config,
            batch_idx,
            frame_idx,
            cursor,
            frame_len,
            &scale_layouts,
        );
        cursor += frame_len;
        frames.push(frame);
    }
    frames
}

pub(crate) fn generated_frame_fixations(
    generated: &AutoGazeGenerateOutput,
    config: &AutoGazeConfig,
    batch_idx: usize,
    frame_idx: usize,
) -> Option<GeneratedFrameFixations> {
    if batch_idx >= generated.gazing_pos.len() || frame_idx >= generated.num_gazing_each_frame.len()
    {
        return None;
    }
    let cursor = generated
        .num_gazing_each_frame
        .iter()
        .take(frame_idx)
        .sum::<usize>();
    let frame_len = generated.num_gazing_each_frame[frame_idx];
    let scale_layouts = scale_token_layouts(config);
    Some(generated_frame_fixations_from_layouts(
        generated,
        config,
        batch_idx,
        frame_idx,
        cursor,
        frame_len,
        &scale_layouts,
    ))
}

fn generated_frame_fixations_from_layouts(
    generated: &AutoGazeGenerateOutput,
    config: &AutoGazeConfig,
    batch_idx: usize,
    frame_idx: usize,
    cursor: usize,
    frame_len: usize,
    scale_layouts: &[AutoGazeScaleTokenLayout],
) -> GeneratedFrameFixations {
    let tokens = generated.gazing_pos.get(batch_idx);
    let padded = generated.if_padded_gazing.get(batch_idx);
    let confidences = generated.confidences.get(batch_idx);
    let mut points = Vec::new();
    let mut stop_probability = 0.0f32;
    for local_idx in 0..frame_len {
        let global_idx = cursor + local_idx;
        let Some(&raw_token) = tokens.and_then(|tokens| tokens.get(global_idx)) else {
            continue;
        };
        let is_padded = padded
            .and_then(|flags| flags.get(global_idx))
            .copied()
            .unwrap_or(true);
        if is_padded {
            stop_probability = 1.0;
            continue;
        }
        let frame_offset = (frame_idx * config.num_vision_tokens_each_frame) as i64;
        let token = raw_token - frame_offset;
        if token < 0 {
            continue;
        }
        let confidence = confidences
            .and_then(|confidences| confidences.get(global_idx))
            .copied()
            .unwrap_or(1.0);
        if let Some(point) = token_to_fixation_point(token as usize, scale_layouts, confidence) {
            points.push(point);
        }
    }
    GeneratedFrameFixations {
        points,
        stop_probability,
    }
}

pub(crate) fn generated_to_frame_points(
    generated: &AutoGazeGenerateOutput,
    config: &AutoGazeConfig,
) -> Vec<Vec<Vec<FixationPoint>>> {
    (0..generated.gazing_pos.len())
        .map(|batch_idx| {
            generated_to_frame_fixations(generated, config, batch_idx)
                .into_iter()
                .map(|frame| frame.points)
                .collect()
        })
        .collect()
}

pub(super) fn generated_scale_token_masks(
    generated: &AutoGazeGenerateOutput,
    config: &AutoGazeConfig,
) -> Vec<Vec<AutoGazeScaleTokenMask>> {
    let scale_layouts = scale_token_layouts(config);
    let frames = generated.num_gazing_each_frame.len();
    let mut batches = Vec::with_capacity(generated.gazing_pos.len());

    for batch_idx in 0..generated.gazing_pos.len() {
        let mut masks = scale_layouts
            .iter()
            .map(|layout| AutoGazeScaleTokenMask {
                grid: layout.grid,
                token_count: layout.token_count,
                frames: vec![vec![false; layout.token_count]; frames],
            })
            .collect::<Vec<_>>();

        let mut cursor = 0usize;
        for (frame_idx, frame_len) in generated.num_gazing_each_frame.iter().copied().enumerate() {
            for local_idx in 0..frame_len {
                let global_idx = cursor + local_idx;
                let Some(&raw_token) = generated
                    .gazing_pos
                    .get(batch_idx)
                    .and_then(|tokens| tokens.get(global_idx))
                else {
                    continue;
                };
                let padded = generated
                    .if_padded_gazing
                    .get(batch_idx)
                    .and_then(|flags| flags.get(global_idx))
                    .copied()
                    .unwrap_or(true);
                if padded {
                    continue;
                }
                let frame_offset = (frame_idx * config.num_vision_tokens_each_frame) as i64;
                let token = raw_token - frame_offset;
                if token < 0 {
                    continue;
                }
                if let Some((scale_idx, local_token)) =
                    scale_token_index(token as usize, &scale_layouts)
                {
                    masks[scale_idx].frames[frame_idx][local_token] = true;
                }
            }
            cursor += frame_len;
        }

        batches.push(masks);
    }

    batches
}