use burn::tensor::backend::Backend;
use burn::tensor::linalg::vector_normalize;
use burn::tensor::Tensor;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PoolingStrategy {
Cls,
Mean,
Max,
WeightedMean,
LastToken,
}
#[derive(Debug, Clone, Copy)]
pub struct PoolingConfig {
pub normalize: bool,
}
impl Default for PoolingConfig {
fn default() -> Self {
Self { normalize: true }
}
}
#[derive(Clone)]
pub struct DynamicPooler<B: Backend> {
strategy: PoolingStrategy,
config: PoolingConfig,
_marker: core::marker::PhantomData<B>,
}
impl<B: Backend> DynamicPooler<B> {
pub fn new(strategy: PoolingStrategy, config: PoolingConfig) -> Self {
Self {
strategy,
config,
_marker: core::marker::PhantomData,
}
}
pub fn pool(
&self,
hidden_states: Tensor<B, 3>,
_attention_mask: Option<Tensor<B, 2>>,
) -> Tensor<B, 2> {
let [batch_size, seq_len, hidden_size] = hidden_states.dims();
let pooled = match self.strategy {
PoolingStrategy::Cls => hidden_states
.slice([0..batch_size, 0..1, 0..hidden_size])
.reshape([batch_size, hidden_size]),
PoolingStrategy::Mean | PoolingStrategy::WeightedMean => {
hidden_states.mean_dim(1).reshape([batch_size, hidden_size])
}
PoolingStrategy::Max => hidden_states.max_dim(1).reshape([batch_size, hidden_size]),
PoolingStrategy::LastToken => hidden_states
.slice([0..batch_size, (seq_len - 1)..seq_len, 0..hidden_size])
.reshape([batch_size, hidden_size]),
};
match self.config.normalize {
true => vector_normalize(pooled, 2.0, 1, 1e-6),
false => pooled,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::ndarray::NdArray;
#[test]
fn mean_pooling_reduces_dim() {
let device = <NdArray<f32> as Backend>::Device::default();
let tensor =
Tensor::<NdArray<f32>, 3>::from_data([[[1.0f32, 2.0, 3.0], [3.0, 4.0, 5.0]]], &device);
let pooler = DynamicPooler::new(PoolingStrategy::Mean, PoolingConfig { normalize: false });
let pooled = pooler.pool(tensor, None);
assert_eq!(pooled.dims(), [1, 3]);
}
}