aurora_semantic/embeddings/
pooling.rs1#![allow(dead_code)]
6
7use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
11pub enum PoolingStrategy {
12 Cls,
14 Mean,
16 Max,
18 LastToken,
21}
22
23impl Default for PoolingStrategy {
24 fn default() -> Self {
25 Self::Mean
26 }
27}
28
29impl PoolingStrategy {
30 pub fn description(&self) -> &'static str {
32 match self {
33 Self::Cls => "Uses the [CLS] token embedding (first token)",
34 Self::Mean => "Averages all token embeddings, weighted by attention mask",
35 Self::Max => "Takes the maximum value across all tokens for each dimension",
36 Self::LastToken => "Uses the last valid token embedding (for decoder models like Jina Code 1.5B)",
37 }
38 }
39
40 pub fn recommended_for_similarity(&self) -> bool {
42 matches!(self, Self::Mean | Self::LastToken)
43 }
44}
45
46pub fn pool_vectors(embeddings: &[Vec<f32>], strategy: PoolingStrategy) -> Vec<f32> {
48 if embeddings.is_empty() {
49 return Vec::new();
50 }
51
52 let dim = embeddings[0].len();
53
54 match strategy {
55 PoolingStrategy::Cls => {
56 embeddings[0].clone()
58 }
59 PoolingStrategy::Mean => {
60 let mut result = vec![0.0f32; dim];
62 for emb in embeddings {
63 for (i, &v) in emb.iter().enumerate() {
64 result[i] += v;
65 }
66 }
67 let n = embeddings.len() as f32;
68 for v in &mut result {
69 *v /= n;
70 }
71 result
72 }
73 PoolingStrategy::Max => {
74 let mut result = vec![f32::NEG_INFINITY; dim];
76 for emb in embeddings {
77 for (i, &v) in emb.iter().enumerate() {
78 if v > result[i] {
79 result[i] = v;
80 }
81 }
82 }
83 result
84 }
85 PoolingStrategy::LastToken => {
86 embeddings.last().cloned().unwrap_or_default()
88 }
89 }
90}
91
92pub fn normalize_vector(v: &mut [f32]) {
94 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
95 if norm > 0.0 {
96 for x in v.iter_mut() {
97 *x /= norm;
98 }
99 }
100}
101
102pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
104 debug_assert_eq!(a.len(), b.len());
105
106 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
107 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
108 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
109
110 if norm_a > 0.0 && norm_b > 0.0 {
111 dot / (norm_a * norm_b)
112 } else {
113 0.0
114 }
115}
116
117pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
120 debug_assert_eq!(a.len(), b.len());
121 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
122}
123
124pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
126 debug_assert_eq!(a.len(), b.len());
127 a.iter()
128 .zip(b.iter())
129 .map(|(x, y)| (x - y).powi(2))
130 .sum::<f32>()
131 .sqrt()
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137
138 #[test]
139 fn test_mean_pooling() {
140 let embeddings = vec![
141 vec![1.0, 2.0, 3.0],
142 vec![4.0, 5.0, 6.0],
143 vec![7.0, 8.0, 9.0],
144 ];
145
146 let result = pool_vectors(&embeddings, PoolingStrategy::Mean);
147 assert_eq!(result, vec![4.0, 5.0, 6.0]);
148 }
149
150 #[test]
151 fn test_max_pooling() {
152 let embeddings = vec![
153 vec![1.0, 5.0, 3.0],
154 vec![4.0, 2.0, 6.0],
155 vec![7.0, 8.0, 1.0],
156 ];
157
158 let result = pool_vectors(&embeddings, PoolingStrategy::Max);
159 assert_eq!(result, vec![7.0, 8.0, 6.0]);
160 }
161
162 #[test]
163 fn test_cosine_similarity() {
164 let a = vec![1.0, 0.0, 0.0];
165 let b = vec![1.0, 0.0, 0.0];
166 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
167
168 let c = vec![0.0, 1.0, 0.0];
169 assert!(cosine_similarity(&a, &c).abs() < 0.001);
170 }
171
172 #[test]
173 fn test_normalize() {
174 let mut v = vec![3.0, 4.0];
175 normalize_vector(&mut v);
176
177 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
178 assert!((norm - 1.0).abs() < 0.001);
179 }
180}