use crate::config::{AutoGazeConfig, ConnectorConfig, GazeModelConfig, VisionModelConfig};
use crate::{FixationPoint, FixationSet, FrameFixationTrace};
use anyhow::{Context, Result, bail};
use burn::module::{Module, Param};
use burn::nn::conv::{Conv3d, Conv3dConfig};
use burn::nn::{
Embedding, EmbeddingConfig, LayerNorm, LayerNormConfig, Linear, LinearConfig, PaddingConfig3d,
};
use burn::tensor::activation;
use burn::tensor::backend::{Backend, ExecutionError};
use burn::tensor::module::interpolate;
use burn::tensor::ops::{InterpolateMode, InterpolateOptions, PadMode};
use burn::tensor::{Bool, Int, Tensor, TensorData};
use burn_store::{ModuleSnapshot, PyTorchToBurnAdapter, SafetensorsStore};
use std::path::Path;
#[derive(Module, Debug)]
pub struct Conv3dBlockForStreaming<B: Backend> {
pub conv3d: Conv3d<B>,
#[module(skip)]
temporal_patch_size: usize,
}
impl<B: Backend> Conv3dBlockForStreaming<B> {
pub fn new(
hidden_dim: usize,
temporal_patch_size: usize,
spatial_kernel_size: usize,
device: &B::Device,
) -> Self {
Self {
conv3d: Conv3dConfig::new(
[hidden_dim.max(1), hidden_dim.max(1)],
[
temporal_patch_size.max(1),
spatial_kernel_size.max(1),
spatial_kernel_size.max(1),
],
)
.with_padding(PaddingConfig3d::Explicit(
0,
spatial_kernel_size.saturating_sub(1) / 2,
spatial_kernel_size.saturating_sub(1) / 2,
))
.init(device),
temporal_patch_size: temporal_patch_size.max(1),
}
}
pub fn forward(
&self,
x: Tensor<B, 5>,
use_cache: bool,
past_conv_values: Option<Tensor<B, 5>>,
) -> (Tensor<B, 5>, Tensor<B, 5>) {
let x = if use_cache {
if let Some(past) = past_conv_values {
Tensor::cat(vec![past, x], 2)
} else {
x.pad(
[
(0, 0),
(0, 0),
(self.temporal_patch_size.saturating_sub(1), 0),
(0, 0),
(0, 0),
],
PadMode::Constant(0.0),
)
}
} else {
x.pad(
[
(0, 0),
(0, 0),
(self.temporal_patch_size.saturating_sub(1), 0),
(0, 0),
(0, 0),
],
PadMode::Constant(0.0),
)
};
let time = x.shape().dims::<5>()[2];
let keep = self.temporal_patch_size.saturating_sub(1);
let new_past_conv_values = if keep == 0 {
Tensor::<B, 5>::zeros(
[
x.shape().dims::<5>()[0],
x.shape().dims::<5>()[1],
0,
x.shape().dims::<5>()[3],
x.shape().dims::<5>()[4],
],
&x.device(),
)
} else {
x.clone().slice_dim(2, time.saturating_sub(keep)..time)
};
let x = self.conv3d.forward(x);
(activation::relu(x), new_past_conv_values)
}
}
#[derive(Module, Debug)]
pub struct ShallowVideoConvNet<B: Backend> {
pub temporal_conv: Conv3d<B>,
pub norm: LayerNorm<B>,
pub blocks: Vec<Conv3dBlockForStreaming<B>>,
pub out_proj: Conv3d<B>,
#[module(skip)]
temporal_patch_size: usize,
}
impl<B: Backend> ShallowVideoConvNet<B> {
pub fn new(config: &VisionModelConfig, device: &B::Device) -> Self {
let hidden_dim = config.hidden_dim.max(1);
let out_dim = config.out_dim.max(1);
let temporal_patch_size = config.temporal_patch_size.max(1);
let temporal_conv = Conv3dConfig::new(
[3, hidden_dim],
[
temporal_patch_size,
config.kernel_size.max(1),
config.kernel_size.max(1),
],
)
.with_stride([
temporal_patch_size,
config.kernel_size.max(1),
config.kernel_size.max(1),
])
.init(device);
let norm = LayerNormConfig::new(hidden_dim).init(device);
let blocks = (0..config.depth.max(1))
.map(|_| {
Conv3dBlockForStreaming::new(
hidden_dim,
config.trunk_temporal_kernel_size.max(1),
config.trunk_spatial_kernel_size.max(1),
device,
)
})
.collect();
let out_proj = Conv3dConfig::new([hidden_dim, out_dim], [1, 1, 1]).init(device);
Self {
temporal_conv,
norm,
blocks,
out_proj,
temporal_patch_size,
}
}
pub fn forward(
&self,
x: Tensor<B, 5>,
use_cache: bool,
past_conv_values: Option<Vec<Tensor<B, 5>>>,
) -> (Tensor<B, 5>, Vec<Tensor<B, 5>>) {
let mut x = x.permute([0, 2, 1, 3, 4]);
x = self.temporal_conv.forward(x);
let [batch, channels, time, height, width] = x.shape().dims::<5>();
let x_flat = x
.permute([0, 2, 1, 3, 4])
.reshape([batch * time, channels, height * width])
.swap_dims(1, 2);
let x_flat = self.norm.forward(x_flat);
let mut x = x_flat
.swap_dims(1, 2)
.reshape([batch, time, channels, height, width])
.permute([0, 2, 1, 3, 4]);
let mut new_past = Vec::with_capacity(self.blocks.len());
for (index, block) in self.blocks.iter().enumerate() {
let past = past_conv_values
.as_ref()
.and_then(|values| values.get(index))
.cloned();
let (next_x, next_past) = block.forward(x, use_cache, past);
x = next_x;
new_past.push(next_past);
}
let x = self.out_proj.forward(x);
(x, new_past)
}
}
#[derive(Module, Debug)]
pub struct Connector<B: Backend> {
pub pos_embed: Param<Tensor<B, 2>>,
}
impl<B: Backend> Connector<B> {
pub fn new(config: &ConnectorConfig, device: &B::Device) -> Self {
Self {
pos_embed: Param::from_tensor(Tensor::<B, 2>::random(
[config.num_tokens.max(1), config.hidden_dim.max(1)],
burn::tensor::Distribution::Normal(0.0, 1.0),
device,
)),
}
}
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 4> {
let [batch, time, tokens, dim] = x.shape().dims::<4>();
let pos = self.pos_embed.val().reshape([1, 1, tokens, dim]);
x + pos.repeat_dim(0, batch).repeat_dim(1, time)
}
}
#[derive(Clone, Debug)]
pub struct AutoGazeGenerateOutput {
pub gazing_pos: Vec<Vec<i64>>,
pub num_gazing_each_frame: Vec<usize>,
pub if_padded_gazing: Vec<Vec<bool>>,
pub confidences: Vec<Vec<f32>>,
}
type GreedyTokenSelection = (Vec<Vec<i64>>, Vec<Vec<bool>>, Vec<Vec<f32>>);
#[derive(Clone, Copy, Debug)]
struct TaskLossStop {
requirement: Option<f32>,
is_first_token: bool,
}
#[derive(Debug)]
pub struct AutoGazeCausalLmOutput<B: Backend> {
pub logits: Tensor<B, 3>,
pub task_loss_prediction: Tensor<B, 3>,
pub hidden_states: Tensor<B, 3>,
}
#[derive(Clone, Debug)]
struct AutoGazePastKeyValue<B: Backend> {
key: Tensor<B, 4>,
value: Tensor<B, 4>,
len: usize,
}
type AutoGazePastKeyValues<B> = Vec<AutoGazePastKeyValue<B>>;
#[derive(Debug)]
struct AutoGazeCachedCausalLmOutput<B: Backend> {
logits: Tensor<B, 3>,
task_loss_prediction: Tensor<B, 3>,
past_key_values: AutoGazePastKeyValues<B>,
}
#[derive(Module, Debug)]
pub struct LlamaRmsNorm<B: Backend> {
pub weight: Param<Tensor<B, 1>>,
#[module(skip)]
eps: f32,
}
impl<B: Backend> LlamaRmsNorm<B> {
pub fn new(width: usize, eps: f32, device: &B::Device) -> Self {
Self {
weight: Param::from_tensor(Tensor::<B, 1>::ones([width.max(1)], device)),
eps: eps.max(1.0e-8),
}
}
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let [_, _, width] = x.shape().dims::<3>();
let (var, mean) = x.clone().var_mean_bias(2);
let rms = var.add(mean.powf_scalar(2.0)).add_scalar(self.eps).sqrt();
let weight = self.weight.val().reshape([1, 1, width]);
x.div(rms).mul(weight)
}
}
#[derive(Module, Debug)]
pub struct AutoGazeLlamaAttention<B: Backend> {
pub q_proj: Linear<B>,
pub k_proj: Linear<B>,
pub v_proj: Linear<B>,
pub o_proj: Linear<B>,
#[module(skip)]
num_heads: usize,
#[module(skip)]
num_key_value_heads: usize,
#[module(skip)]
head_dim: usize,
#[module(skip)]
inv_freq: Tensor<B, 1>,
}
impl<B: Backend> AutoGazeLlamaAttention<B> {
pub fn new(config: &crate::config::GazeDecoderConfig, device: &B::Device) -> Self {
let num_heads = config.num_attention_heads.max(1);
let num_key_value_heads = config.num_key_value_heads.max(1);
let head_dim = config.head_dim.max(1);
let inv_freq = llama3_inv_freq(config, device);
Self {
q_proj: LinearConfig::new(config.hidden_size.max(1), num_heads * head_dim)
.with_bias(config.attention_bias)
.init(device),
k_proj: LinearConfig::new(config.hidden_size.max(1), num_key_value_heads * head_dim)
.with_bias(config.attention_bias)
.init(device),
v_proj: LinearConfig::new(config.hidden_size.max(1), num_key_value_heads * head_dim)
.with_bias(config.attention_bias)
.init(device),
o_proj: LinearConfig::new(num_heads * head_dim, config.hidden_size.max(1))
.with_bias(config.attention_bias)
.init(device),
num_heads,
num_key_value_heads,
head_dim,
inv_freq,
}
}
pub fn forward(
&self,
hidden_states: Tensor<B, 3>,
attention_mask: Option<Tensor<B, 2, Int>>,
position_ids: Tensor<B, 2, Int>,
) -> Tensor<B, 3> {
let [batch, seq, _] = hidden_states.shape().dims::<3>();
let q = split_heads(
self.q_proj.forward(hidden_states.clone()),
self.num_heads,
self.head_dim,
);
let k = split_heads(
self.k_proj.forward(hidden_states.clone()),
self.num_key_value_heads,
self.head_dim,
);
let v = split_heads(
self.v_proj.forward(hidden_states),
self.num_key_value_heads,
self.head_dim,
);
let q = self.apply_rope(q, position_ids.clone());
let mut k = self.apply_rope(k, position_ids);
let mut v = v;
if self.num_key_value_heads != self.num_heads {
let repeat = (self.num_heads / self.num_key_value_heads.max(1)).max(1);
k = k.repeat_dim(1, repeat);
v = v.repeat_dim(1, repeat);
}
let scores = q
.matmul(k.swap_dims(2, 3))
.div_scalar((self.head_dim as f32).sqrt().max(1.0));
let bias = causal_attention_bias(batch, seq, attention_mask, &scores.device());
let attn = activation::softmax(scores + bias, 3);
let out = attn.matmul(v);
self.o_proj.forward(merge_heads(out))
}
fn forward_cached(
&self,
hidden_states: Tensor<B, 3>,
attention_mask: Option<Tensor<B, 2, Int>>,
position_ids: Tensor<B, 2, Int>,
past: Option<AutoGazePastKeyValue<B>>,
cache_capacity: usize,
) -> (Tensor<B, 3>, AutoGazePastKeyValue<B>) {
let [batch, query_len, _] = hidden_states.shape().dims::<3>();
let q = split_heads(
self.q_proj.forward(hidden_states.clone()),
self.num_heads,
self.head_dim,
);
let k = split_heads(
self.k_proj.forward(hidden_states.clone()),
self.num_key_value_heads,
self.head_dim,
);
let v = split_heads(
self.v_proj.forward(hidden_states),
self.num_key_value_heads,
self.head_dim,
);
let q = self.apply_rope(q, position_ids.clone());
let next_k = self.apply_rope(k, position_ids);
let next_v = v;
let cache_device = next_k.device();
let past_len = past.as_ref().map(|past| past.len).unwrap_or(0);
let present_len = past_len + query_len;
let capacity = cache_capacity.max(present_len).max(1);
let (key_cache, value_cache, mut k, mut v) = if let Some(past) = past {
let key_cache = past.key.slice_assign(
[
0..batch,
0..self.num_key_value_heads,
past_len..present_len,
0..self.head_dim,
],
next_k,
);
let value_cache = past.value.slice_assign(
[
0..batch,
0..self.num_key_value_heads,
past_len..present_len,
0..self.head_dim,
],
next_v,
);
(
key_cache.clone(),
value_cache.clone(),
key_cache.slice_dim(2, 0..present_len),
value_cache.slice_dim(2, 0..present_len),
)
} else {
let key_cache = Tensor::<B, 4>::empty(
[batch, self.num_key_value_heads, capacity, self.head_dim],
&cache_device,
)
.slice_assign(
[
0..batch,
0..self.num_key_value_heads,
0..present_len,
0..self.head_dim,
],
next_k.clone(),
);
let value_cache = Tensor::<B, 4>::empty(
[batch, self.num_key_value_heads, capacity, self.head_dim],
&cache_device,
)
.slice_assign(
[
0..batch,
0..self.num_key_value_heads,
0..present_len,
0..self.head_dim,
],
next_v.clone(),
);
(key_cache, value_cache, next_k, next_v)
};
let present = AutoGazePastKeyValue {
key: key_cache,
value: value_cache,
len: present_len,
};
let key_len = present_len;
if self.num_key_value_heads != self.num_heads {
let repeat = (self.num_heads / self.num_key_value_heads.max(1)).max(1);
k = k.repeat_dim(1, repeat);
v = v.repeat_dim(1, repeat);
}
let scores = q
.matmul(k.swap_dims(2, 3))
.div_scalar((self.head_dim as f32).sqrt().max(1.0));
let bias = causal_attention_bias_for_query(
batch,
query_len,
key_len,
past_len,
attention_mask,
&scores.device(),
);
let attn = activation::softmax(scores + bias, 3);
let out = attn.matmul(v);
(self.o_proj.forward(merge_heads(out)), present)
}
fn apply_rope(&self, x: Tensor<B, 4>, position_ids: Tensor<B, 2, Int>) -> Tensor<B, 4> {
let [batch, heads, seq, dim] = x.shape().dims::<4>();
let half = dim / 2;
let pos = position_ids.float().reshape([batch, 1, seq, 1]);
let inv = self.inv_freq.clone().reshape([1, 1, 1, half]);
let freqs = pos.mul(inv);
let phases = Tensor::cat(vec![freqs.clone(), freqs], 3);
let cos = phases.clone().cos().repeat_dim(1, heads);
let sin = phases.sin().repeat_dim(1, heads);
x.clone().mul(cos) + rotate_half(x).mul(sin)
}
}
#[derive(Module, Debug)]
pub struct AutoGazeLlamaMlp<B: Backend> {
pub gate_proj: Linear<B>,
pub up_proj: Linear<B>,
pub down_proj: Linear<B>,
}
impl<B: Backend> AutoGazeLlamaMlp<B> {
pub fn new(config: &crate::config::GazeDecoderConfig, device: &B::Device) -> Self {
Self {
gate_proj: LinearConfig::new(
config.hidden_size.max(1),
config.intermediate_size.max(1),
)
.with_bias(config.mlp_bias)
.init(device),
up_proj: LinearConfig::new(config.hidden_size.max(1), config.intermediate_size.max(1))
.with_bias(config.mlp_bias)
.init(device),
down_proj: LinearConfig::new(
config.intermediate_size.max(1),
config.hidden_size.max(1),
)
.with_bias(config.mlp_bias)
.init(device),
}
}
pub fn forward(&self, hidden_states: Tensor<B, 3>) -> Tensor<B, 3> {
let gate = activation::silu(self.gate_proj.forward(hidden_states.clone()));
let up = self.up_proj.forward(hidden_states);
self.down_proj.forward(gate.mul(up))
}
}
#[derive(Module, Debug)]
pub struct AutoGazeLlamaDecoderLayer<B: Backend> {
pub self_attn: AutoGazeLlamaAttention<B>,
pub mlp: AutoGazeLlamaMlp<B>,
pub input_layernorm: LlamaRmsNorm<B>,
pub post_attention_layernorm: LlamaRmsNorm<B>,
}
impl<B: Backend> AutoGazeLlamaDecoderLayer<B> {
pub fn new(config: &crate::config::GazeDecoderConfig, device: &B::Device) -> Self {
Self {
self_attn: AutoGazeLlamaAttention::new(config, device),
mlp: AutoGazeLlamaMlp::new(config, device),
input_layernorm: LlamaRmsNorm::new(
config.hidden_size.max(1),
config.rms_norm_eps,
device,
),
post_attention_layernorm: LlamaRmsNorm::new(
config.hidden_size.max(1),
config.rms_norm_eps,
device,
),
}
}
pub fn forward(
&self,
hidden_states: Tensor<B, 3>,
attention_mask: Option<Tensor<B, 2, Int>>,
position_ids: Tensor<B, 2, Int>,
) -> Tensor<B, 3> {
let attn = self.self_attn.forward(
self.input_layernorm.forward(hidden_states.clone()),
attention_mask.clone(),
position_ids.clone(),
);
let hidden_states = hidden_states + attn;
let mlp = self
.mlp
.forward(self.post_attention_layernorm.forward(hidden_states.clone()));
hidden_states + mlp
}
fn forward_cached(
&self,
hidden_states: Tensor<B, 3>,
attention_mask: Option<Tensor<B, 2, Int>>,
position_ids: Tensor<B, 2, Int>,
past: Option<AutoGazePastKeyValue<B>>,
cache_capacity: usize,
) -> (Tensor<B, 3>, AutoGazePastKeyValue<B>) {
let (attn, present) = self.self_attn.forward_cached(
self.input_layernorm.forward(hidden_states.clone()),
attention_mask,
position_ids,
past,
cache_capacity,
);
let hidden_states = hidden_states + attn;
let mlp = self
.mlp
.forward(self.post_attention_layernorm.forward(hidden_states.clone()));
(hidden_states + mlp, present)
}
}
#[derive(Module, Debug)]
pub struct AutoGazeLlamaModel<B: Backend> {
pub embed_tokens: Embedding<B>,
pub layers: Vec<AutoGazeLlamaDecoderLayer<B>>,
pub norm: LlamaRmsNorm<B>,
}
impl<B: Backend> AutoGazeLlamaModel<B> {
pub fn new(config: &crate::config::GazeDecoderConfig, device: &B::Device) -> Self {
let layers = (0..config.num_hidden_layers.max(1))
.map(|_| AutoGazeLlamaDecoderLayer::new(config, device))
.collect();
Self {
embed_tokens: EmbeddingConfig::new(config.vocab_size.max(1), config.hidden_size.max(1))
.init(device),
layers,
norm: LlamaRmsNorm::new(config.hidden_size.max(1), config.rms_norm_eps, device),
}
}
pub fn forward(
&self,
inputs_embeds: Tensor<B, 3>,
attention_mask: Option<Tensor<B, 2, Int>>,
position_ids: Tensor<B, 2, Int>,
) -> Tensor<B, 3> {
let mut hidden_states = inputs_embeds;
for layer in self.layers.iter() {
hidden_states =
layer.forward(hidden_states, attention_mask.clone(), position_ids.clone());
}
self.norm.forward(hidden_states)
}
fn forward_cached(
&self,
inputs_embeds: Tensor<B, 3>,
attention_mask: Option<Tensor<B, 2, Int>>,
position_ids: Tensor<B, 2, Int>,
past_key_values: Option<AutoGazePastKeyValues<B>>,
cache_capacity: usize,
) -> (Tensor<B, 3>, AutoGazePastKeyValues<B>) {
let mut hidden_states = inputs_embeds;
let mut next_past = Vec::with_capacity(self.layers.len());
for (idx, layer) in self.layers.iter().enumerate() {
let past = past_key_values
.as_ref()
.and_then(|past_values| past_values.get(idx))
.cloned();
let (next_hidden_states, present) = layer.forward_cached(
hidden_states,
attention_mask.clone(),
position_ids.clone(),
past,
cache_capacity,
);
hidden_states = next_hidden_states;
next_past.push(present);
}
(self.norm.forward(hidden_states), next_past)
}
}
#[derive(Module, Debug)]
pub struct AutoGazeLlamaForCausalLmMultiTokenPred<B: Backend> {
pub model: AutoGazeLlamaModel<B>,
pub lm_head: Linear<B>,
pub task_loss_prediction_head: Linear<B>,
#[module(skip)]
vocab_size: usize,
#[module(skip)]
num_multi_token_pred: usize,
}
impl<B: Backend> AutoGazeLlamaForCausalLmMultiTokenPred<B> {
pub fn new(config: &crate::config::GazeDecoderConfig, device: &B::Device) -> Self {
Self {
model: AutoGazeLlamaModel::new(config, device),
lm_head: LinearConfig::new(
config.hidden_size.max(1),
config.vocab_size.max(1) * config.num_multi_token_pred.max(1),
)
.with_bias(false)
.init(device),
task_loss_prediction_head: LinearConfig::new(
config.hidden_size.max(1),
config.num_multi_token_pred.max(1),
)
.with_bias(false)
.init(device),
vocab_size: config.vocab_size.max(1),
num_multi_token_pred: config.num_multi_token_pred.max(1),
}
}
pub fn forward(
&self,
inputs_embeds: Tensor<B, 3>,
attention_mask: Option<Tensor<B, 2, Int>>,
position_ids: Tensor<B, 2, Int>,
) -> AutoGazeCausalLmOutput<B> {
let hidden_states = self
.model
.forward(inputs_embeds, attention_mask, position_ids);
let logits = self.lm_head.forward(hidden_states.clone());
let task_loss_prediction = self
.task_loss_prediction_head
.forward(hidden_states.clone());
AutoGazeCausalLmOutput {
logits,
task_loss_prediction,
hidden_states,
}
}
fn forward_cached(
&self,
inputs_embeds: Tensor<B, 3>,
attention_mask: Option<Tensor<B, 2, Int>>,
position_ids: Tensor<B, 2, Int>,
past_key_values: Option<AutoGazePastKeyValues<B>>,
cache_capacity: usize,
) -> AutoGazeCachedCausalLmOutput<B> {
let (hidden_states, past_key_values) = self.model.forward_cached(
inputs_embeds,
attention_mask,
position_ids,
past_key_values,
cache_capacity,
);
let logits = self.lm_head.forward(hidden_states.clone());
let task_loss_prediction = self
.task_loss_prediction_head
.forward(hidden_states.clone());
AutoGazeCachedCausalLmOutput {
logits,
task_loss_prediction,
past_key_values,
}
}
}
#[derive(Module, Debug)]
pub struct AutoGazeGazingModel<B: Backend> {
pub vision_model: ShallowVideoConvNet<B>,
pub connector: Connector<B>,
pub gaze_decoder: AutoGazeLlamaForCausalLmMultiTokenPred<B>,
#[module(skip)]
input_img_size: usize,
#[module(skip)]
num_vision_tokens_each_frame: usize,
#[module(skip)]
frame_sampling_rate: usize,
#[module(skip)]
num_multi_token_pred: usize,
#[module(skip)]
eos_token_id: i64,
}
impl<B: Backend> AutoGazeGazingModel<B> {
pub fn new(config: &GazeModelConfig, device: &B::Device) -> Self {
Self {
vision_model: ShallowVideoConvNet::new(&config.vision_model_config, device),
connector: Connector::new(&config.connector_config, device),
gaze_decoder: AutoGazeLlamaForCausalLmMultiTokenPred::new(
&config.gaze_decoder_config,
device,
),
input_img_size: config.input_img_size.max(1),
num_vision_tokens_each_frame: config.num_vision_tokens_each_frame.max(1),
frame_sampling_rate: config.vision_model_config.temporal_patch_size.max(1),
num_multi_token_pred: config.gaze_decoder_config.num_multi_token_pred.max(1),
eos_token_id: config.gaze_decoder_config.eos_token_id,
}
}
pub fn embed_video(
&self,
video: Tensor<B, 5>,
use_cache: bool,
past_conv_values: Option<Vec<Tensor<B, 5>>>,
) -> (Tensor<B, 4>, Vec<Tensor<B, 5>>) {
let [_batch, _time, channels, height, width] = video.shape().dims::<5>();
let device = video.device();
let video = adapt_video_channels(video, channels, &device);
let video = if height != width {
panic!("AutoGaze Burn port currently expects square frames");
} else {
video
};
let (vision_features, new_past) =
self.vision_model
.forward(video, use_cache, past_conv_values);
let vision_features = vision_features.swap_dims(1, 2).permute([0, 1, 3, 4, 2]);
let [batch, time, height, width, dim] = vision_features.shape().dims::<5>();
let vision_features = vision_features.reshape([batch, time, height * width, dim]);
(self.connector.forward(vision_features), new_past)
}
pub fn prepare_video(&self, video: Tensor<B, 5>) -> Tensor<B, 5> {
self.resize_video(video)
}
pub fn generate(
&self,
video: Tensor<B, 5>,
max_gaze_tokens_each_frame: usize,
task_loss_requirement: Option<f32>,
) -> AutoGazeGenerateOutput {
let frames = video.shape().dims::<5>()[1];
if frames > 1 {
self.generate_cached(video, max_gaze_tokens_each_frame, task_loss_requirement)
} else {
self.generate_uncached(video, max_gaze_tokens_each_frame, task_loss_requirement)
}
}
pub fn generate_cached(
&self,
video: Tensor<B, 5>,
max_gaze_tokens_each_frame: usize,
task_loss_requirement: Option<f32>,
) -> AutoGazeGenerateOutput {
let video = self.resize_video(video);
let (video_embeds, _) = self.embed_video(video, false, None);
let [batch, frames, vision_tokens, dim] = video_embeds.shape().dims::<4>();
let device = video_embeds.device();
let mut gazing_pos = vec![Vec::<i64>::new(); batch];
let mut if_padded_gazing = vec![Vec::<bool>::new(); batch];
let mut confidences = vec![Vec::<f32>::new(); batch];
let mut num_gazing_each_frame = Vec::with_capacity(frames);
let mut past_key_values: Option<AutoGazePastKeyValues<B>> = None;
let mut pending_query_embeds: Option<Tensor<B, 3>> = None;
let mut prefix_attention_mask = vec![vec![]; batch];
let mut prefix_position_ids = vec![vec![]; batch];
let mut pending_position_indices = vec![Vec::<usize>::new(); batch];
let max_tokens = max_gaze_tokens_each_frame.max(1);
let cache_capacity = frames * (vision_tokens + max_tokens);
for frame_idx in 0..frames {
commit_pending_position_ids(
&prefix_attention_mask,
&mut prefix_position_ids,
&pending_position_indices,
);
pending_position_indices.iter_mut().for_each(Vec::clear);
let frame_embed = video_embeds
.clone()
.slice_dim(1, frame_idx..(frame_idx + 1))
.reshape([batch, vision_tokens, dim]);
for batch_idx in 0..batch {
let valid_start = prefix_attention_mask[batch_idx]
.iter()
.filter(|&&mask| mask != 0)
.count() as i64;
for valid_count in (valid_start..).take(vision_tokens) {
prefix_attention_mask[batch_idx].push(1);
prefix_position_ids[batch_idx].push(valid_count);
}
}
let initial_query_embeds = if let Some(pending) = pending_query_embeds.take() {
Tensor::cat(vec![pending, frame_embed], 1)
} else {
frame_embed
};
let mut frame_tokens = vec![Vec::<i64>::new(); batch];
let mut frame_padded = vec![Vec::<bool>::new(); batch];
let mut frame_confidences = vec![Vec::<f32>::new(); batch];
let mut finished = vec![false; batch];
let mut is_first_token = true;
let generation_prefix_len = prefix_attention_mask.first().map(Vec::len).unwrap_or(0);
let generation_tail_positions =
generation_tail_positions(&prefix_position_ids, self.num_multi_token_pred);
let mut last_generated_indices = vec![Vec::<usize>::new(); batch];
let mut next_query_embeds = Some(initial_query_embeds);
while frame_tokens.iter().map(Vec::len).max().unwrap_or(0) < max_tokens
&& finished.iter().any(|done| !done)
{
let Some(query_embeds) = next_query_embeds.take() else {
break;
};
let query_len = query_embeds.shape().dims::<3>()[1];
let query_start = cached_sequence_len(&past_key_values);
let key_len = query_start + query_len;
let attention_mask =
attention_mask_tensor_or_none::<B>(&prefix_attention_mask, key_len, &device);
let position_ids = position_ids_slice_tensor_optimized::<B>(
&prefix_position_ids,
query_start,
query_len,
&device,
);
let outputs = self.gaze_decoder.forward_cached(
query_embeds,
attention_mask,
position_ids,
past_key_values,
cache_capacity,
);
past_key_values = Some(outputs.past_key_values);
let last_logits = outputs
.logits
.slice_dim(1, query_len.saturating_sub(1)..query_len)
.reshape([
batch,
self.num_multi_token_pred,
self.gaze_decoder.vocab_size,
]);
let last_task = outputs
.task_loss_prediction
.slice_dim(1, query_len.saturating_sub(1)..query_len)
.reshape([batch, self.num_multi_token_pred]);
let (next_tokens, next_valid, next_confidences) = greedy_select_multi_tokens(
last_logits,
last_task,
&frame_tokens,
&finished,
self.eos_token_id,
max_tokens,
TaskLossStop {
requirement: task_loss_requirement,
is_first_token,
},
);
let new_tokens = next_tokens.first().map(Vec::len).unwrap_or(0);
if new_tokens == 0 {
break;
}
let flat_tokens: Vec<i64> = next_tokens
.iter()
.flat_map(|tokens| tokens.iter().copied())
.collect();
let token_tensor = Tensor::<B, 2, Int>::from_data(
TensorData::new(flat_tokens, [batch, new_tokens]),
&device,
);
let token_embeds = self.gaze_decoder.model.embed_tokens.forward(token_tensor);
for batch_idx in 0..batch {
last_generated_indices[batch_idx].clear();
for local_idx in 0..new_tokens {
last_generated_indices[batch_idx]
.push(prefix_attention_mask[batch_idx].len());
let token = next_tokens[batch_idx][local_idx];
let valid = next_valid[batch_idx][local_idx];
let confidence = next_confidences[batch_idx][local_idx];
frame_tokens[batch_idx].push(token);
frame_padded[batch_idx].push(!valid);
frame_confidences[batch_idx].push(confidence);
prefix_attention_mask[batch_idx].push(1);
let tail = &generation_tail_positions[batch_idx];
prefix_position_ids[batch_idx].push(tail[local_idx % tail.len()]);
if !valid {
finished[batch_idx] = true;
}
}
}
next_query_embeds = Some(token_embeds);
is_first_token = false;
}
let frame_count = frame_tokens.first().map(Vec::len).unwrap_or(0);
num_gazing_each_frame.push(frame_count);
let frame_offset = (frame_idx * self.num_vision_tokens_each_frame) as i64;
for batch_idx in 0..batch {
for (local_idx, padded) in frame_padded[batch_idx].iter().copied().enumerate() {
if padded {
prefix_attention_mask[batch_idx][generation_prefix_len + local_idx] = 0;
}
}
gazing_pos[batch_idx].extend(
frame_tokens[batch_idx]
.iter()
.map(|token| token + frame_offset),
);
if_padded_gazing[batch_idx].extend(frame_padded[batch_idx].iter().copied());
confidences[batch_idx].extend(frame_confidences[batch_idx].iter().copied());
}
pending_position_indices = last_generated_indices;
pending_query_embeds = next_query_embeds;
}
AutoGazeGenerateOutput {
gazing_pos,
num_gazing_each_frame,
if_padded_gazing,
confidences,
}
}
pub fn generate_uncached(
&self,
video: Tensor<B, 5>,
max_gaze_tokens_each_frame: usize,
task_loss_requirement: Option<f32>,
) -> AutoGazeGenerateOutput {
let video = self.resize_video(video);
let (video_embeds, _) = self.embed_video(video, false, None);
let [batch, frames, vision_tokens, dim] = video_embeds.shape().dims::<4>();
let device = video_embeds.device();
let mut gazing_pos = vec![Vec::<i64>::new(); batch];
let mut if_padded_gazing = vec![Vec::<bool>::new(); batch];
let mut confidences = vec![Vec::<f32>::new(); batch];
let mut num_gazing_each_frame = Vec::with_capacity(frames);
let mut prefix_embeds: Option<Tensor<B, 3>> = None;
let mut prefix_attention_mask = vec![vec![]; batch];
let mut prefix_position_ids = vec![vec![]; batch];
let mut pending_position_indices = vec![Vec::<usize>::new(); batch];
for frame_idx in 0..frames {
commit_pending_position_ids(
&prefix_attention_mask,
&mut prefix_position_ids,
&pending_position_indices,
);
pending_position_indices.iter_mut().for_each(Vec::clear);
let frame_embed = video_embeds
.clone()
.slice_dim(1, frame_idx..(frame_idx + 1))
.reshape([batch, vision_tokens, dim]);
let mut sequence_embeds = if let Some(prefix) = prefix_embeds.take() {
Tensor::cat(vec![prefix, frame_embed.clone()], 1)
} else {
frame_embed.clone()
};
for batch_idx in 0..batch {
let valid_start = prefix_attention_mask[batch_idx]
.iter()
.filter(|&&mask| mask != 0)
.count() as i64;
for valid_count in (valid_start..).take(vision_tokens) {
prefix_attention_mask[batch_idx].push(1);
prefix_position_ids[batch_idx].push(valid_count);
}
}
let mut frame_tokens = vec![Vec::<i64>::new(); batch];
let mut frame_padded = vec![Vec::<bool>::new(); batch];
let mut frame_confidences = vec![Vec::<f32>::new(); batch];
let mut finished = vec![false; batch];
let mut is_first_token = true;
let max_tokens = max_gaze_tokens_each_frame.max(1);
let generation_prefix_len = sequence_embeds.shape().dims::<3>()[1];
let generation_tail_positions =
generation_tail_positions(&prefix_position_ids, self.num_multi_token_pred);
let mut last_generated_indices = vec![Vec::<usize>::new(); batch];
while frame_tokens.iter().map(Vec::len).max().unwrap_or(0) < max_tokens
&& finished.iter().any(|done| !done)
{
let seq_len = sequence_embeds.shape().dims::<3>()[1];
let attention_mask =
attention_mask_tensor_or_none::<B>(&prefix_attention_mask, seq_len, &device);
let position_ids =
position_ids_tensor_optimized::<B>(&prefix_position_ids, seq_len, &device);
let outputs = self.gaze_decoder.forward(
sequence_embeds.clone(),
attention_mask,
position_ids,
);
let last_logits = outputs
.logits
.slice_dim(1, seq_len.saturating_sub(1)..seq_len)
.reshape([
batch,
self.num_multi_token_pred,
self.gaze_decoder.vocab_size,
]);
let last_task = outputs
.task_loss_prediction
.slice_dim(1, seq_len.saturating_sub(1)..seq_len)
.reshape([batch, self.num_multi_token_pred]);
let (next_tokens, next_valid, next_confidences) = greedy_select_multi_tokens(
last_logits,
last_task,
&frame_tokens,
&finished,
self.eos_token_id,
max_tokens,
TaskLossStop {
requirement: task_loss_requirement,
is_first_token,
},
);
let new_tokens = next_tokens.first().map(Vec::len).unwrap_or(0);
if new_tokens == 0 {
break;
}
let flat_tokens: Vec<i64> = next_tokens
.iter()
.flat_map(|tokens| tokens.iter().copied())
.collect();
let token_tensor = Tensor::<B, 2, Int>::from_data(
TensorData::new(flat_tokens, [batch, new_tokens]),
&device,
);
let token_embeds = self.gaze_decoder.model.embed_tokens.forward(token_tensor);
sequence_embeds = Tensor::cat(vec![sequence_embeds, token_embeds], 1);
for batch_idx in 0..batch {
last_generated_indices[batch_idx].clear();
for local_idx in 0..new_tokens {
last_generated_indices[batch_idx]
.push(prefix_attention_mask[batch_idx].len());
let token = next_tokens[batch_idx][local_idx];
let valid = next_valid[batch_idx][local_idx];
let confidence = next_confidences[batch_idx][local_idx];
frame_tokens[batch_idx].push(token);
frame_padded[batch_idx].push(!valid);
frame_confidences[batch_idx].push(confidence);
prefix_attention_mask[batch_idx].push(1);
let tail = &generation_tail_positions[batch_idx];
prefix_position_ids[batch_idx].push(tail[local_idx % tail.len()]);
if !valid {
finished[batch_idx] = true;
}
}
}
is_first_token = false;
}
let frame_count = frame_tokens.first().map(Vec::len).unwrap_or(0);
num_gazing_each_frame.push(frame_count);
let frame_offset = (frame_idx * self.num_vision_tokens_each_frame) as i64;
for batch_idx in 0..batch {
for (local_idx, padded) in frame_padded[batch_idx].iter().copied().enumerate() {
if padded {
prefix_attention_mask[batch_idx][generation_prefix_len + local_idx] = 0;
}
}
gazing_pos[batch_idx].extend(
frame_tokens[batch_idx]
.iter()
.map(|token| token + frame_offset),
);
if_padded_gazing[batch_idx].extend(frame_padded[batch_idx].iter().copied());
confidences[batch_idx].extend(frame_confidences[batch_idx].iter().copied());
}
pending_position_indices = last_generated_indices;
prefix_embeds = Some(sequence_embeds);
}
AutoGazeGenerateOutput {
gazing_pos,
num_gazing_each_frame,
if_padded_gazing,
confidences,
}
}
pub async fn generate_async(
&self,
video: Tensor<B, 5>,
max_gaze_tokens_each_frame: usize,
task_loss_requirement: Option<f32>,
) -> Result<AutoGazeGenerateOutput, ExecutionError> {
let frames = video.shape().dims::<5>()[1];
if frames > 1 {
self.generate_cached_async(video, max_gaze_tokens_each_frame, task_loss_requirement)
.await
} else {
self.generate_uncached_async(video, max_gaze_tokens_each_frame, task_loss_requirement)
.await
}
}
pub async fn generate_cached_async(
&self,
video: Tensor<B, 5>,
max_gaze_tokens_each_frame: usize,
task_loss_requirement: Option<f32>,
) -> Result<AutoGazeGenerateOutput, ExecutionError> {
let video = self.resize_video(video);
let (video_embeds, _) = self.embed_video(video, false, None);
let [batch, frames, vision_tokens, dim] = video_embeds.shape().dims::<4>();
let device = video_embeds.device();
let mut gazing_pos = vec![Vec::<i64>::new(); batch];
let mut if_padded_gazing = vec![Vec::<bool>::new(); batch];
let mut confidences = vec![Vec::<f32>::new(); batch];
let mut num_gazing_each_frame = Vec::with_capacity(frames);
let mut past_key_values: Option<AutoGazePastKeyValues<B>> = None;
let mut pending_query_embeds: Option<Tensor<B, 3>> = None;
let mut prefix_attention_mask = vec![vec![]; batch];
let mut prefix_position_ids = vec![vec![]; batch];
let mut pending_position_indices = vec![Vec::<usize>::new(); batch];
let max_tokens = max_gaze_tokens_each_frame.max(1);
let cache_capacity = frames * (vision_tokens + max_tokens);
for frame_idx in 0..frames {
commit_pending_position_ids(
&prefix_attention_mask,
&mut prefix_position_ids,
&pending_position_indices,
);
pending_position_indices.iter_mut().for_each(Vec::clear);
let frame_embed = video_embeds
.clone()
.slice_dim(1, frame_idx..(frame_idx + 1))
.reshape([batch, vision_tokens, dim]);
for batch_idx in 0..batch {
let valid_start = prefix_attention_mask[batch_idx]
.iter()
.filter(|&&mask| mask != 0)
.count() as i64;
for valid_count in (valid_start..).take(vision_tokens) {
prefix_attention_mask[batch_idx].push(1);
prefix_position_ids[batch_idx].push(valid_count);
}
}
let initial_query_embeds = if let Some(pending) = pending_query_embeds.take() {
Tensor::cat(vec![pending, frame_embed], 1)
} else {
frame_embed
};
let mut frame_tokens = vec![Vec::<i64>::new(); batch];
let mut frame_padded = vec![Vec::<bool>::new(); batch];
let mut frame_confidences = vec![Vec::<f32>::new(); batch];
let mut finished = vec![false; batch];
let mut is_first_token = true;
let generation_prefix_len = prefix_attention_mask.first().map(Vec::len).unwrap_or(0);
let generation_tail_positions =
generation_tail_positions(&prefix_position_ids, self.num_multi_token_pred);
let mut last_generated_indices = vec![Vec::<usize>::new(); batch];
let mut next_query_embeds = Some(initial_query_embeds);
while frame_tokens.iter().map(Vec::len).max().unwrap_or(0) < max_tokens
&& finished.iter().any(|done| !done)
{
let Some(query_embeds) = next_query_embeds.take() else {
break;
};
let query_len = query_embeds.shape().dims::<3>()[1];
let query_start = cached_sequence_len(&past_key_values);
let key_len = query_start + query_len;
let attention_mask =
attention_mask_tensor_or_none::<B>(&prefix_attention_mask, key_len, &device);
let position_ids = position_ids_slice_tensor_optimized::<B>(
&prefix_position_ids,
query_start,
query_len,
&device,
);
let outputs = self.gaze_decoder.forward_cached(
query_embeds,
attention_mask,
position_ids,
past_key_values,
cache_capacity,
);
past_key_values = Some(outputs.past_key_values);
let last_logits = outputs
.logits
.slice_dim(1, query_len.saturating_sub(1)..query_len)
.reshape([
batch,
self.num_multi_token_pred,
self.gaze_decoder.vocab_size,
]);
let last_task = outputs
.task_loss_prediction
.slice_dim(1, query_len.saturating_sub(1)..query_len)
.reshape([batch, self.num_multi_token_pred]);
let (next_tokens, next_valid, next_confidences) = greedy_select_multi_tokens_async(
last_logits,
last_task,
&frame_tokens,
&finished,
self.eos_token_id,
max_tokens,
TaskLossStop {
requirement: task_loss_requirement,
is_first_token,
},
)
.await?;
let new_tokens = next_tokens.first().map(Vec::len).unwrap_or(0);
if new_tokens == 0 {
break;
}
let flat_tokens: Vec<i64> = next_tokens
.iter()
.flat_map(|tokens| tokens.iter().copied())
.collect();
let token_tensor = Tensor::<B, 2, Int>::from_data(
TensorData::new(flat_tokens, [batch, new_tokens]),
&device,
);
let token_embeds = self.gaze_decoder.model.embed_tokens.forward(token_tensor);
for batch_idx in 0..batch {
last_generated_indices[batch_idx].clear();
for local_idx in 0..new_tokens {
last_generated_indices[batch_idx]
.push(prefix_attention_mask[batch_idx].len());
let token = next_tokens[batch_idx][local_idx];
let valid = next_valid[batch_idx][local_idx];
let confidence = next_confidences[batch_idx][local_idx];
frame_tokens[batch_idx].push(token);
frame_padded[batch_idx].push(!valid);
frame_confidences[batch_idx].push(confidence);
prefix_attention_mask[batch_idx].push(1);
let tail = &generation_tail_positions[batch_idx];
prefix_position_ids[batch_idx].push(tail[local_idx % tail.len()]);
if !valid {
finished[batch_idx] = true;
}
}
}
next_query_embeds = Some(token_embeds);
is_first_token = false;
}
let frame_count = frame_tokens.first().map(Vec::len).unwrap_or(0);
num_gazing_each_frame.push(frame_count);
let frame_offset = (frame_idx * self.num_vision_tokens_each_frame) as i64;
for batch_idx in 0..batch {
for (local_idx, padded) in frame_padded[batch_idx].iter().copied().enumerate() {
if padded {
prefix_attention_mask[batch_idx][generation_prefix_len + local_idx] = 0;
}
}
gazing_pos[batch_idx].extend(
frame_tokens[batch_idx]
.iter()
.map(|token| token + frame_offset),
);
if_padded_gazing[batch_idx].extend(frame_padded[batch_idx].iter().copied());
confidences[batch_idx].extend(frame_confidences[batch_idx].iter().copied());
}
pending_position_indices = last_generated_indices;
pending_query_embeds = next_query_embeds;
}
Ok(AutoGazeGenerateOutput {
gazing_pos,
num_gazing_each_frame,
if_padded_gazing,
confidences,
})
}
pub async fn generate_uncached_async(
&self,
video: Tensor<B, 5>,
max_gaze_tokens_each_frame: usize,
task_loss_requirement: Option<f32>,
) -> Result<AutoGazeGenerateOutput, ExecutionError> {
let video = self.resize_video(video);
let (video_embeds, _) = self.embed_video(video, false, None);
let [batch, frames, vision_tokens, dim] = video_embeds.shape().dims::<4>();
let device = video_embeds.device();
let mut gazing_pos = vec![Vec::<i64>::new(); batch];
let mut if_padded_gazing = vec![Vec::<bool>::new(); batch];
let mut confidences = vec![Vec::<f32>::new(); batch];
let mut num_gazing_each_frame = Vec::with_capacity(frames);
let mut prefix_embeds: Option<Tensor<B, 3>> = None;
let mut prefix_attention_mask = vec![vec![]; batch];
let mut prefix_position_ids = vec![vec![]; batch];
let mut pending_position_indices = vec![Vec::<usize>::new(); batch];
for frame_idx in 0..frames {
commit_pending_position_ids(
&prefix_attention_mask,
&mut prefix_position_ids,
&pending_position_indices,
);
pending_position_indices.iter_mut().for_each(Vec::clear);
let frame_embed = video_embeds
.clone()
.slice_dim(1, frame_idx..(frame_idx + 1))
.reshape([batch, vision_tokens, dim]);
let mut sequence_embeds = if let Some(prefix) = prefix_embeds.take() {
Tensor::cat(vec![prefix, frame_embed.clone()], 1)
} else {
frame_embed.clone()
};
for batch_idx in 0..batch {
let valid_start = prefix_attention_mask[batch_idx]
.iter()
.filter(|&&mask| mask != 0)
.count() as i64;
for valid_count in (valid_start..).take(vision_tokens) {
prefix_attention_mask[batch_idx].push(1);
prefix_position_ids[batch_idx].push(valid_count);
}
}
let mut frame_tokens = vec![Vec::<i64>::new(); batch];
let mut frame_padded = vec![Vec::<bool>::new(); batch];
let mut frame_confidences = vec![Vec::<f32>::new(); batch];
let mut finished = vec![false; batch];
let mut is_first_token = true;
let max_tokens = max_gaze_tokens_each_frame.max(1);
let generation_prefix_len = sequence_embeds.shape().dims::<3>()[1];
let generation_tail_positions =
generation_tail_positions(&prefix_position_ids, self.num_multi_token_pred);
let mut last_generated_indices = vec![Vec::<usize>::new(); batch];
while frame_tokens.iter().map(Vec::len).max().unwrap_or(0) < max_tokens
&& finished.iter().any(|done| !done)
{
let seq_len = sequence_embeds.shape().dims::<3>()[1];
let attention_mask =
attention_mask_tensor_or_none::<B>(&prefix_attention_mask, seq_len, &device);
let position_ids =
position_ids_tensor_optimized::<B>(&prefix_position_ids, seq_len, &device);
let outputs = self.gaze_decoder.forward(
sequence_embeds.clone(),
attention_mask,
position_ids,
);
let last_logits = outputs
.logits
.slice_dim(1, seq_len.saturating_sub(1)..seq_len)
.reshape([
batch,
self.num_multi_token_pred,
self.gaze_decoder.vocab_size,
]);
let last_task = outputs
.task_loss_prediction
.slice_dim(1, seq_len.saturating_sub(1)..seq_len)
.reshape([batch, self.num_multi_token_pred]);
let (next_tokens, next_valid, next_confidences) = greedy_select_multi_tokens_async(
last_logits,
last_task,
&frame_tokens,
&finished,
self.eos_token_id,
max_tokens,
TaskLossStop {
requirement: task_loss_requirement,
is_first_token,
},
)
.await?;
let new_tokens = next_tokens.first().map(Vec::len).unwrap_or(0);
if new_tokens == 0 {
break;
}
let flat_tokens: Vec<i64> = next_tokens
.iter()
.flat_map(|tokens| tokens.iter().copied())
.collect();
let token_tensor = Tensor::<B, 2, Int>::from_data(
TensorData::new(flat_tokens, [batch, new_tokens]),
&device,
);
let token_embeds = self.gaze_decoder.model.embed_tokens.forward(token_tensor);
sequence_embeds = Tensor::cat(vec![sequence_embeds, token_embeds], 1);
for batch_idx in 0..batch {
last_generated_indices[batch_idx].clear();
for local_idx in 0..new_tokens {
last_generated_indices[batch_idx]
.push(prefix_attention_mask[batch_idx].len());
let token = next_tokens[batch_idx][local_idx];
let valid = next_valid[batch_idx][local_idx];
let confidence = next_confidences[batch_idx][local_idx];
frame_tokens[batch_idx].push(token);
frame_padded[batch_idx].push(!valid);
frame_confidences[batch_idx].push(confidence);
prefix_attention_mask[batch_idx].push(1);
let tail = &generation_tail_positions[batch_idx];
prefix_position_ids[batch_idx].push(tail[local_idx % tail.len()]);
if !valid {
finished[batch_idx] = true;
}
}
}
is_first_token = false;
}
let frame_count = frame_tokens.first().map(Vec::len).unwrap_or(0);
num_gazing_each_frame.push(frame_count);
let frame_offset = (frame_idx * self.num_vision_tokens_each_frame) as i64;
for batch_idx in 0..batch {
for (local_idx, padded) in frame_padded[batch_idx].iter().copied().enumerate() {
if padded {
prefix_attention_mask[batch_idx][generation_prefix_len + local_idx] = 0;
}
}
gazing_pos[batch_idx].extend(
frame_tokens[batch_idx]
.iter()
.map(|token| token + frame_offset),
);
if_padded_gazing[batch_idx].extend(frame_padded[batch_idx].iter().copied());
confidences[batch_idx].extend(frame_confidences[batch_idx].iter().copied());
}
pending_position_indices = last_generated_indices;
prefix_embeds = Some(sequence_embeds);
}
Ok(AutoGazeGenerateOutput {
gazing_pos,
num_gazing_each_frame,
if_padded_gazing,
confidences,
})
}
fn resize_video(&self, video: Tensor<B, 5>) -> Tensor<B, 5> {
let [batch, time, channels, height, width] = video.shape().dims::<5>();
let device = video.device();
let video = adapt_video_channels(video, channels, &device);
if height == self.input_img_size && width == self.input_img_size {
return video;
}
let video = video.reshape([batch * time, 3, height, width]);
let video = interpolate(
video,
[self.input_img_size, self.input_img_size],
InterpolateOptions::new(InterpolateMode::Bicubic).with_align_corners(false),
);
video.reshape([batch, time, 3, self.input_img_size, self.input_img_size])
}
}
#[derive(Module, Debug)]
pub struct NativeAutoGazeModel<B: Backend> {
pub gazing_model: AutoGazeGazingModel<B>,
#[module(skip)]
pub config: AutoGazeConfig,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct AutoGazeLoadOptions {
pub allow_partial: bool,
pub validate: bool,
}
impl AutoGazeLoadOptions {
pub const fn strict() -> Self {
Self {
allow_partial: false,
validate: true,
}
}
pub const fn permissive() -> Self {
Self {
allow_partial: true,
validate: false,
}
}
pub const fn with_allow_partial(mut self, allow_partial: bool) -> Self {
self.allow_partial = allow_partial;
self
}
pub const fn with_validate(mut self, validate: bool) -> Self {
self.validate = validate;
self
}
}
impl Default for AutoGazeLoadOptions {
fn default() -> Self {
Self::strict()
}
}
impl<B: Backend> NativeAutoGazeModel<B> {
pub fn new(config: &AutoGazeConfig, device: &B::Device) -> Self {
Self {
gazing_model: AutoGazeGazingModel::new(&config.gaze_model_config, device),
config: config.clone(),
}
}
pub fn load(dir: impl AsRef<Path>, device: &B::Device) -> Result<Self> {
Self::from_hf_dir(dir, device)
}
pub fn from_hf_dir(dir: impl AsRef<Path>, device: &B::Device) -> Result<Self> {
Self::from_hf_dir_with_options(dir, device, AutoGazeLoadOptions::strict())
}
pub fn from_hf_dir_with_options(
dir: impl AsRef<Path>,
device: &B::Device,
options: AutoGazeLoadOptions,
) -> Result<Self> {
let dir = dir.as_ref();
let config = AutoGazeConfig::from_json_file(dir.join("config.json"))?;
Self::from_config_and_safetensors_file(
&config,
dir.join("model.safetensors"),
device,
options,
)
}
pub fn from_config_and_safetensors_file(
config: &AutoGazeConfig,
path: impl Into<std::path::PathBuf>,
device: &B::Device,
options: AutoGazeLoadOptions,
) -> Result<Self> {
let mut model = Self::new(config, device);
let mut store = SafetensorsStore::from_file(path)
.with_from_adapter(PyTorchToBurnAdapter)
.allow_partial(options.allow_partial)
.validate(options.validate);
model.load_safetensors_store(&mut store, options)?;
Ok(model)
}
pub fn from_config_and_safetensors_bytes(
config: &AutoGazeConfig,
bytes: Vec<u8>,
device: &B::Device,
options: AutoGazeLoadOptions,
) -> Result<Self> {
let mut model = Self::new(config, device);
let mut store = SafetensorsStore::from_bytes(Some(bytes))
.with_from_adapter(PyTorchToBurnAdapter)
.allow_partial(options.allow_partial)
.validate(options.validate);
model.load_safetensors_store(&mut store, options)?;
Ok(model)
}
fn load_safetensors_store(
&mut self,
store: &mut SafetensorsStore,
options: AutoGazeLoadOptions,
) -> Result<()> {
let result = self
.load_from(store)
.context("load AutoGaze safetensors weights")?;
if !options.allow_partial && !result.errors.is_empty() {
bail!("failed to apply AutoGaze weights: {:?}", result.errors);
}
Ok(())
}
pub fn into_pipeline(self) -> crate::AutoGazePipeline<B> {
crate::AutoGazePipeline::new(self)
}
pub fn generate(
&self,
video: Tensor<B, 5>,
max_gaze_tokens_each_frame: usize,
) -> AutoGazeGenerateOutput {
self.gazing_model
.generate(video, max_gaze_tokens_each_frame, None)
}
pub fn generate_with_task_loss_requirement(
&self,
video: Tensor<B, 5>,
max_gaze_tokens_each_frame: usize,
task_loss_requirement: Option<f32>,
) -> AutoGazeGenerateOutput {
self.gazing_model
.generate(video, max_gaze_tokens_each_frame, task_loss_requirement)
}
pub async fn generate_with_task_loss_requirement_async(
&self,
video: Tensor<B, 5>,
max_gaze_tokens_each_frame: usize,
task_loss_requirement: Option<f32>,
) -> Result<AutoGazeGenerateOutput, ExecutionError> {
self.gazing_model
.generate_async(video, max_gaze_tokens_each_frame, task_loss_requirement)
.await
}
pub fn default_max_gaze_tokens_each_frame(&self) -> usize {
self.config
.inference_gazing_ratio()
.map(|ratio| {
(ratio.clamp(0.0, 1.0) * self.config.num_vision_tokens_each_frame.max(1) as f32)
.floor()
.max(1.0) as usize
})
.unwrap_or_else(|| self.gazing_model.num_multi_token_pred.max(1))
}
pub fn default_task_loss_requirement(&self) -> Option<f32> {
self.config.inference_task_loss_requirement()
}
pub fn infer(&self, video: Tensor<B, 5>) -> AutoGazeGenerateOutput {
self.generate_with_task_loss_requirement(
video,
self.default_max_gaze_tokens_each_frame(),
self.default_task_loss_requirement(),
)
}
pub async fn infer_async(
&self,
video: Tensor<B, 5>,
) -> Result<AutoGazeGenerateOutput, ExecutionError> {
self.generate_with_task_loss_requirement_async(
video,
self.default_max_gaze_tokens_each_frame(),
self.default_task_loss_requirement(),
)
.await
}
pub fn trace_video(
&self,
video: Tensor<B, 5>,
k: usize,
max_gaze_tokens_each_frame: usize,
) -> Vec<FrameFixationTrace> {
self.trace_video_with_task_loss_requirement(
video,
k,
max_gaze_tokens_each_frame,
self.default_task_loss_requirement(),
)
}
pub fn trace_video_with_task_loss_requirement(
&self,
video: Tensor<B, 5>,
k: usize,
max_gaze_tokens_each_frame: usize,
task_loss_requirement: Option<f32>,
) -> Vec<FrameFixationTrace> {
let trace_budget = max_gaze_tokens_each_frame.max(k.max(1));
let generated =
self.generate_with_task_loss_requirement(video, trace_budget, task_loss_requirement);
generated_to_traces(&generated, &self.config, trace_budget)
}
pub async fn trace_video_with_task_loss_requirement_async(
&self,
video: Tensor<B, 5>,
k: usize,
max_gaze_tokens_each_frame: usize,
task_loss_requirement: Option<f32>,
) -> Result<Vec<FrameFixationTrace>, ExecutionError> {
let trace_budget = max_gaze_tokens_each_frame.max(k.max(1));
let generated = self
.generate_with_task_loss_requirement_async(video, trace_budget, task_loss_requirement)
.await?;
Ok(generated_to_traces(&generated, &self.config, trace_budget))
}
pub fn trace_clip_from_frames(
&self,
frames: &[f32],
clip_len: usize,
channels: usize,
height: usize,
width: usize,
k: usize,
) -> FrameFixationTrace {
let device = self.gazing_model.connector.pos_embed.val().device();
let clip = Tensor::<B, 5>::from_data(
TensorData::new(
frames.to_vec(),
[
1,
clip_len.max(1),
channels.max(1),
height.max(1),
width.max(1),
],
),
&device,
);
self.trace_video(
clip,
k,
self.gazing_model.num_multi_token_pred.max(k.max(1)),
)
.into_iter()
.next()
.unwrap_or_else(|| FrameFixationTrace::new(vec![]))
}
}
fn adapt_video_channels<B: Backend>(
video: Tensor<B, 5>,
channels: usize,
device: &B::Device,
) -> Tensor<B, 5> {
match channels {
3 => video,
1 => video.repeat_dim(2, 3),
other => {
let [batch, time, _, height, width] = video.shape().dims::<5>();
let data = vec![0.0f32; batch * time * 3 * height * width];
let mut out = Tensor::<B, 5>::from_data(
TensorData::new(data, [batch, time, 3, height, width]),
device,
);
let keep = other.min(3);
out = out.slice_assign(
[0..batch, 0..time, 0..keep, 0..height, 0..width],
video.slice_dim(2, 0..keep),
);
out
}
}
}
impl<B: Backend> crate::AutoGazeTeacher for NativeAutoGazeModel<B> {
fn trace_clip(
&self,
frames: &[f32],
clip_len: usize,
channels: usize,
height: usize,
width: usize,
k: usize,
) -> FrameFixationTrace {
self.trace_clip_from_frames(frames, clip_len, channels, height, width, k)
}
}
fn split_heads<B: Backend>(tokens: Tensor<B, 3>, heads: usize, head_dim: usize) -> Tensor<B, 4> {
let [batch, seq, _] = tokens.shape().dims::<3>();
tokens
.reshape([batch, seq, heads.max(1), head_dim.max(1)])
.swap_dims(1, 2)
}
fn merge_heads<B: Backend>(tokens: Tensor<B, 4>) -> Tensor<B, 3> {
let [batch, heads, seq, head_dim] = tokens.shape().dims::<4>();
tokens
.swap_dims(1, 2)
.reshape([batch, seq, heads * head_dim])
}
fn rotate_half<B: Backend>(x: Tensor<B, 4>) -> Tensor<B, 4> {
let dim = x.shape().dims::<4>()[3];
let half = dim / 2;
let x1 = x.clone().slice_dim(3, 0..half);
let x2 = x.slice_dim(3, half..dim);
Tensor::cat(vec![x2.mul_scalar(-1.0), x1], 3)
}
fn causal_attention_bias<B: Backend>(
batch: usize,
seq_len: usize,
attention_mask: Option<Tensor<B, 2, Int>>,
device: &B::Device,
) -> Tensor<B, 4> {
let q_pos = Tensor::<B, 1, Int>::arange(0..seq_len as i64, device).reshape([1, 1, seq_len, 1]);
let k_pos = Tensor::<B, 1, Int>::arange(0..seq_len as i64, device).reshape([1, 1, 1, seq_len]);
let causal = k_pos.lower_equal(q_pos).float();
let mut bias = causal.sub_scalar(1.0).abs().mul_scalar(-1.0e9);
if let Some(mask) = attention_mask {
let key_valid = mask.float().reshape([batch.max(1), 1, 1, seq_len]);
let key_bias = key_valid.sub_scalar(1.0).abs().mul_scalar(-1.0e9);
bias = bias + key_bias;
}
bias
}
fn causal_attention_bias_for_query<B: Backend>(
batch: usize,
query_len: usize,
key_len: usize,
past_len: usize,
attention_mask: Option<Tensor<B, 2, Int>>,
device: &B::Device,
) -> Tensor<B, 4> {
let q_pos = Tensor::<B, 1, Int>::arange(past_len as i64..(past_len + query_len) as i64, device)
.reshape([1, 1, query_len, 1]);
let k_pos = Tensor::<B, 1, Int>::arange(0..key_len as i64, device).reshape([1, 1, 1, key_len]);
let causal = k_pos.lower_equal(q_pos).float();
let mut bias = causal.sub_scalar(1.0).abs().mul_scalar(-1.0e9);
if let Some(mask) = attention_mask {
let key_valid = mask.float().reshape([batch.max(1), 1, 1, key_len]);
let key_bias = key_valid.sub_scalar(1.0).abs().mul_scalar(-1.0e9);
bias = bias + key_bias;
}
bias
}
fn llama3_inv_freq<B: Backend>(
config: &crate::config::GazeDecoderConfig,
device: &B::Device,
) -> Tensor<B, 1> {
let head_dim = config.head_dim.max(1);
let half = (head_dim / 2).max(1);
let base = config.rope_theta.max(1.0);
let mut inv_freq = Vec::with_capacity(half);
for idx in 0..half {
let dim_index = (idx * 2) as f32;
inv_freq.push(1.0 / base.powf(dim_index / head_dim as f32));
}
if let Some(rope_scaling) = config.rope_scaling.as_ref() {
let rope_type = rope_scaling
.get("rope_type")
.and_then(|value| value.as_str())
.unwrap_or("default");
if rope_type == "llama3" {
let factor = rope_scaling
.get("factor")
.and_then(|value| value.as_f64())
.unwrap_or(1.0) as f32;
let low_freq_factor = rope_scaling
.get("low_freq_factor")
.and_then(|value| value.as_f64())
.unwrap_or(1.0) as f32;
let high_freq_factor = rope_scaling
.get("high_freq_factor")
.and_then(|value| value.as_f64())
.unwrap_or(4.0) as f32;
let original_max_position_embeddings = rope_scaling
.get("original_max_position_embeddings")
.and_then(|value| value.as_f64())
.unwrap_or(config.max_position_embeddings as f64)
as f32;
let low_freq_wavelen = original_max_position_embeddings / low_freq_factor.max(1.0);
let high_freq_wavelen =
original_max_position_embeddings / high_freq_factor.max(low_freq_factor + 1.0e-6);
for value in inv_freq.iter_mut() {
let wavelen = 2.0 * std::f32::consts::PI / (*value).max(1.0e-12);
let scaled = if wavelen > low_freq_wavelen {
*value / factor.max(1.0)
} else {
*value
};
if wavelen >= high_freq_wavelen && wavelen <= low_freq_wavelen {
let smooth_factor = (original_max_position_embeddings / wavelen
- low_freq_factor)
/ (high_freq_factor - low_freq_factor).max(1.0e-6);
*value =
(1.0 - smooth_factor) * scaled / factor.max(1.0) + smooth_factor * scaled;
} else {
*value = scaled;
}
}
}
}
Tensor::<B, 1>::from_data(TensorData::new(inv_freq, [half]), device)
}
fn attention_mask_tensor<B: Backend>(
mask_rows: &[Vec<i64>],
seq_len: usize,
device: &B::Device,
) -> Tensor<B, 2, Int> {
let batch = mask_rows.len().max(1);
let mut values = Vec::with_capacity(batch * seq_len);
for row in mask_rows {
values.extend(row.iter().copied().take(seq_len));
}
Tensor::<B, 2, Int>::from_data(TensorData::new(values, [batch, seq_len]), device)
}
fn attention_mask_tensor_or_none<B: Backend>(
mask_rows: &[Vec<i64>],
seq_len: usize,
device: &B::Device,
) -> Option<Tensor<B, 2, Int>> {
if attention_mask_rows_are_all_valid(mask_rows, seq_len) {
None
} else {
Some(attention_mask_tensor(mask_rows, seq_len, device))
}
}
fn attention_mask_rows_are_all_valid(mask_rows: &[Vec<i64>], seq_len: usize) -> bool {
!mask_rows.is_empty()
&& mask_rows
.iter()
.all(|row| row.len() >= seq_len && row.iter().take(seq_len).all(|mask| *mask != 0))
}
#[cfg(test)]
fn position_ids_tensor<B: Backend>(
position_rows: &[Vec<i64>],
seq_len: usize,
device: &B::Device,
) -> Tensor<B, 2, Int> {
let batch = position_rows.len().max(1);
let mut values = Vec::with_capacity(batch * seq_len);
for row in position_rows {
values.extend(row.iter().copied().take(seq_len));
}
Tensor::<B, 2, Int>::from_data(TensorData::new(values, [batch, seq_len]), device)
}
fn position_ids_tensor_optimized<B: Backend>(
position_rows: &[Vec<i64>],
seq_len: usize,
device: &B::Device,
) -> Tensor<B, 2, Int> {
position_ids_slice_tensor_optimized(position_rows, 0, seq_len, device)
}
fn position_ids_slice_tensor<B: Backend>(
position_rows: &[Vec<i64>],
start: usize,
len: usize,
device: &B::Device,
) -> Tensor<B, 2, Int> {
let batch = position_rows.len().max(1);
let mut values = Vec::with_capacity(batch * len);
for row in position_rows {
values.extend(row.iter().copied().skip(start).take(len));
}
Tensor::<B, 2, Int>::from_data(TensorData::new(values, [batch, len]), device)
}
fn position_ids_slice_tensor_optimized<B: Backend>(
position_rows: &[Vec<i64>],
start: usize,
len: usize,
device: &B::Device,
) -> Tensor<B, 2, Int> {
let batch = position_rows.len().max(1);
if let Some(first_value) = contiguous_position_start(position_rows, start, len) {
return Tensor::<B, 1, Int>::arange(first_value..(first_value + len as i64), device)
.reshape([1, len])
.repeat_dim(0, batch);
}
if let Some(row) = identical_position_slice(position_rows, start, len) {
return Tensor::<B, 1, Int>::from_data(TensorData::new(row, [len]), device)
.reshape([1, len])
.repeat_dim(0, batch);
}
position_ids_slice_tensor(position_rows, start, len, device)
}
fn contiguous_position_start(position_rows: &[Vec<i64>], start: usize, len: usize) -> Option<i64> {
let first = position_rows.first()?.get(start).copied().or(Some(0))?;
if position_rows.iter().all(|row| {
row.len() >= start + len
&& (0..len).all(|offset| row[start + offset] == first + offset as i64)
}) {
Some(first)
} else {
None
}
}
fn identical_position_slice(
position_rows: &[Vec<i64>],
start: usize,
len: usize,
) -> Option<Vec<i64>> {
let first = position_rows.first()?;
if first.len() < start + len {
return None;
}
let row = first[start..start + len].to_vec();
if position_rows
.iter()
.all(|candidate| candidate.len() >= start + len && candidate[start..start + len] == row)
{
Some(row)
} else {
None
}
}
fn cached_sequence_len<B: Backend>(past_key_values: &Option<AutoGazePastKeyValues<B>>) -> usize {
past_key_values
.as_ref()
.and_then(|past| past.first())
.map(|past| past.len)
.unwrap_or(0)
}
fn generation_tail_positions(
position_rows: &[Vec<i64>],
num_multi_token_pred: usize,
) -> Vec<Vec<i64>> {
let chunk = num_multi_token_pred.max(1);
position_rows
.iter()
.map(|row| {
if row.is_empty() {
vec![0]
} else {
row[row.len().saturating_sub(chunk)..].to_vec()
}
})
.collect()
}
fn commit_pending_position_ids(
mask_rows: &[Vec<i64>],
position_rows: &mut [Vec<i64>],
pending_rows: &[Vec<usize>],
) {
for ((mask_row, position_row), pending) in mask_rows
.iter()
.zip(position_rows.iter_mut())
.zip(pending_rows)
{
if pending.is_empty() {
continue;
}
let mut valid_count = 0i64;
for (idx, mask) in mask_row.iter().copied().enumerate() {
if mask != 0 {
valid_count += 1;
}
if pending.contains(&idx) && idx < position_row.len() {
position_row[idx] = valid_count.saturating_sub(1);
}
}
}
}
fn greedy_select_multi_tokens<B: Backend>(
logits: Tensor<B, 3>,
task_loss: Tensor<B, 2>,
prior_tokens: &[Vec<i64>],
finished: &[bool],
eos_token_id: i64,
max_tokens: usize,
task_loss_stop: TaskLossStop,
) -> GreedyTokenSelection {
let [batch, num_multi, vocab] = logits.shape().dims::<3>();
let device = logits.device();
let context = GreedySelectionContext {
prior_tokens,
finished,
eos_token_id,
max_tokens,
task_loss_stop,
};
let mut builder = GreedySelectionBuilder::new(batch);
for multi_idx in 0..num_multi {
if !builder.has_active_rows(context) {
break;
}
let mask_values = builder.disallowed_mask(vocab, context);
let mask =
Tensor::<B, 2, Bool>::from_data(TensorData::new(mask_values, [batch, vocab]), &device);
let step_logits = logits
.clone()
.slice_dim(1, multi_idx..(multi_idx + 1))
.squeeze_dim::<2>(1);
let masked_logits = step_logits.mask_fill(mask, f32::NEG_INFINITY);
let exp_sum = masked_logits.clone().exp().sum_dim(1);
let best_indices = masked_logits
.clone()
.add(greedy_tie_breaker::<B>(batch, vocab, &device))
.argmax(1);
let best_scores = masked_logits.gather(1, best_indices.clone());
let confidences = best_scores.clone().exp().div(exp_sum);
let step_task_loss = task_loss.clone().slice_dim(1, multi_idx..(multi_idx + 1));
let best_scores = best_scores
.into_data()
.to_vec::<f32>()
.expect("convert selected logits to f32 vec");
let best_tokens = best_indices
.into_data()
.convert::<i64>()
.to_vec::<i64>()
.expect("convert selected token ids to i64 vec");
let confidences = confidences
.into_data()
.to_vec::<f32>()
.expect("convert selected confidences to f32 vec");
let task_losses = step_task_loss
.into_data()
.to_vec::<f32>()
.expect("convert selected task loss predictions to f32 vec");
builder.push_step(
multi_idx,
&best_tokens,
&best_scores,
&confidences,
&task_losses,
context,
);
}
builder.finish(eos_token_id)
}
async fn greedy_select_multi_tokens_async<B: Backend>(
logits: Tensor<B, 3>,
task_loss: Tensor<B, 2>,
prior_tokens: &[Vec<i64>],
finished: &[bool],
eos_token_id: i64,
max_tokens: usize,
task_loss_stop: TaskLossStop,
) -> Result<GreedyTokenSelection, ExecutionError> {
let [batch, num_multi, vocab] = logits.shape().dims::<3>();
let device = logits.device();
let context = GreedySelectionContext {
prior_tokens,
finished,
eos_token_id,
max_tokens,
task_loss_stop,
};
let mut builder = GreedySelectionBuilder::new(batch);
for multi_idx in 0..num_multi {
if !builder.has_active_rows(context) {
break;
}
let mask_values = builder.disallowed_mask(vocab, context);
let mask =
Tensor::<B, 2, Bool>::from_data(TensorData::new(mask_values, [batch, vocab]), &device);
let step_logits = logits
.clone()
.slice_dim(1, multi_idx..(multi_idx + 1))
.squeeze_dim::<2>(1);
let masked_logits = step_logits.mask_fill(mask, f32::NEG_INFINITY);
let exp_sum = masked_logits.clone().exp().sum_dim(1);
let best_indices = masked_logits
.clone()
.add(greedy_tie_breaker::<B>(batch, vocab, &device))
.argmax(1);
let best_scores = masked_logits.gather(1, best_indices.clone());
let confidences = best_scores.clone().exp().div(exp_sum);
let step_task_loss = task_loss.clone().slice_dim(1, multi_idx..(multi_idx + 1));
let best_scores = best_scores
.into_data_async()
.await?
.to_vec::<f32>()
.expect("convert selected logits to f32 vec");
let best_tokens = best_indices
.into_data_async()
.await?
.convert::<i64>()
.to_vec::<i64>()
.expect("convert selected token ids to i64 vec");
let confidences = confidences
.into_data_async()
.await?
.to_vec::<f32>()
.expect("convert selected confidences to f32 vec");
let task_losses = step_task_loss
.into_data_async()
.await?
.to_vec::<f32>()
.expect("convert selected task loss predictions to f32 vec");
builder.push_step(
multi_idx,
&best_tokens,
&best_scores,
&confidences,
&task_losses,
context,
);
}
Ok(builder.finish(eos_token_id))
}
fn greedy_tie_breaker<B: Backend>(batch: usize, vocab: usize, device: &B::Device) -> Tensor<B, 2> {
Tensor::<B, 1, Int>::arange(0..vocab as i64, device)
.float()
.mul_scalar(-1.0e-9)
.reshape([1, vocab])
.repeat_dim(0, batch)
}
#[derive(Clone, Copy)]
struct GreedySelectionContext<'a> {
prior_tokens: &'a [Vec<i64>],
finished: &'a [bool],
eos_token_id: i64,
max_tokens: usize,
task_loss_stop: TaskLossStop,
}
#[derive(Debug)]
struct GreedySelectionBuilder {
per_batch_tokens: Vec<Vec<i64>>,
per_batch_disallowed: Vec<Vec<i64>>,
per_batch_valid: Vec<Vec<bool>>,
per_batch_confidences: Vec<Vec<f32>>,
}
impl GreedySelectionBuilder {
fn new(batch: usize) -> Self {
Self {
per_batch_tokens: vec![Vec::new(); batch],
per_batch_disallowed: vec![Vec::new(); batch],
per_batch_valid: vec![Vec::new(); batch],
per_batch_confidences: vec![Vec::new(); batch],
}
}
fn has_active_rows(&self, context: GreedySelectionContext<'_>) -> bool {
self.per_batch_tokens
.iter()
.enumerate()
.any(|(batch_idx, selected)| {
!context.finished.get(batch_idx).copied().unwrap_or(false)
&& context.prior_tokens[batch_idx].len() + selected.len() < context.max_tokens
})
}
fn disallowed_mask(&self, vocab: usize, context: GreedySelectionContext<'_>) -> Vec<bool> {
let batch = self.per_batch_tokens.len();
let mut values = vec![false; batch * vocab];
for batch_idx in 0..batch {
let base = batch_idx * vocab;
if context.finished.get(batch_idx).copied().unwrap_or(false)
|| context.prior_tokens[batch_idx].len() + self.per_batch_tokens[batch_idx].len()
>= context.max_tokens
{
values[base..base + vocab].fill(true);
continue;
}
if context.eos_token_id >= 0 {
let eos = context.eos_token_id as usize;
if eos < vocab {
values[base + eos] = true;
}
}
for token in context.prior_tokens[batch_idx]
.iter()
.chain(&self.per_batch_disallowed[batch_idx])
{
if *token >= 0 {
let token = *token as usize;
if token < vocab {
values[base + token] = true;
}
}
}
}
values
}
fn push_step(
&mut self,
multi_idx: usize,
best_tokens: &[i64],
best_scores: &[f32],
confidences: &[f32],
task_losses: &[f32],
context: GreedySelectionContext<'_>,
) {
for batch_idx in 0..self.per_batch_tokens.len() {
if context.finished.get(batch_idx).copied().unwrap_or(false)
|| context.prior_tokens[batch_idx].len() + self.per_batch_tokens[batch_idx].len()
>= context.max_tokens
{
continue;
}
let Some((&token, &best_score)) =
best_tokens.get(batch_idx).zip(best_scores.get(batch_idx))
else {
continue;
};
if !best_score.is_finite() {
continue;
}
let task_loss = task_losses.get(batch_idx).copied().unwrap_or(f32::INFINITY);
let meets_task_loss_requirement =
context.task_loss_stop.requirement.is_some_and(|threshold| {
!(context.task_loss_stop.is_first_token && multi_idx == 0)
&& task_loss <= threshold
});
self.per_batch_disallowed[batch_idx].push(token);
self.per_batch_tokens[batch_idx].push(if meets_task_loss_requirement {
context.eos_token_id
} else {
token
});
self.per_batch_valid[batch_idx].push(!meets_task_loss_requirement);
self.per_batch_confidences[batch_idx].push(if meets_task_loss_requirement {
0.0
} else {
let confidence = confidences.get(batch_idx).copied().unwrap_or(1.0);
if confidence.is_finite() && confidence > 0.0 {
confidence
} else {
1.0
}
});
}
}
fn finish(self, eos_token_id: i64) -> GreedyTokenSelection {
let padded_len = self
.per_batch_tokens
.iter()
.map(Vec::len)
.max()
.unwrap_or(0);
let mut out_tokens = Vec::with_capacity(self.per_batch_tokens.len());
let mut out_valid = Vec::with_capacity(self.per_batch_tokens.len());
let mut out_confidences = Vec::with_capacity(self.per_batch_tokens.len());
for batch_idx in 0..self.per_batch_tokens.len() {
let mut tokens = self.per_batch_tokens[batch_idx].clone();
let mut valid = self.per_batch_valid[batch_idx].clone();
let mut confidences = self.per_batch_confidences[batch_idx].clone();
while tokens.len() < padded_len {
tokens.push(eos_token_id);
valid.push(false);
confidences.push(0.0);
}
out_tokens.push(tokens);
out_valid.push(valid);
out_confidences.push(confidences);
}
(out_tokens, out_valid, out_confidences)
}
}
#[cfg(test)]
fn greedy_select_multi_tokens_from_data(
scores: Vec<f32>,
task_losses: Vec<f32>,
batch: usize,
num_multi: usize,
vocab: usize,
context: GreedySelectionContext<'_>,
) -> GreedyTokenSelection {
let GreedySelectionContext {
prior_tokens,
finished,
eos_token_id,
max_tokens,
task_loss_stop,
} = context;
let mut per_batch_tokens = vec![Vec::new(); batch];
let mut per_batch_valid = vec![Vec::new(); batch];
let mut per_batch_confidences = vec![Vec::new(); batch];
for batch_idx in 0..batch {
let mut disallowed = prior_tokens[batch_idx].clone();
for multi_idx in 0..num_multi {
if finished.get(batch_idx).copied().unwrap_or(false)
|| prior_tokens[batch_idx].len() + per_batch_tokens[batch_idx].len() >= max_tokens
{
break;
}
let base = (batch_idx * num_multi + multi_idx) * vocab;
let mut best_index = None;
let mut best_score = f32::NEG_INFINITY;
let mut exp_sum = 0.0f32;
for vocab_idx in 0..vocab {
if vocab_idx as i64 == eos_token_id || disallowed.contains(&(vocab_idx as i64)) {
continue;
}
let score = scores[base + vocab_idx];
if score.is_finite() {
exp_sum += score.exp();
if score > best_score {
best_score = score;
best_index = Some(vocab_idx as i64);
}
}
}
if let Some(token) = best_index {
disallowed.push(token);
let meets_task_loss_requirement =
task_loss_stop.requirement.is_some_and(|threshold| {
!(task_loss_stop.is_first_token && multi_idx == 0)
&& task_losses[batch_idx * num_multi + multi_idx] <= threshold
});
per_batch_tokens[batch_idx].push(if meets_task_loss_requirement {
eos_token_id
} else {
token
});
per_batch_valid[batch_idx].push(!meets_task_loss_requirement);
per_batch_confidences[batch_idx].push(if meets_task_loss_requirement {
0.0
} else if exp_sum > 0.0 {
best_score.exp() / exp_sum
} else {
1.0
});
} else {
break;
}
}
}
let padded_len = per_batch_tokens.iter().map(Vec::len).max().unwrap_or(0);
let mut out_tokens = Vec::with_capacity(batch);
let mut out_valid = Vec::with_capacity(batch);
let mut out_confidences = Vec::with_capacity(batch);
for batch_idx in 0..batch {
let mut tokens = per_batch_tokens[batch_idx].clone();
let mut valid = per_batch_valid[batch_idx].clone();
let mut confidences = per_batch_confidences[batch_idx].clone();
while tokens.len() < padded_len {
tokens.push(eos_token_id);
valid.push(false);
confidences.push(0.0);
}
out_tokens.push(tokens);
out_valid.push(valid);
out_confidences.push(confidences);
}
(out_tokens, out_valid, out_confidences)
}
fn generated_to_traces(
generated: &AutoGazeGenerateOutput,
config: &AutoGazeConfig,
k: usize,
) -> Vec<FrameFixationTrace> {
let scale_layouts = scale_token_layouts(config);
let mut traces = Vec::with_capacity(generated.gazing_pos.len());
for batch_idx in 0..generated.gazing_pos.len() {
let mut cursor = 0usize;
let mut frames = Vec::with_capacity(generated.num_gazing_each_frame.len());
for (frame_idx, frame_len) in generated.num_gazing_each_frame.iter().copied().enumerate() {
let mut points = Vec::new();
let mut stop_probability = 0.0f32;
for local_idx in 0..frame_len {
let global_idx = cursor + local_idx;
let token = generated.gazing_pos[batch_idx][global_idx]
- (frame_idx * config.num_vision_tokens_each_frame) as i64;
let padded = generated.if_padded_gazing[batch_idx][global_idx];
if padded {
stop_probability = 1.0;
continue;
}
if let Some(point) = token_to_fixation_point(
token.max(0) as usize,
&scale_layouts,
generated.confidences[batch_idx][global_idx],
) {
points.push(point);
}
}
cursor += frame_len;
frames.push(FixationSet::new(points, stop_probability, k));
}
traces.push(FrameFixationTrace::new(frames));
}
traces
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
struct ScaleTokenLayout {
token_count: usize,
grid: usize,
}
fn token_to_fixation_point(
token: usize,
scale_layouts: &[ScaleTokenLayout],
confidence: f32,
) -> Option<FixationPoint> {
let mut offset = 0usize;
for layout in scale_layouts {
if token < offset + layout.token_count {
let local = token - offset;
let grid = layout.grid.max(1);
let row = local / grid;
let col = local % grid;
let x = (col as f32 + 0.5) / grid as f32;
let y = (row as f32 + 0.5) / grid as f32;
let cell = (1.0 / grid as f32).clamp(1.0e-6, 1.0);
return Some(FixationPoint::with_grid_extent(
x, y, cell, cell, confidence, grid,
));
}
offset += layout.token_count;
}
None
}
fn scale_token_layouts(config: &AutoGazeConfig) -> Vec<ScaleTokenLayout> {
let scales = config.scale_values();
if scales.is_empty() {
let token_count = config.num_vision_tokens_each_frame.max(1);
return vec![ScaleTokenLayout {
token_count,
grid: square_grid(token_count),
}];
}
let patch_size = config
.gaze_model_config
.vision_model_config
.kernel_size
.max(1);
let direct_layouts = scales
.iter()
.map(|scale| {
let grid = (scale / patch_size).max(1);
ScaleTokenLayout {
token_count: grid * grid,
grid,
}
})
.collect::<Vec<_>>();
let direct_tokens = direct_layouts
.iter()
.map(|layout| layout.token_count)
.sum::<usize>();
if direct_tokens == config.num_vision_tokens_each_frame {
return direct_layouts;
}
let sum_sq: usize = scales.iter().map(|scale| scale * scale).sum();
let mut counts = Vec::with_capacity(scales.len());
let mut assigned = 0usize;
for (index, scale) in scales.iter().copied().enumerate() {
if index + 1 == scales.len() {
counts.push(config.num_vision_tokens_each_frame.saturating_sub(assigned));
} else {
let count = ((scale * scale) as f64 / sum_sq.max(1) as f64
* config.num_vision_tokens_each_frame as f64)
.floor() as usize;
counts.push(count);
assigned += count;
}
}
counts
.into_iter()
.map(|token_count| ScaleTokenLayout {
token_count,
grid: square_grid(token_count),
})
.collect()
}
fn square_grid(token_count: usize) -> usize {
(token_count.max(1) as f64).sqrt().round().max(1.0) as usize
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "ndarray")]
use burn::module::ModuleMapper;
#[test]
fn token_to_fixation_point_preserves_multiscale_cells() {
let mut config = AutoGazeConfig {
scales: "32+64+112+224".to_string(),
num_vision_tokens_each_frame: 265,
..Default::default()
};
config.gaze_model_config.num_vision_tokens_each_frame = 265;
let scale_layouts = scale_token_layouts(&config);
assert_eq!(
scale_layouts,
vec![
ScaleTokenLayout {
token_count: 4,
grid: 2
},
ScaleTokenLayout {
token_count: 16,
grid: 4
},
ScaleTokenLayout {
token_count: 49,
grid: 7
},
ScaleTokenLayout {
token_count: 196,
grid: 14
}
]
);
let coarse = token_to_fixation_point(0, &scale_layouts, 1.0).expect("coarse token");
assert_eq!(coarse.x, 0.25);
assert_eq!(coarse.y, 0.25);
assert_eq!(coarse.cell_width(), 0.5);
assert_eq!(coarse.cell_height(), 0.5);
assert_eq!(coarse.cell_grid(), Some(2));
let mid = token_to_fixation_point(4, &scale_layouts, 1.0).expect("second-scale token");
assert_eq!(mid.x, 0.125);
assert_eq!(mid.y, 0.125);
assert_eq!(mid.cell_width(), 0.25);
assert_eq!(mid.cell_height(), 0.25);
assert_eq!(mid.cell_grid(), Some(4));
let fine_offset = scale_layouts[..3]
.iter()
.map(|layout| layout.token_count)
.sum::<usize>();
let fine =
token_to_fixation_point(fine_offset + 13, &scale_layouts, 1.0).expect("fine token");
assert!((fine.x - 13.5 / 14.0).abs() < 1.0e-6);
assert!((fine.y - 0.5 / 14.0).abs() < 1.0e-6);
assert!((fine.cell_width() - 1.0 / 14.0).abs() < 1.0e-6);
assert!((fine.cell_height() - 1.0 / 14.0).abs() < 1.0e-6);
assert_eq!(fine.cell_grid(), Some(14));
}
#[cfg(feature = "ndarray")]
#[test]
fn generation_tail_positions_repeat_prefix_tail_for_generated_chunks() {
type B = burn::backend::NdArray<f32>;
let device = Default::default();
let mut position_rows = vec![vec![0, 1, 1, 2, 3]];
let tail = generation_tail_positions(&position_rows, 2);
position_rows[0].extend([tail[0][0], tail[0][1], tail[0][0]]);
let position_ids = position_ids_tensor::<B>(&position_rows, 8, &device)
.into_data()
.to_vec::<i64>()
.expect("position ids");
assert_eq!(position_ids, vec![0, 1, 1, 2, 3, 2, 3, 2]);
}
#[cfg(feature = "ndarray")]
#[test]
fn optimized_position_ids_preserve_contiguous_and_shared_rows() {
type B = burn::backend::NdArray<f32>;
let device = Default::default();
let contiguous = vec![vec![4, 5, 6], vec![4, 5, 6]];
let values = position_ids_slice_tensor_optimized::<B>(&contiguous, 0, 3, &device)
.into_data()
.to_vec::<i64>()
.expect("position ids");
assert_eq!(values, vec![4, 5, 6, 4, 5, 6]);
let shared_non_contiguous = vec![vec![8, 13, 8], vec![8, 13, 8]];
let values =
position_ids_slice_tensor_optimized::<B>(&shared_non_contiguous, 0, 3, &device)
.into_data()
.to_vec::<i64>()
.expect("position ids");
assert_eq!(values, vec![8, 13, 8, 8, 13, 8]);
let per_batch = vec![vec![8, 13, 8], vec![9, 14, 9]];
let values = position_ids_slice_tensor_optimized::<B>(&per_batch, 0, 3, &device)
.into_data()
.to_vec::<i64>()
.expect("position ids");
assert_eq!(values, vec![8, 13, 8, 9, 14, 9]);
}
#[cfg(feature = "ndarray")]
#[test]
fn attention_mask_upload_is_skipped_when_all_keys_are_valid() {
type B = burn::backend::NdArray<f32>;
let device = Default::default();
assert!(attention_mask_tensor_or_none::<B>(&[vec![1, 1, 1]], 3, &device).is_none());
assert!(attention_mask_tensor_or_none::<B>(&[vec![1, 0, 1]], 3, &device).is_some());
}
#[test]
fn commit_pending_position_ids_uses_attention_cumsum() {
let masks = vec![vec![1, 1, 0, 1, 1, 0, 1]];
let mut positions = vec![vec![0, 1, 1, 2, 2, 2, 2]];
let pending = vec![vec![4, 5, 6]];
commit_pending_position_ids(&masks, &mut positions, &pending);
assert_eq!(positions[0], vec![0, 1, 1, 2, 3, 3, 4]);
}
#[cfg(feature = "ndarray")]
#[test]
fn cached_generation_matches_uncached_generation() {
type B = burn::backend::NdArray<f32>;
let device = Default::default();
let config = tiny_cache_test_config();
let mut mapper = DeterministicParamMapper { cursor: 0 };
let model = NativeAutoGazeModel::<B>::new(&config, &device).map(&mut mapper);
let values = (0..(2 * 3 * 16 * 16))
.map(|idx| ((idx % 251) as f32 / 125.0) - 1.0)
.collect::<Vec<_>>();
let video = Tensor::<B, 5>::from_data(TensorData::new(values, [1, 2, 3, 16, 16]), &device);
let uncached = model.gazing_model.generate_uncached(video.clone(), 4, None);
let cached = model.gazing_model.generate_cached(video, 4, None);
assert_eq!(cached.gazing_pos, uncached.gazing_pos);
assert_eq!(cached.num_gazing_each_frame, uncached.num_gazing_each_frame);
assert_eq!(cached.if_padded_gazing, uncached.if_padded_gazing);
assert_eq!(cached.confidences[0].len(), uncached.confidences[0].len());
for (left, right) in cached.confidences[0].iter().zip(&uncached.confidences[0]) {
assert!((left - right).abs() < 1.0e-5);
}
}
#[cfg(feature = "ndarray")]
fn tiny_cache_test_config() -> AutoGazeConfig {
let hidden = 8;
let heads = 2;
AutoGazeConfig {
scales: "8+16".to_string(),
max_num_frames: 2,
num_vision_tokens_each_frame: 5,
gaze_model_config: GazeModelConfig {
input_img_size: 16,
num_vision_tokens_each_frame: 5,
attn_mode: "sdpa".to_string(),
vision_model_config: VisionModelConfig {
hidden_dim: hidden,
out_dim: hidden,
depth: 1,
kernel_size: 8,
temporal_patch_size: 1,
trunk_temporal_kernel_size: 3,
trunk_spatial_kernel_size: 1,
},
connector_config: ConnectorConfig {
hidden_dim: hidden,
num_tokens: 4,
},
gaze_decoder_config: crate::config::GazeDecoderConfig {
vocab_size: 6,
hidden_size: hidden,
intermediate_size: hidden * 2,
num_hidden_layers: 1,
num_attention_heads: heads,
num_key_value_heads: heads,
max_position_embeddings: 512,
bos_token_id: 0,
eos_token_id: 5,
head_dim: hidden / heads,
num_multi_token_pred: 2,
..crate::config::GazeDecoderConfig::default()
},
},
..AutoGazeConfig::default()
}
}
#[cfg(feature = "ndarray")]
struct DeterministicParamMapper {
cursor: usize,
}
#[cfg(feature = "ndarray")]
impl<B: Backend> ModuleMapper<B> for DeterministicParamMapper {
fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {
let tensor = param.val();
let shape = tensor.shape().dims::<D>();
let device = tensor.device();
let len = shape.iter().product::<usize>();
let start = self.cursor;
self.cursor += len;
let values = (0..len)
.map(|idx| (((start + idx) % 97) as f32 - 48.0) * 0.002)
.collect::<Vec<_>>();
Param::from_tensor(Tensor::from_data(TensorData::new(values, shape), &device))
}
}
#[cfg(feature = "ndarray")]
#[test]
fn greedy_selection_applies_task_loss_requirement_after_first_token() {
type B = burn::backend::NdArray<f32>;
let device = Default::default();
let scores = vec![
10.0, 1.0, 0.0, -1.0, 0.0, 9.0, 1.0, -1.0,
];
let task_losses = vec![0.1, 0.2];
let logits = Tensor::<B, 3>::from_data(TensorData::new(scores.clone(), [1, 2, 4]), &device);
let task_loss =
Tensor::<B, 2>::from_data(TensorData::new(task_losses.clone(), [1, 2]), &device);
let (tokens, valid, confidences) = greedy_select_multi_tokens(
logits,
task_loss,
&[Vec::new()],
&[false],
3,
2,
TaskLossStop {
requirement: Some(0.5),
is_first_token: true,
},
);
assert_eq!(tokens, vec![vec![0, 3]]);
assert_eq!(valid, vec![vec![true, false]]);
assert!(confidences[0][0] > 0.0);
assert_eq!(confidences[0][1], 0.0);
let reference = greedy_select_multi_tokens_from_data(
scores,
task_losses,
1,
2,
4,
GreedySelectionContext {
prior_tokens: &[Vec::new()],
finished: &[false],
eos_token_id: 3,
max_tokens: 2,
task_loss_stop: TaskLossStop {
requirement: Some(0.5),
is_first_token: true,
},
},
);
assert_eq!((tokens, valid), (reference.0, reference.1));
for (left, right) in confidences[0].iter().zip(&reference.2[0]) {
assert!((left - right).abs() < 1.0e-6);
}
}
#[cfg(feature = "ndarray")]
#[test]
fn greedy_selection_disallows_hidden_token_after_task_loss_stop() {
type B = burn::backend::NdArray<f32>;
let device = Default::default();
let scores = vec![
8.0, 1.0, 0.0, -1.0, -2.0, 0.0, 9.0, 2.0, 1.0, -2.0, 0.0, 9.0, 8.0, 1.0, -2.0,
];
let task_losses = vec![1.0, 0.1, 1.0];
let logits = Tensor::<B, 3>::from_data(TensorData::new(scores.clone(), [1, 3, 5]), &device);
let task_loss =
Tensor::<B, 2>::from_data(TensorData::new(task_losses.clone(), [1, 3]), &device);
let selected = greedy_select_multi_tokens(
logits,
task_loss,
&[Vec::new()],
&[false],
4,
3,
TaskLossStop {
requirement: Some(0.5),
is_first_token: true,
},
);
let reference = greedy_select_multi_tokens_from_data(
scores,
task_losses,
1,
3,
5,
GreedySelectionContext {
prior_tokens: &[Vec::new()],
finished: &[false],
eos_token_id: 4,
max_tokens: 3,
task_loss_stop: TaskLossStop {
requirement: Some(0.5),
is_first_token: true,
},
},
);
assert_eq!(selected.0, vec![vec![0, 4, 2]]);
assert_eq!(selected.0, reference.0);
assert_eq!(selected.1, reference.1);
}
#[test]
fn scale_layout_falls_back_to_proportional_counts_for_mismatched_totals() {
let mut config = AutoGazeConfig {
scales: "32+64+224".to_string(),
num_vision_tokens_each_frame: 10,
..Default::default()
};
config.gaze_model_config.vision_model_config.kernel_size = 16;
let layouts = scale_token_layouts(&config);
assert_eq!(
layouts
.iter()
.map(|layout| layout.token_count)
.sum::<usize>(),
10
);
assert_eq!(
layouts.iter().map(|layout| layout.grid).collect::<Vec<_>>(),
vec![1, 1, 3]
);
}
#[test]
fn generated_to_traces_preserves_all_non_padded_multiscale_tokens() {
let mut config = AutoGazeConfig {
scales: "32+64+112+224".to_string(),
num_vision_tokens_each_frame: 265,
..Default::default()
};
config.gaze_model_config.num_vision_tokens_each_frame = 265;
let generated = AutoGazeGenerateOutput {
gazing_pos: vec![vec![0, 4, 20, 69]],
num_gazing_each_frame: vec![4],
if_padded_gazing: vec![vec![false, false, false, false]],
confidences: vec![vec![0.9, 0.8, 0.7, 0.6]],
};
let traces = generated_to_traces(&generated, &config, 4);
let grids = traces[0].frames[0]
.points
.iter()
.filter(|point| point.confidence > 0.0)
.map(|point| point.cell_grid())
.collect::<Vec<_>>();
assert_eq!(grids, vec![Some(2), Some(4), Some(7), Some(14)]);
}
}