Skip to main content

anno/
features.rs

1//! Entity feature extraction for downstream ML and analysis.
2//!
3//! This module provides comprehensive feature extraction for entities at multiple
4//! levels of granularity:
5//!
6//! - **Mention-level**: Context windows, position, syntactic role
7//! - **Chain/Track-level**: Aggregate statistics across coreference chains
8//! - **Co-occurrence**: Which entities appear together
9//! - **Document-level**: Cross-entity patterns
10//!
11//! # Use Cases
12//!
13//! - Training coreference models (pairwise features)
14//! - Entity classification/linking (mention context)
15//! - Knowledge graph construction (co-occurrence patterns)
16//! - Entity salience prediction (aggregate features)
17//!
18//! # Example
19//!
20//! ```rust,ignore
21//! use anno::features::{EntityFeatureExtractor, ExtractorConfig};
22//! use anno::{Model, StackedNER};
23//!
24//! let text = "Barack Obama met Angela Merkel in Berlin. He discussed policy with her.";
25//! let ner = StackedNER::default();
26//! let entities = ner.extract_entities(text, None)?;
27//!
28//! let extractor = EntityFeatureExtractor::new(ExtractorConfig::default());
29//! let features = extractor.extract_all(text, &entities);
30//!
31//! // Get co-occurring entities for "Obama"
32//! let obama_cooc = features.cooccurrence.get("barack obama").unwrap();
33//! assert!(obama_cooc.cooccurring_entities.contains(&"angela merkel".to_string()));
34//! ```
35
36use crate::Entity;
37use std::collections::{HashMap, HashSet};
38
39// =============================================================================
40// Configuration
41// =============================================================================
42
43/// Configuration for entity feature extraction.
44#[derive(Debug, Clone)]
45pub struct ExtractorConfig {
46    /// Window size (in characters) for context extraction around mentions.
47    pub context_window: usize,
48    /// Window size (in characters) for co-occurrence detection.
49    pub cooccurrence_window: usize,
50    /// Whether to normalize text (lowercase) for grouping.
51    pub normalize_text: bool,
52    /// Minimum frequency for an entity to be included in co-occurrence.
53    pub min_cooccurrence_freq: usize,
54}
55
56impl Default for ExtractorConfig {
57    fn default() -> Self {
58        Self {
59            context_window: 100,
60            cooccurrence_window: 150,
61            normalize_text: true,
62            min_cooccurrence_freq: 1,
63        }
64    }
65}
66
67impl ExtractorConfig {
68    /// Create config with custom context window.
69    pub fn with_context_window(mut self, window: usize) -> Self {
70        self.context_window = window;
71        self
72    }
73
74    /// Create config with custom co-occurrence window.
75    pub fn with_cooccurrence_window(mut self, window: usize) -> Self {
76        self.cooccurrence_window = window;
77        self
78    }
79}
80
81// =============================================================================
82// Mention-Level Features
83// =============================================================================
84
85/// Context and features for a single entity mention.
86#[derive(Debug, Clone)]
87pub struct MentionContext {
88    /// The entity mention itself.
89    pub entity: Entity,
90    /// Text before the mention (up to context_window chars).
91    pub left_context: String,
92    /// Text after the mention (up to context_window chars).
93    pub right_context: String,
94    /// Position as fraction of document (0.0 = start, 1.0 = end).
95    pub relative_position: f64,
96    /// Character offset from document start.
97    pub absolute_position: usize,
98    /// Sentence index (if sentence boundaries detected).
99    pub sentence_index: Option<usize>,
100    /// Is this likely in subject position? (heuristic: near sentence start).
101    pub likely_subject: bool,
102    /// Is this in a heading or title? (heuristic: short line, capitalized).
103    pub likely_heading: bool,
104    /// Word count of the mention.
105    pub word_count: usize,
106    /// Character count of the mention.
107    pub char_count: usize,
108    /// Does the mention start with a capital letter?
109    pub is_capitalized: bool,
110    /// Is this mention all uppercase?
111    pub is_all_caps: bool,
112    /// Contains digits?
113    pub contains_digits: bool,
114}
115
116impl MentionContext {
117    /// Extract mention context from text and entity.
118    pub fn extract(text: &str, entity: &Entity, config: &ExtractorConfig) -> Self {
119        let text_chars: Vec<char> = text.chars().collect();
120        let text_len = text_chars.len();
121
122        // Safe bounds for context extraction
123        let left_start = entity.start.saturating_sub(config.context_window);
124        let left_end = entity.start;
125        let right_start = entity.end.min(text_len);
126        let right_end = (entity.end + config.context_window).min(text_len);
127
128        let left_context: String = text_chars[left_start..left_end].iter().collect();
129        let right_context: String = text_chars[right_start..right_end].iter().collect();
130
131        let relative_position = if text_len > 0 {
132            entity.start as f64 / text_len as f64
133        } else {
134            0.0
135        };
136
137        // Heuristic: likely subject if within first 50 chars of a sentence
138        // (simplified: check if near a period/newline in left context)
139        let likely_subject = {
140            let trimmed = left_context.trim_end();
141            trimmed.is_empty()
142                || trimmed.ends_with('.')
143                || trimmed.ends_with('!')
144                || trimmed.ends_with('?')
145                || trimmed.ends_with('\n')
146                || trimmed.len() < 50
147        };
148
149        // Heuristic: likely heading if the line is short and mostly capitalized
150        let likely_heading = {
151            let line_start = left_context.rfind('\n').map(|i| i + 1).unwrap_or(0);
152            let line_end = right_context.find('\n').unwrap_or(right_context.len());
153            let line_len = (left_context.len() - line_start) + entity.text.len() + line_end;
154            line_len < 100
155                && entity
156                    .text
157                    .chars()
158                    .next()
159                    .map(|c| c.is_uppercase())
160                    .unwrap_or(false)
161        };
162
163        let first_char = entity.text.chars().next();
164        let is_capitalized = first_char.map(|c| c.is_uppercase()).unwrap_or(false);
165        let is_all_caps = entity
166            .text
167            .chars()
168            .all(|c| !c.is_alphabetic() || c.is_uppercase());
169        let contains_digits = entity.text.chars().any(|c| c.is_ascii_digit());
170
171        Self {
172            entity: entity.clone(),
173            left_context,
174            right_context,
175            relative_position,
176            absolute_position: entity.start,
177            sentence_index: None, // Would need sentence segmentation
178            likely_subject,
179            likely_heading,
180            word_count: entity.text.split_whitespace().count(),
181            char_count: entity.text.chars().count(),
182            is_capitalized,
183            is_all_caps,
184            contains_digits,
185        }
186    }
187
188    /// Get the full context string (left + entity + right).
189    pub fn full_context(&self) -> String {
190        format!(
191            "{}[{}]{}",
192            self.left_context, self.entity.text, self.right_context
193        )
194    }
195}
196
197// =============================================================================
198// Chain/Track-Level Features
199// =============================================================================
200
201// Re-export the canonical MentionType from anno_core.
202// This unifies the type system across the anno ecosystem.
203//
204// Note: The canonical type uses `Proper` instead of `Named`. For compatibility,
205// use `MentionType::NAMED` constant or `MentionType::is_named()` method.
206pub use anno_core::MentionType;
207
208/// Aggregate features for a coreference chain (group of mentions referring to same entity).
209#[derive(Debug, Clone)]
210pub struct ChainFeatures {
211    /// Canonical/representative surface form.
212    pub canonical_form: String,
213    /// All surface form variations observed.
214    pub variations: Vec<String>,
215    /// Number of mentions in the chain.
216    pub chain_length: usize,
217    /// Entity type (from first/canonical mention).
218    pub entity_type: Option<String>,
219
220    // Positional features
221    /// Position of first mention (character offset).
222    pub first_mention_position: usize,
223    /// Position of last mention (character offset).
224    pub last_mention_position: usize,
225    /// Spread: distance from first to last mention.
226    pub mention_spread: usize,
227    /// Spread as fraction of document length.
228    pub relative_spread: f64,
229
230    // Type distribution
231    /// Count of named mentions.
232    pub named_count: usize,
233    /// Count of nominal mentions.
234    pub nominal_count: usize,
235    /// Count of pronominal mentions.
236    pub pronominal_count: usize,
237    /// Fraction of mentions that are pronominal.
238    pub pronoun_ratio: f64,
239
240    // Statistical features
241    /// Mean position of mentions.
242    pub mean_position: f64,
243    /// Positional entropy (how spread out are mentions?).
244    pub positional_entropy: f64,
245    /// Mean confidence across mentions.
246    pub mean_confidence: f64,
247    /// Min confidence.
248    pub min_confidence: f64,
249    /// Max confidence.
250    pub max_confidence: f64,
251
252    // Aggregate embedding (if available)
253    /// Centroid embedding (mean of mention embeddings, if available).
254    pub centroid_embedding: Option<Vec<f32>>,
255}
256
257impl ChainFeatures {
258    /// Compute chain features from a group of related entity mentions.
259    pub fn from_mentions(mentions: &[&Entity], text_len: usize) -> Self {
260        if mentions.is_empty() {
261            return Self::empty();
262        }
263
264        // Collect variations
265        let mut variations_set: HashSet<String> = HashSet::new();
266        for m in mentions {
267            variations_set.insert(m.text.clone());
268        }
269        let variations: Vec<String> = variations_set.into_iter().collect();
270
271        // Find canonical form (longest named mention, or first)
272        let canonical_form = mentions
273            .iter()
274            .filter(|m| MentionType::classify(&m.text) == MentionType::Proper)
275            .max_by_key(|m| m.text.len())
276            .map(|m| m.text.clone())
277            .unwrap_or_else(|| mentions[0].text.clone());
278
279        // Positions
280        let first_pos = mentions.iter().map(|m| m.start).min().unwrap_or(0);
281        let last_pos = mentions.iter().map(|m| m.end).max().unwrap_or(0);
282        let spread = last_pos.saturating_sub(first_pos);
283        let relative_spread = if text_len > 0 {
284            spread as f64 / text_len as f64
285        } else {
286            0.0
287        };
288
289        // Type distribution
290        let mut named_count = 0;
291        let mut nominal_count = 0;
292        let mut pronominal_count = 0;
293        for m in mentions {
294            match MentionType::classify(&m.text) {
295                MentionType::Proper => named_count += 1,
296                MentionType::Nominal => nominal_count += 1,
297                MentionType::Pronominal => pronominal_count += 1,
298                MentionType::Zero => pronominal_count += 1, // Treat zeros like pronouns
299                MentionType::Unknown => nominal_count += 1, // Conservative default
300            }
301        }
302        let total = mentions.len();
303        let pronoun_ratio = pronominal_count as f64 / total as f64;
304
305        // Statistical features
306        let positions: Vec<f64> = mentions.iter().map(|m| m.start as f64).collect();
307        let mean_position = positions.iter().sum::<f64>() / total as f64;
308
309        // Positional entropy: how spread out?
310        let positional_entropy = if text_len > 0 && total > 1 {
311            let n_bins = 10;
312            let bin_size = text_len / n_bins;
313            let mut bins = vec![0usize; n_bins];
314            for m in mentions {
315                let bin = (m.start / bin_size.max(1)).min(n_bins - 1);
316                bins[bin] += 1;
317            }
318            let total_f = total as f64;
319            bins.iter()
320                .filter(|&&c| c > 0)
321                .map(|&c| {
322                    let p = c as f64 / total_f;
323                    -p * p.ln()
324                })
325                .sum()
326        } else {
327            0.0
328        };
329
330        // Confidence stats
331        let confidences: Vec<f64> = mentions.iter().map(|m| m.confidence).collect();
332        let mean_confidence = confidences.iter().sum::<f64>() / total as f64;
333        let min_confidence = confidences.iter().cloned().fold(f64::INFINITY, f64::min);
334        let max_confidence = confidences
335            .iter()
336            .cloned()
337            .fold(f64::NEG_INFINITY, f64::max);
338
339        // Entity type from first named mention or first mention
340        let entity_type = mentions
341            .iter()
342            .find(|m| MentionType::classify(&m.text) == MentionType::Proper)
343            .or_else(|| mentions.first())
344            .map(|m| m.entity_type.as_label().to_string());
345
346        Self {
347            canonical_form,
348            variations,
349            chain_length: total,
350            entity_type,
351            first_mention_position: first_pos,
352            last_mention_position: last_pos,
353            mention_spread: spread,
354            relative_spread,
355            named_count,
356            nominal_count,
357            pronominal_count,
358            pronoun_ratio,
359            mean_position,
360            positional_entropy,
361            mean_confidence,
362            min_confidence,
363            max_confidence,
364            centroid_embedding: None,
365        }
366    }
367
368    /// Create empty chain features.
369    fn empty() -> Self {
370        Self {
371            canonical_form: String::new(),
372            variations: Vec::new(),
373            chain_length: 0,
374            entity_type: None,
375            first_mention_position: 0,
376            last_mention_position: 0,
377            mention_spread: 0,
378            relative_spread: 0.0,
379            named_count: 0,
380            nominal_count: 0,
381            pronominal_count: 0,
382            pronoun_ratio: 0.0,
383            mean_position: 0.0,
384            positional_entropy: 0.0,
385            mean_confidence: 0.0,
386            min_confidence: 0.0,
387            max_confidence: 0.0,
388            centroid_embedding: None,
389        }
390    }
391
392    /// Set the centroid embedding.
393    pub fn with_centroid(mut self, embedding: Vec<f32>) -> Self {
394        self.centroid_embedding = Some(embedding);
395        self
396    }
397
398    /// Is this a singleton chain (single mention)?
399    pub fn is_singleton(&self) -> bool {
400        self.chain_length == 1
401    }
402
403    /// Is this chain mostly pronominal?
404    pub fn is_mostly_pronominal(&self) -> bool {
405        self.pronoun_ratio > 0.5
406    }
407
408    /// Number of unique surface form variations.
409    #[must_use]
410    pub fn variation_count(&self) -> usize {
411        self.variations.len()
412    }
413}
414
415// =============================================================================
416// Co-occurrence Features
417// =============================================================================
418
419/// Co-occurrence features for an entity.
420#[derive(Debug, Clone)]
421pub struct CooccurrenceFeatures {
422    /// The entity (normalized key).
423    pub entity_key: String,
424    /// Entities that co-occur within the window.
425    pub cooccurring_entities: Vec<String>,
426    /// Co-occurrence counts per entity.
427    pub cooccurrence_counts: HashMap<String, usize>,
428    /// Total co-occurrence count (sum of all).
429    pub total_cooccurrences: usize,
430    /// Unique co-occurring entity count.
431    pub unique_cooccurrences: usize,
432    /// Entity types of co-occurring entities.
433    pub cooccurring_types: HashMap<String, Vec<String>>,
434}
435
436impl CooccurrenceFeatures {
437    /// Create new co-occurrence features for an entity.
438    pub fn new(entity_key: String) -> Self {
439        Self {
440            entity_key,
441            cooccurring_entities: Vec::new(),
442            cooccurrence_counts: HashMap::new(),
443            total_cooccurrences: 0,
444            unique_cooccurrences: 0,
445            cooccurring_types: HashMap::new(),
446        }
447    }
448
449    /// Add a co-occurring entity.
450    pub fn add_cooccurrence(&mut self, other_key: &str, other_type: Option<&str>) {
451        *self
452            .cooccurrence_counts
453            .entry(other_key.to_string())
454            .or_insert(0) += 1;
455        self.total_cooccurrences += 1;
456
457        if let Some(t) = other_type {
458            self.cooccurring_types
459                .entry(other_key.to_string())
460                .or_default()
461                .push(t.to_string());
462        }
463    }
464
465    /// Finalize the features (dedupe, sort, count).
466    pub fn finalize(&mut self) {
467        self.cooccurring_entities = self.cooccurrence_counts.keys().cloned().collect();
468        self.cooccurring_entities.sort_by(|a, b| {
469            self.cooccurrence_counts
470                .get(b)
471                .cmp(&self.cooccurrence_counts.get(a))
472        });
473        self.unique_cooccurrences = self.cooccurring_entities.len();
474    }
475
476    /// Get top-k co-occurring entities by count.
477    pub fn top_k(&self, k: usize) -> Vec<(&str, usize)> {
478        self.cooccurring_entities
479            .iter()
480            .take(k)
481            .filter_map(|e| self.cooccurrence_counts.get(e).map(|&c| (e.as_str(), c)))
482            .collect()
483    }
484}
485
486// =============================================================================
487// Document-Level Feature Collection
488// =============================================================================
489
490/// Complete feature extraction results for a document.
491#[derive(Debug, Clone)]
492pub struct DocumentFeatures {
493    /// Mention-level features for each entity occurrence.
494    pub mention_contexts: Vec<MentionContext>,
495    /// Chain features grouped by normalized entity key.
496    pub chain_features: HashMap<String, ChainFeatures>,
497    /// Co-occurrence features per entity.
498    pub cooccurrence: HashMap<String, CooccurrenceFeatures>,
499    /// Document-level statistics.
500    pub document_stats: DocumentStats,
501}
502
503/// Document-level statistics.
504#[derive(Debug, Clone)]
505pub struct DocumentStats {
506    /// Document length in characters.
507    pub char_count: usize,
508    /// Document length in words.
509    pub word_count: usize,
510    /// Total entity mention count.
511    pub mention_count: usize,
512    /// Unique entity count (by normalized text).
513    pub unique_entity_count: usize,
514    /// Entity density (mentions per 1000 chars).
515    pub entity_density: f64,
516    /// Entity type distribution.
517    pub type_distribution: HashMap<String, usize>,
518}
519
520// =============================================================================
521// Main Extractor
522// =============================================================================
523
524/// Entity feature extractor.
525///
526/// Extracts comprehensive features from entities at multiple levels:
527/// - Mention context (surrounding text, position)
528/// - Chain aggregates (for coreference chains)
529/// - Co-occurrence patterns (which entities appear together)
530#[derive(Debug, Clone)]
531pub struct EntityFeatureExtractor {
532    config: ExtractorConfig,
533}
534
535impl Default for EntityFeatureExtractor {
536    fn default() -> Self {
537        Self::new(ExtractorConfig::default())
538    }
539}
540
541impl EntityFeatureExtractor {
542    /// Create a new extractor with the given configuration.
543    pub fn new(config: ExtractorConfig) -> Self {
544        Self { config }
545    }
546
547    /// Extract all features from text and entities.
548    pub fn extract_all(&self, text: &str, entities: &[Entity]) -> DocumentFeatures {
549        let text_len = text.chars().count();
550
551        // 1. Extract mention contexts
552        let mention_contexts: Vec<MentionContext> = entities
553            .iter()
554            .map(|e| MentionContext::extract(text, e, &self.config))
555            .collect();
556
557        // 2. Group entities by normalized key
558        let groups = self.group_entities(entities);
559
560        // 3. Compute chain features
561        let chain_features: HashMap<String, ChainFeatures> = groups
562            .iter()
563            .map(|(key, mentions)| {
564                let refs: Vec<&Entity> = mentions.to_vec();
565                (key.clone(), ChainFeatures::from_mentions(&refs, text_len))
566            })
567            .collect();
568
569        // 4. Compute co-occurrence features
570        let cooccurrence = self.extract_cooccurrence(entities);
571
572        // 5. Document stats
573        let word_count = text.split_whitespace().count();
574        let unique_entity_count = groups.len();
575        let entity_density = if text_len > 0 {
576            (entities.len() as f64 / text_len as f64) * 1000.0
577        } else {
578            0.0
579        };
580
581        let mut type_distribution: HashMap<String, usize> = HashMap::new();
582        for e in entities {
583            *type_distribution
584                .entry(e.entity_type.as_label().to_string())
585                .or_insert(0) += 1;
586        }
587
588        let document_stats = DocumentStats {
589            char_count: text_len,
590            word_count,
591            mention_count: entities.len(),
592            unique_entity_count,
593            entity_density,
594            type_distribution,
595        };
596
597        DocumentFeatures {
598            mention_contexts,
599            chain_features,
600            cooccurrence,
601            document_stats,
602        }
603    }
604
605    /// Extract only mention contexts (lightweight).
606    pub fn extract_mentions(&self, text: &str, entities: &[Entity]) -> Vec<MentionContext> {
607        entities
608            .iter()
609            .map(|e| MentionContext::extract(text, e, &self.config))
610            .collect()
611    }
612
613    /// Extract only chain features (requires grouping).
614    pub fn extract_chains(
615        &self,
616        text: &str,
617        entities: &[Entity],
618    ) -> HashMap<String, ChainFeatures> {
619        let text_len = text.chars().count();
620        let groups = self.group_entities(entities);
621
622        groups
623            .iter()
624            .map(|(key, mentions)| {
625                let refs: Vec<&Entity> = mentions.to_vec();
626                (key.clone(), ChainFeatures::from_mentions(&refs, text_len))
627            })
628            .collect()
629    }
630
631    /// Extract only co-occurrence features.
632    pub fn extract_cooccurrence(
633        &self,
634        entities: &[Entity],
635    ) -> HashMap<String, CooccurrenceFeatures> {
636        let mut result: HashMap<String, CooccurrenceFeatures> = HashMap::new();
637
638        // Initialize features for each unique entity
639        for e in entities {
640            let key = self.normalize_key(&e.text);
641            result
642                .entry(key.clone())
643                .or_insert_with(|| CooccurrenceFeatures::new(key));
644        }
645
646        // Find co-occurrences
647        for (i, e1) in entities.iter().enumerate() {
648            let key1 = self.normalize_key(&e1.text);
649
650            for e2 in entities.iter().skip(i + 1) {
651                let key2 = self.normalize_key(&e2.text);
652
653                // Skip self-cooccurrence
654                if key1 == key2 {
655                    continue;
656                }
657
658                // Check if within cooccurrence window
659                let distance = if e1.end <= e2.start {
660                    e2.start - e1.end
661                } else if e2.end <= e1.start {
662                    e1.start.saturating_sub(e2.end)
663                } else {
664                    0 // overlapping
665                };
666
667                if distance <= self.config.cooccurrence_window {
668                    if let Some(f) = result.get_mut(&key1) {
669                        f.add_cooccurrence(&key2, Some(e2.entity_type.as_label()));
670                    }
671                    if let Some(f) = result.get_mut(&key2) {
672                        f.add_cooccurrence(&key1, Some(e1.entity_type.as_label()));
673                    }
674                }
675            }
676        }
677
678        // Finalize all
679        for f in result.values_mut() {
680            f.finalize();
681        }
682
683        result
684    }
685
686    /// Group entities by normalized key.
687    fn group_entities<'a>(&self, entities: &'a [Entity]) -> HashMap<String, Vec<&'a Entity>> {
688        let mut groups: HashMap<String, Vec<&'a Entity>> = HashMap::new();
689        for e in entities {
690            let key = self.normalize_key(&e.text);
691            groups.entry(key).or_default().push(e);
692        }
693        groups
694    }
695
696    /// Normalize entity text to a key for grouping.
697    fn normalize_key(&self, text: &str) -> String {
698        if self.config.normalize_text {
699            text.to_lowercase().trim().to_string()
700        } else {
701            text.trim().to_string()
702        }
703    }
704}
705
706// =============================================================================
707// Pairwise Features (for coreference training)
708// =============================================================================
709
710/// Pairwise features between two mentions (for coreference scoring).
711#[derive(Debug, Clone)]
712pub struct PairwiseFeatures {
713    /// Distance in characters between mentions.
714    pub char_distance: usize,
715    /// Distance in mentions (number of mentions between).
716    pub mention_distance: usize,
717    /// Do the surface forms match exactly?
718    pub exact_match: bool,
719    /// Do the surface forms match after lowercasing?
720    pub case_insensitive_match: bool,
721    /// String similarity (Jaccard on words).
722    pub string_similarity: f64,
723    /// Do the entity types match?
724    pub type_match: bool,
725    /// Mention type of first mention.
726    pub mention_type_a: MentionType,
727    /// Mention type of second mention.
728    pub mention_type_b: MentionType,
729    /// Is the second mention a pronoun referring back?
730    pub is_pronominal_anaphora: bool,
731}
732
733impl PairwiseFeatures {
734    /// Compute pairwise features between two mentions.
735    pub fn compute(a: &Entity, b: &Entity, mention_distance: usize) -> Self {
736        let char_distance = if a.end <= b.start {
737            b.start - a.end
738        } else if b.end <= a.start {
739            a.start.saturating_sub(b.end)
740        } else {
741            0
742        };
743
744        let exact_match = a.text == b.text;
745        let case_insensitive_match = a.text.to_lowercase() == b.text.to_lowercase();
746
747        // Jaccard similarity on words
748        let words_a: HashSet<&str> = a.text.split_whitespace().collect();
749        let words_b: HashSet<&str> = b.text.split_whitespace().collect();
750        let intersection = words_a.intersection(&words_b).count();
751        let union = words_a.union(&words_b).count();
752        let string_similarity = if union > 0 {
753            intersection as f64 / union as f64
754        } else {
755            0.0
756        };
757
758        let type_match = a.entity_type == b.entity_type;
759
760        let mention_type_a = MentionType::classify(&a.text);
761        let mention_type_b = MentionType::classify(&b.text);
762
763        // Pronominal anaphora: second mention is pronoun, first is not
764        let is_pronominal_anaphora = mention_type_b == MentionType::Pronominal
765            && mention_type_a != MentionType::Pronominal
766            && b.start > a.start;
767
768        Self {
769            char_distance,
770            mention_distance,
771            exact_match,
772            case_insensitive_match,
773            string_similarity,
774            type_match,
775            mention_type_a,
776            mention_type_b,
777            is_pronominal_anaphora,
778        }
779    }
780
781    /// Compute features for all mention pairs in a document.
782    pub fn compute_all_pairs(entities: &[Entity]) -> Vec<(usize, usize, PairwiseFeatures)> {
783        let mut pairs = Vec::new();
784        for (i, a) in entities.iter().enumerate() {
785            for (j, b) in entities.iter().enumerate().skip(i + 1) {
786                let mention_distance = j - i;
787                let features = Self::compute(a, b, mention_distance);
788                pairs.push((i, j, features));
789            }
790        }
791        pairs
792    }
793}
794
795// =============================================================================
796// Embedding Aggregation Utilities
797// =============================================================================
798
799/// Aggregate embeddings from multiple mentions.
800pub fn aggregate_embeddings(
801    embeddings: &[Vec<f32>],
802    method: AggregationMethod,
803) -> Option<Vec<f32>> {
804    if embeddings.is_empty() {
805        return None;
806    }
807
808    let dim = embeddings[0].len();
809    if dim == 0 {
810        return None;
811    }
812
813    // Verify all same dimension
814    if !embeddings.iter().all(|e| e.len() == dim) {
815        return None;
816    }
817
818    match method {
819        AggregationMethod::Mean => {
820            let mut result = vec![0.0f32; dim];
821            for emb in embeddings {
822                for (i, &v) in emb.iter().enumerate() {
823                    result[i] += v;
824                }
825            }
826            let n = embeddings.len() as f32;
827            for v in &mut result {
828                *v /= n;
829            }
830            Some(result)
831        }
832        AggregationMethod::Max => {
833            let mut result = vec![f32::NEG_INFINITY; dim];
834            for emb in embeddings {
835                for (i, &v) in emb.iter().enumerate() {
836                    result[i] = result[i].max(v);
837                }
838            }
839            Some(result)
840        }
841        AggregationMethod::First => embeddings.first().cloned(),
842        AggregationMethod::WeightedMean { ref weights } => {
843            if weights.len() != embeddings.len() {
844                return None;
845            }
846            let total_weight: f32 = weights.iter().sum();
847            if total_weight == 0.0 {
848                return None;
849            }
850            let mut result = vec![0.0f32; dim];
851            for (emb, &w) in embeddings.iter().zip(weights.iter()) {
852                for (i, &v) in emb.iter().enumerate() {
853                    result[i] += v * w;
854                }
855            }
856            for v in &mut result {
857                *v /= total_weight;
858            }
859            Some(result)
860        }
861    }
862}
863
864/// Method for aggregating multiple embeddings into one.
865#[derive(Debug, Clone, Default)]
866pub enum AggregationMethod {
867    /// Mean of all embeddings.
868    #[default]
869    Mean,
870    /// Element-wise max.
871    Max,
872    /// Just use the first embedding.
873    First,
874    /// Weighted mean with custom weights per embedding dimension.
875    WeightedMean {
876        /// Weights for each embedding dimension.
877        weights: Vec<f32>,
878    },
879}
880
881// =============================================================================
882// Tests
883// =============================================================================
884
885#[cfg(test)]
886mod tests {
887    use super::*;
888    use crate::EntityType;
889
890    fn sample_entities() -> Vec<Entity> {
891        vec![
892            Entity::new("Barack Obama", EntityType::Person, 0, 12, 0.95),
893            Entity::new("Angela Merkel", EntityType::Person, 17, 30, 0.92),
894            Entity::new("Berlin", EntityType::Location, 34, 40, 0.88),
895            Entity::new("He", EntityType::Person, 42, 44, 0.85),
896            Entity::new("Obama", EntityType::Person, 60, 65, 0.90),
897        ]
898    }
899
900    #[test]
901    fn test_mention_type_classification() {
902        assert_eq!(MentionType::classify("he"), MentionType::Pronominal);
903        assert_eq!(MentionType::classify("She"), MentionType::Pronominal);
904        assert_eq!(MentionType::classify("Barack Obama"), MentionType::Proper);
905        assert_eq!(MentionType::classify("the president"), MentionType::Nominal);
906        assert_eq!(MentionType::classify("Apple Inc."), MentionType::Proper);
907    }
908
909    #[test]
910    fn test_mention_context_extraction() {
911        let text = "In Paris, Barack Obama met Angela Merkel. He discussed policy.";
912        let entity = Entity::new("Barack Obama", EntityType::Person, 10, 22, 0.95);
913
914        let ctx = MentionContext::extract(text, &entity, &ExtractorConfig::default());
915
916        assert_eq!(ctx.entity.text, "Barack Obama");
917        assert!(ctx.left_context.contains("Paris"));
918        assert!(ctx.right_context.contains("met"));
919        assert!(ctx.relative_position < 0.5); // Early in document
920        assert!(ctx.is_capitalized);
921    }
922
923    #[test]
924    fn test_chain_features() {
925        let entities = sample_entities();
926        let text_len = 100;
927
928        // Group Obama mentions
929        let obama_mentions: Vec<&Entity> = entities
930            .iter()
931            .filter(|e| e.text.to_lowercase().contains("obama") || e.text.to_lowercase() == "he")
932            .collect();
933
934        let features = ChainFeatures::from_mentions(&obama_mentions, text_len);
935
936        assert_eq!(features.chain_length, 3); // Barack Obama, He, Obama
937        assert!(features.variations.contains(&"Barack Obama".to_string()));
938        assert!(features.pronominal_count >= 1); // "He"
939        assert!(features.named_count >= 1); // "Barack Obama"
940
941        // Test variation_count()
942        // Should have: "Barack Obama", "He", "Obama"
943        assert_eq!(features.variation_count(), 3);
944        assert!(!features.is_singleton()); // Multiple mentions
945    }
946
947    #[test]
948    fn test_cooccurrence_extraction() {
949        let _text = "Barack Obama met Angela Merkel in Berlin. He discussed policy.";
950        let entities = sample_entities();
951
952        let extractor = EntityFeatureExtractor::default();
953        let cooc = extractor.extract_cooccurrence(&entities);
954
955        let obama_cooc = cooc.get("barack obama").unwrap();
956        assert!(obama_cooc
957            .cooccurring_entities
958            .contains(&"angela merkel".to_string()));
959        assert!(obama_cooc
960            .cooccurring_entities
961            .contains(&"berlin".to_string()));
962    }
963
964    #[test]
965    fn test_pairwise_features() {
966        let a = Entity::new("Barack Obama", EntityType::Person, 0, 12, 0.95);
967        let b = Entity::new("Obama", EntityType::Person, 50, 55, 0.90);
968        let c = Entity::new("He", EntityType::Person, 60, 62, 0.85);
969
970        let ab = PairwiseFeatures::compute(&a, &b, 1);
971        assert!(ab.case_insensitive_match || ab.string_similarity > 0.0);
972        assert!(ab.type_match);
973
974        let ac = PairwiseFeatures::compute(&a, &c, 2);
975        assert!(ac.is_pronominal_anaphora);
976    }
977
978    #[test]
979    fn test_full_extraction() {
980        let text = "Barack Obama met Angela Merkel in Berlin. He discussed policy with her.";
981        let entities = sample_entities();
982
983        let extractor = EntityFeatureExtractor::default();
984        let features = extractor.extract_all(text, &entities);
985
986        assert_eq!(features.mention_contexts.len(), entities.len());
987        assert!(!features.chain_features.is_empty());
988        assert!(!features.cooccurrence.is_empty());
989        assert!(features.document_stats.mention_count == entities.len());
990    }
991
992    #[test]
993    fn test_aggregate_embeddings() {
994        let emb1 = vec![1.0, 2.0, 3.0];
995        let emb2 = vec![2.0, 4.0, 6.0];
996        let embeddings = vec![emb1, emb2];
997
998        let mean = aggregate_embeddings(&embeddings, AggregationMethod::Mean).unwrap();
999        assert_eq!(mean, vec![1.5, 3.0, 4.5]);
1000
1001        let max = aggregate_embeddings(&embeddings, AggregationMethod::Max).unwrap();
1002        assert_eq!(max, vec![2.0, 4.0, 6.0]);
1003    }
1004}