Skip to main content

mean_pool

Function mean_pool 

Source
pub fn mean_pool(
    output_data: &[f32],
    seq_len: usize,
    hidden_dim: usize,
    attention_mask: &[i64],
    output_dim: usize,
) -> Vec<f32>
Expand description

Mean pooling over sequence dimension with attention mask

Ported from examples/embeddings.rs create_embedding() function. Averages token embeddings, respecting attention mask.

§Arguments

  • output_data - Flattened ONNX output tensor
  • seq_len - Sequence length
  • hidden_dim - Hidden dimension size
  • attention_mask - Mask indicating real vs padded tokens (1 = real, 0 = padding)
  • output_dim - Target embedding dimension

§Returns

  • Pooled embedding vector (averaged over real tokens only)