cognee_embedding/
utils.rs1pub fn mean_pool(
16 output_data: &[f32],
17 seq_len: usize,
18 hidden_dim: usize,
19 attention_mask: &[i64],
20 output_dim: usize,
21) -> Vec<f32> {
22 let mut pooled = vec![0.0f32; output_dim];
23
24 for s in 0..seq_len {
26 if s < attention_mask.len() && attention_mask[s] == 1 {
27 for (h, pooled_val) in pooled
28 .iter_mut()
29 .enumerate()
30 .take(output_dim.min(hidden_dim))
31 {
32 let idx = s * hidden_dim + h;
33 if idx < output_data.len() {
34 *pooled_val += output_data[idx];
35 }
36 }
37 }
38 }
39
40 let real_tokens = attention_mask.iter().filter(|&&m| m == 1).count().max(1);
42 for val in &mut pooled {
43 *val /= real_tokens as f32;
44 }
45
46 pooled
47}
48
49pub fn l2_normalize(vec: &[f32]) -> Vec<f32> {
59 let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
60 if norm > 0.0 {
61 vec.iter().map(|x| x / norm).collect()
62 } else {
63 vec.to_vec()
64 }
65}
66
67pub fn compute_norm(vec: &[f32]) -> f32 {
75 vec.iter().map(|x| x * x).sum::<f32>().sqrt()
76}
77
78use std::borrow::Cow;
79
80pub fn is_embeddable(s: &str) -> bool {
85 !s.trim().is_empty()
86}
87
88pub fn sanitize_embedding_inputs<'a>(texts: &[&'a str]) -> Vec<Cow<'a, str>> {
98 texts
99 .iter()
100 .map(|&t| {
101 if is_embeddable(t) {
102 Cow::Borrowed(t)
103 } else {
104 Cow::Owned(".".to_string())
105 }
106 })
107 .collect()
108}
109
110pub fn handle_embedding_response(
120 original_texts: &[&str],
121 embeddings: Vec<Vec<f32>>,
122 dimensions: usize,
123) -> Vec<Vec<f32>> {
124 original_texts
125 .iter()
126 .zip(embeddings)
127 .map(|(t, v)| {
128 if is_embeddable(t) {
129 v
130 } else {
131 vec![0.0; dimensions]
132 }
133 })
134 .collect()
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140
141 #[test]
142 fn test_l2_normalization() {
143 let vec = vec![3.0, 4.0];
144 let normalized = l2_normalize(&vec);
145 let norm = compute_norm(&normalized);
146
147 assert!(
148 (norm - 1.0).abs() < 0.001,
149 "Expected norm ≈ 1.0, got {norm}"
150 );
151 assert!((normalized[0] - 0.6).abs() < 0.001);
152 assert!((normalized[1] - 0.8).abs() < 0.001);
153 }
154
155 #[test]
156 fn test_mean_pooling() {
157 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
159 let attention_mask = vec![1, 1]; let pooled = mean_pool(&data, 2, 3, &attention_mask, 3);
162
163 assert!((pooled[0] - 2.5).abs() < 0.001);
165 assert!((pooled[1] - 3.5).abs() < 0.001);
166 assert!((pooled[2] - 4.5).abs() < 0.001);
167 }
168
169 #[test]
170 fn test_mean_pooling_with_padding() {
171 let data = vec![1.0, 2.0, 3.0, 4.0, 0.0, 0.0];
173 let attention_mask = vec![1, 1, 0]; let pooled = mean_pool(&data, 3, 2, &attention_mask, 2);
176
177 assert!((pooled[0] - 2.0).abs() < 0.001);
179 assert!((pooled[1] - 3.0).abs() < 0.001);
180 }
181
182 #[test]
183 fn test_is_embeddable_non_empty() {
184 assert!(is_embeddable("hello world"));
185 assert!(is_embeddable(" some text "));
186 assert!(is_embeddable("."));
187 }
188
189 #[test]
190 fn test_is_embeddable_empty_or_whitespace() {
191 assert!(!is_embeddable(""));
192 assert!(!is_embeddable(" "));
193 assert!(!is_embeddable("\t\n"));
194 assert!(!is_embeddable("\r\n"));
195 }
196
197 #[test]
198 fn test_sanitize_embedding_inputs_preserves_valid() {
199 let texts = ["hello", "world"];
200 let sanitized = sanitize_embedding_inputs(&texts);
201 assert_eq!(sanitized.len(), 2);
202 assert_eq!(sanitized[0].as_ref(), "hello");
204 assert_eq!(sanitized[1].as_ref(), "world");
205 assert!(matches!(sanitized[0], Cow::Borrowed(_)));
206 assert!(matches!(sanitized[1], Cow::Borrowed(_)));
207 }
208
209 #[test]
210 fn test_sanitize_embedding_inputs_replaces_empty() {
211 let texts = ["", " ", "valid", "\t"];
212 let sanitized = sanitize_embedding_inputs(&texts);
213 assert_eq!(sanitized.len(), 4);
214 assert_eq!(sanitized[0].as_ref(), ".");
215 assert_eq!(sanitized[1].as_ref(), ".");
216 assert_eq!(sanitized[2].as_ref(), "valid");
217 assert_eq!(sanitized[3].as_ref(), ".");
218 assert!(matches!(sanitized[0], Cow::Owned(_)));
219 assert!(matches!(sanitized[2], Cow::Borrowed(_)));
220 }
221
222 #[test]
223 fn test_handle_embedding_response_zeros_invalid() {
224 let original = ["valid", ""];
225 let embeddings = vec![vec![1.0, 2.0, 3.0], vec![0.5, 0.5, 0.5]];
226 let result = handle_embedding_response(&original, embeddings, 3);
227 assert_eq!(result[0], vec![1.0, 2.0, 3.0]);
229 assert_eq!(result[1], vec![0.0, 0.0, 0.0]);
231 }
232
233 #[test]
234 fn test_handle_embedding_response_all_valid() {
235 let original = ["a", "b"];
236 let embeddings = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
237 let result = handle_embedding_response(&original, embeddings.clone(), 2);
238 assert_eq!(result, embeddings);
239 }
240
241 #[test]
242 fn test_handle_embedding_response_all_invalid() {
243 let original = ["", " "];
244 let embeddings = vec![vec![9.9, 9.9], vec![8.8, 8.8]];
245 let result = handle_embedding_response(&original, embeddings, 2);
246 assert_eq!(result[0], vec![0.0, 0.0]);
247 assert_eq!(result[1], vec![0.0, 0.0]);
248 }
249}