1use std::collections::HashMap;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub enum PoolingStrategy {
20 Mean,
22 Max,
24 Cls,
26 AttentionWeighted,
28}
29
30#[derive(Debug, Clone)]
32pub struct AggregatorConfig {
33 pub default_strategy: PoolingStrategy,
35 pub normalize_output: bool,
37 pub eps: f32,
39}
40
41impl Default for AggregatorConfig {
42 fn default() -> Self {
43 Self {
44 default_strategy: PoolingStrategy::Mean,
45 normalize_output: false,
46 eps: 1e-12,
47 }
48 }
49}
50
51#[derive(Debug, Clone)]
53pub struct AggregatedEmbedding {
54 pub vector: Vec<f32>,
56 pub strategy: PoolingStrategy,
58 pub token_count: usize,
60}
61
62#[derive(Debug, Clone)]
64pub struct HierarchicalResult {
65 pub sentence_embeddings: Vec<Vec<f32>>,
67 pub paragraph_embeddings: Vec<Vec<f32>>,
69 pub document_embedding: Vec<f32>,
71}
72
73#[derive(Debug, Clone)]
75pub struct BatchResult {
76 pub embeddings: Vec<AggregatedEmbedding>,
78 pub sequence_count: usize,
80}
81
82pub struct EmbeddingAggregator {
88 config: AggregatorConfig,
89 total_aggregations: u64,
90}
91
92impl EmbeddingAggregator {
93 pub fn new(config: AggregatorConfig) -> Self {
95 Self {
96 config,
97 total_aggregations: 0,
98 }
99 }
100
101 pub fn aggregate(&mut self, tokens: &[Vec<f32>]) -> Option<AggregatedEmbedding> {
106 self.aggregate_with(tokens, self.config.default_strategy, None)
107 }
108
109 pub fn aggregate_with(
114 &mut self,
115 tokens: &[Vec<f32>],
116 strategy: PoolingStrategy,
117 attention_weights: Option<&[f32]>,
118 ) -> Option<AggregatedEmbedding> {
119 if tokens.is_empty() {
120 return None;
121 }
122 let dim = tokens[0].len();
123 if dim == 0 {
124 return None;
125 }
126
127 let raw = match strategy {
128 PoolingStrategy::Mean => mean_pool(tokens, dim),
129 PoolingStrategy::Max => max_pool(tokens, dim),
130 PoolingStrategy::Cls => cls_pool(tokens),
131 PoolingStrategy::AttentionWeighted => {
132 attention_pool(tokens, attention_weights, dim, self.config.eps)
133 }
134 };
135
136 let vector = if self.config.normalize_output {
137 l2_normalize(&raw, self.config.eps)
138 } else {
139 raw
140 };
141
142 self.total_aggregations += 1;
143
144 Some(AggregatedEmbedding {
145 vector,
146 strategy,
147 token_count: tokens.len(),
148 })
149 }
150
151 pub fn aggregate_batch(&mut self, batch: &[Vec<Vec<f32>>]) -> BatchResult {
153 self.aggregate_batch_with(batch, self.config.default_strategy)
154 }
155
156 pub fn aggregate_batch_with(
158 &mut self,
159 batch: &[Vec<Vec<f32>>],
160 strategy: PoolingStrategy,
161 ) -> BatchResult {
162 let embeddings: Vec<AggregatedEmbedding> = batch
163 .iter()
164 .filter_map(|tokens| self.aggregate_with(tokens, strategy, None))
165 .collect();
166 let sequence_count = embeddings.len();
167 BatchResult {
168 embeddings,
169 sequence_count,
170 }
171 }
172
173 pub fn hierarchical_aggregate(
181 &mut self,
182 sentences: &[Vec<Vec<f32>>],
183 paragraph_boundaries: &[usize],
184 ) -> Option<HierarchicalResult> {
185 if sentences.is_empty() {
186 return None;
187 }
188
189 let sentence_embeddings: Vec<Vec<f32>> = sentences
191 .iter()
192 .filter_map(|tokens| {
193 self.aggregate_with(tokens, PoolingStrategy::Mean, None)
194 .map(|agg| agg.vector)
195 })
196 .collect();
197
198 if sentence_embeddings.is_empty() {
199 return None;
200 }
201
202 let paragraph_embeddings =
204 aggregate_by_boundaries(&sentence_embeddings, paragraph_boundaries, self.config.eps);
205
206 let dim = paragraph_embeddings.first().map(|v| v.len()).unwrap_or(0);
208 let document_embedding = if paragraph_embeddings.is_empty() || dim == 0 {
209 vec![]
210 } else {
211 mean_pool_refs(¶graph_embeddings, dim)
212 };
213
214 Some(HierarchicalResult {
215 sentence_embeddings,
216 paragraph_embeddings,
217 document_embedding,
218 })
219 }
220
221 pub fn compare_strategies(
223 &mut self,
224 tokens: &[Vec<f32>],
225 strategy_a: PoolingStrategy,
226 strategy_b: PoolingStrategy,
227 ) -> (Option<AggregatedEmbedding>, Option<AggregatedEmbedding>) {
228 let a = self.aggregate_with(tokens, strategy_a, None);
229 let b = self.aggregate_with(tokens, strategy_b, None);
230 (a, b)
231 }
232
233 pub fn total_aggregations(&self) -> u64 {
235 self.total_aggregations
236 }
237
238 pub fn config(&self) -> &AggregatorConfig {
240 &self.config
241 }
242
243 pub fn strategy_summary(results: &[AggregatedEmbedding]) -> HashMap<PoolingStrategy, usize> {
245 let mut counts: HashMap<PoolingStrategy, usize> = HashMap::new();
246 for r in results {
247 *counts.entry(r.strategy).or_insert(0) += 1;
248 }
249 counts
250 }
251}
252
253fn mean_pool(tokens: &[Vec<f32>], dim: usize) -> Vec<f32> {
259 let n = tokens.len() as f32;
260 let mut result = vec![0.0f32; dim];
261 for tok in tokens {
262 for (i, &v) in tok.iter().enumerate().take(dim) {
263 result[i] += v;
264 }
265 }
266 for v in &mut result {
267 *v /= n;
268 }
269 result
270}
271
272fn mean_pool_refs(vectors: &[Vec<f32>], dim: usize) -> Vec<f32> {
274 let n = vectors.len() as f32;
275 let mut result = vec![0.0f32; dim];
276 for vec in vectors {
277 for (i, &v) in vec.iter().enumerate().take(dim) {
278 result[i] += v;
279 }
280 }
281 for v in &mut result {
282 *v /= n;
283 }
284 result
285}
286
287fn max_pool(tokens: &[Vec<f32>], dim: usize) -> Vec<f32> {
289 let mut result = vec![f32::NEG_INFINITY; dim];
290 for tok in tokens {
291 for (i, &v) in tok.iter().enumerate().take(dim) {
292 if v > result[i] {
293 result[i] = v;
294 }
295 }
296 }
297 result
298}
299
300fn cls_pool(tokens: &[Vec<f32>]) -> Vec<f32> {
302 tokens.first().cloned().unwrap_or_default()
303}
304
305fn attention_pool(tokens: &[Vec<f32>], weights: Option<&[f32]>, dim: usize, eps: f32) -> Vec<f32> {
309 let n = tokens.len();
310 let effective_weights: Vec<f32> = match weights {
311 Some(w) if w.len() == n => {
312 let sum: f32 = w.iter().sum();
314 if sum.abs() < eps {
315 vec![1.0 / n as f32; n]
316 } else {
317 w.iter().map(|&v| v / sum).collect()
318 }
319 }
320 _ => vec![1.0 / n as f32; n],
321 };
322
323 let mut result = vec![0.0f32; dim];
324 for (tok, &weight) in tokens.iter().zip(effective_weights.iter()) {
325 for (i, &v) in tok.iter().enumerate().take(dim) {
326 result[i] += v * weight;
327 }
328 }
329 result
330}
331
332fn l2_normalize(vec: &[f32], eps: f32) -> Vec<f32> {
334 let norm: f32 = vec.iter().map(|&v| v * v).sum::<f32>().sqrt();
335 if norm < eps {
336 return vec.to_vec();
337 }
338 vec.iter().map(|&v| v / norm).collect()
339}
340
341pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
343 let len = a.len().min(b.len());
344 if len == 0 {
345 return 0.0;
346 }
347 let dot: f32 = a[..len]
348 .iter()
349 .zip(b[..len].iter())
350 .map(|(x, y)| x * y)
351 .sum();
352 let na: f32 = a[..len].iter().map(|x| x * x).sum::<f32>().sqrt();
353 let nb: f32 = b[..len].iter().map(|x| x * x).sum::<f32>().sqrt();
354 if na == 0.0 || nb == 0.0 {
355 return 0.0;
356 }
357 (dot / (na * nb)).clamp(-1.0, 1.0)
358}
359
360fn aggregate_by_boundaries(vectors: &[Vec<f32>], boundaries: &[usize], _eps: f32) -> Vec<Vec<f32>> {
362 if vectors.is_empty() {
363 return vec![];
364 }
365 let dim = vectors[0].len();
366
367 let mut ranges: Vec<(usize, usize)> = Vec::new();
369 for (i, &start) in boundaries.iter().enumerate() {
370 let end = if i + 1 < boundaries.len() {
371 boundaries[i + 1]
372 } else {
373 vectors.len()
374 };
375 if start < end && start < vectors.len() {
376 ranges.push((start, end.min(vectors.len())));
377 }
378 }
379
380 if ranges.is_empty() {
382 ranges.push((0, vectors.len()));
383 }
384
385 ranges
386 .iter()
387 .map(|&(start, end)| mean_pool_refs(&vectors[start..end], dim))
388 .collect()
389}
390
391#[cfg(test)]
396mod tests {
397 use super::*;
398
399 fn default_aggregator() -> EmbeddingAggregator {
400 EmbeddingAggregator::new(AggregatorConfig::default())
401 }
402
403 fn normalizing_aggregator() -> EmbeddingAggregator {
404 EmbeddingAggregator::new(AggregatorConfig {
405 normalize_output: true,
406 ..AggregatorConfig::default()
407 })
408 }
409
410 fn sample_tokens() -> Vec<Vec<f32>> {
412 vec![
413 vec![1.0, 2.0, 3.0, 4.0],
414 vec![5.0, 6.0, 7.0, 8.0],
415 vec![9.0, 10.0, 11.0, 12.0],
416 ]
417 }
418
419 #[test]
422 fn test_mean_pool_correct_values() {
423 let mut agg = default_aggregator();
424 let result = agg
425 .aggregate_with(&sample_tokens(), PoolingStrategy::Mean, None)
426 .expect("should succeed");
427 assert!((result.vector[0] - 5.0).abs() < 1e-5);
429 assert!((result.vector[1] - 6.0).abs() < 1e-5);
430 assert!((result.vector[2] - 7.0).abs() < 1e-5);
431 assert!((result.vector[3] - 8.0).abs() < 1e-5);
432 }
433
434 #[test]
435 fn test_mean_pool_dimension_preserved() {
436 let mut agg = default_aggregator();
437 let result = agg
438 .aggregate_with(&sample_tokens(), PoolingStrategy::Mean, None)
439 .expect("should succeed");
440 assert_eq!(result.vector.len(), 4);
441 }
442
443 #[test]
444 fn test_mean_pool_single_token() {
445 let mut agg = default_aggregator();
446 let tokens = vec![vec![1.0, 2.0, 3.0]];
447 let result = agg
448 .aggregate_with(&tokens, PoolingStrategy::Mean, None)
449 .expect("should succeed");
450 assert!((result.vector[0] - 1.0).abs() < 1e-6);
451 assert!((result.vector[1] - 2.0).abs() < 1e-6);
452 }
453
454 #[test]
455 fn test_mean_pool_token_count() {
456 let mut agg = default_aggregator();
457 let result = agg
458 .aggregate_with(&sample_tokens(), PoolingStrategy::Mean, None)
459 .expect("should succeed");
460 assert_eq!(result.token_count, 3);
461 }
462
463 #[test]
466 fn test_max_pool_correct_values() {
467 let mut agg = default_aggregator();
468 let result = agg
469 .aggregate_with(&sample_tokens(), PoolingStrategy::Max, None)
470 .expect("should succeed");
471 assert!((result.vector[0] - 9.0).abs() < 1e-5);
473 assert!((result.vector[1] - 10.0).abs() < 1e-5);
474 assert!((result.vector[2] - 11.0).abs() < 1e-5);
475 assert!((result.vector[3] - 12.0).abs() < 1e-5);
476 }
477
478 #[test]
479 fn test_max_pool_with_negatives() {
480 let mut agg = default_aggregator();
481 let tokens = vec![vec![-1.0, -5.0], vec![-3.0, -2.0]];
482 let result = agg
483 .aggregate_with(&tokens, PoolingStrategy::Max, None)
484 .expect("should succeed");
485 assert!((result.vector[0] - (-1.0)).abs() < 1e-6);
486 assert!((result.vector[1] - (-2.0)).abs() < 1e-6);
487 }
488
489 #[test]
490 fn test_max_pool_single_token() {
491 let mut agg = default_aggregator();
492 let tokens = vec![vec![7.0, 8.0, 9.0]];
493 let result = agg
494 .aggregate_with(&tokens, PoolingStrategy::Max, None)
495 .expect("should succeed");
496 assert!((result.vector[0] - 7.0).abs() < 1e-6);
497 }
498
499 #[test]
500 fn test_max_pool_dimension_preserved() {
501 let mut agg = default_aggregator();
502 let result = agg
503 .aggregate_with(&sample_tokens(), PoolingStrategy::Max, None)
504 .expect("should succeed");
505 assert_eq!(result.vector.len(), 4);
506 }
507
508 #[test]
511 fn test_cls_pool_returns_first_token() {
512 let mut agg = default_aggregator();
513 let result = agg
514 .aggregate_with(&sample_tokens(), PoolingStrategy::Cls, None)
515 .expect("should succeed");
516 assert_eq!(result.vector, vec![1.0, 2.0, 3.0, 4.0]);
517 }
518
519 #[test]
520 fn test_cls_pool_ignores_subsequent_tokens() {
521 let mut agg = default_aggregator();
522 let tokens = vec![vec![100.0, 200.0], vec![999.0, 888.0]];
523 let result = agg
524 .aggregate_with(&tokens, PoolingStrategy::Cls, None)
525 .expect("should succeed");
526 assert!((result.vector[0] - 100.0).abs() < 1e-6);
527 }
528
529 #[test]
530 fn test_cls_pool_token_count() {
531 let mut agg = default_aggregator();
532 let tokens = vec![vec![1.0], vec![2.0], vec![3.0]];
533 let result = agg
534 .aggregate_with(&tokens, PoolingStrategy::Cls, None)
535 .expect("should succeed");
536 assert_eq!(result.token_count, 3);
537 }
538
539 #[test]
542 fn test_attention_pool_uniform_weights_equals_mean() {
543 let mut agg = default_aggregator();
544 let tokens = sample_tokens();
545 let weights = vec![1.0, 1.0, 1.0];
546 let attn = agg
547 .aggregate_with(&tokens, PoolingStrategy::AttentionWeighted, Some(&weights))
548 .expect("should succeed");
549 let mean = agg
550 .aggregate_with(&tokens, PoolingStrategy::Mean, None)
551 .expect("should succeed");
552 for (a, m) in attn.vector.iter().zip(mean.vector.iter()) {
553 assert!((a - m).abs() < 1e-5, "uniform attn should equal mean");
554 }
555 }
556
557 #[test]
558 fn test_attention_pool_single_nonzero_weight() {
559 let mut agg = default_aggregator();
560 let tokens = sample_tokens();
561 let weights = vec![0.0, 0.0, 1.0];
563 let result = agg
564 .aggregate_with(&tokens, PoolingStrategy::AttentionWeighted, Some(&weights))
565 .expect("should succeed");
566 assert!((result.vector[0] - 9.0).abs() < 1e-5);
567 assert!((result.vector[1] - 10.0).abs() < 1e-5);
568 }
569
570 #[test]
571 fn test_attention_pool_mismatched_weights_falls_back_to_uniform() {
572 let mut agg = default_aggregator();
573 let tokens = sample_tokens();
574 let weights = vec![1.0, 2.0]; let result = agg
576 .aggregate_with(&tokens, PoolingStrategy::AttentionWeighted, Some(&weights))
577 .expect("should succeed");
578 assert!((result.vector[0] - 5.0).abs() < 1e-5);
580 }
581
582 #[test]
583 fn test_attention_pool_no_weights_falls_back_to_uniform() {
584 let mut agg = default_aggregator();
585 let tokens = sample_tokens();
586 let result = agg
587 .aggregate_with(&tokens, PoolingStrategy::AttentionWeighted, None)
588 .expect("should succeed");
589 assert!((result.vector[0] - 5.0).abs() < 1e-5);
590 }
591
592 #[test]
595 fn test_normalized_output_has_unit_norm() {
596 let mut agg = normalizing_aggregator();
597 let result = agg
598 .aggregate_with(&sample_tokens(), PoolingStrategy::Mean, None)
599 .expect("should succeed");
600 let norm: f32 = result.vector.iter().map(|v| v * v).sum::<f32>().sqrt();
601 assert!(
602 (norm - 1.0).abs() < 1e-5,
603 "normalized output should have unit norm"
604 );
605 }
606
607 #[test]
608 fn test_non_normalized_output_not_unit_norm() {
609 let mut agg = default_aggregator();
610 let result = agg
611 .aggregate_with(&sample_tokens(), PoolingStrategy::Mean, None)
612 .expect("should succeed");
613 let norm: f32 = result.vector.iter().map(|v| v * v).sum::<f32>().sqrt();
614 assert!(norm > 1.0);
616 }
617
618 #[test]
621 fn test_empty_tokens_returns_none() {
622 let mut agg = default_aggregator();
623 let result = agg.aggregate_with(&[], PoolingStrategy::Mean, None);
624 assert!(result.is_none());
625 }
626
627 #[test]
628 fn test_zero_dim_tokens_returns_none() {
629 let mut agg = default_aggregator();
630 let tokens: Vec<Vec<f32>> = vec![vec![], vec![]];
631 let result = agg.aggregate_with(&tokens, PoolingStrategy::Mean, None);
632 assert!(result.is_none());
633 }
634
635 #[test]
638 fn test_aggregate_uses_default_strategy() {
639 let mut agg = EmbeddingAggregator::new(AggregatorConfig {
640 default_strategy: PoolingStrategy::Max,
641 ..AggregatorConfig::default()
642 });
643 let result = agg.aggregate(&sample_tokens()).expect("should succeed");
644 assert_eq!(result.strategy, PoolingStrategy::Max);
645 }
646
647 #[test]
650 fn test_batch_aggregate_count() {
651 let mut agg = default_aggregator();
652 let batch = vec![sample_tokens(), sample_tokens(), sample_tokens()];
653 let result = agg.aggregate_batch(&batch);
654 assert_eq!(result.sequence_count, 3);
655 assert_eq!(result.embeddings.len(), 3);
656 }
657
658 #[test]
659 fn test_batch_aggregate_with_empty_sequences() {
660 let mut agg = default_aggregator();
661 let batch: Vec<Vec<Vec<f32>>> = vec![sample_tokens(), vec![], sample_tokens()];
662 let result = agg.aggregate_batch(&batch);
663 assert_eq!(
664 result.sequence_count, 2,
665 "empty sequence should be filtered out"
666 );
667 }
668
669 #[test]
670 fn test_batch_aggregate_strategy_propagates() {
671 let mut agg = default_aggregator();
672 let batch = vec![sample_tokens()];
673 let result = agg.aggregate_batch_with(&batch, PoolingStrategy::Cls);
674 assert_eq!(result.embeddings[0].strategy, PoolingStrategy::Cls);
675 }
676
677 #[test]
680 fn test_hierarchical_single_sentence() {
681 let mut agg = default_aggregator();
682 let sentences = vec![sample_tokens()];
683 let result = agg
684 .hierarchical_aggregate(&sentences, &[0])
685 .expect("should succeed");
686 assert_eq!(result.sentence_embeddings.len(), 1);
687 assert_eq!(result.paragraph_embeddings.len(), 1);
688 assert_eq!(result.document_embedding.len(), 4);
689 }
690
691 #[test]
692 fn test_hierarchical_two_paragraphs() {
693 let mut agg = default_aggregator();
694 let sentences = vec![
695 vec![vec![1.0, 0.0], vec![3.0, 0.0]], vec![vec![5.0, 0.0], vec![7.0, 0.0]], vec![vec![9.0, 0.0], vec![11.0, 0.0]], ];
699 let boundaries = vec![0, 2]; let result = agg
701 .hierarchical_aggregate(&sentences, &boundaries)
702 .expect("should succeed");
703 assert_eq!(result.paragraph_embeddings.len(), 2);
704 }
705
706 #[test]
707 fn test_hierarchical_empty_returns_none() {
708 let mut agg = default_aggregator();
709 let result = agg.hierarchical_aggregate(&[], &[0]);
710 assert!(result.is_none());
711 }
712
713 #[test]
714 fn test_hierarchical_document_is_mean_of_paragraphs() {
715 let mut agg = default_aggregator();
716 let sentences = vec![
718 vec![vec![2.0, 4.0], vec![4.0, 6.0]], vec![vec![6.0, 8.0], vec![8.0, 10.0]], ];
721 let result = agg
723 .hierarchical_aggregate(&sentences, &[0])
724 .expect("should succeed");
725 assert!((result.document_embedding[0] - 5.0).abs() < 1e-5);
727 assert!((result.document_embedding[1] - 7.0).abs() < 1e-5);
728 }
729
730 #[test]
733 fn test_compare_strategies_returns_both() {
734 let mut agg = default_aggregator();
735 let (a, b) = agg.compare_strategies(
736 &sample_tokens(),
737 PoolingStrategy::Mean,
738 PoolingStrategy::Max,
739 );
740 assert!(a.is_some());
741 assert!(b.is_some());
742 assert_eq!(a.as_ref().map(|r| r.strategy), Some(PoolingStrategy::Mean));
743 assert_eq!(b.as_ref().map(|r| r.strategy), Some(PoolingStrategy::Max));
744 }
745
746 #[test]
747 fn test_compare_strategies_different_results() {
748 let mut agg = default_aggregator();
749 let (a, b) = agg.compare_strategies(
750 &sample_tokens(),
751 PoolingStrategy::Mean,
752 PoolingStrategy::Max,
753 );
754 assert!((a.as_ref().map(|r| r.vector[0]).unwrap_or(0.0) - 5.0).abs() < 1e-5);
756 assert!((b.as_ref().map(|r| r.vector[0]).unwrap_or(0.0) - 9.0).abs() < 1e-5);
757 }
758
759 #[test]
762 fn test_total_aggregations_initially_zero() {
763 let agg = default_aggregator();
764 assert_eq!(agg.total_aggregations(), 0);
765 }
766
767 #[test]
768 fn test_total_aggregations_increments() {
769 let mut agg = default_aggregator();
770 agg.aggregate(&sample_tokens());
771 agg.aggregate(&sample_tokens());
772 assert_eq!(agg.total_aggregations(), 2);
773 }
774
775 #[test]
776 fn test_total_aggregations_batch_increments() {
777 let mut agg = default_aggregator();
778 let batch = vec![sample_tokens(), sample_tokens()];
779 agg.aggregate_batch(&batch);
780 assert_eq!(agg.total_aggregations(), 2);
781 }
782
783 #[test]
786 fn test_strategy_summary_counts() {
787 let results = vec![
788 AggregatedEmbedding {
789 vector: vec![1.0],
790 strategy: PoolingStrategy::Mean,
791 token_count: 1,
792 },
793 AggregatedEmbedding {
794 vector: vec![2.0],
795 strategy: PoolingStrategy::Mean,
796 token_count: 1,
797 },
798 AggregatedEmbedding {
799 vector: vec![3.0],
800 strategy: PoolingStrategy::Max,
801 token_count: 1,
802 },
803 ];
804 let summary = EmbeddingAggregator::strategy_summary(&results);
805 assert_eq!(summary.get(&PoolingStrategy::Mean), Some(&2));
806 assert_eq!(summary.get(&PoolingStrategy::Max), Some(&1));
807 assert_eq!(summary.get(&PoolingStrategy::Cls), None);
808 }
809
810 #[test]
813 fn test_cosine_similarity_identical() {
814 let a = vec![1.0, 2.0, 3.0];
815 let sim = cosine_similarity(&a, &a);
816 assert!((sim - 1.0).abs() < 1e-6);
817 }
818
819 #[test]
820 fn test_cosine_similarity_orthogonal() {
821 let a = vec![1.0, 0.0];
822 let b = vec![0.0, 1.0];
823 let sim = cosine_similarity(&a, &b);
824 assert!(sim.abs() < 1e-6);
825 }
826
827 #[test]
828 fn test_cosine_similarity_opposite() {
829 let a = vec![1.0, 0.0];
830 let b = vec![-1.0, 0.0];
831 let sim = cosine_similarity(&a, &b);
832 assert!((sim + 1.0).abs() < 1e-6);
833 }
834
835 #[test]
836 fn test_cosine_similarity_empty() {
837 let sim = cosine_similarity(&[], &[]);
838 assert_eq!(sim, 0.0);
839 }
840
841 #[test]
844 fn test_config_accessor() {
845 let agg = default_aggregator();
846 assert_eq!(agg.config().default_strategy, PoolingStrategy::Mean);
847 assert!(!agg.config().normalize_output);
848 }
849
850 #[test]
851 fn test_aggregator_config_default() {
852 let config = AggregatorConfig::default();
853 assert_eq!(config.default_strategy, PoolingStrategy::Mean);
854 assert!(!config.normalize_output);
855 assert!(config.eps > 0.0);
856 }
857}