anno/backends/inference.rs
1//! Inference abstractions shared across `anno` backends.
2//!
3//! This module is mostly **plumbing**: common traits, data shapes, and small
4//! utilities used by multiple NER / IE backends (including “fixed-label” and
5//! “open/zero-shot” styles).
6//!
7//! Some of the terminology and design choices correspond to well-known
8//! architectures in the NER/IE literature, but the code here should be treated
9//! as an implementation substrate, not a verbatim reproduction of any single
10//! paper’s experiment section.
11//!
12//! ## Paper pointers (context only)
13//!
14//! - GLiNER: arXiv:2311.08526
15//! - UniversalNER: arXiv:2308.03279
16//! - W2NER: arXiv:2112.10070
17//! - ModernBERT: arXiv:2412.13663
18
19use std::borrow::Cow;
20use std::collections::HashMap;
21
22use crate::{Entity, EntityType};
23use anno_core::{RaggedBatch, Relation, SpanCandidate};
24
25// =============================================================================
26// Modality Types
27// =============================================================================
28
29/// Input modality for the encoder.
30///
31/// Supports text, images, and hybrid (OCR + visual) inputs.
32/// This enables ColPali-style visual document understanding.
33#[derive(Debug, Clone)]
34pub enum ModalityInput<'a> {
35 /// Plain text input
36 Text(Cow<'a, str>),
37 /// Image bytes (PNG/JPEG)
38 Image {
39 /// Raw image bytes
40 data: Cow<'a, [u8]>,
41 /// Image format hint
42 format: ImageFormat,
43 },
44 /// Hybrid: text with visual location (e.g., OCR result)
45 Hybrid {
46 /// Extracted text
47 text: Cow<'a, str>,
48 /// Visual bounding boxes for each token/word
49 visual_positions: Vec<VisualPosition>,
50 },
51}
52
53/// Image format hint for decoding.
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
55pub enum ImageFormat {
56 /// PNG format
57 #[default]
58 Png,
59 /// JPEG format
60 Jpeg,
61 /// WebP format
62 Webp,
63 /// Unknown/auto-detect
64 Unknown,
65}
66
67/// Visual position of a text token in an image.
68#[derive(Debug, Clone, Copy)]
69pub struct VisualPosition {
70 /// Token/word index
71 pub token_idx: u32,
72 /// Normalized x coordinate (0.0-1.0)
73 pub x: f32,
74 /// Normalized y coordinate (0.0-1.0)
75 pub y: f32,
76 /// Normalized width (0.0-1.0)
77 pub width: f32,
78 /// Normalized height (0.0-1.0)
79 pub height: f32,
80 /// Page number (for multi-page documents)
81 pub page: u32,
82}
83
84// =============================================================================
85// Semantic Registry (Pre-computed Label Embeddings)
86// =============================================================================
87
88/// A frozen, pre-computed registry of entity and relation types.
89///
90/// # Motivation
91///
92/// The `SemanticRegistry` is the "knowledge base" of a bi-encoder NER system.
93/// It stores pre-computed embeddings for all entity/relation types, enabling:
94///
95/// - **Zero-shot**: Add new types without retraining
96/// - **Speed**: Encode labels once, reuse forever
97/// - **Semantics**: Rich descriptions enable better matching
98///
99/// # Architecture
100///
101/// ```text
102/// ┌────────────────────────────────────────────────────────────────┐
103/// │ SemanticRegistry │
104/// ├────────────────────────────────────────────────────────────────┤
105/// │ labels: [ │
106/// │ { slug: "person", description: "named individual human" } │
107/// │ { slug: "organization", description: "company or group" } │
108/// │ { slug: "CEO_OF", description: "leads organization" } │
109/// │ ] │
110/// │ │
111/// │ embeddings: [768 floats] [768 floats] [768 floats] │
112/// │ └────┬────┘ └────┬────┘ └────┬────┘ │
113/// │ ▲ ▲ ▲ │
114/// │ person organization CEO_OF │
115/// │ │
116/// │ label_index: { "person" → 0, "organization" → 1, ... } │
117/// └────────────────────────────────────────────────────────────────┘
118/// ```
119///
120/// # Bi-Encoder Efficiency
121///
122/// The key insight from GLiNER is that label embeddings can be computed once
123/// and reused across all inference requests:
124///
125/// | Approach | Cost per query | Benefit |
126/// |----------|----------------|---------|
127/// | Cross-encoder | O(N × L) | Better accuracy |
128/// | Bi-encoder | O(N) + O(L) | Much faster, labels cached |
129///
130/// # Example
131///
132/// ```ignore
133/// use anno::SemanticRegistry;
134///
135/// // Build registry (expensive, do once at startup)
136/// let registry = SemanticRegistry::builder()
137/// .add_entity("person", "A named individual human being")
138/// .add_entity("organization", "A company, institution, or organized group")
139/// .add_relation("CEO_OF", "Chief executive officer of an organization")
140/// .build(&label_encoder)?;
141///
142/// // Use registry for all inference (cheap, cached embeddings)
143/// for document in documents {
144/// let entities = engine.extract(&document, ®istry)?;
145/// }
146/// ```
147///
148/// # Adding Custom Types
149///
150/// ```ignore
151/// // Domain-specific medical entities
152/// let medical_registry = SemanticRegistry::builder()
153/// .add_entity("drug", "A pharmaceutical compound or medication")
154/// .add_entity("disease", "A medical condition or illness")
155/// .add_entity("gene", "A genetic sequence encoding a protein")
156/// .add_relation("TREATS", "Drug is used to treat disease")
157/// .add_relation("CAUSES", "Factor causes or leads to condition")
158/// .build(&label_encoder)?;
159/// ```
160#[derive(Debug, Clone)]
161pub struct SemanticRegistry {
162 /// Pre-computed embeddings for all labels.
163 /// Shape: [num_labels, hidden_dim]
164 /// Stored as flattened f32 for simplicity without tensor deps.
165 pub embeddings: Vec<f32>,
166 /// Hidden dimension of embeddings
167 pub hidden_dim: usize,
168 /// Metadata for each label (index corresponds to embedding row)
169 pub labels: Vec<LabelDefinition>,
170 /// Index mapping from label slug to embedding row
171 pub label_index: HashMap<String, usize>,
172}
173
174/// Definition of a semantic label (entity type or relation type).
175#[derive(Debug, Clone)]
176pub struct LabelDefinition {
177 /// Unique identifier (e.g., "person", "CEO_OF")
178 pub slug: String,
179 /// Human-readable description (used for encoding)
180 pub description: String,
181 /// Category: Entity or Relation
182 pub category: LabelCategory,
183 /// Expected source modality
184 pub modality: ModalityHint,
185 /// Minimum confidence threshold for this label
186 pub threshold: f32,
187}
188
189/// Category of semantic label.
190#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
191pub enum LabelCategory {
192 /// Named entity (Person, Organization, Location, etc.)
193 Entity,
194 /// Relation between entities (CEO_OF, LOCATED_IN, etc.)
195 Relation,
196 /// Attribute of an entity (date of birth, revenue, etc.)
197 Attribute,
198}
199
200/// Hint for which modality this label applies to.
201#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
202pub enum ModalityHint {
203 /// Text-only (most entity types)
204 #[default]
205 TextOnly,
206 /// Visual-only (e.g., logos, signatures)
207 VisualOnly,
208 /// Works with both text and visual
209 Any,
210}
211
212impl SemanticRegistry {
213 /// Create a builder for constructing a registry.
214 pub fn builder() -> SemanticRegistryBuilder {
215 SemanticRegistryBuilder::new()
216 }
217
218 /// Get number of labels in the registry.
219 pub fn len(&self) -> usize {
220 self.labels.len()
221 }
222
223 /// Check if registry is empty.
224 pub fn is_empty(&self) -> bool {
225 self.labels.is_empty()
226 }
227
228 /// Get embedding for a label by slug.
229 pub fn get_embedding(&self, slug: &str) -> Option<&[f32]> {
230 let idx = self.label_index.get(slug)?;
231 let start = idx * self.hidden_dim;
232 let end = start + self.hidden_dim;
233 if end <= self.embeddings.len() {
234 Some(&self.embeddings[start..end])
235 } else {
236 None
237 }
238 }
239
240 /// Get all entity labels (excluding relations).
241 pub fn entity_labels(&self) -> impl Iterator<Item = &LabelDefinition> {
242 self.labels
243 .iter()
244 .filter(|l| l.category == LabelCategory::Entity)
245 }
246
247 /// Get all relation labels.
248 pub fn relation_labels(&self) -> impl Iterator<Item = &LabelDefinition> {
249 self.labels
250 .iter()
251 .filter(|l| l.category == LabelCategory::Relation)
252 }
253
254 /// Create a standard NER registry with common entity types.
255 pub fn standard_ner(hidden_dim: usize) -> Self {
256 // Placeholder embeddings - in real use, these would be encoder outputs
257 let labels = vec![
258 LabelDefinition {
259 slug: "person".into(),
260 description: "A named individual human being".into(),
261 category: LabelCategory::Entity,
262 modality: ModalityHint::TextOnly,
263 threshold: 0.5,
264 },
265 LabelDefinition {
266 slug: "organization".into(),
267 description: "A company, institution, agency, or other group".into(),
268 category: LabelCategory::Entity,
269 modality: ModalityHint::TextOnly,
270 threshold: 0.5,
271 },
272 LabelDefinition {
273 slug: "location".into(),
274 description: "A geographical place, city, country, or region".into(),
275 category: LabelCategory::Entity,
276 modality: ModalityHint::TextOnly,
277 threshold: 0.5,
278 },
279 LabelDefinition {
280 slug: "date".into(),
281 description: "A calendar date or time expression".into(),
282 category: LabelCategory::Entity,
283 modality: ModalityHint::TextOnly,
284 threshold: 0.5,
285 },
286 LabelDefinition {
287 slug: "money".into(),
288 description: "A monetary amount with currency".into(),
289 category: LabelCategory::Entity,
290 modality: ModalityHint::TextOnly,
291 threshold: 0.5,
292 },
293 ];
294
295 let num_labels = labels.len();
296 let label_index: HashMap<String, usize> = labels
297 .iter()
298 .enumerate()
299 .map(|(i, l)| (l.slug.clone(), i))
300 .collect();
301
302 // Initialize with zeros (placeholder)
303 let embeddings = vec![0.0f32; num_labels * hidden_dim];
304
305 Self {
306 embeddings,
307 hidden_dim,
308 labels,
309 label_index,
310 }
311 }
312}
313
314/// Builder for SemanticRegistry.
315#[derive(Debug, Default)]
316pub struct SemanticRegistryBuilder {
317 labels: Vec<LabelDefinition>,
318}
319
320impl SemanticRegistryBuilder {
321 /// Create a new builder.
322 pub fn new() -> Self {
323 Self::default()
324 }
325
326 /// Add an entity type.
327 pub fn add_entity(mut self, slug: &str, description: &str) -> Self {
328 self.labels.push(LabelDefinition {
329 slug: slug.into(),
330 description: description.into(),
331 category: LabelCategory::Entity,
332 modality: ModalityHint::TextOnly,
333 threshold: 0.5,
334 });
335 self
336 }
337
338 /// Add a relation type.
339 pub fn add_relation(mut self, slug: &str, description: &str) -> Self {
340 self.labels.push(LabelDefinition {
341 slug: slug.into(),
342 description: description.into(),
343 category: LabelCategory::Relation,
344 modality: ModalityHint::TextOnly,
345 threshold: 0.5,
346 });
347 self
348 }
349
350 /// Add a label with full configuration.
351 pub fn add_label(mut self, label: LabelDefinition) -> Self {
352 self.labels.push(label);
353 self
354 }
355
356 /// Build the registry (placeholder - real impl needs encoder).
357 pub fn build_placeholder(self, hidden_dim: usize) -> SemanticRegistry {
358 let num_labels = self.labels.len();
359 let label_index: HashMap<String, usize> = self
360 .labels
361 .iter()
362 .enumerate()
363 .map(|(i, l)| (l.slug.clone(), i))
364 .collect();
365
366 SemanticRegistry {
367 embeddings: vec![0.0f32; num_labels * hidden_dim],
368 hidden_dim,
369 labels: self.labels,
370 label_index,
371 }
372 }
373}
374
375// =============================================================================
376// Core Encoder Traits (GLiNER/ModernBERT Alignment)
377// =============================================================================
378
379/// Text encoder trait for transformer-based encoders.
380///
381/// # Motivation
382///
383/// Modern NER systems require converting raw text into dense vector representations
384/// that capture semantic meaning. This trait abstracts the encoding step, allowing
385/// different transformer architectures to be used interchangeably.
386///
387/// # Supported Architectures
388///
389/// | Architecture | Context | Key Features | Speed |
390/// |--------------|---------|--------------|-------|
391/// | ModernBERT | 8,192 | RoPE, GeGLU, unpadded inference | 3x faster |
392/// | DeBERTaV3 | 512 | Disentangled attention | Baseline |
393/// | BERT/RoBERTa | 512 | Classic, widely available | Baseline |
394///
395/// # Research Alignment (ModernBERT, Dec 2024)
396///
397/// From ModernBERT paper (arXiv:2412.13663):
398/// > "Pareto improvements to BERT... encoder-only models offer great
399/// > performance-size tradeoff for retrieval and classification."
400///
401/// Key innovations:
402/// - **Alternating Attention**: Global attention every 3 layers, local (128-token
403/// window) elsewhere. Reduces complexity for long sequences.
404/// - **Unpadding**: "ModernBERT unpads inputs *before* the token embedding layer
405/// and optionally repads model outputs leading to a 10-to-20 percent
406/// performance improvement over previous methods."
407/// - **RoPE**: Rotary positional embeddings enable extrapolation to longer sequences.
408/// - **GeGLU**: Gated activation function improves over GELU.
409///
410/// # Example
411///
412/// ```ignore
413/// use anno::TextEncoder;
414///
415/// fn process_document(encoder: &dyn TextEncoder, text: &str) {
416/// let output = encoder.encode(text).unwrap();
417/// println!("Encoded {} tokens into {} dimensions",
418/// output.num_tokens, output.hidden_dim);
419///
420/// // Token offsets map back to character positions
421/// for (i, (start, end)) in output.token_offsets.iter().enumerate() {
422/// println!("Token {}: chars {}..{}", i, start, end);
423/// }
424/// }
425/// ```
426pub trait TextEncoder: Send + Sync {
427 /// Encode text into token embeddings.
428 ///
429 /// # Arguments
430 /// * `text` - Input text to encode
431 ///
432 /// # Returns
433 /// * Token embeddings as flattened [num_tokens, hidden_dim]
434 /// * Attention mask indicating valid tokens
435 fn encode(&self, text: &str) -> crate::Result<EncoderOutput>;
436
437 /// Encode a batch of texts.
438 ///
439 /// # Arguments
440 /// * `texts` - Batch of input texts
441 ///
442 /// # Returns
443 /// * RaggedBatch containing all embeddings with document boundaries
444 fn encode_batch(&self, texts: &[&str]) -> crate::Result<(Vec<f32>, RaggedBatch)>;
445
446 /// Get the hidden dimension of the encoder.
447 fn hidden_dim(&self) -> usize;
448
449 /// Get the maximum sequence length.
450 fn max_length(&self) -> usize;
451
452 /// Get the encoder architecture name.
453 fn architecture(&self) -> &'static str;
454}
455
456/// Output from text encoding.
457#[derive(Debug, Clone)]
458pub struct EncoderOutput {
459 /// Token embeddings: [num_tokens, hidden_dim]
460 pub embeddings: Vec<f32>,
461 /// Number of tokens
462 pub num_tokens: usize,
463 /// Hidden dimension
464 pub hidden_dim: usize,
465 /// Token-to-character mapping (for span recovery)
466 pub token_offsets: Vec<(usize, usize)>,
467}
468
469/// Label encoder trait for encoding entity type descriptions.
470///
471/// # Motivation
472///
473/// Zero-shot NER works by encoding entity type *descriptions* into the same
474/// vector space as text spans. Instead of training separate classifiers for
475/// each entity type, we compute similarity between spans and label embeddings.
476///
477/// This enables:
478/// - **Unlimited entity types** at inference (no retraining needed)
479/// - **Faster inference** when labels are pre-computed
480/// - **Better generalization** to unseen entity types via semantic similarity
481///
482/// # Research Alignment
483///
484/// From GLiNER bi-encoder (knowledgator/modern-gliner-bi-base-v1.0):
485/// > "textual encoder is ModernBERT-base and entity label encoder is
486/// > sentence transformer - BGE-small-en."
487///
488/// # Example
489///
490/// ```ignore
491/// use anno::LabelEncoder;
492///
493/// fn setup_custom_types(encoder: &dyn LabelEncoder) {
494/// // Encode rich descriptions for better matching
495/// let labels = &[
496/// "a named individual human being",
497/// "a company, institution, or organized group",
498/// "a geographical location, city, country, or region",
499/// ];
500///
501/// let embeddings = encoder.encode_labels(labels).unwrap();
502/// // Store embeddings in SemanticRegistry for fast lookup
503/// }
504/// ```
505pub trait LabelEncoder: Send + Sync {
506 /// Encode a single label description.
507 ///
508 /// # Arguments
509 /// * `label` - Label description (e.g., "a named individual human being")
510 fn encode_label(&self, label: &str) -> crate::Result<Vec<f32>>;
511
512 /// Encode multiple labels.
513 ///
514 /// # Arguments
515 /// * `labels` - Label descriptions
516 ///
517 /// # Returns
518 /// Flattened embeddings: [num_labels, hidden_dim]
519 fn encode_labels(&self, labels: &[&str]) -> crate::Result<Vec<f32>>;
520
521 /// Get the hidden dimension.
522 fn hidden_dim(&self) -> usize;
523}
524
525/// Bi-encoder architecture combining text and label encoders.
526///
527/// # Motivation
528///
529/// The bi-encoder architecture treats NER as a **matching problem** rather than
530/// a classification problem. It encodes text spans and entity labels separately,
531/// then computes similarity scores to determine matches.
532///
533/// ```text
534/// ┌─────────────────┐ ┌─────────────────┐
535/// │ Text Input │ │ Label Desc. │
536/// │ "Steve Jobs" │ │ "person name" │
537/// └────────┬────────┘ └────────┬────────┘
538/// │ │
539/// ▼ ▼
540/// ┌─────────────────┐ ┌─────────────────┐
541/// │ TextEncoder │ │ LabelEncoder │
542/// │ (ModernBERT) │ │ (BGE-small) │
543/// └────────┬────────┘ └────────┬────────┘
544/// │ │
545/// ▼ ▼
546/// ┌─────────────────┐ ┌─────────────────┐
547/// │ Span Embedding │◄───────►│ Label Embedding │
548/// │ [768] │ cosine │ [768] │
549/// └─────────────────┘ sim └─────────────────┘
550/// │
551/// ▼
552/// Score: 0.92
553/// ```
554///
555/// # Trade-offs
556///
557/// | Aspect | Bi-Encoder | Uni-Encoder |
558/// |--------|------------|-------------|
559/// | Entity types | Unlimited | Fixed at training |
560/// | Inference speed | Faster (pre-compute labels) | Slower |
561/// | Disambiguation | Harder (no label interaction) | Better |
562/// | Generalization | Better to new types | Limited |
563///
564/// # Research Alignment
565///
566/// From GLiNER: "GLiNER frames NER as a matching problem, comparing candidate
567/// spans with entity type embeddings."
568///
569/// From knowledgator: "Bi-encoder architecture brings several advantages...
570/// unlimited entities, faster inference, better generalization."
571///
572/// Drawback: "Lack of inter-label interactions that make it hard to
573/// disambiguate semantically similar but contextually different entities."
574///
575/// # Example
576///
577/// ```ignore
578/// use anno::BiEncoder;
579///
580/// fn extract_custom_entities(bi_enc: &dyn BiEncoder, text: &str) {
581/// let labels = &["software company", "hardware manufacturer", "person"];
582/// let scores = bi_enc.encode_and_match(text, labels, 8).unwrap();
583///
584/// for s in scores.iter().filter(|s| s.score > 0.5) {
585/// println!("Found '{}' as type {} (score: {:.2})",
586/// &text[s.start..s.end], labels[s.label_idx], s.score);
587/// }
588/// }
589/// ```
590pub trait BiEncoder: Send + Sync {
591 /// Get the text encoder.
592 fn text_encoder(&self) -> &dyn TextEncoder;
593
594 /// Get the label encoder.
595 fn label_encoder(&self) -> &dyn LabelEncoder;
596
597 /// Encode text and labels, compute span-label similarities.
598 ///
599 /// # Arguments
600 /// * `text` - Input text
601 /// * `labels` - Entity type descriptions
602 /// * `max_span_width` - Maximum span width to consider
603 ///
604 /// # Returns
605 /// Similarity scores for each (span, label) pair
606 fn encode_and_match(
607 &self,
608 text: &str,
609 labels: &[&str],
610 max_span_width: usize,
611 ) -> crate::Result<Vec<SpanLabelScore>>;
612}
613
614/// Score for a (span, label) match.
615#[derive(Debug, Clone)]
616pub struct SpanLabelScore {
617 /// Span start (character offset)
618 pub start: usize,
619 /// Span end (character offset, exclusive)
620 pub end: usize,
621 /// Label index
622 pub label_idx: usize,
623 /// Similarity score (0.0 - 1.0)
624 pub score: f32,
625}
626
627// =============================================================================
628// Zero-Shot NER Trait
629// =============================================================================
630
631/// Zero-shot NER for open entity types.
632///
633/// # Motivation
634///
635/// Traditional NER models are trained on fixed taxonomies (PER, ORG, LOC, etc.)
636/// and cannot extract new entity types without retraining. Zero-shot NER solves
637/// this by allowing **arbitrary entity types at inference time**.
638///
639/// Instead of asking "is this a PERSON?", zero-shot NER asks "does this text
640/// span match the description 'a named individual human being'?"
641///
642/// # Use Cases
643///
644/// - **Domain adaptation**: Extract "gene names" or "legal citations" without
645/// training data
646/// - **Custom taxonomies**: Use your own entity hierarchy
647/// - **Rapid prototyping**: Test new entity types before investing in annotation
648///
649/// # Research Alignment
650///
651/// From GLiNER (arXiv:2311.08526):
652/// > "NER model capable of identifying any entity type using a bidirectional
653/// > transformer encoder... provides a practical alternative to traditional
654/// > NER models, which are limited to predefined entity types."
655///
656/// From UniversalNER (arXiv:2308.03279):
657/// > "Large language models demonstrate remarkable generalizability, such as
658/// > understanding arbitrary entities and relations."
659///
660/// # Example
661///
662/// ```ignore
663/// use anno::ZeroShotNER;
664///
665/// fn extract_medical_entities(ner: &dyn ZeroShotNER, clinical_note: &str) {
666/// // Define custom medical entity types at runtime
667/// let types = &["drug name", "disease", "symptom", "dosage"];
668///
669/// let entities = ner.extract_with_types(clinical_note, types, 0.5).unwrap();
670/// for e in entities {
671/// println!("{}: {} (conf: {:.2})", e.entity_type, e.text, e.confidence);
672/// }
673/// }
674///
675/// fn extract_with_descriptions(ner: &dyn ZeroShotNER, text: &str) {
676/// // Even richer: use natural language descriptions
677/// let descriptions = &[
678/// "a medication or pharmaceutical compound",
679/// "a medical condition or illness",
680/// "a physical sensation indicating illness",
681/// ];
682///
683/// let entities = ner.extract_with_descriptions(text, descriptions, 0.5).unwrap();
684/// }
685/// ```
686pub trait ZeroShotNER: Send + Sync {
687 /// Extract entities with custom types.
688 ///
689 /// # Arguments
690 /// * `text` - Input text
691 /// * `entity_types` - Entity type descriptions (arbitrary text, not fixed vocabulary)
692 /// - Encoded as text embeddings via bi-encoder (semantic matching, not exact string match)
693 /// - Any string works: `"disease"`, `"pharmaceutical compound"`, `"19th century French philosopher"`
694 /// - **Replaces default types completely** - model only extracts the specified types
695 /// - To include defaults, pass them explicitly: `&["person", "organization", "disease"]`
696 /// * `threshold` - Confidence threshold (0.0 - 1.0)
697 ///
698 /// # Returns
699 /// Entities with their matched types
700 ///
701 /// # Behavior
702 ///
703 /// - **Arbitrary text**: Type hints are not fixed vocabulary. They're encoded as embeddings,
704 /// so semantic similarity determines matches (not exact string matching).
705 /// - **Replace, don't union**: This method completely replaces default entity types.
706 /// The model only extracts the types you specify.
707 /// - **Semantic matching**: Uses cosine similarity between text span embeddings and label embeddings.
708 fn extract_with_types(
709 &self,
710 text: &str,
711 entity_types: &[&str],
712 threshold: f32,
713 ) -> crate::Result<Vec<Entity>>;
714
715 /// Extract entities with natural language descriptions.
716 ///
717 /// # Arguments
718 /// * `text` - Input text
719 /// * `descriptions` - Natural language descriptions of what to extract
720 /// - Encoded as text embeddings (same as `extract_with_types`)
721 /// - Examples: `"companies headquartered in Europe"`, `"diseases affecting the heart"`
722 /// - **Replaces default types completely** - model only extracts the specified descriptions
723 /// * `threshold` - Confidence threshold
724 ///
725 /// # Behavior
726 ///
727 /// Same as `extract_with_types`, but accepts natural language descriptions instead of
728 /// short type labels. Both methods encode labels as embeddings and use semantic matching.
729 fn extract_with_descriptions(
730 &self,
731 text: &str,
732 descriptions: &[&str],
733 threshold: f32,
734 ) -> crate::Result<Vec<Entity>>;
735
736 /// Get default entity types for this model.
737 ///
738 /// Returns the entity types used by `extract_entities()` (via `Model` trait).
739 /// Useful for extending defaults: combine with custom types and pass to `extract_with_types()`.
740 ///
741 /// # Example: Extending defaults
742 ///
743 /// ```ignore
744 /// use anno::ZeroShotNER;
745 ///
746 /// let ner: &dyn ZeroShotNER = ...;
747 /// let defaults = ner.default_types();
748 ///
749 /// // Combine defaults with custom types
750 /// let mut types: Vec<&str> = defaults.to_vec();
751 /// types.extend(&["disease", "medication"]);
752 ///
753 /// let entities = ner.extract_with_types(text, &types, 0.5)?;
754 /// ```
755 fn default_types(&self) -> &[&'static str];
756}
757
758// =============================================================================
759// Relation Extractor Trait
760// =============================================================================
761
762/// Joint entity and relation extraction.
763///
764/// # Motivation
765///
766/// Real-world information extraction often requires both entities AND their
767/// relationships. For example, extracting "Steve Jobs" and "Apple" is useful,
768/// but knowing "Steve Jobs FOUNDED Apple" is far more valuable.
769///
770/// Joint extraction (vs pipeline) is preferred because:
771/// - **Error propagation**: Pipeline errors compound (bad entities → bad relations)
772/// - **Shared context**: Entities and relations inform each other
773/// - **Efficiency**: Single forward pass instead of two
774///
775/// # Architecture
776///
777/// ```text
778/// Input: "Steve Jobs founded Apple in 1976."
779/// │
780/// ▼
781/// ┌──────────────────────────────────┐
782/// │ Shared Encoder (BERT) │
783/// └──────────────────────────────────┘
784/// │
785/// ┌──────┴──────┐
786/// ▼ ▼
787/// ┌───────────────┐ ┌───────────────┐
788/// │ Entity Head │ │ Relation Head │
789/// │ (span class.) │ │ (pair class.) │
790/// └───────┬───────┘ └───────┬───────┘
791/// │ │
792/// ▼ ▼
793/// Entities: Relations:
794/// - Steve Jobs [PER] - (Steve Jobs, FOUNDED, Apple)
795/// - Apple [ORG] - (Apple, FOUNDED_IN, 1976)
796/// - 1976 [DATE]
797/// ```
798///
799/// # Research Alignment
800///
801/// From GLiNER multi-task (arXiv:2406.12925):
802/// > "Generalist Lightweight Model for Various Information Extraction Tasks...
803/// > joint entity and relation extraction."
804///
805/// From W2NER (arXiv:2112.10070):
806/// > "Unified Named Entity Recognition as Word-Word Relation Classification...
807/// > handles flat, overlapped, and discontinuous NER."
808///
809/// # Example
810///
811/// ```ignore
812/// use anno::RelationExtractor;
813///
814/// fn build_knowledge_graph(extractor: &dyn RelationExtractor, text: &str) {
815/// let entity_types = &["person", "organization", "date"];
816/// let relation_types = &["founded", "works_for", "acquired"];
817///
818/// let result = extractor.extract_with_relations(
819/// text, entity_types, relation_types, 0.5
820/// ).unwrap();
821///
822/// // Build graph nodes from entities
823/// for e in &result.entities {
824/// println!("Node: {} ({})", e.text, e.entity_type);
825/// }
826///
827/// // Build graph edges from relations
828/// for r in &result.relations {
829/// let head = &result.entities[r.head_idx];
830/// let tail = &result.entities[r.tail_idx];
831/// println!("Edge: {} --[{}]--> {}", head.text, r.relation_type, tail.text);
832/// }
833/// }
834/// ```
835pub trait RelationExtractor: Send + Sync {
836 /// Extract entities and relations jointly.
837 ///
838 /// # Arguments
839 /// * `text` - Input text
840 /// * `entity_types` - Entity types to extract
841 /// * `relation_types` - Relation types to extract
842 /// * `threshold` - Confidence threshold
843 ///
844 /// # Returns
845 /// Entities and relations between them
846 fn extract_with_relations(
847 &self,
848 text: &str,
849 entity_types: &[&str],
850 relation_types: &[&str],
851 threshold: f32,
852 ) -> crate::Result<ExtractionWithRelations>;
853}
854
855/// Output from joint entity-relation extraction.
856#[derive(Debug, Clone, Default)]
857pub struct ExtractionWithRelations {
858 /// Extracted entities
859 pub entities: Vec<Entity>,
860 /// Relations between entities (indices into entities vec)
861 pub relations: Vec<RelationTriple>,
862}
863
864/// A relation triple linking two entities.
865#[derive(Debug, Clone)]
866pub struct RelationTriple {
867 /// Index of head entity in entities vec
868 pub head_idx: usize,
869 /// Index of tail entity in entities vec
870 pub tail_idx: usize,
871 /// Relation type
872 pub relation_type: String,
873 /// Confidence score
874 pub confidence: f32,
875}
876
877// =============================================================================
878// Discontinuous Entity Support (W2NER Research)
879// =============================================================================
880
881/// Support for discontinuous entity spans.
882///
883/// # Motivation
884///
885/// Not all entities are contiguous text spans. In coordination structures,
886/// entities can be **discontinuous** - scattered across non-adjacent positions.
887///
888/// # Examples of Discontinuous Entities
889///
890/// ```text
891/// "New York and Los Angeles airports"
892/// ^^^^^^^^ ^^^^^^^^^^^ ^^^^^^^^
893/// └──────────────────────────┘
894/// LOCATION: "New York airports" (discontinuous!)
895/// ^^^^^^^^^^^ ^^^^^^^^
896/// └───────────┘
897/// LOCATION: "Los Angeles airports" (contiguous)
898///
899/// "protein A and B complex"
900/// ^^^^^^^^^ ^^^ ^^^^^^^^^
901/// └────────────────────┘
902/// PROTEIN: "protein A ... complex" (discontinuous!)
903/// ```
904///
905/// # NER Complexity Hierarchy
906///
907/// | Type | Description | Example |
908/// |------|-------------|---------|
909/// | Flat | Non-overlapping spans | "John works at Google" |
910/// | Nested | Overlapping spans | "\[New \[York\] City\]" |
911/// | Discontinuous | Non-contiguous | "New York and LA \[airports\]" |
912///
913/// # Research Alignment
914///
915/// From W2NER (arXiv:2112.10070):
916/// > "Named entity recognition has been involved with three major types,
917/// > including flat, overlapped (aka. nested), and discontinuous NER...
918/// > we propose a novel architecture to model NER as word-word relation
919/// > classification."
920///
921/// W2NER achieves this by building a **handshaking matrix** where each cell
922/// (i, j) indicates whether tokens i and j are part of the same entity.
923///
924/// # Example
925///
926/// ```ignore
927/// use anno::DiscontinuousNER;
928///
929/// fn extract_complex_entities(ner: &dyn DiscontinuousNER, text: &str) {
930/// let types = &["location", "protein"];
931/// let entities = ner.extract_discontinuous(text, types, 0.5).unwrap();
932///
933/// for e in entities {
934/// if e.is_contiguous() {
935/// println!("Contiguous {}: '{}'", e.entity_type, e.text);
936/// } else {
937/// println!("Discontinuous {}: '{}' spans: {:?}",
938/// e.entity_type, e.text, e.spans);
939/// }
940/// }
941/// }
942/// ```
943pub trait DiscontinuousNER: Send + Sync {
944 /// Extract entities including discontinuous spans.
945 ///
946 /// # Arguments
947 /// * `text` - Input text
948 /// * `entity_types` - Entity types to extract
949 /// * `threshold` - Confidence threshold
950 ///
951 /// # Returns
952 /// Entities, potentially with multiple non-contiguous spans
953 fn extract_discontinuous(
954 &self,
955 text: &str,
956 entity_types: &[&str],
957 threshold: f32,
958 ) -> crate::Result<Vec<DiscontinuousEntity>>;
959}
960
961/// An entity that may span multiple non-contiguous regions.
962#[derive(Debug, Clone)]
963pub struct DiscontinuousEntity {
964 /// The spans that make up this entity (may be non-contiguous)
965 pub spans: Vec<(usize, usize)>,
966 /// Concatenated text from all spans
967 pub text: String,
968 /// Entity type
969 pub entity_type: String,
970 /// Confidence score
971 pub confidence: f32,
972}
973
974impl DiscontinuousEntity {
975 /// Check if this entity is contiguous (single span).
976 pub fn is_contiguous(&self) -> bool {
977 self.spans.len() == 1
978 }
979
980 /// Convert to a standard Entity if contiguous.
981 pub fn to_entity(&self) -> Option<Entity> {
982 if self.is_contiguous() {
983 let (start, end) = self.spans[0];
984 Some(Entity::new(
985 self.text.clone(),
986 EntityType::from_label(&self.entity_type),
987 start,
988 end,
989 self.confidence as f64,
990 ))
991 } else {
992 None
993 }
994 }
995}
996
997// =============================================================================
998// Late Interaction Trait
999// =============================================================================
1000
1001/// The core abstraction for bi-encoder NER scoring.
1002///
1003/// # Motivation
1004///
1005/// "Late interaction" refers to when the text and label representations
1006/// interact: at the very end of the pipeline, after both have been
1007/// independently encoded. This is in contrast to "early fusion" where
1008/// text and labels are concatenated before encoding.
1009///
1010/// ```text
1011/// Early Fusion Late Interaction
1012/// ──────────── ────────────────
1013///
1014/// Encode: [text + label] text label
1015/// │ │ │
1016/// ▼ ▼ ▼
1017/// Encoder Enc_T Enc_L
1018/// │ │ │
1019/// ▼ ▼ ▼
1020/// Score emb_t emb_l
1021/// │ │
1022/// └───┬───┘
1023/// ▼
1024/// dot(emb_t, emb_l)
1025/// ```
1026///
1027/// Late interaction enables:
1028/// - Pre-computing label embeddings (major speedup)
1029/// - Adding new labels without re-encoding text
1030/// - Parallelizing text and label encoding
1031///
1032/// # The Math
1033///
1034/// ```text
1035/// Score(span, label) = σ(span_emb · label_emb / τ)
1036///
1037/// where:
1038/// σ = sigmoid activation
1039/// · = dot product
1040/// τ = temperature (sharpness parameter)
1041/// ```
1042///
1043/// # Implementations
1044///
1045/// | Interaction | Formula | Speed | Accuracy | Use Case |
1046/// |-------------|---------|-------|----------|----------|
1047/// | DotProduct | s·l | Fast | Good | General purpose |
1048/// | MaxSim | max(s·l)| Medium| Better | Multi-token labels |
1049/// | Bilinear | s·W·l | Slow | Best | When accuracy critical |
1050///
1051/// # Example
1052///
1053/// ```ignore
1054/// use anno::{LateInteraction, DotProductInteraction};
1055///
1056/// let interaction = DotProductInteraction::with_temperature(20.0);
1057///
1058/// // Span embeddings: 3 spans × 768 dim
1059/// let span_embs: Vec<f32> = get_span_embeddings(&tokens, &candidates);
1060///
1061/// // Label embeddings: 5 labels × 768 dim
1062/// let label_embs: Vec<f32> = registry.all_embeddings();
1063///
1064/// // Compute 3×5 = 15 similarity scores
1065/// let mut scores = interaction.compute_similarity(
1066/// &span_embs, 3, &label_embs, 5, 768
1067/// );
1068/// interaction.apply_sigmoid(&mut scores);
1069///
1070/// // scores[i*5 + j] = similarity between span i and label j
1071/// ```
1072pub trait LateInteraction: Send + Sync {
1073 /// Compute similarity scores between span and label embeddings.
1074 ///
1075 /// # Arguments
1076 /// * `span_embeddings` - Shape: [num_spans, hidden_dim]
1077 /// * `label_embeddings` - Shape: [num_labels, hidden_dim]
1078 ///
1079 /// # Returns
1080 /// Similarity matrix of shape: [num_spans, num_labels]
1081 fn compute_similarity(
1082 &self,
1083 span_embeddings: &[f32],
1084 num_spans: usize,
1085 label_embeddings: &[f32],
1086 num_labels: usize,
1087 hidden_dim: usize,
1088 ) -> Vec<f32>;
1089
1090 /// Apply sigmoid activation to scores.
1091 fn apply_sigmoid(&self, scores: &mut [f32]) {
1092 for s in scores.iter_mut() {
1093 *s = 1.0 / (1.0 + (-*s).exp());
1094 }
1095 }
1096}
1097
1098/// Dot product interaction (default, fast).
1099#[derive(Debug, Clone, Copy, Default)]
1100pub struct DotProductInteraction {
1101 /// Temperature scaling (higher = sharper distribution)
1102 pub temperature: f32,
1103}
1104
1105impl DotProductInteraction {
1106 /// Create with default temperature (1.0).
1107 pub fn new() -> Self {
1108 Self { temperature: 1.0 }
1109 }
1110
1111 /// Create with custom temperature.
1112 #[must_use]
1113 pub fn with_temperature(temperature: f32) -> Self {
1114 Self { temperature }
1115 }
1116}
1117
1118impl LateInteraction for DotProductInteraction {
1119 fn compute_similarity(
1120 &self,
1121 span_embeddings: &[f32],
1122 num_spans: usize,
1123 label_embeddings: &[f32],
1124 num_labels: usize,
1125 hidden_dim: usize,
1126 ) -> Vec<f32> {
1127 let mut scores = vec![0.0f32; num_spans * num_labels];
1128
1129 for s in 0..num_spans {
1130 let span_start = s * hidden_dim;
1131 let span_end = span_start + hidden_dim;
1132 let span_vec = &span_embeddings[span_start..span_end];
1133
1134 for l in 0..num_labels {
1135 let label_start = l * hidden_dim;
1136 let label_end = label_start + hidden_dim;
1137 let label_vec = &label_embeddings[label_start..label_end];
1138
1139 // Dot product
1140 let mut dot: f32 = span_vec
1141 .iter()
1142 .zip(label_vec.iter())
1143 .map(|(a, b)| a * b)
1144 .sum();
1145
1146 // Temperature scaling
1147 dot *= self.temperature;
1148
1149 scores[s * num_labels + l] = dot;
1150 }
1151 }
1152
1153 scores
1154 }
1155}
1156
1157/// MaxSim interaction (ColBERT-style, better for phrases).
1158#[derive(Debug, Clone, Copy, Default)]
1159pub struct MaxSimInteraction {
1160 /// Temperature scaling
1161 pub temperature: f32,
1162}
1163
1164impl MaxSimInteraction {
1165 /// Create with default settings.
1166 pub fn new() -> Self {
1167 Self { temperature: 1.0 }
1168 }
1169}
1170
1171impl LateInteraction for MaxSimInteraction {
1172 fn compute_similarity(
1173 &self,
1174 span_embeddings: &[f32],
1175 num_spans: usize,
1176 label_embeddings: &[f32],
1177 num_labels: usize,
1178 hidden_dim: usize,
1179 ) -> Vec<f32> {
1180 // For single-vector embeddings, MaxSim degrades to dot product
1181 // True MaxSim requires multi-vector representations
1182 DotProductInteraction::new().compute_similarity(
1183 span_embeddings,
1184 num_spans,
1185 label_embeddings,
1186 num_labels,
1187 hidden_dim,
1188 )
1189 }
1190}
1191
1192// =============================================================================
1193// Span Representation
1194// =============================================================================
1195
1196/// Configuration for span representation.
1197///
1198/// # Research Context (Deep Span Representations, arXiv:2210.04182)
1199///
1200/// From "Deep Span Representations for NER":
1201/// > "Existing span-based NER systems **shallowly aggregate** the token
1202/// > representations to span representations. However, this typically results
1203/// > in significant ineffectiveness for **long-span entities**."
1204///
1205/// Common span representation strategies:
1206///
1207/// | Method | Formula | Pros | Cons |
1208/// |--------|---------|------|------|
1209/// | Concat | [h_i; h_j] | Simple, fast | Ignores middle tokens |
1210/// | Pooling | mean(h_i:h_j) | Uses all tokens | Loses boundary info |
1211/// | Attention | attn(h_i:h_j) | Learnable | Expensive |
1212/// | GLiNER | FFN([h_i; h_j; w]) | Balanced | Requires width emb |
1213///
1214/// # Recommendation (GLiNER Default)
1215///
1216/// For most use cases, concatenating first + last token embeddings with
1217/// a width embedding provides the best tradeoff:
1218/// - O(N) complexity (vs O(N²) for all-pairs attention)
1219/// - Captures boundary positions (critical for NER)
1220/// - Width embedding disambiguates "I" vs "New York City"
1221#[derive(Debug, Clone)]
1222pub struct SpanRepConfig {
1223 /// Hidden dimension of the encoder
1224 pub hidden_dim: usize,
1225 /// Maximum span width (in tokens)
1226 ///
1227 /// GLiNER uses K=12: "to keep linear complexity without harming recall."
1228 /// Wider spans rarely contain coherent entities.
1229 pub max_width: usize,
1230 /// Whether to include width embeddings
1231 ///
1232 /// Critical for distinguishing spans of different lengths
1233 /// with similar boundary tokens.
1234 pub use_width_embeddings: bool,
1235 /// Width embedding dimension (typically hidden_dim / 4)
1236 pub width_emb_dim: usize,
1237}
1238
1239impl Default for SpanRepConfig {
1240 fn default() -> Self {
1241 Self {
1242 hidden_dim: 768,
1243 max_width: 12,
1244 use_width_embeddings: true,
1245 width_emb_dim: 192, // 768 / 4
1246 }
1247 }
1248}
1249
1250/// Computes span representations from token embeddings.
1251///
1252/// # Research Alignment (GLiNER, NAACL 2024)
1253///
1254/// From the GLiNER paper (arXiv:2311.08526):
1255/// > "The representation of a span starting at position i and ending at
1256/// > position j in the input text, S_ij ∈ R^D, is computed as:
1257/// > **S_ij = FFN(h_i ⊗ h_j)**
1258/// > where FFN denotes a two-layer feedforward network, and ⊗ represents
1259/// > the concatenation operation."
1260///
1261/// The paper also notes:
1262/// > "We set an upper bound to the length (K=12) of the span in order to
1263/// > keep linear complexity in the size of the input text, without harming recall."
1264///
1265/// # Span Representation Formula
1266///
1267/// ```text
1268/// span_emb = FFN(Concat(token[i], token[j], width_emb[j-i]))
1269/// = W_2 · ReLU(W_1 · [h_i; h_j; w_{j-i}] + b_1) + b_2
1270/// ```
1271///
1272/// where:
1273/// - h_i = start token embedding
1274/// - h_j = end token embedding
1275/// - w_{j-i} = learned width embedding (captures span length)
1276///
1277/// This is the "gnarly bit" from GLiNER that enables zero-shot matching.
1278///
1279/// # Alternative: Global Pointer (arXiv:2208.03054)
1280///
1281/// Instead of enumerating spans, Global Pointer uses RoPE (rotary position
1282/// embeddings) to predict (start, end) pairs simultaneously:
1283///
1284/// ```text
1285/// score(i, j) = q_i^T * k_j (where q, k have RoPE applied)
1286/// ```
1287///
1288/// Advantages:
1289/// - No explicit span enumeration needed
1290/// - Naturally handles nested entities
1291/// - More parameter-efficient
1292///
1293/// GLiNER-style enumeration is still preferred for zero-shot because
1294/// it allows pre-computing label embeddings.
1295pub struct SpanRepresentationLayer {
1296 /// Configuration
1297 pub config: SpanRepConfig,
1298 /// Projection weights: [input_dim, hidden_dim]
1299 pub projection_weights: Vec<f32>,
1300 /// Projection bias: \[hidden_dim\]
1301 pub projection_bias: Vec<f32>,
1302 /// Width embeddings: [max_width, width_emb_dim]
1303 pub width_embeddings: Vec<f32>,
1304}
1305
1306impl SpanRepresentationLayer {
1307 /// Create a new span representation layer with random initialization.
1308 pub fn new(config: SpanRepConfig) -> Self {
1309 let input_dim = config.hidden_dim * 2 + config.width_emb_dim;
1310
1311 Self {
1312 projection_weights: vec![0.0f32; input_dim * config.hidden_dim],
1313 projection_bias: vec![0.0f32; config.hidden_dim],
1314 width_embeddings: vec![0.0f32; config.max_width * config.width_emb_dim],
1315 config,
1316 }
1317 }
1318
1319 /// Compute span representations from token embeddings.
1320 ///
1321 /// # Arguments
1322 /// * `token_embeddings` - Flattened [num_tokens, hidden_dim]
1323 /// * `candidates` - Span candidates with start/end indices
1324 ///
1325 /// # Returns
1326 /// Span embeddings: [num_candidates, hidden_dim]
1327 pub fn forward(
1328 &self,
1329 token_embeddings: &[f32],
1330 candidates: &[SpanCandidate],
1331 batch: &RaggedBatch,
1332 ) -> Vec<f32> {
1333 let hidden_dim = self.config.hidden_dim;
1334 let width_emb_dim = self.config.width_emb_dim;
1335 let max_width = self.config.max_width;
1336
1337 // Check for overflow in allocation
1338 let total_elements = match candidates.len().checked_mul(hidden_dim) {
1339 Some(v) => v,
1340 None => {
1341 log::warn!(
1342 "Span embedding allocation overflow: {} candidates * {} hidden_dim, returning empty",
1343 candidates.len(), hidden_dim
1344 );
1345 return vec![];
1346 }
1347 };
1348 let mut span_embeddings = vec![0.0f32; total_elements];
1349
1350 for (span_idx, candidate) in candidates.iter().enumerate() {
1351 // Get document token range
1352 let doc_range = match batch.doc_range(candidate.doc_idx as usize) {
1353 Some(r) => r,
1354 None => continue,
1355 };
1356
1357 // Validate span before computing global indices
1358 if candidate.end <= candidate.start {
1359 log::warn!(
1360 "Invalid span candidate: end ({}) <= start ({})",
1361 candidate.end,
1362 candidate.start
1363 );
1364 continue;
1365 }
1366
1367 // Global token indices
1368 let start_global = doc_range.start + candidate.start as usize;
1369 let end_global = doc_range.start + (candidate.end as usize) - 1; // Safe now that we validated
1370
1371 // Bounds check - must ensure both start and end slices fit
1372 // Use checked arithmetic to prevent overflow
1373 let start_byte = match start_global.checked_mul(hidden_dim) {
1374 Some(v) => v,
1375 None => {
1376 log::warn!(
1377 "Token index overflow: start_global={} * hidden_dim={}",
1378 start_global,
1379 hidden_dim
1380 );
1381 continue;
1382 }
1383 };
1384 let start_end_byte = match (start_global + 1).checked_mul(hidden_dim) {
1385 Some(v) => v,
1386 None => {
1387 log::warn!(
1388 "Token index overflow: (start_global+1)={} * hidden_dim={}",
1389 start_global + 1,
1390 hidden_dim
1391 );
1392 continue;
1393 }
1394 };
1395 let end_byte = match end_global.checked_mul(hidden_dim) {
1396 Some(v) => v,
1397 None => {
1398 log::warn!(
1399 "Token index overflow: end_global={} * hidden_dim={}",
1400 end_global,
1401 hidden_dim
1402 );
1403 continue;
1404 }
1405 };
1406 let end_end_byte = match (end_global + 1).checked_mul(hidden_dim) {
1407 Some(v) => v,
1408 None => {
1409 log::warn!(
1410 "Token index overflow: (end_global+1)={} * hidden_dim={}",
1411 end_global + 1,
1412 hidden_dim
1413 );
1414 continue;
1415 }
1416 };
1417
1418 if start_byte >= token_embeddings.len()
1419 || start_end_byte > token_embeddings.len()
1420 || end_byte >= token_embeddings.len()
1421 || end_end_byte > token_embeddings.len()
1422 {
1423 continue;
1424 }
1425
1426 // Get start and end token embeddings
1427 let start_emb = &token_embeddings[start_byte..start_end_byte];
1428 let end_emb = &token_embeddings[end_byte..end_end_byte];
1429
1430 // Optional width embedding (index = span_len - 1).
1431 let width_emb = if self.config.use_width_embeddings && width_emb_dim > 0 {
1432 let max_width_idx = max_width.saturating_sub(1);
1433 let span_len = candidate.width() as usize;
1434 let width_idx = span_len.saturating_sub(1).min(max_width_idx);
1435
1436 let width_start = width_idx.saturating_mul(width_emb_dim);
1437 let width_end = width_start.saturating_add(width_emb_dim);
1438 if width_end > self.width_embeddings.len() {
1439 None
1440 } else {
1441 Some(&self.width_embeddings[width_start..width_end])
1442 }
1443 } else {
1444 None
1445 };
1446
1447 // Baseline span representation: average of boundary embeddings (+ optional width signal).
1448 // This is deterministic and works without learned projection weights.
1449 let output_start = span_idx * hidden_dim;
1450 for h in 0..hidden_dim {
1451 span_embeddings[output_start + h] = (start_emb[h] + end_emb[h]) * 0.5;
1452 if let Some(width_emb) = width_emb {
1453 if h < width_emb_dim {
1454 span_embeddings[output_start + h] += width_emb[h] * 0.1;
1455 }
1456 }
1457 }
1458 }
1459
1460 span_embeddings
1461 }
1462}
1463
1464// =============================================================================
1465// Handshaking Matrix (TPLinker-style Joint Extraction)
1466// =============================================================================
1467
1468/// Result cell in a handshaking matrix.
1469#[derive(Debug, Clone, Copy)]
1470pub struct HandshakingCell {
1471 /// Row index (token i)
1472 pub i: u32,
1473 /// Column index (token j)
1474 pub j: u32,
1475 /// Predicted label index
1476 pub label_idx: u16,
1477 /// Confidence score
1478 pub score: f32,
1479}
1480
1481/// Handshaking matrix for joint entity-relation extraction.
1482///
1483/// # Research Alignment (W2NER, AAAI 2022)
1484///
1485/// From the W2NER paper (arXiv:2112.10070):
1486/// > "We present a novel alternative by modeling the unified NER as word-word
1487/// > relation classification, namely W2NER. The architecture resolves the kernel
1488/// > bottleneck of unified NER by effectively modeling the neighboring relations
1489/// > between entity words with **Next-Neighboring-Word (NNW)** and
1490/// > **Tail-Head-Word-* (THW-*)** relations."
1491///
1492/// In TPLinker/W2NER, we don't just tag tokens - we tag token PAIRS.
1493/// The matrix M\[i,j\] contains the label for the span (i, j).
1494///
1495/// # Key Relations
1496///
1497/// | Relation | Description | Purpose |
1498/// |----------|-------------|---------|
1499/// | NNW | Next-Neighboring-Word | Links adjacent tokens within entity |
1500/// | THW-* | Tail-Head-Word | Links end of one entity to start of next |
1501///
1502/// # Benefits
1503///
1504/// - Overlapping entities (same token in multiple spans)
1505/// - Joint entity-relation extraction in one pass
1506/// - Explicit boundary modeling
1507/// - Handles flat, nested, AND discontinuous NER in one model
1508pub struct HandshakingMatrix {
1509 /// Non-zero cells (sparse representation)
1510 pub cells: Vec<HandshakingCell>,
1511 /// Sequence length
1512 pub seq_len: usize,
1513 /// Number of labels
1514 pub num_labels: usize,
1515}
1516
1517impl HandshakingMatrix {
1518 /// Create from dense scores with thresholding.
1519 ///
1520 /// # Arguments
1521 /// * `scores` - Dense [seq_len, seq_len, num_labels] scores
1522 /// * `threshold` - Minimum score to keep
1523 pub fn from_dense(scores: &[f32], seq_len: usize, num_labels: usize, threshold: f32) -> Self {
1524 // Performance: Pre-allocate cells vec with estimated capacity
1525 // Most matrices have sparse cells (only high-scoring ones), so we estimate conservatively
1526 let estimated_capacity = (seq_len * seq_len / 10).min(1000); // ~10% of cells typically pass threshold
1527 let mut cells = Vec::with_capacity(estimated_capacity);
1528
1529 for i in 0..seq_len {
1530 for j in i..seq_len {
1531 // Upper triangular (i <= j)
1532 for l in 0..num_labels {
1533 let idx = i * seq_len * num_labels + j * num_labels + l;
1534 if idx < scores.len() {
1535 let score = scores[idx];
1536 if score >= threshold {
1537 cells.push(HandshakingCell {
1538 i: i as u32,
1539 j: j as u32,
1540 label_idx: l as u16,
1541 score,
1542 });
1543 }
1544 }
1545 }
1546 }
1547 }
1548
1549 Self {
1550 cells,
1551 seq_len,
1552 num_labels,
1553 }
1554 }
1555
1556 /// Decode entities from handshaking matrix.
1557 ///
1558 /// In W2NER convention, cell (i, j) represents a span where:
1559 /// - j is the start token index
1560 /// - i is the end token index (inclusive, so we add 1 for exclusive end)
1561 pub fn decode_entities<'a>(
1562 &self,
1563 registry: &'a SemanticRegistry,
1564 ) -> Vec<(SpanCandidate, &'a LabelDefinition, f32)> {
1565 let mut entities = Vec::new();
1566
1567 for cell in &self.cells {
1568 if let Some(label) = registry.labels.get(cell.label_idx as usize) {
1569 if label.category == LabelCategory::Entity {
1570 // W2NER: j=start, i=end (inclusive), so span is [j, i+1)
1571 entities.push((SpanCandidate::new(0, cell.j, cell.i + 1), label, cell.score));
1572 }
1573 }
1574 }
1575
1576 // Performance: Use unstable sort (we don't need stable sort here)
1577 // Sort by position, then by score (descending)
1578 entities.sort_unstable_by(|a, b| {
1579 a.0.start
1580 .cmp(&b.0.start)
1581 .then_with(|| a.0.end.cmp(&b.0.end))
1582 .then_with(|| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal))
1583 });
1584
1585 // Performance: Pre-allocate kept vec with estimated capacity
1586 // Non-maximum suppression
1587 let mut kept = Vec::with_capacity(entities.len().min(32));
1588 for (span, label, score) in entities {
1589 let overlaps = kept.iter().any(|(s, _, _): &(SpanCandidate, _, _)| {
1590 !(span.end <= s.start || s.end <= span.start)
1591 });
1592 if !overlaps {
1593 kept.push((span, label, score));
1594 }
1595 }
1596
1597 kept
1598 }
1599}
1600
1601// =============================================================================
1602// Coreference Resolution
1603// =============================================================================
1604
1605/// A coreference cluster (mentions referring to same entity).
1606#[derive(Debug, Clone)]
1607pub struct CoreferenceCluster {
1608 /// Cluster ID
1609 pub id: u64,
1610 /// Member entities (indices into entity list)
1611 pub members: Vec<usize>,
1612 /// Representative entity index (most informative mention)
1613 pub representative: usize,
1614 /// Canonical name (from representative)
1615 pub canonical_name: String,
1616}
1617
1618/// Configuration for coreference resolution.
1619#[derive(Debug, Clone)]
1620pub struct CoreferenceConfig {
1621 /// Minimum cosine similarity to link mentions
1622 pub similarity_threshold: f32,
1623 /// Maximum token distance between coreferent mentions
1624 pub max_distance: Option<usize>,
1625 /// Whether to use exact string matching as a signal
1626 pub use_string_match: bool,
1627}
1628
1629impl Default for CoreferenceConfig {
1630 fn default() -> Self {
1631 Self {
1632 similarity_threshold: 0.85,
1633 max_distance: Some(500),
1634 use_string_match: true,
1635 }
1636 }
1637}
1638
1639/// Resolve coreferences between entities using embedding similarity.
1640///
1641/// # Algorithm
1642///
1643/// 1. Compute pairwise cosine similarity between entity embeddings
1644/// 2. Link entities above threshold (with optional distance constraint)
1645/// 3. Build clusters via transitive closure
1646/// 4. Select representative (longest/most informative mention)
1647///
1648/// # Example
1649///
1650/// Input entities: ["Lynn Conway", "She", "The engineer", "Conway"]
1651/// Output clusters: [{0, 1, 2, 3}] with canonical_name = "Lynn Conway"
1652pub fn resolve_coreferences(
1653 entities: &[Entity],
1654 embeddings: &[f32], // [num_entities, hidden_dim]
1655 hidden_dim: usize,
1656 config: &CoreferenceConfig,
1657) -> Vec<CoreferenceCluster> {
1658 let n = entities.len();
1659 if n == 0 {
1660 return vec![];
1661 }
1662
1663 // Union-find for clustering
1664 let mut parent: Vec<usize> = (0..n).collect();
1665
1666 fn find(parent: &mut [usize], i: usize) -> usize {
1667 if parent[i] != i {
1668 parent[i] = find(parent, parent[i]);
1669 }
1670 parent[i]
1671 }
1672
1673 fn union(parent: &mut [usize], i: usize, j: usize) {
1674 let pi = find(parent, i);
1675 let pj = find(parent, j);
1676 if pi != pj {
1677 parent[pi] = pj;
1678 }
1679 }
1680
1681 // Check all pairs
1682 for i in 0..n {
1683 for j in (i + 1)..n {
1684 // String match check (fast path)
1685 if config.use_string_match {
1686 let text_i = entities[i].text.to_lowercase();
1687 let text_j = entities[j].text.to_lowercase();
1688 if text_i == text_j || text_i.contains(&text_j) || text_j.contains(&text_i) {
1689 // Same entity type required
1690 if entities[i].entity_type == entities[j].entity_type {
1691 union(&mut parent, i, j);
1692 continue;
1693 }
1694 }
1695 }
1696
1697 // Distance check
1698 if let Some(max_dist) = config.max_distance {
1699 let dist = if entities[i].end <= entities[j].start {
1700 entities[j].start - entities[i].end
1701 } else {
1702 entities[i].start.saturating_sub(entities[j].end)
1703 };
1704 if dist > max_dist {
1705 continue;
1706 }
1707 }
1708
1709 // Embedding similarity
1710 if embeddings.len() >= (j + 1) * hidden_dim {
1711 let emb_i = &embeddings[i * hidden_dim..(i + 1) * hidden_dim];
1712 let emb_j = &embeddings[j * hidden_dim..(j + 1) * hidden_dim];
1713
1714 let similarity = cosine_similarity(emb_i, emb_j);
1715
1716 if similarity >= config.similarity_threshold {
1717 // Same entity type required
1718 if entities[i].entity_type == entities[j].entity_type {
1719 union(&mut parent, i, j);
1720 }
1721 }
1722 }
1723 }
1724 }
1725
1726 // Build clusters
1727 let mut cluster_members: HashMap<usize, Vec<usize>> = HashMap::new();
1728 for i in 0..n {
1729 let root = find(&mut parent, i);
1730 cluster_members.entry(root).or_default().push(i);
1731 }
1732
1733 // Convert to CoreferenceCluster
1734 let mut clusters = Vec::new();
1735 let mut cluster_id = 0u64;
1736
1737 for (_root, members) in cluster_members {
1738 if members.len() > 1 {
1739 // Find representative (longest mention)
1740 let representative = *members
1741 .iter()
1742 .max_by_key(|&&i| entities[i].text.len())
1743 .unwrap_or(&members[0]);
1744
1745 clusters.push(CoreferenceCluster {
1746 id: cluster_id,
1747 members,
1748 representative,
1749 canonical_name: entities[representative].text.clone(),
1750 });
1751 cluster_id += 1;
1752 }
1753 }
1754
1755 clusters
1756}
1757
1758/// Compute cosine similarity between two vectors.
1759///
1760/// Returns a value in [-1.0, 1.0] where:
1761/// - 1.0 = identical direction
1762/// - 0.0 = orthogonal
1763/// - -1.0 = opposite direction
1764pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
1765 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
1766 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
1767 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
1768
1769 if norm_a > 0.0 && norm_b > 0.0 {
1770 dot / (norm_a * norm_b)
1771 } else {
1772 0.0
1773 }
1774}
1775
1776// =============================================================================
1777// Relation Extraction
1778// =============================================================================
1779
1780/// Configuration for relation extraction.
1781#[derive(Debug, Clone)]
1782pub struct RelationExtractionConfig {
1783 /// Maximum token distance between head and tail
1784 pub max_span_distance: usize,
1785 /// Minimum confidence for relation
1786 pub threshold: f32,
1787 /// Whether to extract relation triggers
1788 pub extract_triggers: bool,
1789}
1790
1791impl Default for RelationExtractionConfig {
1792 fn default() -> Self {
1793 Self {
1794 max_span_distance: 50,
1795 threshold: 0.5,
1796 extract_triggers: true,
1797 }
1798 }
1799}
1800
1801/// Extract relations between entities.
1802///
1803/// # Algorithm (Two-Pass)
1804///
1805/// 1. Run entity NER to find all entity mentions
1806/// 2. For each entity pair within distance threshold:
1807/// - Encode the span between them
1808/// - Match against relation type embeddings
1809/// - Optionally identify trigger span
1810///
1811/// # Returns
1812///
1813/// Relations with head/tail entities and optional trigger spans.
1814pub fn extract_relations(
1815 entities: &[Entity],
1816 text: &str,
1817 registry: &SemanticRegistry,
1818 config: &RelationExtractionConfig,
1819) -> Vec<Relation> {
1820 let mut relations = Vec::new();
1821 // `Entity` spans in anno are character offsets, but slicing a Rust `&str` requires byte
1822 // offsets. Build a converter once so we can safely slice and map trigger spans back.
1823 let span_converter = crate::offset::SpanConverter::new(text);
1824
1825 // Get relation labels
1826 let relation_labels: Vec<_> = registry.relation_labels().collect();
1827 if relation_labels.is_empty() {
1828 return relations;
1829 }
1830
1831 // Check all entity pairs
1832 for (i, head) in entities.iter().enumerate() {
1833 for (j, tail) in entities.iter().enumerate() {
1834 if i == j {
1835 continue;
1836 }
1837
1838 // Check distance
1839 let distance = if head.end <= tail.start {
1840 tail.start - head.end
1841 } else {
1842 head.start.saturating_sub(tail.end)
1843 };
1844
1845 if distance > config.max_span_distance {
1846 continue;
1847 }
1848
1849 // Look for relation triggers in the text between entities
1850 let (span_start, span_end) = if head.end <= tail.start {
1851 (head.end, tail.start)
1852 } else {
1853 (tail.end, head.start)
1854 };
1855
1856 let between_span = span_converter.from_chars(span_start, span_end);
1857 let between_text = text
1858 .get(between_span.byte_start..between_span.byte_end)
1859 .unwrap_or("");
1860
1861 // Simple heuristic: check for common relation indicators
1862 let relation_type = detect_relation_type(head, tail, between_text, &relation_labels);
1863
1864 if let Some((rel_type, mut confidence, trigger)) = relation_type {
1865 // Apply distance penalty: closer entities are more likely to be related
1866 // Confidence decays linearly from 1.0 at distance 0 to 0.5 at max_span_distance
1867 let distance_penalty = if distance < config.max_span_distance {
1868 let penalty_factor =
1869 1.0 - (distance as f64 / config.max_span_distance as f64) * 0.5;
1870 penalty_factor.max(0.5) // Minimum 0.5 confidence even at max distance
1871 } else {
1872 0.5 // At or beyond max distance, apply minimum confidence
1873 };
1874 confidence *= distance_penalty;
1875
1876 if confidence < config.threshold as f64 {
1877 continue;
1878 }
1879
1880 // `detect_relation_type` returns byte offsets into `between_text`.
1881 let trigger_span = if config.extract_triggers {
1882 trigger.map(|(s, e)| {
1883 let trigger_start_byte = between_span.byte_start.saturating_add(s);
1884 let trigger_end_byte = between_span.byte_start.saturating_add(e);
1885 (
1886 span_converter.byte_to_char(trigger_start_byte),
1887 span_converter.byte_to_char(trigger_end_byte),
1888 )
1889 })
1890 } else {
1891 None
1892 };
1893
1894 relations.push(Relation {
1895 head: head.clone(),
1896 tail: tail.clone(),
1897 relation_type: rel_type.to_string(),
1898 trigger_span,
1899 confidence: confidence.clamp(0.0, 1.0), // Clamp to [0, 1]
1900 });
1901 }
1902 }
1903 }
1904
1905 relations
1906}
1907
1908/// Extract relations as index-based triples (for joint extraction backends).
1909///
1910/// This is the same heuristic logic as [`extract_relations`], but returns
1911/// [`RelationTriple`] with indices into the provided `entities` slice.
1912///
1913/// Notes:
1914/// - Entity spans are **character offsets**.
1915/// - Trigger spans are not currently exposed in `RelationTriple`.
1916#[must_use]
1917pub fn extract_relation_triples(
1918 entities: &[Entity],
1919 text: &str,
1920 registry: &SemanticRegistry,
1921 config: &RelationExtractionConfig,
1922) -> Vec<RelationTriple> {
1923 let mut triples = Vec::new();
1924 if entities.len() < 2 {
1925 return triples;
1926 }
1927
1928 // `Entity` spans are character offsets; slicing needs byte offsets.
1929 let span_converter = crate::offset::SpanConverter::new(text);
1930
1931 let relation_labels: Vec<_> = registry.relation_labels().collect();
1932 if relation_labels.is_empty() {
1933 return triples;
1934 }
1935
1936 for (i, head) in entities.iter().enumerate() {
1937 for (j, tail) in entities.iter().enumerate() {
1938 if i == j {
1939 continue;
1940 }
1941
1942 // Skip overlapping spans (avoids self-nesting artifacts like "New York" vs "York").
1943 if head.start < tail.end && tail.start < head.end {
1944 continue;
1945 }
1946
1947 // Check distance (character offsets)
1948 let distance = if head.end <= tail.start {
1949 tail.start - head.end
1950 } else {
1951 head.start.saturating_sub(tail.end)
1952 };
1953 if distance > config.max_span_distance {
1954 continue;
1955 }
1956
1957 let (span_start, span_end) = if head.end <= tail.start {
1958 (head.end, tail.start)
1959 } else {
1960 (tail.end, head.start)
1961 };
1962
1963 let between_span = span_converter.from_chars(span_start, span_end);
1964 let between_text = text
1965 .get(between_span.byte_start..between_span.byte_end)
1966 .unwrap_or("");
1967
1968 if let Some((rel_type, mut confidence, _trigger)) =
1969 detect_relation_type(head, tail, between_text, &relation_labels)
1970 {
1971 // Apply distance penalty (same logic as extract_relations)
1972 let distance_penalty = if distance < config.max_span_distance {
1973 let penalty_factor =
1974 1.0 - (distance as f64 / config.max_span_distance as f64) * 0.5;
1975 penalty_factor.max(0.5)
1976 } else {
1977 0.5
1978 };
1979 confidence *= distance_penalty;
1980
1981 if confidence < config.threshold as f64 {
1982 continue;
1983 }
1984
1985 triples.push(RelationTriple {
1986 head_idx: i,
1987 tail_idx: j,
1988 relation_type: rel_type.to_string(),
1989 confidence: confidence as f32,
1990 });
1991 }
1992 }
1993 }
1994
1995 triples
1996}
1997
1998/// Result of relation detection: (label, confidence, optional span).
1999type RelationMatch<'a> = (&'a str, f64, Option<(usize, usize)>);
2000
2001/// Detect relation type from context (heuristic fallback).
2002fn detect_relation_type<'a>(
2003 head: &Entity,
2004 tail: &Entity,
2005 between_text: &str,
2006 relation_labels: &[&'a LabelDefinition],
2007) -> Option<RelationMatch<'a>> {
2008 // Use Unicode-aware lowercasing for multilingual support
2009 // Note: For CJK languages, case doesn't apply, but this is safe
2010 let between_lower = between_text.to_lowercase();
2011
2012 // Normalize relation slugs so datasets that use kebab-case / colon-separated schemas
2013 // (e.g. DocRED: "part-of", "general-affiliation") can match our canonical patterns
2014 // (e.g. "PART_OF", "GENERAL_AFFILIATION").
2015 fn norm_rel_slug(s: &str) -> String {
2016 // Uppercase + map non-alphanumerics to '_' so we can compare across naming schemes.
2017 let mut out = String::with_capacity(s.len());
2018 let mut prev_underscore = false;
2019 for ch in s.chars() {
2020 if ch.is_alphanumeric() {
2021 // Keep Unicode letters/digits; uppercase ASCII for stable matching.
2022 if ch.is_ascii_alphabetic() {
2023 out.push(ch.to_ascii_uppercase());
2024 } else {
2025 out.push(ch);
2026 }
2027 prev_underscore = false;
2028 } else if !prev_underscore {
2029 out.push('_');
2030 prev_underscore = true;
2031 }
2032 }
2033 while out.starts_with('_') {
2034 out.remove(0);
2035 }
2036 while out.ends_with('_') {
2037 out.pop();
2038 }
2039 out
2040 }
2041
2042 // Common patterns: (relation_slug, triggers, confidence)
2043 struct RelPattern {
2044 slug: &'static str,
2045 triggers: &'static [&'static str],
2046 confidence: f64,
2047 }
2048
2049 let patterns: &[RelPattern] = &[
2050 // Employment relations
2051 RelPattern {
2052 slug: "CEO_OF",
2053 triggers: &[
2054 "ceo of",
2055 "chief executive",
2056 "chief executive officer",
2057 "leads",
2058 "founded",
2059 "founder of",
2060 ],
2061 confidence: 0.8,
2062 },
2063 RelPattern {
2064 slug: "WORKS_FOR",
2065 triggers: &[
2066 "works for",
2067 "works at",
2068 "employed by",
2069 "employee of",
2070 "works with",
2071 "staff at",
2072 "member of",
2073 ],
2074 confidence: 0.7,
2075 },
2076 RelPattern {
2077 slug: "FOUNDED",
2078 triggers: &[
2079 "founded",
2080 "co-founded",
2081 "cofounder",
2082 "started",
2083 "established",
2084 "created",
2085 "launched",
2086 ],
2087 confidence: 0.8,
2088 },
2089 RelPattern {
2090 slug: "MANAGES",
2091 triggers: &[
2092 "manages",
2093 "managing",
2094 "oversees",
2095 "directs",
2096 "supervises",
2097 "runs",
2098 ],
2099 confidence: 0.75,
2100 },
2101 RelPattern {
2102 slug: "REPORTS_TO",
2103 triggers: &["reports to", "reported to", "under", "reports directly to"],
2104 confidence: 0.7,
2105 },
2106 // Location relations
2107 RelPattern {
2108 slug: "LOCATED_IN",
2109 triggers: &[
2110 "in",
2111 "at",
2112 "based in",
2113 "located in",
2114 "headquartered in",
2115 "situated in",
2116 "found in",
2117 ],
2118 confidence: 0.6,
2119 },
2120 RelPattern {
2121 slug: "BORN_IN",
2122 triggers: &[
2123 "born in",
2124 "native of",
2125 "from",
2126 "hails from",
2127 "originated in",
2128 ],
2129 confidence: 0.7,
2130 },
2131 RelPattern {
2132 slug: "LIVES_IN",
2133 triggers: &["lives in", "resides in", "living in", "based in"],
2134 confidence: 0.65,
2135 },
2136 RelPattern {
2137 slug: "DIED_IN",
2138 triggers: &["died in", "passed away in", "deceased in"],
2139 confidence: 0.8,
2140 },
2141 // Temporal relations
2142 RelPattern {
2143 slug: "OCCURRED_ON",
2144 triggers: &["on", "occurred on", "happened on", "took place on", "dated"],
2145 confidence: 0.6,
2146 },
2147 RelPattern {
2148 slug: "STARTED_ON",
2149 triggers: &["started on", "began on", "commenced on", "initiated on"],
2150 confidence: 0.7,
2151 },
2152 RelPattern {
2153 slug: "ENDED_ON",
2154 triggers: &["ended on", "concluded on", "finished on", "completed on"],
2155 confidence: 0.7,
2156 },
2157 // Organizational relations
2158 RelPattern {
2159 slug: "PART_OF",
2160 triggers: &[
2161 "part of",
2162 "member of",
2163 "belongs to",
2164 "subsidiary of",
2165 "division of",
2166 "branch of",
2167 ],
2168 confidence: 0.7,
2169 },
2170 RelPattern {
2171 slug: "ACQUIRED",
2172 triggers: &[
2173 "acquired",
2174 "bought",
2175 "purchased",
2176 "took over",
2177 "merged with",
2178 ],
2179 confidence: 0.75,
2180 },
2181 RelPattern {
2182 slug: "MERGED_WITH",
2183 triggers: &["merged with", "merged into", "combined with", "joined with"],
2184 confidence: 0.8,
2185 },
2186 RelPattern {
2187 slug: "PARENT_OF",
2188 triggers: &["parent of", "parent company of", "owns", "owner of"],
2189 confidence: 0.75,
2190 },
2191 // Social relations
2192 RelPattern {
2193 slug: "MARRIED_TO",
2194 triggers: &["married to", "wed to", "spouse of", "husband of", "wife of"],
2195 confidence: 0.85,
2196 },
2197 RelPattern {
2198 slug: "CHILD_OF",
2199 triggers: &["son of", "daughter of", "child of", "offspring of"],
2200 confidence: 0.8,
2201 },
2202 RelPattern {
2203 slug: "SIBLING_OF",
2204 triggers: &["brother of", "sister of", "sibling of"],
2205 confidence: 0.8,
2206 },
2207 // Academic/Professional
2208 RelPattern {
2209 slug: "STUDIED_AT",
2210 triggers: &[
2211 "studied at",
2212 "attended",
2213 "graduated from",
2214 "alumni of",
2215 "educated at",
2216 ],
2217 confidence: 0.75,
2218 },
2219 RelPattern {
2220 slug: "TEACHES_AT",
2221 triggers: &["teaches at", "professor at", "instructor at", "faculty at"],
2222 confidence: 0.8,
2223 },
2224 // Product/Service relations
2225 RelPattern {
2226 slug: "DEVELOPS",
2227 triggers: &[
2228 "develops",
2229 "created",
2230 "built",
2231 "designed",
2232 "produces",
2233 "manufactures",
2234 ],
2235 confidence: 0.7,
2236 },
2237 RelPattern {
2238 slug: "USES",
2239 triggers: &["uses", "utilizes", "employs", "adopts", "implements"],
2240 confidence: 0.6,
2241 },
2242 // Dataset-style relation labels (DocRED/CHisIEC-like)
2243 //
2244 // These are the *coarse* label names we actually see in the CrossRE/DocRED-style
2245 // exports used by this repo (e.g. `docred_dev.json`), which differ from the
2246 // “canonical” IE labels above.
2247 RelPattern {
2248 slug: "NAMED",
2249 triggers: &[
2250 "called",
2251 "known as",
2252 "also known as",
2253 "named",
2254 "referred to as",
2255 "nickname",
2256 ],
2257 confidence: 0.6,
2258 },
2259 RelPattern {
2260 slug: "TYPE_OF",
2261 triggers: &[
2262 "type of",
2263 "kind of",
2264 "form of",
2265 "a type of",
2266 "is a",
2267 "are a",
2268 ],
2269 confidence: 0.6,
2270 },
2271 RelPattern {
2272 slug: "RELATED_TO",
2273 triggers: &["related to", "associated with", "connected to", "linked to"],
2274 confidence: 0.55,
2275 },
2276 RelPattern {
2277 slug: "ORIGIN",
2278 triggers: &[
2279 "from",
2280 "born",
2281 "originated",
2282 "created by",
2283 "invented by",
2284 "derived from",
2285 "spinoff",
2286 "spin-off",
2287 ],
2288 confidence: 0.55,
2289 },
2290 RelPattern {
2291 slug: "ROLE",
2292 triggers: &[
2293 "president",
2294 "ceo",
2295 "chair",
2296 "director",
2297 "editor",
2298 "producer",
2299 "actor",
2300 "professor",
2301 "fellow",
2302 "member",
2303 ],
2304 confidence: 0.55,
2305 },
2306 RelPattern {
2307 slug: "TEMPORAL",
2308 triggers: &[
2309 "in 19", "in 20", "during", "before", "after", "between", "until", "since",
2310 ],
2311 confidence: 0.5,
2312 },
2313 RelPattern {
2314 slug: "PHYSICAL",
2315 triggers: &["located in", "based in", "headquartered in", "at "],
2316 confidence: 0.55,
2317 },
2318 RelPattern {
2319 slug: "TOPIC",
2320 triggers: &["topic", "about", "on", "regarding", "focused on"],
2321 confidence: 0.5,
2322 },
2323 RelPattern {
2324 slug: "OPPOSITE",
2325 triggers: &["opposite", "contrasts with", "as opposed to"],
2326 confidence: 0.6,
2327 },
2328 RelPattern {
2329 slug: "WIN_DEFEAT",
2330 triggers: &["defeated", "beat", "won", "win", "lose", "lost to"],
2331 confidence: 0.6,
2332 },
2333 RelPattern {
2334 slug: "CAUSE_EFFECT",
2335 triggers: &["caused", "causes", "leads to", "results in", "because"],
2336 confidence: 0.55,
2337 },
2338 RelPattern {
2339 slug: "USAGE",
2340 triggers: &["use", "uses", "used", "using", "utilize", "employ", "adopt"],
2341 confidence: 0.55,
2342 },
2343 RelPattern {
2344 slug: "ARTIFACT",
2345 triggers: &[
2346 "tool",
2347 "library",
2348 "framework",
2349 "system",
2350 "artifact",
2351 "implementation",
2352 ],
2353 confidence: 0.55,
2354 },
2355 RelPattern {
2356 slug: "COMPARE",
2357 triggers: &[
2358 "compare",
2359 "compared to",
2360 "versus",
2361 "vs",
2362 "better than",
2363 "worse than",
2364 ],
2365 confidence: 0.55,
2366 },
2367 RelPattern {
2368 slug: "GENERAL_AFFILIATION",
2369 triggers: &[
2370 "affiliation",
2371 "affiliated with",
2372 "member of",
2373 "part of",
2374 "associated with",
2375 ],
2376 confidence: 0.55,
2377 },
2378 // CHisIEC (classical Chinese) relations (match either simplified or traditional labels)
2379 RelPattern {
2380 slug: "父母",
2381 triggers: &["父", "母", "父母"],
2382 confidence: 0.7,
2383 },
2384 RelPattern {
2385 slug: "兄弟",
2386 triggers: &["兄", "弟", "兄弟"],
2387 confidence: 0.7,
2388 },
2389 RelPattern {
2390 slug: "別名",
2391 triggers: &["別名", "别名"],
2392 confidence: 0.75,
2393 },
2394 RelPattern {
2395 slug: "到達",
2396 triggers: &["到", "至", "達", "到達", "到达"],
2397 confidence: 0.6,
2398 },
2399 RelPattern {
2400 slug: "出生於某地",
2401 triggers: &["生於", "生于", "出生於", "出生于"],
2402 confidence: 0.65,
2403 },
2404 RelPattern {
2405 slug: "任職",
2406 triggers: &["任", "拜", "任職", "任职"],
2407 confidence: 0.6,
2408 },
2409 RelPattern {
2410 slug: "管理",
2411 triggers: &["管", "治", "守", "管理"],
2412 confidence: 0.55,
2413 },
2414 RelPattern {
2415 slug: "駐守",
2416 triggers: &["駐", "驻", "守", "駐守", "驻守"],
2417 confidence: 0.55,
2418 },
2419 RelPattern {
2420 slug: "敵對攻伐",
2421 triggers: &["敵", "敌", "攻", "伐", "戰", "战"],
2422 confidence: 0.55,
2423 },
2424 RelPattern {
2425 slug: "同僚",
2426 triggers: &["同僚"],
2427 confidence: 0.55,
2428 },
2429 RelPattern {
2430 slug: "政治奧援",
2431 triggers: &["奧援", "奥援"],
2432 confidence: 0.55,
2433 },
2434 // Communication/Interaction
2435 RelPattern {
2436 slug: "MET_WITH",
2437 triggers: &["met with", "met", "met up with", "encountered", "saw"],
2438 confidence: 0.65,
2439 },
2440 RelPattern {
2441 slug: "SPOKE_WITH",
2442 triggers: &[
2443 "spoke with",
2444 "talked with",
2445 "discussed with",
2446 "conversed with",
2447 ],
2448 confidence: 0.7,
2449 },
2450 // Ownership
2451 RelPattern {
2452 slug: "OWNS",
2453 triggers: &["owns", "owner of", "possesses", "holds"],
2454 confidence: 0.75,
2455 },
2456 // =========================================================================
2457 // Multilingual relation triggers
2458 // =========================================================================
2459 // Spanish (es)
2460 RelPattern {
2461 slug: "WORKS_FOR",
2462 triggers: &["trabaja en", "trabaja para", "empleado de", "trabaja con"],
2463 confidence: 0.7,
2464 },
2465 RelPattern {
2466 slug: "FOUNDED",
2467 triggers: &["fundó", "fundada", "creó", "creada", "estableció", "inició"],
2468 confidence: 0.8,
2469 },
2470 RelPattern {
2471 slug: "LOCATED_IN",
2472 triggers: &[
2473 "en",
2474 "ubicado en",
2475 "situado en",
2476 "basado en",
2477 "localizado en",
2478 ],
2479 confidence: 0.6,
2480 },
2481 RelPattern {
2482 slug: "BORN_IN",
2483 triggers: &["nació en", "nacido en", "originario de", "de"],
2484 confidence: 0.7,
2485 },
2486 RelPattern {
2487 slug: "LIVES_IN",
2488 triggers: &["cerno en", "reside en", "viviendo en"],
2489 confidence: 0.65,
2490 },
2491 RelPattern {
2492 slug: "MARRIED_TO",
2493 triggers: &["casado con", "casada con", "esposo de", "esposa de"],
2494 confidence: 0.85,
2495 },
2496 // French (fr)
2497 RelPattern {
2498 slug: "WORKS_FOR",
2499 triggers: &[
2500 "travaille pour",
2501 "travaille à",
2502 "employé de",
2503 "travaille avec",
2504 ],
2505 confidence: 0.7,
2506 },
2507 RelPattern {
2508 slug: "FOUNDED",
2509 triggers: &["fondé", "fondée", "créé", "créée", "établi", "établie"],
2510 confidence: 0.8,
2511 },
2512 RelPattern {
2513 slug: "LOCATED_IN",
2514 triggers: &["dans", "à", "situé en", "basé en", "localisé en"],
2515 confidence: 0.6,
2516 },
2517 RelPattern {
2518 slug: "BORN_IN",
2519 triggers: &["né en", "née en", "originaire de", "de"],
2520 confidence: 0.7,
2521 },
2522 RelPattern {
2523 slug: "LIVES_IN",
2524 triggers: &["vit en", "réside en", "vivant en"],
2525 confidence: 0.65,
2526 },
2527 RelPattern {
2528 slug: "MARRIED_TO",
2529 triggers: &["marié avec", "mariée avec", "époux de", "épouse de"],
2530 confidence: 0.85,
2531 },
2532 // German (de)
2533 RelPattern {
2534 slug: "WORKS_FOR",
2535 triggers: &[
2536 "arbeitet für",
2537 "arbeitet bei",
2538 "angestellt bei",
2539 "arbeitet mit",
2540 ],
2541 confidence: 0.7,
2542 },
2543 RelPattern {
2544 slug: "FOUNDED",
2545 triggers: &[
2546 "gegründet",
2547 "gründete",
2548 "erstellt",
2549 "errichtet",
2550 "etabliert",
2551 ],
2552 confidence: 0.8,
2553 },
2554 RelPattern {
2555 slug: "LOCATED_IN",
2556 triggers: &["in", "bei", "situiert in", "basiert in", "befindet sich in"],
2557 confidence: 0.6,
2558 },
2559 RelPattern {
2560 slug: "BORN_IN",
2561 triggers: &["geboren in", "geboren am", "stammt aus", "aus"],
2562 confidence: 0.7,
2563 },
2564 RelPattern {
2565 slug: "LIVES_IN",
2566 triggers: &["lebt in", "wohnt in", "lebend in"],
2567 confidence: 0.65,
2568 },
2569 RelPattern {
2570 slug: "MARRIED_TO",
2571 triggers: &["verheiratet mit", "ehemann von", "ehefrau von"],
2572 confidence: 0.85,
2573 },
2574 // Chinese (zh) - Simplified
2575 RelPattern {
2576 slug: "WORKS_FOR",
2577 triggers: &["为", "在", "工作于", "就职于", "任职于"],
2578 confidence: 0.7,
2579 },
2580 RelPattern {
2581 slug: "FOUNDED",
2582 triggers: &["创立", "创建", "建立", "成立", "创办"],
2583 confidence: 0.8,
2584 },
2585 RelPattern {
2586 slug: "LOCATED_IN",
2587 triggers: &["在", "位于", "坐落于", "地处"],
2588 confidence: 0.6,
2589 },
2590 RelPattern {
2591 slug: "BORN_IN",
2592 triggers: &["出生于", "生于", "来自", "出生于"],
2593 confidence: 0.7,
2594 },
2595 RelPattern {
2596 slug: "LIVES_IN",
2597 triggers: &["居住于", "住在", "生活在"],
2598 confidence: 0.65,
2599 },
2600 RelPattern {
2601 slug: "MARRIED_TO",
2602 triggers: &["与...结婚", "嫁给", "娶了"],
2603 confidence: 0.85,
2604 },
2605 // Japanese (ja)
2606 RelPattern {
2607 slug: "WORKS_FOR",
2608 triggers: &["で働く", "に勤務", "に所属", "で就職"],
2609 confidence: 0.7,
2610 },
2611 RelPattern {
2612 slug: "FOUNDED",
2613 triggers: &["設立", "創立", "設立した", "創設"],
2614 confidence: 0.8,
2615 },
2616 RelPattern {
2617 slug: "LOCATED_IN",
2618 triggers: &["に", "で", "に位置", "に所在"],
2619 confidence: 0.6,
2620 },
2621 RelPattern {
2622 slug: "BORN_IN",
2623 triggers: &["に生まれた", "の出身", "で生まれた"],
2624 confidence: 0.7,
2625 },
2626 RelPattern {
2627 slug: "LIVES_IN",
2628 triggers: &["に住む", "に居住", "に在住"],
2629 confidence: 0.65,
2630 },
2631 RelPattern {
2632 slug: "MARRIED_TO",
2633 triggers: &["と結婚", "と結婚した", "の配偶者"],
2634 confidence: 0.85,
2635 },
2636 // Arabic (ar) - RTL
2637 RelPattern {
2638 slug: "WORKS_FOR",
2639 triggers: &["يعمل في", "يعمل لصالح", "موظف في", "يعمل مع"],
2640 confidence: 0.7,
2641 },
2642 RelPattern {
2643 slug: "FOUNDED",
2644 triggers: &["أسس", "أنشأ", "تأسست", "أنشأت"],
2645 confidence: 0.8,
2646 },
2647 RelPattern {
2648 slug: "LOCATED_IN",
2649 triggers: &["في", "ب", "يقع في", "موجود في"],
2650 confidence: 0.6,
2651 },
2652 RelPattern {
2653 slug: "BORN_IN",
2654 triggers: &["ولد في", "من مواليد", "من"],
2655 confidence: 0.7,
2656 },
2657 RelPattern {
2658 slug: "LIVES_IN",
2659 triggers: &["يعيش في", "يسكن في", "مقيم في"],
2660 confidence: 0.65,
2661 },
2662 RelPattern {
2663 slug: "MARRIED_TO",
2664 triggers: &["متزوج من", "زوج", "زوجة"],
2665 confidence: 0.85,
2666 },
2667 // Russian (ru)
2668 RelPattern {
2669 slug: "WORKS_FOR",
2670 triggers: &["работает в", "работает на", "работает для", "сотрудник"],
2671 confidence: 0.7,
2672 },
2673 RelPattern {
2674 slug: "FOUNDED",
2675 triggers: &["основал", "основала", "создал", "создала", "учредил"],
2676 confidence: 0.8,
2677 },
2678 RelPattern {
2679 slug: "LOCATED_IN",
2680 triggers: &["в", "на", "расположен в", "находится в"],
2681 confidence: 0.6,
2682 },
2683 RelPattern {
2684 slug: "BORN_IN",
2685 triggers: &["родился в", "родилась в", "родом из", "из"],
2686 confidence: 0.7,
2687 },
2688 RelPattern {
2689 slug: "LIVES_IN",
2690 triggers: &["живет в", "проживает в", "живущий в"],
2691 confidence: 0.65,
2692 },
2693 RelPattern {
2694 slug: "MARRIED_TO",
2695 triggers: &["женат на", "замужем за", "супруг", "супруга"],
2696 confidence: 0.85,
2697 },
2698 ];
2699
2700 for pattern in patterns {
2701 // Find the canonical label in the registry (case-insensitive).
2702 // We return the label's *original* slug so callers preserve user-provided casing.
2703 let label = match relation_labels.iter().find(|l| {
2704 // Match both:
2705 // - exact canonical names (e.g. "PART_OF")
2706 // - normalized dataset slugs (e.g. "part-of" -> "PART_OF")
2707 norm_rel_slug(&l.slug) == pattern.slug || l.slug.eq_ignore_ascii_case(pattern.slug)
2708 }) {
2709 Some(l) => *l,
2710 None => continue,
2711 };
2712
2713 for trigger in pattern.triggers {
2714 if let Some(pos) = between_lower.find(trigger) {
2715 // Validate entity types make sense for the relation
2716 let valid = match pattern.slug {
2717 // Person-Organization relations
2718 "CEO_OF" | "WORKS_FOR" | "FOUNDED" | "MANAGES" | "REPORTS_TO" => {
2719 // If either side is unknown/misc, don't reject on type alone (relation datasets
2720 // often use a richer schema than `EntityType`).
2721 matches!(
2722 head.entity_type,
2723 EntityType::Other(_) | EntityType::Custom { .. }
2724 ) || matches!(
2725 tail.entity_type,
2726 EntityType::Other(_) | EntityType::Custom { .. }
2727 ) || (matches!(head.entity_type, EntityType::Person)
2728 && matches!(tail.entity_type, EntityType::Organization))
2729 }
2730 // Location relations (any entity can be located in/born in a location)
2731 "LOCATED_IN" | "BORN_IN" | "LIVES_IN" | "DIED_IN" => {
2732 matches!(
2733 tail.entity_type,
2734 EntityType::Other(_) | EntityType::Custom { .. }
2735 ) || matches!(tail.entity_type, EntityType::Location)
2736 }
2737 // Temporal relations (any entity can have temporal attributes)
2738 "OCCURRED_ON" | "STARTED_ON" | "ENDED_ON" => {
2739 matches!(
2740 tail.entity_type,
2741 EntityType::Other(_) | EntityType::Custom { .. }
2742 ) || matches!(tail.entity_type, EntityType::Date | EntityType::Time)
2743 }
2744 // Organizational relations
2745 "PART_OF" | "ACQUIRED" | "MERGED_WITH" | "PARENT_OF" => {
2746 matches!(
2747 head.entity_type,
2748 EntityType::Other(_) | EntityType::Custom { .. }
2749 ) || matches!(
2750 tail.entity_type,
2751 EntityType::Other(_) | EntityType::Custom { .. }
2752 ) || (matches!(head.entity_type, EntityType::Organization)
2753 && matches!(tail.entity_type, EntityType::Organization))
2754 }
2755 // Social relations
2756 "MARRIED_TO" | "CHILD_OF" | "SIBLING_OF" => {
2757 matches!(
2758 head.entity_type,
2759 EntityType::Other(_) | EntityType::Custom { .. }
2760 ) || matches!(
2761 tail.entity_type,
2762 EntityType::Other(_) | EntityType::Custom { .. }
2763 ) || (matches!(head.entity_type, EntityType::Person)
2764 && matches!(tail.entity_type, EntityType::Person))
2765 }
2766 // Academic relations
2767 "STUDIED_AT" | "TEACHES_AT" => {
2768 matches!(
2769 head.entity_type,
2770 EntityType::Other(_) | EntityType::Custom { .. }
2771 ) || matches!(
2772 tail.entity_type,
2773 EntityType::Other(_) | EntityType::Custom { .. }
2774 ) || (matches!(head.entity_type, EntityType::Person)
2775 && matches!(
2776 tail.entity_type,
2777 EntityType::Organization | EntityType::Location
2778 ))
2779 }
2780 // Product relations
2781 "DEVELOPS" | "USES" => {
2782 matches!(
2783 head.entity_type,
2784 EntityType::Other(_) | EntityType::Custom { .. }
2785 ) || matches!(
2786 head.entity_type,
2787 EntityType::Organization | EntityType::Person
2788 )
2789 }
2790 // Interaction relations
2791 "MET_WITH" | "SPOKE_WITH" => {
2792 matches!(
2793 head.entity_type,
2794 EntityType::Other(_) | EntityType::Custom { .. }
2795 ) || matches!(
2796 tail.entity_type,
2797 EntityType::Other(_) | EntityType::Custom { .. }
2798 ) || (matches!(head.entity_type, EntityType::Person)
2799 && matches!(
2800 tail.entity_type,
2801 EntityType::Person | EntityType::Organization
2802 ))
2803 }
2804 // Ownership
2805 "OWNS" => {
2806 matches!(
2807 head.entity_type,
2808 EntityType::Other(_) | EntityType::Custom { .. }
2809 ) || matches!(
2810 head.entity_type,
2811 EntityType::Person | EntityType::Organization
2812 )
2813 }
2814 _ => true, // Default: allow any combination
2815 };
2816
2817 if valid {
2818 return Some((
2819 label.slug.as_str(),
2820 pattern.confidence,
2821 Some((pos, pos + trigger.len())),
2822 ));
2823 }
2824 }
2825 }
2826 }
2827
2828 None
2829}
2830
2831// =============================================================================
2832// Binary Embeddings for Fast Blocking (Research: Hamming Distance)
2833// =============================================================================
2834
2835/// Binary hash for fast approximate nearest neighbor search.
2836///
2837/// # Research Background
2838///
2839/// Binary embeddings enable sub-linear search via Hamming distance. Key insight
2840/// from our research synthesis: **binary embeddings are for blocking, not primary
2841/// retrieval**. The sign-rank limitation means they cannot represent all similarity
2842/// relationships, but they excel at fast candidate filtering.
2843///
2844/// # Two-Stage Retrieval Pattern
2845///
2846/// ```text
2847/// Query → [Binary Hash] → Hamming Filter (fast) → Candidates
2848/// ↓
2849/// [Dense Similarity]
2850/// ↓
2851/// Final Results
2852/// ```
2853///
2854/// # Example
2855///
2856/// ```rust
2857/// use anno::backends::inference::BinaryHash;
2858///
2859/// // Create hashes from embeddings
2860/// let hash1 = BinaryHash::from_embedding(&[0.1, -0.2, 0.3, -0.4, 0.5, -0.6, 0.7, -0.8]);
2861/// let hash2 = BinaryHash::from_embedding(&[0.15, -0.25, 0.35, -0.45, 0.55, -0.65, 0.75, -0.85]);
2862///
2863/// // Similar embeddings → low Hamming distance
2864/// assert!(hash1.hamming_distance(&hash2) < 2);
2865/// ```
2866#[derive(Debug, Clone, PartialEq, Eq, Hash)]
2867pub struct BinaryHash {
2868 /// Packed bits (each u64 holds 64 bits)
2869 pub bits: Vec<u64>,
2870 /// Original dimension (number of bits)
2871 pub dim: usize,
2872}
2873
2874impl BinaryHash {
2875 /// Create from a dense embedding using sign function.
2876 ///
2877 /// Each positive value → 1, each negative/zero value → 0.
2878 #[must_use]
2879 pub fn from_embedding(embedding: &[f32]) -> Self {
2880 let dim = embedding.len();
2881 let num_u64s = dim.div_ceil(64);
2882 let mut bits = vec![0u64; num_u64s];
2883
2884 for (i, &val) in embedding.iter().enumerate() {
2885 if val > 0.0 {
2886 let word_idx = i / 64;
2887 let bit_idx = i % 64;
2888 bits[word_idx] |= 1u64 << bit_idx;
2889 }
2890 }
2891
2892 Self { bits, dim }
2893 }
2894
2895 /// Create from a dense f64 embedding.
2896 #[must_use]
2897 pub fn from_embedding_f64(embedding: &[f64]) -> Self {
2898 let dim = embedding.len();
2899 let num_u64s = dim.div_ceil(64);
2900 let mut bits = vec![0u64; num_u64s];
2901
2902 for (i, &val) in embedding.iter().enumerate() {
2903 if val > 0.0 {
2904 let word_idx = i / 64;
2905 let bit_idx = i % 64;
2906 bits[word_idx] |= 1u64 << bit_idx;
2907 }
2908 }
2909
2910 Self { bits, dim }
2911 }
2912
2913 /// Compute Hamming distance (number of differing bits).
2914 ///
2915 /// Uses POPCNT instruction when available for hardware acceleration.
2916 #[must_use]
2917 pub fn hamming_distance(&self, other: &Self) -> u32 {
2918 self.bits
2919 .iter()
2920 .zip(other.bits.iter())
2921 .map(|(a, b)| (a ^ b).count_ones())
2922 .sum()
2923 }
2924
2925 /// Compute normalized Hamming distance (0.0 to 1.0).
2926 #[must_use]
2927 pub fn hamming_distance_normalized(&self, other: &Self) -> f64 {
2928 if self.dim == 0 {
2929 return 0.0;
2930 }
2931 self.hamming_distance(other) as f64 / self.dim as f64
2932 }
2933
2934 /// Convert Hamming distance to approximate cosine similarity.
2935 ///
2936 /// Based on the relationship: cos(θ) ≈ 1 - 2 * (hamming_distance / dim)
2937 /// This is an approximation valid for random hyperplane hashing.
2938 #[must_use]
2939 pub fn approximate_cosine(&self, other: &Self) -> f64 {
2940 1.0 - 2.0 * self.hamming_distance_normalized(other)
2941 }
2942}
2943
2944/// Blocker using binary embeddings for fast candidate filtering.
2945///
2946/// # Usage Pattern
2947///
2948/// 1. Pre-compute binary hashes for all entities in your KB
2949/// 2. At query time, hash the query embedding
2950/// 3. Find candidates within Hamming distance threshold
2951/// 4. Run dense similarity only on candidates
2952///
2953/// # Example
2954///
2955/// ```rust
2956/// use anno::backends::inference::{BinaryBlocker, BinaryHash};
2957///
2958/// let mut blocker = BinaryBlocker::new(8); // 8-bit Hamming threshold
2959///
2960/// // Add entities to the index
2961/// let hash1 = BinaryHash::from_embedding(&vec![0.1; 768]);
2962/// let hash2 = BinaryHash::from_embedding(&vec![-0.1; 768]);
2963/// blocker.add(0, hash1);
2964/// blocker.add(1, hash2);
2965///
2966/// // Query
2967/// let query = BinaryHash::from_embedding(&vec![0.1; 768]);
2968/// let candidates = blocker.query(&query);
2969/// assert!(candidates.contains(&0)); // Similar to hash1
2970/// ```
2971#[derive(Debug, Clone)]
2972pub struct BinaryBlocker {
2973 /// Hamming distance threshold for candidates
2974 pub threshold: u32,
2975 /// Index of hashes by ID
2976 index: Vec<(usize, BinaryHash)>,
2977}
2978
2979impl BinaryBlocker {
2980 /// Create a new blocker with the given threshold.
2981 #[must_use]
2982 pub fn new(threshold: u32) -> Self {
2983 Self {
2984 threshold,
2985 index: Vec::new(),
2986 }
2987 }
2988
2989 /// Add an entity to the index.
2990 pub fn add(&mut self, id: usize, hash: BinaryHash) {
2991 self.index.push((id, hash));
2992 }
2993
2994 /// Add multiple entities.
2995 pub fn add_batch(&mut self, entries: impl IntoIterator<Item = (usize, BinaryHash)>) {
2996 self.index.extend(entries);
2997 }
2998
2999 /// Find candidate IDs within Hamming distance threshold.
3000 #[must_use]
3001 pub fn query(&self, query: &BinaryHash) -> Vec<usize> {
3002 self.index
3003 .iter()
3004 .filter(|(_, hash)| hash.hamming_distance(query) <= self.threshold)
3005 .map(|(id, _)| *id)
3006 .collect()
3007 }
3008
3009 /// Find candidates with their distances.
3010 #[must_use]
3011 pub fn query_with_distance(&self, query: &BinaryHash) -> Vec<(usize, u32)> {
3012 self.index
3013 .iter()
3014 .map(|(id, hash)| (*id, hash.hamming_distance(query)))
3015 .filter(|(_, dist)| *dist <= self.threshold)
3016 .collect()
3017 }
3018
3019 /// Number of entries in the index.
3020 #[must_use]
3021 pub fn len(&self) -> usize {
3022 self.index.len()
3023 }
3024
3025 /// Check if index is empty.
3026 #[must_use]
3027 pub fn is_empty(&self) -> bool {
3028 self.index.is_empty()
3029 }
3030
3031 /// Clear the index.
3032 pub fn clear(&mut self) {
3033 self.index.clear();
3034 }
3035}
3036
3037/// Recommended two-stage retrieval using binary blocking + dense reranking.
3038///
3039/// # Research Context
3040///
3041/// This implements the pattern identified in our research synthesis:
3042/// - Stage 1: Binary blocking for O(n) candidate filtering
3043/// - Stage 2: Dense similarity for accurate ranking
3044///
3045/// The key insight is that binary embeddings have fundamental limitations
3046/// (sign-rank theorem) but excel at fast filtering.
3047///
3048/// # Arguments
3049///
3050/// * `query_embedding` - Dense query embedding
3051/// * `candidate_embeddings` - Dense embeddings of all candidates
3052/// * `binary_threshold` - Hamming distance threshold for blocking
3053/// * `top_k` - Number of final results to return
3054///
3055/// # Returns
3056///
3057/// Vector of (candidate_index, similarity_score) pairs, sorted by score descending.
3058#[must_use]
3059pub fn two_stage_retrieval(
3060 query_embedding: &[f32],
3061 candidate_embeddings: &[Vec<f32>],
3062 binary_threshold: u32,
3063 top_k: usize,
3064) -> Vec<(usize, f32)> {
3065 // Stage 1: Binary blocking
3066 let query_hash = BinaryHash::from_embedding(query_embedding);
3067
3068 let candidate_hashes: Vec<BinaryHash> = candidate_embeddings
3069 .iter()
3070 .map(|e| BinaryHash::from_embedding(e))
3071 .collect();
3072
3073 let mut blocker = BinaryBlocker::new(binary_threshold);
3074 for (i, hash) in candidate_hashes.into_iter().enumerate() {
3075 blocker.add(i, hash);
3076 }
3077
3078 let candidates = blocker.query(&query_hash);
3079
3080 // Stage 2: Dense similarity on candidates only
3081 // Performance: Pre-allocate scored vec with known size
3082 let mut scored: Vec<(usize, f32)> = Vec::with_capacity(candidates.len());
3083 scored.extend(candidates.into_iter().map(|idx| {
3084 let sim = cosine_similarity_f32(query_embedding, &candidate_embeddings[idx]);
3085 (idx, sim)
3086 }));
3087
3088 // Performance: Use unstable sort (we don't need stable sort here)
3089 // Sort by similarity descending
3090 scored.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
3091 scored.truncate(top_k);
3092 scored
3093}
3094
3095/// Compute cosine similarity between two f32 vectors.
3096#[must_use]
3097pub fn cosine_similarity_f32(a: &[f32], b: &[f32]) -> f32 {
3098 if a.len() != b.len() || a.is_empty() {
3099 return 0.0;
3100 }
3101
3102 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
3103 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
3104 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
3105
3106 if norm_a == 0.0 || norm_b == 0.0 {
3107 return 0.0;
3108 }
3109
3110 dot / (norm_a * norm_b)
3111}
3112
3113// =============================================================================
3114// Tests
3115// =============================================================================
3116
3117#[cfg(test)]
3118mod tests {
3119 use super::*;
3120
3121 #[test]
3122 fn test_semantic_registry_builder() {
3123 let registry = SemanticRegistry::builder()
3124 .add_entity("person", "A human being")
3125 .add_entity("organization", "A company or group")
3126 .add_relation("WORKS_FOR", "Employment relationship")
3127 .build_placeholder(768);
3128
3129 assert_eq!(registry.len(), 3);
3130 assert_eq!(registry.entity_labels().count(), 2);
3131 assert_eq!(registry.relation_labels().count(), 1);
3132 }
3133
3134 #[test]
3135 fn test_standard_ner_registry() {
3136 let registry = SemanticRegistry::standard_ner(768);
3137 assert!(registry.len() >= 5);
3138 assert!(registry.label_index.contains_key("person"));
3139 assert!(registry.label_index.contains_key("organization"));
3140 }
3141
3142 #[test]
3143 fn test_dot_product_interaction() {
3144 let interaction = DotProductInteraction::new();
3145
3146 // 2 spans, 3 labels, hidden_dim=4
3147 let span_embs = vec![
3148 1.0, 0.0, 0.0, 0.0, // span 0
3149 0.0, 1.0, 0.0, 0.0, // span 1
3150 ];
3151 let label_embs = vec![
3152 1.0, 0.0, 0.0, 0.0, // label 0 (matches span 0)
3153 0.0, 1.0, 0.0, 0.0, // label 1 (matches span 1)
3154 0.5, 0.5, 0.0, 0.0, // label 2 (partial match both)
3155 ];
3156
3157 let scores = interaction.compute_similarity(&span_embs, 2, &label_embs, 3, 4);
3158
3159 assert_eq!(scores.len(), 6); // 2 * 3
3160 assert!((scores[0] - 1.0).abs() < 0.01); // span0 vs label0
3161 assert!((scores[4] - 1.0).abs() < 0.01); // span1 vs label1
3162 }
3163
3164 #[test]
3165 fn test_cosine_similarity() {
3166 let a = vec![1.0, 0.0, 0.0];
3167 let b = vec![1.0, 0.0, 0.0];
3168 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
3169
3170 let c = vec![0.0, 1.0, 0.0];
3171 assert!(cosine_similarity(&a, &c).abs() < 0.001);
3172
3173 let d = vec![-1.0, 0.0, 0.0];
3174 assert!((cosine_similarity(&a, &d) - (-1.0)).abs() < 0.001);
3175 }
3176
3177 #[test]
3178 fn test_coreference_string_match() {
3179 let entities = vec![
3180 Entity::new("Marie Curie", EntityType::Person, 0, 11, 0.95),
3181 Entity::new("Curie", EntityType::Person, 50, 55, 0.90),
3182 ];
3183
3184 let embeddings = vec![0.0f32; 2 * 768]; // Placeholder
3185 let clusters =
3186 resolve_coreferences(&entities, &embeddings, 768, &CoreferenceConfig::default());
3187
3188 assert_eq!(clusters.len(), 1);
3189 assert_eq!(clusters[0].members.len(), 2);
3190 assert_eq!(clusters[0].canonical_name, "Marie Curie");
3191 }
3192
3193 #[test]
3194 fn test_handshaking_matrix() {
3195 // 3 tokens, 2 labels, threshold 0.5
3196 let scores = vec![
3197 // token 0 with tokens 0,1,2 for labels 0,1
3198 0.9, 0.1, // (0,0)
3199 0.2, 0.8, // (0,1)
3200 0.1, 0.1, // (0,2)
3201 // token 1 with tokens 0,1,2
3202 0.0, 0.0, // (1,0) - skipped (lower triangle)
3203 0.7, 0.2, // (1,1)
3204 0.3, 0.6, // (1,2)
3205 // token 2
3206 0.0, 0.0, // (2,0)
3207 0.0, 0.0, // (2,1)
3208 0.1, 0.1, // (2,2)
3209 ];
3210
3211 let matrix = HandshakingMatrix::from_dense(&scores, 3, 2, 0.5);
3212
3213 // Should have cells for scores >= 0.5
3214 assert!(matrix.cells.len() >= 4);
3215 }
3216
3217 #[test]
3218 fn test_relation_extraction() {
3219 let entities = vec![
3220 Entity::new("Steve Jobs", EntityType::Person, 0, 10, 0.95),
3221 Entity::new("Apple", EntityType::Organization, 20, 25, 0.90),
3222 ];
3223
3224 let text = "Steve Jobs founded Apple Inc in 1976";
3225
3226 let registry = SemanticRegistry::builder()
3227 .add_relation("FOUNDED", "Founded an organization")
3228 .build_placeholder(768);
3229
3230 let config = RelationExtractionConfig::default();
3231 let relations = extract_relations(&entities, text, ®istry, &config);
3232
3233 assert!(!relations.is_empty());
3234 assert_eq!(relations[0].relation_type, "FOUNDED");
3235 }
3236
3237 #[test]
3238 fn test_relation_extraction_uses_character_offsets_with_unicode_prefix() {
3239 // Unicode prefix ensures byte offsets != character offsets.
3240 let text = "👋 Steve Jobs founded Apple Inc.";
3241
3242 // Compute character offsets explicitly (Entity spans are char-based).
3243 let steve_start = text.find("Steve Jobs").expect("substring present");
3244 // `find` returns byte offset; convert to char offset.
3245 let conv = crate::offset::SpanConverter::new(text);
3246 let steve_start_char = conv.byte_to_char(steve_start);
3247 let steve_end_char = steve_start_char + "Steve Jobs".chars().count();
3248
3249 let apple_start = text.find("Apple").expect("substring present");
3250 let apple_start_char = conv.byte_to_char(apple_start);
3251 let apple_end_char = apple_start_char + "Apple".chars().count();
3252
3253 let entities = vec![
3254 Entity::new(
3255 "Steve Jobs",
3256 EntityType::Person,
3257 steve_start_char,
3258 steve_end_char,
3259 0.95,
3260 ),
3261 Entity::new(
3262 "Apple",
3263 EntityType::Organization,
3264 apple_start_char,
3265 apple_end_char,
3266 0.90,
3267 ),
3268 ];
3269
3270 let registry = SemanticRegistry::builder()
3271 .add_relation("FOUNDED", "Founded an organization")
3272 .build_placeholder(768);
3273
3274 let config = RelationExtractionConfig::default();
3275 let relations = extract_relations(&entities, text, ®istry, &config);
3276
3277 assert!(
3278 !relations.is_empty(),
3279 "Expected FOUNDED relation to be detected"
3280 );
3281 assert_eq!(relations[0].relation_type, "FOUNDED");
3282
3283 // Trigger span should exist and cover "founded" in character offsets.
3284 let trigger = relations[0]
3285 .trigger_span
3286 .expect("expected trigger_span to be present");
3287 let trigger_text: String = text
3288 .chars()
3289 .skip(trigger.0)
3290 .take(trigger.1.saturating_sub(trigger.0))
3291 .collect();
3292 assert_eq!(trigger_text.to_ascii_lowercase(), "founded");
3293 }
3294
3295 // =========================================================================
3296 // Binary Embedding Tests
3297 // =========================================================================
3298
3299 #[test]
3300 fn test_binary_hash_creation() {
3301 let embedding = vec![0.1, -0.2, 0.3, -0.4, 0.5, -0.6, 0.7, -0.8];
3302 let hash = BinaryHash::from_embedding(&embedding);
3303
3304 assert_eq!(hash.dim, 8);
3305 // Positive values at indices 0, 2, 4, 6 should be set
3306 // bits[0] should have bits 0, 2, 4, 6 set = 0b01010101 = 85
3307 assert_eq!(hash.bits[0], 85);
3308 }
3309
3310 #[test]
3311 fn test_hamming_distance_identical() {
3312 let embedding = vec![0.1; 64];
3313 let hash1 = BinaryHash::from_embedding(&embedding);
3314 let hash2 = BinaryHash::from_embedding(&embedding);
3315
3316 assert_eq!(hash1.hamming_distance(&hash2), 0);
3317 }
3318
3319 #[test]
3320 fn test_hamming_distance_opposite() {
3321 let embedding1 = vec![0.1; 64];
3322 let embedding2 = vec![-0.1; 64];
3323 let hash1 = BinaryHash::from_embedding(&embedding1);
3324 let hash2 = BinaryHash::from_embedding(&embedding2);
3325
3326 assert_eq!(hash1.hamming_distance(&hash2), 64);
3327 }
3328
3329 #[test]
3330 fn test_hamming_distance_half() {
3331 let embedding1 = vec![0.1; 64];
3332 let mut embedding2 = vec![0.1; 64];
3333 // Flip second half
3334 embedding2[32..64].iter_mut().for_each(|x| *x = -0.1);
3335
3336 let hash1 = BinaryHash::from_embedding(&embedding1);
3337 let hash2 = BinaryHash::from_embedding(&embedding2);
3338
3339 assert_eq!(hash1.hamming_distance(&hash2), 32);
3340 }
3341
3342 #[test]
3343 fn test_binary_blocker() {
3344 let mut blocker = BinaryBlocker::new(5);
3345
3346 // Add some hashes
3347 let base_embedding = vec![0.1; 64];
3348 let similar_embedding = {
3349 let mut e = vec![0.1; 64];
3350 e[0] = -0.1; // Flip 1 bit
3351 e[1] = -0.1; // Flip 2 bits
3352 e
3353 };
3354 let different_embedding = vec![-0.1; 64];
3355
3356 blocker.add(0, BinaryHash::from_embedding(&base_embedding));
3357 blocker.add(1, BinaryHash::from_embedding(&similar_embedding));
3358 blocker.add(2, BinaryHash::from_embedding(&different_embedding));
3359
3360 // Query with base
3361 let query = BinaryHash::from_embedding(&base_embedding);
3362 let candidates = blocker.query(&query);
3363
3364 assert!(candidates.contains(&0), "Should find exact match");
3365 assert!(
3366 candidates.contains(&1),
3367 "Should find similar (2 bits different)"
3368 );
3369 assert!(
3370 !candidates.contains(&2),
3371 "Should NOT find opposite (64 bits different)"
3372 );
3373 }
3374
3375 #[test]
3376 fn test_two_stage_retrieval() {
3377 // Create embeddings
3378 let query = vec![1.0, 0.0, 0.0, 0.0];
3379 let candidates = vec![
3380 vec![1.0, 0.0, 0.0, 0.0], // Identical
3381 vec![0.9, 0.1, 0.0, 0.0], // Similar
3382 vec![-1.0, 0.0, 0.0, 0.0], // Opposite
3383 vec![0.0, 1.0, 0.0, 0.0], // Orthogonal
3384 ];
3385
3386 // Generous threshold to get candidates
3387 let results = two_stage_retrieval(&query, &candidates, 4, 2);
3388
3389 assert!(!results.is_empty());
3390 // First result should be exact match
3391 assert_eq!(results[0].0, 0);
3392 assert!((results[0].1 - 1.0).abs() < 0.001);
3393 }
3394
3395 #[test]
3396 fn test_approximate_cosine() {
3397 let embedding1 = vec![0.1; 768];
3398 let embedding2 = vec![0.1; 768];
3399 let hash1 = BinaryHash::from_embedding(&embedding1);
3400 let hash2 = BinaryHash::from_embedding(&embedding2);
3401
3402 // Identical → approximate cosine should be ~1.0
3403 let approx = hash1.approximate_cosine(&hash2);
3404 assert!((approx - 1.0).abs() < 0.001);
3405
3406 // Opposite → approximate cosine should be ~-1.0
3407 let embedding3 = vec![-0.1; 768];
3408 let hash3 = BinaryHash::from_embedding(&embedding3);
3409 let approx_opp = hash1.approximate_cosine(&hash3);
3410 assert!((approx_opp - (-1.0)).abs() < 0.001);
3411 }
3412}