1use ndarray::{s, Array2, ArrayView, Dim, Dimension, IxDynImpl};
2
3#[derive(Debug, Clone, PartialEq, Eq)]
4pub enum Pooling {
5 Cls,
6 Mean,
7}
8
9impl Default for Pooling {
10 fn default() -> Self {
14 Self::Cls
15 }
16}
17
18pub fn cls(tensor: &ArrayView<f32, Dim<IxDynImpl>>) -> anyhow::Result<Array2<f32>> {
19 match tensor.dim().ndim() {
20 2 => Ok(tensor.slice(s![.., ..]).to_owned()),
21 3 => Ok(tensor.slice(s![.., 0, ..]).to_owned()),
22 _ => Err(anyhow::Error::msg(format!(
23 "Invalid output shape: {shape:?}. Expected 2D or 3D tensor.",
24 shape = tensor.dim()
25 ))),
26 }
27}
28
29pub fn mean(
35 token_embeddings: &ArrayView<f32, Dim<IxDynImpl>>,
36 attention_mask_array: Array2<i64>,
37) -> anyhow::Result<Array2<f32>> {
38 let attention_mask_original_dim = attention_mask_array.dim();
39
40 if token_embeddings.dim().ndim() == 2 {
41 return Ok(token_embeddings.slice(s![.., ..]).to_owned());
45 } else if token_embeddings.dim().ndim() != 3 {
46 return Err(anyhow::Error::msg(format!(
47 "Invalid output shape: {shape:?}. Expected 2D or 3D tensor.",
48 shape = token_embeddings.dim()
49 )));
50 }
51
52 let token_embeddings =
53 token_embeddings
56 .slice(s![.., .., ..]);
57
58 let attention_mask = attention_mask_array
60 .insert_axis(ndarray::Axis(2))
61 .broadcast(token_embeddings.dim())
62 .ok_or_else(|| {
63 anyhow::Error::msg(format!(
64 "Could not broadcast attention mask from {:?} to {:?}",
65 attention_mask_original_dim,
66 token_embeddings.dim()
67 ))
68 })?
69 .mapv(|x| x as f32);
70
71 let masked_tensor = &attention_mask * &token_embeddings;
72 let sum = masked_tensor.sum_axis(ndarray::Axis(1));
73 let mask_sum = attention_mask.sum_axis(ndarray::Axis(1));
74 let mask_sum = mask_sum.mapv(|x| if x == 0f32 { 1.0 } else { x });
75 Ok(&sum / &mask_sum)
76}