lean_ctx/core/embeddings/
pooling.rs1pub fn mean_pool(
11 hidden_states: &[f32],
12 attention_mask: &[i32],
13 seq_len: usize,
14 dim: usize,
15) -> Vec<f32> {
16 let mut sum = vec![0.0f32; dim];
17 let mut count = 0.0f32;
18
19 for pos in 0..seq_len {
20 if attention_mask.get(pos).copied().unwrap_or(0) > 0 {
21 let offset = pos * dim;
22 for (d, sum_val) in sum.iter_mut().enumerate().take(dim) {
23 if let Some(&val) = hidden_states.get(offset + d) {
24 *sum_val += val;
25 }
26 }
27 count += 1.0;
28 }
29 }
30
31 if count > 0.0 {
32 for val in &mut sum {
33 *val /= count;
34 }
35 }
36
37 sum
38}
39
40pub fn normalize_l2(vec: &mut [f32]) {
42 let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
43 if norm > f32::EPSILON {
44 for x in vec.iter_mut() {
45 *x /= norm;
46 }
47 }
48}
49
50pub fn l2_norm(vec: &[f32]) -> f32 {
52 vec.iter().map(|x| x * x).sum::<f32>().sqrt()
53}
54
55#[cfg(test)]
56mod tests {
57 use super::*;
58
59 #[test]
60 fn mean_pool_basic() {
61 let hidden = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
63 let mask = vec![1, 1];
64 let result = mean_pool(&hidden, &mask, 2, 3);
65 assert_eq!(result.len(), 3);
66 assert!((result[0] - 2.5).abs() < 1e-6);
67 assert!((result[1] - 3.5).abs() < 1e-6);
68 assert!((result[2] - 4.5).abs() < 1e-6);
69 }
70
71 #[test]
72 fn mean_pool_with_padding() {
73 let hidden = vec![1.0, 2.0, 3.0, 4.0, 0.0, 0.0];
75 let mask = vec![1, 1, 0];
76 let result = mean_pool(&hidden, &mask, 3, 2);
77 assert!((result[0] - 2.0).abs() < 1e-6);
78 assert!((result[1] - 3.0).abs() < 1e-6);
79 }
80
81 #[test]
82 fn mean_pool_single_token() {
83 let hidden = vec![5.0, 10.0];
84 let mask = vec![1];
85 let result = mean_pool(&hidden, &mask, 1, 2);
86 assert!((result[0] - 5.0).abs() < 1e-6);
87 assert!((result[1] - 10.0).abs() < 1e-6);
88 }
89
90 #[test]
91 fn mean_pool_all_masked() {
92 let hidden = vec![1.0, 2.0, 3.0, 4.0];
93 let mask = vec![0, 0];
94 let result = mean_pool(&hidden, &mask, 2, 2);
95 assert!(result.iter().all(|&v| v == 0.0));
96 }
97
98 #[test]
99 fn normalize_l2_basic() {
100 let mut vec = vec![3.0, 4.0];
101 normalize_l2(&mut vec);
102 assert!((vec[0] - 0.6).abs() < 1e-6);
103 assert!((vec[1] - 0.8).abs() < 1e-6);
104
105 let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
106 assert!((norm - 1.0).abs() < 1e-5);
107 }
108
109 #[test]
110 fn normalize_l2_already_normalized() {
111 let mut vec = vec![1.0, 0.0, 0.0];
112 normalize_l2(&mut vec);
113 assert!((vec[0] - 1.0).abs() < 1e-6);
114 }
115
116 #[test]
117 fn normalize_l2_zero_vector() {
118 let mut vec = vec![0.0, 0.0, 0.0];
119 normalize_l2(&mut vec);
120 assert!(vec.iter().all(|&v| v == 0.0));
121 }
122
123 #[test]
124 fn l2_norm_basic() {
125 assert!((l2_norm(&[3.0, 4.0]) - 5.0).abs() < 1e-6);
126 }
127
128 #[test]
129 fn l2_norm_unit() {
130 assert!((l2_norm(&[1.0, 0.0, 0.0]) - 1.0).abs() < 1e-6);
131 }
132}