use ndarray::{s, Array2, ArrayView, Dim, Dimension, IxDynImpl};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Pooling {
Cls,
Mean,
}
impl Default for Pooling {
fn default() -> Self {
Self::Cls
}
}
pub fn cls(tensor: &ArrayView<f32, Dim<IxDynImpl>>) -> anyhow::Result<Array2<f32>> {
match tensor.dim().ndim() {
2 => Ok(tensor.slice(s![.., ..]).to_owned()),
3 => Ok(tensor.slice(s![.., 0, ..]).to_owned()),
_ => Err(anyhow::Error::msg(format!(
"Invalid output shape: {shape:?}. Expected 2D or 3D tensor.",
shape = tensor.dim()
))),
}
}
pub fn mean(
token_embeddings: &ArrayView<f32, Dim<IxDynImpl>>,
attention_mask_array: Array2<i64>,
) -> anyhow::Result<Array2<f32>> {
let attention_mask_original_dim = attention_mask_array.dim();
if token_embeddings.dim().ndim() == 2 {
return Ok(token_embeddings.slice(s![.., ..]).to_owned());
} else if token_embeddings.dim().ndim() != 3 {
return Err(anyhow::Error::msg(format!(
"Invalid output shape: {shape:?}. Expected 2D or 3D tensor.",
shape = token_embeddings.dim()
)));
}
let token_embeddings =
token_embeddings
.slice(s![.., .., ..]);
let attention_mask = attention_mask_array
.insert_axis(ndarray::Axis(2))
.broadcast(token_embeddings.dim())
.ok_or_else(|| {
anyhow::Error::msg(format!(
"Could not broadcast attention mask from {:?} to {:?}",
attention_mask_original_dim,
token_embeddings.dim()
))
})?
.mapv(|x| x as f32);
let masked_tensor = &attention_mask * &token_embeddings;
let sum = masked_tensor.sum_axis(ndarray::Axis(1));
let mask_sum = attention_mask.sum_axis(ndarray::Axis(1));
let mask_sum = mask_sum.mapv(|x| if x == 0f32 { 1.0 } else { x });
Ok(&sum / &mask_sum)
}