Skip to main content

oxirs_embed/
embedding_aggregator.rs

1//! Embedding aggregation strategies for combining token-level embeddings.
2//!
3//! Provides multiple pooling strategies:
4//! - **Mean pooling**: average of all token embeddings
5//! - **Max pooling**: element-wise maximum across tokens
6//! - **CLS token extraction**: first token (index 0) embedding
7//! - **Attention-weighted aggregation**: weighted sum by attention scores
8//! - **Hierarchical aggregation**: sentence -> paragraph -> document
9//! - Configurable strategy selection, dimension-preserving output, batch support
10
11use std::collections::HashMap;
12
13// ---------------------------------------------------------------------------
14// Public types
15// ---------------------------------------------------------------------------
16
17/// Available pooling strategies.
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
19pub enum PoolingStrategy {
20    /// Average of all token embeddings.
21    Mean,
22    /// Element-wise maximum across all token embeddings.
23    Max,
24    /// First token (index 0) embedding, typically CLS token in BERT-like models.
25    Cls,
26    /// Weighted average using attention scores.
27    AttentionWeighted,
28}
29
30/// Configuration for the embedding aggregator.
31#[derive(Debug, Clone)]
32pub struct AggregatorConfig {
33    /// The default pooling strategy when none is specified.
34    pub default_strategy: PoolingStrategy,
35    /// Whether to L2-normalise the result after aggregation.
36    pub normalize_output: bool,
37    /// Epsilon added to norms to prevent division by zero.
38    pub eps: f32,
39}
40
41impl Default for AggregatorConfig {
42    fn default() -> Self {
43        Self {
44            default_strategy: PoolingStrategy::Mean,
45            normalize_output: false,
46            eps: 1e-12,
47        }
48    }
49}
50
51/// Result of a single aggregation operation.
52#[derive(Debug, Clone)]
53pub struct AggregatedEmbedding {
54    /// The aggregated vector.
55    pub vector: Vec<f32>,
56    /// Which strategy was used.
57    pub strategy: PoolingStrategy,
58    /// How many token embeddings were consumed.
59    pub token_count: usize,
60}
61
62/// Result of hierarchical aggregation (sentence -> paragraph -> document).
63#[derive(Debug, Clone)]
64pub struct HierarchicalResult {
65    /// Per-sentence aggregated embeddings.
66    pub sentence_embeddings: Vec<Vec<f32>>,
67    /// Per-paragraph aggregated embeddings (each paragraph = group of sentences).
68    pub paragraph_embeddings: Vec<Vec<f32>>,
69    /// Document-level embedding (single vector).
70    pub document_embedding: Vec<f32>,
71}
72
73/// Batch aggregation result.
74#[derive(Debug, Clone)]
75pub struct BatchResult {
76    /// One aggregated embedding per input sequence.
77    pub embeddings: Vec<AggregatedEmbedding>,
78    /// Total number of sequences processed.
79    pub sequence_count: usize,
80}
81
82// ---------------------------------------------------------------------------
83// EmbeddingAggregator
84// ---------------------------------------------------------------------------
85
86/// Stateful embedding aggregator that tracks total aggregation operations.
87pub struct EmbeddingAggregator {
88    config: AggregatorConfig,
89    total_aggregations: u64,
90}
91
92impl EmbeddingAggregator {
93    /// Create a new aggregator with the given configuration.
94    pub fn new(config: AggregatorConfig) -> Self {
95        Self {
96            config,
97            total_aggregations: 0,
98        }
99    }
100
101    /// Aggregate a sequence of token embeddings using the default strategy.
102    ///
103    /// Each inner `Vec<f32>` is a single token's embedding.
104    /// All token embeddings must have the same dimensionality.
105    pub fn aggregate(&mut self, tokens: &[Vec<f32>]) -> Option<AggregatedEmbedding> {
106        self.aggregate_with(tokens, self.config.default_strategy, None)
107    }
108
109    /// Aggregate using a specific strategy.
110    ///
111    /// `attention_weights` is required when `strategy == AttentionWeighted` and
112    /// is ignored for other strategies.
113    pub fn aggregate_with(
114        &mut self,
115        tokens: &[Vec<f32>],
116        strategy: PoolingStrategy,
117        attention_weights: Option<&[f32]>,
118    ) -> Option<AggregatedEmbedding> {
119        if tokens.is_empty() {
120            return None;
121        }
122        let dim = tokens[0].len();
123        if dim == 0 {
124            return None;
125        }
126
127        let raw = match strategy {
128            PoolingStrategy::Mean => mean_pool(tokens, dim),
129            PoolingStrategy::Max => max_pool(tokens, dim),
130            PoolingStrategy::Cls => cls_pool(tokens),
131            PoolingStrategy::AttentionWeighted => {
132                attention_pool(tokens, attention_weights, dim, self.config.eps)
133            }
134        };
135
136        let vector = if self.config.normalize_output {
137            l2_normalize(&raw, self.config.eps)
138        } else {
139            raw
140        };
141
142        self.total_aggregations += 1;
143
144        Some(AggregatedEmbedding {
145            vector,
146            strategy,
147            token_count: tokens.len(),
148        })
149    }
150
151    /// Aggregate a batch of token sequences using the default strategy.
152    pub fn aggregate_batch(&mut self, batch: &[Vec<Vec<f32>>]) -> BatchResult {
153        self.aggregate_batch_with(batch, self.config.default_strategy)
154    }
155
156    /// Aggregate a batch with a specific strategy.
157    pub fn aggregate_batch_with(
158        &mut self,
159        batch: &[Vec<Vec<f32>>],
160        strategy: PoolingStrategy,
161    ) -> BatchResult {
162        let embeddings: Vec<AggregatedEmbedding> = batch
163            .iter()
164            .filter_map(|tokens| self.aggregate_with(tokens, strategy, None))
165            .collect();
166        let sequence_count = embeddings.len();
167        BatchResult {
168            embeddings,
169            sequence_count,
170        }
171    }
172
173    /// Perform hierarchical aggregation: sentence -> paragraph -> document.
174    ///
175    /// * `sentences` – each entry is a sequence of token embeddings forming one sentence.
176    /// * `paragraph_boundaries` – indices into `sentences` where a new paragraph starts
177    ///   (e.g. `[0, 3, 7]` means sentences 0..3 form paragraph 0, 3..7 form paragraph 1, etc.).
178    ///
179    /// Uses mean pooling at every level.
180    pub fn hierarchical_aggregate(
181        &mut self,
182        sentences: &[Vec<Vec<f32>>],
183        paragraph_boundaries: &[usize],
184    ) -> Option<HierarchicalResult> {
185        if sentences.is_empty() {
186            return None;
187        }
188
189        // 1. Sentence-level: aggregate each sentence's token embeddings.
190        let sentence_embeddings: Vec<Vec<f32>> = sentences
191            .iter()
192            .filter_map(|tokens| {
193                self.aggregate_with(tokens, PoolingStrategy::Mean, None)
194                    .map(|agg| agg.vector)
195            })
196            .collect();
197
198        if sentence_embeddings.is_empty() {
199            return None;
200        }
201
202        // 2. Paragraph-level: group sentence embeddings by boundaries.
203        let paragraph_embeddings =
204            aggregate_by_boundaries(&sentence_embeddings, paragraph_boundaries, self.config.eps);
205
206        // 3. Document-level: average of paragraph embeddings.
207        let dim = paragraph_embeddings.first().map(|v| v.len()).unwrap_or(0);
208        let document_embedding = if paragraph_embeddings.is_empty() || dim == 0 {
209            vec![]
210        } else {
211            mean_pool_refs(&paragraph_embeddings, dim)
212        };
213
214        Some(HierarchicalResult {
215            sentence_embeddings,
216            paragraph_embeddings,
217            document_embedding,
218        })
219    }
220
221    /// Compare two pooling strategies on the same tokens and return both results.
222    pub fn compare_strategies(
223        &mut self,
224        tokens: &[Vec<f32>],
225        strategy_a: PoolingStrategy,
226        strategy_b: PoolingStrategy,
227    ) -> (Option<AggregatedEmbedding>, Option<AggregatedEmbedding>) {
228        let a = self.aggregate_with(tokens, strategy_a, None);
229        let b = self.aggregate_with(tokens, strategy_b, None);
230        (a, b)
231    }
232
233    /// Return the total number of individual aggregation operations performed.
234    pub fn total_aggregations(&self) -> u64 {
235        self.total_aggregations
236    }
237
238    /// Return the current configuration.
239    pub fn config(&self) -> &AggregatorConfig {
240        &self.config
241    }
242
243    /// Build a summary of aggregation results per strategy from provided labels.
244    pub fn strategy_summary(results: &[AggregatedEmbedding]) -> HashMap<PoolingStrategy, usize> {
245        let mut counts: HashMap<PoolingStrategy, usize> = HashMap::new();
246        for r in results {
247            *counts.entry(r.strategy).or_insert(0) += 1;
248        }
249        counts
250    }
251}
252
253// ---------------------------------------------------------------------------
254// Free functions – pooling implementations
255// ---------------------------------------------------------------------------
256
257/// Mean pooling: element-wise average of all token embeddings.
258fn mean_pool(tokens: &[Vec<f32>], dim: usize) -> Vec<f32> {
259    let n = tokens.len() as f32;
260    let mut result = vec![0.0f32; dim];
261    for tok in tokens {
262        for (i, &v) in tok.iter().enumerate().take(dim) {
263            result[i] += v;
264        }
265    }
266    for v in &mut result {
267        *v /= n;
268    }
269    result
270}
271
272/// Mean pool from a slice of references (used in hierarchical aggregation).
273fn mean_pool_refs(vectors: &[Vec<f32>], dim: usize) -> Vec<f32> {
274    let n = vectors.len() as f32;
275    let mut result = vec![0.0f32; dim];
276    for vec in vectors {
277        for (i, &v) in vec.iter().enumerate().take(dim) {
278            result[i] += v;
279        }
280    }
281    for v in &mut result {
282        *v /= n;
283    }
284    result
285}
286
287/// Max pooling: element-wise maximum across all token embeddings.
288fn max_pool(tokens: &[Vec<f32>], dim: usize) -> Vec<f32> {
289    let mut result = vec![f32::NEG_INFINITY; dim];
290    for tok in tokens {
291        for (i, &v) in tok.iter().enumerate().take(dim) {
292            if v > result[i] {
293                result[i] = v;
294            }
295        }
296    }
297    result
298}
299
300/// CLS token extraction: return a clone of the first token's embedding.
301fn cls_pool(tokens: &[Vec<f32>]) -> Vec<f32> {
302    tokens.first().cloned().unwrap_or_default()
303}
304
305/// Attention-weighted pooling: weighted average using attention scores.
306///
307/// If `weights` is `None` or mismatched in length, falls back to uniform weights.
308fn attention_pool(tokens: &[Vec<f32>], weights: Option<&[f32]>, dim: usize, eps: f32) -> Vec<f32> {
309    let n = tokens.len();
310    let effective_weights: Vec<f32> = match weights {
311        Some(w) if w.len() == n => {
312            // Softmax-style normalisation (just normalise to sum=1).
313            let sum: f32 = w.iter().sum();
314            if sum.abs() < eps {
315                vec![1.0 / n as f32; n]
316            } else {
317                w.iter().map(|&v| v / sum).collect()
318            }
319        }
320        _ => vec![1.0 / n as f32; n],
321    };
322
323    let mut result = vec![0.0f32; dim];
324    for (tok, &weight) in tokens.iter().zip(effective_weights.iter()) {
325        for (i, &v) in tok.iter().enumerate().take(dim) {
326            result[i] += v * weight;
327        }
328    }
329    result
330}
331
332/// L2-normalise a vector in-place.
333fn l2_normalize(vec: &[f32], eps: f32) -> Vec<f32> {
334    let norm: f32 = vec.iter().map(|&v| v * v).sum::<f32>().sqrt();
335    if norm < eps {
336        return vec.to_vec();
337    }
338    vec.iter().map(|&v| v / norm).collect()
339}
340
341/// Cosine similarity between two f32 slices.
342pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
343    let len = a.len().min(b.len());
344    if len == 0 {
345        return 0.0;
346    }
347    let dot: f32 = a[..len]
348        .iter()
349        .zip(b[..len].iter())
350        .map(|(x, y)| x * y)
351        .sum();
352    let na: f32 = a[..len].iter().map(|x| x * x).sum::<f32>().sqrt();
353    let nb: f32 = b[..len].iter().map(|x| x * x).sum::<f32>().sqrt();
354    if na == 0.0 || nb == 0.0 {
355        return 0.0;
356    }
357    (dot / (na * nb)).clamp(-1.0, 1.0)
358}
359
360/// Group vectors by paragraph boundaries and mean-pool each group.
361fn aggregate_by_boundaries(vectors: &[Vec<f32>], boundaries: &[usize], _eps: f32) -> Vec<Vec<f32>> {
362    if vectors.is_empty() {
363        return vec![];
364    }
365    let dim = vectors[0].len();
366
367    // Determine segment ranges.
368    let mut ranges: Vec<(usize, usize)> = Vec::new();
369    for (i, &start) in boundaries.iter().enumerate() {
370        let end = if i + 1 < boundaries.len() {
371            boundaries[i + 1]
372        } else {
373            vectors.len()
374        };
375        if start < end && start < vectors.len() {
376            ranges.push((start, end.min(vectors.len())));
377        }
378    }
379
380    // If no valid boundaries, treat entire input as one paragraph.
381    if ranges.is_empty() {
382        ranges.push((0, vectors.len()));
383    }
384
385    ranges
386        .iter()
387        .map(|&(start, end)| mean_pool_refs(&vectors[start..end], dim))
388        .collect()
389}
390
391// ---------------------------------------------------------------------------
392// Tests
393// ---------------------------------------------------------------------------
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398
399    fn default_aggregator() -> EmbeddingAggregator {
400        EmbeddingAggregator::new(AggregatorConfig::default())
401    }
402
403    fn normalizing_aggregator() -> EmbeddingAggregator {
404        EmbeddingAggregator::new(AggregatorConfig {
405            normalize_output: true,
406            ..AggregatorConfig::default()
407        })
408    }
409
410    /// Three token embeddings each of dimension 4.
411    fn sample_tokens() -> Vec<Vec<f32>> {
412        vec![
413            vec![1.0, 2.0, 3.0, 4.0],
414            vec![5.0, 6.0, 7.0, 8.0],
415            vec![9.0, 10.0, 11.0, 12.0],
416        ]
417    }
418
419    // --- mean pooling ---
420
421    #[test]
422    fn test_mean_pool_correct_values() {
423        let mut agg = default_aggregator();
424        let result = agg
425            .aggregate_with(&sample_tokens(), PoolingStrategy::Mean, None)
426            .expect("should succeed");
427        // mean = (1+5+9)/3, (2+6+10)/3, (3+7+11)/3, (4+8+12)/3 = 5, 6, 7, 8
428        assert!((result.vector[0] - 5.0).abs() < 1e-5);
429        assert!((result.vector[1] - 6.0).abs() < 1e-5);
430        assert!((result.vector[2] - 7.0).abs() < 1e-5);
431        assert!((result.vector[3] - 8.0).abs() < 1e-5);
432    }
433
434    #[test]
435    fn test_mean_pool_dimension_preserved() {
436        let mut agg = default_aggregator();
437        let result = agg
438            .aggregate_with(&sample_tokens(), PoolingStrategy::Mean, None)
439            .expect("should succeed");
440        assert_eq!(result.vector.len(), 4);
441    }
442
443    #[test]
444    fn test_mean_pool_single_token() {
445        let mut agg = default_aggregator();
446        let tokens = vec![vec![1.0, 2.0, 3.0]];
447        let result = agg
448            .aggregate_with(&tokens, PoolingStrategy::Mean, None)
449            .expect("should succeed");
450        assert!((result.vector[0] - 1.0).abs() < 1e-6);
451        assert!((result.vector[1] - 2.0).abs() < 1e-6);
452    }
453
454    #[test]
455    fn test_mean_pool_token_count() {
456        let mut agg = default_aggregator();
457        let result = agg
458            .aggregate_with(&sample_tokens(), PoolingStrategy::Mean, None)
459            .expect("should succeed");
460        assert_eq!(result.token_count, 3);
461    }
462
463    // --- max pooling ---
464
465    #[test]
466    fn test_max_pool_correct_values() {
467        let mut agg = default_aggregator();
468        let result = agg
469            .aggregate_with(&sample_tokens(), PoolingStrategy::Max, None)
470            .expect("should succeed");
471        // max = 9, 10, 11, 12
472        assert!((result.vector[0] - 9.0).abs() < 1e-5);
473        assert!((result.vector[1] - 10.0).abs() < 1e-5);
474        assert!((result.vector[2] - 11.0).abs() < 1e-5);
475        assert!((result.vector[3] - 12.0).abs() < 1e-5);
476    }
477
478    #[test]
479    fn test_max_pool_with_negatives() {
480        let mut agg = default_aggregator();
481        let tokens = vec![vec![-1.0, -5.0], vec![-3.0, -2.0]];
482        let result = agg
483            .aggregate_with(&tokens, PoolingStrategy::Max, None)
484            .expect("should succeed");
485        assert!((result.vector[0] - (-1.0)).abs() < 1e-6);
486        assert!((result.vector[1] - (-2.0)).abs() < 1e-6);
487    }
488
489    #[test]
490    fn test_max_pool_single_token() {
491        let mut agg = default_aggregator();
492        let tokens = vec![vec![7.0, 8.0, 9.0]];
493        let result = agg
494            .aggregate_with(&tokens, PoolingStrategy::Max, None)
495            .expect("should succeed");
496        assert!((result.vector[0] - 7.0).abs() < 1e-6);
497    }
498
499    #[test]
500    fn test_max_pool_dimension_preserved() {
501        let mut agg = default_aggregator();
502        let result = agg
503            .aggregate_with(&sample_tokens(), PoolingStrategy::Max, None)
504            .expect("should succeed");
505        assert_eq!(result.vector.len(), 4);
506    }
507
508    // --- CLS pooling ---
509
510    #[test]
511    fn test_cls_pool_returns_first_token() {
512        let mut agg = default_aggregator();
513        let result = agg
514            .aggregate_with(&sample_tokens(), PoolingStrategy::Cls, None)
515            .expect("should succeed");
516        assert_eq!(result.vector, vec![1.0, 2.0, 3.0, 4.0]);
517    }
518
519    #[test]
520    fn test_cls_pool_ignores_subsequent_tokens() {
521        let mut agg = default_aggregator();
522        let tokens = vec![vec![100.0, 200.0], vec![999.0, 888.0]];
523        let result = agg
524            .aggregate_with(&tokens, PoolingStrategy::Cls, None)
525            .expect("should succeed");
526        assert!((result.vector[0] - 100.0).abs() < 1e-6);
527    }
528
529    #[test]
530    fn test_cls_pool_token_count() {
531        let mut agg = default_aggregator();
532        let tokens = vec![vec![1.0], vec![2.0], vec![3.0]];
533        let result = agg
534            .aggregate_with(&tokens, PoolingStrategy::Cls, None)
535            .expect("should succeed");
536        assert_eq!(result.token_count, 3);
537    }
538
539    // --- attention-weighted pooling ---
540
541    #[test]
542    fn test_attention_pool_uniform_weights_equals_mean() {
543        let mut agg = default_aggregator();
544        let tokens = sample_tokens();
545        let weights = vec![1.0, 1.0, 1.0];
546        let attn = agg
547            .aggregate_with(&tokens, PoolingStrategy::AttentionWeighted, Some(&weights))
548            .expect("should succeed");
549        let mean = agg
550            .aggregate_with(&tokens, PoolingStrategy::Mean, None)
551            .expect("should succeed");
552        for (a, m) in attn.vector.iter().zip(mean.vector.iter()) {
553            assert!((a - m).abs() < 1e-5, "uniform attn should equal mean");
554        }
555    }
556
557    #[test]
558    fn test_attention_pool_single_nonzero_weight() {
559        let mut agg = default_aggregator();
560        let tokens = sample_tokens();
561        // Only the last token gets weight
562        let weights = vec![0.0, 0.0, 1.0];
563        let result = agg
564            .aggregate_with(&tokens, PoolingStrategy::AttentionWeighted, Some(&weights))
565            .expect("should succeed");
566        assert!((result.vector[0] - 9.0).abs() < 1e-5);
567        assert!((result.vector[1] - 10.0).abs() < 1e-5);
568    }
569
570    #[test]
571    fn test_attention_pool_mismatched_weights_falls_back_to_uniform() {
572        let mut agg = default_aggregator();
573        let tokens = sample_tokens();
574        let weights = vec![1.0, 2.0]; // length 2 != 3 tokens
575        let result = agg
576            .aggregate_with(&tokens, PoolingStrategy::AttentionWeighted, Some(&weights))
577            .expect("should succeed");
578        // Falls back to uniform = mean pooling
579        assert!((result.vector[0] - 5.0).abs() < 1e-5);
580    }
581
582    #[test]
583    fn test_attention_pool_no_weights_falls_back_to_uniform() {
584        let mut agg = default_aggregator();
585        let tokens = sample_tokens();
586        let result = agg
587            .aggregate_with(&tokens, PoolingStrategy::AttentionWeighted, None)
588            .expect("should succeed");
589        assert!((result.vector[0] - 5.0).abs() < 1e-5);
590    }
591
592    // --- normalization ---
593
594    #[test]
595    fn test_normalized_output_has_unit_norm() {
596        let mut agg = normalizing_aggregator();
597        let result = agg
598            .aggregate_with(&sample_tokens(), PoolingStrategy::Mean, None)
599            .expect("should succeed");
600        let norm: f32 = result.vector.iter().map(|v| v * v).sum::<f32>().sqrt();
601        assert!(
602            (norm - 1.0).abs() < 1e-5,
603            "normalized output should have unit norm"
604        );
605    }
606
607    #[test]
608    fn test_non_normalized_output_not_unit_norm() {
609        let mut agg = default_aggregator();
610        let result = agg
611            .aggregate_with(&sample_tokens(), PoolingStrategy::Mean, None)
612            .expect("should succeed");
613        let norm: f32 = result.vector.iter().map(|v| v * v).sum::<f32>().sqrt();
614        // Mean of (5,6,7,8) has norm > 1
615        assert!(norm > 1.0);
616    }
617
618    // --- empty / edge cases ---
619
620    #[test]
621    fn test_empty_tokens_returns_none() {
622        let mut agg = default_aggregator();
623        let result = agg.aggregate_with(&[], PoolingStrategy::Mean, None);
624        assert!(result.is_none());
625    }
626
627    #[test]
628    fn test_zero_dim_tokens_returns_none() {
629        let mut agg = default_aggregator();
630        let tokens: Vec<Vec<f32>> = vec![vec![], vec![]];
631        let result = agg.aggregate_with(&tokens, PoolingStrategy::Mean, None);
632        assert!(result.is_none());
633    }
634
635    // --- default aggregate ---
636
637    #[test]
638    fn test_aggregate_uses_default_strategy() {
639        let mut agg = EmbeddingAggregator::new(AggregatorConfig {
640            default_strategy: PoolingStrategy::Max,
641            ..AggregatorConfig::default()
642        });
643        let result = agg.aggregate(&sample_tokens()).expect("should succeed");
644        assert_eq!(result.strategy, PoolingStrategy::Max);
645    }
646
647    // --- batch aggregation ---
648
649    #[test]
650    fn test_batch_aggregate_count() {
651        let mut agg = default_aggregator();
652        let batch = vec![sample_tokens(), sample_tokens(), sample_tokens()];
653        let result = agg.aggregate_batch(&batch);
654        assert_eq!(result.sequence_count, 3);
655        assert_eq!(result.embeddings.len(), 3);
656    }
657
658    #[test]
659    fn test_batch_aggregate_with_empty_sequences() {
660        let mut agg = default_aggregator();
661        let batch: Vec<Vec<Vec<f32>>> = vec![sample_tokens(), vec![], sample_tokens()];
662        let result = agg.aggregate_batch(&batch);
663        assert_eq!(
664            result.sequence_count, 2,
665            "empty sequence should be filtered out"
666        );
667    }
668
669    #[test]
670    fn test_batch_aggregate_strategy_propagates() {
671        let mut agg = default_aggregator();
672        let batch = vec![sample_tokens()];
673        let result = agg.aggregate_batch_with(&batch, PoolingStrategy::Cls);
674        assert_eq!(result.embeddings[0].strategy, PoolingStrategy::Cls);
675    }
676
677    // --- hierarchical aggregation ---
678
679    #[test]
680    fn test_hierarchical_single_sentence() {
681        let mut agg = default_aggregator();
682        let sentences = vec![sample_tokens()];
683        let result = agg
684            .hierarchical_aggregate(&sentences, &[0])
685            .expect("should succeed");
686        assert_eq!(result.sentence_embeddings.len(), 1);
687        assert_eq!(result.paragraph_embeddings.len(), 1);
688        assert_eq!(result.document_embedding.len(), 4);
689    }
690
691    #[test]
692    fn test_hierarchical_two_paragraphs() {
693        let mut agg = default_aggregator();
694        let sentences = vec![
695            vec![vec![1.0, 0.0], vec![3.0, 0.0]],  // sentence 0
696            vec![vec![5.0, 0.0], vec![7.0, 0.0]],  // sentence 1
697            vec![vec![9.0, 0.0], vec![11.0, 0.0]], // sentence 2
698        ];
699        let boundaries = vec![0, 2]; // paragraph 0 = sentences [0,1], paragraph 1 = sentence [2]
700        let result = agg
701            .hierarchical_aggregate(&sentences, &boundaries)
702            .expect("should succeed");
703        assert_eq!(result.paragraph_embeddings.len(), 2);
704    }
705
706    #[test]
707    fn test_hierarchical_empty_returns_none() {
708        let mut agg = default_aggregator();
709        let result = agg.hierarchical_aggregate(&[], &[0]);
710        assert!(result.is_none());
711    }
712
713    #[test]
714    fn test_hierarchical_document_is_mean_of_paragraphs() {
715        let mut agg = default_aggregator();
716        // Two sentences, each with two tokens of dim 2
717        let sentences = vec![
718            vec![vec![2.0, 4.0], vec![4.0, 6.0]], // sentence 0 → mean = (3, 5)
719            vec![vec![6.0, 8.0], vec![8.0, 10.0]], // sentence 1 → mean = (7, 9)
720        ];
721        // One paragraph encompassing both
722        let result = agg
723            .hierarchical_aggregate(&sentences, &[0])
724            .expect("should succeed");
725        // Paragraph mean = (3+7)/2, (5+9)/2 = (5, 7) = document
726        assert!((result.document_embedding[0] - 5.0).abs() < 1e-5);
727        assert!((result.document_embedding[1] - 7.0).abs() < 1e-5);
728    }
729
730    // --- compare strategies ---
731
732    #[test]
733    fn test_compare_strategies_returns_both() {
734        let mut agg = default_aggregator();
735        let (a, b) = agg.compare_strategies(
736            &sample_tokens(),
737            PoolingStrategy::Mean,
738            PoolingStrategy::Max,
739        );
740        assert!(a.is_some());
741        assert!(b.is_some());
742        assert_eq!(a.as_ref().map(|r| r.strategy), Some(PoolingStrategy::Mean));
743        assert_eq!(b.as_ref().map(|r| r.strategy), Some(PoolingStrategy::Max));
744    }
745
746    #[test]
747    fn test_compare_strategies_different_results() {
748        let mut agg = default_aggregator();
749        let (a, b) = agg.compare_strategies(
750            &sample_tokens(),
751            PoolingStrategy::Mean,
752            PoolingStrategy::Max,
753        );
754        // Mean[0]=5, Max[0]=9
755        assert!((a.as_ref().map(|r| r.vector[0]).unwrap_or(0.0) - 5.0).abs() < 1e-5);
756        assert!((b.as_ref().map(|r| r.vector[0]).unwrap_or(0.0) - 9.0).abs() < 1e-5);
757    }
758
759    // --- total aggregations tracking ---
760
761    #[test]
762    fn test_total_aggregations_initially_zero() {
763        let agg = default_aggregator();
764        assert_eq!(agg.total_aggregations(), 0);
765    }
766
767    #[test]
768    fn test_total_aggregations_increments() {
769        let mut agg = default_aggregator();
770        agg.aggregate(&sample_tokens());
771        agg.aggregate(&sample_tokens());
772        assert_eq!(agg.total_aggregations(), 2);
773    }
774
775    #[test]
776    fn test_total_aggregations_batch_increments() {
777        let mut agg = default_aggregator();
778        let batch = vec![sample_tokens(), sample_tokens()];
779        agg.aggregate_batch(&batch);
780        assert_eq!(agg.total_aggregations(), 2);
781    }
782
783    // --- strategy summary ---
784
785    #[test]
786    fn test_strategy_summary_counts() {
787        let results = vec![
788            AggregatedEmbedding {
789                vector: vec![1.0],
790                strategy: PoolingStrategy::Mean,
791                token_count: 1,
792            },
793            AggregatedEmbedding {
794                vector: vec![2.0],
795                strategy: PoolingStrategy::Mean,
796                token_count: 1,
797            },
798            AggregatedEmbedding {
799                vector: vec![3.0],
800                strategy: PoolingStrategy::Max,
801                token_count: 1,
802            },
803        ];
804        let summary = EmbeddingAggregator::strategy_summary(&results);
805        assert_eq!(summary.get(&PoolingStrategy::Mean), Some(&2));
806        assert_eq!(summary.get(&PoolingStrategy::Max), Some(&1));
807        assert_eq!(summary.get(&PoolingStrategy::Cls), None);
808    }
809
810    // --- cosine similarity ---
811
812    #[test]
813    fn test_cosine_similarity_identical() {
814        let a = vec![1.0, 2.0, 3.0];
815        let sim = cosine_similarity(&a, &a);
816        assert!((sim - 1.0).abs() < 1e-6);
817    }
818
819    #[test]
820    fn test_cosine_similarity_orthogonal() {
821        let a = vec![1.0, 0.0];
822        let b = vec![0.0, 1.0];
823        let sim = cosine_similarity(&a, &b);
824        assert!(sim.abs() < 1e-6);
825    }
826
827    #[test]
828    fn test_cosine_similarity_opposite() {
829        let a = vec![1.0, 0.0];
830        let b = vec![-1.0, 0.0];
831        let sim = cosine_similarity(&a, &b);
832        assert!((sim + 1.0).abs() < 1e-6);
833    }
834
835    #[test]
836    fn test_cosine_similarity_empty() {
837        let sim = cosine_similarity(&[], &[]);
838        assert_eq!(sim, 0.0);
839    }
840
841    // --- config access ---
842
843    #[test]
844    fn test_config_accessor() {
845        let agg = default_aggregator();
846        assert_eq!(agg.config().default_strategy, PoolingStrategy::Mean);
847        assert!(!agg.config().normalize_output);
848    }
849
850    #[test]
851    fn test_aggregator_config_default() {
852        let config = AggregatorConfig::default();
853        assert_eq!(config.default_strategy, PoolingStrategy::Mean);
854        assert!(!config.normalize_output);
855        assert!(config.eps > 0.0);
856    }
857}