use crate::{
array::Array,
dtype::Dtype,
error::{
Error, RankMismatchPayload, Result, ShapePairMismatchPayload, UnknownEnumValuePayload,
try_with_capacity,
},
ops::{
arithmetic::{divide, maximum, multiply, subtract},
comparison::equal,
indexing::{take_along_axis, take_axis},
logical::select,
misc::{argmax, astype},
reduction::{max_axes, sum_axes},
shape::{broadcast_to, expand_dims_axes, squeeze_axes},
},
};
use super::{
fast::{layer_norm, rms_norm},
normalize::{DEFAULT_NORMALIZE_EPS, l2_normalize_eps},
scalar_like,
};
pub const LAYER_NORM_EPS: f32 = 1e-5;
pub const RMS_NORM_EPS: f32 = 1e-5;
fn validate_token_embeddings_and_mask(
token_embeddings: &Array,
attention_mask: &Array,
) -> Result<()> {
let emb_shape = token_embeddings.shape();
let mask_shape = attention_mask.shape();
if emb_shape.len() != 3 {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"token_embeddings must be rank-3 (batch, seq_len, hidden)",
emb_shape.len() as u32,
emb_shape,
)));
}
if mask_shape.len() != 2 {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"attention_mask must be rank-2 (batch, seq_len)",
mask_shape.len() as u32,
mask_shape,
)));
}
if emb_shape[0] != mask_shape[0] || emb_shape[1] != mask_shape[1] {
return Err(Error::ShapePairMismatch(ShapePairMismatchPayload::new(
"token_embeddings (batch, seq_len) must match attention_mask (batch, seq_len)",
vec![emb_shape[0], emb_shape[1]],
vec![mask_shape[0], mask_shape[1]],
)));
}
Ok(())
}
fn validate_token_embeddings_rank3(token_embeddings: &Array) -> Result<()> {
let emb_shape = token_embeddings.shape();
if emb_shape.len() != 3 {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"token_embeddings must be rank-3 (batch, seq_len, hidden)",
emb_shape.len() as u32,
emb_shape,
)));
}
Ok(())
}
pub fn mean_pooling(token_embeddings: &Array, attention_mask: &Array) -> Result<Array> {
validate_token_embeddings_and_mask(token_embeddings, attention_mask)?;
let shape = token_embeddings.shape();
let mask = expand_dims_axes(attention_mask, &[-1])?;
let mask = broadcast_to(&mask, &shape.as_slice())?;
let mask = astype(&mask, Dtype::F32)?;
let weighted = multiply(token_embeddings, &mask)?;
let sum_embeddings = sum_axes(&weighted, &[1], false)?;
let sum_mask = sum_axes(&mask, &[1], false)?;
let floor = Array::full::<f32>(&(1,), 1e-9)?;
let sum_mask = maximum(&sum_mask, &floor)?;
divide(&sum_embeddings, &sum_mask)
}
pub fn cls_pooling(token_embeddings: &Array, attention_mask: &Array) -> Result<Array> {
validate_token_embeddings_and_mask(token_embeddings, attention_mask)?;
let shape = token_embeddings.shape();
let batch = shape[0];
let hidden = shape[2];
let first_indices = argmax(attention_mask, Some(1), false)?;
let gather_idx = expand_dims_axes(&first_indices, &[1, 2])?;
let gather_idx = broadcast_to(&gather_idx, &(batch, 1, hidden))?;
let gathered = take_along_axis(token_embeddings, &gather_idx, 1)?;
squeeze_axes(&gathered, &[1])
}
pub fn max_pooling(token_embeddings: &Array, attention_mask: &Array) -> Result<Array> {
validate_token_embeddings_and_mask(token_embeddings, attention_mask)?;
let shape = token_embeddings.shape();
let emb_dtype = token_embeddings.dtype()?;
let mask = expand_dims_axes(attention_mask, &[-1])?;
let mask = broadcast_to(&mask, &shape.as_slice())?;
let mask = astype(&mask, emb_dtype)?;
let zero = scalar_like(0.0, token_embeddings)?;
let is_pad = equal(&mask, &zero)?;
let neg_inf = scalar_like(f32::NEG_INFINITY, token_embeddings)?;
let masked = select(&is_pad, &neg_inf, token_embeddings)?;
max_axes(&masked, &[1], false)
}
pub fn last_token_pooling(token_embeddings: &Array, attention_mask: &Array) -> Result<Array> {
validate_token_embeddings_and_mask(token_embeddings, attention_mask)?;
let shape = token_embeddings.shape();
let batch = shape[0];
let seq_len = shape[1];
let hidden = shape[2];
let mask_i32 = astype(attention_mask, Dtype::I32)?;
let mut rev_idx: Vec<i32> = try_with_capacity(seq_len)?;
rev_idx.extend((0..seq_len as i32).rev());
let rev_idx = Array::from_slice(&rev_idx, &(seq_len,))?;
let flipped = take_axis(&mask_i32, &rev_idx, 1)?;
let flip_indices = astype(&argmax(&flipped, Some(1), false)?, Dtype::I32)?;
let has_any_real = max_axes(&flipped, &[1], false)?;
let zero = Array::full::<i32>(&(1,), 0)?;
let is_all_pad = equal(&has_any_real, &zero)?;
let seq_len_m1 = Array::full::<i32>(&(1,), seq_len as i32 - 1)?;
let flip_indices = select(&is_all_pad, &seq_len_m1, &flip_indices)?;
let seq_len_arr = Array::full::<i32>(&(1,), seq_len as i32)?;
let one = Array::full::<i32>(&(1,), 1)?;
let last_indices = subtract(&subtract(&seq_len_arr, &flip_indices)?, &one)?;
let mask = expand_dims_axes(attention_mask, &[-1])?;
let mask = broadcast_to(&mask, &shape.as_slice())?;
let mask = astype(&mask, token_embeddings.dtype()?)?;
let masked = multiply(token_embeddings, &mask)?;
let gather_idx = expand_dims_axes(&last_indices, &[1, 2])?;
let gather_idx = broadcast_to(&gather_idx, &(batch, 1, hidden))?;
let gathered = take_along_axis(&masked, &gather_idx, 1)?;
squeeze_axes(&gathered, &[1])
}
pub fn first_token_pooling(token_embeddings: &Array) -> Result<Array> {
validate_token_embeddings_rank3(token_embeddings)?;
let zero = Array::from_slice(&[0_i32], &(1,))?;
let gathered = take_axis(token_embeddings, &zero, 1)?;
squeeze_axes(&gathered, &[1])
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, derive_more::Display, derive_more::IsVariant)]
#[display("{}", self.as_str())]
pub enum PoolingStrategy {
Mean,
Cls,
First,
Last,
Max,
None,
}
impl PoolingStrategy {
pub const fn as_str(&self) -> &'static str {
match self {
Self::Mean => "mean",
Self::Cls => "cls",
Self::First => "first",
Self::Last => "last",
Self::Max => "max",
Self::None => "none",
}
}
pub fn from_mode(mode: &str) -> Result<Self> {
match mode {
"cls" => Ok(Self::Cls),
"mean" => Ok(Self::Mean),
"max" => Ok(Self::Max),
"lasttoken" | "last" => Ok(Self::Last),
"first" => Ok(Self::First),
"none" => Ok(Self::None),
_ => Err(Error::UnknownEnumValue(UnknownEnumValuePayload::new(
"embeddings::PoolingStrategy",
mode,
&["cls", "lasttoken", "max", "mean"],
))),
}
}
}
pub fn pool(
token_embeddings: &Array,
attention_mask: &Array,
strategy: PoolingStrategy,
normalize: bool,
dimension: Option<usize>,
apply_layer_norm: bool,
apply_rms_norm: bool,
) -> Result<Array> {
let pooled = match strategy {
PoolingStrategy::Mean => mean_pooling(token_embeddings, attention_mask)?,
PoolingStrategy::Max => max_pooling(token_embeddings, attention_mask)?,
PoolingStrategy::Last => last_token_pooling(token_embeddings, attention_mask)?,
PoolingStrategy::Cls => cls_pooling(token_embeddings, attention_mask)?,
PoolingStrategy::First => first_token_pooling(token_embeddings)?,
PoolingStrategy::None => token_embeddings.try_clone()?,
};
pool_post(
pooled,
normalize,
dimension,
apply_layer_norm,
apply_rms_norm,
)
}
pub fn pool_post(
mut pooled: Array,
normalize: bool,
dimension: Option<usize>,
apply_layer_norm: bool,
apply_rms_norm: bool,
) -> Result<Array> {
if apply_layer_norm {
pooled = layer_norm(&pooled, None, None, LAYER_NORM_EPS)?;
} else if apply_rms_norm {
pooled = rms_norm(&pooled, None, RMS_NORM_EPS)?;
}
if let Some(d) = dimension {
pooled = truncate_last_dim(&pooled, d)?;
}
if normalize {
pooled = l2_normalize_eps(&pooled, DEFAULT_NORMALIZE_EPS)?;
}
Ok(pooled)
}
pub fn truncate_last_dim(x: &Array, dimension: usize) -> Result<Array> {
let shape = x.shape();
let ndim = shape.len();
if ndim == 0 {
return x.try_clone();
}
let last = shape[ndim - 1];
if dimension >= last {
return x.try_clone();
}
let mut idx: Vec<i32> = try_with_capacity(dimension)?;
idx.extend(0..dimension as i32);
let mut idx_shape = vec![1_usize; ndim];
idx_shape[ndim - 1] = dimension;
let indices = Array::from_slice(&idx, &idx_shape.as_slice())?;
let mut bshape = shape;
bshape[ndim - 1] = dimension;
let indices = broadcast_to(&indices, &bshape.as_slice())?;
take_along_axis(x, &indices, (ndim - 1) as i32)
}
#[cfg(test)]
mod tests;