Skip to main content

anno/backends/
tplinker.rs

1//! TPLinker: Single-stage Joint Entity-Relation Extraction
2//!
3//! TPLinker uses a handshaking tagging scheme for joint entity-relation extraction.
4//! It models entity boundaries and relations simultaneously using a unified tagging matrix.
5//!
6//! # Implementation Status
7//!
8//! This module keeps the TPLinker **name + wiring** stable while the full neural handshaking
9//! model is still pending.
10//!
11//! Today, `TPLinker` is implemented as a **dependency-light heuristic baseline**:
12//! - entities are extracted using a zero-dependency NER baseline
13//! - relations are inferred using the shared heuristic matcher in `backends::inference`
14//!
15//! This makes relation extraction *function* end-to-end (for demos, DX, and eval harnesses)
16//! without pretending we have a trained TPLinker model.
17//!
18//! # Research
19//!
20//! - **Paper**: [TPLinker: Single-stage Joint Extraction](https://aclanthology.org/2020.coling-main.138/)
21//! - **Architecture**: Handshaking matrix where each cell (i,j) encodes:
22//!   - Entity boundaries (SH2OH, OH2SH, ST2OT, OT2ST)
23//!   - Relations (handshaking between entity pairs)
24//!
25//! # Usage
26//!
27//! ```rust,ignore
28//! use anno::backends::tplinker::TPLinker;
29//!
30//! let extractor = TPLinker::new()?;
31//! let result = extractor.extract_with_relations(
32//!     "Steve Jobs founded Apple in 1976.",
33//!     &["person", "organization"],
34//!     &["founded", "works_for"],
35//!     0.5
36//! )?;
37//!
38//! for entity in &result.entities {
39//!     println!("Entity: {} ({})", entity.text, entity.entity_type);
40//! }
41//!
42//! for relation in &result.relations {
43//!     let head = &result.entities[relation.head_idx];
44//!     let tail = &result.entities[relation.tail_idx];
45//!     println!("Relation: {} --[{}]--> {}", head.text, relation.relation_type, tail.text);
46//! }
47//! ```
48
49use crate::backends::inference::{
50    extract_relation_triples, ExtractionWithRelations, RelationExtractionConfig, RelationExtractor,
51    SemanticRegistry,
52};
53use crate::{Entity, EntityType, Model, Result};
54use std::borrow::Cow;
55use std::collections::HashSet;
56
57/// TPLinker backend for joint entity-relation extraction.
58///
59/// Uses handshaking matrix to simultaneously extract entities and relations.
60#[derive(Debug)]
61pub struct TPLinker {
62    /// Confidence threshold for entity extraction
63    #[allow(dead_code)]
64    entity_threshold: f32,
65    /// Confidence threshold for relation extraction
66    #[allow(dead_code)]
67    relation_threshold: f32,
68}
69
70impl TPLinker {
71    /// Create a new TPLinker instance.
72    pub fn new() -> Result<Self> {
73        Ok(Self::with_thresholds(0.15, 0.55))
74    }
75
76    /// Create with custom thresholds.
77    pub fn with_thresholds(entity_threshold: f32, relation_threshold: f32) -> Self {
78        Self {
79            entity_threshold,
80            relation_threshold,
81        }
82    }
83
84    /// Reserved decoder entrypoint (not implemented).
85    ///
86    /// A full TPLinker implementation would:
87    /// 1. Run ONNX model to get handshaking matrix predictions
88    /// 2. Decode entity boundaries from SH2OH/OH2SH tags
89    /// 3. Decode relations from handshaking between entity pairs
90    #[allow(dead_code)] // Placeholder helper; kept for future TPLinker ONNX decoding work.
91    fn extract_with_handshaking(
92        &self,
93        text: &str,
94        entity_types: &[&str],
95        relation_types: &[&str],
96        threshold: f32,
97    ) -> Result<ExtractionWithRelations> {
98        // Interpret the call-site `threshold` as the *relation* threshold.
99        // Entity extraction should remain governed by `self.entity_threshold`, otherwise
100        // relation-eval runs with `threshold=0.5` can accidentally wipe out almost all
101        // heuristic entities and produce zero relations.
102        let rel_threshold = if threshold > 0.0 {
103            threshold
104        } else {
105            self.relation_threshold
106        };
107        let ent_threshold = self.entity_threshold;
108
109        // Heuristic baseline: use the default stacked NER (pattern + heuristic).
110        // This keeps the RE baseline dependency-light while still extracting common structured
111        // entities (DATE/MONEY/EMAIL/...) that relations frequently attach to.
112        let ner = crate::StackedNER::default();
113        let mut entities = ner.extract_entities(text, None)?;
114
115        // Respect the requested entity schema when possible.
116        // Note: Some relation datasets provide rich, dataset-specific entity type labels
117        // (e.g. "programlang", "academicjournal"). Those are not representable in our
118        // `EntityType` enum, so filtering via `EntityType::from_label` would collapse them
119        // (typically to `Misc`) and accidentally drop all HeuristicNER entities.
120        //
121        // We only apply filtering when the requested schema looks like it targets the
122        // canonical types we can actually emit.
123        if !entity_types.is_empty() {
124            let requested: Vec<String> = entity_types.iter().map(|s| s.to_lowercase()).collect();
125            let looks_supported = requested.iter().all(|t| {
126                matches!(
127                    t.as_str(),
128                    "person"
129                        | "per"
130                        | "organization"
131                        | "organisation"
132                        | "org"
133                        | "location"
134                        | "loc"
135                        | "date"
136                        | "time"
137                        | "money"
138                        | "misc"
139                )
140            });
141            if looks_supported {
142                let allowed: HashSet<EntityType> = entity_types
143                    .iter()
144                    .map(|s| EntityType::from_label(s))
145                    .collect();
146                entities.retain(|e| allowed.contains(&e.entity_type));
147            }
148        }
149
150        // Apply the *entity* threshold to entity confidences.
151        entities.retain(|e| e.confidence >= f64::from(ent_threshold));
152
153        // Add provenance to indicate heuristic baseline (not a neural TPLinker).
154        for entity in &mut entities {
155            entity.provenance = Some(crate::Provenance {
156                source: Cow::Borrowed("tplinker"),
157                method: crate::ExtractionMethod::Heuristic,
158                pattern: None,
159                raw_confidence: Some(entity.confidence),
160                model_version: Some(Cow::Borrowed("heuristic")),
161                timestamp: None,
162            });
163        }
164
165        // Extract relations: heuristic trigger-based extraction implemented in `inference.rs`.
166        //
167        // This is deliberately conservative: we only emit relations when we match a known trigger
168        // pattern *and* the relation type is present in `relation_types`. We do not "guess" a
169        // relation type just because two entities are nearby.
170        // If the caller doesn't provide an explicit relation schema, fall back to a conservative
171        // default set that matches the built-in heuristic trigger patterns.
172        //
173        // This keeps `TPLinker` usable from the CLI without requiring users to know label sets.
174        const DEFAULT_RELATIONS: &[&str] = &[
175            "CEO_OF",
176            "WORKS_FOR",
177            "FOUNDED",
178            "MANAGES",
179            "REPORTS_TO",
180            "LOCATED_IN",
181            "BORN_IN",
182            "LIVES_IN",
183            "DIED_IN",
184            "OCCURRED_ON",
185            "STARTED_ON",
186            "ENDED_ON",
187            "PART_OF",
188            "ACQUIRED",
189            "MERGED_WITH",
190            "PARENT_OF",
191            "MARRIED_TO",
192            "CHILD_OF",
193            "SIBLING_OF",
194        ];
195
196        let rels: Vec<&str> = if relation_types.is_empty() {
197            DEFAULT_RELATIONS.to_vec()
198        } else {
199            relation_types.to_vec()
200        };
201
202        let registry = {
203            let mut builder = SemanticRegistry::builder();
204            for rel in rels {
205                // Description is a best-effort placeholder; only the slug is used by the
206                // heuristic matcher today.
207                builder = builder.add_relation(rel, rel);
208            }
209            builder.build_placeholder(1)
210        };
211
212        let rel_config = RelationExtractionConfig {
213            threshold: rel_threshold,
214            max_span_distance: 120,
215            extract_triggers: false,
216        };
217
218        let relations = extract_relation_triples(&entities, text, &registry, &rel_config);
219
220        Ok(ExtractionWithRelations {
221            entities,
222            relations,
223        })
224    }
225}
226
227impl Model for TPLinker {
228    fn extract_entities(&self, text: &str, _language: Option<&str>) -> Result<Vec<Entity>> {
229        let heuristic = crate::StackedNER::default();
230        let mut entities = heuristic.extract_entities(text, None)?;
231        entities.retain(|e| e.confidence >= f64::from(self.entity_threshold));
232        Ok(entities)
233    }
234
235    fn supported_types(&self) -> Vec<EntityType> {
236        vec![
237            EntityType::Person,
238            EntityType::Organization,
239            EntityType::Location,
240            EntityType::Date,
241            EntityType::Time,
242            EntityType::Money,
243        ]
244    }
245
246    fn is_available(&self) -> bool {
247        true
248    }
249
250    fn name(&self) -> &'static str {
251        "tplinker"
252    }
253
254    fn description(&self) -> &'static str {
255        "TPLinker (heuristic baseline today; neural handshaking model TBD)"
256    }
257
258    fn capabilities(&self) -> crate::ModelCapabilities {
259        crate::ModelCapabilities {
260            batch_capable: true,
261            streaming_capable: true,
262            recommended_chunk_size: Some(10_000),
263            relation_capable: true,
264            ..Default::default()
265        }
266    }
267}
268
269impl crate::NamedEntityCapable for TPLinker {}
270
271impl RelationExtractor for TPLinker {
272    fn extract_with_relations(
273        &self,
274        text: &str,
275        entity_types: &[&str],
276        relation_types: &[&str],
277        threshold: f32,
278    ) -> Result<ExtractionWithRelations> {
279        self.extract_with_handshaking(text, entity_types, relation_types, threshold)
280    }
281}
282
283impl crate::RelationCapable for TPLinker {
284    fn extract_with_relations(
285        &self,
286        text: &str,
287        _language: Option<&str>,
288    ) -> Result<(Vec<Entity>, Vec<crate::Relation>)> {
289        use crate::backends::inference::{DEFAULT_ENTITY_TYPES, DEFAULT_RELATION_TYPES};
290        let result = <Self as RelationExtractor>::extract_with_relations(
291            self,
292            text,
293            DEFAULT_ENTITY_TYPES,
294            DEFAULT_RELATION_TYPES,
295            0.5,
296        )?;
297        Ok(result.into_anno_relations())
298    }
299}
300
301// Make TPLinker implement BatchCapable and StreamingCapable for consistency
302impl crate::BatchCapable for TPLinker {
303    fn extract_entities_batch(
304        &self,
305        texts: &[&str],
306        _language: Option<&str>,
307    ) -> Result<Vec<Vec<Entity>>> {
308        texts
309            .iter()
310            .map(|text| self.extract_entities(text, None))
311            .collect()
312    }
313}
314
315impl crate::StreamingCapable for TPLinker {
316    fn extract_entities_streaming(&self, chunk: &str, offset: usize) -> Result<Vec<Entity>> {
317        let mut entities = self.extract_entities(chunk, None)?;
318        for entity in &mut entities {
319            entity.start += offset;
320            entity.end += offset;
321        }
322        Ok(entities)
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329
330    #[test]
331    fn test_tplinker_creation() {
332        let tplinker = TPLinker::new().unwrap();
333        assert!(tplinker.is_available());
334    }
335
336    #[test]
337    fn test_tplinker_entity_extraction() {
338        let tplinker = TPLinker::with_thresholds(0.15, 0.55);
339        let entities = tplinker
340            .extract_entities("Steve Jobs founded Apple.", None)
341            .unwrap();
342        assert!(!entities.is_empty());
343    }
344
345    #[test]
346    fn test_tplinker_relation_extraction() {
347        let tplinker = TPLinker::with_thresholds(0.15, 0.55);
348        let out = tplinker
349            .extract_with_relations(
350                "Steve Jobs founded Apple in 1976.",
351                &["person", "organization"],
352                &["founded"],
353                0.5,
354            )
355            .unwrap();
356        assert!(out.entities.len() >= 2);
357        assert!(
358            out.relations.iter().any(|r| r.relation_type == "founded"),
359            "expected a founded relation; got: {:?}",
360            out.relations
361        );
362    }
363
364    #[test]
365    fn test_tplinker_unicode_offsets_invariants() {
366        // Diverse scripts + emoji (multi-byte). Offsets must be character-based and valid.
367        let tplinker = TPLinker::with_thresholds(0.15, 0.55);
368        let text = "Dr. 田中 met François Müller in 東京. 🎉";
369        let out = tplinker
370            .extract_with_relations(
371                text,
372                &["person", "location", "organization"],
373                &["works_for", "located_in", "founded"],
374                0.0,
375            )
376            .unwrap();
377
378        let text_len = text.chars().count();
379        for e in &out.entities {
380            assert!(e.start < e.end, "invalid span: {:?}", (e.start, e.end));
381            assert!(
382                e.end <= text_len,
383                "span out of bounds: {:?} (len={})",
384                (e.start, e.end),
385                text_len
386            );
387            let extracted = crate::offset::TextSpan::from_chars(text, e.start, e.end).extract(text);
388            assert_eq!(extracted, e.text);
389        }
390        for r in &out.relations {
391            assert!(r.head_idx < out.entities.len());
392            assert!(r.tail_idx < out.entities.len());
393        }
394    }
395}