1use crate::Entity;
37use std::collections::{HashMap, HashSet};
38
39#[derive(Debug, Clone)]
45pub struct ExtractorConfig {
46 pub context_window: usize,
48 pub cooccurrence_window: usize,
50 pub normalize_text: bool,
52 pub min_cooccurrence_freq: usize,
54}
55
56impl Default for ExtractorConfig {
57 fn default() -> Self {
58 Self {
59 context_window: 100,
60 cooccurrence_window: 150,
61 normalize_text: true,
62 min_cooccurrence_freq: 1,
63 }
64 }
65}
66
67impl ExtractorConfig {
68 pub fn with_context_window(mut self, window: usize) -> Self {
70 self.context_window = window;
71 self
72 }
73
74 pub fn with_cooccurrence_window(mut self, window: usize) -> Self {
76 self.cooccurrence_window = window;
77 self
78 }
79}
80
81#[derive(Debug, Clone)]
87pub struct MentionContext {
88 pub entity: Entity,
90 pub left_context: String,
92 pub right_context: String,
94 pub relative_position: f64,
96 pub absolute_position: usize,
98 pub sentence_index: Option<usize>,
100 pub likely_subject: bool,
102 pub likely_heading: bool,
104 pub word_count: usize,
106 pub char_count: usize,
108 pub is_capitalized: bool,
110 pub is_all_caps: bool,
112 pub contains_digits: bool,
114}
115
116impl MentionContext {
117 pub fn extract(text: &str, entity: &Entity, config: &ExtractorConfig) -> Self {
119 let text_chars: Vec<char> = text.chars().collect();
120 let text_len = text_chars.len();
121
122 let left_start = entity.start.saturating_sub(config.context_window);
124 let left_end = entity.start;
125 let right_start = entity.end.min(text_len);
126 let right_end = (entity.end + config.context_window).min(text_len);
127
128 let left_context: String = text_chars[left_start..left_end].iter().collect();
129 let right_context: String = text_chars[right_start..right_end].iter().collect();
130
131 let relative_position = if text_len > 0 {
132 entity.start as f64 / text_len as f64
133 } else {
134 0.0
135 };
136
137 let likely_subject = {
140 let trimmed = left_context.trim_end();
141 trimmed.is_empty()
142 || trimmed.ends_with('.')
143 || trimmed.ends_with('!')
144 || trimmed.ends_with('?')
145 || trimmed.ends_with('\n')
146 || trimmed.len() < 50
147 };
148
149 let likely_heading = {
151 let line_start = left_context.rfind('\n').map(|i| i + 1).unwrap_or(0);
152 let line_end = right_context.find('\n').unwrap_or(right_context.len());
153 let line_len = (left_context.len() - line_start) + entity.text.len() + line_end;
154 line_len < 100
155 && entity
156 .text
157 .chars()
158 .next()
159 .map(|c| c.is_uppercase())
160 .unwrap_or(false)
161 };
162
163 let first_char = entity.text.chars().next();
164 let is_capitalized = first_char.map(|c| c.is_uppercase()).unwrap_or(false);
165 let is_all_caps = entity
166 .text
167 .chars()
168 .all(|c| !c.is_alphabetic() || c.is_uppercase());
169 let contains_digits = entity.text.chars().any(|c| c.is_ascii_digit());
170
171 Self {
172 entity: entity.clone(),
173 left_context,
174 right_context,
175 relative_position,
176 absolute_position: entity.start,
177 sentence_index: None, likely_subject,
179 likely_heading,
180 word_count: entity.text.split_whitespace().count(),
181 char_count: entity.text.chars().count(),
182 is_capitalized,
183 is_all_caps,
184 contains_digits,
185 }
186 }
187
188 pub fn full_context(&self) -> String {
190 format!(
191 "{}[{}]{}",
192 self.left_context, self.entity.text, self.right_context
193 )
194 }
195}
196
197pub use anno_core::MentionType;
207
208#[derive(Debug, Clone)]
210pub struct ChainFeatures {
211 pub canonical_form: String,
213 pub variations: Vec<String>,
215 pub chain_length: usize,
217 pub entity_type: Option<String>,
219
220 pub first_mention_position: usize,
223 pub last_mention_position: usize,
225 pub mention_spread: usize,
227 pub relative_spread: f64,
229
230 pub named_count: usize,
233 pub nominal_count: usize,
235 pub pronominal_count: usize,
237 pub pronoun_ratio: f64,
239
240 pub mean_position: f64,
243 pub positional_entropy: f64,
245 pub mean_confidence: f64,
247 pub min_confidence: f64,
249 pub max_confidence: f64,
251
252 pub centroid_embedding: Option<Vec<f32>>,
255}
256
257impl ChainFeatures {
258 pub fn from_mentions(mentions: &[&Entity], text_len: usize) -> Self {
260 if mentions.is_empty() {
261 return Self::empty();
262 }
263
264 let mut variations_set: HashSet<String> = HashSet::new();
266 for m in mentions {
267 variations_set.insert(m.text.clone());
268 }
269 let variations: Vec<String> = variations_set.into_iter().collect();
270
271 let canonical_form = mentions
273 .iter()
274 .filter(|m| MentionType::classify(&m.text) == MentionType::Proper)
275 .max_by_key(|m| m.text.len())
276 .map(|m| m.text.clone())
277 .unwrap_or_else(|| mentions[0].text.clone());
278
279 let first_pos = mentions.iter().map(|m| m.start).min().unwrap_or(0);
281 let last_pos = mentions.iter().map(|m| m.end).max().unwrap_or(0);
282 let spread = last_pos.saturating_sub(first_pos);
283 let relative_spread = if text_len > 0 {
284 spread as f64 / text_len as f64
285 } else {
286 0.0
287 };
288
289 let mut named_count = 0;
291 let mut nominal_count = 0;
292 let mut pronominal_count = 0;
293 for m in mentions {
294 match MentionType::classify(&m.text) {
295 MentionType::Proper => named_count += 1,
296 MentionType::Nominal => nominal_count += 1,
297 MentionType::Pronominal => pronominal_count += 1,
298 MentionType::Zero => pronominal_count += 1, MentionType::Unknown => nominal_count += 1, }
301 }
302 let total = mentions.len();
303 let pronoun_ratio = pronominal_count as f64 / total as f64;
304
305 let positions: Vec<f64> = mentions.iter().map(|m| m.start as f64).collect();
307 let mean_position = positions.iter().sum::<f64>() / total as f64;
308
309 let positional_entropy = if text_len > 0 && total > 1 {
311 let n_bins = 10;
312 let bin_size = text_len / n_bins;
313 let mut bins = vec![0usize; n_bins];
314 for m in mentions {
315 let bin = (m.start / bin_size.max(1)).min(n_bins - 1);
316 bins[bin] += 1;
317 }
318 let total_f = total as f64;
319 bins.iter()
320 .filter(|&&c| c > 0)
321 .map(|&c| {
322 let p = c as f64 / total_f;
323 -p * p.ln()
324 })
325 .sum()
326 } else {
327 0.0
328 };
329
330 let confidences: Vec<f64> = mentions.iter().map(|m| m.confidence).collect();
332 let mean_confidence = confidences.iter().sum::<f64>() / total as f64;
333 let min_confidence = confidences.iter().cloned().fold(f64::INFINITY, f64::min);
334 let max_confidence = confidences
335 .iter()
336 .cloned()
337 .fold(f64::NEG_INFINITY, f64::max);
338
339 let entity_type = mentions
341 .iter()
342 .find(|m| MentionType::classify(&m.text) == MentionType::Proper)
343 .or_else(|| mentions.first())
344 .map(|m| m.entity_type.as_label().to_string());
345
346 Self {
347 canonical_form,
348 variations,
349 chain_length: total,
350 entity_type,
351 first_mention_position: first_pos,
352 last_mention_position: last_pos,
353 mention_spread: spread,
354 relative_spread,
355 named_count,
356 nominal_count,
357 pronominal_count,
358 pronoun_ratio,
359 mean_position,
360 positional_entropy,
361 mean_confidence,
362 min_confidence,
363 max_confidence,
364 centroid_embedding: None,
365 }
366 }
367
368 fn empty() -> Self {
370 Self {
371 canonical_form: String::new(),
372 variations: Vec::new(),
373 chain_length: 0,
374 entity_type: None,
375 first_mention_position: 0,
376 last_mention_position: 0,
377 mention_spread: 0,
378 relative_spread: 0.0,
379 named_count: 0,
380 nominal_count: 0,
381 pronominal_count: 0,
382 pronoun_ratio: 0.0,
383 mean_position: 0.0,
384 positional_entropy: 0.0,
385 mean_confidence: 0.0,
386 min_confidence: 0.0,
387 max_confidence: 0.0,
388 centroid_embedding: None,
389 }
390 }
391
392 pub fn with_centroid(mut self, embedding: Vec<f32>) -> Self {
394 self.centroid_embedding = Some(embedding);
395 self
396 }
397
398 pub fn is_singleton(&self) -> bool {
400 self.chain_length == 1
401 }
402
403 pub fn is_mostly_pronominal(&self) -> bool {
405 self.pronoun_ratio > 0.5
406 }
407
408 #[must_use]
410 pub fn variation_count(&self) -> usize {
411 self.variations.len()
412 }
413}
414
415#[derive(Debug, Clone)]
421pub struct CooccurrenceFeatures {
422 pub entity_key: String,
424 pub cooccurring_entities: Vec<String>,
426 pub cooccurrence_counts: HashMap<String, usize>,
428 pub total_cooccurrences: usize,
430 pub unique_cooccurrences: usize,
432 pub cooccurring_types: HashMap<String, Vec<String>>,
434}
435
436impl CooccurrenceFeatures {
437 pub fn new(entity_key: String) -> Self {
439 Self {
440 entity_key,
441 cooccurring_entities: Vec::new(),
442 cooccurrence_counts: HashMap::new(),
443 total_cooccurrences: 0,
444 unique_cooccurrences: 0,
445 cooccurring_types: HashMap::new(),
446 }
447 }
448
449 pub fn add_cooccurrence(&mut self, other_key: &str, other_type: Option<&str>) {
451 *self
452 .cooccurrence_counts
453 .entry(other_key.to_string())
454 .or_insert(0) += 1;
455 self.total_cooccurrences += 1;
456
457 if let Some(t) = other_type {
458 self.cooccurring_types
459 .entry(other_key.to_string())
460 .or_default()
461 .push(t.to_string());
462 }
463 }
464
465 pub fn finalize(&mut self) {
467 self.cooccurring_entities = self.cooccurrence_counts.keys().cloned().collect();
468 self.cooccurring_entities.sort_by(|a, b| {
469 self.cooccurrence_counts
470 .get(b)
471 .cmp(&self.cooccurrence_counts.get(a))
472 });
473 self.unique_cooccurrences = self.cooccurring_entities.len();
474 }
475
476 pub fn top_k(&self, k: usize) -> Vec<(&str, usize)> {
478 self.cooccurring_entities
479 .iter()
480 .take(k)
481 .filter_map(|e| self.cooccurrence_counts.get(e).map(|&c| (e.as_str(), c)))
482 .collect()
483 }
484}
485
486#[derive(Debug, Clone)]
492pub struct DocumentFeatures {
493 pub mention_contexts: Vec<MentionContext>,
495 pub chain_features: HashMap<String, ChainFeatures>,
497 pub cooccurrence: HashMap<String, CooccurrenceFeatures>,
499 pub document_stats: DocumentStats,
501}
502
503#[derive(Debug, Clone)]
505pub struct DocumentStats {
506 pub char_count: usize,
508 pub word_count: usize,
510 pub mention_count: usize,
512 pub unique_entity_count: usize,
514 pub entity_density: f64,
516 pub type_distribution: HashMap<String, usize>,
518}
519
520#[derive(Debug, Clone)]
531pub struct EntityFeatureExtractor {
532 config: ExtractorConfig,
533}
534
535impl Default for EntityFeatureExtractor {
536 fn default() -> Self {
537 Self::new(ExtractorConfig::default())
538 }
539}
540
541impl EntityFeatureExtractor {
542 pub fn new(config: ExtractorConfig) -> Self {
544 Self { config }
545 }
546
547 pub fn extract_all(&self, text: &str, entities: &[Entity]) -> DocumentFeatures {
549 let text_len = text.chars().count();
550
551 let mention_contexts: Vec<MentionContext> = entities
553 .iter()
554 .map(|e| MentionContext::extract(text, e, &self.config))
555 .collect();
556
557 let groups = self.group_entities(entities);
559
560 let chain_features: HashMap<String, ChainFeatures> = groups
562 .iter()
563 .map(|(key, mentions)| {
564 let refs: Vec<&Entity> = mentions.to_vec();
565 (key.clone(), ChainFeatures::from_mentions(&refs, text_len))
566 })
567 .collect();
568
569 let cooccurrence = self.extract_cooccurrence(entities);
571
572 let word_count = text.split_whitespace().count();
574 let unique_entity_count = groups.len();
575 let entity_density = if text_len > 0 {
576 (entities.len() as f64 / text_len as f64) * 1000.0
577 } else {
578 0.0
579 };
580
581 let mut type_distribution: HashMap<String, usize> = HashMap::new();
582 for e in entities {
583 *type_distribution
584 .entry(e.entity_type.as_label().to_string())
585 .or_insert(0) += 1;
586 }
587
588 let document_stats = DocumentStats {
589 char_count: text_len,
590 word_count,
591 mention_count: entities.len(),
592 unique_entity_count,
593 entity_density,
594 type_distribution,
595 };
596
597 DocumentFeatures {
598 mention_contexts,
599 chain_features,
600 cooccurrence,
601 document_stats,
602 }
603 }
604
605 pub fn extract_mentions(&self, text: &str, entities: &[Entity]) -> Vec<MentionContext> {
607 entities
608 .iter()
609 .map(|e| MentionContext::extract(text, e, &self.config))
610 .collect()
611 }
612
613 pub fn extract_chains(
615 &self,
616 text: &str,
617 entities: &[Entity],
618 ) -> HashMap<String, ChainFeatures> {
619 let text_len = text.chars().count();
620 let groups = self.group_entities(entities);
621
622 groups
623 .iter()
624 .map(|(key, mentions)| {
625 let refs: Vec<&Entity> = mentions.to_vec();
626 (key.clone(), ChainFeatures::from_mentions(&refs, text_len))
627 })
628 .collect()
629 }
630
631 pub fn extract_cooccurrence(
633 &self,
634 entities: &[Entity],
635 ) -> HashMap<String, CooccurrenceFeatures> {
636 let mut result: HashMap<String, CooccurrenceFeatures> = HashMap::new();
637
638 for e in entities {
640 let key = self.normalize_key(&e.text);
641 result
642 .entry(key.clone())
643 .or_insert_with(|| CooccurrenceFeatures::new(key));
644 }
645
646 for (i, e1) in entities.iter().enumerate() {
648 let key1 = self.normalize_key(&e1.text);
649
650 for e2 in entities.iter().skip(i + 1) {
651 let key2 = self.normalize_key(&e2.text);
652
653 if key1 == key2 {
655 continue;
656 }
657
658 let distance = if e1.end <= e2.start {
660 e2.start - e1.end
661 } else if e2.end <= e1.start {
662 e1.start.saturating_sub(e2.end)
663 } else {
664 0 };
666
667 if distance <= self.config.cooccurrence_window {
668 if let Some(f) = result.get_mut(&key1) {
669 f.add_cooccurrence(&key2, Some(e2.entity_type.as_label()));
670 }
671 if let Some(f) = result.get_mut(&key2) {
672 f.add_cooccurrence(&key1, Some(e1.entity_type.as_label()));
673 }
674 }
675 }
676 }
677
678 for f in result.values_mut() {
680 f.finalize();
681 }
682
683 result
684 }
685
686 fn group_entities<'a>(&self, entities: &'a [Entity]) -> HashMap<String, Vec<&'a Entity>> {
688 let mut groups: HashMap<String, Vec<&'a Entity>> = HashMap::new();
689 for e in entities {
690 let key = self.normalize_key(&e.text);
691 groups.entry(key).or_default().push(e);
692 }
693 groups
694 }
695
696 fn normalize_key(&self, text: &str) -> String {
698 if self.config.normalize_text {
699 text.to_lowercase().trim().to_string()
700 } else {
701 text.trim().to_string()
702 }
703 }
704}
705
706#[derive(Debug, Clone)]
712pub struct PairwiseFeatures {
713 pub char_distance: usize,
715 pub mention_distance: usize,
717 pub exact_match: bool,
719 pub case_insensitive_match: bool,
721 pub string_similarity: f64,
723 pub type_match: bool,
725 pub mention_type_a: MentionType,
727 pub mention_type_b: MentionType,
729 pub is_pronominal_anaphora: bool,
731}
732
733impl PairwiseFeatures {
734 pub fn compute(a: &Entity, b: &Entity, mention_distance: usize) -> Self {
736 let char_distance = if a.end <= b.start {
737 b.start - a.end
738 } else if b.end <= a.start {
739 a.start.saturating_sub(b.end)
740 } else {
741 0
742 };
743
744 let exact_match = a.text == b.text;
745 let case_insensitive_match = a.text.to_lowercase() == b.text.to_lowercase();
746
747 let words_a: HashSet<&str> = a.text.split_whitespace().collect();
749 let words_b: HashSet<&str> = b.text.split_whitespace().collect();
750 let intersection = words_a.intersection(&words_b).count();
751 let union = words_a.union(&words_b).count();
752 let string_similarity = if union > 0 {
753 intersection as f64 / union as f64
754 } else {
755 0.0
756 };
757
758 let type_match = a.entity_type == b.entity_type;
759
760 let mention_type_a = MentionType::classify(&a.text);
761 let mention_type_b = MentionType::classify(&b.text);
762
763 let is_pronominal_anaphora = mention_type_b == MentionType::Pronominal
765 && mention_type_a != MentionType::Pronominal
766 && b.start > a.start;
767
768 Self {
769 char_distance,
770 mention_distance,
771 exact_match,
772 case_insensitive_match,
773 string_similarity,
774 type_match,
775 mention_type_a,
776 mention_type_b,
777 is_pronominal_anaphora,
778 }
779 }
780
781 pub fn compute_all_pairs(entities: &[Entity]) -> Vec<(usize, usize, PairwiseFeatures)> {
783 let mut pairs = Vec::new();
784 for (i, a) in entities.iter().enumerate() {
785 for (j, b) in entities.iter().enumerate().skip(i + 1) {
786 let mention_distance = j - i;
787 let features = Self::compute(a, b, mention_distance);
788 pairs.push((i, j, features));
789 }
790 }
791 pairs
792 }
793}
794
795pub fn aggregate_embeddings(
801 embeddings: &[Vec<f32>],
802 method: AggregationMethod,
803) -> Option<Vec<f32>> {
804 if embeddings.is_empty() {
805 return None;
806 }
807
808 let dim = embeddings[0].len();
809 if dim == 0 {
810 return None;
811 }
812
813 if !embeddings.iter().all(|e| e.len() == dim) {
815 return None;
816 }
817
818 match method {
819 AggregationMethod::Mean => {
820 let mut result = vec![0.0f32; dim];
821 for emb in embeddings {
822 for (i, &v) in emb.iter().enumerate() {
823 result[i] += v;
824 }
825 }
826 let n = embeddings.len() as f32;
827 for v in &mut result {
828 *v /= n;
829 }
830 Some(result)
831 }
832 AggregationMethod::Max => {
833 let mut result = vec![f32::NEG_INFINITY; dim];
834 for emb in embeddings {
835 for (i, &v) in emb.iter().enumerate() {
836 result[i] = result[i].max(v);
837 }
838 }
839 Some(result)
840 }
841 AggregationMethod::First => embeddings.first().cloned(),
842 AggregationMethod::WeightedMean { ref weights } => {
843 if weights.len() != embeddings.len() {
844 return None;
845 }
846 let total_weight: f32 = weights.iter().sum();
847 if total_weight == 0.0 {
848 return None;
849 }
850 let mut result = vec![0.0f32; dim];
851 for (emb, &w) in embeddings.iter().zip(weights.iter()) {
852 for (i, &v) in emb.iter().enumerate() {
853 result[i] += v * w;
854 }
855 }
856 for v in &mut result {
857 *v /= total_weight;
858 }
859 Some(result)
860 }
861 }
862}
863
864#[derive(Debug, Clone, Default)]
866pub enum AggregationMethod {
867 #[default]
869 Mean,
870 Max,
872 First,
874 WeightedMean {
876 weights: Vec<f32>,
878 },
879}
880
881#[cfg(test)]
886mod tests {
887 use super::*;
888 use crate::EntityType;
889
890 fn sample_entities() -> Vec<Entity> {
891 vec![
892 Entity::new("Barack Obama", EntityType::Person, 0, 12, 0.95),
893 Entity::new("Angela Merkel", EntityType::Person, 17, 30, 0.92),
894 Entity::new("Berlin", EntityType::Location, 34, 40, 0.88),
895 Entity::new("He", EntityType::Person, 42, 44, 0.85),
896 Entity::new("Obama", EntityType::Person, 60, 65, 0.90),
897 ]
898 }
899
900 #[test]
901 fn test_mention_type_classification() {
902 assert_eq!(MentionType::classify("he"), MentionType::Pronominal);
903 assert_eq!(MentionType::classify("She"), MentionType::Pronominal);
904 assert_eq!(MentionType::classify("Barack Obama"), MentionType::Proper);
905 assert_eq!(MentionType::classify("the president"), MentionType::Nominal);
906 assert_eq!(MentionType::classify("Apple Inc."), MentionType::Proper);
907 }
908
909 #[test]
910 fn test_mention_context_extraction() {
911 let text = "In Paris, Barack Obama met Angela Merkel. He discussed policy.";
912 let entity = Entity::new("Barack Obama", EntityType::Person, 10, 22, 0.95);
913
914 let ctx = MentionContext::extract(text, &entity, &ExtractorConfig::default());
915
916 assert_eq!(ctx.entity.text, "Barack Obama");
917 assert!(ctx.left_context.contains("Paris"));
918 assert!(ctx.right_context.contains("met"));
919 assert!(ctx.relative_position < 0.5); assert!(ctx.is_capitalized);
921 }
922
923 #[test]
924 fn test_chain_features() {
925 let entities = sample_entities();
926 let text_len = 100;
927
928 let obama_mentions: Vec<&Entity> = entities
930 .iter()
931 .filter(|e| e.text.to_lowercase().contains("obama") || e.text.to_lowercase() == "he")
932 .collect();
933
934 let features = ChainFeatures::from_mentions(&obama_mentions, text_len);
935
936 assert_eq!(features.chain_length, 3); assert!(features.variations.contains(&"Barack Obama".to_string()));
938 assert!(features.pronominal_count >= 1); assert!(features.named_count >= 1); assert_eq!(features.variation_count(), 3);
944 assert!(!features.is_singleton()); }
946
947 #[test]
948 fn test_cooccurrence_extraction() {
949 let _text = "Barack Obama met Angela Merkel in Berlin. He discussed policy.";
950 let entities = sample_entities();
951
952 let extractor = EntityFeatureExtractor::default();
953 let cooc = extractor.extract_cooccurrence(&entities);
954
955 let obama_cooc = cooc.get("barack obama").unwrap();
956 assert!(obama_cooc
957 .cooccurring_entities
958 .contains(&"angela merkel".to_string()));
959 assert!(obama_cooc
960 .cooccurring_entities
961 .contains(&"berlin".to_string()));
962 }
963
964 #[test]
965 fn test_pairwise_features() {
966 let a = Entity::new("Barack Obama", EntityType::Person, 0, 12, 0.95);
967 let b = Entity::new("Obama", EntityType::Person, 50, 55, 0.90);
968 let c = Entity::new("He", EntityType::Person, 60, 62, 0.85);
969
970 let ab = PairwiseFeatures::compute(&a, &b, 1);
971 assert!(ab.case_insensitive_match || ab.string_similarity > 0.0);
972 assert!(ab.type_match);
973
974 let ac = PairwiseFeatures::compute(&a, &c, 2);
975 assert!(ac.is_pronominal_anaphora);
976 }
977
978 #[test]
979 fn test_full_extraction() {
980 let text = "Barack Obama met Angela Merkel in Berlin. He discussed policy with her.";
981 let entities = sample_entities();
982
983 let extractor = EntityFeatureExtractor::default();
984 let features = extractor.extract_all(text, &entities);
985
986 assert_eq!(features.mention_contexts.len(), entities.len());
987 assert!(!features.chain_features.is_empty());
988 assert!(!features.cooccurrence.is_empty());
989 assert!(features.document_stats.mention_count == entities.len());
990 }
991
992 #[test]
993 fn test_aggregate_embeddings() {
994 let emb1 = vec![1.0, 2.0, 3.0];
995 let emb2 = vec![2.0, 4.0, 6.0];
996 let embeddings = vec![emb1, emb2];
997
998 let mean = aggregate_embeddings(&embeddings, AggregationMethod::Mean).unwrap();
999 assert_eq!(mean, vec![1.5, 3.0, 4.5]);
1000
1001 let max = aggregate_embeddings(&embeddings, AggregationMethod::Max).unwrap();
1002 assert_eq!(max, vec![2.0, 4.0, 6.0]);
1003 }
1004}