Skip to main content

anno/backends/inference/
traits.rs

1//! Core extraction traits: ZeroShotNER, RelationExtractor, RelationCapable defaults,
2//! and DiscontinuousNER.
3
4#[allow(unused_imports)]
5use crate::{Entity, EntityType, Relation};
6
7// Zero-Shot NER Trait
8// =============================================================================
9
10/// Zero-shot NER for open entity types.
11///
12/// # Motivation
13///
14/// Traditional NER models are trained on fixed taxonomies (PER, ORG, LOC, etc.)
15/// and cannot extract new entity types without retraining. Zero-shot NER solves
16/// this by allowing **arbitrary entity types at inference time**.
17///
18/// Instead of asking "is this a PERSON?", zero-shot NER asks "does this text
19/// span match the description 'a named individual human being'?"
20///
21/// # Use Cases
22///
23/// - **Domain adaptation**: Extract "gene names" or "legal citations" without
24///   training data
25/// - **Custom taxonomies**: Use your own entity hierarchy
26/// - **Rapid prototyping**: Test new entity types before investing in annotation
27///
28/// # Research Alignment
29///
30/// From GLiNER (arXiv:2311.08526):
31/// > "NER model capable of identifying any entity type using a bidirectional
32/// > transformer encoder... provides a practical alternative to traditional
33/// > NER models, which are limited to predefined entity types."
34///
35/// From UniversalNER (arXiv:2308.03279):
36/// > "Large language models demonstrate remarkable generalizability, such as
37/// > understanding arbitrary entities and relations."
38///
39/// # Example
40///
41/// ```ignore
42/// use anno::ZeroShotNER;
43///
44/// fn extract_medical_entities(ner: &dyn ZeroShotNER, clinical_note: &str) {
45///     // Define custom medical entity types at runtime
46///     let types = &["drug name", "disease", "symptom", "dosage"];
47///
48///     let entities = ner.extract_with_types(clinical_note, types, 0.5).unwrap();
49///     for e in entities {
50///         println!("{}: {} (conf: {:.2})", e.entity_type, e.text, e.confidence);
51///     }
52/// }
53///
54/// fn extract_with_descriptions(ner: &dyn ZeroShotNER, text: &str) {
55///     // Even richer: use natural language descriptions
56///     let descriptions = &[
57///         "a medication or pharmaceutical compound",
58///         "a medical condition or illness",
59///         "a physical sensation indicating illness",
60///     ];
61///
62///     let entities = ner.extract_with_descriptions(text, descriptions, 0.5).unwrap();
63/// }
64/// ```
65pub trait ZeroShotNER: Send + Sync {
66    /// Extract entities with custom types.
67    ///
68    /// # Arguments
69    /// * `text` - Input text
70    /// * `entity_types` - Entity type descriptions (arbitrary text, not fixed vocabulary)
71    ///   - Encoded as text embeddings via bi-encoder (semantic matching, not exact string match)
72    ///   - Any string works: `"disease"`, `"pharmaceutical compound"`, `"19th century French philosopher"`
73    ///   - **Replaces default types completely** - model only extracts the specified types
74    ///   - To include defaults, pass them explicitly: `&["person", "organization", "disease"]`
75    /// * `threshold` - Confidence threshold (0.0 - 1.0)
76    ///
77    /// # Returns
78    /// Entities with their matched types
79    ///
80    /// # Behavior
81    ///
82    /// - **Arbitrary text**: Type hints are not fixed vocabulary. They're encoded as embeddings,
83    ///   so semantic similarity determines matches (not exact string matching).
84    /// - **Replace, don't union**: This method completely replaces default entity types.
85    ///   The model only extracts the types you specify.
86    /// - **Semantic matching**: Uses cosine similarity between text span embeddings and label embeddings.
87    fn extract_with_types(
88        &self,
89        text: &str,
90        entity_types: &[&str],
91        threshold: f32,
92    ) -> crate::Result<Vec<Entity>>;
93
94    /// Extract entities with natural language descriptions.
95    ///
96    /// # Arguments
97    /// * `text` - Input text
98    /// * `descriptions` - Natural language descriptions of what to extract
99    ///   - Encoded as text embeddings (same as `extract_with_types`)
100    ///   - Examples: `"companies headquartered in Europe"`, `"diseases affecting the heart"`
101    ///   - **Replaces default types completely** - model only extracts the specified descriptions
102    /// * `threshold` - Confidence threshold
103    ///
104    /// # Behavior
105    ///
106    /// Same as `extract_with_types`, but accepts natural language descriptions instead of
107    /// short type labels. Both methods encode labels as embeddings and use semantic matching.
108    fn extract_with_descriptions(
109        &self,
110        text: &str,
111        descriptions: &[&str],
112        threshold: f32,
113    ) -> crate::Result<Vec<Entity>>;
114
115    /// Get default entity types for this model.
116    ///
117    /// Returns the entity types used by `extract_entities()` (via `Model` trait).
118    /// Useful for extending defaults: combine with custom types and pass to `extract_with_types()`.
119    ///
120    /// # Example: Extending defaults
121    ///
122    /// ```ignore
123    /// use anno::ZeroShotNER;
124    ///
125    /// let ner: &dyn ZeroShotNER = ...;
126    /// let defaults = ner.default_types();
127    ///
128    /// // Combine defaults with custom types
129    /// let mut types: Vec<&str> = defaults.to_vec();
130    /// types.extend(&["disease", "medication"]);
131    ///
132    /// let entities = ner.extract_with_types(text, &types, 0.5)?;
133    /// ```
134    fn default_types(&self) -> &[&'static str];
135}
136
137// =============================================================================
138// Relation Extractor Trait
139// =============================================================================
140
141/// Joint entity and relation extraction.
142///
143/// # Motivation
144///
145/// Real-world information extraction often requires both entities AND their
146/// relationships. For example, extracting "Steve Jobs" and "Apple" is useful,
147/// but knowing "Steve Jobs FOUNDED Apple" is far more valuable.
148///
149/// Joint extraction (vs pipeline) is preferred because:
150/// - **Error propagation**: Pipeline errors compound (bad entities → bad relations)
151/// - **Shared context**: Entities and relations inform each other
152/// - **Efficiency**: Single forward pass instead of two
153///
154/// # Architecture
155///
156/// ```text
157/// Input: "Steve Jobs founded Apple in 1976."
158///                │
159///                ▼
160/// ┌──────────────────────────────────┐
161/// │     Shared Encoder (BERT)        │
162/// └──────────────────────────────────┘
163///                │
164///         ┌──────┴──────┐
165///         ▼             ▼
166/// ┌───────────────┐  ┌───────────────┐
167/// │ Entity Head   │  │ Relation Head │
168/// │ (span class.) │  │ (pair class.) │
169/// └───────┬───────┘  └───────┬───────┘
170///         │                  │
171///         ▼                  ▼
172/// Entities:              Relations:
173/// - Steve Jobs [PER]     - (Steve Jobs, FOUNDED, Apple)
174/// - Apple [ORG]          - (Apple, FOUNDED_IN, 1976)
175/// - 1976 [DATE]
176/// ```
177///
178/// # Research Alignment
179///
180/// From GLiNER multi-task (arXiv:2406.12925):
181/// > "Generalist Lightweight Model for Various Information Extraction Tasks...
182/// > joint entity and relation extraction."
183///
184/// From W2NER (arXiv:2112.10070):
185/// > "Unified Named Entity Recognition as Word-Word Relation Classification...
186/// > handles flat, overlapped, and discontinuous NER."
187///
188/// # Example
189///
190/// ```ignore
191/// use anno::RelationExtractor;
192///
193/// fn build_knowledge_graph(extractor: &dyn RelationExtractor, text: &str) {
194///     let entity_types = &["person", "organization", "date"];
195///     let relation_types = &["founded", "works_for", "acquired"];
196///
197///     let result = extractor.extract_with_relations(
198///         text, entity_types, relation_types, 0.5
199///     ).unwrap();
200///
201///     // Build graph nodes from entities
202///     for e in &result.entities {
203///         println!("Node: {} ({})", e.text, e.entity_type);
204///     }
205///
206///     // Build graph edges from relations
207///     for r in &result.relations {
208///         let head = &result.entities[r.head_idx];
209///         let tail = &result.entities[r.tail_idx];
210///         println!("Edge: {} --[{}]--> {}", head.text, r.relation_type, tail.text);
211///     }
212/// }
213/// ```
214pub trait RelationExtractor: Send + Sync {
215    /// Extract entities and relations jointly.
216    ///
217    /// # Arguments
218    /// * `text` - Input text
219    /// * `entity_types` - Entity types to extract
220    /// * `relation_types` - Relation types to extract
221    /// * `threshold` - Confidence threshold
222    ///
223    /// # Returns
224    /// Entities and relations between them
225    fn extract_with_relations(
226        &self,
227        text: &str,
228        entity_types: &[&str],
229        relation_types: &[&str],
230        threshold: f32,
231    ) -> crate::Result<ExtractionWithRelations>;
232}
233
234/// Output from joint entity-relation extraction.
235#[derive(Debug, Clone, Default)]
236pub struct ExtractionWithRelations {
237    /// Extracted entities
238    pub entities: Vec<Entity>,
239    /// Relations between entities (indices into entities vec)
240    pub relations: Vec<RelationTriple>,
241}
242
243/// A relation triple linking two entities.
244#[derive(Debug, Clone)]
245pub struct RelationTriple {
246    /// Index of head entity in entities vec
247    pub head_idx: usize,
248    /// Index of tail entity in entities vec
249    pub tail_idx: usize,
250    /// Relation type
251    pub relation_type: String,
252    /// Confidence score
253    pub confidence: f32,
254}
255
256// =============================================================================
257// Shared defaults for RelationCapable::extract_with_relations
258// =============================================================================
259
260/// Broad default entity types for the no-arg `RelationCapable` convenience interface.
261///
262/// These cover the most common NER taxonomies (CoNLL, OntoNotes, ACE). Callers that need
263/// precise control should use `RelationExtractor::extract_with_relations` directly.
264pub(crate) const DEFAULT_ENTITY_TYPES: &[&str] = &[
265    "person",
266    "organization",
267    "location",
268    "date",
269    "product",
270    "event",
271];
272
273/// Broad default relation types for the no-arg `RelationCapable` convenience interface.
274pub(crate) const DEFAULT_RELATION_TYPES: &[&str] = &[
275    "founded",
276    "works_for",
277    "located_in",
278    "acquired",
279    "born_in",
280    "member_of",
281    "ceo_of",
282    "part_of",
283    "subsidiary_of",
284];
285
286impl ExtractionWithRelations {
287    /// Convert index-based `RelationTriple`s into owned `anno_core::Relation` values.
288    ///
289    /// Indices that are out-of-bounds (should not happen but can in malformed output) are
290    /// silently dropped.
291    #[must_use]
292    pub fn into_anno_relations(self) -> (Vec<Entity>, Vec<crate::Relation>) {
293        let relations = self
294            .relations
295            .iter()
296            .filter_map(|t| {
297                let head = self.entities.get(t.head_idx)?.clone();
298                let tail = self.entities.get(t.tail_idx)?.clone();
299                Some(crate::Relation::new(
300                    head,
301                    tail,
302                    t.relation_type.clone(),
303                    t.confidence as f64,
304                ))
305            })
306            .collect();
307        (self.entities, relations)
308    }
309}
310
311// =============================================================================
312// Discontinuous Entity Support (W2NER Research)
313// =============================================================================
314
315/// Support for discontinuous entity spans.
316///
317/// # Motivation
318///
319/// Not all entities are contiguous text spans. In coordination structures,
320/// entities can be **discontinuous** - scattered across non-adjacent positions.
321///
322/// # Examples of Discontinuous Entities
323///
324/// ```text
325/// "New York and Los Angeles airports"
326///  ^^^^^^^^     ^^^^^^^^^^^ ^^^^^^^^
327///  └──────────────────────────┘
328///     LOCATION: "New York airports" (discontinuous!)
329///                ^^^^^^^^^^^ ^^^^^^^^
330///                └───────────┘
331///                LOCATION: "Los Angeles airports" (contiguous)
332///
333/// "protein A and B complex"
334///  ^^^^^^^^^ ^^^ ^^^^^^^^^
335///  └────────────────────┘
336///     PROTEIN: "protein A ... complex" (discontinuous!)
337/// ```
338///
339/// # NER Complexity Hierarchy
340///
341/// | Type | Description | Example |
342/// |------|-------------|---------|
343/// | Flat | Non-overlapping spans | "John works at Google" |
344/// | Nested | Overlapping spans | "\[New \[York\] City\]" |
345/// | Discontinuous | Non-contiguous | "New York and LA \[airports\]" |
346///
347/// # Research Alignment
348///
349/// From W2NER (arXiv:2112.10070):
350/// > "Named entity recognition has been involved with three major types,
351/// > including flat, overlapped (aka. nested), and discontinuous NER...
352/// > we propose a novel architecture to model NER as word-word relation
353/// > classification."
354///
355/// W2NER achieves this by building a **handshaking matrix** where each cell
356/// (i, j) indicates whether tokens i and j are part of the same entity.
357///
358/// # Example
359///
360/// ```ignore
361/// use anno::DiscontinuousNER;
362///
363/// fn extract_complex_entities(ner: &dyn DiscontinuousNER, text: &str) {
364///     let types = &["location", "protein"];
365///     let entities = ner.extract_discontinuous(text, types, 0.5).unwrap();
366///
367///     for e in entities {
368///         if e.is_contiguous() {
369///             println!("Contiguous {}: '{}'", e.entity_type, e.text);
370///         } else {
371///             println!("Discontinuous {}: '{}' spans: {:?}",
372///                      e.entity_type, e.text, e.spans);
373///         }
374///     }
375/// }
376/// ```
377pub trait DiscontinuousNER: Send + Sync {
378    /// Extract entities including discontinuous spans.
379    ///
380    /// # Arguments
381    /// * `text` - Input text
382    /// * `entity_types` - Entity types to extract
383    /// * `threshold` - Confidence threshold
384    ///
385    /// # Returns
386    /// Entities, potentially with multiple non-contiguous spans
387    fn extract_discontinuous(
388        &self,
389        text: &str,
390        entity_types: &[&str],
391        threshold: f32,
392    ) -> crate::Result<Vec<DiscontinuousEntity>>;
393}
394
395/// An entity that may span multiple non-contiguous regions.
396#[derive(Debug, Clone)]
397pub struct DiscontinuousEntity {
398    /// The spans that make up this entity (may be non-contiguous)
399    pub spans: Vec<(usize, usize)>,
400    /// Concatenated text from all spans
401    pub text: String,
402    /// Entity type
403    pub entity_type: String,
404    /// Confidence score
405    pub confidence: f32,
406}
407
408impl DiscontinuousEntity {
409    /// Check if this entity is contiguous (single span).
410    pub fn is_contiguous(&self) -> bool {
411        self.spans.len() == 1
412    }
413
414    /// Convert to a standard Entity if contiguous.
415    pub fn to_entity(&self) -> Option<Entity> {
416        if self.is_contiguous() {
417            let (start, end) = self.spans[0];
418            Some(Entity::new(
419                self.text.clone(),
420                EntityType::from_label(&self.entity_type),
421                start,
422                end,
423                self.confidence as f64,
424            ))
425        } else {
426            None
427        }
428    }
429}
430
431// =============================================================================