rlx_embed/pooling.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Pooling and L2 normalization for sentence/image embeddings.
17
18/// Pooling strategy for reducing token hidden states to one vector per sequence.
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum Pooling {
21 /// First token ([CLS]).
22 Cls,
23 /// Attention-mask-weighted mean over tokens.
24 Mean,
25}
26
27/// Pool `[batch, seq, hidden]` hidden states into `[batch, hidden]` and L2-normalize.
28pub fn pool_embeddings(
29 hidden: &[f32],
30 attention_mask: &[&[u32]],
31 batch: usize,
32 seq: usize,
33 hidden_size: usize,
34 pooling: Pooling,
35) -> Vec<Vec<f32>> {
36 let mut out = Vec::with_capacity(batch);
37 for bi in 0..batch {
38 let mut pooled = vec![0f32; hidden_size];
39 match pooling {
40 Pooling::Cls => {
41 pooled.copy_from_slice(
42 &hidden[bi * seq * hidden_size..bi * seq * hidden_size + hidden_size],
43 );
44 }
45 Pooling::Mean => {
46 let count: f32 = attention_mask[bi].iter().map(|&v| v as f32).sum();
47 let inv = 1.0 / count.max(1.0);
48 for si in 0..seq {
49 if attention_mask[bi][si] > 0 {
50 let off = (bi * seq + si) * hidden_size;
51 for j in 0..hidden_size {
52 pooled[j] += hidden[off + j];
53 }
54 }
55 }
56 for v in &mut pooled {
57 *v *= inv;
58 }
59 }
60 }
61 l2_normalize_in_place(&mut pooled);
62 out.push(pooled);
63 }
64 out
65}
66
67/// L2-normalize a vector in place (matches fastembed: divide by norm + 1e-12).
68pub fn l2_normalize_in_place(v: &mut [f32]) {
69 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt() + 1e-12;
70 let inv = 1.0 / norm;
71 for x in v {
72 *x *= inv;
73 }
74}