burn_autogaze 0.21.6

AutoGaze inference, fixation traces, and crisp mask visualization for Burn
Documentation
use crate::{FixationPoint, fixation_alpha_mask};
use anyhow::{Result, ensure};
use burn::tensor::backend::Backend;
use burn::tensor::module::{adaptive_avg_pool2d, interpolate};
use burn::tensor::ops::{InterpolateMode, InterpolateOptions};
use burn::tensor::{Int, Tensor, TensorData};
use std::ops::Range;

#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct ImagePyramidLevel {
    pub height: usize,
    pub width: usize,
}

impl ImagePyramidLevel {
    pub const fn new(height: usize, width: usize) -> Self {
        Self { height, width }
    }

    pub const fn token_count(&self) -> usize {
        self.height * self.width
    }

    fn normalized(self) -> Self {
        Self {
            height: self.height.max(1),
            width: self.width.max(1),
        }
    }
}

#[derive(Clone, Copy, Debug, PartialEq)]
pub struct ImagePyramidMaskOptions {
    pub threshold: f32,
}

impl Default for ImagePyramidMaskOptions {
    fn default() -> Self {
        Self { threshold: 0.0 }
    }
}

pub struct ImagePyramidMask<B: Backend> {
    pub level: ImagePyramidLevel,
    pub density: Tensor<B, 4>,
    pub active: Tensor<B, 4>,
}

pub struct ImagePyramidTokens<B: Backend> {
    pub tokens: Tensor<B, 3>,
    pub weights: Tensor<B, 3>,
    pub levels: Vec<ImagePyramidLevel>,
    pub level_token_ranges: Vec<Range<usize>>,
}

pub struct SparseImagePyramidTokens<B: Backend> {
    pub tokens: Tensor<B, 3>,
    pub weights: Tensor<B, 3>,
    pub indices: Tensor<B, 2, Int>,
}

pub fn fixation_image_mask_tensor<B: Backend>(
    batch: usize,
    height: usize,
    width: usize,
    points: &[FixationPoint],
    device: &B::Device,
) -> Result<Tensor<B, 4>> {
    ensure!(batch > 0, "mask batch must be nonzero");
    ensure!(height > 0 && width > 0, "mask dimensions must be nonzero");
    let alpha = fixation_alpha_mask(width, height, points, 1.0);
    let values = alpha
        .into_iter()
        .map(|value| if value > 0 { 1.0 } else { 0.0 })
        .collect::<Vec<_>>();
    let mask = Tensor::<B, 4>::from_data(TensorData::new(values, [1, 1, height, width]), device);
    Ok(if batch == 1 {
        mask
    } else {
        mask.repeat_dim(0, batch)
    })
}

pub fn frame_fixation_masks_tensor<B: Backend>(
    traces: &[crate::FrameFixationTrace],
    frame_index: usize,
    height: usize,
    width: usize,
    device: &B::Device,
) -> Result<Tensor<B, 4>> {
    ensure!(
        !traces.is_empty(),
        "at least one fixation trace is required"
    );
    ensure!(height > 0 && width > 0, "mask dimensions must be nonzero");
    let mut values = Vec::with_capacity(traces.len() * height * width);
    for trace in traces {
        let points = trace
            .frames
            .get(frame_index)
            .map(|set| set.points.as_slice())
            .unwrap_or(&[]);
        values.extend(
            fixation_alpha_mask(width, height, points, 1.0)
                .into_iter()
                .map(|value| if value > 0 { 1.0 } else { 0.0 }),
        );
    }
    Ok(Tensor::<B, 4>::from_data(
        TensorData::new(values, [traces.len(), 1, height, width]),
        device,
    ))
}

pub fn apply_image_mask<B: Backend>(
    image: Tensor<B, 4>,
    mask: Tensor<B, 4>,
    fill_value: f32,
) -> Result<Tensor<B, 4>> {
    let [batch, channels, height, width] = image.shape().dims::<4>();
    ensure_image_shape(batch, channels, height, width)?;
    let mask = image_channel_mask(mask, batch, channels, height, width)?;
    let fill = image.zeros_like().add_scalar(fill_value);
    let inverse = mask.clone().mul_scalar(-1.0).add_scalar(1.0);
    Ok(image * mask + fill * inverse)
}

pub fn image_pyramid_masks<B: Backend>(
    mask: Tensor<B, 4>,
    image_height: usize,
    image_width: usize,
    levels: &[ImagePyramidLevel],
    options: ImagePyramidMaskOptions,
) -> Result<Vec<ImagePyramidMask<B>>> {
    ensure!(
        image_height > 0 && image_width > 0,
        "image dimensions must be nonzero"
    );
    let [batch, _channels, _height, _width] = mask.shape().dims::<4>();
    ensure!(batch > 0, "mask batch must be nonzero");
    let mask = single_channel_mask(mask, batch, image_height, image_width)?;
    Ok(levels
        .iter()
        .copied()
        .map(ImagePyramidLevel::normalized)
        .map(|level| {
            let density =
                adaptive_avg_pool2d(mask.clone(), [level.height, level.width]).clamp(0.0, 1.0);
            let active = density.clone().greater_elem(options.threshold).float();
            ImagePyramidMask {
                level,
                density,
                active,
            }
        })
        .collect())
}

pub fn tokenize_masked_image_pyramid<B: Backend>(
    image: Tensor<B, 4>,
    mask: Tensor<B, 4>,
    levels: &[ImagePyramidLevel],
    options: ImagePyramidMaskOptions,
) -> Result<ImagePyramidTokens<B>> {
    ensure!(
        !levels.is_empty(),
        "image pyramid must contain at least one level"
    );
    let [batch, channels, height, width] = image.shape().dims::<4>();
    ensure_image_shape(batch, channels, height, width)?;
    let mask = single_channel_mask(mask, batch, height, width)?;
    let mut level_token_ranges = Vec::with_capacity(levels.len());
    let mut normalized_levels = Vec::with_capacity(levels.len());
    let mut token_tensors = Vec::with_capacity(levels.len());
    let mut weight_tensors = Vec::with_capacity(levels.len());
    let mut offset = 0usize;

    for level in levels.iter().copied().map(ImagePyramidLevel::normalized) {
        let density =
            adaptive_avg_pool2d(mask.clone(), [level.height, level.width]).clamp(0.0, 1.0);
        let active = density.clone().greater_elem(options.threshold).float();
        let weights = density * active;
        let pooled = adaptive_avg_pool2d(image.clone(), [level.height, level.width]);
        let masked = pooled * weights.clone().repeat_dim(1, channels);
        let token_count = level.token_count();

        token_tensors.push(
            masked
                .reshape([batch, channels, token_count])
                .swap_dims(1, 2),
        );
        weight_tensors.push(weights.reshape([batch, 1, token_count]).swap_dims(1, 2));
        level_token_ranges.push(offset..offset + token_count);
        offset += token_count;
        normalized_levels.push(level);
    }

    Ok(ImagePyramidTokens {
        tokens: Tensor::cat(token_tensors, 1),
        weights: Tensor::cat(weight_tensors, 1),
        levels: normalized_levels,
        level_token_ranges,
    })
}

pub fn sparsify_image_pyramid_tokens<B: Backend>(
    tokens: ImagePyramidTokens<B>,
    max_tokens: usize,
) -> Result<SparseImagePyramidTokens<B>> {
    let [_batch, total_tokens, channels] = tokens.tokens.shape().dims::<3>();
    ensure!(total_tokens > 0, "image pyramid tokens must be nonempty");
    ensure!(channels > 0, "image pyramid token channels must be nonzero");
    let k = max_tokens.max(1).min(total_tokens);
    let scores = tokens.weights.clone().squeeze_dim::<2>(2);
    let (_values, indices) = scores.topk_with_indices(k, 1);
    let token_indices = indices
        .clone()
        .unsqueeze_dim::<3>(2)
        .repeat_dim(2, channels);
    let weight_indices = indices.clone().unsqueeze_dim::<3>(2);
    Ok(SparseImagePyramidTokens {
        tokens: tokens.tokens.gather(1, token_indices),
        weights: tokens.weights.gather(1, weight_indices),
        indices,
    })
}

fn ensure_image_shape(batch: usize, channels: usize, height: usize, width: usize) -> Result<()> {
    ensure!(batch > 0, "image batch must be nonzero");
    ensure!(channels > 0, "image channels must be nonzero");
    ensure!(height > 0 && width > 0, "image dimensions must be nonzero");
    Ok(())
}

fn single_channel_mask<B: Backend>(
    mask: Tensor<B, 4>,
    batch: usize,
    height: usize,
    width: usize,
) -> Result<Tensor<B, 4>> {
    let [mask_batch, mask_channels, mask_height, mask_width] = mask.shape().dims::<4>();
    ensure!(
        mask_batch == 1 || mask_batch == batch,
        "mask batch must be 1 or match image batch"
    );
    ensure!(mask_channels > 0, "mask channels must be nonzero");
    let mask = if mask_channels == 1 {
        mask
    } else {
        mask.mean_dim(1)
    };
    let mask = if mask_height == height && mask_width == width {
        mask
    } else {
        interpolate(
            mask,
            [height, width],
            InterpolateOptions::new(InterpolateMode::Nearest),
        )
    };
    let mask = if mask_batch == 1 && batch > 1 {
        mask.repeat_dim(0, batch)
    } else {
        mask
    };
    Ok(mask.clamp(0.0, 1.0))
}

fn image_channel_mask<B: Backend>(
    mask: Tensor<B, 4>,
    batch: usize,
    channels: usize,
    height: usize,
    width: usize,
) -> Result<Tensor<B, 4>> {
    let mask = single_channel_mask(mask, batch, height, width)?;
    Ok(if channels == 1 {
        mask
    } else {
        mask.repeat_dim(1, channels)
    })
}

#[cfg(all(test, feature = "ndarray"))]
mod tests {
    use super::*;
    use burn::backend::NdArray;

    type B = NdArray<f32>;

    #[test]
    fn fixation_mask_tensor_matches_crisp_cell_bounds() {
        let device = Default::default();
        let point = FixationPoint::with_extent(0.25, 0.25, 0.5, 0.5, 1.0);

        let mask = fixation_image_mask_tensor::<B>(1, 4, 4, &[point], &device).expect("mask");
        let values = mask.into_data().to_vec::<f32>().expect("f32 mask");

        for y in 0..4 {
            for x in 0..4 {
                let expected = if x < 2 && y < 2 { 1.0 } else { 0.0 };
                assert_eq!(values[y * 4 + x], expected, "{x},{y}");
            }
        }
    }

    #[test]
    fn frame_fixation_masks_tensor_keeps_batches_separate() {
        let device = Default::default();
        let left = FixationPoint::with_extent(0.25, 0.5, 0.5, 1.0, 1.0);
        let right = FixationPoint::with_extent(0.75, 0.5, 0.5, 1.0, 1.0);
        let traces = vec![
            crate::FrameFixationTrace::new(vec![crate::FixationSet::new(vec![left], 1.0, 1)]),
            crate::FrameFixationTrace::new(vec![crate::FixationSet::new(vec![right], 1.0, 1)]),
        ];

        let mask = frame_fixation_masks_tensor::<B>(&traces, 0, 1, 2, &device).expect("mask");
        let values = mask.into_data().to_vec::<f32>().expect("mask values");

        assert_eq!(values, vec![1.0, 0.0, 0.0, 1.0]);
    }

    #[test]
    fn apply_image_mask_preserves_selected_pixels_and_fills_the_rest() {
        let device = Default::default();
        let image = Tensor::<B, 4>::from_data(
            TensorData::new(vec![1.0, 2.0, 3.0, 4.0], [1, 1, 2, 2]),
            &device,
        );
        let mask = Tensor::<B, 4>::from_data(
            TensorData::new(vec![1.0, 0.0, 0.0, 1.0], [1, 1, 2, 2]),
            &device,
        );

        let masked = apply_image_mask(image, mask, -1.0).expect("masked image");
        let values = masked.into_data().to_vec::<f32>().expect("f32 image");

        assert_eq!(values, vec![1.0, -1.0, -1.0, 4.0]);
    }

    #[test]
    fn tokenizes_masked_image_pyramid_with_density_weights() {
        let device = Default::default();
        let image = Tensor::<B, 4>::from_data(
            TensorData::new(vec![1.0, 2.0, 3.0, 4.0], [1, 1, 2, 2]),
            &device,
        );
        let mask = Tensor::<B, 4>::from_data(
            TensorData::new(vec![1.0, 0.0, 0.0, 0.0], [1, 1, 2, 2]),
            &device,
        );

        let tokens = tokenize_masked_image_pyramid(
            image,
            mask,
            &[ImagePyramidLevel::new(1, 1), ImagePyramidLevel::new(2, 2)],
            ImagePyramidMaskOptions::default(),
        )
        .expect("tokens");
        let token_values = tokens
            .tokens
            .clone()
            .into_data()
            .to_vec::<f32>()
            .expect("token values");
        let weight_values = tokens
            .weights
            .clone()
            .into_data()
            .to_vec::<f32>()
            .expect("weight values");

        assert_eq!(tokens.level_token_ranges, vec![0..1, 1..5]);
        assert_eq!(weight_values, vec![0.25, 1.0, 0.0, 0.0, 0.0]);
        assert_eq!(token_values, vec![0.625, 1.0, 0.0, 0.0, 0.0]);
    }

    #[test]
    fn sparsifies_image_pyramid_tokens_by_highest_mask_density() {
        let device = Default::default();
        let image = Tensor::<B, 4>::from_data(
            TensorData::new(vec![1.0, 2.0, 3.0, 4.0], [1, 1, 2, 2]),
            &device,
        );
        let mask = Tensor::<B, 4>::from_data(
            TensorData::new(vec![0.1, 0.9, 0.2, 0.7], [1, 1, 2, 2]),
            &device,
        );
        let tokens = tokenize_masked_image_pyramid(
            image,
            mask,
            &[ImagePyramidLevel::new(2, 2)],
            ImagePyramidMaskOptions::default(),
        )
        .expect("tokens");

        let sparse = sparsify_image_pyramid_tokens(tokens, 2).expect("sparse tokens");
        let weights = sparse.weights.into_data().to_vec::<f32>().expect("weights");

        assert_eq!(weights, vec![0.9, 0.7]);
    }
}