burn_autogaze 0.21.6

AutoGaze inference, fixation traces, and crisp mask visualization for Burn
Documentation
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::Path;

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(default)]
pub struct AutoGazeConfig {
    pub model_type: String,
    pub attn_mode: String,
    pub scales: String,
    pub num_vision_tokens_each_frame: usize,
    pub max_num_frames: usize,
    pub use_flash_attn: bool,
    pub has_task_loss_requirement_during_training: bool,
    pub has_task_loss_requirement_during_inference: bool,
    pub gazing_ratio_config: serde_json::Value,
    pub gazing_ratio_each_frame_config: serde_json::Value,
    pub task_loss_requirement_config: serde_json::Value,
    pub gaze_model_config: GazeModelConfig,
}

impl Default for AutoGazeConfig {
    fn default() -> Self {
        Self {
            model_type: "autogaze".to_string(),
            attn_mode: "sdpa".to_string(),
            scales: "224".to_string(),
            num_vision_tokens_each_frame: 196,
            max_num_frames: 16,
            use_flash_attn: false,
            has_task_loss_requirement_during_training: false,
            has_task_loss_requirement_during_inference: false,
            gazing_ratio_config: serde_json::json!({}),
            gazing_ratio_each_frame_config: serde_json::json!({}),
            task_loss_requirement_config: serde_json::json!({}),
            gaze_model_config: GazeModelConfig::default(),
        }
    }
}

impl AutoGazeConfig {
    pub fn from_json_file(path: impl AsRef<Path>) -> Result<Self> {
        let path = path.as_ref();
        let bytes = fs::read(path)
            .with_context(|| format!("read AutoGaze config from {}", path.display()))?;
        serde_json::from_slice(&bytes)
            .with_context(|| format!("parse AutoGaze config {}", path.display()))
    }

    pub fn scale_values(&self) -> Vec<usize> {
        self.scales
            .split('+')
            .filter_map(|part| part.trim().parse::<usize>().ok())
            .collect()
    }

    pub fn inference_gazing_ratio(&self) -> Option<f32> {
        fixed_inference_value(&self.gazing_ratio_config, "gazing_ratio")
    }

    pub fn inference_task_loss_requirement(&self) -> Option<f32> {
        if self.has_task_loss_requirement_during_inference {
            fixed_inference_value(&self.task_loss_requirement_config, "task_loss_requirement")
        } else {
            None
        }
    }
}

fn fixed_inference_value(config: &serde_json::Value, value_key: &str) -> Option<f32> {
    let strategy = config
        .get("sample_strategy_during_inference")
        .and_then(serde_json::Value::as_str)?;
    if strategy != "fixed" {
        return None;
    }
    config
        .get("fixed")
        .and_then(|fixed| fixed.get(value_key))
        .and_then(serde_json::Value::as_f64)
        .map(|value| value as f32)
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(default)]
pub struct GazeModelConfig {
    pub input_img_size: usize,
    pub num_vision_tokens_each_frame: usize,
    pub attn_mode: String,
    pub vision_model_config: VisionModelConfig,
    pub connector_config: ConnectorConfig,
    pub gaze_decoder_config: GazeDecoderConfig,
}

impl Default for GazeModelConfig {
    fn default() -> Self {
        Self {
            input_img_size: 224,
            num_vision_tokens_each_frame: 196,
            attn_mode: "sdpa".to_string(),
            vision_model_config: VisionModelConfig::default(),
            connector_config: ConnectorConfig::default(),
            gaze_decoder_config: GazeDecoderConfig::default(),
        }
    }
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(default)]
pub struct VisionModelConfig {
    pub hidden_dim: usize,
    pub out_dim: usize,
    pub depth: usize,
    pub kernel_size: usize,
    pub temporal_patch_size: usize,
    pub trunk_temporal_kernel_size: usize,
    pub trunk_spatial_kernel_size: usize,
}

impl Default for VisionModelConfig {
    fn default() -> Self {
        Self {
            hidden_dim: 192,
            out_dim: 192,
            depth: 1,
            kernel_size: 16,
            temporal_patch_size: 1,
            trunk_temporal_kernel_size: 3,
            trunk_spatial_kernel_size: 3,
        }
    }
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(default)]
pub struct ConnectorConfig {
    pub hidden_dim: usize,
    pub num_tokens: usize,
}

impl Default for ConnectorConfig {
    fn default() -> Self {
        Self {
            hidden_dim: 192,
            num_tokens: 196,
        }
    }
}

#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(default)]
pub struct GazeDecoderConfig {
    pub model_type: String,
    pub vocab_size: usize,
    pub hidden_size: usize,
    pub intermediate_size: usize,
    pub num_hidden_layers: usize,
    pub num_attention_heads: usize,
    pub num_key_value_heads: usize,
    pub hidden_act: String,
    pub max_position_embeddings: usize,
    pub initializer_range: f32,
    pub rms_norm_eps: f32,
    pub use_cache: bool,
    pub bos_token_id: i64,
    pub eos_token_id: i64,
    pub rope_theta: f32,
    pub rope_scaling: Option<serde_json::Value>,
    pub attention_bias: bool,
    pub attention_dropout: f32,
    pub mlp_bias: bool,
    pub head_dim: usize,
    pub attn_mode: String,
    pub num_multi_token_pred: usize,
}

impl Default for GazeDecoderConfig {
    fn default() -> Self {
        Self {
            model_type: "llama".to_string(),
            vocab_size: 32000,
            hidden_size: 4096,
            intermediate_size: 11008,
            num_hidden_layers: 32,
            num_attention_heads: 32,
            num_key_value_heads: 32,
            hidden_act: "silu".to_string(),
            max_position_embeddings: 2048,
            initializer_range: 0.02,
            rms_norm_eps: 1.0e-6,
            use_cache: true,
            bos_token_id: 1,
            eos_token_id: 2,
            rope_theta: 10000.0,
            rope_scaling: None,
            attention_bias: false,
            attention_dropout: 0.0,
            mlp_bias: false,
            head_dim: 128,
            attn_mode: "sdpa".to_string(),
            num_multi_token_pred: 1,
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn parses_cached_autogaze_config_shape() {
        let path = Path::new(
            "/home/mosure/.cache/huggingface/hub/models--nvidia--AutoGaze/snapshots/5100fae739ec1bf3f875914fa1b703846a18943a/config.json",
        );
        if !path.exists() {
            eprintln!(
                "skipping AutoGaze config parse: missing Hugging Face config {}",
                path.display()
            );
            return;
        }
        let config = AutoGazeConfig::from_json_file(path).expect("parse autogaze config");
        assert_eq!(config.model_type, "autogaze");
        assert_eq!(config.gaze_model_config.vision_model_config.hidden_dim, 192);
        assert_eq!(
            config.gaze_model_config.gaze_decoder_config.hidden_size,
            192
        );
        assert_eq!(
            config
                .gaze_model_config
                .gaze_decoder_config
                .num_hidden_layers,
            4
        );
        assert_eq!(config.scale_values(), vec![32, 64, 112, 224]);
        assert_eq!(config.inference_gazing_ratio(), Some(0.75));
        assert_eq!(config.inference_task_loss_requirement(), Some(0.7));
    }
}