burn_autogaze 0.21.2

AutoGaze inference, fixation traces, and crisp mask visualization for Burn
Documentation
mod config;
mod model;
mod nodes;
mod pipeline;
mod pyramid;
mod safetensors_io;
mod teacher;
mod trace;
mod visualization;
#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
mod wasm;

pub use config::{
    AutoGazeConfig, ConnectorConfig, GazeDecoderConfig, GazeModelConfig, VisionModelConfig,
};
pub use model::{
    AutoGazeCausalLmOutput, AutoGazeGazingModel, AutoGazeGenerateOutput, AutoGazeLoadOptions,
    Connector, Conv3dBlockForStreaming, NativeAutoGazeModel, ShallowVideoConvNet,
};
pub use nodes::{
    AutoGazeInputNode, AutoGazeOutputNode, AutoGazePipelinePacket, AutoGazeRgbaClip,
    AutoGazeTensorClip, AutoGazeTensorClipShape, AutoGazeTensorPipeline,
    AutoGazeTensorPipelineConfig, FnOutputNode, RgbaClipInput, TensorClipInput, VecOutputNode,
};
pub use pipeline::{
    AUTO_GAZE_IMAGE_MEAN, AUTO_GAZE_IMAGE_STD, AUTO_GAZE_RESCALE_FACTOR, AutoGazeClipShape,
    AutoGazeEmbedOutput, AutoGazeInferenceMode, AutoGazePipeline, AutoGazeRgbaClipShape,
    AutoGazeTile, AutoGazeTileLayout, rgba_clip_to_tensor,
};
pub use pyramid::{
    ImagePyramidLevel, ImagePyramidMask, ImagePyramidMaskOptions, ImagePyramidTokens,
    SparseImagePyramidTokens, apply_image_mask, fixation_image_mask_tensor,
    frame_fixation_masks_tensor, image_pyramid_masks, sparsify_image_pyramid_tokens,
    tokenize_masked_image_pyramid,
};
pub use safetensors_io::AutoGazeTraceStore;
pub use teacher::AutoGazeTeacher;
pub use trace::{FixationBounds, FixationPoint, FixationSet, FrameFixationTrace};
pub use visualization::{
    AutoGazeVisualization, AutoGazeVisualizationMode, AutoGazeVisualizationState,
    fixation_alpha_mask, fixation_scale_mask_rgba, visualize_fixations_rgba,
};
#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
pub use wasm::*;

#[cfg(feature = "ndarray")]
pub type NdArrayAutoGazeModel = NativeAutoGazeModel<burn::backend::NdArray<f32>>;

#[cfg(feature = "ndarray")]
pub type NdArrayAutoGazePipeline = AutoGazePipeline<burn::backend::NdArray<f32>>;

#[cfg(feature = "cuda")]
pub type CudaAutoGazeModel = NativeAutoGazeModel<burn::backend::Cuda<f32, i32>>;

#[cfg(feature = "cuda")]
pub type CudaAutoGazePipeline = AutoGazePipeline<burn::backend::Cuda<f32, i32>>;

#[cfg(any(feature = "wgpu", feature = "webgpu"))]
pub type WgpuAutoGazeModel = NativeAutoGazeModel<burn::backend::Wgpu<f32, i32>>;

#[cfg(any(feature = "wgpu", feature = "webgpu"))]
pub type WgpuAutoGazePipeline = AutoGazePipeline<burn::backend::Wgpu<f32, i32>>;

#[cfg(feature = "webgpu")]
pub type WebGpuAutoGazeModel = NativeAutoGazeModel<burn::backend::WebGpu<f32, i32>>;

#[cfg(feature = "webgpu")]
pub type WebGpuAutoGazePipeline = AutoGazePipeline<burn::backend::WebGpu<f32, i32>>;