Skip to main content

anno/backends/ensemble/
mod.rs

1//! Ensemble NER - Multi-backend extraction with unsupervised weighted voting.
2//!
3//! # Method
4//!
5//! This is an **unsupervised heuristic** approach (no training data required).
6//! Conflict resolution uses hand-tuned weights based on expected backend reliability.
7//! For supervised weight learning from labeled data, see `WeightLearner`.
8//!
9//! # The Core Idea
10//!
11//! Instead of simple priority-based stacking, `EnsembleNER`:
12//! 1. Runs ALL available backends opportunistically (in parallel conceptually)
13//! 2. Collects candidate entities with provenance
14//! 3. Groups overlapping spans into conflict clusters
15//! 4. Resolves conflicts using weighted voting with agreement bonus
16//!
17//! ```text
18//! ┌─────────────────────────────────────────────────────────────────────────┐
19//! │                    ENSEMBLE NER ARCHITECTURE                            │
20//! ├─────────────────────────────────────────────────────────────────────────┤
21//! │                                                                         │
22//! │  Input: "Tim Cook, CEO of Apple, met with Sundar Pichai"                │
23//! │                                                                         │
24//! │  ┌──────────────────────────────────────────────────────────────────┐   │
25//! │  │ PHASE 1: OPPORTUNISTIC EXTRACTION (parallel)                     │   │
26//! │  │                                                                  │   │
27//! │  │  Pattern ──────► [no entities]                                   │   │
28//! │  │  Heuristic ────► Tim Cook (PER, 0.75), Apple (ORG, 0.80), ...    │   │
29//! │  │  GLiNER ────────► Tim Cook (PER, 0.95), Apple (ORG, 0.87), ...   │   │
30//! │  │  Candle ────────► [unavailable, skip]                            │   │
31//! │  └──────────────────────────────────────────────────────────────────┘   │
32//! │                            │                                            │
33//! │                            ▼                                            │
34//! │  ┌──────────────────────────────────────────────────────────────────┐   │
35//! │  │ PHASE 2: CANDIDATE AGGREGATION                                   │   │
36//! │  │                                                                  │   │
37//! │  │  Span [0:8] "Tim Cook":                                          │   │
38//! │  │    • Heuristic: PER (0.75)                                       │   │
39//! │  │    • GLiNER: PER (0.95)                                          │   │
40//! │  │    Agreement: 2/2 → HIGH confidence                              │   │
41//! │  │                                                                  │   │
42//! │  │  Span [17:22] "Apple":                                           │   │
43//! │  │    • Heuristic: ORG (0.80)                                       │   │
44//! │  │    • GLiNER: ORG (0.87)                                          │   │
45//! │  │    Agreement: 2/2 → HIGH confidence                              │   │
46//! │  └──────────────────────────────────────────────────────────────────┘   │
47//! │                            │                                            │
48//! │                            ▼                                            │
49//! │  ┌──────────────────────────────────────────────────────────────────┐   │
50//! │  │ PHASE 3: CONFLICT RESOLUTION (weighted voting)                   │   │
51//! │  │                                                                  │   │
52//! │  │  Backend weights (learned or configured):                        │   │
53//! │  │    Pattern: 0.99 (when fires, almost always right)               │   │
54//! │  │    GLiNER:  0.85 (ML-based, good accuracy)                       │   │
55//! │  │    Heuristic: 0.65 (reasonable but noisy)                        │   │
56//! │  │                                                                  │   │
57//! │  │  For span [0:8]:                                                 │   │
58//! │  │    Weighted vote = (0.65 * 0.75) + (0.85 * 0.95) = 1.29          │   │
59//! │  │    Normalized confidence = 0.91                                  │   │
60//! │  └──────────────────────────────────────────────────────────────────┘   │
61//! │                            │                                            │
62//! │                            ▼                                            │
63//! │  ┌──────────────────────────────────────────────────────────────────┐   │
64//! │  │ OUTPUT                                                           │   │
65//! │  │                                                                  │   │
66//! │  │  Entity { text: "Tim Cook", type: PER, conf: 0.91,               │   │
67//! │  │           sources: ["heuristic", "gliner"], agreement: 1.0 }     │   │
68//! │  └──────────────────────────────────────────────────────────────────┘   │
69//! │                                                                         │
70//! └─────────────────────────────────────────────────────────────────────────┘
71//! ```
72//!
73//! # Conflict Resolution Strategies
74//!
75//! ## Weighted Voting (Unsupervised)
76//!
77//! Each backend has a weight based on its expected reliability:
78//! - Pattern backends: high weight (0.95+) when they fire
79//! - ML backends: medium-high weight (0.80-0.90)
80//! - Heuristic backends: lower weight (0.60-0.70)
81//!
82//! ## Type-Conditioned Voting
83//!
84//! Some backends are better at certain types:
85//! - Pattern: DATE, MONEY, EMAIL, URL (near-perfect)
86//! - GLiNER: PER, ORG (good), LOC (decent)
87//! - Heuristic: ORG (good with "Inc", "Corp"), PER (title+name patterns)
88//!
89//! ## Agreement Bonus
90//!
91//! When multiple backends agree on type AND span, boost confidence:
92//! - 2 backends agree: +0.10 bonus
93//! - 3+ backends agree: +0.15 bonus
94//!
95//! # Example
96//!
97//! ```rust
98//! use anno::{Model, EnsembleNER};
99//!
100//! let ner = EnsembleNER::new();
101//! let entities = ner.extract_entities("Tim Cook leads Apple Inc.", None).unwrap();
102//!
103//! // Each entity includes provenance and agreement info
104//! for e in &entities {
105//!     println!("{}: {} (conf: {:.2}, sources: {:?})",
106//!              e.entity_type.as_label(), e.text, e.confidence,
107//!              e.provenance.as_ref().map(|p| &p.source));
108//! }
109//! ```
110
111use std::borrow::Cow;
112use std::collections::HashMap;
113use std::sync::Arc;
114
115use crate::{Entity, EntityType, Model, Result};
116
117fn method_for_backend_id(backend_id: &str) -> anno_core::ExtractionMethod {
118    match backend_id {
119        // Stable IDs used by `EnsembleNER::new()`.
120        "regex" => anno_core::ExtractionMethod::Pattern,
121        "heuristic" => anno_core::ExtractionMethod::Heuristic,
122        // Legacy backend id (deprecated, but still used in tests/compositions).
123        "rule" => anno_core::ExtractionMethod::Heuristic,
124        // Everything else: treat as neural by default.
125        _ => anno_core::ExtractionMethod::Neural,
126    }
127}
128
129pub mod weights;
130pub use weights::*;
131
132/// Weighted ensemble of NER backends.
133pub struct EnsembleNER {
134    backends: Vec<Arc<dyn Model + Send + Sync>>,
135    /// Stable backend IDs used for weighting and source tracking.
136    ///
137    /// This is intentionally decoupled from `backend.name()`, which is a
138    /// human-facing label and may vary across implementations (e.g. "GLiNER-ONNX").
139    backend_ids: Vec<String>,
140    weights: HashMap<String, BackendWeight>,
141    agreement_bonus: f64,
142    min_confidence: f64,
143    /// Transparent name showing constituent backends (e.g., "ensemble(regex|gliner|heuristic)")
144    name: String,
145    /// Cached static name (avoids Box::leak on every name() call)
146    name_static: std::sync::OnceLock<&'static str>,
147}
148
149impl Default for EnsembleNER {
150    fn default() -> Self {
151        Self::new()
152    }
153}
154
155impl EnsembleNER {
156    /// Create ensemble with all available backends.
157    #[must_use]
158    pub fn new() -> Self {
159        let mut backends: Vec<Arc<dyn Model + Send + Sync>> = Vec::new();
160        let mut backend_ids: Vec<&'static str> = Vec::new();
161
162        // Always add pattern (high precision for structured data)
163        backends.push(Arc::new(crate::RegexNER::new()));
164        backend_ids.push("regex");
165
166        // Add GLiNER if available
167        #[cfg(feature = "onnx")]
168        {
169            use super::GLiNEROnnx;
170            use crate::DEFAULT_GLINER_MODEL;
171            if let Ok(gliner) = GLiNEROnnx::new(DEFAULT_GLINER_MODEL) {
172                backends.push(Arc::new(gliner));
173                backend_ids.push("gliner");
174            }
175        }
176
177        // Add Candle GLiNER if available
178        #[cfg(feature = "candle")]
179        {
180            use super::GLiNERCandle;
181            use crate::DEFAULT_GLINER_MODEL;
182            if let Ok(candle) = GLiNERCandle::from_pretrained(DEFAULT_GLINER_MODEL) {
183                backends.push(Arc::new(candle));
184                backend_ids.push("gliner-candle");
185            }
186        }
187
188        // Always add heuristic as fallback
189        backends.push(Arc::new(crate::HeuristicNER::new()));
190        backend_ids.push("heuristic");
191
192        // Build transparent name showing constituents
193        // Use '|' for parallel weighted voting (no priority ordering)
194        let name = format!("ensemble({})", backend_ids.join("|"));
195
196        // Convert default weights to owned strings
197        let weights: HashMap<String, BackendWeight> = default_backend_weights()
198            .into_iter()
199            .map(|(k, v)| (k.to_string(), v))
200            .collect();
201
202        Self {
203            backends,
204            backend_ids: backend_ids.into_iter().map(str::to_string).collect(),
205            weights,
206            agreement_bonus: 0.10,
207            min_confidence: 0.30,
208            name,
209            name_static: std::sync::OnceLock::new(),
210        }
211    }
212
213    /// Create with custom backends.
214    #[must_use]
215    pub fn with_backends(backends: Vec<Box<dyn Model + Send + Sync>>) -> Self {
216        // For custom backends, use the backend's reported name as both ID and display string.
217        let backend_ids: Vec<String> = backends.iter().map(|b| b.name().to_string()).collect();
218        let name = format!("ensemble({})", backend_ids.join("|"));
219
220        let backends: Vec<Arc<dyn Model + Send + Sync>> =
221            backends.into_iter().map(Arc::from).collect();
222
223        let weights: HashMap<String, BackendWeight> = default_backend_weights()
224            .into_iter()
225            .map(|(k, v)| (k.to_string(), v))
226            .collect();
227
228        Self {
229            backends,
230            backend_ids,
231            weights,
232            agreement_bonus: 0.10,
233            min_confidence: 0.30,
234            name,
235            name_static: std::sync::OnceLock::new(),
236        }
237    }
238
239    /// Set custom backend weights.
240    #[must_use]
241    pub fn with_weights(mut self, weights: HashMap<String, BackendWeight>) -> Self {
242        self.weights = weights;
243        self
244    }
245
246    /// Set the agreement bonus (added when multiple backends agree).
247    #[must_use]
248    pub fn with_agreement_bonus(mut self, bonus: f64) -> Self {
249        self.agreement_bonus = bonus;
250        self
251    }
252
253    /// Set minimum confidence threshold.
254    #[must_use]
255    pub fn with_min_confidence(mut self, min: f64) -> Self {
256        self.min_confidence = min;
257        self
258    }
259
260    /// Get the weight for a backend and entity type.
261    fn get_weight(&self, backend_name: &str, entity_type: &EntityType) -> f64 {
262        if let Some(weight) = self.weights.get(backend_name) {
263            if let Some(ref type_weights) = weight.per_type {
264                type_weights.get(entity_type)
265            } else {
266                weight.overall
267            }
268        } else {
269            // Unknown backend - use conservative default
270            0.50
271        }
272    }
273
274    /// Resolve overlapping candidates using weighted voting.
275    fn resolve_candidates(&self, candidates: Vec<Candidate>) -> Option<Entity> {
276        if candidates.is_empty() {
277            return None;
278        }
279
280        if candidates.len() == 1 {
281            // Single candidate - use its confidence directly
282            let candidate = candidates
283                .into_iter()
284                .next()
285                .expect("candidates.len() == 1 guarantees next() is Some");
286            let mut entity = candidate.entity;
287            let original_prov = entity.provenance.clone();
288            let original_confidence = entity.confidence;
289            // Slight penalty for single-source
290            entity.confidence *= 0.95;
291            // Set provenance for single-source entities
292            entity.provenance = Some(anno_core::Provenance {
293                source: std::borrow::Cow::Owned(format!("ensemble({})", candidate.source)),
294                // Preserve underlying method/pattern when possible (important for nested ensembles).
295                method: original_prov
296                    .as_ref()
297                    .map(|p| p.method)
298                    .unwrap_or_else(|| method_for_backend_id(&candidate.source)),
299                pattern: original_prov.as_ref().and_then(|p| p.pattern.clone()),
300                raw_confidence: original_prov
301                    .as_ref()
302                    .and_then(|p| p.raw_confidence)
303                    .or(Some(original_confidence)),
304                model_version: None,
305                timestamp: None,
306            });
307            return Some(entity);
308        }
309
310        // Group by entity type
311        let mut type_votes: HashMap<String, Vec<&Candidate>> = HashMap::new();
312        for c in &candidates {
313            let type_key = c.entity.entity_type.as_label().to_string();
314            type_votes.entry(type_key).or_default().push(c);
315        }
316
317        // Find the type with highest weighted vote (deterministic tie-breaking).
318        //
319        // HashMap iteration order can vary across process runs. If two types tie on
320        // weighted_sum, we still need a stable selection.
321        //
322        // Ordering:
323        // 1) Higher weighted_sum wins
324        // 2) If tied, more candidates (more votes) wins
325        // 3) If tied, lexicographically smaller type key wins
326        let mut best_type: Option<(String, f64, usize, Vec<&Candidate>)> = None;
327        for (type_key, type_candidates) in &type_votes {
328            let weighted_sum: f64 = type_candidates
329                .iter()
330                .map(|c| c.backend_weight * c.entity.confidence)
331                .sum();
332            let count = type_candidates.len();
333
334            let should_replace = match &best_type {
335                None => true,
336                Some((best_key, best_sum, best_count, _)) => {
337                    if weighted_sum > *best_sum {
338                        true
339                    } else if weighted_sum < *best_sum {
340                        false
341                    } else if count > *best_count {
342                        true
343                    } else if count < *best_count {
344                        false
345                    } else {
346                        type_key < best_key
347                    }
348                }
349            };
350
351            if should_replace {
352                best_type = Some((
353                    type_key.clone(),
354                    weighted_sum,
355                    count,
356                    type_candidates.clone(),
357                ));
358            }
359        }
360
361        let (_type_key, weighted_sum, _count, winning_candidates) = best_type?;
362
363        // Calculate ensemble confidence
364        let num_sources = winning_candidates.len();
365        let total_weight: f64 = winning_candidates.iter().map(|c| c.backend_weight).sum();
366
367        let base_confidence = if total_weight > 0.0 {
368            weighted_sum / total_weight
369        } else {
370            0.5
371        };
372
373        // Agreement bonus
374        let agreement_bonus = if num_sources >= 3 {
375            self.agreement_bonus * 1.5
376        } else if num_sources >= 2 {
377            self.agreement_bonus
378        } else {
379            0.0
380        };
381
382        let final_confidence = (base_confidence + agreement_bonus).min(1.0);
383
384        // Build merged entity
385        // Use the candidate with highest individual confidence as base
386        let best_candidate = winning_candidates.iter().max_by(|a, b| {
387            a.entity
388                .confidence
389                .partial_cmp(&b.entity.confidence)
390                .unwrap_or(std::cmp::Ordering::Equal)
391        })?;
392
393        let sources: Vec<String> = winning_candidates
394            .iter()
395            .map(|c| c.source.clone())
396            .collect();
397
398        // Calculate hierarchical confidence scores
399        // - linkage: How many backends detected an entity here (normalized)
400        // - type_score: Agreement on type classification
401        // - boundary: Agreement on exact span boundaries
402        let total_candidates = candidates.len() as f32;
403        let num_winners = winning_candidates.len() as f32;
404
405        // Linkage: ratio of candidates in winning type
406        let linkage = if total_candidates > 0.0 {
407            (num_winners / total_candidates).min(1.0)
408        } else {
409            0.5
410        };
411
412        // Type score: confidence in the winning type (weighted)
413        let type_score = final_confidence as f32;
414
415        // Boundary: agreement on span boundaries
416        // Check if all winning candidates have the same start/end
417        let reference_span = (best_candidate.entity.start, best_candidate.entity.end);
418        let span_agreement_count = winning_candidates
419            .iter()
420            .filter(|c| c.entity.start == reference_span.0 && c.entity.end == reference_span.1)
421            .count();
422        let boundary = if num_winners > 0.0 {
423            (span_agreement_count as f32 / num_winners).min(1.0)
424        } else {
425            1.0
426        };
427
428        let mut entity = best_candidate.entity.clone();
429        entity.confidence = final_confidence;
430        entity.hierarchical_confidence = Some(anno_core::HierarchicalConfidence::new(
431            linkage, type_score, boundary,
432        ));
433        entity.provenance = Some(anno_core::Provenance {
434            source: Cow::Owned(format!("ensemble({})", sources.join("+"))),
435            method: anno_core::ExtractionMethod::Consensus,
436            pattern: None,
437            raw_confidence: Some(base_confidence),
438            model_version: None,
439            timestamp: None,
440        });
441
442        Some(entity)
443    }
444}
445
446impl Model for EnsembleNER {
447    fn extract_entities(&self, text: &str, language: Option<&str>) -> Result<Vec<Entity>> {
448        if self.backends.is_empty() {
449            return Ok(Vec::new());
450        }
451
452        // Phase 1: Collect candidates from all backends
453        let mut all_candidates: Vec<Candidate> = Vec::new();
454
455        for (i, backend) in self.backends.iter().enumerate() {
456            let backend_id = self
457                .backend_ids
458                .get(i)
459                .cloned()
460                .unwrap_or_else(|| backend.name().to_string());
461
462            match backend.extract_entities(text, language) {
463                Ok(entities) => {
464                    for entity in entities {
465                        let weight = self.get_weight(&backend_id, &entity.entity_type);
466                        all_candidates.push(Candidate {
467                            entity,
468                            source: backend_id.clone(),
469                            backend_weight: weight,
470                        });
471                    }
472                }
473                Err(e) => {
474                    // Log but continue (opportunistic)
475                    log::debug!(
476                        "EnsembleNER: Backend {} (id={}) failed: {}",
477                        backend.name(),
478                        backend_id,
479                        e
480                    );
481                }
482            }
483        }
484
485        if all_candidates.is_empty() {
486            return Ok(Vec::new());
487        }
488
489        // Phase 2: Group candidates by overlapping spans
490        let mut span_groups: Vec<Vec<Candidate>> = Vec::new();
491
492        for candidate in all_candidates {
493            let span = SpanKey::from_entity(&candidate.entity);
494
495            // Find existing group with overlapping span
496            let mut found_group = false;
497            for group in &mut span_groups {
498                if let Some(first) = group.first() {
499                    let existing_span = SpanKey::from_entity(&first.entity);
500                    if span.overlaps(&existing_span) {
501                        group.push(candidate.clone());
502                        found_group = true;
503                        break;
504                    }
505                }
506            }
507
508            if !found_group {
509                span_groups.push(vec![candidate]);
510            }
511        }
512
513        // Phase 3: Resolve each group
514        let mut results: Vec<Entity> = Vec::new();
515
516        for group in span_groups {
517            if let Some(entity) = self.resolve_candidates(group) {
518                if entity.confidence >= self.min_confidence {
519                    results.push(entity);
520                }
521            }
522        }
523
524        // Sort by position
525        results.sort_by_key(|e| (e.start, e.end));
526
527        Ok(results)
528    }
529
530    fn supported_types(&self) -> Vec<EntityType> {
531        // Union of all backend types
532        let mut types: Vec<EntityType> = Vec::new();
533        for backend in &self.backends {
534            for t in backend.supported_types() {
535                if !types.contains(&t) {
536                    types.push(t);
537                }
538            }
539        }
540        types
541    }
542
543    fn is_available(&self) -> bool {
544        // Available if at least one backend is available
545        self.backends.iter().any(|b| b.is_available())
546    }
547
548    fn name(&self) -> &'static str {
549        // Use OnceLock to cache the static string, avoiding repeated memory leaks
550        self.name_static
551            .get_or_init(|| Box::leak(self.name.clone().into_boxed_str()))
552    }
553
554    fn description(&self) -> &'static str {
555        "Ensemble NER: weighted voting across multiple backends"
556    }
557
558    fn capabilities(&self) -> crate::ModelCapabilities {
559        crate::ModelCapabilities {
560            batch_capable: true,
561            streaming_capable: true,
562            ..Default::default()
563        }
564    }
565}
566
567// Implement required traits
568impl crate::NamedEntityCapable for EnsembleNER {}
569
570impl crate::BatchCapable for EnsembleNER {
571    fn optimal_batch_size(&self) -> Option<usize> {
572        Some(8) // Reasonable default for ensemble
573    }
574}
575
576impl crate::StreamingCapable for EnsembleNER {
577    fn recommended_chunk_size(&self) -> usize {
578        8192
579    }
580}
581
582// =============================================================================
583// Weight Learning
584// =============================================================================
585
586/// Training example for weight learning.
587#[derive(Debug, Clone)]
588pub struct WeightTrainingExample {
589    /// Text of the entity
590    pub text: String,
591    /// True entity type (gold label)
592    pub gold_type: EntityType,
593    /// Span start
594    pub start: usize,
595    /// Span end
596    pub end: usize,
597    /// Predictions from each backend: (backend_name, predicted_type, confidence)
598    pub predictions: Vec<(String, EntityType, f64)>,
599}
600
601/// Statistics for weight learning.
602#[derive(Debug, Clone, Default)]
603pub struct BackendStats {
604    /// Total correct predictions
605    pub correct: usize,
606    /// Total predictions made
607    pub total: usize,
608    /// Per-type statistics: (type, correct, total)
609    pub per_type: HashMap<String, (usize, usize)>,
610}
611
612impl BackendStats {
613    /// Calculate overall precision.
614    pub fn precision(&self) -> f64 {
615        if self.total == 0 {
616            0.0
617        } else {
618            self.correct as f64 / self.total as f64
619        }
620    }
621
622    /// Calculate per-type precision.
623    pub fn type_precision(&self, entity_type: &str) -> f64 {
624        if let Some((correct, total)) = self.per_type.get(entity_type) {
625            if *total == 0 {
626                0.0
627            } else {
628                *correct as f64 / *total as f64
629            }
630        } else {
631            0.0
632        }
633    }
634}
635
636/// Weight learner for EnsembleNER.
637///
638/// Learns optimal backend weights from evaluation data.
639///
640/// # Example
641///
642/// ```rust,ignore
643/// use anno::backends::ensemble::{EnsembleNER, WeightLearner};
644///
645/// let mut learner = WeightLearner::new();
646///
647/// // Add training examples from gold data
648/// for (text, gold_entities) in gold_data {
649///     learner.add_examples(&text, &gold_entities, &backends);
650/// }
651///
652/// // Learn weights
653/// let learned_weights = learner.learn_weights();
654///
655/// // Create ensemble with learned weights
656/// let ensemble = EnsembleNER::new().with_weights(learned_weights);
657/// ```
658pub struct WeightLearner {
659    /// Per-backend statistics
660    backend_stats: HashMap<String, BackendStats>,
661    /// Smoothing factor for precision (avoid division by zero / overfitting)
662    smoothing: f64,
663}
664
665impl Default for WeightLearner {
666    fn default() -> Self {
667        Self::new()
668    }
669}
670
671impl WeightLearner {
672    /// Create a new weight learner.
673    #[must_use]
674    pub fn new() -> Self {
675        Self {
676            backend_stats: HashMap::new(),
677            smoothing: 1.0, // Laplace smoothing
678        }
679    }
680
681    /// Set smoothing factor.
682    #[must_use]
683    pub fn with_smoothing(mut self, smoothing: f64) -> Self {
684        self.smoothing = smoothing;
685        self
686    }
687
688    /// Add a training example.
689    pub fn add_example(&mut self, example: &WeightTrainingExample) {
690        for (backend_name, predicted_type, _confidence) in &example.predictions {
691            let stats = self.backend_stats.entry(backend_name.clone()).or_default();
692
693            stats.total += 1;
694            let correct = *predicted_type == example.gold_type;
695            if correct {
696                stats.correct += 1;
697            }
698
699            // Per-type stats
700            let type_key = example.gold_type.as_label().to_string();
701            let type_stats = stats.per_type.entry(type_key).or_insert((0, 0));
702            type_stats.1 += 1;
703            if correct {
704                type_stats.0 += 1;
705            }
706        }
707    }
708
709    /// Add examples from gold entities and backend predictions.
710    ///
711    /// Runs each backend on the text and compares to gold entities.
712    pub fn add_from_backends(
713        &mut self,
714        text: &str,
715        gold_entities: &[Entity],
716        backends: &[(&str, &dyn Model)],
717    ) {
718        // Get predictions from each backend
719        let mut backend_preds: HashMap<String, Vec<Entity>> = HashMap::new();
720        for (name, backend) in backends {
721            if let Ok(entities) = backend.extract_entities(text, None) {
722                backend_preds.insert(name.to_string(), entities);
723            }
724        }
725
726        // Match predictions to gold entities
727        for gold in gold_entities {
728            let mut example = WeightTrainingExample {
729                text: gold.text.clone(),
730                gold_type: gold.entity_type.clone(),
731                start: gold.start,
732                end: gold.end,
733                predictions: Vec::new(),
734            };
735
736            for (backend_name, entities) in &backend_preds {
737                // Find matching prediction (same span)
738                for pred in entities {
739                    if pred.start == gold.start && pred.end == gold.end {
740                        example.predictions.push((
741                            backend_name.clone(),
742                            pred.entity_type.clone(),
743                            pred.confidence,
744                        ));
745                        break;
746                    }
747                }
748            }
749
750            if !example.predictions.is_empty() {
751                self.add_example(&example);
752            }
753        }
754    }
755
756    /// Learn optimal weights from accumulated statistics.
757    ///
758    /// Uses precision-based weighting with Laplace smoothing.
759    pub fn learn_weights(&self) -> HashMap<String, BackendWeight> {
760        let mut weights = HashMap::new();
761
762        for (backend_name, stats) in &self.backend_stats {
763            // Smoothed precision: (correct + smoothing) / (total + 2*smoothing)
764            let smoothed_precision = (stats.correct as f64 + self.smoothing)
765                / (stats.total as f64 + 2.0 * self.smoothing);
766
767            // Per-type weights
768            let mut type_weights = TypeWeights::default();
769            for (type_key, (correct, total)) in &stats.per_type {
770                let type_precision =
771                    (*correct as f64 + self.smoothing) / (*total as f64 + 2.0 * self.smoothing);
772
773                match type_key.as_str() {
774                    "PER" | "PERSON" => type_weights.person = type_precision,
775                    "ORG" | "ORGANIZATION" => type_weights.organization = type_precision,
776                    "LOC" | "LOCATION" | "GPE" => type_weights.location = type_precision,
777                    "DATE" => type_weights.date = type_precision,
778                    "MONEY" => type_weights.money = type_precision,
779                    _ => type_weights.other = type_precision,
780                }
781            }
782
783            weights.insert(
784                backend_name.clone(),
785                BackendWeight {
786                    overall: smoothed_precision,
787                    per_type: Some(type_weights),
788                },
789            );
790        }
791
792        weights
793    }
794
795    /// Get statistics for a backend.
796    pub fn get_stats(&self, backend_name: &str) -> Option<&BackendStats> {
797        self.backend_stats.get(backend_name)
798    }
799
800    /// Get all backend names.
801    pub fn backend_names(&self) -> Vec<&String> {
802        self.backend_stats.keys().collect()
803    }
804}
805
806// =============================================================================
807// Tests
808// =============================================================================
809
810#[cfg(test)]
811mod tests;