pub fn mean_pool(
hidden_states: &[f32],
attention_mask: &[i32],
seq_len: usize,
dim: usize,
) -> Vec<f32>Expand description
Mean pooling over token positions, weighted by attention mask.
Takes the raw hidden state output [1 × seq_len × dim] flattened to a Vec, and produces a single embedding by averaging across attended positions.