Skip to main content

lean_ctx/core/embeddings/
pooling.rs

1//! Pooling strategies for transformer hidden states.
2//!
3//! Converts per-token hidden states [seq_len × dim] into a single
4//! fixed-size embedding vector [dim].
5
6/// Mean pooling over token positions, weighted by attention mask.
7///
8/// Takes the raw hidden state output [1 × seq_len × dim] flattened to a Vec,
9/// and produces a single embedding by averaging across attended positions.
10pub fn mean_pool(
11    hidden_states: &[f32],
12    attention_mask: &[i32],
13    seq_len: usize,
14    dim: usize,
15) -> Vec<f32> {
16    let mut sum = vec![0.0f32; dim];
17    let mut count = 0.0f32;
18
19    for pos in 0..seq_len {
20        if attention_mask.get(pos).copied().unwrap_or(0) > 0 {
21            let offset = pos * dim;
22            for (d, sum_val) in sum.iter_mut().enumerate().take(dim) {
23                if let Some(&val) = hidden_states.get(offset + d) {
24                    *sum_val += val;
25                }
26            }
27            count += 1.0;
28        }
29    }
30
31    if count > 0.0 {
32        for val in &mut sum {
33            *val /= count;
34        }
35    }
36
37    sum
38}
39
40/// L2-normalize a vector in-place.
41pub fn normalize_l2(vec: &mut [f32]) {
42    let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
43    if norm > f32::EPSILON {
44        for x in vec.iter_mut() {
45            *x /= norm;
46        }
47    }
48}
49
50/// Compute the L2 norm of a vector.
51pub fn l2_norm(vec: &[f32]) -> f32 {
52    vec.iter().map(|x| x * x).sum::<f32>().sqrt()
53}
54
55#[cfg(test)]
56mod tests {
57    use super::*;
58
59    #[test]
60    fn mean_pool_basic() {
61        // 2 tokens, 3 dimensions, all attended
62        let hidden = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
63        let mask = vec![1, 1];
64        let result = mean_pool(&hidden, &mask, 2, 3);
65        assert_eq!(result.len(), 3);
66        assert!((result[0] - 2.5).abs() < 1e-6);
67        assert!((result[1] - 3.5).abs() < 1e-6);
68        assert!((result[2] - 4.5).abs() < 1e-6);
69    }
70
71    #[test]
72    fn mean_pool_with_padding() {
73        // 3 tokens, 2 dimensions, last token is padding
74        let hidden = vec![1.0, 2.0, 3.0, 4.0, 0.0, 0.0];
75        let mask = vec![1, 1, 0];
76        let result = mean_pool(&hidden, &mask, 3, 2);
77        assert!((result[0] - 2.0).abs() < 1e-6);
78        assert!((result[1] - 3.0).abs() < 1e-6);
79    }
80
81    #[test]
82    fn mean_pool_single_token() {
83        let hidden = vec![5.0, 10.0];
84        let mask = vec![1];
85        let result = mean_pool(&hidden, &mask, 1, 2);
86        assert!((result[0] - 5.0).abs() < 1e-6);
87        assert!((result[1] - 10.0).abs() < 1e-6);
88    }
89
90    #[test]
91    fn mean_pool_all_masked() {
92        let hidden = vec![1.0, 2.0, 3.0, 4.0];
93        let mask = vec![0, 0];
94        let result = mean_pool(&hidden, &mask, 2, 2);
95        assert!(result.iter().all(|&v| v == 0.0));
96    }
97
98    #[test]
99    fn normalize_l2_basic() {
100        let mut vec = vec![3.0, 4.0];
101        normalize_l2(&mut vec);
102        assert!((vec[0] - 0.6).abs() < 1e-6);
103        assert!((vec[1] - 0.8).abs() < 1e-6);
104
105        let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
106        assert!((norm - 1.0).abs() < 1e-5);
107    }
108
109    #[test]
110    fn normalize_l2_already_normalized() {
111        let mut vec = vec![1.0, 0.0, 0.0];
112        normalize_l2(&mut vec);
113        assert!((vec[0] - 1.0).abs() < 1e-6);
114    }
115
116    #[test]
117    fn normalize_l2_zero_vector() {
118        let mut vec = vec![0.0, 0.0, 0.0];
119        normalize_l2(&mut vec);
120        assert!(vec.iter().all(|&v| v == 0.0));
121    }
122
123    #[test]
124    fn l2_norm_basic() {
125        assert!((l2_norm(&[3.0, 4.0]) - 5.0).abs() < 1e-6);
126    }
127
128    #[test]
129    fn l2_norm_unit() {
130        assert!((l2_norm(&[1.0, 0.0, 0.0]) - 1.0).abs() < 1e-6);
131    }
132}