fastembed/
pooling.rs

1use ndarray::{s, Array2, ArrayView, Dim, Dimension, IxDynImpl};
2
3#[derive(Debug, Clone, PartialEq, Eq)]
4pub enum Pooling {
5    Cls,
6    Mean,
7}
8
9impl Default for Pooling {
10    /// Change this to define the default pooling strategy.
11    ///
12    /// Currently this is set to [`Self::Cls`] for backward compatibility.
13    fn default() -> Self {
14        Self::Cls
15    }
16}
17
18pub fn cls(tensor: &ArrayView<f32, Dim<IxDynImpl>>) -> anyhow::Result<Array2<f32>> {
19    match tensor.dim().ndim() {
20        2 => Ok(tensor.slice(s![.., ..]).to_owned()),
21        3 => Ok(tensor.slice(s![.., 0, ..]).to_owned()),
22        _ => Err(anyhow::Error::msg(format!(
23            "Invalid output shape: {shape:?}. Expected 2D or 3D tensor.",
24            shape = tensor.dim()
25        ))),
26    }
27}
28
29/// Pool the previous layer output by taking the element-wise arithmetic mean of the token-level embeddings after applying the attention mask.
30/// * `token_embeddings` - token embeddings in form of a tensor output of the encoding.
31/// * `attention_mask_array` - is the same mask generated by Tokenizer and used for encoding.
32// Please refer to the original python implementation for more details:
33// https://github.com/UKPLab/sentence-transformers/blob/c0fc0e8238f7f48a1e92dc90f6f96c86f69f1e02/sentence_transformers/models/Pooling.py#L151
34pub fn mean(
35    token_embeddings: &ArrayView<f32, Dim<IxDynImpl>>,
36    attention_mask_array: Array2<i64>,
37) -> anyhow::Result<Array2<f32>> {
38    let attention_mask_original_dim = attention_mask_array.dim();
39
40    if token_embeddings.dim().ndim() == 2 {
41        // There are no means to speak of if the Axis(1) is missing.
42        // Typically we'll see a dimension of (batch_size, feature_count) here.
43        // It can be assumed that pooling is already done within the model.
44        return Ok(token_embeddings.slice(s![.., ..]).to_owned());
45    } else if token_embeddings.dim().ndim() != 3 {
46        return Err(anyhow::Error::msg(format!(
47            "Invalid output shape: {shape:?}. Expected 2D or 3D tensor.",
48            shape = token_embeddings.dim()
49        )));
50    }
51
52    let token_embeddings =
53        // If the token_embeddings is 3D, return the whole thing.
54        // Using `slice` here to assert the dimension.
55        token_embeddings
56            .slice(s![.., .., ..]);
57
58    // Compute attention mask
59    let attention_mask = attention_mask_array
60        .insert_axis(ndarray::Axis(2))
61        .broadcast(token_embeddings.dim())
62        .ok_or_else(|| {
63            anyhow::Error::msg(format!(
64                "Could not broadcast attention mask from {:?} to {:?}",
65                attention_mask_original_dim,
66                token_embeddings.dim()
67            ))
68        })?
69        .mapv(|x| x as f32);
70
71    let masked_tensor = &attention_mask * &token_embeddings;
72    let sum = masked_tensor.sum_axis(ndarray::Axis(1));
73    let mask_sum = attention_mask.sum_axis(ndarray::Axis(1));
74    let mask_sum = mask_sum.mapv(|x| if x == 0f32 { 1.0 } else { x });
75    Ok(&sum / &mask_sum)
76}