anno/backends/inference/traits.rs
1//! Core extraction traits: ZeroShotNER, RelationExtractor, RelationCapable defaults,
2//! and DiscontinuousNER.
3
4#[allow(unused_imports)]
5use crate::{Entity, EntityType, Relation};
6
7// Zero-Shot NER Trait
8// =============================================================================
9
10/// Zero-shot NER for open entity types.
11///
12/// # Motivation
13///
14/// Traditional NER models are trained on fixed taxonomies (PER, ORG, LOC, etc.)
15/// and cannot extract new entity types without retraining. Zero-shot NER solves
16/// this by allowing **arbitrary entity types at inference time**.
17///
18/// Instead of asking "is this a PERSON?", zero-shot NER asks "does this text
19/// span match the description 'a named individual human being'?"
20///
21/// # Use Cases
22///
23/// - **Domain adaptation**: Extract "gene names" or "legal citations" without
24/// training data
25/// - **Custom taxonomies**: Use your own entity hierarchy
26/// - **Rapid prototyping**: Test new entity types before investing in annotation
27///
28/// # Research Alignment
29///
30/// From GLiNER (arXiv:2311.08526):
31/// > "NER model capable of identifying any entity type using a bidirectional
32/// > transformer encoder... provides a practical alternative to traditional
33/// > NER models, which are limited to predefined entity types."
34///
35/// From UniversalNER (arXiv:2308.03279):
36/// > "Large language models demonstrate remarkable generalizability, such as
37/// > understanding arbitrary entities and relations."
38///
39/// # Example
40///
41/// ```ignore
42/// use anno::ZeroShotNER;
43///
44/// fn extract_medical_entities(ner: &dyn ZeroShotNER, clinical_note: &str) {
45/// // Define custom medical entity types at runtime
46/// let types = &["drug name", "disease", "symptom", "dosage"];
47///
48/// let entities = ner.extract_with_types(clinical_note, types, 0.5).unwrap();
49/// for e in entities {
50/// println!("{}: {} (conf: {:.2})", e.entity_type, e.text, e.confidence);
51/// }
52/// }
53///
54/// fn extract_with_descriptions(ner: &dyn ZeroShotNER, text: &str) {
55/// // Even richer: use natural language descriptions
56/// let descriptions = &[
57/// "a medication or pharmaceutical compound",
58/// "a medical condition or illness",
59/// "a physical sensation indicating illness",
60/// ];
61///
62/// let entities = ner.extract_with_descriptions(text, descriptions, 0.5).unwrap();
63/// }
64/// ```
65pub trait ZeroShotNER: Send + Sync {
66 /// Extract entities with custom types.
67 ///
68 /// # Arguments
69 /// * `text` - Input text
70 /// * `entity_types` - Entity type descriptions (arbitrary text, not fixed vocabulary)
71 /// - Encoded as text embeddings via bi-encoder (semantic matching, not exact string match)
72 /// - Any string works: `"disease"`, `"pharmaceutical compound"`, `"19th century French philosopher"`
73 /// - **Replaces default types completely** - model only extracts the specified types
74 /// - To include defaults, pass them explicitly: `&["person", "organization", "disease"]`
75 /// * `threshold` - Confidence threshold (0.0 - 1.0)
76 ///
77 /// # Returns
78 /// Entities with their matched types
79 ///
80 /// # Behavior
81 ///
82 /// - **Arbitrary text**: Type hints are not fixed vocabulary. They're encoded as embeddings,
83 /// so semantic similarity determines matches (not exact string matching).
84 /// - **Replace, don't union**: This method completely replaces default entity types.
85 /// The model only extracts the types you specify.
86 /// - **Semantic matching**: Uses cosine similarity between text span embeddings and label embeddings.
87 fn extract_with_types(
88 &self,
89 text: &str,
90 entity_types: &[&str],
91 threshold: f32,
92 ) -> crate::Result<Vec<Entity>>;
93
94 /// Extract entities with natural language descriptions.
95 ///
96 /// # Arguments
97 /// * `text` - Input text
98 /// * `descriptions` - Natural language descriptions of what to extract
99 /// - Encoded as text embeddings (same as `extract_with_types`)
100 /// - Examples: `"companies headquartered in Europe"`, `"diseases affecting the heart"`
101 /// - **Replaces default types completely** - model only extracts the specified descriptions
102 /// * `threshold` - Confidence threshold
103 ///
104 /// # Behavior
105 ///
106 /// Same as `extract_with_types`, but accepts natural language descriptions instead of
107 /// short type labels. Both methods encode labels as embeddings and use semantic matching.
108 fn extract_with_descriptions(
109 &self,
110 text: &str,
111 descriptions: &[&str],
112 threshold: f32,
113 ) -> crate::Result<Vec<Entity>>;
114
115 /// Get default entity types for this model.
116 ///
117 /// Returns the entity types used by `extract_entities()` (via `Model` trait).
118 /// Useful for extending defaults: combine with custom types and pass to `extract_with_types()`.
119 ///
120 /// # Example: Extending defaults
121 ///
122 /// ```ignore
123 /// use anno::ZeroShotNER;
124 ///
125 /// let ner: &dyn ZeroShotNER = ...;
126 /// let defaults = ner.default_types();
127 ///
128 /// // Combine defaults with custom types
129 /// let mut types: Vec<&str> = defaults.to_vec();
130 /// types.extend(&["disease", "medication"]);
131 ///
132 /// let entities = ner.extract_with_types(text, &types, 0.5)?;
133 /// ```
134 fn default_types(&self) -> &[&'static str];
135}
136
137// =============================================================================
138// Relation Extractor Trait
139// =============================================================================
140
141/// Joint entity and relation extraction.
142///
143/// # Motivation
144///
145/// Real-world information extraction often requires both entities AND their
146/// relationships. For example, extracting "Steve Jobs" and "Apple" is useful,
147/// but knowing "Steve Jobs FOUNDED Apple" is far more valuable.
148///
149/// Joint extraction (vs pipeline) is preferred because:
150/// - **Error propagation**: Pipeline errors compound (bad entities → bad relations)
151/// - **Shared context**: Entities and relations inform each other
152/// - **Efficiency**: Single forward pass instead of two
153///
154/// # Architecture
155///
156/// ```text
157/// Input: "Steve Jobs founded Apple in 1976."
158/// │
159/// ▼
160/// ┌──────────────────────────────────┐
161/// │ Shared Encoder (BERT) │
162/// └──────────────────────────────────┘
163/// │
164/// ┌──────┴──────┐
165/// ▼ ▼
166/// ┌───────────────┐ ┌───────────────┐
167/// │ Entity Head │ │ Relation Head │
168/// │ (span class.) │ │ (pair class.) │
169/// └───────┬───────┘ └───────┬───────┘
170/// │ │
171/// ▼ ▼
172/// Entities: Relations:
173/// - Steve Jobs [PER] - (Steve Jobs, FOUNDED, Apple)
174/// - Apple [ORG] - (Apple, FOUNDED_IN, 1976)
175/// - 1976 [DATE]
176/// ```
177///
178/// # Research Alignment
179///
180/// From GLiNER multi-task (arXiv:2406.12925):
181/// > "Generalist Lightweight Model for Various Information Extraction Tasks...
182/// > joint entity and relation extraction."
183///
184/// From W2NER (arXiv:2112.10070):
185/// > "Unified Named Entity Recognition as Word-Word Relation Classification...
186/// > handles flat, overlapped, and discontinuous NER."
187///
188/// # Example
189///
190/// ```ignore
191/// use anno::RelationExtractor;
192///
193/// fn build_knowledge_graph(extractor: &dyn RelationExtractor, text: &str) {
194/// let entity_types = &["person", "organization", "date"];
195/// let relation_types = &["founded", "works_for", "acquired"];
196///
197/// let result = extractor.extract_with_relations(
198/// text, entity_types, relation_types, 0.5
199/// ).unwrap();
200///
201/// // Build graph nodes from entities
202/// for e in &result.entities {
203/// println!("Node: {} ({})", e.text, e.entity_type);
204/// }
205///
206/// // Build graph edges from relations
207/// for r in &result.relations {
208/// let head = &result.entities[r.head_idx];
209/// let tail = &result.entities[r.tail_idx];
210/// println!("Edge: {} --[{}]--> {}", head.text, r.relation_type, tail.text);
211/// }
212/// }
213/// ```
214pub trait RelationExtractor: Send + Sync {
215 /// Extract entities and relations jointly.
216 ///
217 /// # Arguments
218 /// * `text` - Input text
219 /// * `entity_types` - Entity types to extract
220 /// * `relation_types` - Relation types to extract
221 /// * `threshold` - Confidence threshold
222 ///
223 /// # Returns
224 /// Entities and relations between them
225 fn extract_with_relations(
226 &self,
227 text: &str,
228 entity_types: &[&str],
229 relation_types: &[&str],
230 threshold: f32,
231 ) -> crate::Result<ExtractionWithRelations>;
232}
233
234/// Output from joint entity-relation extraction.
235#[derive(Debug, Clone, Default)]
236pub struct ExtractionWithRelations {
237 /// Extracted entities
238 pub entities: Vec<Entity>,
239 /// Relations between entities (indices into entities vec)
240 pub relations: Vec<RelationTriple>,
241}
242
243/// A relation triple linking two entities.
244#[derive(Debug, Clone)]
245pub struct RelationTriple {
246 /// Index of head entity in entities vec
247 pub head_idx: usize,
248 /// Index of tail entity in entities vec
249 pub tail_idx: usize,
250 /// Relation type
251 pub relation_type: String,
252 /// Confidence score
253 pub confidence: f32,
254}
255
256// =============================================================================
257// Shared defaults for RelationCapable::extract_with_relations
258// =============================================================================
259
260/// Broad default entity types for the no-arg `RelationCapable` convenience interface.
261///
262/// These cover the most common NER taxonomies (CoNLL, OntoNotes, ACE). Callers that need
263/// precise control should use `RelationExtractor::extract_with_relations` directly.
264pub(crate) const DEFAULT_ENTITY_TYPES: &[&str] = &[
265 "person",
266 "organization",
267 "location",
268 "date",
269 "product",
270 "event",
271];
272
273/// Broad default relation types for the no-arg `RelationCapable` convenience interface.
274pub(crate) const DEFAULT_RELATION_TYPES: &[&str] = &[
275 "founded",
276 "works_for",
277 "located_in",
278 "acquired",
279 "born_in",
280 "member_of",
281 "ceo_of",
282 "part_of",
283 "subsidiary_of",
284];
285
286impl ExtractionWithRelations {
287 /// Convert index-based `RelationTriple`s into owned `anno_core::Relation` values.
288 ///
289 /// Indices that are out-of-bounds (should not happen but can in malformed output) are
290 /// silently dropped.
291 #[must_use]
292 pub fn into_anno_relations(self) -> (Vec<Entity>, Vec<crate::Relation>) {
293 let relations = self
294 .relations
295 .iter()
296 .filter_map(|t| {
297 let head = self.entities.get(t.head_idx)?.clone();
298 let tail = self.entities.get(t.tail_idx)?.clone();
299 Some(crate::Relation::new(
300 head,
301 tail,
302 t.relation_type.clone(),
303 t.confidence as f64,
304 ))
305 })
306 .collect();
307 (self.entities, relations)
308 }
309}
310
311// =============================================================================
312// Discontinuous Entity Support (W2NER Research)
313// =============================================================================
314
315/// Support for discontinuous entity spans.
316///
317/// # Motivation
318///
319/// Not all entities are contiguous text spans. In coordination structures,
320/// entities can be **discontinuous** - scattered across non-adjacent positions.
321///
322/// # Examples of Discontinuous Entities
323///
324/// ```text
325/// "New York and Los Angeles airports"
326/// ^^^^^^^^ ^^^^^^^^^^^ ^^^^^^^^
327/// └──────────────────────────┘
328/// LOCATION: "New York airports" (discontinuous!)
329/// ^^^^^^^^^^^ ^^^^^^^^
330/// └───────────┘
331/// LOCATION: "Los Angeles airports" (contiguous)
332///
333/// "protein A and B complex"
334/// ^^^^^^^^^ ^^^ ^^^^^^^^^
335/// └────────────────────┘
336/// PROTEIN: "protein A ... complex" (discontinuous!)
337/// ```
338///
339/// # NER Complexity Hierarchy
340///
341/// | Type | Description | Example |
342/// |------|-------------|---------|
343/// | Flat | Non-overlapping spans | "John works at Google" |
344/// | Nested | Overlapping spans | "\[New \[York\] City\]" |
345/// | Discontinuous | Non-contiguous | "New York and LA \[airports\]" |
346///
347/// # Research Alignment
348///
349/// From W2NER (arXiv:2112.10070):
350/// > "Named entity recognition has been involved with three major types,
351/// > including flat, overlapped (aka. nested), and discontinuous NER...
352/// > we propose a novel architecture to model NER as word-word relation
353/// > classification."
354///
355/// W2NER achieves this by building a **handshaking matrix** where each cell
356/// (i, j) indicates whether tokens i and j are part of the same entity.
357///
358/// # Example
359///
360/// ```ignore
361/// use anno::DiscontinuousNER;
362///
363/// fn extract_complex_entities(ner: &dyn DiscontinuousNER, text: &str) {
364/// let types = &["location", "protein"];
365/// let entities = ner.extract_discontinuous(text, types, 0.5).unwrap();
366///
367/// for e in entities {
368/// if e.is_contiguous() {
369/// println!("Contiguous {}: '{}'", e.entity_type, e.text);
370/// } else {
371/// println!("Discontinuous {}: '{}' spans: {:?}",
372/// e.entity_type, e.text, e.spans);
373/// }
374/// }
375/// }
376/// ```
377pub trait DiscontinuousNER: Send + Sync {
378 /// Extract entities including discontinuous spans.
379 ///
380 /// # Arguments
381 /// * `text` - Input text
382 /// * `entity_types` - Entity types to extract
383 /// * `threshold` - Confidence threshold
384 ///
385 /// # Returns
386 /// Entities, potentially with multiple non-contiguous spans
387 fn extract_discontinuous(
388 &self,
389 text: &str,
390 entity_types: &[&str],
391 threshold: f32,
392 ) -> crate::Result<Vec<DiscontinuousEntity>>;
393}
394
395/// An entity that may span multiple non-contiguous regions.
396#[derive(Debug, Clone)]
397pub struct DiscontinuousEntity {
398 /// The spans that make up this entity (may be non-contiguous)
399 pub spans: Vec<(usize, usize)>,
400 /// Concatenated text from all spans
401 pub text: String,
402 /// Entity type
403 pub entity_type: String,
404 /// Confidence score
405 pub confidence: f32,
406}
407
408impl DiscontinuousEntity {
409 /// Check if this entity is contiguous (single span).
410 pub fn is_contiguous(&self) -> bool {
411 self.spans.len() == 1
412 }
413
414 /// Convert to a standard Entity if contiguous.
415 pub fn to_entity(&self) -> Option<Entity> {
416 if self.is_contiguous() {
417 let (start, end) = self.spans[0];
418 Some(Entity::new(
419 self.text.clone(),
420 EntityType::from_label(&self.entity_type),
421 start,
422 end,
423 self.confidence as f64,
424 ))
425 } else {
426 None
427 }
428 }
429}
430
431// =============================================================================