Skip to main content

anno/backends/
ensemble.rs

1//! Ensemble NER - Multi-backend extraction with unsupervised weighted voting.
2//!
3//! # Method
4//!
5//! This is an **unsupervised heuristic** approach (no training data required).
6//! Conflict resolution uses hand-tuned weights based on expected backend reliability.
7//! For supervised weight learning from labeled data, see `WeightLearner`.
8//!
9//! # The Core Idea
10//!
11//! Instead of simple priority-based stacking, `EnsembleNER`:
12//! 1. Runs ALL available backends opportunistically (in parallel conceptually)
13//! 2. Collects candidate entities with provenance
14//! 3. Groups overlapping spans into conflict clusters
15//! 4. Resolves conflicts using weighted voting with agreement bonus
16//!
17//! ```text
18//! ┌─────────────────────────────────────────────────────────────────────────┐
19//! │                    ENSEMBLE NER ARCHITECTURE                            │
20//! ├─────────────────────────────────────────────────────────────────────────┤
21//! │                                                                         │
22//! │  Input: "Tim Cook, CEO of Apple, met with Sundar Pichai"                │
23//! │                                                                         │
24//! │  ┌──────────────────────────────────────────────────────────────────┐   │
25//! │  │ PHASE 1: OPPORTUNISTIC EXTRACTION (parallel)                     │   │
26//! │  │                                                                  │   │
27//! │  │  Pattern ──────► [no entities]                                   │   │
28//! │  │  Heuristic ────► Tim Cook (PER, 0.75), Apple (ORG, 0.80), ...    │   │
29//! │  │  GLiNER ────────► Tim Cook (PER, 0.95), Apple (ORG, 0.87), ...   │   │
30//! │  │  Candle ────────► [unavailable, skip]                            │   │
31//! │  └──────────────────────────────────────────────────────────────────┘   │
32//! │                            │                                            │
33//! │                            ▼                                            │
34//! │  ┌──────────────────────────────────────────────────────────────────┐   │
35//! │  │ PHASE 2: CANDIDATE AGGREGATION                                   │   │
36//! │  │                                                                  │   │
37//! │  │  Span [0:8] "Tim Cook":                                          │   │
38//! │  │    • Heuristic: PER (0.75)                                       │   │
39//! │  │    • GLiNER: PER (0.95)                                          │   │
40//! │  │    Agreement: 2/2 → HIGH confidence                              │   │
41//! │  │                                                                  │   │
42//! │  │  Span [17:22] "Apple":                                           │   │
43//! │  │    • Heuristic: ORG (0.80)                                       │   │
44//! │  │    • GLiNER: ORG (0.87)                                          │   │
45//! │  │    Agreement: 2/2 → HIGH confidence                              │   │
46//! │  └──────────────────────────────────────────────────────────────────┘   │
47//! │                            │                                            │
48//! │                            ▼                                            │
49//! │  ┌──────────────────────────────────────────────────────────────────┐   │
50//! │  │ PHASE 3: CONFLICT RESOLUTION (weighted voting)                   │   │
51//! │  │                                                                  │   │
52//! │  │  Backend weights (learned or configured):                        │   │
53//! │  │    Pattern: 0.99 (when fires, almost always right)               │   │
54//! │  │    GLiNER:  0.85 (ML-based, good accuracy)                       │   │
55//! │  │    Heuristic: 0.65 (reasonable but noisy)                        │   │
56//! │  │                                                                  │   │
57//! │  │  For span [0:8]:                                                 │   │
58//! │  │    Weighted vote = (0.65 * 0.75) + (0.85 * 0.95) = 1.29          │   │
59//! │  │    Normalized confidence = 0.91                                  │   │
60//! │  └──────────────────────────────────────────────────────────────────┘   │
61//! │                            │                                            │
62//! │                            ▼                                            │
63//! │  ┌──────────────────────────────────────────────────────────────────┐   │
64//! │  │ OUTPUT                                                           │   │
65//! │  │                                                                  │   │
66//! │  │  Entity { text: "Tim Cook", type: PER, conf: 0.91,               │   │
67//! │  │           sources: ["heuristic", "gliner"], agreement: 1.0 }     │   │
68//! │  └──────────────────────────────────────────────────────────────────┘   │
69//! │                                                                         │
70//! └─────────────────────────────────────────────────────────────────────────┘
71//! ```
72//!
73//! # Conflict Resolution Strategies
74//!
75//! ## Weighted Voting (Unsupervised)
76//!
77//! Each backend has a weight based on its expected reliability:
78//! - Pattern backends: high weight (0.95+) when they fire
79//! - ML backends: medium-high weight (0.80-0.90)
80//! - Heuristic backends: lower weight (0.60-0.70)
81//!
82//! ## Type-Conditioned Voting
83//!
84//! Some backends are better at certain types:
85//! - Pattern: DATE, MONEY, EMAIL, URL (near-perfect)
86//! - GLiNER: PER, ORG (good), LOC (decent)
87//! - Heuristic: ORG (good with "Inc", "Corp"), PER (title+name patterns)
88//!
89//! ## Agreement Bonus
90//!
91//! When multiple backends agree on type AND span, boost confidence:
92//! - 2 backends agree: +0.10 bonus
93//! - 3+ backends agree: +0.15 bonus
94//!
95//! # Example
96//!
97//! ```rust
98//! use anno::{Model, EnsembleNER};
99//!
100//! let ner = EnsembleNER::new();
101//! let entities = ner.extract_entities("Tim Cook leads Apple Inc.", None).unwrap();
102//!
103//! // Each entity includes provenance and agreement info
104//! for e in &entities {
105//!     println!("{}: {} (conf: {:.2}, sources: {:?})",
106//!              e.entity_type.as_label(), e.text, e.confidence,
107//!              e.provenance.as_ref().map(|p| &p.source));
108//! }
109//! ```
110
111use std::borrow::Cow;
112use std::collections::HashMap;
113use std::sync::Arc;
114
115use crate::{Entity, EntityType, Model, Result};
116
117fn method_for_backend_id(backend_id: &str) -> anno_core::ExtractionMethod {
118    match backend_id {
119        // Stable IDs used by `EnsembleNER::new()`.
120        "regex" => anno_core::ExtractionMethod::Pattern,
121        "heuristic" => anno_core::ExtractionMethod::Heuristic,
122        // Legacy backend id (deprecated, but still used in tests/compositions).
123        "rule" => anno_core::ExtractionMethod::Heuristic,
124        // Everything else: treat as neural by default.
125        _ => anno_core::ExtractionMethod::Neural,
126    }
127}
128
129// =============================================================================
130// Backend Weights
131// =============================================================================
132
133/// Reliability weight for a backend (0.0 to 1.0).
134///
135/// Higher weight = more trusted when resolving conflicts.
136#[derive(Debug, Clone, Copy)]
137pub struct BackendWeight {
138    /// Overall reliability of this backend
139    pub overall: f64,
140    /// Type-specific weights (optional overrides)
141    pub per_type: Option<TypeWeights>,
142}
143
144impl Default for BackendWeight {
145    fn default() -> Self {
146        Self {
147            overall: 0.5,
148            per_type: None,
149        }
150    }
151}
152
153/// Type-specific reliability weights.
154///
155/// Different backends may have different accuracy profiles for different entity types.
156/// These weights adjust confidence scores based on the entity type being extracted.
157#[derive(Debug, Clone, Copy, Default)]
158pub struct TypeWeights {
159    /// Weight multiplier for Person entities
160    pub person: f64,
161    /// Weight multiplier for Organization entities
162    pub organization: f64,
163    /// Weight multiplier for Location entities
164    pub location: f64,
165    /// Weight multiplier for Date entities
166    pub date: f64,
167    /// Weight multiplier for Money entities
168    pub money: f64,
169    /// Weight multiplier for other/misc entity types
170    pub other: f64,
171}
172
173impl TypeWeights {
174    fn get(&self, entity_type: &EntityType) -> f64 {
175        match entity_type {
176            EntityType::Person => self.person,
177            EntityType::Organization => self.organization,
178            EntityType::Location => self.location,
179            EntityType::Date => self.date,
180            EntityType::Money => self.money,
181            _ => self.other,
182        }
183    }
184}
185
186/// Default weights based on empirical observations.
187fn default_backend_weights() -> HashMap<&'static str, BackendWeight> {
188    let mut weights = HashMap::new();
189
190    // Pattern backends: very high precision when they fire
191    weights.insert(
192        "regex",
193        BackendWeight {
194            overall: 0.98,
195            per_type: Some(TypeWeights {
196                date: 0.99,
197                money: 0.99,
198                person: 0.50, // Pattern doesn't do NER
199                organization: 0.50,
200                location: 0.50,
201                other: 0.95, // URLs, emails, etc.
202            }),
203        },
204    );
205
206    // GLiNER: good ML-based NER
207    weights.insert(
208        "gliner",
209        BackendWeight {
210            overall: 0.85,
211            per_type: Some(TypeWeights {
212                person: 0.90,
213                organization: 0.85,
214                location: 0.80,
215                date: 0.75,
216                money: 0.70,
217                other: 0.75,
218            }),
219        },
220    );
221    weights.insert(
222        "GLiNER-ONNX",
223        BackendWeight {
224            overall: 0.85,
225            per_type: Some(TypeWeights {
226                person: 0.90,
227                organization: 0.85,
228                location: 0.80,
229                date: 0.75,
230                money: 0.70,
231                other: 0.75,
232            }),
233        },
234    );
235
236    // GLiNER Candle
237    weights.insert(
238        "gliner-candle",
239        BackendWeight {
240            overall: 0.85,
241            per_type: None,
242        },
243    );
244
245    // BERT NER
246    weights.insert(
247        "bert-ner-onnx",
248        BackendWeight {
249            overall: 0.80,
250            per_type: None,
251        },
252    );
253
254    // Heuristic: reasonable but noisy
255    weights.insert(
256        "heuristic",
257        BackendWeight {
258            overall: 0.60,
259            per_type: Some(TypeWeights {
260                person: 0.65,       // Title + Name pattern works well
261                organization: 0.70, // "Inc", "Corp" patterns
262                location: 0.55,     // Context-dependent
263                date: 0.40,         // Better to use pattern
264                money: 0.40,
265                other: 0.50,
266            }),
267        },
268    );
269
270    weights
271}
272
273// =============================================================================
274// Candidate Entity (with source tracking)
275// =============================================================================
276
277/// An entity candidate from a specific backend.
278#[derive(Debug, Clone)]
279struct Candidate {
280    entity: Entity,
281    source: String,
282    backend_weight: f64,
283}
284
285// =============================================================================
286// Span Key (for grouping overlapping entities)
287// =============================================================================
288
289/// Key for grouping entities by span.
290///
291/// Two entities are considered "same span" if they significantly overlap.
292#[derive(Debug, Clone, PartialEq, Eq, Hash)]
293struct SpanKey {
294    start: usize,
295    end: usize,
296}
297
298impl SpanKey {
299    fn from_entity(e: &Entity) -> Self {
300        Self {
301            start: e.start,
302            end: e.end,
303        }
304    }
305
306    /// Check if two spans overlap significantly (>50% of smaller span).
307    fn overlaps(&self, other: &SpanKey) -> bool {
308        let overlap_start = self.start.max(other.start);
309        let overlap_end = self.end.min(other.end);
310
311        if overlap_start >= overlap_end {
312            return false;
313        }
314
315        let overlap = overlap_end - overlap_start;
316        let smaller_span = (self.end - self.start).min(other.end - other.start);
317
318        // Overlap if >50% of smaller span is covered
319        (overlap as f64 / smaller_span as f64) > 0.5
320    }
321}
322
323// =============================================================================
324// EnsembleNER
325// =============================================================================
326
327/// Ensemble NER that runs ALL backends and resolves conflicts via weighted voting.
328///
329/// Unlike [`StackedNER`] (priority-based cascade), `EnsembleNER`:
330/// 1. Runs ALL backends in parallel (conceptually)
331/// 2. Groups overlapping spans into conflict clusters
332/// 3. Resolves via weighted voting with type-conditioned weights
333/// 4. Applies agreement bonus when multiple backends agree
334///
335/// # When to Use
336///
337/// - **EnsembleNER**: Maximum accuracy, latency not critical
338/// - **StackedNER**: Production, predictable latency, early exit
339///
340/// # Example
341///
342/// ```rust
343/// use anno::{EnsembleNER, Model, RegexNER, HeuristicNER};
344///
345/// // Default: uses all available backends
346/// let ensemble = EnsembleNER::new();
347///
348/// // Custom: specific backends
349/// let custom = EnsembleNER::with_backends(vec![
350///     Box::new(RegexNER::new()),
351///     Box::new(HeuristicNER::new()),
352/// ]);
353///
354/// let entities = custom.extract_entities("Contact us at test@example.com", None)?;
355/// # Ok::<(), anno::Error>(())
356/// ```
357///
358/// # Algorithm
359///
360/// 1. **Collect candidates**: Run each backend, tag results with provenance
361/// 2. **Cluster overlaps**: Group candidates with >50% span overlap
362/// 3. **Weighted vote**: Score each candidate by `backend_weight * confidence`
363/// 4. **Agreement bonus**: Add +0.10 when 2+ backends agree on type
364/// 5. **Select winner**: Highest weighted score wins the cluster
365///
366/// [`StackedNER`]: super::stacked::StackedNER
367pub struct EnsembleNER {
368    backends: Vec<Arc<dyn Model + Send + Sync>>,
369    /// Stable backend IDs used for weighting and source tracking.
370    ///
371    /// This is intentionally decoupled from `backend.name()`, which is a
372    /// human-facing label and may vary across implementations (e.g. "GLiNER-ONNX").
373    backend_ids: Vec<String>,
374    weights: HashMap<String, BackendWeight>,
375    agreement_bonus: f64,
376    min_confidence: f64,
377    /// Transparent name showing constituent backends (e.g., "ensemble(regex|gliner|heuristic)")
378    name: String,
379    /// Cached static name (avoids Box::leak on every name() call)
380    name_static: std::sync::OnceLock<&'static str>,
381}
382
383impl Default for EnsembleNER {
384    fn default() -> Self {
385        Self::new()
386    }
387}
388
389impl EnsembleNER {
390    /// Create ensemble with all available backends.
391    #[must_use]
392    pub fn new() -> Self {
393        let mut backends: Vec<Arc<dyn Model + Send + Sync>> = Vec::new();
394        let mut backend_ids: Vec<&'static str> = Vec::new();
395
396        // Always add pattern (high precision for structured data)
397        backends.push(Arc::new(crate::RegexNER::new()));
398        backend_ids.push("regex");
399
400        // Add GLiNER if available
401        #[cfg(feature = "onnx")]
402        {
403            use super::GLiNEROnnx;
404            use crate::DEFAULT_GLINER_MODEL;
405            if let Ok(gliner) = GLiNEROnnx::new(DEFAULT_GLINER_MODEL) {
406                backends.push(Arc::new(gliner));
407                backend_ids.push("gliner");
408            }
409        }
410
411        // Add Candle GLiNER if available
412        #[cfg(feature = "candle")]
413        {
414            use super::GLiNERCandle;
415            use crate::DEFAULT_GLINER_MODEL;
416            if let Ok(candle) = GLiNERCandle::from_pretrained(DEFAULT_GLINER_MODEL) {
417                backends.push(Arc::new(candle));
418                backend_ids.push("gliner-candle");
419            }
420        }
421
422        // Always add heuristic as fallback
423        backends.push(Arc::new(crate::HeuristicNER::new()));
424        backend_ids.push("heuristic");
425
426        // Build transparent name showing constituents
427        // Use '|' for parallel weighted voting (no priority ordering)
428        let name = format!("ensemble({})", backend_ids.join("|"));
429
430        // Convert default weights to owned strings
431        let weights: HashMap<String, BackendWeight> = default_backend_weights()
432            .into_iter()
433            .map(|(k, v)| (k.to_string(), v))
434            .collect();
435
436        Self {
437            backends,
438            backend_ids: backend_ids.into_iter().map(str::to_string).collect(),
439            weights,
440            agreement_bonus: 0.10,
441            min_confidence: 0.30,
442            name,
443            name_static: std::sync::OnceLock::new(),
444        }
445    }
446
447    /// Create with custom backends.
448    #[must_use]
449    pub fn with_backends(backends: Vec<Box<dyn Model + Send + Sync>>) -> Self {
450        // For custom backends, use the backend's reported name as both ID and display string.
451        let backend_ids: Vec<String> = backends.iter().map(|b| b.name().to_string()).collect();
452        let name = format!("ensemble({})", backend_ids.join("|"));
453
454        let backends: Vec<Arc<dyn Model + Send + Sync>> =
455            backends.into_iter().map(Arc::from).collect();
456
457        let weights: HashMap<String, BackendWeight> = default_backend_weights()
458            .into_iter()
459            .map(|(k, v)| (k.to_string(), v))
460            .collect();
461
462        Self {
463            backends,
464            backend_ids,
465            weights,
466            agreement_bonus: 0.10,
467            min_confidence: 0.30,
468            name,
469            name_static: std::sync::OnceLock::new(),
470        }
471    }
472
473    /// Set custom backend weights.
474    #[must_use]
475    pub fn with_weights(mut self, weights: HashMap<String, BackendWeight>) -> Self {
476        self.weights = weights;
477        self
478    }
479
480    /// Set the agreement bonus (added when multiple backends agree).
481    #[must_use]
482    pub fn with_agreement_bonus(mut self, bonus: f64) -> Self {
483        self.agreement_bonus = bonus;
484        self
485    }
486
487    /// Set minimum confidence threshold.
488    #[must_use]
489    pub fn with_min_confidence(mut self, min: f64) -> Self {
490        self.min_confidence = min;
491        self
492    }
493
494    /// Get the weight for a backend and entity type.
495    fn get_weight(&self, backend_name: &str, entity_type: &EntityType) -> f64 {
496        if let Some(weight) = self.weights.get(backend_name) {
497            if let Some(ref type_weights) = weight.per_type {
498                type_weights.get(entity_type)
499            } else {
500                weight.overall
501            }
502        } else {
503            // Unknown backend - use conservative default
504            0.50
505        }
506    }
507
508    /// Resolve overlapping candidates using weighted voting.
509    fn resolve_candidates(&self, candidates: Vec<Candidate>) -> Option<Entity> {
510        if candidates.is_empty() {
511            return None;
512        }
513
514        if candidates.len() == 1 {
515            // Single candidate - use its confidence directly
516            let candidate = candidates
517                .into_iter()
518                .next()
519                .expect("candidates.len() == 1 guarantees next() is Some");
520            let mut entity = candidate.entity;
521            let original_prov = entity.provenance.clone();
522            let original_confidence = entity.confidence;
523            // Slight penalty for single-source
524            entity.confidence *= 0.95;
525            // Set provenance for single-source entities
526            entity.provenance = Some(anno_core::Provenance {
527                source: std::borrow::Cow::Owned(format!("ensemble({})", candidate.source)),
528                // Preserve underlying method/pattern when possible (important for nested ensembles).
529                method: original_prov
530                    .as_ref()
531                    .map(|p| p.method)
532                    .unwrap_or_else(|| method_for_backend_id(&candidate.source)),
533                pattern: original_prov.as_ref().and_then(|p| p.pattern.clone()),
534                raw_confidence: original_prov
535                    .as_ref()
536                    .and_then(|p| p.raw_confidence)
537                    .or(Some(original_confidence)),
538                model_version: None,
539                timestamp: None,
540            });
541            return Some(entity);
542        }
543
544        // Group by entity type
545        let mut type_votes: HashMap<String, Vec<&Candidate>> = HashMap::new();
546        for c in &candidates {
547            let type_key = c.entity.entity_type.as_label().to_string();
548            type_votes.entry(type_key).or_default().push(c);
549        }
550
551        // Find the type with highest weighted vote (deterministic tie-breaking).
552        //
553        // HashMap iteration order can vary across process runs. If two types tie on
554        // weighted_sum, we still need a stable selection.
555        //
556        // Ordering:
557        // 1) Higher weighted_sum wins
558        // 2) If tied, more candidates (more votes) wins
559        // 3) If tied, lexicographically smaller type key wins
560        let mut best_type: Option<(String, f64, usize, Vec<&Candidate>)> = None;
561        for (type_key, type_candidates) in &type_votes {
562            let weighted_sum: f64 = type_candidates
563                .iter()
564                .map(|c| c.backend_weight * c.entity.confidence)
565                .sum();
566            let count = type_candidates.len();
567
568            let should_replace = match &best_type {
569                None => true,
570                Some((best_key, best_sum, best_count, _)) => {
571                    if weighted_sum > *best_sum {
572                        true
573                    } else if weighted_sum < *best_sum {
574                        false
575                    } else if count > *best_count {
576                        true
577                    } else if count < *best_count {
578                        false
579                    } else {
580                        type_key < best_key
581                    }
582                }
583            };
584
585            if should_replace {
586                best_type = Some((
587                    type_key.clone(),
588                    weighted_sum,
589                    count,
590                    type_candidates.clone(),
591                ));
592            }
593        }
594
595        let (_type_key, weighted_sum, _count, winning_candidates) = best_type?;
596
597        // Calculate ensemble confidence
598        let num_sources = winning_candidates.len();
599        let total_weight: f64 = winning_candidates.iter().map(|c| c.backend_weight).sum();
600
601        let base_confidence = if total_weight > 0.0 {
602            weighted_sum / total_weight
603        } else {
604            0.5
605        };
606
607        // Agreement bonus
608        let agreement_bonus = if num_sources >= 3 {
609            self.agreement_bonus * 1.5
610        } else if num_sources >= 2 {
611            self.agreement_bonus
612        } else {
613            0.0
614        };
615
616        let final_confidence = (base_confidence + agreement_bonus).min(1.0);
617
618        // Build merged entity
619        // Use the candidate with highest individual confidence as base
620        let best_candidate = winning_candidates.iter().max_by(|a, b| {
621            a.entity
622                .confidence
623                .partial_cmp(&b.entity.confidence)
624                .unwrap_or(std::cmp::Ordering::Equal)
625        })?;
626
627        let sources: Vec<String> = winning_candidates
628            .iter()
629            .map(|c| c.source.clone())
630            .collect();
631
632        // Calculate hierarchical confidence scores
633        // - linkage: How many backends detected an entity here (normalized)
634        // - type_score: Agreement on type classification
635        // - boundary: Agreement on exact span boundaries
636        let total_candidates = candidates.len() as f32;
637        let num_winners = winning_candidates.len() as f32;
638
639        // Linkage: ratio of candidates in winning type
640        let linkage = if total_candidates > 0.0 {
641            (num_winners / total_candidates).min(1.0)
642        } else {
643            0.5
644        };
645
646        // Type score: confidence in the winning type (weighted)
647        let type_score = final_confidence as f32;
648
649        // Boundary: agreement on span boundaries
650        // Check if all winning candidates have the same start/end
651        let reference_span = (best_candidate.entity.start, best_candidate.entity.end);
652        let span_agreement_count = winning_candidates
653            .iter()
654            .filter(|c| c.entity.start == reference_span.0 && c.entity.end == reference_span.1)
655            .count();
656        let boundary = if num_winners > 0.0 {
657            (span_agreement_count as f32 / num_winners).min(1.0)
658        } else {
659            1.0
660        };
661
662        let mut entity = best_candidate.entity.clone();
663        entity.confidence = final_confidence;
664        entity.hierarchical_confidence = Some(anno_core::HierarchicalConfidence::new(
665            linkage, type_score, boundary,
666        ));
667        entity.provenance = Some(anno_core::Provenance {
668            source: Cow::Owned(format!("ensemble({})", sources.join("+"))),
669            method: anno_core::ExtractionMethod::Consensus,
670            pattern: None,
671            raw_confidence: Some(base_confidence),
672            model_version: None,
673            timestamp: None,
674        });
675
676        Some(entity)
677    }
678}
679
680impl Model for EnsembleNER {
681    fn extract_entities(&self, text: &str, language: Option<&str>) -> Result<Vec<Entity>> {
682        if self.backends.is_empty() {
683            return Ok(Vec::new());
684        }
685
686        // Phase 1: Collect candidates from all backends
687        let mut all_candidates: Vec<Candidate> = Vec::new();
688
689        for (i, backend) in self.backends.iter().enumerate() {
690            let backend_id = self
691                .backend_ids
692                .get(i)
693                .cloned()
694                .unwrap_or_else(|| backend.name().to_string());
695
696            match backend.extract_entities(text, language) {
697                Ok(entities) => {
698                    for entity in entities {
699                        let weight = self.get_weight(&backend_id, &entity.entity_type);
700                        all_candidates.push(Candidate {
701                            entity,
702                            source: backend_id.clone(),
703                            backend_weight: weight,
704                        });
705                    }
706                }
707                Err(e) => {
708                    // Log but continue (opportunistic)
709                    log::debug!(
710                        "EnsembleNER: Backend {} (id={}) failed: {}",
711                        backend.name(),
712                        backend_id,
713                        e
714                    );
715                }
716            }
717        }
718
719        if all_candidates.is_empty() {
720            return Ok(Vec::new());
721        }
722
723        // Phase 2: Group candidates by overlapping spans
724        let mut span_groups: Vec<Vec<Candidate>> = Vec::new();
725
726        for candidate in all_candidates {
727            let span = SpanKey::from_entity(&candidate.entity);
728
729            // Find existing group with overlapping span
730            let mut found_group = false;
731            for group in &mut span_groups {
732                if let Some(first) = group.first() {
733                    let existing_span = SpanKey::from_entity(&first.entity);
734                    if span.overlaps(&existing_span) {
735                        group.push(candidate.clone());
736                        found_group = true;
737                        break;
738                    }
739                }
740            }
741
742            if !found_group {
743                span_groups.push(vec![candidate]);
744            }
745        }
746
747        // Phase 3: Resolve each group
748        let mut results: Vec<Entity> = Vec::new();
749
750        for group in span_groups {
751            if let Some(entity) = self.resolve_candidates(group) {
752                if entity.confidence >= self.min_confidence {
753                    results.push(entity);
754                }
755            }
756        }
757
758        // Sort by position
759        results.sort_by_key(|e| (e.start, e.end));
760
761        Ok(results)
762    }
763
764    fn supported_types(&self) -> Vec<EntityType> {
765        // Union of all backend types
766        let mut types: Vec<EntityType> = Vec::new();
767        for backend in &self.backends {
768            for t in backend.supported_types() {
769                if !types.contains(&t) {
770                    types.push(t);
771                }
772            }
773        }
774        types
775    }
776
777    fn is_available(&self) -> bool {
778        // Available if at least one backend is available
779        self.backends.iter().any(|b| b.is_available())
780    }
781
782    fn name(&self) -> &'static str {
783        // Use OnceLock to cache the static string, avoiding repeated memory leaks
784        self.name_static
785            .get_or_init(|| Box::leak(self.name.clone().into_boxed_str()))
786    }
787
788    fn description(&self) -> &'static str {
789        "Ensemble NER: weighted voting across multiple backends"
790    }
791}
792
793// Implement required traits
794impl crate::NamedEntityCapable for EnsembleNER {}
795
796impl crate::BatchCapable for EnsembleNER {
797    fn optimal_batch_size(&self) -> Option<usize> {
798        Some(8) // Reasonable default for ensemble
799    }
800}
801
802impl crate::StreamingCapable for EnsembleNER {
803    fn recommended_chunk_size(&self) -> usize {
804        8192
805    }
806}
807
808// =============================================================================
809// Weight Learning
810// =============================================================================
811
812/// Training example for weight learning.
813#[derive(Debug, Clone)]
814pub struct WeightTrainingExample {
815    /// Text of the entity
816    pub text: String,
817    /// True entity type (gold label)
818    pub gold_type: EntityType,
819    /// Span start
820    pub start: usize,
821    /// Span end
822    pub end: usize,
823    /// Predictions from each backend: (backend_name, predicted_type, confidence)
824    pub predictions: Vec<(String, EntityType, f64)>,
825}
826
827/// Statistics for weight learning.
828#[derive(Debug, Clone, Default)]
829pub struct BackendStats {
830    /// Total correct predictions
831    pub correct: usize,
832    /// Total predictions made
833    pub total: usize,
834    /// Per-type statistics: (type, correct, total)
835    pub per_type: HashMap<String, (usize, usize)>,
836}
837
838impl BackendStats {
839    /// Calculate overall precision.
840    pub fn precision(&self) -> f64 {
841        if self.total == 0 {
842            0.0
843        } else {
844            self.correct as f64 / self.total as f64
845        }
846    }
847
848    /// Calculate per-type precision.
849    pub fn type_precision(&self, entity_type: &str) -> f64 {
850        if let Some((correct, total)) = self.per_type.get(entity_type) {
851            if *total == 0 {
852                0.0
853            } else {
854                *correct as f64 / *total as f64
855            }
856        } else {
857            0.0
858        }
859    }
860}
861
862/// Weight learner for EnsembleNER.
863///
864/// Learns optimal backend weights from evaluation data.
865///
866/// # Example
867///
868/// ```rust,ignore
869/// use anno::backends::ensemble::{EnsembleNER, WeightLearner};
870///
871/// let mut learner = WeightLearner::new();
872///
873/// // Add training examples from gold data
874/// for (text, gold_entities) in gold_data {
875///     learner.add_examples(&text, &gold_entities, &backends);
876/// }
877///
878/// // Learn weights
879/// let learned_weights = learner.learn_weights();
880///
881/// // Create ensemble with learned weights
882/// let ensemble = EnsembleNER::new().with_weights(learned_weights);
883/// ```
884pub struct WeightLearner {
885    /// Per-backend statistics
886    backend_stats: HashMap<String, BackendStats>,
887    /// Smoothing factor for precision (avoid division by zero / overfitting)
888    smoothing: f64,
889}
890
891impl Default for WeightLearner {
892    fn default() -> Self {
893        Self::new()
894    }
895}
896
897impl WeightLearner {
898    /// Create a new weight learner.
899    #[must_use]
900    pub fn new() -> Self {
901        Self {
902            backend_stats: HashMap::new(),
903            smoothing: 1.0, // Laplace smoothing
904        }
905    }
906
907    /// Set smoothing factor.
908    #[must_use]
909    pub fn with_smoothing(mut self, smoothing: f64) -> Self {
910        self.smoothing = smoothing;
911        self
912    }
913
914    /// Add a training example.
915    pub fn add_example(&mut self, example: &WeightTrainingExample) {
916        for (backend_name, predicted_type, _confidence) in &example.predictions {
917            let stats = self.backend_stats.entry(backend_name.clone()).or_default();
918
919            stats.total += 1;
920            let correct = *predicted_type == example.gold_type;
921            if correct {
922                stats.correct += 1;
923            }
924
925            // Per-type stats
926            let type_key = example.gold_type.as_label().to_string();
927            let type_stats = stats.per_type.entry(type_key).or_insert((0, 0));
928            type_stats.1 += 1;
929            if correct {
930                type_stats.0 += 1;
931            }
932        }
933    }
934
935    /// Add examples from gold entities and backend predictions.
936    ///
937    /// Runs each backend on the text and compares to gold entities.
938    pub fn add_from_backends(
939        &mut self,
940        text: &str,
941        gold_entities: &[Entity],
942        backends: &[(&str, &dyn Model)],
943    ) {
944        // Get predictions from each backend
945        let mut backend_preds: HashMap<String, Vec<Entity>> = HashMap::new();
946        for (name, backend) in backends {
947            if let Ok(entities) = backend.extract_entities(text, None) {
948                backend_preds.insert(name.to_string(), entities);
949            }
950        }
951
952        // Match predictions to gold entities
953        for gold in gold_entities {
954            let mut example = WeightTrainingExample {
955                text: gold.text.clone(),
956                gold_type: gold.entity_type.clone(),
957                start: gold.start,
958                end: gold.end,
959                predictions: Vec::new(),
960            };
961
962            for (backend_name, entities) in &backend_preds {
963                // Find matching prediction (same span)
964                for pred in entities {
965                    if pred.start == gold.start && pred.end == gold.end {
966                        example.predictions.push((
967                            backend_name.clone(),
968                            pred.entity_type.clone(),
969                            pred.confidence,
970                        ));
971                        break;
972                    }
973                }
974            }
975
976            if !example.predictions.is_empty() {
977                self.add_example(&example);
978            }
979        }
980    }
981
982    /// Learn optimal weights from accumulated statistics.
983    ///
984    /// Uses precision-based weighting with Laplace smoothing.
985    pub fn learn_weights(&self) -> HashMap<String, BackendWeight> {
986        let mut weights = HashMap::new();
987
988        for (backend_name, stats) in &self.backend_stats {
989            // Smoothed precision: (correct + smoothing) / (total + 2*smoothing)
990            let smoothed_precision = (stats.correct as f64 + self.smoothing)
991                / (stats.total as f64 + 2.0 * self.smoothing);
992
993            // Per-type weights
994            let mut type_weights = TypeWeights::default();
995            for (type_key, (correct, total)) in &stats.per_type {
996                let type_precision =
997                    (*correct as f64 + self.smoothing) / (*total as f64 + 2.0 * self.smoothing);
998
999                match type_key.as_str() {
1000                    "PER" | "PERSON" => type_weights.person = type_precision,
1001                    "ORG" | "ORGANIZATION" => type_weights.organization = type_precision,
1002                    "LOC" | "LOCATION" | "GPE" => type_weights.location = type_precision,
1003                    "DATE" => type_weights.date = type_precision,
1004                    "MONEY" => type_weights.money = type_precision,
1005                    _ => type_weights.other = type_precision,
1006                }
1007            }
1008
1009            weights.insert(
1010                backend_name.clone(),
1011                BackendWeight {
1012                    overall: smoothed_precision,
1013                    per_type: Some(type_weights),
1014                },
1015            );
1016        }
1017
1018        weights
1019    }
1020
1021    /// Get statistics for a backend.
1022    pub fn get_stats(&self, backend_name: &str) -> Option<&BackendStats> {
1023        self.backend_stats.get(backend_name)
1024    }
1025
1026    /// Get all backend names.
1027    pub fn backend_names(&self) -> Vec<&String> {
1028        self.backend_stats.keys().collect()
1029    }
1030}
1031
1032// =============================================================================
1033// Tests
1034// =============================================================================
1035
1036#[cfg(test)]
1037mod tests {
1038    use super::*;
1039    use anno_core::ExtractionMethod;
1040
1041    fn fast_ensemble() -> EnsembleNER {
1042        // Keep unit tests deterministic and fast: do not initialize model-loading backends here.
1043        EnsembleNER::with_backends(vec![
1044            Box::new(crate::RegexNER::new()),
1045            Box::new(crate::HeuristicNER::new()),
1046        ])
1047    }
1048
1049    #[test]
1050    fn test_new_backend_ids_have_weights() {
1051        let ner = EnsembleNER::new();
1052
1053        // For the built-in constructor, we require stable IDs so weights apply as intended.
1054        assert!(
1055            !ner.backend_ids.is_empty(),
1056            "EnsembleNER::new() should have at least one backend"
1057        );
1058
1059        for id in &ner.backend_ids {
1060            assert!(
1061                ner.weights.contains_key(id),
1062                "EnsembleNER::new(): missing weight for backend id={:?}. This usually means the ensemble's advertised IDs drifted from default_backend_weights keys.",
1063                id
1064            );
1065        }
1066    }
1067
1068    #[test]
1069    fn test_ensemble_basic() {
1070        let ner = fast_ensemble();
1071        let entities = ner
1072            .extract_entities("Tim Cook is the CEO of Apple Inc.", None)
1073            .unwrap();
1074
1075        // Should find at least some entities
1076        assert!(!entities.is_empty(), "Should extract entities");
1077
1078        // Check that provenance exists (may or may not say "ensemble" for single-source entities)
1079        for e in &entities {
1080            assert!(
1081                e.provenance.is_some(),
1082                "All entities should have provenance"
1083            );
1084        }
1085    }
1086
1087    #[test]
1088    fn test_span_overlap() {
1089        // Span1 [0-10], Span2 [5-15]: overlap [5-10] = 5 chars
1090        // Smaller span = 10 chars, overlap/smaller = 5/10 = 0.5
1091        // Need >0.5 so this is borderline - adjust test
1092        let span1 = SpanKey { start: 0, end: 10 };
1093        let span2 = SpanKey { start: 3, end: 15 }; // overlap [3-10] = 7 chars, 7/10 = 0.7 > 0.5
1094        let span3 = SpanKey { start: 20, end: 30 };
1095
1096        assert!(span1.overlaps(&span2), "Overlapping spans should match");
1097        assert!(
1098            !span1.overlaps(&span3),
1099            "Non-overlapping spans should not match"
1100        );
1101    }
1102
1103    #[test]
1104    fn test_backend_weights() {
1105        let weights = default_backend_weights();
1106
1107        // Pattern should have high weight
1108        assert!(weights["regex"].overall > 0.9);
1109
1110        // GLiNER should have good weight
1111        assert!(weights["gliner"].overall > 0.8);
1112
1113        // Heuristic should have lower weight
1114        assert!(weights["heuristic"].overall < 0.7);
1115    }
1116
1117    #[test]
1118    fn test_type_specific_weights() {
1119        let weights = default_backend_weights();
1120
1121        // Pattern should be best for dates
1122        let pattern_date = weights["regex"].per_type.as_ref().unwrap().date;
1123        let heuristic_date = weights["heuristic"].per_type.as_ref().unwrap().date;
1124        assert!(pattern_date > heuristic_date);
1125
1126        // Heuristic should be decent for orgs
1127        let heuristic_org = weights["heuristic"].per_type.as_ref().unwrap().organization;
1128        assert!(heuristic_org > 0.6);
1129    }
1130
1131    #[test]
1132    fn test_agreement_bonus() {
1133        let ner = fast_ensemble().with_agreement_bonus(0.15);
1134        assert!((ner.agreement_bonus - 0.15).abs() < 0.001);
1135    }
1136
1137    #[test]
1138    fn test_weight_learner_basic() {
1139        let mut learner = WeightLearner::new();
1140
1141        // Add some training examples
1142        learner.add_example(&WeightTrainingExample {
1143            text: "Apple".to_string(),
1144            gold_type: EntityType::Organization,
1145            start: 0,
1146            end: 5,
1147            predictions: vec![
1148                ("heuristic".to_string(), EntityType::Organization, 0.8),
1149                ("gliner".to_string(), EntityType::Organization, 0.9),
1150            ],
1151        });
1152
1153        learner.add_example(&WeightTrainingExample {
1154            text: "Paris".to_string(),
1155            gold_type: EntityType::Location,
1156            start: 0,
1157            end: 5,
1158            predictions: vec![
1159                ("heuristic".to_string(), EntityType::Person, 0.6), // Wrong!
1160                ("gliner".to_string(), EntityType::Location, 0.85),
1161            ],
1162        });
1163
1164        // Learn weights
1165        let weights = learner.learn_weights();
1166
1167        // GLiNER should have higher weight (2/2 correct vs 1/2)
1168        let gliner_weight = weights.get("gliner").map(|w| w.overall).unwrap_or(0.0);
1169        let heuristic_weight = weights.get("heuristic").map(|w| w.overall).unwrap_or(0.0);
1170
1171        assert!(
1172            gliner_weight > heuristic_weight,
1173            "GLiNER should have higher weight (was {} vs {})",
1174            gliner_weight,
1175            heuristic_weight
1176        );
1177    }
1178
1179    #[test]
1180    fn test_backend_stats() {
1181        let mut stats = BackendStats {
1182            correct: 8,
1183            total: 10,
1184            ..Default::default()
1185        };
1186        stats.per_type.insert("PER".to_string(), (5, 6));
1187
1188        assert!((stats.precision() - 0.8).abs() < 0.01);
1189        assert!((stats.type_precision("PER") - 0.833).abs() < 0.01);
1190        assert!((stats.type_precision("ORG") - 0.0).abs() < 0.01); // Unknown type
1191    }
1192
1193    // =========================================================================
1194    // Additional Edge Case Tests
1195    // =========================================================================
1196
1197    #[test]
1198    fn test_empty_text() {
1199        let ner = fast_ensemble();
1200        let entities = ner.extract_entities("", None).unwrap();
1201        assert!(entities.is_empty());
1202    }
1203
1204    #[test]
1205    fn test_whitespace_only_text() {
1206        let ner = fast_ensemble();
1207        let entities = ner.extract_entities("   \t\n   ", None).unwrap();
1208        assert!(entities.is_empty());
1209    }
1210
1211    #[test]
1212    fn test_resolve_candidates_tie_break_is_order_independent() {
1213        let ner = fast_ensemble();
1214        let span_text = "Apple";
1215        let span = (0, 5);
1216
1217        let e_person = Entity::new(span_text, EntityType::Person, span.0, span.1, 0.5);
1218        let e_org = Entity::new(span_text, EntityType::Organization, span.0, span.1, 0.5);
1219
1220        let c1 = Candidate {
1221            entity: e_person,
1222            source: "heuristic".to_string(),
1223            backend_weight: 1.0,
1224        };
1225        let c2 = Candidate {
1226            entity: e_org,
1227            source: "heuristic".to_string(),
1228            backend_weight: 1.0,
1229        };
1230
1231        let out_a = ner
1232            .resolve_candidates(vec![c1.clone(), c2.clone()])
1233            .expect("should resolve");
1234        let out_b = ner
1235            .resolve_candidates(vec![c2, c1])
1236            .expect("should resolve");
1237
1238        assert_eq!(
1239            out_a.entity_type, out_b.entity_type,
1240            "tie resolution should not depend on candidate order"
1241        );
1242
1243        let key_a = out_a.entity_type.as_label().to_string();
1244        let person_key = EntityType::Person.as_label().to_string();
1245        let org_key = EntityType::Organization.as_label().to_string();
1246        let expected = std::cmp::min(person_key, org_key);
1247        assert_eq!(
1248            key_a, expected,
1249            "tie-break should choose lexicographically smallest type label"
1250        );
1251    }
1252
1253    #[test]
1254    fn test_single_source_preserves_underlying_method_and_pattern() {
1255        // With a single backend, ensemble should preserve the backend's extraction method/pattern
1256        // (important for explainability and nested composition).
1257        let ner = EnsembleNER::with_backends(vec![Box::new(crate::RegexNER::new())]);
1258        let text = "Contact test@email.com on 2024-01-15";
1259        let entities = ner.extract_entities(text, None).expect("extract");
1260        assert!(!entities.is_empty());
1261
1262        let email = entities
1263            .iter()
1264            .find(|e| e.text == "test@email.com")
1265            .expect("email entity should exist");
1266        let prov = email.provenance.as_ref().expect("provenance");
1267
1268        assert_eq!(prov.method, ExtractionMethod::Pattern);
1269        assert!(
1270            prov.pattern.is_some(),
1271            "expected to preserve regex pattern name"
1272        );
1273    }
1274
1275    #[test]
1276    fn test_nested_single_source_preserves_inner_method() {
1277        // Inner ensemble produces provenance.method = Heuristic; outer should not overwrite it
1278        // to Neural just because the backend id is "ensemble(...)".
1279        let inner = EnsembleNER::with_backends(vec![Box::new(crate::HeuristicNER::new())]);
1280        let outer = EnsembleNER::with_backends(vec![Box::new(inner)]);
1281
1282        let text = "John Smith visited Paris.";
1283        let entities = outer.extract_entities(text, None).expect("extract");
1284        assert!(!entities.is_empty());
1285
1286        for e in &entities {
1287            let prov = e.provenance.as_ref().expect("provenance");
1288            assert_eq!(
1289                prov.method,
1290                ExtractionMethod::Heuristic,
1291                "expected outer to preserve inner method"
1292            );
1293        }
1294    }
1295
1296    #[test]
1297    fn test_span_key_self_overlap() {
1298        let span = SpanKey { start: 0, end: 10 };
1299        assert!(span.overlaps(&span), "Span should overlap with itself");
1300    }
1301
1302    #[test]
1303    fn test_span_key_adjacent_no_overlap() {
1304        let span1 = SpanKey { start: 0, end: 10 };
1305        let span2 = SpanKey { start: 10, end: 20 };
1306        assert!(!span1.overlaps(&span2), "Adjacent spans should not overlap");
1307    }
1308
1309    #[test]
1310    fn test_span_key_contained() {
1311        let outer = SpanKey { start: 0, end: 20 };
1312        let inner = SpanKey { start: 5, end: 15 };
1313        assert!(outer.overlaps(&inner), "Contained spans should overlap");
1314        assert!(inner.overlaps(&outer), "Overlap should be symmetric");
1315    }
1316
1317    #[test]
1318    fn test_backend_stats_empty() {
1319        let stats = BackendStats::default();
1320        assert!((stats.precision() - 0.0).abs() < 0.001);
1321        assert!((stats.type_precision("ANY") - 0.0).abs() < 0.001);
1322    }
1323
1324    #[test]
1325    fn test_weight_learner_empty() {
1326        let learner = WeightLearner::new();
1327        let weights = learner.learn_weights();
1328        // Empty learner returns empty weights (caller should use defaults)
1329        // Just verify it doesn't panic and returns a valid HashMap
1330        let _ = weights.len();
1331    }
1332
1333    #[test]
1334    fn test_ensemble_with_language() {
1335        let ner = fast_ensemble();
1336
1337        // Try with English language hint
1338        let entities = ner
1339            .extract_entities("Tim Cook is the CEO of Apple.", Some("en"))
1340            .unwrap();
1341
1342        // Should find entities (language hint shouldn't break anything)
1343        assert!(
1344            !entities.is_empty(),
1345            "Should find entities with language hint"
1346        );
1347    }
1348
1349    #[test]
1350    fn test_type_weights_structure() {
1351        let weights = TypeWeights {
1352            person: 0.9,
1353            location: 0.85,
1354            organization: 0.88,
1355            date: 0.95,
1356            money: 0.8,
1357            other: 0.7,
1358        };
1359
1360        assert!(weights.person > 0.0);
1361        assert!(weights.date > weights.other);
1362    }
1363
1364    #[test]
1365    fn test_backend_weight_structure() {
1366        let weight = BackendWeight {
1367            overall: 0.85,
1368            per_type: Some(TypeWeights {
1369                person: 0.9,
1370                location: 0.88,
1371                organization: 0.87,
1372                date: 0.92,
1373                money: 0.85,
1374                other: 0.75,
1375            }),
1376        };
1377
1378        assert!(weight.overall > 0.0);
1379        assert!(weight.per_type.is_some());
1380    }
1381
1382    #[test]
1383    fn test_unicode_extraction() {
1384        let ner = EnsembleNER::new();
1385        let entities = ner
1386            .extract_entities("東京で会議がありました。", None)
1387            .unwrap();
1388
1389        // Should not crash on Unicode
1390        for e in &entities {
1391            assert!(e.confidence >= 0.0 && e.confidence <= 1.0);
1392        }
1393    }
1394
1395    #[test]
1396    fn test_ensemble_provenance_tracking() {
1397        let ner = EnsembleNER::new();
1398        let entities = ner
1399            .extract_entities("Barack Obama visited Paris yesterday.", None)
1400            .unwrap();
1401
1402        for e in &entities {
1403            // All entities should have provenance
1404            assert!(
1405                e.provenance.is_some(),
1406                "Entity '{}' ({:?}) at {}..{} has no provenance",
1407                e.text,
1408                e.entity_type,
1409                e.start,
1410                e.end
1411            );
1412            let prov = e.provenance.as_ref().unwrap();
1413            // Provenance source should not be empty
1414            assert!(!prov.source.is_empty());
1415        }
1416    }
1417}