Skip to main content

mean_pool

Function mean_pool 

Source
pub fn mean_pool(embeddings: &Tensor, attention_mask: &Tensor) -> Result<Tensor>
Expand description

Mean-pools token embeddings over the sequence dimension, respecting an attention mask.

Masked positions (0 in attention_mask) are excluded from the average.

  • embeddings: shape [batch, seq_len, hidden]
  • attention_mask: shape [batch, seq_len] with 1 for real tokens, 0 for padding