Skip to main content

mean_pool

Function mean_pool 

Source
pub fn mean_pool(
    hidden_states: &[f32],
    attention_mask: &[i32],
    seq_len: usize,
    dim: usize,
) -> Vec<f32>
Expand description

Mean pooling over token positions, weighted by attention mask.

Takes the raw hidden state output [1 × seq_len × dim] flattened to a Vec, and produces a single embedding by averaging across attended positions.