anno/backends/tplinker.rs
1//! TPLinker: Single-stage Joint Entity-Relation Extraction
2//!
3//! TPLinker uses a handshaking tagging scheme for joint entity-relation extraction.
4//! It models entity boundaries and relations simultaneously using a unified tagging matrix.
5//!
6//! # Implementation Status
7//!
8//! This module keeps the TPLinker **name + wiring** stable while the full neural handshaking
9//! model is still pending.
10//!
11//! Today, `TPLinker` is implemented as a **dependency-light heuristic baseline**:
12//! - entities are extracted using a zero-dependency NER baseline
13//! - relations are inferred using the shared heuristic matcher in `backends::inference`
14//!
15//! This makes relation extraction *function* end-to-end (for demos, DX, and eval harnesses)
16//! without pretending we have a trained TPLinker model.
17//!
18//! # Research
19//!
20//! - **Paper**: [TPLinker: Single-stage Joint Extraction](https://aclanthology.org/2020.coling-main.138/)
21//! - **Architecture**: Handshaking matrix where each cell (i,j) encodes:
22//! - Entity boundaries (SH2OH, OH2SH, ST2OT, OT2ST)
23//! - Relations (handshaking between entity pairs)
24//!
25//! # Usage
26//!
27//! ```rust,ignore
28//! use anno::backends::tplinker::TPLinker;
29//!
30//! let extractor = TPLinker::new()?;
31//! let result = extractor.extract_with_relations(
32//! "Steve Jobs founded Apple in 1976.",
33//! &["person", "organization"],
34//! &["founded", "works_for"],
35//! 0.5
36//! )?;
37//!
38//! for entity in &result.entities {
39//! println!("Entity: {} ({})", entity.text, entity.entity_type);
40//! }
41//!
42//! for relation in &result.relations {
43//! let head = &result.entities[relation.head_idx];
44//! let tail = &result.entities[relation.tail_idx];
45//! println!("Relation: {} --[{}]--> {}", head.text, relation.relation_type, tail.text);
46//! }
47//! ```
48
49use crate::backends::inference::{
50 extract_relation_triples, ExtractionWithRelations, RelationExtractionConfig, RelationExtractor,
51 SemanticRegistry,
52};
53use crate::{Entity, EntityType, Model, Result};
54use std::borrow::Cow;
55use std::collections::HashSet;
56
57/// TPLinker backend for joint entity-relation extraction.
58///
59/// Uses handshaking matrix to simultaneously extract entities and relations.
60#[derive(Debug)]
61pub struct TPLinker {
62 /// Confidence threshold for entity extraction
63 #[allow(dead_code)]
64 entity_threshold: f32,
65 /// Confidence threshold for relation extraction
66 #[allow(dead_code)]
67 relation_threshold: f32,
68}
69
70impl TPLinker {
71 /// Create a new TPLinker instance.
72 pub fn new() -> Result<Self> {
73 Ok(Self::with_thresholds(0.15, 0.55))
74 }
75
76 /// Create with custom thresholds.
77 pub fn with_thresholds(entity_threshold: f32, relation_threshold: f32) -> Self {
78 Self {
79 entity_threshold,
80 relation_threshold,
81 }
82 }
83
84 /// Reserved decoder entrypoint (not implemented).
85 ///
86 /// A full TPLinker implementation would:
87 /// 1. Run ONNX model to get handshaking matrix predictions
88 /// 2. Decode entity boundaries from SH2OH/OH2SH tags
89 /// 3. Decode relations from handshaking between entity pairs
90 #[allow(dead_code)] // Placeholder helper; kept for future TPLinker ONNX decoding work.
91 fn extract_with_handshaking(
92 &self,
93 text: &str,
94 entity_types: &[&str],
95 relation_types: &[&str],
96 threshold: f32,
97 ) -> Result<ExtractionWithRelations> {
98 // Interpret the call-site `threshold` as the *relation* threshold.
99 // Entity extraction should remain governed by `self.entity_threshold`, otherwise
100 // relation-eval runs with `threshold=0.5` can accidentally wipe out almost all
101 // heuristic entities and produce zero relations.
102 let rel_threshold = if threshold > 0.0 {
103 threshold
104 } else {
105 self.relation_threshold
106 };
107 let ent_threshold = self.entity_threshold;
108
109 // Heuristic baseline: use the default stacked NER (pattern + heuristic).
110 // This keeps the RE baseline dependency-light while still extracting common structured
111 // entities (DATE/MONEY/EMAIL/...) that relations frequently attach to.
112 let ner = crate::StackedNER::default();
113 let mut entities = ner.extract_entities(text, None)?;
114
115 // Respect the requested entity schema when possible.
116 // Note: Some relation datasets provide rich, dataset-specific entity type labels
117 // (e.g. "programlang", "academicjournal"). Those are not representable in our
118 // `EntityType` enum, so filtering via `EntityType::from_label` would collapse them
119 // (typically to `Misc`) and accidentally drop all HeuristicNER entities.
120 //
121 // We only apply filtering when the requested schema looks like it targets the
122 // canonical types we can actually emit.
123 if !entity_types.is_empty() {
124 let requested: Vec<String> = entity_types.iter().map(|s| s.to_lowercase()).collect();
125 let looks_supported = requested.iter().all(|t| {
126 matches!(
127 t.as_str(),
128 "person"
129 | "per"
130 | "organization"
131 | "organisation"
132 | "org"
133 | "location"
134 | "loc"
135 | "date"
136 | "time"
137 | "money"
138 | "misc"
139 )
140 });
141 if looks_supported {
142 let allowed: HashSet<EntityType> = entity_types
143 .iter()
144 .map(|s| EntityType::from_label(s))
145 .collect();
146 entities.retain(|e| allowed.contains(&e.entity_type));
147 }
148 }
149
150 // Apply the *entity* threshold to entity confidences.
151 entities.retain(|e| e.confidence >= f64::from(ent_threshold));
152
153 // Add provenance to indicate heuristic baseline (not a neural TPLinker).
154 for entity in &mut entities {
155 entity.provenance = Some(crate::Provenance {
156 source: Cow::Borrowed("tplinker"),
157 method: crate::ExtractionMethod::Heuristic,
158 pattern: None,
159 raw_confidence: Some(entity.confidence),
160 model_version: Some(Cow::Borrowed("heuristic")),
161 timestamp: None,
162 });
163 }
164
165 // Extract relations: heuristic trigger-based extraction implemented in `inference.rs`.
166 //
167 // This is deliberately conservative: we only emit relations when we match a known trigger
168 // pattern *and* the relation type is present in `relation_types`. We do not "guess" a
169 // relation type just because two entities are nearby.
170 // If the caller doesn't provide an explicit relation schema, fall back to a conservative
171 // default set that matches the built-in heuristic trigger patterns.
172 //
173 // This keeps `TPLinker` usable from the CLI without requiring users to know label sets.
174 const DEFAULT_RELATIONS: &[&str] = &[
175 "CEO_OF",
176 "WORKS_FOR",
177 "FOUNDED",
178 "MANAGES",
179 "REPORTS_TO",
180 "LOCATED_IN",
181 "BORN_IN",
182 "LIVES_IN",
183 "DIED_IN",
184 "OCCURRED_ON",
185 "STARTED_ON",
186 "ENDED_ON",
187 "PART_OF",
188 "ACQUIRED",
189 "MERGED_WITH",
190 "PARENT_OF",
191 "MARRIED_TO",
192 "CHILD_OF",
193 "SIBLING_OF",
194 ];
195
196 let rels: Vec<&str> = if relation_types.is_empty() {
197 DEFAULT_RELATIONS.to_vec()
198 } else {
199 relation_types.to_vec()
200 };
201
202 let registry = {
203 let mut builder = SemanticRegistry::builder();
204 for rel in rels {
205 // Description is a best-effort placeholder; only the slug is used by the
206 // heuristic matcher today.
207 builder = builder.add_relation(rel, rel);
208 }
209 builder.build_placeholder(1)
210 };
211
212 let rel_config = RelationExtractionConfig {
213 threshold: rel_threshold,
214 max_span_distance: 120,
215 extract_triggers: false,
216 };
217
218 let relations = extract_relation_triples(&entities, text, ®istry, &rel_config);
219
220 Ok(ExtractionWithRelations {
221 entities,
222 relations,
223 })
224 }
225}
226
227impl Model for TPLinker {
228 fn extract_entities(&self, text: &str, _language: Option<&str>) -> Result<Vec<Entity>> {
229 let heuristic = crate::StackedNER::default();
230 let mut entities = heuristic.extract_entities(text, None)?;
231 entities.retain(|e| e.confidence >= f64::from(self.entity_threshold));
232 Ok(entities)
233 }
234
235 fn supported_types(&self) -> Vec<EntityType> {
236 vec![
237 EntityType::Person,
238 EntityType::Organization,
239 EntityType::Location,
240 EntityType::Date,
241 EntityType::Time,
242 EntityType::Money,
243 ]
244 }
245
246 fn is_available(&self) -> bool {
247 true
248 }
249
250 fn name(&self) -> &'static str {
251 "tplinker"
252 }
253
254 fn description(&self) -> &'static str {
255 "TPLinker (heuristic baseline today; neural handshaking model TBD)"
256 }
257}
258
259impl RelationExtractor for TPLinker {
260 fn extract_with_relations(
261 &self,
262 text: &str,
263 entity_types: &[&str],
264 relation_types: &[&str],
265 threshold: f32,
266 ) -> Result<ExtractionWithRelations> {
267 self.extract_with_handshaking(text, entity_types, relation_types, threshold)
268 }
269}
270
271// Make TPLinker implement BatchCapable and StreamingCapable for consistency
272impl crate::BatchCapable for TPLinker {
273 fn extract_entities_batch(
274 &self,
275 texts: &[&str],
276 _language: Option<&str>,
277 ) -> Result<Vec<Vec<Entity>>> {
278 texts
279 .iter()
280 .map(|text| self.extract_entities(text, None))
281 .collect()
282 }
283}
284
285impl crate::StreamingCapable for TPLinker {
286 fn extract_entities_streaming(&self, chunk: &str, offset: usize) -> Result<Vec<Entity>> {
287 let mut entities = self.extract_entities(chunk, None)?;
288 for entity in &mut entities {
289 entity.start += offset;
290 entity.end += offset;
291 }
292 Ok(entities)
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299
300 #[test]
301 fn test_tplinker_creation() {
302 let tplinker = TPLinker::new().unwrap();
303 assert!(tplinker.is_available());
304 }
305
306 #[test]
307 fn test_tplinker_entity_extraction() {
308 let tplinker = TPLinker::with_thresholds(0.15, 0.55);
309 let entities = tplinker
310 .extract_entities("Steve Jobs founded Apple.", None)
311 .unwrap();
312 assert!(!entities.is_empty());
313 }
314
315 #[test]
316 fn test_tplinker_relation_extraction() {
317 let tplinker = TPLinker::with_thresholds(0.15, 0.55);
318 let out = tplinker
319 .extract_with_relations(
320 "Steve Jobs founded Apple in 1976.",
321 &["person", "organization"],
322 &["founded"],
323 0.5,
324 )
325 .unwrap();
326 assert!(out.entities.len() >= 2);
327 assert!(
328 out.relations.iter().any(|r| r.relation_type == "founded"),
329 "expected a founded relation; got: {:?}",
330 out.relations
331 );
332 }
333
334 #[test]
335 fn test_tplinker_unicode_offsets_invariants() {
336 // Diverse scripts + emoji (multi-byte). Offsets must be character-based and valid.
337 let tplinker = TPLinker::with_thresholds(0.15, 0.55);
338 let text = "Dr. 田中 met François Müller in 東京. 🎉";
339 let out = tplinker
340 .extract_with_relations(
341 text,
342 &["person", "location", "organization"],
343 &["works_for", "located_in", "founded"],
344 0.0,
345 )
346 .unwrap();
347
348 let text_len = text.chars().count();
349 for e in &out.entities {
350 assert!(e.start < e.end, "invalid span: {:?}", (e.start, e.end));
351 assert!(
352 e.end <= text_len,
353 "span out of bounds: {:?} (len={})",
354 (e.start, e.end),
355 text_len
356 );
357 let extracted = crate::offset::TextSpan::from_chars(text, e.start, e.end).extract(text);
358 assert_eq!(extracted, e.text);
359 }
360 for r in &out.relations {
361 assert!(r.head_idx < out.entities.len());
362 assert!(r.tail_idx < out.entities.len());
363 }
364 }
365}