Skip to main content

ctxgraph_extract/
relex.rs

1//! Custom ONNX pipeline for `gliner-relex-large-v0.5`.
2//!
3//! This model performs joint NER + relation extraction in a single forward pass
4//! using a DeBERTa-v3-large backbone with a GCN-based relation layer.
5//!
6//! Architecture: UniEncoderSpanRelex (span_mode=markerV0)
7//!
8//! ONNX inputs (6 tensors):
9//!   input_ids, attention_mask, words_mask, text_lengths, span_idx, span_mask
10//!
11//! ONNX outputs (4 tensors):
12//!   logits         — entity span scores [batch, num_words, max_width, num_classes]
13//!   rel_idx        — entity pair indices [batch, num_pairs, 2]
14//!   rel_logits     — relation type scores [batch, num_pairs, num_rel_classes]
15//!   rel_mask       — valid pair mask [batch, num_pairs]
16
17use std::path::Path;
18
19use ndarray::{Array2, Array3, ArrayD};
20use ort::session::Session;
21use ort::value::Tensor;
22use tokenizers::Tokenizer;
23
24use crate::ner::ExtractedEntity;
25use crate::rel::ExtractedRelation;
26use crate::schema::ExtractionSchema;
27
28/// Raw tensor outputs from ONNX inference, fully owned.
29struct InferenceOutputs {
30    logits: ArrayD<f32>,
31    rel_idx: Option<ArrayD<i64>>,
32    rel_logits: Option<ArrayD<f32>>,
33    rel_mask: Option<ArrayD<f32>>,
34    word_spans: Vec<(usize, usize, String)>,
35    num_words: usize,
36}
37
38/// Maximum span width (number of words). Matches model config `max_width: 12`.
39const MAX_WIDTH: usize = 12;
40
41/// Special token IDs from the relex tokenizer.
42const ENT_TOKEN: &str = "<<ENT>>";
43const SEP_TOKEN: &str = "<<SEP>>";
44const REL_TOKEN: &str = "<<REL>>";
45
46/// The relex ONNX inference engine.
47pub struct RelexEngine {
48    session: Session,
49    tokenizer: Tokenizer,
50}
51
52/// Result from a single relex inference pass.
53pub struct RelexResult {
54    pub entities: Vec<ExtractedEntity>,
55    pub relations: Vec<ExtractedRelation>,
56}
57
58impl RelexEngine {
59    /// Load the relex ONNX model and tokenizer.
60    pub fn new(model_path: &Path, tokenizer_path: &Path) -> Result<Self, RelexError> {
61        let session = Session::builder()
62            .map_err(|e| RelexError::ModelLoad(e.to_string()))?
63            .with_intra_threads(4)
64            .map_err(|e| RelexError::ModelLoad(e.to_string()))?
65            .commit_from_file(model_path)
66            .map_err(|e| RelexError::ModelLoad(format!("{}: {}", model_path.display(), e)))?;
67
68        let tokenizer = Tokenizer::from_file(tokenizer_path)
69            .map_err(|e| RelexError::ModelLoad(format!("tokenizer: {}", e)))?;
70
71        Ok(Self { session, tokenizer })
72    }
73
74    /// Run joint NER + relation extraction on text.
75    ///
76    /// - `entity_labels`: entity type names (e.g., ["Person", "Database", "Service"])
77    /// - `relation_labels`: relation type names (e.g., ["chose", "replaced", "depends_on"])
78    /// - `entity_threshold`: minimum sigmoid score for entity spans (0.0-1.0)
79    /// - `relation_threshold`: minimum sigmoid score for relations (0.0-1.0)
80    pub fn extract(
81        &self,
82        text: &str,
83        entity_labels: &[&str],
84        relation_labels: &[&str],
85        entity_threshold: f32,
86        relation_threshold: f32,
87        schema: &ExtractionSchema,
88    ) -> Result<RelexResult, RelexError> {
89        let out = self.run_inference(text, entity_labels, relation_labels)?;
90
91        let entities = decode_entities(
92            &out.logits.view(),
93            &out.word_spans,
94            out.num_words,
95            text,
96            entity_labels,
97            entity_threshold,
98        );
99
100        let relations = match (out.rel_idx, out.rel_logits, out.rel_mask) {
101            (Some(ri), Some(rl), Some(rm)) => decode_relations(
102                &ri.view(),
103                &rl.view(),
104                &rm.view(),
105                &entities,
106                relation_labels,
107                relation_threshold,
108                schema,
109            ),
110            _ => Vec::new(),
111        };
112
113        Ok(RelexResult {
114            entities,
115            relations,
116        })
117    }
118
119    /// Build the 6 ONNX input tensors and run inference.
120    ///
121    /// Returns (outputs, word_spans, num_words).
122    fn run_inference(
123        &self,
124        text: &str,
125        entity_labels: &[&str],
126        relation_labels: &[&str],
127    ) -> Result<InferenceOutputs, RelexError> {
128        // Split text into words (simple whitespace split with char offsets)
129        let words: Vec<(usize, usize, &str)> = split_words(text);
130        let num_words = words.len();
131
132        // Build the prompt string:
133        // <<ENT>> Person <<ENT>> Database ... <<SEP>> <<REL>> chose <<REL>> replaced ... <<SEP>> text
134        let mut prompt_parts: Vec<String> = Vec::new();
135        for label in entity_labels {
136            prompt_parts.push(format!("{} {}", ENT_TOKEN, label));
137        }
138        prompt_parts.push(SEP_TOKEN.to_string());
139        for label in relation_labels {
140            prompt_parts.push(format!("{} {}", REL_TOKEN, label));
141        }
142        prompt_parts.push(SEP_TOKEN.to_string());
143
144        let prompt_prefix = prompt_parts.join(" ");
145        let full_text = format!("{} {}", prompt_prefix, text);
146
147        // Tokenize
148        let encoding = self
149            .tokenizer
150            .encode(full_text.as_str(), true)
151            .map_err(|e| RelexError::Inference(format!("tokenize: {}", e)))?;
152
153        let ids = encoding.get_ids();
154        let attention = encoding.get_attention_mask();
155        let seq_len = ids.len();
156
157        // Build input_ids and attention_mask
158        let input_ids: Vec<i64> = ids.iter().map(|&id| id as i64).collect();
159        let attention_mask: Vec<i64> = attention.iter().map(|&a| a as i64).collect();
160
161        // Build words_mask: maps each token to its word index in the text portion.
162        //
163        // The sentencepiece tokenizer produces offsets that include a leading `▁`
164        // (space) character as part of the token, so token offsets don't align
165        // exactly with word start positions. We use overlap-based matching: a
166        // token maps to a word if the token's character range overlaps the word's
167        // range. Only the first sub-token of each word receives the 1-based word
168        // index; continuation sub-tokens remain 0 (matching GLiNER's
169        // `prepare_word_mask` which uses `word_ids()` from HuggingFace tokenizers).
170        let mut words_mask = vec![0i64; seq_len];
171
172        let offsets = encoding.get_offsets();
173        let prompt_char_len = prompt_prefix.len() + 1; // +1 for the space between prompt and text
174
175        let mut prev_word_idx: Option<usize> = None;
176        for (tok_idx, &(tok_start, tok_end)) in offsets.iter().enumerate() {
177            if tok_idx == 0 || (tok_start == 0 && tok_end == 0) {
178                continue; // skip [CLS], [SEP], and padding
179            }
180            if tok_start < prompt_char_len {
181                continue; // skip prompt tokens
182            }
183
184            // Convert token char range to text-relative offsets
185            let t_start = tok_start - prompt_char_len;
186            let t_end = tok_end - prompt_char_len;
187
188            // Find word whose range overlaps this token (first sub-token only)
189            for (word_idx, &(w_start, w_end, _)) in words.iter().enumerate() {
190                if t_start < w_end && t_end > w_start {
191                    if prev_word_idx != Some(word_idx) {
192                        words_mask[tok_idx] = (word_idx + 1) as i64; // 1-based
193                        prev_word_idx = Some(word_idx);
194                    }
195                    break;
196                }
197            }
198        }
199
200        // Build span_idx and span_mask
201        // Model expects exactly num_words * MAX_WIDTH spans (reshape requirement)
202        let num_spans = num_words * MAX_WIDTH;
203        let mut span_indices: Vec<i64> = Vec::with_capacity(num_spans * 2);
204        let mut span_mask: Vec<bool> = Vec::with_capacity(num_spans);
205        for start in 0..num_words {
206            for width in 0..MAX_WIDTH {
207                let end = start + width;
208                if end < num_words {
209                    span_indices.push(start as i64);
210                    span_indices.push(end as i64);
211                    span_mask.push(true);
212                } else {
213                    // Padding span (invalid)
214                    span_indices.push(0);
215                    span_indices.push(0);
216                    span_mask.push(false);
217                }
218            }
219        }
220
221        // Build ndarray tensors
222        let input_ids_arr = Array2::from_shape_vec((1, seq_len), input_ids)
223            .map_err(|e| RelexError::Inference(e.to_string()))?;
224        let attention_mask_arr = Array2::from_shape_vec((1, seq_len), attention_mask)
225            .map_err(|e| RelexError::Inference(e.to_string()))?;
226        let words_mask_arr = Array2::from_shape_vec((1, seq_len), words_mask)
227            .map_err(|e| RelexError::Inference(e.to_string()))?;
228        let text_lengths_arr = Array2::from_shape_vec((1, 1), vec![num_words as i64])
229            .map_err(|e| RelexError::Inference(e.to_string()))?;
230        let span_idx_arr = Array3::from_shape_vec((1, num_spans, 2), span_indices)
231            .map_err(|e| RelexError::Inference(e.to_string()))?;
232        let span_mask_arr = Array2::from_shape_vec((1, num_spans), span_mask)
233            .map_err(|e| RelexError::Inference(e.to_string()))?;
234
235        // Convert to ort Values
236        let v_ids = Tensor::from_array(input_ids_arr).map_err(|e| RelexError::Inference(e.to_string()))?;
237        let v_attn = Tensor::from_array(attention_mask_arr).map_err(|e| RelexError::Inference(e.to_string()))?;
238        let v_wmask = Tensor::from_array(words_mask_arr).map_err(|e| RelexError::Inference(e.to_string()))?;
239        let v_tlen = Tensor::from_array(text_lengths_arr).map_err(|e| RelexError::Inference(e.to_string()))?;
240        let v_sidx = Tensor::from_array(span_idx_arr).map_err(|e| RelexError::Inference(e.to_string()))?;
241        let v_smask = Tensor::from_array(span_mask_arr).map_err(|e| RelexError::Inference(e.to_string()))?;
242
243        let inputs = ort::inputs![
244            "input_ids" => v_ids,
245            "attention_mask" => v_attn,
246            "words_mask" => v_wmask,
247            "text_lengths" => v_tlen,
248            "span_idx" => v_sidx,
249            "span_mask" => v_smask,
250        ]
251        .map_err(|e| RelexError::Inference(e.to_string()))?;
252
253        let outputs = self
254            .session
255            .run(inputs)
256            .map_err(|e| RelexError::Inference(e.to_string()))?;
257
258        // Extract tensors into owned arrays before SessionOutputs is dropped
259        let logits: ArrayD<f32> = outputs
260            .get("logits")
261            .ok_or_else(|| RelexError::Inference("missing 'logits' output".into()))?
262            .try_extract_tensor::<f32>()
263            .map_err(|e| RelexError::Inference(format!("logits tensor: {}", e)))?
264            .into_owned();
265
266        let rel_idx = outputs
267            .get("rel_idx")
268            .and_then(|v| v.try_extract_tensor::<i64>().ok())
269            .map(|t| t.into_owned());
270        let rel_logits = outputs
271            .get("rel_logits")
272            .and_then(|v| v.try_extract_tensor::<f32>().ok())
273            .map(|t| t.into_owned());
274        // rel_mask is output as bool by the ONNX model; convert to f32 for decode.
275        let rel_mask = outputs
276            .get("rel_mask")
277            .and_then(|v| {
278                // Try bool first (matches ONNX export), fall back to f32
279                if let Ok(t) = v.try_extract_tensor::<bool>() {
280                    let converted: ArrayD<f32> = t.mapv(|b| if b { 1.0f32 } else { 0.0 });
281                    Some(converted)
282                } else {
283                    v.try_extract_tensor::<f32>().ok().map(|t| t.into_owned())
284                }
285            });
286
287        let owned_words: Vec<(usize, usize, String)> = words
288            .iter()
289            .map(|&(s, e, w)| (s, e, w.to_string()))
290            .collect();
291
292        Ok(InferenceOutputs {
293            logits,
294            rel_idx,
295            rel_logits,
296            rel_mask,
297            word_spans: owned_words,
298            num_words,
299        })
300    }
301}
302
303/// Split text into words with byte offsets: (start, end, word_str).
304fn split_words(text: &str) -> Vec<(usize, usize, &str)> {
305    let mut words = Vec::new();
306    let mut start = None;
307
308    for (i, c) in text.char_indices() {
309        if c.is_whitespace() {
310            if let Some(s) = start {
311                words.push((s, i, &text[s..i]));
312                start = None;
313            }
314        } else if start.is_none() {
315            start = Some(i);
316        }
317    }
318    if let Some(s) = start {
319        words.push((s, text.len(), &text[s..]));
320    }
321    words
322}
323
324/// Decode entity spans from the logits tensor.
325///
326/// logits shape: [batch=1, num_words, max_width, num_entity_classes]
327/// Each value is the score for span (word_i, word_i+width) being entity class c.
328fn decode_entities(
329    logits: &ndarray::ArrayViewD<f32>,
330    word_spans: &[(usize, usize, String)],
331    num_words: usize,
332    text: &str,
333    entity_labels: &[&str],
334    threshold: f32,
335) -> Vec<ExtractedEntity> {
336    let shape = logits.shape();
337    // shape: [1, num_words, max_width, num_classes]
338    if shape.len() != 4 {
339        return Vec::new();
340    }
341
342    let _batch = shape[0];
343    let n_words = shape[1];
344    let max_w = shape[2];
345    let n_classes = shape[3];
346
347    let mut entities = Vec::new();
348
349    for word_start in 0..n_words.min(num_words) {
350        for width in 0..max_w.min(num_words - word_start) {
351            let word_end = word_start + width;
352
353            for class_idx in 0..n_classes.min(entity_labels.len()) {
354                let score = logits[[0, word_start, width, class_idx]];
355                let prob = sigmoid(score);
356
357                if prob >= threshold {
358                    // Convert word indices to character offsets
359                    if word_start < word_spans.len() && word_end < word_spans.len() {
360                        let char_start = word_spans[word_start].0;
361                        let char_end = word_spans[word_end].1;
362
363                        if char_end <= text.len() {
364                            let span_text = text[char_start..char_end].to_string();
365                            entities.push(ExtractedEntity {
366                                text: span_text,
367                                entity_type: entity_labels[class_idx].to_string(),
368                                span_start: char_start,
369                                span_end: char_end,
370                                confidence: prob as f64,
371                            });
372                        }
373                    }
374                }
375            }
376        }
377    }
378
379    // Greedy dedup: for overlapping spans, keep highest confidence
380    entities.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());
381    let mut used_ranges: Vec<(usize, usize)> = Vec::new();
382    entities.retain(|e| {
383        let overlaps = used_ranges
384            .iter()
385            .any(|&(s, end)| e.span_start < end && e.span_end > s);
386        if !overlaps {
387            used_ranges.push((e.span_start, e.span_end));
388            true
389        } else {
390            false
391        }
392    });
393
394    entities
395}
396
397/// Decode relations from model outputs.
398///
399/// rel_idx shape:    [batch=1, num_pairs, 2]     — indices into entity list
400/// rel_logits shape: [batch=1, num_pairs, num_rel_classes] — scores per relation type
401/// rel_mask shape:   [batch=1, num_pairs]         — valid pair indicator
402fn decode_relations(
403    rel_idx: &ndarray::ArrayViewD<i64>,
404    rel_logits: &ndarray::ArrayViewD<f32>,
405    rel_mask: &ndarray::ArrayViewD<f32>,
406    entities: &[ExtractedEntity],
407    relation_labels: &[&str],
408    threshold: f32,
409    schema: &ExtractionSchema,
410) -> Vec<ExtractedRelation> {
411    let shape = rel_logits.shape();
412    if shape.len() != 3 {
413        return Vec::new();
414    }
415
416    let num_pairs = shape[1];
417    let num_rel_classes = shape[2];
418
419    let mut relations = Vec::new();
420    let mut seen = std::collections::HashSet::new();
421
422    for pair_idx in 0..num_pairs {
423        // Check mask
424        let mask_val = rel_mask[[0, pair_idx]];
425        if mask_val < 0.5 {
426            continue;
427        }
428
429        let head_idx = rel_idx[[0, pair_idx, 0]] as usize;
430        let tail_idx = rel_idx[[0, pair_idx, 1]] as usize;
431
432        if head_idx >= entities.len() || tail_idx >= entities.len() {
433            continue;
434        }
435
436        let head_entity = &entities[head_idx];
437        let tail_entity = &entities[tail_idx];
438
439        if head_entity.text == tail_entity.text {
440            continue;
441        }
442
443        // Find best relation type for this pair
444        for rel_idx_inner in 0..num_rel_classes.min(relation_labels.len()) {
445            let score = rel_logits[[0, pair_idx, rel_idx_inner]];
446            let prob = sigmoid(score);
447
448            if prob >= threshold {
449                let relation = relation_labels[rel_idx_inner];
450
451                // Validate against schema
452                if let Some(spec) = schema.relation_types.get(relation) {
453                    let valid = spec.head.contains(&head_entity.entity_type)
454                        && spec.tail.contains(&tail_entity.entity_type);
455                    // Also check reverse direction
456                    let valid_rev = spec.head.contains(&tail_entity.entity_type)
457                        && spec.tail.contains(&head_entity.entity_type);
458
459                    if !valid && !valid_rev {
460                        continue;
461                    }
462
463                    let (h, t) = if valid {
464                        (head_entity.text.clone(), tail_entity.text.clone())
465                    } else {
466                        (tail_entity.text.clone(), head_entity.text.clone())
467                    };
468
469                    let key = (h.clone(), relation.to_string(), t.clone());
470                    if seen.insert(key) {
471                        relations.push(ExtractedRelation {
472                            head: h,
473                            relation: relation.to_string(),
474                            tail: t,
475                            confidence: prob as f64,
476                        });
477                    }
478                }
479            }
480        }
481    }
482
483    relations
484}
485
486fn sigmoid(x: f32) -> f32 {
487    1.0 / (1.0 + (-x).exp())
488}
489
490#[derive(Debug, thiserror::Error)]
491pub enum RelexError {
492    #[error("failed to load relex model: {0}")]
493    ModelLoad(String),
494
495    #[error("relex inference error: {0}")]
496    Inference(String),
497}