anno/backends/ensemble/mod.rs
1//! Ensemble NER - Multi-backend extraction with unsupervised weighted voting.
2//!
3//! # Method
4//!
5//! This is an **unsupervised heuristic** approach (no training data required).
6//! Conflict resolution uses hand-tuned weights based on expected backend reliability.
7//! For supervised weight learning from labeled data, see `WeightLearner`.
8//!
9//! # The Core Idea
10//!
11//! Instead of simple priority-based stacking, `EnsembleNER`:
12//! 1. Runs ALL available backends opportunistically (in parallel conceptually)
13//! 2. Collects candidate entities with provenance
14//! 3. Groups overlapping spans into conflict clusters
15//! 4. Resolves conflicts using weighted voting with agreement bonus
16//!
17//! ```text
18//! ┌─────────────────────────────────────────────────────────────────────────┐
19//! │ ENSEMBLE NER ARCHITECTURE │
20//! ├─────────────────────────────────────────────────────────────────────────┤
21//! │ │
22//! │ Input: "Tim Cook, CEO of Apple, met with Sundar Pichai" │
23//! │ │
24//! │ ┌──────────────────────────────────────────────────────────────────┐ │
25//! │ │ PHASE 1: OPPORTUNISTIC EXTRACTION (parallel) │ │
26//! │ │ │ │
27//! │ │ Pattern ──────► [no entities] │ │
28//! │ │ Heuristic ────► Tim Cook (PER, 0.75), Apple (ORG, 0.80), ... │ │
29//! │ │ GLiNER ────────► Tim Cook (PER, 0.95), Apple (ORG, 0.87), ... │ │
30//! │ │ Candle ────────► [unavailable, skip] │ │
31//! │ └──────────────────────────────────────────────────────────────────┘ │
32//! │ │ │
33//! │ ▼ │
34//! │ ┌──────────────────────────────────────────────────────────────────┐ │
35//! │ │ PHASE 2: CANDIDATE AGGREGATION │ │
36//! │ │ │ │
37//! │ │ Span [0:8] "Tim Cook": │ │
38//! │ │ • Heuristic: PER (0.75) │ │
39//! │ │ • GLiNER: PER (0.95) │ │
40//! │ │ Agreement: 2/2 → HIGH confidence │ │
41//! │ │ │ │
42//! │ │ Span [17:22] "Apple": │ │
43//! │ │ • Heuristic: ORG (0.80) │ │
44//! │ │ • GLiNER: ORG (0.87) │ │
45//! │ │ Agreement: 2/2 → HIGH confidence │ │
46//! │ └──────────────────────────────────────────────────────────────────┘ │
47//! │ │ │
48//! │ ▼ │
49//! │ ┌──────────────────────────────────────────────────────────────────┐ │
50//! │ │ PHASE 3: CONFLICT RESOLUTION (weighted voting) │ │
51//! │ │ │ │
52//! │ │ Backend weights (learned or configured): │ │
53//! │ │ Pattern: 0.99 (when fires, almost always right) │ │
54//! │ │ GLiNER: 0.85 (ML-based, good accuracy) │ │
55//! │ │ Heuristic: 0.65 (reasonable but noisy) │ │
56//! │ │ │ │
57//! │ │ For span [0:8]: │ │
58//! │ │ Weighted vote = (0.65 * 0.75) + (0.85 * 0.95) = 1.29 │ │
59//! │ │ Normalized confidence = 0.91 │ │
60//! │ └──────────────────────────────────────────────────────────────────┘ │
61//! │ │ │
62//! │ ▼ │
63//! │ ┌──────────────────────────────────────────────────────────────────┐ │
64//! │ │ OUTPUT │ │
65//! │ │ │ │
66//! │ │ Entity { text: "Tim Cook", type: PER, conf: 0.91, │ │
67//! │ │ sources: ["heuristic", "gliner"], agreement: 1.0 } │ │
68//! │ └──────────────────────────────────────────────────────────────────┘ │
69//! │ │
70//! └─────────────────────────────────────────────────────────────────────────┘
71//! ```
72//!
73//! # Conflict Resolution Strategies
74//!
75//! ## Weighted Voting (Unsupervised)
76//!
77//! Each backend has a weight based on its expected reliability:
78//! - Pattern backends: high weight (0.95+) when they fire
79//! - ML backends: medium-high weight (0.80-0.90)
80//! - Heuristic backends: lower weight (0.60-0.70)
81//!
82//! ## Type-Conditioned Voting
83//!
84//! Some backends are better at certain types:
85//! - Pattern: DATE, MONEY, EMAIL, URL (near-perfect)
86//! - GLiNER: PER, ORG (good), LOC (decent)
87//! - Heuristic: ORG (good with "Inc", "Corp"), PER (title+name patterns)
88//!
89//! ## Agreement Bonus
90//!
91//! When multiple backends agree on type AND span, boost confidence:
92//! - 2 backends agree: +0.10 bonus
93//! - 3+ backends agree: +0.15 bonus
94//!
95//! # Example
96//!
97//! ```rust
98//! use anno::{Model, EnsembleNER};
99//!
100//! let ner = EnsembleNER::new();
101//! let entities = ner.extract_entities("Tim Cook leads Apple Inc.", None).unwrap();
102//!
103//! // Each entity includes provenance and agreement info
104//! for e in &entities {
105//! println!("{}: {} (conf: {:.2}, sources: {:?})",
106//! e.entity_type.as_label(), e.text, e.confidence,
107//! e.provenance.as_ref().map(|p| &p.source));
108//! }
109//! ```
110
111use std::borrow::Cow;
112use std::collections::HashMap;
113use std::sync::Arc;
114
115use crate::{Entity, EntityType, Model, Result};
116
117fn method_for_backend_id(backend_id: &str) -> anno_core::ExtractionMethod {
118 match backend_id {
119 // Stable IDs used by `EnsembleNER::new()`.
120 "regex" => anno_core::ExtractionMethod::Pattern,
121 "heuristic" => anno_core::ExtractionMethod::Heuristic,
122 // Legacy backend id (deprecated, but still used in tests/compositions).
123 "rule" => anno_core::ExtractionMethod::Heuristic,
124 // Everything else: treat as neural by default.
125 _ => anno_core::ExtractionMethod::Neural,
126 }
127}
128
129pub mod weights;
130pub use weights::*;
131
132/// Weighted ensemble of NER backends.
133pub struct EnsembleNER {
134 backends: Vec<Arc<dyn Model + Send + Sync>>,
135 /// Stable backend IDs used for weighting and source tracking.
136 ///
137 /// This is intentionally decoupled from `backend.name()`, which is a
138 /// human-facing label and may vary across implementations (e.g. "GLiNER-ONNX").
139 backend_ids: Vec<String>,
140 weights: HashMap<String, BackendWeight>,
141 agreement_bonus: f64,
142 min_confidence: f64,
143 /// Transparent name showing constituent backends (e.g., "ensemble(regex|gliner|heuristic)")
144 name: String,
145 /// Cached static name (avoids Box::leak on every name() call)
146 name_static: std::sync::OnceLock<&'static str>,
147}
148
149impl Default for EnsembleNER {
150 fn default() -> Self {
151 Self::new()
152 }
153}
154
155impl EnsembleNER {
156 /// Create ensemble with all available backends.
157 #[must_use]
158 pub fn new() -> Self {
159 let mut backends: Vec<Arc<dyn Model + Send + Sync>> = Vec::new();
160 let mut backend_ids: Vec<&'static str> = Vec::new();
161
162 // Always add pattern (high precision for structured data)
163 backends.push(Arc::new(crate::RegexNER::new()));
164 backend_ids.push("regex");
165
166 // Add GLiNER if available
167 #[cfg(feature = "onnx")]
168 {
169 use super::GLiNEROnnx;
170 use crate::DEFAULT_GLINER_MODEL;
171 if let Ok(gliner) = GLiNEROnnx::new(DEFAULT_GLINER_MODEL) {
172 backends.push(Arc::new(gliner));
173 backend_ids.push("gliner");
174 }
175 }
176
177 // Add Candle GLiNER if available
178 #[cfg(feature = "candle")]
179 {
180 use super::GLiNERCandle;
181 use crate::DEFAULT_GLINER_MODEL;
182 if let Ok(candle) = GLiNERCandle::from_pretrained(DEFAULT_GLINER_MODEL) {
183 backends.push(Arc::new(candle));
184 backend_ids.push("gliner-candle");
185 }
186 }
187
188 // Always add heuristic as fallback
189 backends.push(Arc::new(crate::HeuristicNER::new()));
190 backend_ids.push("heuristic");
191
192 // Build transparent name showing constituents
193 // Use '|' for parallel weighted voting (no priority ordering)
194 let name = format!("ensemble({})", backend_ids.join("|"));
195
196 // Convert default weights to owned strings
197 let weights: HashMap<String, BackendWeight> = default_backend_weights()
198 .into_iter()
199 .map(|(k, v)| (k.to_string(), v))
200 .collect();
201
202 Self {
203 backends,
204 backend_ids: backend_ids.into_iter().map(str::to_string).collect(),
205 weights,
206 agreement_bonus: 0.10,
207 min_confidence: 0.30,
208 name,
209 name_static: std::sync::OnceLock::new(),
210 }
211 }
212
213 /// Create with custom backends.
214 #[must_use]
215 pub fn with_backends(backends: Vec<Box<dyn Model + Send + Sync>>) -> Self {
216 // For custom backends, use the backend's reported name as both ID and display string.
217 let backend_ids: Vec<String> = backends.iter().map(|b| b.name().to_string()).collect();
218 let name = format!("ensemble({})", backend_ids.join("|"));
219
220 let backends: Vec<Arc<dyn Model + Send + Sync>> =
221 backends.into_iter().map(Arc::from).collect();
222
223 let weights: HashMap<String, BackendWeight> = default_backend_weights()
224 .into_iter()
225 .map(|(k, v)| (k.to_string(), v))
226 .collect();
227
228 Self {
229 backends,
230 backend_ids,
231 weights,
232 agreement_bonus: 0.10,
233 min_confidence: 0.30,
234 name,
235 name_static: std::sync::OnceLock::new(),
236 }
237 }
238
239 /// Set custom backend weights.
240 #[must_use]
241 pub fn with_weights(mut self, weights: HashMap<String, BackendWeight>) -> Self {
242 self.weights = weights;
243 self
244 }
245
246 /// Set the agreement bonus (added when multiple backends agree).
247 #[must_use]
248 pub fn with_agreement_bonus(mut self, bonus: f64) -> Self {
249 self.agreement_bonus = bonus;
250 self
251 }
252
253 /// Set minimum confidence threshold.
254 #[must_use]
255 pub fn with_min_confidence(mut self, min: f64) -> Self {
256 self.min_confidence = min;
257 self
258 }
259
260 /// Get the weight for a backend and entity type.
261 fn get_weight(&self, backend_name: &str, entity_type: &EntityType) -> f64 {
262 if let Some(weight) = self.weights.get(backend_name) {
263 if let Some(ref type_weights) = weight.per_type {
264 type_weights.get(entity_type)
265 } else {
266 weight.overall
267 }
268 } else {
269 // Unknown backend - use conservative default
270 0.50
271 }
272 }
273
274 /// Resolve overlapping candidates using weighted voting.
275 fn resolve_candidates(&self, candidates: Vec<Candidate>) -> Option<Entity> {
276 if candidates.is_empty() {
277 return None;
278 }
279
280 if candidates.len() == 1 {
281 // Single candidate - use its confidence directly
282 let candidate = candidates
283 .into_iter()
284 .next()
285 .expect("candidates.len() == 1 guarantees next() is Some");
286 let mut entity = candidate.entity;
287 let original_prov = entity.provenance.clone();
288 let original_confidence = entity.confidence;
289 // Slight penalty for single-source
290 entity.confidence *= 0.95;
291 // Set provenance for single-source entities
292 entity.provenance = Some(anno_core::Provenance {
293 source: std::borrow::Cow::Owned(format!("ensemble({})", candidate.source)),
294 // Preserve underlying method/pattern when possible (important for nested ensembles).
295 method: original_prov
296 .as_ref()
297 .map(|p| p.method)
298 .unwrap_or_else(|| method_for_backend_id(&candidate.source)),
299 pattern: original_prov.as_ref().and_then(|p| p.pattern.clone()),
300 raw_confidence: original_prov
301 .as_ref()
302 .and_then(|p| p.raw_confidence)
303 .or(Some(original_confidence)),
304 model_version: None,
305 timestamp: None,
306 });
307 return Some(entity);
308 }
309
310 // Group by entity type
311 let mut type_votes: HashMap<String, Vec<&Candidate>> = HashMap::new();
312 for c in &candidates {
313 let type_key = c.entity.entity_type.as_label().to_string();
314 type_votes.entry(type_key).or_default().push(c);
315 }
316
317 // Find the type with highest weighted vote (deterministic tie-breaking).
318 //
319 // HashMap iteration order can vary across process runs. If two types tie on
320 // weighted_sum, we still need a stable selection.
321 //
322 // Ordering:
323 // 1) Higher weighted_sum wins
324 // 2) If tied, more candidates (more votes) wins
325 // 3) If tied, lexicographically smaller type key wins
326 let mut best_type: Option<(String, f64, usize, Vec<&Candidate>)> = None;
327 for (type_key, type_candidates) in &type_votes {
328 let weighted_sum: f64 = type_candidates
329 .iter()
330 .map(|c| c.backend_weight * c.entity.confidence)
331 .sum();
332 let count = type_candidates.len();
333
334 let should_replace = match &best_type {
335 None => true,
336 Some((best_key, best_sum, best_count, _)) => {
337 if weighted_sum > *best_sum {
338 true
339 } else if weighted_sum < *best_sum {
340 false
341 } else if count > *best_count {
342 true
343 } else if count < *best_count {
344 false
345 } else {
346 type_key < best_key
347 }
348 }
349 };
350
351 if should_replace {
352 best_type = Some((
353 type_key.clone(),
354 weighted_sum,
355 count,
356 type_candidates.clone(),
357 ));
358 }
359 }
360
361 let (_type_key, weighted_sum, _count, winning_candidates) = best_type?;
362
363 // Calculate ensemble confidence
364 let num_sources = winning_candidates.len();
365 let total_weight: f64 = winning_candidates.iter().map(|c| c.backend_weight).sum();
366
367 let base_confidence = if total_weight > 0.0 {
368 weighted_sum / total_weight
369 } else {
370 0.5
371 };
372
373 // Agreement bonus
374 let agreement_bonus = if num_sources >= 3 {
375 self.agreement_bonus * 1.5
376 } else if num_sources >= 2 {
377 self.agreement_bonus
378 } else {
379 0.0
380 };
381
382 let final_confidence = (base_confidence + agreement_bonus).min(1.0);
383
384 // Build merged entity
385 // Use the candidate with highest individual confidence as base
386 let best_candidate = winning_candidates.iter().max_by(|a, b| {
387 a.entity
388 .confidence
389 .partial_cmp(&b.entity.confidence)
390 .unwrap_or(std::cmp::Ordering::Equal)
391 })?;
392
393 let sources: Vec<String> = winning_candidates
394 .iter()
395 .map(|c| c.source.clone())
396 .collect();
397
398 // Calculate hierarchical confidence scores
399 // - linkage: How many backends detected an entity here (normalized)
400 // - type_score: Agreement on type classification
401 // - boundary: Agreement on exact span boundaries
402 let total_candidates = candidates.len() as f32;
403 let num_winners = winning_candidates.len() as f32;
404
405 // Linkage: ratio of candidates in winning type
406 let linkage = if total_candidates > 0.0 {
407 (num_winners / total_candidates).min(1.0)
408 } else {
409 0.5
410 };
411
412 // Type score: confidence in the winning type (weighted)
413 let type_score = final_confidence as f32;
414
415 // Boundary: agreement on span boundaries
416 // Check if all winning candidates have the same start/end
417 let reference_span = (best_candidate.entity.start, best_candidate.entity.end);
418 let span_agreement_count = winning_candidates
419 .iter()
420 .filter(|c| c.entity.start == reference_span.0 && c.entity.end == reference_span.1)
421 .count();
422 let boundary = if num_winners > 0.0 {
423 (span_agreement_count as f32 / num_winners).min(1.0)
424 } else {
425 1.0
426 };
427
428 let mut entity = best_candidate.entity.clone();
429 entity.confidence = final_confidence;
430 entity.hierarchical_confidence = Some(anno_core::HierarchicalConfidence::new(
431 linkage, type_score, boundary,
432 ));
433 entity.provenance = Some(anno_core::Provenance {
434 source: Cow::Owned(format!("ensemble({})", sources.join("+"))),
435 method: anno_core::ExtractionMethod::Consensus,
436 pattern: None,
437 raw_confidence: Some(base_confidence),
438 model_version: None,
439 timestamp: None,
440 });
441
442 Some(entity)
443 }
444}
445
446impl Model for EnsembleNER {
447 fn extract_entities(&self, text: &str, language: Option<&str>) -> Result<Vec<Entity>> {
448 if self.backends.is_empty() {
449 return Ok(Vec::new());
450 }
451
452 // Phase 1: Collect candidates from all backends
453 let mut all_candidates: Vec<Candidate> = Vec::new();
454
455 for (i, backend) in self.backends.iter().enumerate() {
456 let backend_id = self
457 .backend_ids
458 .get(i)
459 .cloned()
460 .unwrap_or_else(|| backend.name().to_string());
461
462 match backend.extract_entities(text, language) {
463 Ok(entities) => {
464 for entity in entities {
465 let weight = self.get_weight(&backend_id, &entity.entity_type);
466 all_candidates.push(Candidate {
467 entity,
468 source: backend_id.clone(),
469 backend_weight: weight,
470 });
471 }
472 }
473 Err(e) => {
474 // Log but continue (opportunistic)
475 log::debug!(
476 "EnsembleNER: Backend {} (id={}) failed: {}",
477 backend.name(),
478 backend_id,
479 e
480 );
481 }
482 }
483 }
484
485 if all_candidates.is_empty() {
486 return Ok(Vec::new());
487 }
488
489 // Phase 2: Group candidates by overlapping spans
490 let mut span_groups: Vec<Vec<Candidate>> = Vec::new();
491
492 for candidate in all_candidates {
493 let span = SpanKey::from_entity(&candidate.entity);
494
495 // Find existing group with overlapping span
496 let mut found_group = false;
497 for group in &mut span_groups {
498 if let Some(first) = group.first() {
499 let existing_span = SpanKey::from_entity(&first.entity);
500 if span.overlaps(&existing_span) {
501 group.push(candidate.clone());
502 found_group = true;
503 break;
504 }
505 }
506 }
507
508 if !found_group {
509 span_groups.push(vec![candidate]);
510 }
511 }
512
513 // Phase 3: Resolve each group
514 let mut results: Vec<Entity> = Vec::new();
515
516 for group in span_groups {
517 if let Some(entity) = self.resolve_candidates(group) {
518 if entity.confidence >= self.min_confidence {
519 results.push(entity);
520 }
521 }
522 }
523
524 // Sort by position
525 results.sort_by_key(|e| (e.start, e.end));
526
527 Ok(results)
528 }
529
530 fn supported_types(&self) -> Vec<EntityType> {
531 // Union of all backend types
532 let mut types: Vec<EntityType> = Vec::new();
533 for backend in &self.backends {
534 for t in backend.supported_types() {
535 if !types.contains(&t) {
536 types.push(t);
537 }
538 }
539 }
540 types
541 }
542
543 fn is_available(&self) -> bool {
544 // Available if at least one backend is available
545 self.backends.iter().any(|b| b.is_available())
546 }
547
548 fn name(&self) -> &'static str {
549 // Use OnceLock to cache the static string, avoiding repeated memory leaks
550 self.name_static
551 .get_or_init(|| Box::leak(self.name.clone().into_boxed_str()))
552 }
553
554 fn description(&self) -> &'static str {
555 "Ensemble NER: weighted voting across multiple backends"
556 }
557
558 fn capabilities(&self) -> crate::ModelCapabilities {
559 crate::ModelCapabilities {
560 batch_capable: true,
561 streaming_capable: true,
562 ..Default::default()
563 }
564 }
565}
566
567// Implement required traits
568impl crate::NamedEntityCapable for EnsembleNER {}
569
570impl crate::BatchCapable for EnsembleNER {
571 fn optimal_batch_size(&self) -> Option<usize> {
572 Some(8) // Reasonable default for ensemble
573 }
574}
575
576impl crate::StreamingCapable for EnsembleNER {
577 fn recommended_chunk_size(&self) -> usize {
578 8192
579 }
580}
581
582// =============================================================================
583// Weight Learning
584// =============================================================================
585
586/// Training example for weight learning.
587#[derive(Debug, Clone)]
588pub struct WeightTrainingExample {
589 /// Text of the entity
590 pub text: String,
591 /// True entity type (gold label)
592 pub gold_type: EntityType,
593 /// Span start
594 pub start: usize,
595 /// Span end
596 pub end: usize,
597 /// Predictions from each backend: (backend_name, predicted_type, confidence)
598 pub predictions: Vec<(String, EntityType, f64)>,
599}
600
601/// Statistics for weight learning.
602#[derive(Debug, Clone, Default)]
603pub struct BackendStats {
604 /// Total correct predictions
605 pub correct: usize,
606 /// Total predictions made
607 pub total: usize,
608 /// Per-type statistics: (type, correct, total)
609 pub per_type: HashMap<String, (usize, usize)>,
610}
611
612impl BackendStats {
613 /// Calculate overall precision.
614 pub fn precision(&self) -> f64 {
615 if self.total == 0 {
616 0.0
617 } else {
618 self.correct as f64 / self.total as f64
619 }
620 }
621
622 /// Calculate per-type precision.
623 pub fn type_precision(&self, entity_type: &str) -> f64 {
624 if let Some((correct, total)) = self.per_type.get(entity_type) {
625 if *total == 0 {
626 0.0
627 } else {
628 *correct as f64 / *total as f64
629 }
630 } else {
631 0.0
632 }
633 }
634}
635
636/// Weight learner for EnsembleNER.
637///
638/// Learns optimal backend weights from evaluation data.
639///
640/// # Example
641///
642/// ```rust,ignore
643/// use anno::backends::ensemble::{EnsembleNER, WeightLearner};
644///
645/// let mut learner = WeightLearner::new();
646///
647/// // Add training examples from gold data
648/// for (text, gold_entities) in gold_data {
649/// learner.add_examples(&text, &gold_entities, &backends);
650/// }
651///
652/// // Learn weights
653/// let learned_weights = learner.learn_weights();
654///
655/// // Create ensemble with learned weights
656/// let ensemble = EnsembleNER::new().with_weights(learned_weights);
657/// ```
658pub struct WeightLearner {
659 /// Per-backend statistics
660 backend_stats: HashMap<String, BackendStats>,
661 /// Smoothing factor for precision (avoid division by zero / overfitting)
662 smoothing: f64,
663}
664
665impl Default for WeightLearner {
666 fn default() -> Self {
667 Self::new()
668 }
669}
670
671impl WeightLearner {
672 /// Create a new weight learner.
673 #[must_use]
674 pub fn new() -> Self {
675 Self {
676 backend_stats: HashMap::new(),
677 smoothing: 1.0, // Laplace smoothing
678 }
679 }
680
681 /// Set smoothing factor.
682 #[must_use]
683 pub fn with_smoothing(mut self, smoothing: f64) -> Self {
684 self.smoothing = smoothing;
685 self
686 }
687
688 /// Add a training example.
689 pub fn add_example(&mut self, example: &WeightTrainingExample) {
690 for (backend_name, predicted_type, _confidence) in &example.predictions {
691 let stats = self.backend_stats.entry(backend_name.clone()).or_default();
692
693 stats.total += 1;
694 let correct = *predicted_type == example.gold_type;
695 if correct {
696 stats.correct += 1;
697 }
698
699 // Per-type stats
700 let type_key = example.gold_type.as_label().to_string();
701 let type_stats = stats.per_type.entry(type_key).or_insert((0, 0));
702 type_stats.1 += 1;
703 if correct {
704 type_stats.0 += 1;
705 }
706 }
707 }
708
709 /// Add examples from gold entities and backend predictions.
710 ///
711 /// Runs each backend on the text and compares to gold entities.
712 pub fn add_from_backends(
713 &mut self,
714 text: &str,
715 gold_entities: &[Entity],
716 backends: &[(&str, &dyn Model)],
717 ) {
718 // Get predictions from each backend
719 let mut backend_preds: HashMap<String, Vec<Entity>> = HashMap::new();
720 for (name, backend) in backends {
721 if let Ok(entities) = backend.extract_entities(text, None) {
722 backend_preds.insert(name.to_string(), entities);
723 }
724 }
725
726 // Match predictions to gold entities
727 for gold in gold_entities {
728 let mut example = WeightTrainingExample {
729 text: gold.text.clone(),
730 gold_type: gold.entity_type.clone(),
731 start: gold.start,
732 end: gold.end,
733 predictions: Vec::new(),
734 };
735
736 for (backend_name, entities) in &backend_preds {
737 // Find matching prediction (same span)
738 for pred in entities {
739 if pred.start == gold.start && pred.end == gold.end {
740 example.predictions.push((
741 backend_name.clone(),
742 pred.entity_type.clone(),
743 pred.confidence,
744 ));
745 break;
746 }
747 }
748 }
749
750 if !example.predictions.is_empty() {
751 self.add_example(&example);
752 }
753 }
754 }
755
756 /// Learn optimal weights from accumulated statistics.
757 ///
758 /// Uses precision-based weighting with Laplace smoothing.
759 pub fn learn_weights(&self) -> HashMap<String, BackendWeight> {
760 let mut weights = HashMap::new();
761
762 for (backend_name, stats) in &self.backend_stats {
763 // Smoothed precision: (correct + smoothing) / (total + 2*smoothing)
764 let smoothed_precision = (stats.correct as f64 + self.smoothing)
765 / (stats.total as f64 + 2.0 * self.smoothing);
766
767 // Per-type weights
768 let mut type_weights = TypeWeights::default();
769 for (type_key, (correct, total)) in &stats.per_type {
770 let type_precision =
771 (*correct as f64 + self.smoothing) / (*total as f64 + 2.0 * self.smoothing);
772
773 match type_key.as_str() {
774 "PER" | "PERSON" => type_weights.person = type_precision,
775 "ORG" | "ORGANIZATION" => type_weights.organization = type_precision,
776 "LOC" | "LOCATION" | "GPE" => type_weights.location = type_precision,
777 "DATE" => type_weights.date = type_precision,
778 "MONEY" => type_weights.money = type_precision,
779 _ => type_weights.other = type_precision,
780 }
781 }
782
783 weights.insert(
784 backend_name.clone(),
785 BackendWeight {
786 overall: smoothed_precision,
787 per_type: Some(type_weights),
788 },
789 );
790 }
791
792 weights
793 }
794
795 /// Get statistics for a backend.
796 pub fn get_stats(&self, backend_name: &str) -> Option<&BackendStats> {
797 self.backend_stats.get(backend_name)
798 }
799
800 /// Get all backend names.
801 pub fn backend_names(&self) -> Vec<&String> {
802 self.backend_stats.keys().collect()
803 }
804}
805
806// =============================================================================
807// Tests
808// =============================================================================
809
810#[cfg(test)]
811mod tests;