1use std::path::Path;
18
19use ndarray::{Array2, Array3, ArrayD};
20use ort::session::Session;
21use ort::value::Tensor;
22use tokenizers::Tokenizer;
23
24use crate::ner::ExtractedEntity;
25use crate::rel::ExtractedRelation;
26use crate::schema::ExtractionSchema;
27
28struct InferenceOutputs {
30 logits: ArrayD<f32>,
31 rel_idx: Option<ArrayD<i64>>,
32 rel_logits: Option<ArrayD<f32>>,
33 rel_mask: Option<ArrayD<f32>>,
34 word_spans: Vec<(usize, usize, String)>,
35 num_words: usize,
36}
37
38const MAX_WIDTH: usize = 12;
40
41const ENT_TOKEN: &str = "<<ENT>>";
43const SEP_TOKEN: &str = "<<SEP>>";
44const REL_TOKEN: &str = "<<REL>>";
45
46pub struct RelexEngine {
48 session: Session,
49 tokenizer: Tokenizer,
50}
51
52pub struct RelexResult {
54 pub entities: Vec<ExtractedEntity>,
55 pub relations: Vec<ExtractedRelation>,
56}
57
58impl RelexEngine {
59 pub fn new(model_path: &Path, tokenizer_path: &Path) -> Result<Self, RelexError> {
61 let session = Session::builder()
62 .map_err(|e| RelexError::ModelLoad(e.to_string()))?
63 .with_intra_threads(4)
64 .map_err(|e| RelexError::ModelLoad(e.to_string()))?
65 .commit_from_file(model_path)
66 .map_err(|e| RelexError::ModelLoad(format!("{}: {}", model_path.display(), e)))?;
67
68 let tokenizer = Tokenizer::from_file(tokenizer_path)
69 .map_err(|e| RelexError::ModelLoad(format!("tokenizer: {}", e)))?;
70
71 Ok(Self { session, tokenizer })
72 }
73
74 pub fn extract(
81 &self,
82 text: &str,
83 entity_labels: &[&str],
84 relation_labels: &[&str],
85 entity_threshold: f32,
86 relation_threshold: f32,
87 schema: &ExtractionSchema,
88 ) -> Result<RelexResult, RelexError> {
89 let out = self.run_inference(text, entity_labels, relation_labels)?;
90
91 let entities = decode_entities(
92 &out.logits.view(),
93 &out.word_spans,
94 out.num_words,
95 text,
96 entity_labels,
97 entity_threshold,
98 );
99
100 let relations = match (out.rel_idx, out.rel_logits, out.rel_mask) {
101 (Some(ri), Some(rl), Some(rm)) => decode_relations(
102 &ri.view(),
103 &rl.view(),
104 &rm.view(),
105 &entities,
106 relation_labels,
107 relation_threshold,
108 schema,
109 ),
110 _ => Vec::new(),
111 };
112
113 Ok(RelexResult {
114 entities,
115 relations,
116 })
117 }
118
119 fn run_inference(
123 &self,
124 text: &str,
125 entity_labels: &[&str],
126 relation_labels: &[&str],
127 ) -> Result<InferenceOutputs, RelexError> {
128 let words: Vec<(usize, usize, &str)> = split_words(text);
130 let num_words = words.len();
131
132 let mut prompt_parts: Vec<String> = Vec::new();
135 for label in entity_labels {
136 prompt_parts.push(format!("{} {}", ENT_TOKEN, label));
137 }
138 prompt_parts.push(SEP_TOKEN.to_string());
139 for label in relation_labels {
140 prompt_parts.push(format!("{} {}", REL_TOKEN, label));
141 }
142 prompt_parts.push(SEP_TOKEN.to_string());
143
144 let prompt_prefix = prompt_parts.join(" ");
145 let full_text = format!("{} {}", prompt_prefix, text);
146
147 let encoding = self
149 .tokenizer
150 .encode(full_text.as_str(), true)
151 .map_err(|e| RelexError::Inference(format!("tokenize: {}", e)))?;
152
153 let ids = encoding.get_ids();
154 let attention = encoding.get_attention_mask();
155 let seq_len = ids.len();
156
157 let input_ids: Vec<i64> = ids.iter().map(|&id| id as i64).collect();
159 let attention_mask: Vec<i64> = attention.iter().map(|&a| a as i64).collect();
160
161 let mut words_mask = vec![0i64; seq_len];
171
172 let offsets = encoding.get_offsets();
173 let prompt_char_len = prompt_prefix.len() + 1; let mut prev_word_idx: Option<usize> = None;
176 for (tok_idx, &(tok_start, tok_end)) in offsets.iter().enumerate() {
177 if tok_idx == 0 || (tok_start == 0 && tok_end == 0) {
178 continue; }
180 if tok_start < prompt_char_len {
181 continue; }
183
184 let t_start = tok_start - prompt_char_len;
186 let t_end = tok_end - prompt_char_len;
187
188 for (word_idx, &(w_start, w_end, _)) in words.iter().enumerate() {
190 if t_start < w_end && t_end > w_start {
191 if prev_word_idx != Some(word_idx) {
192 words_mask[tok_idx] = (word_idx + 1) as i64; prev_word_idx = Some(word_idx);
194 }
195 break;
196 }
197 }
198 }
199
200 let num_spans = num_words * MAX_WIDTH;
203 let mut span_indices: Vec<i64> = Vec::with_capacity(num_spans * 2);
204 let mut span_mask: Vec<bool> = Vec::with_capacity(num_spans);
205 for start in 0..num_words {
206 for width in 0..MAX_WIDTH {
207 let end = start + width;
208 if end < num_words {
209 span_indices.push(start as i64);
210 span_indices.push(end as i64);
211 span_mask.push(true);
212 } else {
213 span_indices.push(0);
215 span_indices.push(0);
216 span_mask.push(false);
217 }
218 }
219 }
220
221 let input_ids_arr = Array2::from_shape_vec((1, seq_len), input_ids)
223 .map_err(|e| RelexError::Inference(e.to_string()))?;
224 let attention_mask_arr = Array2::from_shape_vec((1, seq_len), attention_mask)
225 .map_err(|e| RelexError::Inference(e.to_string()))?;
226 let words_mask_arr = Array2::from_shape_vec((1, seq_len), words_mask)
227 .map_err(|e| RelexError::Inference(e.to_string()))?;
228 let text_lengths_arr = Array2::from_shape_vec((1, 1), vec![num_words as i64])
229 .map_err(|e| RelexError::Inference(e.to_string()))?;
230 let span_idx_arr = Array3::from_shape_vec((1, num_spans, 2), span_indices)
231 .map_err(|e| RelexError::Inference(e.to_string()))?;
232 let span_mask_arr = Array2::from_shape_vec((1, num_spans), span_mask)
233 .map_err(|e| RelexError::Inference(e.to_string()))?;
234
235 let v_ids = Tensor::from_array(input_ids_arr).map_err(|e| RelexError::Inference(e.to_string()))?;
237 let v_attn = Tensor::from_array(attention_mask_arr).map_err(|e| RelexError::Inference(e.to_string()))?;
238 let v_wmask = Tensor::from_array(words_mask_arr).map_err(|e| RelexError::Inference(e.to_string()))?;
239 let v_tlen = Tensor::from_array(text_lengths_arr).map_err(|e| RelexError::Inference(e.to_string()))?;
240 let v_sidx = Tensor::from_array(span_idx_arr).map_err(|e| RelexError::Inference(e.to_string()))?;
241 let v_smask = Tensor::from_array(span_mask_arr).map_err(|e| RelexError::Inference(e.to_string()))?;
242
243 let inputs = ort::inputs![
244 "input_ids" => v_ids,
245 "attention_mask" => v_attn,
246 "words_mask" => v_wmask,
247 "text_lengths" => v_tlen,
248 "span_idx" => v_sidx,
249 "span_mask" => v_smask,
250 ]
251 .map_err(|e| RelexError::Inference(e.to_string()))?;
252
253 let outputs = self
254 .session
255 .run(inputs)
256 .map_err(|e| RelexError::Inference(e.to_string()))?;
257
258 let logits: ArrayD<f32> = outputs
260 .get("logits")
261 .ok_or_else(|| RelexError::Inference("missing 'logits' output".into()))?
262 .try_extract_tensor::<f32>()
263 .map_err(|e| RelexError::Inference(format!("logits tensor: {}", e)))?
264 .into_owned();
265
266 let rel_idx = outputs
267 .get("rel_idx")
268 .and_then(|v| v.try_extract_tensor::<i64>().ok())
269 .map(|t| t.into_owned());
270 let rel_logits = outputs
271 .get("rel_logits")
272 .and_then(|v| v.try_extract_tensor::<f32>().ok())
273 .map(|t| t.into_owned());
274 let rel_mask = outputs
276 .get("rel_mask")
277 .and_then(|v| {
278 if let Ok(t) = v.try_extract_tensor::<bool>() {
280 let converted: ArrayD<f32> = t.mapv(|b| if b { 1.0f32 } else { 0.0 });
281 Some(converted)
282 } else {
283 v.try_extract_tensor::<f32>().ok().map(|t| t.into_owned())
284 }
285 });
286
287 let owned_words: Vec<(usize, usize, String)> = words
288 .iter()
289 .map(|&(s, e, w)| (s, e, w.to_string()))
290 .collect();
291
292 Ok(InferenceOutputs {
293 logits,
294 rel_idx,
295 rel_logits,
296 rel_mask,
297 word_spans: owned_words,
298 num_words,
299 })
300 }
301}
302
303fn split_words(text: &str) -> Vec<(usize, usize, &str)> {
305 let mut words = Vec::new();
306 let mut start = None;
307
308 for (i, c) in text.char_indices() {
309 if c.is_whitespace() {
310 if let Some(s) = start {
311 words.push((s, i, &text[s..i]));
312 start = None;
313 }
314 } else if start.is_none() {
315 start = Some(i);
316 }
317 }
318 if let Some(s) = start {
319 words.push((s, text.len(), &text[s..]));
320 }
321 words
322}
323
324fn decode_entities(
329 logits: &ndarray::ArrayViewD<f32>,
330 word_spans: &[(usize, usize, String)],
331 num_words: usize,
332 text: &str,
333 entity_labels: &[&str],
334 threshold: f32,
335) -> Vec<ExtractedEntity> {
336 let shape = logits.shape();
337 if shape.len() != 4 {
339 return Vec::new();
340 }
341
342 let _batch = shape[0];
343 let n_words = shape[1];
344 let max_w = shape[2];
345 let n_classes = shape[3];
346
347 let mut entities = Vec::new();
348
349 for word_start in 0..n_words.min(num_words) {
350 for width in 0..max_w.min(num_words - word_start) {
351 let word_end = word_start + width;
352
353 for class_idx in 0..n_classes.min(entity_labels.len()) {
354 let score = logits[[0, word_start, width, class_idx]];
355 let prob = sigmoid(score);
356
357 if prob >= threshold {
358 if word_start < word_spans.len() && word_end < word_spans.len() {
360 let char_start = word_spans[word_start].0;
361 let char_end = word_spans[word_end].1;
362
363 if char_end <= text.len() {
364 let span_text = text[char_start..char_end].to_string();
365 entities.push(ExtractedEntity {
366 text: span_text,
367 entity_type: entity_labels[class_idx].to_string(),
368 span_start: char_start,
369 span_end: char_end,
370 confidence: prob as f64,
371 });
372 }
373 }
374 }
375 }
376 }
377 }
378
379 entities.sort_by(|a, b| b.confidence.partial_cmp(&a.confidence).unwrap());
381 let mut used_ranges: Vec<(usize, usize)> = Vec::new();
382 entities.retain(|e| {
383 let overlaps = used_ranges
384 .iter()
385 .any(|&(s, end)| e.span_start < end && e.span_end > s);
386 if !overlaps {
387 used_ranges.push((e.span_start, e.span_end));
388 true
389 } else {
390 false
391 }
392 });
393
394 entities
395}
396
397fn decode_relations(
403 rel_idx: &ndarray::ArrayViewD<i64>,
404 rel_logits: &ndarray::ArrayViewD<f32>,
405 rel_mask: &ndarray::ArrayViewD<f32>,
406 entities: &[ExtractedEntity],
407 relation_labels: &[&str],
408 threshold: f32,
409 schema: &ExtractionSchema,
410) -> Vec<ExtractedRelation> {
411 let shape = rel_logits.shape();
412 if shape.len() != 3 {
413 return Vec::new();
414 }
415
416 let num_pairs = shape[1];
417 let num_rel_classes = shape[2];
418
419 let mut relations = Vec::new();
420 let mut seen = std::collections::HashSet::new();
421
422 for pair_idx in 0..num_pairs {
423 let mask_val = rel_mask[[0, pair_idx]];
425 if mask_val < 0.5 {
426 continue;
427 }
428
429 let head_idx = rel_idx[[0, pair_idx, 0]] as usize;
430 let tail_idx = rel_idx[[0, pair_idx, 1]] as usize;
431
432 if head_idx >= entities.len() || tail_idx >= entities.len() {
433 continue;
434 }
435
436 let head_entity = &entities[head_idx];
437 let tail_entity = &entities[tail_idx];
438
439 if head_entity.text == tail_entity.text {
440 continue;
441 }
442
443 for rel_idx_inner in 0..num_rel_classes.min(relation_labels.len()) {
445 let score = rel_logits[[0, pair_idx, rel_idx_inner]];
446 let prob = sigmoid(score);
447
448 if prob >= threshold {
449 let relation = relation_labels[rel_idx_inner];
450
451 if let Some(spec) = schema.relation_types.get(relation) {
453 let valid = spec.head.contains(&head_entity.entity_type)
454 && spec.tail.contains(&tail_entity.entity_type);
455 let valid_rev = spec.head.contains(&tail_entity.entity_type)
457 && spec.tail.contains(&head_entity.entity_type);
458
459 if !valid && !valid_rev {
460 continue;
461 }
462
463 let (h, t) = if valid {
464 (head_entity.text.clone(), tail_entity.text.clone())
465 } else {
466 (tail_entity.text.clone(), head_entity.text.clone())
467 };
468
469 let key = (h.clone(), relation.to_string(), t.clone());
470 if seen.insert(key) {
471 relations.push(ExtractedRelation {
472 head: h,
473 relation: relation.to_string(),
474 tail: t,
475 confidence: prob as f64,
476 });
477 }
478 }
479 }
480 }
481 }
482
483 relations
484}
485
486fn sigmoid(x: f32) -> f32 {
487 1.0 / (1.0 + (-x).exp())
488}
489
490#[derive(Debug, thiserror::Error)]
491pub enum RelexError {
492 #[error("failed to load relex model: {0}")]
493 ModelLoad(String),
494
495 #[error("relex inference error: {0}")]
496 Inference(String),
497}