Skip to main content

rlx_embed/
pooling.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Pooling and L2 normalization for sentence/image embeddings.
17
18/// Pooling strategy for reducing token hidden states to one vector per sequence.
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum Pooling {
21    /// First token ([CLS]).
22    Cls,
23    /// Attention-mask-weighted mean over tokens.
24    Mean,
25}
26
27/// Pool `[batch, seq, hidden]` hidden states into `[batch, hidden]` and L2-normalize.
28pub fn pool_embeddings(
29    hidden: &[f32],
30    attention_mask: &[&[u32]],
31    batch: usize,
32    seq: usize,
33    hidden_size: usize,
34    pooling: Pooling,
35) -> Vec<Vec<f32>> {
36    let mut out = Vec::with_capacity(batch);
37    for bi in 0..batch {
38        let mut pooled = vec![0f32; hidden_size];
39        match pooling {
40            Pooling::Cls => {
41                pooled.copy_from_slice(
42                    &hidden[bi * seq * hidden_size..bi * seq * hidden_size + hidden_size],
43                );
44            }
45            Pooling::Mean => {
46                let count: f32 = attention_mask[bi].iter().map(|&v| v as f32).sum();
47                let inv = 1.0 / count.max(1.0);
48                for si in 0..seq {
49                    if attention_mask[bi][si] > 0 {
50                        let off = (bi * seq + si) * hidden_size;
51                        for j in 0..hidden_size {
52                            pooled[j] += hidden[off + j];
53                        }
54                    }
55                }
56                for v in &mut pooled {
57                    *v *= inv;
58                }
59            }
60        }
61        l2_normalize_in_place(&mut pooled);
62        out.push(pooled);
63    }
64    out
65}
66
67/// L2-normalize a vector in place (matches fastembed: divide by norm + 1e-12).
68pub fn l2_normalize_in_place(v: &mut [f32]) {
69    let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt() + 1e-12;
70    let inv = 1.0 / norm;
71    for x in v {
72        *x *= inv;
73    }
74}