Skip to main content

cognee_embedding/
utils.rs

1/// Mean pooling over sequence dimension with attention mask
2///
3/// Ported from examples/embeddings.rs create_embedding() function.
4/// Averages token embeddings, respecting attention mask.
5///
6/// # Arguments
7/// * `output_data` - Flattened ONNX output tensor
8/// * `seq_len` - Sequence length
9/// * `hidden_dim` - Hidden dimension size
10/// * `attention_mask` - Mask indicating real vs padded tokens (1 = real, 0 = padding)
11/// * `output_dim` - Target embedding dimension
12///
13/// # Returns
14/// * Pooled embedding vector (averaged over real tokens only)
15pub fn mean_pool(
16    output_data: &[f32],
17    seq_len: usize,
18    hidden_dim: usize,
19    attention_mask: &[i64],
20    output_dim: usize,
21) -> Vec<f32> {
22    let mut pooled = vec![0.0f32; output_dim];
23
24    // Sum over sequence dimension (only real tokens)
25    for s in 0..seq_len {
26        if s < attention_mask.len() && attention_mask[s] == 1 {
27            for (h, pooled_val) in pooled
28                .iter_mut()
29                .enumerate()
30                .take(output_dim.min(hidden_dim))
31            {
32                let idx = s * hidden_dim + h;
33                if idx < output_data.len() {
34                    *pooled_val += output_data[idx];
35                }
36            }
37        }
38    }
39
40    // Average by number of real tokens
41    let real_tokens = attention_mask.iter().filter(|&&m| m == 1).count().max(1);
42    for val in &mut pooled {
43        *val /= real_tokens as f32;
44    }
45
46    pooled
47}
48
49/// L2 normalize a vector to unit length
50///
51/// Ported from examples/embeddings.rs l2_normalize() function.
52///
53/// # Arguments
54/// * `vec` - Input vector
55///
56/// # Returns
57/// * Normalized vector with L2 norm ≈ 1.0
58pub fn l2_normalize(vec: &[f32]) -> Vec<f32> {
59    let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
60    if norm > 0.0 {
61        vec.iter().map(|x| x / norm).collect()
62    } else {
63        vec.to_vec()
64    }
65}
66
67/// Compute L2 norm of a vector
68///
69/// # Arguments
70/// * `vec` - Input vector
71///
72/// # Returns
73/// * L2 norm (magnitude) of the vector
74pub fn compute_norm(vec: &[f32]) -> f32 {
75    vec.iter().map(|x| x * x).sum::<f32>().sqrt()
76}
77
78use std::borrow::Cow;
79
80/// Returns `true` if `s` is non-empty after stripping ASCII whitespace.
81///
82/// Used to detect inputs that would produce degenerate zero/NaN embeddings
83/// when sent to an embedding API.
84pub fn is_embeddable(s: &str) -> bool {
85    !s.trim().is_empty()
86}
87
88/// Replace empty/whitespace-only strings with `"."` to prevent API errors.
89///
90/// Returns a `Vec<Cow<str>>` of the same length as `texts`. Non-empty strings
91/// are returned as `Cow::Borrowed` (zero-copy); empty/whitespace-only strings
92/// are replaced with `Cow::Owned(".")`.
93///
94/// After receiving the API response, pair this with
95/// [`handle_embedding_response`] to zero out vectors for slots that were
96/// originally invalid.
97pub fn sanitize_embedding_inputs<'a>(texts: &[&'a str]) -> Vec<Cow<'a, str>> {
98    texts
99        .iter()
100        .map(|&t| {
101            if is_embeddable(t) {
102                Cow::Borrowed(t)
103            } else {
104                Cow::Owned(".".to_string())
105            }
106        })
107        .collect()
108}
109
110/// Replace embeddings for originally-invalid inputs with zero vectors.
111///
112/// Iterates `original_texts` in parallel with `embeddings`. For each slot
113/// where `original_texts[i]` is empty or whitespace-only (as determined by
114/// [`is_embeddable`]), the corresponding embedding is replaced with a zero
115/// vector of length `dimensions`.
116///
117/// This must be called with the *original* (unsanitized) texts, not the
118/// sanitized ones returned by [`sanitize_embedding_inputs`].
119pub fn handle_embedding_response(
120    original_texts: &[&str],
121    embeddings: Vec<Vec<f32>>,
122    dimensions: usize,
123) -> Vec<Vec<f32>> {
124    original_texts
125        .iter()
126        .zip(embeddings)
127        .map(|(t, v)| {
128            if is_embeddable(t) {
129                v
130            } else {
131                vec![0.0; dimensions]
132            }
133        })
134        .collect()
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140
141    #[test]
142    fn test_l2_normalization() {
143        let vec = vec![3.0, 4.0];
144        let normalized = l2_normalize(&vec);
145        let norm = compute_norm(&normalized);
146
147        assert!(
148            (norm - 1.0).abs() < 0.001,
149            "Expected norm ≈ 1.0, got {norm}"
150        );
151        assert!((normalized[0] - 0.6).abs() < 0.001);
152        assert!((normalized[1] - 0.8).abs() < 0.001);
153    }
154
155    #[test]
156    fn test_mean_pooling() {
157        // Simple 2x3 tensor: [[1,2,3], [4,5,6]]
158        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
159        let attention_mask = vec![1, 1]; // Both tokens are real
160
161        let pooled = mean_pool(&data, 2, 3, &attention_mask, 3);
162
163        // Mean of [[1,2,3], [4,5,6]] = [2.5, 3.5, 4.5]
164        assert!((pooled[0] - 2.5).abs() < 0.001);
165        assert!((pooled[1] - 3.5).abs() < 0.001);
166        assert!((pooled[2] - 4.5).abs() < 0.001);
167    }
168
169    #[test]
170    fn test_mean_pooling_with_padding() {
171        // 3x2 tensor with one padded token
172        let data = vec![1.0, 2.0, 3.0, 4.0, 0.0, 0.0];
173        let attention_mask = vec![1, 1, 0]; // Third token is padding
174
175        let pooled = mean_pool(&data, 3, 2, &attention_mask, 2);
176
177        // Mean of only real tokens [[1,2], [3,4]] = [2.0, 3.0]
178        assert!((pooled[0] - 2.0).abs() < 0.001);
179        assert!((pooled[1] - 3.0).abs() < 0.001);
180    }
181
182    #[test]
183    fn test_is_embeddable_non_empty() {
184        assert!(is_embeddable("hello world"));
185        assert!(is_embeddable("  some text  "));
186        assert!(is_embeddable("."));
187    }
188
189    #[test]
190    fn test_is_embeddable_empty_or_whitespace() {
191        assert!(!is_embeddable(""));
192        assert!(!is_embeddable("   "));
193        assert!(!is_embeddable("\t\n"));
194        assert!(!is_embeddable("\r\n"));
195    }
196
197    #[test]
198    fn test_sanitize_embedding_inputs_preserves_valid() {
199        let texts = ["hello", "world"];
200        let sanitized = sanitize_embedding_inputs(&texts);
201        assert_eq!(sanitized.len(), 2);
202        // Valid strings should be borrowed (no allocation)
203        assert_eq!(sanitized[0].as_ref(), "hello");
204        assert_eq!(sanitized[1].as_ref(), "world");
205        assert!(matches!(sanitized[0], Cow::Borrowed(_)));
206        assert!(matches!(sanitized[1], Cow::Borrowed(_)));
207    }
208
209    #[test]
210    fn test_sanitize_embedding_inputs_replaces_empty() {
211        let texts = ["", "   ", "valid", "\t"];
212        let sanitized = sanitize_embedding_inputs(&texts);
213        assert_eq!(sanitized.len(), 4);
214        assert_eq!(sanitized[0].as_ref(), ".");
215        assert_eq!(sanitized[1].as_ref(), ".");
216        assert_eq!(sanitized[2].as_ref(), "valid");
217        assert_eq!(sanitized[3].as_ref(), ".");
218        assert!(matches!(sanitized[0], Cow::Owned(_)));
219        assert!(matches!(sanitized[2], Cow::Borrowed(_)));
220    }
221
222    #[test]
223    fn test_handle_embedding_response_zeros_invalid() {
224        let original = ["valid", ""];
225        let embeddings = vec![vec![1.0, 2.0, 3.0], vec![0.5, 0.5, 0.5]];
226        let result = handle_embedding_response(&original, embeddings, 3);
227        // Valid slot: unchanged
228        assert_eq!(result[0], vec![1.0, 2.0, 3.0]);
229        // Invalid slot: zeroed out
230        assert_eq!(result[1], vec![0.0, 0.0, 0.0]);
231    }
232
233    #[test]
234    fn test_handle_embedding_response_all_valid() {
235        let original = ["a", "b"];
236        let embeddings = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
237        let result = handle_embedding_response(&original, embeddings.clone(), 2);
238        assert_eq!(result, embeddings);
239    }
240
241    #[test]
242    fn test_handle_embedding_response_all_invalid() {
243        let original = ["", "  "];
244        let embeddings = vec![vec![9.9, 9.9], vec![8.8, 8.8]];
245        let result = handle_embedding_response(&original, embeddings, 2);
246        assert_eq!(result[0], vec![0.0, 0.0]);
247        assert_eq!(result[1], vec![0.0, 0.0]);
248    }
249}