Skip to main content

anno/backends/
inference.rs

1//! Inference abstractions shared across `anno` backends.
2//!
3//! This module is mostly **plumbing**: common traits, data shapes, and small
4//! utilities used by multiple NER / IE backends (including “fixed-label” and
5//! “open/zero-shot” styles).
6//!
7//! Some of the terminology and design choices correspond to well-known
8//! architectures in the NER/IE literature, but the code here should be treated
9//! as an implementation substrate, not a verbatim reproduction of any single
10//! paper’s experiment section.
11//!
12//! ## Paper pointers (context only)
13//!
14//! - GLiNER: arXiv:2311.08526
15//! - UniversalNER: arXiv:2308.03279
16//! - W2NER: arXiv:2112.10070
17//! - ModernBERT: arXiv:2412.13663
18
19use std::borrow::Cow;
20use std::collections::HashMap;
21
22use crate::{Entity, EntityType};
23use anno_core::{RaggedBatch, Relation, SpanCandidate};
24
25// =============================================================================
26// Modality Types
27// =============================================================================
28
29/// Input modality for the encoder.
30///
31/// Supports text, images, and hybrid (OCR + visual) inputs.
32/// This enables ColPali-style visual document understanding.
33#[derive(Debug, Clone)]
34pub enum ModalityInput<'a> {
35    /// Plain text input
36    Text(Cow<'a, str>),
37    /// Image bytes (PNG/JPEG)
38    Image {
39        /// Raw image bytes
40        data: Cow<'a, [u8]>,
41        /// Image format hint
42        format: ImageFormat,
43    },
44    /// Hybrid: text with visual location (e.g., OCR result)
45    Hybrid {
46        /// Extracted text
47        text: Cow<'a, str>,
48        /// Visual bounding boxes for each token/word
49        visual_positions: Vec<VisualPosition>,
50    },
51}
52
53/// Image format hint for decoding.
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
55pub enum ImageFormat {
56    /// PNG format
57    #[default]
58    Png,
59    /// JPEG format
60    Jpeg,
61    /// WebP format
62    Webp,
63    /// Unknown/auto-detect
64    Unknown,
65}
66
67/// Visual position of a text token in an image.
68#[derive(Debug, Clone, Copy)]
69pub struct VisualPosition {
70    /// Token/word index
71    pub token_idx: u32,
72    /// Normalized x coordinate (0.0-1.0)
73    pub x: f32,
74    /// Normalized y coordinate (0.0-1.0)
75    pub y: f32,
76    /// Normalized width (0.0-1.0)
77    pub width: f32,
78    /// Normalized height (0.0-1.0)
79    pub height: f32,
80    /// Page number (for multi-page documents)
81    pub page: u32,
82}
83
84// =============================================================================
85// Semantic Registry (Pre-computed Label Embeddings)
86// =============================================================================
87
88/// A frozen, pre-computed registry of entity and relation types.
89///
90/// # Motivation
91///
92/// The `SemanticRegistry` is the "knowledge base" of a bi-encoder NER system.
93/// It stores pre-computed embeddings for all entity/relation types, enabling:
94///
95/// - **Zero-shot**: Add new types without retraining
96/// - **Speed**: Encode labels once, reuse forever
97/// - **Semantics**: Rich descriptions enable better matching
98///
99/// # Architecture
100///
101/// ```text
102/// ┌────────────────────────────────────────────────────────────────┐
103/// │                     SemanticRegistry                           │
104/// ├────────────────────────────────────────────────────────────────┤
105/// │  labels: [                                                     │
106/// │    { slug: "person", description: "named individual human" }   │
107/// │    { slug: "organization", description: "company or group" }   │
108/// │    { slug: "CEO_OF", description: "leads organization" }       │
109/// │  ]                                                             │
110/// │                                                                │
111/// │  embeddings: [768 floats] [768 floats] [768 floats]            │
112/// │              └────┬────┘  └────┬────┘  └────┬────┘             │
113/// │                   ▲            ▲            ▲                  │
114/// │              person        organization   CEO_OF               │
115/// │                                                                │
116/// │  label_index: { "person" → 0, "organization" → 1, ... }        │
117/// └────────────────────────────────────────────────────────────────┘
118/// ```
119///
120/// # Bi-Encoder Efficiency
121///
122/// The key insight from GLiNER is that label embeddings can be computed once
123/// and reused across all inference requests:
124///
125/// | Approach | Cost per query | Benefit |
126/// |----------|----------------|---------|
127/// | Cross-encoder | O(N × L) | Better accuracy |
128/// | Bi-encoder | O(N) + O(L) | Much faster, labels cached |
129///
130/// # Example
131///
132/// ```ignore
133/// use anno::SemanticRegistry;
134///
135/// // Build registry (expensive, do once at startup)
136/// let registry = SemanticRegistry::builder()
137///     .add_entity("person", "A named individual human being")
138///     .add_entity("organization", "A company, institution, or organized group")
139///     .add_relation("CEO_OF", "Chief executive officer of an organization")
140///     .build(&label_encoder)?;
141///
142/// // Use registry for all inference (cheap, cached embeddings)
143/// for document in documents {
144///     let entities = engine.extract(&document, &registry)?;
145/// }
146/// ```
147///
148/// # Adding Custom Types
149///
150/// ```ignore
151/// // Domain-specific medical entities
152/// let medical_registry = SemanticRegistry::builder()
153///     .add_entity("drug", "A pharmaceutical compound or medication")
154///     .add_entity("disease", "A medical condition or illness")
155///     .add_entity("gene", "A genetic sequence encoding a protein")
156///     .add_relation("TREATS", "Drug is used to treat disease")
157///     .add_relation("CAUSES", "Factor causes or leads to condition")
158///     .build(&label_encoder)?;
159/// ```
160#[derive(Debug, Clone)]
161pub struct SemanticRegistry {
162    /// Pre-computed embeddings for all labels.
163    /// Shape: [num_labels, hidden_dim]
164    /// Stored as flattened f32 for simplicity without tensor deps.
165    pub embeddings: Vec<f32>,
166    /// Hidden dimension of embeddings
167    pub hidden_dim: usize,
168    /// Metadata for each label (index corresponds to embedding row)
169    pub labels: Vec<LabelDefinition>,
170    /// Index mapping from label slug to embedding row
171    pub label_index: HashMap<String, usize>,
172}
173
174/// Definition of a semantic label (entity type or relation type).
175#[derive(Debug, Clone)]
176pub struct LabelDefinition {
177    /// Unique identifier (e.g., "person", "CEO_OF")
178    pub slug: String,
179    /// Human-readable description (used for encoding)
180    pub description: String,
181    /// Category: Entity or Relation
182    pub category: LabelCategory,
183    /// Expected source modality
184    pub modality: ModalityHint,
185    /// Minimum confidence threshold for this label
186    pub threshold: f32,
187}
188
189/// Category of semantic label.
190#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
191pub enum LabelCategory {
192    /// Named entity (Person, Organization, Location, etc.)
193    Entity,
194    /// Relation between entities (CEO_OF, LOCATED_IN, etc.)
195    Relation,
196    /// Attribute of an entity (date of birth, revenue, etc.)
197    Attribute,
198}
199
200/// Hint for which modality this label applies to.
201#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
202pub enum ModalityHint {
203    /// Text-only (most entity types)
204    #[default]
205    TextOnly,
206    /// Visual-only (e.g., logos, signatures)
207    VisualOnly,
208    /// Works with both text and visual
209    Any,
210}
211
212impl SemanticRegistry {
213    /// Create a builder for constructing a registry.
214    pub fn builder() -> SemanticRegistryBuilder {
215        SemanticRegistryBuilder::new()
216    }
217
218    /// Get number of labels in the registry.
219    pub fn len(&self) -> usize {
220        self.labels.len()
221    }
222
223    /// Check if registry is empty.
224    pub fn is_empty(&self) -> bool {
225        self.labels.is_empty()
226    }
227
228    /// Get embedding for a label by slug.
229    pub fn get_embedding(&self, slug: &str) -> Option<&[f32]> {
230        let idx = self.label_index.get(slug)?;
231        let start = idx * self.hidden_dim;
232        let end = start + self.hidden_dim;
233        if end <= self.embeddings.len() {
234            Some(&self.embeddings[start..end])
235        } else {
236            None
237        }
238    }
239
240    /// Get all entity labels (excluding relations).
241    pub fn entity_labels(&self) -> impl Iterator<Item = &LabelDefinition> {
242        self.labels
243            .iter()
244            .filter(|l| l.category == LabelCategory::Entity)
245    }
246
247    /// Get all relation labels.
248    pub fn relation_labels(&self) -> impl Iterator<Item = &LabelDefinition> {
249        self.labels
250            .iter()
251            .filter(|l| l.category == LabelCategory::Relation)
252    }
253
254    /// Create a standard NER registry with common entity types.
255    pub fn standard_ner(hidden_dim: usize) -> Self {
256        // Placeholder embeddings - in real use, these would be encoder outputs
257        let labels = vec![
258            LabelDefinition {
259                slug: "person".into(),
260                description: "A named individual human being".into(),
261                category: LabelCategory::Entity,
262                modality: ModalityHint::TextOnly,
263                threshold: 0.5,
264            },
265            LabelDefinition {
266                slug: "organization".into(),
267                description: "A company, institution, agency, or other group".into(),
268                category: LabelCategory::Entity,
269                modality: ModalityHint::TextOnly,
270                threshold: 0.5,
271            },
272            LabelDefinition {
273                slug: "location".into(),
274                description: "A geographical place, city, country, or region".into(),
275                category: LabelCategory::Entity,
276                modality: ModalityHint::TextOnly,
277                threshold: 0.5,
278            },
279            LabelDefinition {
280                slug: "date".into(),
281                description: "A calendar date or time expression".into(),
282                category: LabelCategory::Entity,
283                modality: ModalityHint::TextOnly,
284                threshold: 0.5,
285            },
286            LabelDefinition {
287                slug: "money".into(),
288                description: "A monetary amount with currency".into(),
289                category: LabelCategory::Entity,
290                modality: ModalityHint::TextOnly,
291                threshold: 0.5,
292            },
293        ];
294
295        let num_labels = labels.len();
296        let label_index: HashMap<String, usize> = labels
297            .iter()
298            .enumerate()
299            .map(|(i, l)| (l.slug.clone(), i))
300            .collect();
301
302        // Initialize with zeros (placeholder)
303        let embeddings = vec![0.0f32; num_labels * hidden_dim];
304
305        Self {
306            embeddings,
307            hidden_dim,
308            labels,
309            label_index,
310        }
311    }
312}
313
314/// Builder for SemanticRegistry.
315#[derive(Debug, Default)]
316pub struct SemanticRegistryBuilder {
317    labels: Vec<LabelDefinition>,
318}
319
320impl SemanticRegistryBuilder {
321    /// Create a new builder.
322    pub fn new() -> Self {
323        Self::default()
324    }
325
326    /// Add an entity type.
327    pub fn add_entity(mut self, slug: &str, description: &str) -> Self {
328        self.labels.push(LabelDefinition {
329            slug: slug.into(),
330            description: description.into(),
331            category: LabelCategory::Entity,
332            modality: ModalityHint::TextOnly,
333            threshold: 0.5,
334        });
335        self
336    }
337
338    /// Add a relation type.
339    pub fn add_relation(mut self, slug: &str, description: &str) -> Self {
340        self.labels.push(LabelDefinition {
341            slug: slug.into(),
342            description: description.into(),
343            category: LabelCategory::Relation,
344            modality: ModalityHint::TextOnly,
345            threshold: 0.5,
346        });
347        self
348    }
349
350    /// Add a label with full configuration.
351    pub fn add_label(mut self, label: LabelDefinition) -> Self {
352        self.labels.push(label);
353        self
354    }
355
356    /// Build the registry (placeholder - real impl needs encoder).
357    pub fn build_placeholder(self, hidden_dim: usize) -> SemanticRegistry {
358        let num_labels = self.labels.len();
359        let label_index: HashMap<String, usize> = self
360            .labels
361            .iter()
362            .enumerate()
363            .map(|(i, l)| (l.slug.clone(), i))
364            .collect();
365
366        SemanticRegistry {
367            embeddings: vec![0.0f32; num_labels * hidden_dim],
368            hidden_dim,
369            labels: self.labels,
370            label_index,
371        }
372    }
373}
374
375// =============================================================================
376// Core Encoder Traits (GLiNER/ModernBERT Alignment)
377// =============================================================================
378
379/// Text encoder trait for transformer-based encoders.
380///
381/// # Motivation
382///
383/// Modern NER systems require converting raw text into dense vector representations
384/// that capture semantic meaning. This trait abstracts the encoding step, allowing
385/// different transformer architectures to be used interchangeably.
386///
387/// # Supported Architectures
388///
389/// | Architecture | Context | Key Features | Speed |
390/// |--------------|---------|--------------|-------|
391/// | ModernBERT   | 8,192   | RoPE, GeGLU, unpadded inference | 3x faster |
392/// | DeBERTaV3    | 512     | Disentangled attention | Baseline |
393/// | BERT/RoBERTa | 512     | Classic, widely available | Baseline |
394///
395/// # Research Alignment (ModernBERT, Dec 2024)
396///
397/// From ModernBERT paper (arXiv:2412.13663):
398/// > "Pareto improvements to BERT... encoder-only models offer great
399/// > performance-size tradeoff for retrieval and classification."
400///
401/// Key innovations:
402/// - **Alternating Attention**: Global attention every 3 layers, local (128-token
403///   window) elsewhere. Reduces complexity for long sequences.
404/// - **Unpadding**: "ModernBERT unpads inputs *before* the token embedding layer
405///   and optionally repads model outputs leading to a 10-to-20 percent
406///   performance improvement over previous methods."
407/// - **RoPE**: Rotary positional embeddings enable extrapolation to longer sequences.
408/// - **GeGLU**: Gated activation function improves over GELU.
409///
410/// # Example
411///
412/// ```ignore
413/// use anno::TextEncoder;
414///
415/// fn process_document(encoder: &dyn TextEncoder, text: &str) {
416///     let output = encoder.encode(text).unwrap();
417///     println!("Encoded {} tokens into {} dimensions",
418///              output.num_tokens, output.hidden_dim);
419///
420///     // Token offsets map back to character positions
421///     for (i, (start, end)) in output.token_offsets.iter().enumerate() {
422///         println!("Token {}: chars {}..{}", i, start, end);
423///     }
424/// }
425/// ```
426pub trait TextEncoder: Send + Sync {
427    /// Encode text into token embeddings.
428    ///
429    /// # Arguments
430    /// * `text` - Input text to encode
431    ///
432    /// # Returns
433    /// * Token embeddings as flattened [num_tokens, hidden_dim]
434    /// * Attention mask indicating valid tokens
435    fn encode(&self, text: &str) -> crate::Result<EncoderOutput>;
436
437    /// Encode a batch of texts.
438    ///
439    /// # Arguments
440    /// * `texts` - Batch of input texts
441    ///
442    /// # Returns
443    /// * RaggedBatch containing all embeddings with document boundaries
444    fn encode_batch(&self, texts: &[&str]) -> crate::Result<(Vec<f32>, RaggedBatch)>;
445
446    /// Get the hidden dimension of the encoder.
447    fn hidden_dim(&self) -> usize;
448
449    /// Get the maximum sequence length.
450    fn max_length(&self) -> usize;
451
452    /// Get the encoder architecture name.
453    fn architecture(&self) -> &'static str;
454}
455
456/// Output from text encoding.
457#[derive(Debug, Clone)]
458pub struct EncoderOutput {
459    /// Token embeddings: [num_tokens, hidden_dim]
460    pub embeddings: Vec<f32>,
461    /// Number of tokens
462    pub num_tokens: usize,
463    /// Hidden dimension
464    pub hidden_dim: usize,
465    /// Token-to-character mapping (for span recovery)
466    pub token_offsets: Vec<(usize, usize)>,
467}
468
469/// Label encoder trait for encoding entity type descriptions.
470///
471/// # Motivation
472///
473/// Zero-shot NER works by encoding entity type *descriptions* into the same
474/// vector space as text spans. Instead of training separate classifiers for
475/// each entity type, we compute similarity between spans and label embeddings.
476///
477/// This enables:
478/// - **Unlimited entity types** at inference (no retraining needed)
479/// - **Faster inference** when labels are pre-computed
480/// - **Better generalization** to unseen entity types via semantic similarity
481///
482/// # Research Alignment
483///
484/// From GLiNER bi-encoder (knowledgator/modern-gliner-bi-base-v1.0):
485/// > "textual encoder is ModernBERT-base and entity label encoder is
486/// > sentence transformer - BGE-small-en."
487///
488/// # Example
489///
490/// ```ignore
491/// use anno::LabelEncoder;
492///
493/// fn setup_custom_types(encoder: &dyn LabelEncoder) {
494///     // Encode rich descriptions for better matching
495///     let labels = &[
496///         "a named individual human being",
497///         "a company, institution, or organized group",
498///         "a geographical location, city, country, or region",
499///     ];
500///
501///     let embeddings = encoder.encode_labels(labels).unwrap();
502///     // Store embeddings in SemanticRegistry for fast lookup
503/// }
504/// ```
505pub trait LabelEncoder: Send + Sync {
506    /// Encode a single label description.
507    ///
508    /// # Arguments
509    /// * `label` - Label description (e.g., "a named individual human being")
510    fn encode_label(&self, label: &str) -> crate::Result<Vec<f32>>;
511
512    /// Encode multiple labels.
513    ///
514    /// # Arguments
515    /// * `labels` - Label descriptions
516    ///
517    /// # Returns
518    /// Flattened embeddings: [num_labels, hidden_dim]
519    fn encode_labels(&self, labels: &[&str]) -> crate::Result<Vec<f32>>;
520
521    /// Get the hidden dimension.
522    fn hidden_dim(&self) -> usize;
523}
524
525/// Bi-encoder architecture combining text and label encoders.
526///
527/// # Motivation
528///
529/// The bi-encoder architecture treats NER as a **matching problem** rather than
530/// a classification problem. It encodes text spans and entity labels separately,
531/// then computes similarity scores to determine matches.
532///
533/// ```text
534/// ┌─────────────────┐         ┌─────────────────┐
535/// │   Text Input    │         │  Label Desc.    │
536/// │ "Steve Jobs"    │         │ "person name"   │
537/// └────────┬────────┘         └────────┬────────┘
538///          │                           │
539///          ▼                           ▼
540/// ┌─────────────────┐         ┌─────────────────┐
541/// │  TextEncoder    │         │  LabelEncoder   │
542/// │  (ModernBERT)   │         │  (BGE-small)    │
543/// └────────┬────────┘         └────────┬────────┘
544///          │                           │
545///          ▼                           ▼
546/// ┌─────────────────┐         ┌─────────────────┐
547/// │ Span Embedding  │◄───────►│ Label Embedding │
548/// │   [768]         │ cosine  │   [768]         │
549/// └─────────────────┘ sim     └─────────────────┘
550///                      │
551///                      ▼
552///               Score: 0.92
553/// ```
554///
555/// # Trade-offs
556///
557/// | Aspect | Bi-Encoder | Uni-Encoder |
558/// |--------|------------|-------------|
559/// | Entity types | Unlimited | Fixed at training |
560/// | Inference speed | Faster (pre-compute labels) | Slower |
561/// | Disambiguation | Harder (no label interaction) | Better |
562/// | Generalization | Better to new types | Limited |
563///
564/// # Research Alignment
565///
566/// From GLiNER: "GLiNER frames NER as a matching problem, comparing candidate
567/// spans with entity type embeddings."
568///
569/// From knowledgator: "Bi-encoder architecture brings several advantages...
570/// unlimited entities, faster inference, better generalization."
571///
572/// Drawback: "Lack of inter-label interactions that make it hard to
573/// disambiguate semantically similar but contextually different entities."
574///
575/// # Example
576///
577/// ```ignore
578/// use anno::BiEncoder;
579///
580/// fn extract_custom_entities(bi_enc: &dyn BiEncoder, text: &str) {
581///     let labels = &["software company", "hardware manufacturer", "person"];
582///     let scores = bi_enc.encode_and_match(text, labels, 8).unwrap();
583///
584///     for s in scores.iter().filter(|s| s.score > 0.5) {
585///         println!("Found '{}' as type {} (score: {:.2})",
586///                  &text[s.start..s.end], labels[s.label_idx], s.score);
587///     }
588/// }
589/// ```
590pub trait BiEncoder: Send + Sync {
591    /// Get the text encoder.
592    fn text_encoder(&self) -> &dyn TextEncoder;
593
594    /// Get the label encoder.
595    fn label_encoder(&self) -> &dyn LabelEncoder;
596
597    /// Encode text and labels, compute span-label similarities.
598    ///
599    /// # Arguments
600    /// * `text` - Input text
601    /// * `labels` - Entity type descriptions
602    /// * `max_span_width` - Maximum span width to consider
603    ///
604    /// # Returns
605    /// Similarity scores for each (span, label) pair
606    fn encode_and_match(
607        &self,
608        text: &str,
609        labels: &[&str],
610        max_span_width: usize,
611    ) -> crate::Result<Vec<SpanLabelScore>>;
612}
613
614/// Score for a (span, label) match.
615#[derive(Debug, Clone)]
616pub struct SpanLabelScore {
617    /// Span start (character offset)
618    pub start: usize,
619    /// Span end (character offset, exclusive)
620    pub end: usize,
621    /// Label index
622    pub label_idx: usize,
623    /// Similarity score (0.0 - 1.0)
624    pub score: f32,
625}
626
627// =============================================================================
628// Zero-Shot NER Trait
629// =============================================================================
630
631/// Zero-shot NER for open entity types.
632///
633/// # Motivation
634///
635/// Traditional NER models are trained on fixed taxonomies (PER, ORG, LOC, etc.)
636/// and cannot extract new entity types without retraining. Zero-shot NER solves
637/// this by allowing **arbitrary entity types at inference time**.
638///
639/// Instead of asking "is this a PERSON?", zero-shot NER asks "does this text
640/// span match the description 'a named individual human being'?"
641///
642/// # Use Cases
643///
644/// - **Domain adaptation**: Extract "gene names" or "legal citations" without
645///   training data
646/// - **Custom taxonomies**: Use your own entity hierarchy
647/// - **Rapid prototyping**: Test new entity types before investing in annotation
648///
649/// # Research Alignment
650///
651/// From GLiNER (arXiv:2311.08526):
652/// > "NER model capable of identifying any entity type using a bidirectional
653/// > transformer encoder... provides a practical alternative to traditional
654/// > NER models, which are limited to predefined entity types."
655///
656/// From UniversalNER (arXiv:2308.03279):
657/// > "Large language models demonstrate remarkable generalizability, such as
658/// > understanding arbitrary entities and relations."
659///
660/// # Example
661///
662/// ```ignore
663/// use anno::ZeroShotNER;
664///
665/// fn extract_medical_entities(ner: &dyn ZeroShotNER, clinical_note: &str) {
666///     // Define custom medical entity types at runtime
667///     let types = &["drug name", "disease", "symptom", "dosage"];
668///
669///     let entities = ner.extract_with_types(clinical_note, types, 0.5).unwrap();
670///     for e in entities {
671///         println!("{}: {} (conf: {:.2})", e.entity_type, e.text, e.confidence);
672///     }
673/// }
674///
675/// fn extract_with_descriptions(ner: &dyn ZeroShotNER, text: &str) {
676///     // Even richer: use natural language descriptions
677///     let descriptions = &[
678///         "a medication or pharmaceutical compound",
679///         "a medical condition or illness",
680///         "a physical sensation indicating illness",
681///     ];
682///
683///     let entities = ner.extract_with_descriptions(text, descriptions, 0.5).unwrap();
684/// }
685/// ```
686pub trait ZeroShotNER: Send + Sync {
687    /// Extract entities with custom types.
688    ///
689    /// # Arguments
690    /// * `text` - Input text
691    /// * `entity_types` - Entity type descriptions (arbitrary text, not fixed vocabulary)
692    ///   - Encoded as text embeddings via bi-encoder (semantic matching, not exact string match)
693    ///   - Any string works: `"disease"`, `"pharmaceutical compound"`, `"19th century French philosopher"`
694    ///   - **Replaces default types completely** - model only extracts the specified types
695    ///   - To include defaults, pass them explicitly: `&["person", "organization", "disease"]`
696    /// * `threshold` - Confidence threshold (0.0 - 1.0)
697    ///
698    /// # Returns
699    /// Entities with their matched types
700    ///
701    /// # Behavior
702    ///
703    /// - **Arbitrary text**: Type hints are not fixed vocabulary. They're encoded as embeddings,
704    ///   so semantic similarity determines matches (not exact string matching).
705    /// - **Replace, don't union**: This method completely replaces default entity types.
706    ///   The model only extracts the types you specify.
707    /// - **Semantic matching**: Uses cosine similarity between text span embeddings and label embeddings.
708    fn extract_with_types(
709        &self,
710        text: &str,
711        entity_types: &[&str],
712        threshold: f32,
713    ) -> crate::Result<Vec<Entity>>;
714
715    /// Extract entities with natural language descriptions.
716    ///
717    /// # Arguments
718    /// * `text` - Input text
719    /// * `descriptions` - Natural language descriptions of what to extract
720    ///   - Encoded as text embeddings (same as `extract_with_types`)
721    ///   - Examples: `"companies headquartered in Europe"`, `"diseases affecting the heart"`
722    ///   - **Replaces default types completely** - model only extracts the specified descriptions
723    /// * `threshold` - Confidence threshold
724    ///
725    /// # Behavior
726    ///
727    /// Same as `extract_with_types`, but accepts natural language descriptions instead of
728    /// short type labels. Both methods encode labels as embeddings and use semantic matching.
729    fn extract_with_descriptions(
730        &self,
731        text: &str,
732        descriptions: &[&str],
733        threshold: f32,
734    ) -> crate::Result<Vec<Entity>>;
735
736    /// Get default entity types for this model.
737    ///
738    /// Returns the entity types used by `extract_entities()` (via `Model` trait).
739    /// Useful for extending defaults: combine with custom types and pass to `extract_with_types()`.
740    ///
741    /// # Example: Extending defaults
742    ///
743    /// ```ignore
744    /// use anno::ZeroShotNER;
745    ///
746    /// let ner: &dyn ZeroShotNER = ...;
747    /// let defaults = ner.default_types();
748    ///
749    /// // Combine defaults with custom types
750    /// let mut types: Vec<&str> = defaults.to_vec();
751    /// types.extend(&["disease", "medication"]);
752    ///
753    /// let entities = ner.extract_with_types(text, &types, 0.5)?;
754    /// ```
755    fn default_types(&self) -> &[&'static str];
756}
757
758// =============================================================================
759// Relation Extractor Trait
760// =============================================================================
761
762/// Joint entity and relation extraction.
763///
764/// # Motivation
765///
766/// Real-world information extraction often requires both entities AND their
767/// relationships. For example, extracting "Steve Jobs" and "Apple" is useful,
768/// but knowing "Steve Jobs FOUNDED Apple" is far more valuable.
769///
770/// Joint extraction (vs pipeline) is preferred because:
771/// - **Error propagation**: Pipeline errors compound (bad entities → bad relations)
772/// - **Shared context**: Entities and relations inform each other
773/// - **Efficiency**: Single forward pass instead of two
774///
775/// # Architecture
776///
777/// ```text
778/// Input: "Steve Jobs founded Apple in 1976."
779///                │
780///                ▼
781/// ┌──────────────────────────────────┐
782/// │     Shared Encoder (BERT)        │
783/// └──────────────────────────────────┘
784///                │
785///         ┌──────┴──────┐
786///         ▼             ▼
787/// ┌───────────────┐  ┌───────────────┐
788/// │ Entity Head   │  │ Relation Head │
789/// │ (span class.) │  │ (pair class.) │
790/// └───────┬───────┘  └───────┬───────┘
791///         │                  │
792///         ▼                  ▼
793/// Entities:              Relations:
794/// - Steve Jobs [PER]     - (Steve Jobs, FOUNDED, Apple)
795/// - Apple [ORG]          - (Apple, FOUNDED_IN, 1976)
796/// - 1976 [DATE]
797/// ```
798///
799/// # Research Alignment
800///
801/// From GLiNER multi-task (arXiv:2406.12925):
802/// > "Generalist Lightweight Model for Various Information Extraction Tasks...
803/// > joint entity and relation extraction."
804///
805/// From W2NER (arXiv:2112.10070):
806/// > "Unified Named Entity Recognition as Word-Word Relation Classification...
807/// > handles flat, overlapped, and discontinuous NER."
808///
809/// # Example
810///
811/// ```ignore
812/// use anno::RelationExtractor;
813///
814/// fn build_knowledge_graph(extractor: &dyn RelationExtractor, text: &str) {
815///     let entity_types = &["person", "organization", "date"];
816///     let relation_types = &["founded", "works_for", "acquired"];
817///
818///     let result = extractor.extract_with_relations(
819///         text, entity_types, relation_types, 0.5
820///     ).unwrap();
821///
822///     // Build graph nodes from entities
823///     for e in &result.entities {
824///         println!("Node: {} ({})", e.text, e.entity_type);
825///     }
826///
827///     // Build graph edges from relations
828///     for r in &result.relations {
829///         let head = &result.entities[r.head_idx];
830///         let tail = &result.entities[r.tail_idx];
831///         println!("Edge: {} --[{}]--> {}", head.text, r.relation_type, tail.text);
832///     }
833/// }
834/// ```
835pub trait RelationExtractor: Send + Sync {
836    /// Extract entities and relations jointly.
837    ///
838    /// # Arguments
839    /// * `text` - Input text
840    /// * `entity_types` - Entity types to extract
841    /// * `relation_types` - Relation types to extract
842    /// * `threshold` - Confidence threshold
843    ///
844    /// # Returns
845    /// Entities and relations between them
846    fn extract_with_relations(
847        &self,
848        text: &str,
849        entity_types: &[&str],
850        relation_types: &[&str],
851        threshold: f32,
852    ) -> crate::Result<ExtractionWithRelations>;
853}
854
855/// Output from joint entity-relation extraction.
856#[derive(Debug, Clone, Default)]
857pub struct ExtractionWithRelations {
858    /// Extracted entities
859    pub entities: Vec<Entity>,
860    /// Relations between entities (indices into entities vec)
861    pub relations: Vec<RelationTriple>,
862}
863
864/// A relation triple linking two entities.
865#[derive(Debug, Clone)]
866pub struct RelationTriple {
867    /// Index of head entity in entities vec
868    pub head_idx: usize,
869    /// Index of tail entity in entities vec
870    pub tail_idx: usize,
871    /// Relation type
872    pub relation_type: String,
873    /// Confidence score
874    pub confidence: f32,
875}
876
877// =============================================================================
878// Discontinuous Entity Support (W2NER Research)
879// =============================================================================
880
881/// Support for discontinuous entity spans.
882///
883/// # Motivation
884///
885/// Not all entities are contiguous text spans. In coordination structures,
886/// entities can be **discontinuous** - scattered across non-adjacent positions.
887///
888/// # Examples of Discontinuous Entities
889///
890/// ```text
891/// "New York and Los Angeles airports"
892///  ^^^^^^^^     ^^^^^^^^^^^ ^^^^^^^^
893///  └──────────────────────────┘
894///     LOCATION: "New York airports" (discontinuous!)
895///                ^^^^^^^^^^^ ^^^^^^^^
896///                └───────────┘
897///                LOCATION: "Los Angeles airports" (contiguous)
898///
899/// "protein A and B complex"
900///  ^^^^^^^^^ ^^^ ^^^^^^^^^
901///  └────────────────────┘
902///     PROTEIN: "protein A ... complex" (discontinuous!)
903/// ```
904///
905/// # NER Complexity Hierarchy
906///
907/// | Type | Description | Example |
908/// |------|-------------|---------|
909/// | Flat | Non-overlapping spans | "John works at Google" |
910/// | Nested | Overlapping spans | "\[New \[York\] City\]" |
911/// | Discontinuous | Non-contiguous | "New York and LA \[airports\]" |
912///
913/// # Research Alignment
914///
915/// From W2NER (arXiv:2112.10070):
916/// > "Named entity recognition has been involved with three major types,
917/// > including flat, overlapped (aka. nested), and discontinuous NER...
918/// > we propose a novel architecture to model NER as word-word relation
919/// > classification."
920///
921/// W2NER achieves this by building a **handshaking matrix** where each cell
922/// (i, j) indicates whether tokens i and j are part of the same entity.
923///
924/// # Example
925///
926/// ```ignore
927/// use anno::DiscontinuousNER;
928///
929/// fn extract_complex_entities(ner: &dyn DiscontinuousNER, text: &str) {
930///     let types = &["location", "protein"];
931///     let entities = ner.extract_discontinuous(text, types, 0.5).unwrap();
932///
933///     for e in entities {
934///         if e.is_contiguous() {
935///             println!("Contiguous {}: '{}'", e.entity_type, e.text);
936///         } else {
937///             println!("Discontinuous {}: '{}' spans: {:?}",
938///                      e.entity_type, e.text, e.spans);
939///         }
940///     }
941/// }
942/// ```
943pub trait DiscontinuousNER: Send + Sync {
944    /// Extract entities including discontinuous spans.
945    ///
946    /// # Arguments
947    /// * `text` - Input text
948    /// * `entity_types` - Entity types to extract
949    /// * `threshold` - Confidence threshold
950    ///
951    /// # Returns
952    /// Entities, potentially with multiple non-contiguous spans
953    fn extract_discontinuous(
954        &self,
955        text: &str,
956        entity_types: &[&str],
957        threshold: f32,
958    ) -> crate::Result<Vec<DiscontinuousEntity>>;
959}
960
961/// An entity that may span multiple non-contiguous regions.
962#[derive(Debug, Clone)]
963pub struct DiscontinuousEntity {
964    /// The spans that make up this entity (may be non-contiguous)
965    pub spans: Vec<(usize, usize)>,
966    /// Concatenated text from all spans
967    pub text: String,
968    /// Entity type
969    pub entity_type: String,
970    /// Confidence score
971    pub confidence: f32,
972}
973
974impl DiscontinuousEntity {
975    /// Check if this entity is contiguous (single span).
976    pub fn is_contiguous(&self) -> bool {
977        self.spans.len() == 1
978    }
979
980    /// Convert to a standard Entity if contiguous.
981    pub fn to_entity(&self) -> Option<Entity> {
982        if self.is_contiguous() {
983            let (start, end) = self.spans[0];
984            Some(Entity::new(
985                self.text.clone(),
986                EntityType::from_label(&self.entity_type),
987                start,
988                end,
989                self.confidence as f64,
990            ))
991        } else {
992            None
993        }
994    }
995}
996
997// =============================================================================
998// Late Interaction Trait
999// =============================================================================
1000
1001/// The core abstraction for bi-encoder NER scoring.
1002///
1003/// # Motivation
1004///
1005/// "Late interaction" refers to when the text and label representations
1006/// interact: at the very end of the pipeline, after both have been
1007/// independently encoded. This is in contrast to "early fusion" where
1008/// text and labels are concatenated before encoding.
1009///
1010/// ```text
1011///                      Early Fusion             Late Interaction
1012///                      ────────────             ────────────────
1013///
1014/// Encode:          [text + label]              text    label
1015///                        │                       │       │
1016///                        ▼                       ▼       ▼
1017///                    Encoder                  Enc_T   Enc_L
1018///                        │                       │       │
1019///                        ▼                       ▼       ▼
1020///                    Score                   emb_t   emb_l
1021///                                                │       │
1022///                                                └───┬───┘
1023///                                                    ▼
1024///                                              dot(emb_t, emb_l)
1025/// ```
1026///
1027/// Late interaction enables:
1028/// - Pre-computing label embeddings (major speedup)
1029/// - Adding new labels without re-encoding text
1030/// - Parallelizing text and label encoding
1031///
1032/// # The Math
1033///
1034/// ```text
1035/// Score(span, label) = σ(span_emb · label_emb / τ)
1036///
1037/// where:
1038///   σ = sigmoid activation
1039///   · = dot product
1040///   τ = temperature (sharpness parameter)
1041/// ```
1042///
1043/// # Implementations
1044///
1045/// | Interaction | Formula | Speed | Accuracy | Use Case |
1046/// |-------------|---------|-------|----------|----------|
1047/// | DotProduct  | s·l     | Fast  | Good     | General purpose |
1048/// | MaxSim      | max(s·l)| Medium| Better   | Multi-token labels |
1049/// | Bilinear    | s·W·l   | Slow  | Best     | When accuracy critical |
1050///
1051/// # Example
1052///
1053/// ```ignore
1054/// use anno::{LateInteraction, DotProductInteraction};
1055///
1056/// let interaction = DotProductInteraction::with_temperature(20.0);
1057///
1058/// // Span embeddings: 3 spans × 768 dim
1059/// let span_embs: Vec<f32> = get_span_embeddings(&tokens, &candidates);
1060///
1061/// // Label embeddings: 5 labels × 768 dim
1062/// let label_embs: Vec<f32> = registry.all_embeddings();
1063///
1064/// // Compute 3×5 = 15 similarity scores
1065/// let mut scores = interaction.compute_similarity(
1066///     &span_embs, 3, &label_embs, 5, 768
1067/// );
1068/// interaction.apply_sigmoid(&mut scores);
1069///
1070/// // scores[i*5 + j] = similarity between span i and label j
1071/// ```
1072pub trait LateInteraction: Send + Sync {
1073    /// Compute similarity scores between span and label embeddings.
1074    ///
1075    /// # Arguments
1076    /// * `span_embeddings` - Shape: [num_spans, hidden_dim]
1077    /// * `label_embeddings` - Shape: [num_labels, hidden_dim]
1078    ///
1079    /// # Returns
1080    /// Similarity matrix of shape: [num_spans, num_labels]
1081    fn compute_similarity(
1082        &self,
1083        span_embeddings: &[f32],
1084        num_spans: usize,
1085        label_embeddings: &[f32],
1086        num_labels: usize,
1087        hidden_dim: usize,
1088    ) -> Vec<f32>;
1089
1090    /// Apply sigmoid activation to scores.
1091    fn apply_sigmoid(&self, scores: &mut [f32]) {
1092        for s in scores.iter_mut() {
1093            *s = 1.0 / (1.0 + (-*s).exp());
1094        }
1095    }
1096}
1097
1098/// Dot product interaction (default, fast).
1099#[derive(Debug, Clone, Copy, Default)]
1100pub struct DotProductInteraction {
1101    /// Temperature scaling (higher = sharper distribution)
1102    pub temperature: f32,
1103}
1104
1105impl DotProductInteraction {
1106    /// Create with default temperature (1.0).
1107    pub fn new() -> Self {
1108        Self { temperature: 1.0 }
1109    }
1110
1111    /// Create with custom temperature.
1112    #[must_use]
1113    pub fn with_temperature(temperature: f32) -> Self {
1114        Self { temperature }
1115    }
1116}
1117
1118impl LateInteraction for DotProductInteraction {
1119    fn compute_similarity(
1120        &self,
1121        span_embeddings: &[f32],
1122        num_spans: usize,
1123        label_embeddings: &[f32],
1124        num_labels: usize,
1125        hidden_dim: usize,
1126    ) -> Vec<f32> {
1127        let mut scores = vec![0.0f32; num_spans * num_labels];
1128
1129        for s in 0..num_spans {
1130            let span_start = s * hidden_dim;
1131            let span_end = span_start + hidden_dim;
1132            let span_vec = &span_embeddings[span_start..span_end];
1133
1134            for l in 0..num_labels {
1135                let label_start = l * hidden_dim;
1136                let label_end = label_start + hidden_dim;
1137                let label_vec = &label_embeddings[label_start..label_end];
1138
1139                // Dot product
1140                let mut dot: f32 = span_vec
1141                    .iter()
1142                    .zip(label_vec.iter())
1143                    .map(|(a, b)| a * b)
1144                    .sum();
1145
1146                // Temperature scaling
1147                dot *= self.temperature;
1148
1149                scores[s * num_labels + l] = dot;
1150            }
1151        }
1152
1153        scores
1154    }
1155}
1156
1157/// MaxSim interaction (ColBERT-style, better for phrases).
1158#[derive(Debug, Clone, Copy, Default)]
1159pub struct MaxSimInteraction {
1160    /// Temperature scaling
1161    pub temperature: f32,
1162}
1163
1164impl MaxSimInteraction {
1165    /// Create with default settings.
1166    pub fn new() -> Self {
1167        Self { temperature: 1.0 }
1168    }
1169}
1170
1171impl LateInteraction for MaxSimInteraction {
1172    fn compute_similarity(
1173        &self,
1174        span_embeddings: &[f32],
1175        num_spans: usize,
1176        label_embeddings: &[f32],
1177        num_labels: usize,
1178        hidden_dim: usize,
1179    ) -> Vec<f32> {
1180        // For single-vector embeddings, MaxSim degrades to dot product
1181        // True MaxSim requires multi-vector representations
1182        DotProductInteraction::new().compute_similarity(
1183            span_embeddings,
1184            num_spans,
1185            label_embeddings,
1186            num_labels,
1187            hidden_dim,
1188        )
1189    }
1190}
1191
1192// =============================================================================
1193// Span Representation
1194// =============================================================================
1195
1196/// Configuration for span representation.
1197///
1198/// # Research Context (Deep Span Representations, arXiv:2210.04182)
1199///
1200/// From "Deep Span Representations for NER":
1201/// > "Existing span-based NER systems **shallowly aggregate** the token
1202/// > representations to span representations. However, this typically results
1203/// > in significant ineffectiveness for **long-span entities**."
1204///
1205/// Common span representation strategies:
1206///
1207/// | Method | Formula | Pros | Cons |
1208/// |--------|---------|------|------|
1209/// | Concat | [h_i; h_j] | Simple, fast | Ignores middle tokens |
1210/// | Pooling | mean(h_i:h_j) | Uses all tokens | Loses boundary info |
1211/// | Attention | attn(h_i:h_j) | Learnable | Expensive |
1212/// | GLiNER | FFN([h_i; h_j; w]) | Balanced | Requires width emb |
1213///
1214/// # Recommendation (GLiNER Default)
1215///
1216/// For most use cases, concatenating first + last token embeddings with
1217/// a width embedding provides the best tradeoff:
1218/// - O(N) complexity (vs O(N²) for all-pairs attention)
1219/// - Captures boundary positions (critical for NER)
1220/// - Width embedding disambiguates "I" vs "New York City"
1221#[derive(Debug, Clone)]
1222pub struct SpanRepConfig {
1223    /// Hidden dimension of the encoder
1224    pub hidden_dim: usize,
1225    /// Maximum span width (in tokens)
1226    ///
1227    /// GLiNER uses K=12: "to keep linear complexity without harming recall."
1228    /// Wider spans rarely contain coherent entities.
1229    pub max_width: usize,
1230    /// Whether to include width embeddings
1231    ///
1232    /// Critical for distinguishing spans of different lengths
1233    /// with similar boundary tokens.
1234    pub use_width_embeddings: bool,
1235    /// Width embedding dimension (typically hidden_dim / 4)
1236    pub width_emb_dim: usize,
1237}
1238
1239impl Default for SpanRepConfig {
1240    fn default() -> Self {
1241        Self {
1242            hidden_dim: 768,
1243            max_width: 12,
1244            use_width_embeddings: true,
1245            width_emb_dim: 192, // 768 / 4
1246        }
1247    }
1248}
1249
1250/// Computes span representations from token embeddings.
1251///
1252/// # Research Alignment (GLiNER, NAACL 2024)
1253///
1254/// From the GLiNER paper (arXiv:2311.08526):
1255/// > "The representation of a span starting at position i and ending at
1256/// > position j in the input text, S_ij ∈ R^D, is computed as:
1257/// > **S_ij = FFN(h_i ⊗ h_j)**
1258/// > where FFN denotes a two-layer feedforward network, and ⊗ represents
1259/// > the concatenation operation."
1260///
1261/// The paper also notes:
1262/// > "We set an upper bound to the length (K=12) of the span in order to
1263/// > keep linear complexity in the size of the input text, without harming recall."
1264///
1265/// # Span Representation Formula
1266///
1267/// ```text
1268/// span_emb = FFN(Concat(token[i], token[j], width_emb[j-i]))
1269///          = W_2 · ReLU(W_1 · [h_i; h_j; w_{j-i}] + b_1) + b_2
1270/// ```
1271///
1272/// where:
1273/// - h_i = start token embedding
1274/// - h_j = end token embedding
1275/// - w_{j-i} = learned width embedding (captures span length)
1276///
1277/// This is the "gnarly bit" from GLiNER that enables zero-shot matching.
1278///
1279/// # Alternative: Global Pointer (arXiv:2208.03054)
1280///
1281/// Instead of enumerating spans, Global Pointer uses RoPE (rotary position
1282/// embeddings) to predict (start, end) pairs simultaneously:
1283///
1284/// ```text
1285/// score(i, j) = q_i^T * k_j    (where q, k have RoPE applied)
1286/// ```
1287///
1288/// Advantages:
1289/// - No explicit span enumeration needed
1290/// - Naturally handles nested entities
1291/// - More parameter-efficient
1292///
1293/// GLiNER-style enumeration is still preferred for zero-shot because
1294/// it allows pre-computing label embeddings.
1295pub struct SpanRepresentationLayer {
1296    /// Configuration
1297    pub config: SpanRepConfig,
1298    /// Projection weights: [input_dim, hidden_dim]
1299    pub projection_weights: Vec<f32>,
1300    /// Projection bias: \[hidden_dim\]
1301    pub projection_bias: Vec<f32>,
1302    /// Width embeddings: [max_width, width_emb_dim]
1303    pub width_embeddings: Vec<f32>,
1304}
1305
1306impl SpanRepresentationLayer {
1307    /// Create a new span representation layer with random initialization.
1308    pub fn new(config: SpanRepConfig) -> Self {
1309        let input_dim = config.hidden_dim * 2 + config.width_emb_dim;
1310
1311        Self {
1312            projection_weights: vec![0.0f32; input_dim * config.hidden_dim],
1313            projection_bias: vec![0.0f32; config.hidden_dim],
1314            width_embeddings: vec![0.0f32; config.max_width * config.width_emb_dim],
1315            config,
1316        }
1317    }
1318
1319    /// Compute span representations from token embeddings.
1320    ///
1321    /// # Arguments
1322    /// * `token_embeddings` - Flattened [num_tokens, hidden_dim]
1323    /// * `candidates` - Span candidates with start/end indices
1324    ///
1325    /// # Returns
1326    /// Span embeddings: [num_candidates, hidden_dim]
1327    pub fn forward(
1328        &self,
1329        token_embeddings: &[f32],
1330        candidates: &[SpanCandidate],
1331        batch: &RaggedBatch,
1332    ) -> Vec<f32> {
1333        let hidden_dim = self.config.hidden_dim;
1334        let width_emb_dim = self.config.width_emb_dim;
1335        let max_width = self.config.max_width;
1336
1337        // Check for overflow in allocation
1338        let total_elements = match candidates.len().checked_mul(hidden_dim) {
1339            Some(v) => v,
1340            None => {
1341                log::warn!(
1342                    "Span embedding allocation overflow: {} candidates * {} hidden_dim, returning empty",
1343                    candidates.len(), hidden_dim
1344                );
1345                return vec![];
1346            }
1347        };
1348        let mut span_embeddings = vec![0.0f32; total_elements];
1349
1350        for (span_idx, candidate) in candidates.iter().enumerate() {
1351            // Get document token range
1352            let doc_range = match batch.doc_range(candidate.doc_idx as usize) {
1353                Some(r) => r,
1354                None => continue,
1355            };
1356
1357            // Validate span before computing global indices
1358            if candidate.end <= candidate.start {
1359                log::warn!(
1360                    "Invalid span candidate: end ({}) <= start ({})",
1361                    candidate.end,
1362                    candidate.start
1363                );
1364                continue;
1365            }
1366
1367            // Global token indices
1368            let start_global = doc_range.start + candidate.start as usize;
1369            let end_global = doc_range.start + (candidate.end as usize) - 1; // Safe now that we validated
1370
1371            // Bounds check - must ensure both start and end slices fit
1372            // Use checked arithmetic to prevent overflow
1373            let start_byte = match start_global.checked_mul(hidden_dim) {
1374                Some(v) => v,
1375                None => {
1376                    log::warn!(
1377                        "Token index overflow: start_global={} * hidden_dim={}",
1378                        start_global,
1379                        hidden_dim
1380                    );
1381                    continue;
1382                }
1383            };
1384            let start_end_byte = match (start_global + 1).checked_mul(hidden_dim) {
1385                Some(v) => v,
1386                None => {
1387                    log::warn!(
1388                        "Token index overflow: (start_global+1)={} * hidden_dim={}",
1389                        start_global + 1,
1390                        hidden_dim
1391                    );
1392                    continue;
1393                }
1394            };
1395            let end_byte = match end_global.checked_mul(hidden_dim) {
1396                Some(v) => v,
1397                None => {
1398                    log::warn!(
1399                        "Token index overflow: end_global={} * hidden_dim={}",
1400                        end_global,
1401                        hidden_dim
1402                    );
1403                    continue;
1404                }
1405            };
1406            let end_end_byte = match (end_global + 1).checked_mul(hidden_dim) {
1407                Some(v) => v,
1408                None => {
1409                    log::warn!(
1410                        "Token index overflow: (end_global+1)={} * hidden_dim={}",
1411                        end_global + 1,
1412                        hidden_dim
1413                    );
1414                    continue;
1415                }
1416            };
1417
1418            if start_byte >= token_embeddings.len()
1419                || start_end_byte > token_embeddings.len()
1420                || end_byte >= token_embeddings.len()
1421                || end_end_byte > token_embeddings.len()
1422            {
1423                continue;
1424            }
1425
1426            // Get start and end token embeddings
1427            let start_emb = &token_embeddings[start_byte..start_end_byte];
1428            let end_emb = &token_embeddings[end_byte..end_end_byte];
1429
1430            // Optional width embedding (index = span_len - 1).
1431            let width_emb = if self.config.use_width_embeddings && width_emb_dim > 0 {
1432                let max_width_idx = max_width.saturating_sub(1);
1433                let span_len = candidate.width() as usize;
1434                let width_idx = span_len.saturating_sub(1).min(max_width_idx);
1435
1436                let width_start = width_idx.saturating_mul(width_emb_dim);
1437                let width_end = width_start.saturating_add(width_emb_dim);
1438                if width_end > self.width_embeddings.len() {
1439                    None
1440                } else {
1441                    Some(&self.width_embeddings[width_start..width_end])
1442                }
1443            } else {
1444                None
1445            };
1446
1447            // Baseline span representation: average of boundary embeddings (+ optional width signal).
1448            // This is deterministic and works without learned projection weights.
1449            let output_start = span_idx * hidden_dim;
1450            for h in 0..hidden_dim {
1451                span_embeddings[output_start + h] = (start_emb[h] + end_emb[h]) * 0.5;
1452                if let Some(width_emb) = width_emb {
1453                    if h < width_emb_dim {
1454                        span_embeddings[output_start + h] += width_emb[h] * 0.1;
1455                    }
1456                }
1457            }
1458        }
1459
1460        span_embeddings
1461    }
1462}
1463
1464// =============================================================================
1465// Handshaking Matrix (TPLinker-style Joint Extraction)
1466// =============================================================================
1467
1468/// Result cell in a handshaking matrix.
1469#[derive(Debug, Clone, Copy)]
1470pub struct HandshakingCell {
1471    /// Row index (token i)
1472    pub i: u32,
1473    /// Column index (token j)
1474    pub j: u32,
1475    /// Predicted label index
1476    pub label_idx: u16,
1477    /// Confidence score
1478    pub score: f32,
1479}
1480
1481/// Handshaking matrix for joint entity-relation extraction.
1482///
1483/// # Research Alignment (W2NER, AAAI 2022)
1484///
1485/// From the W2NER paper (arXiv:2112.10070):
1486/// > "We present a novel alternative by modeling the unified NER as word-word
1487/// > relation classification, namely W2NER. The architecture resolves the kernel
1488/// > bottleneck of unified NER by effectively modeling the neighboring relations
1489/// > between entity words with **Next-Neighboring-Word (NNW)** and
1490/// > **Tail-Head-Word-* (THW-*)** relations."
1491///
1492/// In TPLinker/W2NER, we don't just tag tokens - we tag token PAIRS.
1493/// The matrix M\[i,j\] contains the label for the span (i, j).
1494///
1495/// # Key Relations
1496///
1497/// | Relation | Description | Purpose |
1498/// |----------|-------------|---------|
1499/// | NNW | Next-Neighboring-Word | Links adjacent tokens within entity |
1500/// | THW-* | Tail-Head-Word | Links end of one entity to start of next |
1501///
1502/// # Benefits
1503///
1504/// - Overlapping entities (same token in multiple spans)
1505/// - Joint entity-relation extraction in one pass
1506/// - Explicit boundary modeling
1507/// - Handles flat, nested, AND discontinuous NER in one model
1508pub struct HandshakingMatrix {
1509    /// Non-zero cells (sparse representation)
1510    pub cells: Vec<HandshakingCell>,
1511    /// Sequence length
1512    pub seq_len: usize,
1513    /// Number of labels
1514    pub num_labels: usize,
1515}
1516
1517impl HandshakingMatrix {
1518    /// Create from dense scores with thresholding.
1519    ///
1520    /// # Arguments
1521    /// * `scores` - Dense [seq_len, seq_len, num_labels] scores
1522    /// * `threshold` - Minimum score to keep
1523    pub fn from_dense(scores: &[f32], seq_len: usize, num_labels: usize, threshold: f32) -> Self {
1524        // Performance: Pre-allocate cells vec with estimated capacity
1525        // Most matrices have sparse cells (only high-scoring ones), so we estimate conservatively
1526        let estimated_capacity = (seq_len * seq_len / 10).min(1000); // ~10% of cells typically pass threshold
1527        let mut cells = Vec::with_capacity(estimated_capacity);
1528
1529        for i in 0..seq_len {
1530            for j in i..seq_len {
1531                // Upper triangular (i <= j)
1532                for l in 0..num_labels {
1533                    let idx = i * seq_len * num_labels + j * num_labels + l;
1534                    if idx < scores.len() {
1535                        let score = scores[idx];
1536                        if score >= threshold {
1537                            cells.push(HandshakingCell {
1538                                i: i as u32,
1539                                j: j as u32,
1540                                label_idx: l as u16,
1541                                score,
1542                            });
1543                        }
1544                    }
1545                }
1546            }
1547        }
1548
1549        Self {
1550            cells,
1551            seq_len,
1552            num_labels,
1553        }
1554    }
1555
1556    /// Decode entities from handshaking matrix.
1557    ///
1558    /// In W2NER convention, cell (i, j) represents a span where:
1559    /// - j is the start token index
1560    /// - i is the end token index (inclusive, so we add 1 for exclusive end)
1561    pub fn decode_entities<'a>(
1562        &self,
1563        registry: &'a SemanticRegistry,
1564    ) -> Vec<(SpanCandidate, &'a LabelDefinition, f32)> {
1565        let mut entities = Vec::new();
1566
1567        for cell in &self.cells {
1568            if let Some(label) = registry.labels.get(cell.label_idx as usize) {
1569                if label.category == LabelCategory::Entity {
1570                    // W2NER: j=start, i=end (inclusive), so span is [j, i+1)
1571                    entities.push((SpanCandidate::new(0, cell.j, cell.i + 1), label, cell.score));
1572                }
1573            }
1574        }
1575
1576        // Performance: Use unstable sort (we don't need stable sort here)
1577        // Sort by position, then by score (descending)
1578        entities.sort_unstable_by(|a, b| {
1579            a.0.start
1580                .cmp(&b.0.start)
1581                .then_with(|| a.0.end.cmp(&b.0.end))
1582                .then_with(|| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal))
1583        });
1584
1585        // Performance: Pre-allocate kept vec with estimated capacity
1586        // Non-maximum suppression
1587        let mut kept = Vec::with_capacity(entities.len().min(32));
1588        for (span, label, score) in entities {
1589            let overlaps = kept.iter().any(|(s, _, _): &(SpanCandidate, _, _)| {
1590                !(span.end <= s.start || s.end <= span.start)
1591            });
1592            if !overlaps {
1593                kept.push((span, label, score));
1594            }
1595        }
1596
1597        kept
1598    }
1599}
1600
1601// =============================================================================
1602// Coreference Resolution
1603// =============================================================================
1604
1605/// A coreference cluster (mentions referring to same entity).
1606#[derive(Debug, Clone)]
1607pub struct CoreferenceCluster {
1608    /// Cluster ID
1609    pub id: u64,
1610    /// Member entities (indices into entity list)
1611    pub members: Vec<usize>,
1612    /// Representative entity index (most informative mention)
1613    pub representative: usize,
1614    /// Canonical name (from representative)
1615    pub canonical_name: String,
1616}
1617
1618/// Configuration for coreference resolution.
1619#[derive(Debug, Clone)]
1620pub struct CoreferenceConfig {
1621    /// Minimum cosine similarity to link mentions
1622    pub similarity_threshold: f32,
1623    /// Maximum token distance between coreferent mentions
1624    pub max_distance: Option<usize>,
1625    /// Whether to use exact string matching as a signal
1626    pub use_string_match: bool,
1627}
1628
1629impl Default for CoreferenceConfig {
1630    fn default() -> Self {
1631        Self {
1632            similarity_threshold: 0.85,
1633            max_distance: Some(500),
1634            use_string_match: true,
1635        }
1636    }
1637}
1638
1639/// Resolve coreferences between entities using embedding similarity.
1640///
1641/// # Algorithm
1642///
1643/// 1. Compute pairwise cosine similarity between entity embeddings
1644/// 2. Link entities above threshold (with optional distance constraint)
1645/// 3. Build clusters via transitive closure
1646/// 4. Select representative (longest/most informative mention)
1647///
1648/// # Example
1649///
1650/// Input entities: ["Lynn Conway", "She", "The engineer", "Conway"]
1651/// Output clusters: [{0, 1, 2, 3}] with canonical_name = "Lynn Conway"
1652pub fn resolve_coreferences(
1653    entities: &[Entity],
1654    embeddings: &[f32], // [num_entities, hidden_dim]
1655    hidden_dim: usize,
1656    config: &CoreferenceConfig,
1657) -> Vec<CoreferenceCluster> {
1658    let n = entities.len();
1659    if n == 0 {
1660        return vec![];
1661    }
1662
1663    // Union-find for clustering
1664    let mut parent: Vec<usize> = (0..n).collect();
1665
1666    fn find(parent: &mut [usize], i: usize) -> usize {
1667        if parent[i] != i {
1668            parent[i] = find(parent, parent[i]);
1669        }
1670        parent[i]
1671    }
1672
1673    fn union(parent: &mut [usize], i: usize, j: usize) {
1674        let pi = find(parent, i);
1675        let pj = find(parent, j);
1676        if pi != pj {
1677            parent[pi] = pj;
1678        }
1679    }
1680
1681    // Check all pairs
1682    for i in 0..n {
1683        for j in (i + 1)..n {
1684            // String match check (fast path)
1685            if config.use_string_match {
1686                let text_i = entities[i].text.to_lowercase();
1687                let text_j = entities[j].text.to_lowercase();
1688                if text_i == text_j || text_i.contains(&text_j) || text_j.contains(&text_i) {
1689                    // Same entity type required
1690                    if entities[i].entity_type == entities[j].entity_type {
1691                        union(&mut parent, i, j);
1692                        continue;
1693                    }
1694                }
1695            }
1696
1697            // Distance check
1698            if let Some(max_dist) = config.max_distance {
1699                let dist = if entities[i].end <= entities[j].start {
1700                    entities[j].start - entities[i].end
1701                } else {
1702                    entities[i].start.saturating_sub(entities[j].end)
1703                };
1704                if dist > max_dist {
1705                    continue;
1706                }
1707            }
1708
1709            // Embedding similarity
1710            if embeddings.len() >= (j + 1) * hidden_dim {
1711                let emb_i = &embeddings[i * hidden_dim..(i + 1) * hidden_dim];
1712                let emb_j = &embeddings[j * hidden_dim..(j + 1) * hidden_dim];
1713
1714                let similarity = cosine_similarity(emb_i, emb_j);
1715
1716                if similarity >= config.similarity_threshold {
1717                    // Same entity type required
1718                    if entities[i].entity_type == entities[j].entity_type {
1719                        union(&mut parent, i, j);
1720                    }
1721                }
1722            }
1723        }
1724    }
1725
1726    // Build clusters
1727    let mut cluster_members: HashMap<usize, Vec<usize>> = HashMap::new();
1728    for i in 0..n {
1729        let root = find(&mut parent, i);
1730        cluster_members.entry(root).or_default().push(i);
1731    }
1732
1733    // Convert to CoreferenceCluster
1734    let mut clusters = Vec::new();
1735    let mut cluster_id = 0u64;
1736
1737    for (_root, members) in cluster_members {
1738        if members.len() > 1 {
1739            // Find representative (longest mention)
1740            let representative = *members
1741                .iter()
1742                .max_by_key(|&&i| entities[i].text.len())
1743                .unwrap_or(&members[0]);
1744
1745            clusters.push(CoreferenceCluster {
1746                id: cluster_id,
1747                members,
1748                representative,
1749                canonical_name: entities[representative].text.clone(),
1750            });
1751            cluster_id += 1;
1752        }
1753    }
1754
1755    clusters
1756}
1757
1758/// Compute cosine similarity between two vectors.
1759///
1760/// Returns a value in [-1.0, 1.0] where:
1761/// - 1.0 = identical direction
1762/// - 0.0 = orthogonal
1763/// - -1.0 = opposite direction
1764pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
1765    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
1766    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
1767    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
1768
1769    if norm_a > 0.0 && norm_b > 0.0 {
1770        dot / (norm_a * norm_b)
1771    } else {
1772        0.0
1773    }
1774}
1775
1776// =============================================================================
1777// Relation Extraction
1778// =============================================================================
1779
1780/// Configuration for relation extraction.
1781#[derive(Debug, Clone)]
1782pub struct RelationExtractionConfig {
1783    /// Maximum token distance between head and tail
1784    pub max_span_distance: usize,
1785    /// Minimum confidence for relation
1786    pub threshold: f32,
1787    /// Whether to extract relation triggers
1788    pub extract_triggers: bool,
1789}
1790
1791impl Default for RelationExtractionConfig {
1792    fn default() -> Self {
1793        Self {
1794            max_span_distance: 50,
1795            threshold: 0.5,
1796            extract_triggers: true,
1797        }
1798    }
1799}
1800
1801/// Extract relations between entities.
1802///
1803/// # Algorithm (Two-Pass)
1804///
1805/// 1. Run entity NER to find all entity mentions
1806/// 2. For each entity pair within distance threshold:
1807///    - Encode the span between them
1808///    - Match against relation type embeddings
1809///    - Optionally identify trigger span
1810///
1811/// # Returns
1812///
1813/// Relations with head/tail entities and optional trigger spans.
1814pub fn extract_relations(
1815    entities: &[Entity],
1816    text: &str,
1817    registry: &SemanticRegistry,
1818    config: &RelationExtractionConfig,
1819) -> Vec<Relation> {
1820    let mut relations = Vec::new();
1821    // `Entity` spans in anno are character offsets, but slicing a Rust `&str` requires byte
1822    // offsets. Build a converter once so we can safely slice and map trigger spans back.
1823    let span_converter = crate::offset::SpanConverter::new(text);
1824
1825    // Get relation labels
1826    let relation_labels: Vec<_> = registry.relation_labels().collect();
1827    if relation_labels.is_empty() {
1828        return relations;
1829    }
1830
1831    // Check all entity pairs
1832    for (i, head) in entities.iter().enumerate() {
1833        for (j, tail) in entities.iter().enumerate() {
1834            if i == j {
1835                continue;
1836            }
1837
1838            // Check distance
1839            let distance = if head.end <= tail.start {
1840                tail.start - head.end
1841            } else {
1842                head.start.saturating_sub(tail.end)
1843            };
1844
1845            if distance > config.max_span_distance {
1846                continue;
1847            }
1848
1849            // Look for relation triggers in the text between entities
1850            let (span_start, span_end) = if head.end <= tail.start {
1851                (head.end, tail.start)
1852            } else {
1853                (tail.end, head.start)
1854            };
1855
1856            let between_span = span_converter.from_chars(span_start, span_end);
1857            let between_text = text
1858                .get(between_span.byte_start..between_span.byte_end)
1859                .unwrap_or("");
1860
1861            // Simple heuristic: check for common relation indicators
1862            let relation_type = detect_relation_type(head, tail, between_text, &relation_labels);
1863
1864            if let Some((rel_type, mut confidence, trigger)) = relation_type {
1865                // Apply distance penalty: closer entities are more likely to be related
1866                // Confidence decays linearly from 1.0 at distance 0 to 0.5 at max_span_distance
1867                let distance_penalty = if distance < config.max_span_distance {
1868                    let penalty_factor =
1869                        1.0 - (distance as f64 / config.max_span_distance as f64) * 0.5;
1870                    penalty_factor.max(0.5) // Minimum 0.5 confidence even at max distance
1871                } else {
1872                    0.5 // At or beyond max distance, apply minimum confidence
1873                };
1874                confidence *= distance_penalty;
1875
1876                if confidence < config.threshold as f64 {
1877                    continue;
1878                }
1879
1880                // `detect_relation_type` returns byte offsets into `between_text`.
1881                let trigger_span = if config.extract_triggers {
1882                    trigger.map(|(s, e)| {
1883                        let trigger_start_byte = between_span.byte_start.saturating_add(s);
1884                        let trigger_end_byte = between_span.byte_start.saturating_add(e);
1885                        (
1886                            span_converter.byte_to_char(trigger_start_byte),
1887                            span_converter.byte_to_char(trigger_end_byte),
1888                        )
1889                    })
1890                } else {
1891                    None
1892                };
1893
1894                relations.push(Relation {
1895                    head: head.clone(),
1896                    tail: tail.clone(),
1897                    relation_type: rel_type.to_string(),
1898                    trigger_span,
1899                    confidence: confidence.clamp(0.0, 1.0), // Clamp to [0, 1]
1900                });
1901            }
1902        }
1903    }
1904
1905    relations
1906}
1907
1908/// Extract relations as index-based triples (for joint extraction backends).
1909///
1910/// This is the same heuristic logic as [`extract_relations`], but returns
1911/// [`RelationTriple`] with indices into the provided `entities` slice.
1912///
1913/// Notes:
1914/// - Entity spans are **character offsets**.
1915/// - Trigger spans are not currently exposed in `RelationTriple`.
1916#[must_use]
1917pub fn extract_relation_triples(
1918    entities: &[Entity],
1919    text: &str,
1920    registry: &SemanticRegistry,
1921    config: &RelationExtractionConfig,
1922) -> Vec<RelationTriple> {
1923    let mut triples = Vec::new();
1924    if entities.len() < 2 {
1925        return triples;
1926    }
1927
1928    // `Entity` spans are character offsets; slicing needs byte offsets.
1929    let span_converter = crate::offset::SpanConverter::new(text);
1930
1931    let relation_labels: Vec<_> = registry.relation_labels().collect();
1932    if relation_labels.is_empty() {
1933        return triples;
1934    }
1935
1936    for (i, head) in entities.iter().enumerate() {
1937        for (j, tail) in entities.iter().enumerate() {
1938            if i == j {
1939                continue;
1940            }
1941
1942            // Skip overlapping spans (avoids self-nesting artifacts like "New York" vs "York").
1943            if head.start < tail.end && tail.start < head.end {
1944                continue;
1945            }
1946
1947            // Check distance (character offsets)
1948            let distance = if head.end <= tail.start {
1949                tail.start - head.end
1950            } else {
1951                head.start.saturating_sub(tail.end)
1952            };
1953            if distance > config.max_span_distance {
1954                continue;
1955            }
1956
1957            let (span_start, span_end) = if head.end <= tail.start {
1958                (head.end, tail.start)
1959            } else {
1960                (tail.end, head.start)
1961            };
1962
1963            let between_span = span_converter.from_chars(span_start, span_end);
1964            let between_text = text
1965                .get(between_span.byte_start..between_span.byte_end)
1966                .unwrap_or("");
1967
1968            if let Some((rel_type, mut confidence, _trigger)) =
1969                detect_relation_type(head, tail, between_text, &relation_labels)
1970            {
1971                // Apply distance penalty (same logic as extract_relations)
1972                let distance_penalty = if distance < config.max_span_distance {
1973                    let penalty_factor =
1974                        1.0 - (distance as f64 / config.max_span_distance as f64) * 0.5;
1975                    penalty_factor.max(0.5)
1976                } else {
1977                    0.5
1978                };
1979                confidence *= distance_penalty;
1980
1981                if confidence < config.threshold as f64 {
1982                    continue;
1983                }
1984
1985                triples.push(RelationTriple {
1986                    head_idx: i,
1987                    tail_idx: j,
1988                    relation_type: rel_type.to_string(),
1989                    confidence: confidence as f32,
1990                });
1991            }
1992        }
1993    }
1994
1995    triples
1996}
1997
1998/// Result of relation detection: (label, confidence, optional span).
1999type RelationMatch<'a> = (&'a str, f64, Option<(usize, usize)>);
2000
2001/// Detect relation type from context (heuristic fallback).
2002fn detect_relation_type<'a>(
2003    head: &Entity,
2004    tail: &Entity,
2005    between_text: &str,
2006    relation_labels: &[&'a LabelDefinition],
2007) -> Option<RelationMatch<'a>> {
2008    // Use Unicode-aware lowercasing for multilingual support
2009    // Note: For CJK languages, case doesn't apply, but this is safe
2010    let between_lower = between_text.to_lowercase();
2011
2012    // Normalize relation slugs so datasets that use kebab-case / colon-separated schemas
2013    // (e.g. DocRED: "part-of", "general-affiliation") can match our canonical patterns
2014    // (e.g. "PART_OF", "GENERAL_AFFILIATION").
2015    fn norm_rel_slug(s: &str) -> String {
2016        // Uppercase + map non-alphanumerics to '_' so we can compare across naming schemes.
2017        let mut out = String::with_capacity(s.len());
2018        let mut prev_underscore = false;
2019        for ch in s.chars() {
2020            if ch.is_alphanumeric() {
2021                // Keep Unicode letters/digits; uppercase ASCII for stable matching.
2022                if ch.is_ascii_alphabetic() {
2023                    out.push(ch.to_ascii_uppercase());
2024                } else {
2025                    out.push(ch);
2026                }
2027                prev_underscore = false;
2028            } else if !prev_underscore {
2029                out.push('_');
2030                prev_underscore = true;
2031            }
2032        }
2033        while out.starts_with('_') {
2034            out.remove(0);
2035        }
2036        while out.ends_with('_') {
2037            out.pop();
2038        }
2039        out
2040    }
2041
2042    // Common patterns: (relation_slug, triggers, confidence)
2043    struct RelPattern {
2044        slug: &'static str,
2045        triggers: &'static [&'static str],
2046        confidence: f64,
2047    }
2048
2049    let patterns: &[RelPattern] = &[
2050        // Employment relations
2051        RelPattern {
2052            slug: "CEO_OF",
2053            triggers: &[
2054                "ceo of",
2055                "chief executive",
2056                "chief executive officer",
2057                "leads",
2058                "founded",
2059                "founder of",
2060            ],
2061            confidence: 0.8,
2062        },
2063        RelPattern {
2064            slug: "WORKS_FOR",
2065            triggers: &[
2066                "works for",
2067                "works at",
2068                "employed by",
2069                "employee of",
2070                "works with",
2071                "staff at",
2072                "member of",
2073            ],
2074            confidence: 0.7,
2075        },
2076        RelPattern {
2077            slug: "FOUNDED",
2078            triggers: &[
2079                "founded",
2080                "co-founded",
2081                "cofounder",
2082                "started",
2083                "established",
2084                "created",
2085                "launched",
2086            ],
2087            confidence: 0.8,
2088        },
2089        RelPattern {
2090            slug: "MANAGES",
2091            triggers: &[
2092                "manages",
2093                "managing",
2094                "oversees",
2095                "directs",
2096                "supervises",
2097                "runs",
2098            ],
2099            confidence: 0.75,
2100        },
2101        RelPattern {
2102            slug: "REPORTS_TO",
2103            triggers: &["reports to", "reported to", "under", "reports directly to"],
2104            confidence: 0.7,
2105        },
2106        // Location relations
2107        RelPattern {
2108            slug: "LOCATED_IN",
2109            triggers: &[
2110                "in",
2111                "at",
2112                "based in",
2113                "located in",
2114                "headquartered in",
2115                "situated in",
2116                "found in",
2117            ],
2118            confidence: 0.6,
2119        },
2120        RelPattern {
2121            slug: "BORN_IN",
2122            triggers: &[
2123                "born in",
2124                "native of",
2125                "from",
2126                "hails from",
2127                "originated in",
2128            ],
2129            confidence: 0.7,
2130        },
2131        RelPattern {
2132            slug: "LIVES_IN",
2133            triggers: &["lives in", "resides in", "living in", "based in"],
2134            confidence: 0.65,
2135        },
2136        RelPattern {
2137            slug: "DIED_IN",
2138            triggers: &["died in", "passed away in", "deceased in"],
2139            confidence: 0.8,
2140        },
2141        // Temporal relations
2142        RelPattern {
2143            slug: "OCCURRED_ON",
2144            triggers: &["on", "occurred on", "happened on", "took place on", "dated"],
2145            confidence: 0.6,
2146        },
2147        RelPattern {
2148            slug: "STARTED_ON",
2149            triggers: &["started on", "began on", "commenced on", "initiated on"],
2150            confidence: 0.7,
2151        },
2152        RelPattern {
2153            slug: "ENDED_ON",
2154            triggers: &["ended on", "concluded on", "finished on", "completed on"],
2155            confidence: 0.7,
2156        },
2157        // Organizational relations
2158        RelPattern {
2159            slug: "PART_OF",
2160            triggers: &[
2161                "part of",
2162                "member of",
2163                "belongs to",
2164                "subsidiary of",
2165                "division of",
2166                "branch of",
2167            ],
2168            confidence: 0.7,
2169        },
2170        RelPattern {
2171            slug: "ACQUIRED",
2172            triggers: &[
2173                "acquired",
2174                "bought",
2175                "purchased",
2176                "took over",
2177                "merged with",
2178            ],
2179            confidence: 0.75,
2180        },
2181        RelPattern {
2182            slug: "MERGED_WITH",
2183            triggers: &["merged with", "merged into", "combined with", "joined with"],
2184            confidence: 0.8,
2185        },
2186        RelPattern {
2187            slug: "PARENT_OF",
2188            triggers: &["parent of", "parent company of", "owns", "owner of"],
2189            confidence: 0.75,
2190        },
2191        // Social relations
2192        RelPattern {
2193            slug: "MARRIED_TO",
2194            triggers: &["married to", "wed to", "spouse of", "husband of", "wife of"],
2195            confidence: 0.85,
2196        },
2197        RelPattern {
2198            slug: "CHILD_OF",
2199            triggers: &["son of", "daughter of", "child of", "offspring of"],
2200            confidence: 0.8,
2201        },
2202        RelPattern {
2203            slug: "SIBLING_OF",
2204            triggers: &["brother of", "sister of", "sibling of"],
2205            confidence: 0.8,
2206        },
2207        // Academic/Professional
2208        RelPattern {
2209            slug: "STUDIED_AT",
2210            triggers: &[
2211                "studied at",
2212                "attended",
2213                "graduated from",
2214                "alumni of",
2215                "educated at",
2216            ],
2217            confidence: 0.75,
2218        },
2219        RelPattern {
2220            slug: "TEACHES_AT",
2221            triggers: &["teaches at", "professor at", "instructor at", "faculty at"],
2222            confidence: 0.8,
2223        },
2224        // Product/Service relations
2225        RelPattern {
2226            slug: "DEVELOPS",
2227            triggers: &[
2228                "develops",
2229                "created",
2230                "built",
2231                "designed",
2232                "produces",
2233                "manufactures",
2234            ],
2235            confidence: 0.7,
2236        },
2237        RelPattern {
2238            slug: "USES",
2239            triggers: &["uses", "utilizes", "employs", "adopts", "implements"],
2240            confidence: 0.6,
2241        },
2242        // Dataset-style relation labels (DocRED/CHisIEC-like)
2243        //
2244        // These are the *coarse* label names we actually see in the CrossRE/DocRED-style
2245        // exports used by this repo (e.g. `docred_dev.json`), which differ from the
2246        // “canonical” IE labels above.
2247        RelPattern {
2248            slug: "NAMED",
2249            triggers: &[
2250                "called",
2251                "known as",
2252                "also known as",
2253                "named",
2254                "referred to as",
2255                "nickname",
2256            ],
2257            confidence: 0.6,
2258        },
2259        RelPattern {
2260            slug: "TYPE_OF",
2261            triggers: &[
2262                "type of",
2263                "kind of",
2264                "form of",
2265                "a type of",
2266                "is a",
2267                "are a",
2268            ],
2269            confidence: 0.6,
2270        },
2271        RelPattern {
2272            slug: "RELATED_TO",
2273            triggers: &["related to", "associated with", "connected to", "linked to"],
2274            confidence: 0.55,
2275        },
2276        RelPattern {
2277            slug: "ORIGIN",
2278            triggers: &[
2279                "from",
2280                "born",
2281                "originated",
2282                "created by",
2283                "invented by",
2284                "derived from",
2285                "spinoff",
2286                "spin-off",
2287            ],
2288            confidence: 0.55,
2289        },
2290        RelPattern {
2291            slug: "ROLE",
2292            triggers: &[
2293                "president",
2294                "ceo",
2295                "chair",
2296                "director",
2297                "editor",
2298                "producer",
2299                "actor",
2300                "professor",
2301                "fellow",
2302                "member",
2303            ],
2304            confidence: 0.55,
2305        },
2306        RelPattern {
2307            slug: "TEMPORAL",
2308            triggers: &[
2309                "in 19", "in 20", "during", "before", "after", "between", "until", "since",
2310            ],
2311            confidence: 0.5,
2312        },
2313        RelPattern {
2314            slug: "PHYSICAL",
2315            triggers: &["located in", "based in", "headquartered in", "at "],
2316            confidence: 0.55,
2317        },
2318        RelPattern {
2319            slug: "TOPIC",
2320            triggers: &["topic", "about", "on", "regarding", "focused on"],
2321            confidence: 0.5,
2322        },
2323        RelPattern {
2324            slug: "OPPOSITE",
2325            triggers: &["opposite", "contrasts with", "as opposed to"],
2326            confidence: 0.6,
2327        },
2328        RelPattern {
2329            slug: "WIN_DEFEAT",
2330            triggers: &["defeated", "beat", "won", "win", "lose", "lost to"],
2331            confidence: 0.6,
2332        },
2333        RelPattern {
2334            slug: "CAUSE_EFFECT",
2335            triggers: &["caused", "causes", "leads to", "results in", "because"],
2336            confidence: 0.55,
2337        },
2338        RelPattern {
2339            slug: "USAGE",
2340            triggers: &["use", "uses", "used", "using", "utilize", "employ", "adopt"],
2341            confidence: 0.55,
2342        },
2343        RelPattern {
2344            slug: "ARTIFACT",
2345            triggers: &[
2346                "tool",
2347                "library",
2348                "framework",
2349                "system",
2350                "artifact",
2351                "implementation",
2352            ],
2353            confidence: 0.55,
2354        },
2355        RelPattern {
2356            slug: "COMPARE",
2357            triggers: &[
2358                "compare",
2359                "compared to",
2360                "versus",
2361                "vs",
2362                "better than",
2363                "worse than",
2364            ],
2365            confidence: 0.55,
2366        },
2367        RelPattern {
2368            slug: "GENERAL_AFFILIATION",
2369            triggers: &[
2370                "affiliation",
2371                "affiliated with",
2372                "member of",
2373                "part of",
2374                "associated with",
2375            ],
2376            confidence: 0.55,
2377        },
2378        // CHisIEC (classical Chinese) relations (match either simplified or traditional labels)
2379        RelPattern {
2380            slug: "父母",
2381            triggers: &["父", "母", "父母"],
2382            confidence: 0.7,
2383        },
2384        RelPattern {
2385            slug: "兄弟",
2386            triggers: &["兄", "弟", "兄弟"],
2387            confidence: 0.7,
2388        },
2389        RelPattern {
2390            slug: "別名",
2391            triggers: &["別名", "别名"],
2392            confidence: 0.75,
2393        },
2394        RelPattern {
2395            slug: "到達",
2396            triggers: &["到", "至", "達", "到達", "到达"],
2397            confidence: 0.6,
2398        },
2399        RelPattern {
2400            slug: "出生於某地",
2401            triggers: &["生於", "生于", "出生於", "出生于"],
2402            confidence: 0.65,
2403        },
2404        RelPattern {
2405            slug: "任職",
2406            triggers: &["任", "拜", "任職", "任职"],
2407            confidence: 0.6,
2408        },
2409        RelPattern {
2410            slug: "管理",
2411            triggers: &["管", "治", "守", "管理"],
2412            confidence: 0.55,
2413        },
2414        RelPattern {
2415            slug: "駐守",
2416            triggers: &["駐", "驻", "守", "駐守", "驻守"],
2417            confidence: 0.55,
2418        },
2419        RelPattern {
2420            slug: "敵對攻伐",
2421            triggers: &["敵", "敌", "攻", "伐", "戰", "战"],
2422            confidence: 0.55,
2423        },
2424        RelPattern {
2425            slug: "同僚",
2426            triggers: &["同僚"],
2427            confidence: 0.55,
2428        },
2429        RelPattern {
2430            slug: "政治奧援",
2431            triggers: &["奧援", "奥援"],
2432            confidence: 0.55,
2433        },
2434        // Communication/Interaction
2435        RelPattern {
2436            slug: "MET_WITH",
2437            triggers: &["met with", "met", "met up with", "encountered", "saw"],
2438            confidence: 0.65,
2439        },
2440        RelPattern {
2441            slug: "SPOKE_WITH",
2442            triggers: &[
2443                "spoke with",
2444                "talked with",
2445                "discussed with",
2446                "conversed with",
2447            ],
2448            confidence: 0.7,
2449        },
2450        // Ownership
2451        RelPattern {
2452            slug: "OWNS",
2453            triggers: &["owns", "owner of", "possesses", "holds"],
2454            confidence: 0.75,
2455        },
2456        // =========================================================================
2457        // Multilingual relation triggers
2458        // =========================================================================
2459        // Spanish (es)
2460        RelPattern {
2461            slug: "WORKS_FOR",
2462            triggers: &["trabaja en", "trabaja para", "empleado de", "trabaja con"],
2463            confidence: 0.7,
2464        },
2465        RelPattern {
2466            slug: "FOUNDED",
2467            triggers: &["fundó", "fundada", "creó", "creada", "estableció", "inició"],
2468            confidence: 0.8,
2469        },
2470        RelPattern {
2471            slug: "LOCATED_IN",
2472            triggers: &[
2473                "en",
2474                "ubicado en",
2475                "situado en",
2476                "basado en",
2477                "localizado en",
2478            ],
2479            confidence: 0.6,
2480        },
2481        RelPattern {
2482            slug: "BORN_IN",
2483            triggers: &["nació en", "nacido en", "originario de", "de"],
2484            confidence: 0.7,
2485        },
2486        RelPattern {
2487            slug: "LIVES_IN",
2488            triggers: &["cerno en", "reside en", "viviendo en"],
2489            confidence: 0.65,
2490        },
2491        RelPattern {
2492            slug: "MARRIED_TO",
2493            triggers: &["casado con", "casada con", "esposo de", "esposa de"],
2494            confidence: 0.85,
2495        },
2496        // French (fr)
2497        RelPattern {
2498            slug: "WORKS_FOR",
2499            triggers: &[
2500                "travaille pour",
2501                "travaille à",
2502                "employé de",
2503                "travaille avec",
2504            ],
2505            confidence: 0.7,
2506        },
2507        RelPattern {
2508            slug: "FOUNDED",
2509            triggers: &["fondé", "fondée", "créé", "créée", "établi", "établie"],
2510            confidence: 0.8,
2511        },
2512        RelPattern {
2513            slug: "LOCATED_IN",
2514            triggers: &["dans", "à", "situé en", "basé en", "localisé en"],
2515            confidence: 0.6,
2516        },
2517        RelPattern {
2518            slug: "BORN_IN",
2519            triggers: &["né en", "née en", "originaire de", "de"],
2520            confidence: 0.7,
2521        },
2522        RelPattern {
2523            slug: "LIVES_IN",
2524            triggers: &["vit en", "réside en", "vivant en"],
2525            confidence: 0.65,
2526        },
2527        RelPattern {
2528            slug: "MARRIED_TO",
2529            triggers: &["marié avec", "mariée avec", "époux de", "épouse de"],
2530            confidence: 0.85,
2531        },
2532        // German (de)
2533        RelPattern {
2534            slug: "WORKS_FOR",
2535            triggers: &[
2536                "arbeitet für",
2537                "arbeitet bei",
2538                "angestellt bei",
2539                "arbeitet mit",
2540            ],
2541            confidence: 0.7,
2542        },
2543        RelPattern {
2544            slug: "FOUNDED",
2545            triggers: &[
2546                "gegründet",
2547                "gründete",
2548                "erstellt",
2549                "errichtet",
2550                "etabliert",
2551            ],
2552            confidence: 0.8,
2553        },
2554        RelPattern {
2555            slug: "LOCATED_IN",
2556            triggers: &["in", "bei", "situiert in", "basiert in", "befindet sich in"],
2557            confidence: 0.6,
2558        },
2559        RelPattern {
2560            slug: "BORN_IN",
2561            triggers: &["geboren in", "geboren am", "stammt aus", "aus"],
2562            confidence: 0.7,
2563        },
2564        RelPattern {
2565            slug: "LIVES_IN",
2566            triggers: &["lebt in", "wohnt in", "lebend in"],
2567            confidence: 0.65,
2568        },
2569        RelPattern {
2570            slug: "MARRIED_TO",
2571            triggers: &["verheiratet mit", "ehemann von", "ehefrau von"],
2572            confidence: 0.85,
2573        },
2574        // Chinese (zh) - Simplified
2575        RelPattern {
2576            slug: "WORKS_FOR",
2577            triggers: &["为", "在", "工作于", "就职于", "任职于"],
2578            confidence: 0.7,
2579        },
2580        RelPattern {
2581            slug: "FOUNDED",
2582            triggers: &["创立", "创建", "建立", "成立", "创办"],
2583            confidence: 0.8,
2584        },
2585        RelPattern {
2586            slug: "LOCATED_IN",
2587            triggers: &["在", "位于", "坐落于", "地处"],
2588            confidence: 0.6,
2589        },
2590        RelPattern {
2591            slug: "BORN_IN",
2592            triggers: &["出生于", "生于", "来自", "出生于"],
2593            confidence: 0.7,
2594        },
2595        RelPattern {
2596            slug: "LIVES_IN",
2597            triggers: &["居住于", "住在", "生活在"],
2598            confidence: 0.65,
2599        },
2600        RelPattern {
2601            slug: "MARRIED_TO",
2602            triggers: &["与...结婚", "嫁给", "娶了"],
2603            confidence: 0.85,
2604        },
2605        // Japanese (ja)
2606        RelPattern {
2607            slug: "WORKS_FOR",
2608            triggers: &["で働く", "に勤務", "に所属", "で就職"],
2609            confidence: 0.7,
2610        },
2611        RelPattern {
2612            slug: "FOUNDED",
2613            triggers: &["設立", "創立", "設立した", "創設"],
2614            confidence: 0.8,
2615        },
2616        RelPattern {
2617            slug: "LOCATED_IN",
2618            triggers: &["に", "で", "に位置", "に所在"],
2619            confidence: 0.6,
2620        },
2621        RelPattern {
2622            slug: "BORN_IN",
2623            triggers: &["に生まれた", "の出身", "で生まれた"],
2624            confidence: 0.7,
2625        },
2626        RelPattern {
2627            slug: "LIVES_IN",
2628            triggers: &["に住む", "に居住", "に在住"],
2629            confidence: 0.65,
2630        },
2631        RelPattern {
2632            slug: "MARRIED_TO",
2633            triggers: &["と結婚", "と結婚した", "の配偶者"],
2634            confidence: 0.85,
2635        },
2636        // Arabic (ar) - RTL
2637        RelPattern {
2638            slug: "WORKS_FOR",
2639            triggers: &["يعمل في", "يعمل لصالح", "موظف في", "يعمل مع"],
2640            confidence: 0.7,
2641        },
2642        RelPattern {
2643            slug: "FOUNDED",
2644            triggers: &["أسس", "أنشأ", "تأسست", "أنشأت"],
2645            confidence: 0.8,
2646        },
2647        RelPattern {
2648            slug: "LOCATED_IN",
2649            triggers: &["في", "ب", "يقع في", "موجود في"],
2650            confidence: 0.6,
2651        },
2652        RelPattern {
2653            slug: "BORN_IN",
2654            triggers: &["ولد في", "من مواليد", "من"],
2655            confidence: 0.7,
2656        },
2657        RelPattern {
2658            slug: "LIVES_IN",
2659            triggers: &["يعيش في", "يسكن في", "مقيم في"],
2660            confidence: 0.65,
2661        },
2662        RelPattern {
2663            slug: "MARRIED_TO",
2664            triggers: &["متزوج من", "زوج", "زوجة"],
2665            confidence: 0.85,
2666        },
2667        // Russian (ru)
2668        RelPattern {
2669            slug: "WORKS_FOR",
2670            triggers: &["работает в", "работает на", "работает для", "сотрудник"],
2671            confidence: 0.7,
2672        },
2673        RelPattern {
2674            slug: "FOUNDED",
2675            triggers: &["основал", "основала", "создал", "создала", "учредил"],
2676            confidence: 0.8,
2677        },
2678        RelPattern {
2679            slug: "LOCATED_IN",
2680            triggers: &["в", "на", "расположен в", "находится в"],
2681            confidence: 0.6,
2682        },
2683        RelPattern {
2684            slug: "BORN_IN",
2685            triggers: &["родился в", "родилась в", "родом из", "из"],
2686            confidence: 0.7,
2687        },
2688        RelPattern {
2689            slug: "LIVES_IN",
2690            triggers: &["живет в", "проживает в", "живущий в"],
2691            confidence: 0.65,
2692        },
2693        RelPattern {
2694            slug: "MARRIED_TO",
2695            triggers: &["женат на", "замужем за", "супруг", "супруга"],
2696            confidence: 0.85,
2697        },
2698    ];
2699
2700    for pattern in patterns {
2701        // Find the canonical label in the registry (case-insensitive).
2702        // We return the label's *original* slug so callers preserve user-provided casing.
2703        let label = match relation_labels.iter().find(|l| {
2704            // Match both:
2705            // - exact canonical names (e.g. "PART_OF")
2706            // - normalized dataset slugs (e.g. "part-of" -> "PART_OF")
2707            norm_rel_slug(&l.slug) == pattern.slug || l.slug.eq_ignore_ascii_case(pattern.slug)
2708        }) {
2709            Some(l) => *l,
2710            None => continue,
2711        };
2712
2713        for trigger in pattern.triggers {
2714            if let Some(pos) = between_lower.find(trigger) {
2715                // Validate entity types make sense for the relation
2716                let valid = match pattern.slug {
2717                    // Person-Organization relations
2718                    "CEO_OF" | "WORKS_FOR" | "FOUNDED" | "MANAGES" | "REPORTS_TO" => {
2719                        // If either side is unknown/misc, don't reject on type alone (relation datasets
2720                        // often use a richer schema than `EntityType`).
2721                        matches!(
2722                            head.entity_type,
2723                            EntityType::Other(_) | EntityType::Custom { .. }
2724                        ) || matches!(
2725                            tail.entity_type,
2726                            EntityType::Other(_) | EntityType::Custom { .. }
2727                        ) || (matches!(head.entity_type, EntityType::Person)
2728                            && matches!(tail.entity_type, EntityType::Organization))
2729                    }
2730                    // Location relations (any entity can be located in/born in a location)
2731                    "LOCATED_IN" | "BORN_IN" | "LIVES_IN" | "DIED_IN" => {
2732                        matches!(
2733                            tail.entity_type,
2734                            EntityType::Other(_) | EntityType::Custom { .. }
2735                        ) || matches!(tail.entity_type, EntityType::Location)
2736                    }
2737                    // Temporal relations (any entity can have temporal attributes)
2738                    "OCCURRED_ON" | "STARTED_ON" | "ENDED_ON" => {
2739                        matches!(
2740                            tail.entity_type,
2741                            EntityType::Other(_) | EntityType::Custom { .. }
2742                        ) || matches!(tail.entity_type, EntityType::Date | EntityType::Time)
2743                    }
2744                    // Organizational relations
2745                    "PART_OF" | "ACQUIRED" | "MERGED_WITH" | "PARENT_OF" => {
2746                        matches!(
2747                            head.entity_type,
2748                            EntityType::Other(_) | EntityType::Custom { .. }
2749                        ) || matches!(
2750                            tail.entity_type,
2751                            EntityType::Other(_) | EntityType::Custom { .. }
2752                        ) || (matches!(head.entity_type, EntityType::Organization)
2753                            && matches!(tail.entity_type, EntityType::Organization))
2754                    }
2755                    // Social relations
2756                    "MARRIED_TO" | "CHILD_OF" | "SIBLING_OF" => {
2757                        matches!(
2758                            head.entity_type,
2759                            EntityType::Other(_) | EntityType::Custom { .. }
2760                        ) || matches!(
2761                            tail.entity_type,
2762                            EntityType::Other(_) | EntityType::Custom { .. }
2763                        ) || (matches!(head.entity_type, EntityType::Person)
2764                            && matches!(tail.entity_type, EntityType::Person))
2765                    }
2766                    // Academic relations
2767                    "STUDIED_AT" | "TEACHES_AT" => {
2768                        matches!(
2769                            head.entity_type,
2770                            EntityType::Other(_) | EntityType::Custom { .. }
2771                        ) || matches!(
2772                            tail.entity_type,
2773                            EntityType::Other(_) | EntityType::Custom { .. }
2774                        ) || (matches!(head.entity_type, EntityType::Person)
2775                            && matches!(
2776                                tail.entity_type,
2777                                EntityType::Organization | EntityType::Location
2778                            ))
2779                    }
2780                    // Product relations
2781                    "DEVELOPS" | "USES" => {
2782                        matches!(
2783                            head.entity_type,
2784                            EntityType::Other(_) | EntityType::Custom { .. }
2785                        ) || matches!(
2786                            head.entity_type,
2787                            EntityType::Organization | EntityType::Person
2788                        )
2789                    }
2790                    // Interaction relations
2791                    "MET_WITH" | "SPOKE_WITH" => {
2792                        matches!(
2793                            head.entity_type,
2794                            EntityType::Other(_) | EntityType::Custom { .. }
2795                        ) || matches!(
2796                            tail.entity_type,
2797                            EntityType::Other(_) | EntityType::Custom { .. }
2798                        ) || (matches!(head.entity_type, EntityType::Person)
2799                            && matches!(
2800                                tail.entity_type,
2801                                EntityType::Person | EntityType::Organization
2802                            ))
2803                    }
2804                    // Ownership
2805                    "OWNS" => {
2806                        matches!(
2807                            head.entity_type,
2808                            EntityType::Other(_) | EntityType::Custom { .. }
2809                        ) || matches!(
2810                            head.entity_type,
2811                            EntityType::Person | EntityType::Organization
2812                        )
2813                    }
2814                    _ => true, // Default: allow any combination
2815                };
2816
2817                if valid {
2818                    return Some((
2819                        label.slug.as_str(),
2820                        pattern.confidence,
2821                        Some((pos, pos + trigger.len())),
2822                    ));
2823                }
2824            }
2825        }
2826    }
2827
2828    None
2829}
2830
2831// =============================================================================
2832// Binary Embeddings for Fast Blocking (Research: Hamming Distance)
2833// =============================================================================
2834
2835/// Binary hash for fast approximate nearest neighbor search.
2836///
2837/// # Research Background
2838///
2839/// Binary embeddings enable sub-linear search via Hamming distance. Key insight
2840/// from our research synthesis: **binary embeddings are for blocking, not primary
2841/// retrieval**. The sign-rank limitation means they cannot represent all similarity
2842/// relationships, but they excel at fast candidate filtering.
2843///
2844/// # Two-Stage Retrieval Pattern
2845///
2846/// ```text
2847/// Query → [Binary Hash] → Hamming Filter (fast) → Candidates
2848///                                                      ↓
2849///                                              [Dense Similarity]
2850///                                                      ↓
2851///                                               Final Results
2852/// ```
2853///
2854/// # Example
2855///
2856/// ```rust
2857/// use anno::backends::inference::BinaryHash;
2858///
2859/// // Create hashes from embeddings
2860/// let hash1 = BinaryHash::from_embedding(&[0.1, -0.2, 0.3, -0.4, 0.5, -0.6, 0.7, -0.8]);
2861/// let hash2 = BinaryHash::from_embedding(&[0.15, -0.25, 0.35, -0.45, 0.55, -0.65, 0.75, -0.85]);
2862///
2863/// // Similar embeddings → low Hamming distance
2864/// assert!(hash1.hamming_distance(&hash2) < 2);
2865/// ```
2866#[derive(Debug, Clone, PartialEq, Eq, Hash)]
2867pub struct BinaryHash {
2868    /// Packed bits (each u64 holds 64 bits)
2869    pub bits: Vec<u64>,
2870    /// Original dimension (number of bits)
2871    pub dim: usize,
2872}
2873
2874impl BinaryHash {
2875    /// Create from a dense embedding using sign function.
2876    ///
2877    /// Each positive value → 1, each negative/zero value → 0.
2878    #[must_use]
2879    pub fn from_embedding(embedding: &[f32]) -> Self {
2880        let dim = embedding.len();
2881        let num_u64s = dim.div_ceil(64);
2882        let mut bits = vec![0u64; num_u64s];
2883
2884        for (i, &val) in embedding.iter().enumerate() {
2885            if val > 0.0 {
2886                let word_idx = i / 64;
2887                let bit_idx = i % 64;
2888                bits[word_idx] |= 1u64 << bit_idx;
2889            }
2890        }
2891
2892        Self { bits, dim }
2893    }
2894
2895    /// Create from a dense f64 embedding.
2896    #[must_use]
2897    pub fn from_embedding_f64(embedding: &[f64]) -> Self {
2898        let dim = embedding.len();
2899        let num_u64s = dim.div_ceil(64);
2900        let mut bits = vec![0u64; num_u64s];
2901
2902        for (i, &val) in embedding.iter().enumerate() {
2903            if val > 0.0 {
2904                let word_idx = i / 64;
2905                let bit_idx = i % 64;
2906                bits[word_idx] |= 1u64 << bit_idx;
2907            }
2908        }
2909
2910        Self { bits, dim }
2911    }
2912
2913    /// Compute Hamming distance (number of differing bits).
2914    ///
2915    /// Uses POPCNT instruction when available for hardware acceleration.
2916    #[must_use]
2917    pub fn hamming_distance(&self, other: &Self) -> u32 {
2918        self.bits
2919            .iter()
2920            .zip(other.bits.iter())
2921            .map(|(a, b)| (a ^ b).count_ones())
2922            .sum()
2923    }
2924
2925    /// Compute normalized Hamming distance (0.0 to 1.0).
2926    #[must_use]
2927    pub fn hamming_distance_normalized(&self, other: &Self) -> f64 {
2928        if self.dim == 0 {
2929            return 0.0;
2930        }
2931        self.hamming_distance(other) as f64 / self.dim as f64
2932    }
2933
2934    /// Convert Hamming distance to approximate cosine similarity.
2935    ///
2936    /// Based on the relationship: cos(θ) ≈ 1 - 2 * (hamming_distance / dim)
2937    /// This is an approximation valid for random hyperplane hashing.
2938    #[must_use]
2939    pub fn approximate_cosine(&self, other: &Self) -> f64 {
2940        1.0 - 2.0 * self.hamming_distance_normalized(other)
2941    }
2942}
2943
2944/// Blocker using binary embeddings for fast candidate filtering.
2945///
2946/// # Usage Pattern
2947///
2948/// 1. Pre-compute binary hashes for all entities in your KB
2949/// 2. At query time, hash the query embedding
2950/// 3. Find candidates within Hamming distance threshold
2951/// 4. Run dense similarity only on candidates
2952///
2953/// # Example
2954///
2955/// ```rust
2956/// use anno::backends::inference::{BinaryBlocker, BinaryHash};
2957///
2958/// let mut blocker = BinaryBlocker::new(8); // 8-bit Hamming threshold
2959///
2960/// // Add entities to the index
2961/// let hash1 = BinaryHash::from_embedding(&vec![0.1; 768]);
2962/// let hash2 = BinaryHash::from_embedding(&vec![-0.1; 768]);
2963/// blocker.add(0, hash1);
2964/// blocker.add(1, hash2);
2965///
2966/// // Query
2967/// let query = BinaryHash::from_embedding(&vec![0.1; 768]);
2968/// let candidates = blocker.query(&query);
2969/// assert!(candidates.contains(&0)); // Similar to hash1
2970/// ```
2971#[derive(Debug, Clone)]
2972pub struct BinaryBlocker {
2973    /// Hamming distance threshold for candidates
2974    pub threshold: u32,
2975    /// Index of hashes by ID
2976    index: Vec<(usize, BinaryHash)>,
2977}
2978
2979impl BinaryBlocker {
2980    /// Create a new blocker with the given threshold.
2981    #[must_use]
2982    pub fn new(threshold: u32) -> Self {
2983        Self {
2984            threshold,
2985            index: Vec::new(),
2986        }
2987    }
2988
2989    /// Add an entity to the index.
2990    pub fn add(&mut self, id: usize, hash: BinaryHash) {
2991        self.index.push((id, hash));
2992    }
2993
2994    /// Add multiple entities.
2995    pub fn add_batch(&mut self, entries: impl IntoIterator<Item = (usize, BinaryHash)>) {
2996        self.index.extend(entries);
2997    }
2998
2999    /// Find candidate IDs within Hamming distance threshold.
3000    #[must_use]
3001    pub fn query(&self, query: &BinaryHash) -> Vec<usize> {
3002        self.index
3003            .iter()
3004            .filter(|(_, hash)| hash.hamming_distance(query) <= self.threshold)
3005            .map(|(id, _)| *id)
3006            .collect()
3007    }
3008
3009    /// Find candidates with their distances.
3010    #[must_use]
3011    pub fn query_with_distance(&self, query: &BinaryHash) -> Vec<(usize, u32)> {
3012        self.index
3013            .iter()
3014            .map(|(id, hash)| (*id, hash.hamming_distance(query)))
3015            .filter(|(_, dist)| *dist <= self.threshold)
3016            .collect()
3017    }
3018
3019    /// Number of entries in the index.
3020    #[must_use]
3021    pub fn len(&self) -> usize {
3022        self.index.len()
3023    }
3024
3025    /// Check if index is empty.
3026    #[must_use]
3027    pub fn is_empty(&self) -> bool {
3028        self.index.is_empty()
3029    }
3030
3031    /// Clear the index.
3032    pub fn clear(&mut self) {
3033        self.index.clear();
3034    }
3035}
3036
3037/// Recommended two-stage retrieval using binary blocking + dense reranking.
3038///
3039/// # Research Context
3040///
3041/// This implements the pattern identified in our research synthesis:
3042/// - Stage 1: Binary blocking for O(n) candidate filtering
3043/// - Stage 2: Dense similarity for accurate ranking
3044///
3045/// The key insight is that binary embeddings have fundamental limitations
3046/// (sign-rank theorem) but excel at fast filtering.
3047///
3048/// # Arguments
3049///
3050/// * `query_embedding` - Dense query embedding
3051/// * `candidate_embeddings` - Dense embeddings of all candidates
3052/// * `binary_threshold` - Hamming distance threshold for blocking
3053/// * `top_k` - Number of final results to return
3054///
3055/// # Returns
3056///
3057/// Vector of (candidate_index, similarity_score) pairs, sorted by score descending.
3058#[must_use]
3059pub fn two_stage_retrieval(
3060    query_embedding: &[f32],
3061    candidate_embeddings: &[Vec<f32>],
3062    binary_threshold: u32,
3063    top_k: usize,
3064) -> Vec<(usize, f32)> {
3065    // Stage 1: Binary blocking
3066    let query_hash = BinaryHash::from_embedding(query_embedding);
3067
3068    let candidate_hashes: Vec<BinaryHash> = candidate_embeddings
3069        .iter()
3070        .map(|e| BinaryHash::from_embedding(e))
3071        .collect();
3072
3073    let mut blocker = BinaryBlocker::new(binary_threshold);
3074    for (i, hash) in candidate_hashes.into_iter().enumerate() {
3075        blocker.add(i, hash);
3076    }
3077
3078    let candidates = blocker.query(&query_hash);
3079
3080    // Stage 2: Dense similarity on candidates only
3081    // Performance: Pre-allocate scored vec with known size
3082    let mut scored: Vec<(usize, f32)> = Vec::with_capacity(candidates.len());
3083    scored.extend(candidates.into_iter().map(|idx| {
3084        let sim = cosine_similarity_f32(query_embedding, &candidate_embeddings[idx]);
3085        (idx, sim)
3086    }));
3087
3088    // Performance: Use unstable sort (we don't need stable sort here)
3089    // Sort by similarity descending
3090    scored.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
3091    scored.truncate(top_k);
3092    scored
3093}
3094
3095/// Compute cosine similarity between two f32 vectors.
3096#[must_use]
3097pub fn cosine_similarity_f32(a: &[f32], b: &[f32]) -> f32 {
3098    if a.len() != b.len() || a.is_empty() {
3099        return 0.0;
3100    }
3101
3102    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
3103    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
3104    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
3105
3106    if norm_a == 0.0 || norm_b == 0.0 {
3107        return 0.0;
3108    }
3109
3110    dot / (norm_a * norm_b)
3111}
3112
3113// =============================================================================
3114// Tests
3115// =============================================================================
3116
3117#[cfg(test)]
3118mod tests {
3119    use super::*;
3120
3121    #[test]
3122    fn test_semantic_registry_builder() {
3123        let registry = SemanticRegistry::builder()
3124            .add_entity("person", "A human being")
3125            .add_entity("organization", "A company or group")
3126            .add_relation("WORKS_FOR", "Employment relationship")
3127            .build_placeholder(768);
3128
3129        assert_eq!(registry.len(), 3);
3130        assert_eq!(registry.entity_labels().count(), 2);
3131        assert_eq!(registry.relation_labels().count(), 1);
3132    }
3133
3134    #[test]
3135    fn test_standard_ner_registry() {
3136        let registry = SemanticRegistry::standard_ner(768);
3137        assert!(registry.len() >= 5);
3138        assert!(registry.label_index.contains_key("person"));
3139        assert!(registry.label_index.contains_key("organization"));
3140    }
3141
3142    #[test]
3143    fn test_dot_product_interaction() {
3144        let interaction = DotProductInteraction::new();
3145
3146        // 2 spans, 3 labels, hidden_dim=4
3147        let span_embs = vec![
3148            1.0, 0.0, 0.0, 0.0, // span 0
3149            0.0, 1.0, 0.0, 0.0, // span 1
3150        ];
3151        let label_embs = vec![
3152            1.0, 0.0, 0.0, 0.0, // label 0 (matches span 0)
3153            0.0, 1.0, 0.0, 0.0, // label 1 (matches span 1)
3154            0.5, 0.5, 0.0, 0.0, // label 2 (partial match both)
3155        ];
3156
3157        let scores = interaction.compute_similarity(&span_embs, 2, &label_embs, 3, 4);
3158
3159        assert_eq!(scores.len(), 6); // 2 * 3
3160        assert!((scores[0] - 1.0).abs() < 0.01); // span0 vs label0
3161        assert!((scores[4] - 1.0).abs() < 0.01); // span1 vs label1
3162    }
3163
3164    #[test]
3165    fn test_cosine_similarity() {
3166        let a = vec![1.0, 0.0, 0.0];
3167        let b = vec![1.0, 0.0, 0.0];
3168        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
3169
3170        let c = vec![0.0, 1.0, 0.0];
3171        assert!(cosine_similarity(&a, &c).abs() < 0.001);
3172
3173        let d = vec![-1.0, 0.0, 0.0];
3174        assert!((cosine_similarity(&a, &d) - (-1.0)).abs() < 0.001);
3175    }
3176
3177    #[test]
3178    fn test_coreference_string_match() {
3179        let entities = vec![
3180            Entity::new("Marie Curie", EntityType::Person, 0, 11, 0.95),
3181            Entity::new("Curie", EntityType::Person, 50, 55, 0.90),
3182        ];
3183
3184        let embeddings = vec![0.0f32; 2 * 768]; // Placeholder
3185        let clusters =
3186            resolve_coreferences(&entities, &embeddings, 768, &CoreferenceConfig::default());
3187
3188        assert_eq!(clusters.len(), 1);
3189        assert_eq!(clusters[0].members.len(), 2);
3190        assert_eq!(clusters[0].canonical_name, "Marie Curie");
3191    }
3192
3193    #[test]
3194    fn test_handshaking_matrix() {
3195        // 3 tokens, 2 labels, threshold 0.5
3196        let scores = vec![
3197            // token 0 with tokens 0,1,2 for labels 0,1
3198            0.9, 0.1, // (0,0)
3199            0.2, 0.8, // (0,1)
3200            0.1, 0.1, // (0,2)
3201            // token 1 with tokens 0,1,2
3202            0.0, 0.0, // (1,0) - skipped (lower triangle)
3203            0.7, 0.2, // (1,1)
3204            0.3, 0.6, // (1,2)
3205            // token 2
3206            0.0, 0.0, // (2,0)
3207            0.0, 0.0, // (2,1)
3208            0.1, 0.1, // (2,2)
3209        ];
3210
3211        let matrix = HandshakingMatrix::from_dense(&scores, 3, 2, 0.5);
3212
3213        // Should have cells for scores >= 0.5
3214        assert!(matrix.cells.len() >= 4);
3215    }
3216
3217    #[test]
3218    fn test_relation_extraction() {
3219        let entities = vec![
3220            Entity::new("Steve Jobs", EntityType::Person, 0, 10, 0.95),
3221            Entity::new("Apple", EntityType::Organization, 20, 25, 0.90),
3222        ];
3223
3224        let text = "Steve Jobs founded Apple Inc in 1976";
3225
3226        let registry = SemanticRegistry::builder()
3227            .add_relation("FOUNDED", "Founded an organization")
3228            .build_placeholder(768);
3229
3230        let config = RelationExtractionConfig::default();
3231        let relations = extract_relations(&entities, text, &registry, &config);
3232
3233        assert!(!relations.is_empty());
3234        assert_eq!(relations[0].relation_type, "FOUNDED");
3235    }
3236
3237    #[test]
3238    fn test_relation_extraction_uses_character_offsets_with_unicode_prefix() {
3239        // Unicode prefix ensures byte offsets != character offsets.
3240        let text = "👋 Steve Jobs founded Apple Inc.";
3241
3242        // Compute character offsets explicitly (Entity spans are char-based).
3243        let steve_start = text.find("Steve Jobs").expect("substring present");
3244        // `find` returns byte offset; convert to char offset.
3245        let conv = crate::offset::SpanConverter::new(text);
3246        let steve_start_char = conv.byte_to_char(steve_start);
3247        let steve_end_char = steve_start_char + "Steve Jobs".chars().count();
3248
3249        let apple_start = text.find("Apple").expect("substring present");
3250        let apple_start_char = conv.byte_to_char(apple_start);
3251        let apple_end_char = apple_start_char + "Apple".chars().count();
3252
3253        let entities = vec![
3254            Entity::new(
3255                "Steve Jobs",
3256                EntityType::Person,
3257                steve_start_char,
3258                steve_end_char,
3259                0.95,
3260            ),
3261            Entity::new(
3262                "Apple",
3263                EntityType::Organization,
3264                apple_start_char,
3265                apple_end_char,
3266                0.90,
3267            ),
3268        ];
3269
3270        let registry = SemanticRegistry::builder()
3271            .add_relation("FOUNDED", "Founded an organization")
3272            .build_placeholder(768);
3273
3274        let config = RelationExtractionConfig::default();
3275        let relations = extract_relations(&entities, text, &registry, &config);
3276
3277        assert!(
3278            !relations.is_empty(),
3279            "Expected FOUNDED relation to be detected"
3280        );
3281        assert_eq!(relations[0].relation_type, "FOUNDED");
3282
3283        // Trigger span should exist and cover "founded" in character offsets.
3284        let trigger = relations[0]
3285            .trigger_span
3286            .expect("expected trigger_span to be present");
3287        let trigger_text: String = text
3288            .chars()
3289            .skip(trigger.0)
3290            .take(trigger.1.saturating_sub(trigger.0))
3291            .collect();
3292        assert_eq!(trigger_text.to_ascii_lowercase(), "founded");
3293    }
3294
3295    // =========================================================================
3296    // Binary Embedding Tests
3297    // =========================================================================
3298
3299    #[test]
3300    fn test_binary_hash_creation() {
3301        let embedding = vec![0.1, -0.2, 0.3, -0.4, 0.5, -0.6, 0.7, -0.8];
3302        let hash = BinaryHash::from_embedding(&embedding);
3303
3304        assert_eq!(hash.dim, 8);
3305        // Positive values at indices 0, 2, 4, 6 should be set
3306        // bits[0] should have bits 0, 2, 4, 6 set = 0b01010101 = 85
3307        assert_eq!(hash.bits[0], 85);
3308    }
3309
3310    #[test]
3311    fn test_hamming_distance_identical() {
3312        let embedding = vec![0.1; 64];
3313        let hash1 = BinaryHash::from_embedding(&embedding);
3314        let hash2 = BinaryHash::from_embedding(&embedding);
3315
3316        assert_eq!(hash1.hamming_distance(&hash2), 0);
3317    }
3318
3319    #[test]
3320    fn test_hamming_distance_opposite() {
3321        let embedding1 = vec![0.1; 64];
3322        let embedding2 = vec![-0.1; 64];
3323        let hash1 = BinaryHash::from_embedding(&embedding1);
3324        let hash2 = BinaryHash::from_embedding(&embedding2);
3325
3326        assert_eq!(hash1.hamming_distance(&hash2), 64);
3327    }
3328
3329    #[test]
3330    fn test_hamming_distance_half() {
3331        let embedding1 = vec![0.1; 64];
3332        let mut embedding2 = vec![0.1; 64];
3333        // Flip second half
3334        embedding2[32..64].iter_mut().for_each(|x| *x = -0.1);
3335
3336        let hash1 = BinaryHash::from_embedding(&embedding1);
3337        let hash2 = BinaryHash::from_embedding(&embedding2);
3338
3339        assert_eq!(hash1.hamming_distance(&hash2), 32);
3340    }
3341
3342    #[test]
3343    fn test_binary_blocker() {
3344        let mut blocker = BinaryBlocker::new(5);
3345
3346        // Add some hashes
3347        let base_embedding = vec![0.1; 64];
3348        let similar_embedding = {
3349            let mut e = vec![0.1; 64];
3350            e[0] = -0.1; // Flip 1 bit
3351            e[1] = -0.1; // Flip 2 bits
3352            e
3353        };
3354        let different_embedding = vec![-0.1; 64];
3355
3356        blocker.add(0, BinaryHash::from_embedding(&base_embedding));
3357        blocker.add(1, BinaryHash::from_embedding(&similar_embedding));
3358        blocker.add(2, BinaryHash::from_embedding(&different_embedding));
3359
3360        // Query with base
3361        let query = BinaryHash::from_embedding(&base_embedding);
3362        let candidates = blocker.query(&query);
3363
3364        assert!(candidates.contains(&0), "Should find exact match");
3365        assert!(
3366            candidates.contains(&1),
3367            "Should find similar (2 bits different)"
3368        );
3369        assert!(
3370            !candidates.contains(&2),
3371            "Should NOT find opposite (64 bits different)"
3372        );
3373    }
3374
3375    #[test]
3376    fn test_two_stage_retrieval() {
3377        // Create embeddings
3378        let query = vec![1.0, 0.0, 0.0, 0.0];
3379        let candidates = vec![
3380            vec![1.0, 0.0, 0.0, 0.0],  // Identical
3381            vec![0.9, 0.1, 0.0, 0.0],  // Similar
3382            vec![-1.0, 0.0, 0.0, 0.0], // Opposite
3383            vec![0.0, 1.0, 0.0, 0.0],  // Orthogonal
3384        ];
3385
3386        // Generous threshold to get candidates
3387        let results = two_stage_retrieval(&query, &candidates, 4, 2);
3388
3389        assert!(!results.is_empty());
3390        // First result should be exact match
3391        assert_eq!(results[0].0, 0);
3392        assert!((results[0].1 - 1.0).abs() < 0.001);
3393    }
3394
3395    #[test]
3396    fn test_approximate_cosine() {
3397        let embedding1 = vec![0.1; 768];
3398        let embedding2 = vec![0.1; 768];
3399        let hash1 = BinaryHash::from_embedding(&embedding1);
3400        let hash2 = BinaryHash::from_embedding(&embedding2);
3401
3402        // Identical → approximate cosine should be ~1.0
3403        let approx = hash1.approximate_cosine(&hash2);
3404        assert!((approx - 1.0).abs() < 0.001);
3405
3406        // Opposite → approximate cosine should be ~-1.0
3407        let embedding3 = vec![-0.1; 768];
3408        let hash3 = BinaryHash::from_embedding(&embedding3);
3409        let approx_opp = hash1.approximate_cosine(&hash3);
3410        assert!((approx_opp - (-1.0)).abs() < 0.001);
3411    }
3412}