#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Pooling {
Cls,
Mean,
}
pub fn pool_embeddings(
hidden: &[f32],
attention_mask: &[&[u32]],
batch: usize,
seq: usize,
hidden_size: usize,
pooling: Pooling,
) -> Vec<Vec<f32>> {
let mut out = Vec::with_capacity(batch);
for bi in 0..batch {
let mut pooled = vec![0f32; hidden_size];
match pooling {
Pooling::Cls => {
pooled.copy_from_slice(
&hidden[bi * seq * hidden_size..bi * seq * hidden_size + hidden_size],
);
}
Pooling::Mean => {
let count: f32 = attention_mask[bi].iter().map(|&v| v as f32).sum();
let inv = 1.0 / count.max(1.0);
for si in 0..seq {
if attention_mask[bi][si] > 0 {
let off = (bi * seq + si) * hidden_size;
for j in 0..hidden_size {
pooled[j] += hidden[off + j];
}
}
}
for v in &mut pooled {
*v *= inv;
}
}
}
l2_normalize_in_place(&mut pooled);
out.push(pooled);
}
out
}
pub fn l2_normalize_in_place(v: &mut [f32]) {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt() + 1e-12;
let inv = 1.0 / norm;
for x in v {
*x *= inv;
}
}