aurora_semantic/embeddings/
pooling.rs

1//! Pooling strategies for aggregating token embeddings.
2//!
3//! These utilities are available for custom embedder implementations.
4
5#![allow(dead_code)]
6
7use serde::{Deserialize, Serialize};
8
9/// Strategy for pooling token embeddings into a single vector.
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
11pub enum PoolingStrategy {
12    /// Use the [CLS] token embedding.
13    Cls,
14    /// Mean of all token embeddings (weighted by attention mask).
15    Mean,
16    /// Max pooling across all tokens.
17    Max,
18    /// Use the last token embedding (required for Jina Code 1.5B).
19    /// This is the correct pooling for decoder-based models like Qwen2.5-Coder.
20    LastToken,
21}
22
23impl Default for PoolingStrategy {
24    fn default() -> Self {
25        Self::Mean
26    }
27}
28
29impl PoolingStrategy {
30    /// Get a description of the pooling strategy.
31    pub fn description(&self) -> &'static str {
32        match self {
33            Self::Cls => "Uses the [CLS] token embedding (first token)",
34            Self::Mean => "Averages all token embeddings, weighted by attention mask",
35            Self::Max => "Takes the maximum value across all tokens for each dimension",
36            Self::LastToken => "Uses the last valid token embedding (for decoder models like Jina Code 1.5B)",
37        }
38    }
39
40    /// Check if this strategy is recommended for sentence similarity.
41    pub fn recommended_for_similarity(&self) -> bool {
42        matches!(self, Self::Mean | Self::LastToken)
43    }
44}
45
46/// Pool vectors using CPU operations.
47pub fn pool_vectors(embeddings: &[Vec<f32>], strategy: PoolingStrategy) -> Vec<f32> {
48    if embeddings.is_empty() {
49        return Vec::new();
50    }
51
52    let dim = embeddings[0].len();
53
54    match strategy {
55        PoolingStrategy::Cls => {
56            // First vector
57            embeddings[0].clone()
58        }
59        PoolingStrategy::Mean => {
60            // Average of all vectors
61            let mut result = vec![0.0f32; dim];
62            for emb in embeddings {
63                for (i, &v) in emb.iter().enumerate() {
64                    result[i] += v;
65                }
66            }
67            let n = embeddings.len() as f32;
68            for v in &mut result {
69                *v /= n;
70            }
71            result
72        }
73        PoolingStrategy::Max => {
74            // Maximum across all vectors
75            let mut result = vec![f32::NEG_INFINITY; dim];
76            for emb in embeddings {
77                for (i, &v) in emb.iter().enumerate() {
78                    if v > result[i] {
79                        result[i] = v;
80                    }
81                }
82            }
83            result
84        }
85        PoolingStrategy::LastToken => {
86            // Last token embedding (for decoder-based models)
87            embeddings.last().cloned().unwrap_or_default()
88        }
89    }
90}
91
92/// Normalize a vector to unit length.
93pub fn normalize_vector(v: &mut [f32]) {
94    let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
95    if norm > 0.0 {
96        for x in v.iter_mut() {
97            *x /= norm;
98        }
99    }
100}
101
102/// Compute cosine similarity between two vectors.
103pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
104    debug_assert_eq!(a.len(), b.len());
105
106    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
107    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
108    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
109
110    if norm_a > 0.0 && norm_b > 0.0 {
111        dot / (norm_a * norm_b)
112    } else {
113        0.0
114    }
115}
116
117/// Compute dot product between two vectors.
118/// Assumes vectors are normalized.
119pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
120    debug_assert_eq!(a.len(), b.len());
121    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
122}
123
124/// Compute euclidean distance between two vectors.
125pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
126    debug_assert_eq!(a.len(), b.len());
127    a.iter()
128        .zip(b.iter())
129        .map(|(x, y)| (x - y).powi(2))
130        .sum::<f32>()
131        .sqrt()
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137
138    #[test]
139    fn test_mean_pooling() {
140        let embeddings = vec![
141            vec![1.0, 2.0, 3.0],
142            vec![4.0, 5.0, 6.0],
143            vec![7.0, 8.0, 9.0],
144        ];
145
146        let result = pool_vectors(&embeddings, PoolingStrategy::Mean);
147        assert_eq!(result, vec![4.0, 5.0, 6.0]);
148    }
149
150    #[test]
151    fn test_max_pooling() {
152        let embeddings = vec![
153            vec![1.0, 5.0, 3.0],
154            vec![4.0, 2.0, 6.0],
155            vec![7.0, 8.0, 1.0],
156        ];
157
158        let result = pool_vectors(&embeddings, PoolingStrategy::Max);
159        assert_eq!(result, vec![7.0, 8.0, 6.0]);
160    }
161
162    #[test]
163    fn test_cosine_similarity() {
164        let a = vec![1.0, 0.0, 0.0];
165        let b = vec![1.0, 0.0, 0.0];
166        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
167
168        let c = vec![0.0, 1.0, 0.0];
169        assert!(cosine_similarity(&a, &c).abs() < 0.001);
170    }
171
172    #[test]
173    fn test_normalize() {
174        let mut v = vec![3.0, 4.0];
175        normalize_vector(&mut v);
176
177        let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
178        assert!((norm - 1.0).abs() < 0.001);
179    }
180}