burn_autogaze 0.21.6

AutoGaze inference, fixation traces, and crisp mask visualization for Burn
Documentation
use crate::{FixationPoint, FixationSet, FrameFixationTrace};
use anyhow::{Result, anyhow, ensure};
use half::f16;
use safetensors::{SafeTensors, tensor::TensorView};
use std::fs;
use std::path::Path;

#[derive(Clone, Debug)]
pub struct AutoGazeTraceStore {
    traces: Vec<FrameFixationTrace>,
    clip_len: usize,
    k: usize,
    visibility_maps: Option<Vec<f32>>,
    visibility_height: usize,
    visibility_width: usize,
}

impl AutoGazeTraceStore {
    pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
        let bytes = fs::read(path)?;
        let tensors = SafeTensors::deserialize(&bytes)?;
        Self::from_safetensors(&tensors)
    }

    pub fn from_safetensors(tensors: &SafeTensors<'_>) -> Result<Self> {
        let fixations = tensors
            .tensor("fixations")
            .map_err(|_| anyhow!("missing fixations tensor"))?;
        let scales = tensors
            .tensor("scales")
            .map_err(|_| anyhow!("missing scales tensor"))?;
        let confidences = tensors
            .tensor("confidences")
            .map_err(|_| anyhow!("missing confidences tensor"))?;
        let stops = tensors
            .tensor("stop_probabilities")
            .map_err(|_| anyhow!("missing stop_probabilities tensor"))?;

        let shape = fixations.shape();
        if shape.len() != 4 || *shape.last().unwrap_or(&0) != 2 {
            return Err(anyhow!(
                "fixations tensor must have shape [clips, frames, k, 2]"
            ));
        }
        let clips = shape[0];
        let clip_len = shape[1];
        let k = shape[2];
        let fixation_values = tensor_to_f32(&fixations)?;
        let scale_values = tensor_to_f32(&scales)?;
        let confidence_values = tensor_to_f32(&confidences)?;
        let stop_values = tensor_to_f32(&stops)?;
        let scalar_count = clips
            .checked_mul(clip_len)
            .and_then(|value| value.checked_mul(k))
            .ok_or_else(|| anyhow!("trace store shape overflow"))?;
        ensure!(
            fixation_values.len() == scalar_count * 2,
            "fixations tensor data length does not match shape"
        );
        ensure!(
            scale_values.len() == scalar_count,
            "scales tensor must have shape [clips, frames, k]"
        );
        ensure!(
            confidence_values.len() == scalar_count,
            "confidences tensor must have shape [clips, frames, k]"
        );
        ensure!(
            stop_values.len() == clips * clip_len,
            "stop_probabilities tensor must have shape [clips, frames]"
        );
        let visibility = tensors.tensor("visibility_maps").ok();
        let (visibility_maps, visibility_height, visibility_width) = if let Some(view) = visibility
        {
            let shape = view.shape();
            if shape.len() != 4 || shape[0] != clips || shape[1] != clip_len {
                return Err(anyhow!(
                    "visibility_maps tensor must have shape [clips, frames, height, width]"
                ));
            }
            let values = tensor_to_f32(&view)?;
            (Some(values), shape[2], shape[3])
        } else {
            (None, 0, 0)
        };

        let mut traces = Vec::with_capacity(clips);
        for clip_idx in 0..clips {
            let mut frames = Vec::with_capacity(clip_len);
            for frame_idx in 0..clip_len {
                let mut points = Vec::with_capacity(k);
                for point_idx in 0..k {
                    let base = ((clip_idx * clip_len + frame_idx) * k + point_idx) * 2;
                    let scalar_idx = (clip_idx * clip_len + frame_idx) * k + point_idx;
                    points.push(FixationPoint::new(
                        fixation_values[base],
                        fixation_values[base + 1],
                        scale_values[scalar_idx],
                        confidence_values[scalar_idx],
                    ));
                }
                let stop_idx = clip_idx * clip_len + frame_idx;
                frames.push(FixationSet::new(points, stop_values[stop_idx], k));
            }
            traces.push(FrameFixationTrace::new(frames));
        }

        Ok(Self {
            traces,
            clip_len,
            k,
            visibility_maps,
            visibility_height,
            visibility_width,
        })
    }

    pub fn trace(&self, index: usize) -> Option<&FrameFixationTrace> {
        self.traces.get(index)
    }

    pub fn len(&self) -> usize {
        self.traces.len()
    }

    pub fn is_empty(&self) -> bool {
        self.traces.is_empty()
    }

    pub fn clip_len(&self) -> usize {
        self.clip_len
    }

    pub fn k(&self) -> usize {
        self.k
    }

    pub fn visibility_shape(&self) -> Option<(usize, usize)> {
        self.visibility_maps
            .as_ref()
            .map(|_| (self.visibility_height, self.visibility_width))
    }

    pub fn visibility_map(&self, clip_idx: usize, frame_idx: usize) -> Option<&[f32]> {
        let maps = self.visibility_maps.as_ref()?;
        if clip_idx >= self.traces.len() || frame_idx >= self.clip_len {
            return None;
        }
        let frame_area = self.visibility_height * self.visibility_width;
        let base = (clip_idx * self.clip_len + frame_idx) * frame_area;
        maps.get(base..base + frame_area)
    }
}

fn tensor_to_f32(view: &TensorView<'_>) -> Result<Vec<f32>> {
    let element_count = view
        .shape()
        .iter()
        .try_fold(1usize, |acc, &dim| acc.checked_mul(dim))
        .ok_or_else(|| anyhow!("safetensors trace store shape overflow"))?;
    let bytes = view.data();
    match view.dtype() {
        safetensors::Dtype::F32 => {
            let expected_bytes = element_count
                .checked_mul(4)
                .ok_or_else(|| anyhow!("F32 tensor byte length overflow"))?;
            ensure!(
                bytes.len() == expected_bytes,
                "F32 tensor byte length does not match shape"
            );
            Ok(bytes
                .chunks_exact(4)
                .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
                .collect())
        }
        safetensors::Dtype::F16 => {
            let expected_bytes = element_count
                .checked_mul(2)
                .ok_or_else(|| anyhow!("F16 tensor byte length overflow"))?;
            ensure!(
                bytes.len() == expected_bytes,
                "F16 tensor byte length does not match shape"
            );
            Ok(bytes
                .chunks_exact(2)
                .map(|chunk| {
                    let value = u16::from_le_bytes([chunk[0], chunk[1]]);
                    f16::from_bits(value).to_f32()
                })
                .collect())
        }
        other => Err(anyhow!(
            "unsupported dtype in safetensors trace store: {other:?}"
        )),
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use safetensors::tensor::{Dtype, View, serialize_to_file};
    use tempfile::NamedTempFile;

    #[derive(Clone)]
    struct OwnedTensor {
        shape: Vec<usize>,
        data: Vec<u8>,
        dtype: Dtype,
    }

    impl View for OwnedTensor {
        fn dtype(&self) -> Dtype {
            self.dtype
        }

        fn shape(&self) -> &[usize] {
            &self.shape
        }

        fn data(&self) -> std::borrow::Cow<'_, [u8]> {
            std::borrow::Cow::Borrowed(&self.data)
        }

        fn data_len(&self) -> usize {
            self.data.len()
        }
    }

    fn tensor_f32(shape: &[usize], values: &[f32]) -> OwnedTensor {
        let mut data = Vec::with_capacity(values.len() * 4);
        for value in values {
            data.extend_from_slice(&value.to_le_bytes());
        }
        OwnedTensor {
            shape: shape.to_vec(),
            data,
            dtype: Dtype::F32,
        }
    }

    #[test]
    fn loads_trace_store_from_safetensors() {
        let temp = NamedTempFile::new().expect("tempfile");
        let tensors = vec![
            (
                "fixations".to_string(),
                tensor_f32(&[1, 2, 1, 2], &[0.25, 0.75, 0.5, 0.25]),
            ),
            ("scales".to_string(), tensor_f32(&[1, 2, 1], &[0.2, 0.3])),
            (
                "confidences".to_string(),
                tensor_f32(&[1, 2, 1], &[0.9, 0.8]),
            ),
            (
                "stop_probabilities".to_string(),
                tensor_f32(&[1, 2], &[0.1, 0.2]),
            ),
            (
                "visibility_maps".to_string(),
                tensor_f32(&[1, 2, 2, 2], &[1.0, 0.0, 0.0, 1.0, 0.2, 0.4, 0.6, 0.8]),
            ),
        ];
        serialize_to_file(tensors, None, temp.path()).expect("write safetensors");

        let store = AutoGazeTraceStore::from_file(temp.path()).expect("load store");
        assert_eq!(store.len(), 1);
        assert_eq!(store.clip_len(), 2);
        assert_eq!(store.k(), 1);
        let trace = store.trace(0).expect("trace");
        assert_eq!(trace.frames[0].points[0].x, 0.25);
        assert_eq!(trace.frames[1].points[0].y, 0.25);
        assert_eq!(store.visibility_shape(), Some((2, 2)));
        assert_eq!(
            store.visibility_map(0, 1).expect("visibility"),
            &[0.2, 0.4, 0.6, 0.8]
        );
    }

    #[test]
    fn rejects_trace_store_tensor_shape_mismatches() {
        let temp = NamedTempFile::new().expect("tempfile");
        let tensors = vec![
            (
                "fixations".to_string(),
                tensor_f32(&[1, 2, 1, 2], &[0.25, 0.75, 0.5, 0.25]),
            ),
            ("scales".to_string(), tensor_f32(&[1, 1, 1], &[0.2])),
            (
                "confidences".to_string(),
                tensor_f32(&[1, 2, 1], &[0.9, 0.8]),
            ),
            (
                "stop_probabilities".to_string(),
                tensor_f32(&[1, 2], &[0.1, 0.2]),
            ),
        ];
        serialize_to_file(tensors, None, temp.path()).expect("write safetensors");

        let err = AutoGazeTraceStore::from_file(temp.path()).expect_err("shape mismatch");

        assert!(
            err.to_string().contains("scales tensor must have shape"),
            "unexpected error: {err:#}"
        );
    }
}