1use std::path::Path;
10
11use ort::session::Session;
12use ort::value::Tensor;
13use tokenizers::Tokenizer;
14
15use crate::ner::ExtractedEntity;
16use crate::rel::ExtractedRelation;
17use crate::schema::ExtractionSchema;
18
19const HYPOTHESIS_TEMPLATES: &[(&str, &[&str])] = &[
22 ("chose", &[
23 "{head} chose {tail}",
24 "{head} selected {tail}",
25 ]),
26 ("rejected", &[
27 "{head} rejected {tail}",
28 "{head} decided against {tail}",
29 ]),
30 ("replaced", &[
31 "{head} replaced {tail}",
32 "{tail} was replaced by {head}",
33 ]),
34 ("depends_on", &[
35 "{head} depends on {tail}",
36 "{head} uses {tail}",
37 ]),
38 ("fixed", &[
39 "{head} fixed {tail}",
40 "{head} resolved {tail}",
41 ]),
42 ("introduced", &[
43 "{head} introduced {tail}",
44 "{head} added {tail}",
45 ]),
46 ("deprecated", &[
47 "{head} deprecated {tail}",
48 "{head} removed {tail}",
49 ]),
50 ("caused", &[
51 "{head} caused {tail}",
52 "{head} led to {tail}",
53 ]),
54 ("constrained_by", &[
55 "{head} is constrained by {tail}",
56 "{head} must comply with {tail}",
57 ]),
58];
59
60pub struct NliEngine {
62 session: Session,
63 tokenizer: Tokenizer,
64}
65
66const ENTAILMENT_IDX: usize = 1;
68
69impl NliEngine {
70 pub fn new(model_path: &Path, tokenizer_path: &Path) -> Result<Self, NliError> {
72 let session = Session::builder()
73 .and_then(|b| b.with_intra_threads(1))
74 .and_then(|b| b.commit_from_file(model_path))
75 .map_err(|e| NliError::ModelLoad(e.to_string()))?;
76
77 let tokenizer = Tokenizer::from_file(tokenizer_path)
78 .map_err(|e| NliError::ModelLoad(e.to_string()))?;
79
80 Ok(Self { session, tokenizer })
81 }
82
83 fn score(&self, premise: &str, hypothesis: &str) -> Result<[f32; 3], NliError> {
86 let encoding = self.tokenizer
87 .encode((premise, hypothesis), true)
88 .map_err(|e| NliError::Inference(e.to_string()))?;
89
90 let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&id| id as i64).collect();
91 let attention_mask: Vec<i64> = encoding.get_attention_mask().iter().map(|&m| m as i64).collect();
92
93 let seq_len = input_ids.len();
94
95 let ids_tensor = Tensor::from_array(([1, seq_len], input_ids))
96 .map_err(|e| NliError::Inference(e.to_string()))?;
97 let mask_tensor = Tensor::from_array(([1, seq_len], attention_mask))
98 .map_err(|e| NliError::Inference(e.to_string()))?;
99
100 let inputs = ort::inputs![ids_tensor, mask_tensor]
101 .map_err(|e| NliError::Inference(e.to_string()))?;
102
103 let outputs = self.session.run(inputs)
104 .map_err(|e| NliError::Inference(e.to_string()))?;
105
106 let logits_view = outputs[0]
108 .try_extract_tensor::<f32>()
109 .map_err(|e| NliError::Inference(e.to_string()))?;
110
111 let logits = logits_view.as_slice()
112 .ok_or_else(|| NliError::Inference("non-contiguous logits".into()))?;
113
114 if logits.len() < 3 {
115 return Err(NliError::Inference(format!("expected 3 logits, got {}", logits.len())));
116 }
117
118 let max_logit = logits[0].max(logits[1]).max(logits[2]);
120 let exp: Vec<f32> = logits[..3].iter().map(|&l| (l - max_logit).exp()).collect();
121 let sum: f32 = exp.iter().sum();
122 Ok([exp[0] / sum, exp[1] / sum, exp[2] / sum])
123 }
124
125 pub fn extract(
130 &self,
131 text: &str,
132 entities: &[ExtractedEntity],
133 schema: &ExtractionSchema,
134 threshold: f32,
135 ) -> Result<Vec<ExtractedRelation>, NliError> {
136 let mut relations = Vec::new();
137 let mut seen = std::collections::HashSet::<(String, String, String)>::new();
138
139 let sentences = split_into_sentences(text);
141
142 for (sent_start, sent_end) in &sentences {
143 let premise = &text[*sent_start..*sent_end];
144
145 let sent_entities: Vec<&ExtractedEntity> = entities
147 .iter()
148 .filter(|e| e.span_start >= *sent_start && e.span_start < *sent_end)
149 .collect();
150
151 if sent_entities.len() < 2 {
152 continue;
153 }
154
155 for (i, head) in sent_entities.iter().enumerate() {
157 for tail in sent_entities.iter().skip(i + 1) {
158 if head.text == tail.text {
159 continue;
160 }
161
162 for &(rel_name, templates) in HYPOTHESIS_TEMPLATES {
164 let schema_valid = schema.relation_types.get(rel_name)
166 .map(|spec| {
167 (spec.head.contains(&head.entity_type) && spec.tail.contains(&tail.entity_type))
168 || (spec.head.contains(&tail.entity_type) && spec.tail.contains(&head.entity_type))
169 })
170 .unwrap_or(false);
171
172 if !schema_valid {
173 continue;
174 }
175
176 let mut best_score_fwd: f32 = 0.0;
178 for template in templates {
179 let hypothesis = template
180 .replace("{head}", &head.text)
181 .replace("{tail}", &tail.text);
182 if let Ok(probs) = self.score(premise, &hypothesis) {
183 best_score_fwd = best_score_fwd.max(probs[ENTAILMENT_IDX]);
184 }
185 }
186
187 let mut best_score_rev: f32 = 0.0;
189 for template in templates {
190 let hypothesis = template
191 .replace("{head}", &tail.text)
192 .replace("{tail}", &head.text);
193 if let Ok(probs) = self.score(premise, &hypothesis) {
194 best_score_rev = best_score_rev.max(probs[ENTAILMENT_IDX]);
195 }
196 }
197
198 let (actual_head, actual_tail, score) = if best_score_fwd >= best_score_rev {
200 (&head.text, &tail.text, best_score_fwd)
201 } else {
202 (&tail.text, &head.text, best_score_rev)
203 };
204
205 if score >= threshold {
206 let key = (actual_head.clone(), rel_name.to_string(), actual_tail.clone());
207 if seen.insert(key) {
208 relations.push(ExtractedRelation {
209 head: actual_head.clone(),
210 relation: rel_name.to_string(),
211 tail: actual_tail.clone(),
212 confidence: score as f64,
213 });
214 }
215 }
216 }
217 }
218 }
219 }
220
221 deduplicate_by_pair(&mut relations);
223
224 Ok(relations)
225 }
226}
227
228fn deduplicate_by_pair(relations: &mut Vec<ExtractedRelation>) {
230 let mut best: std::collections::HashMap<(String, String), usize> = std::collections::HashMap::new();
231
232 for (i, rel) in relations.iter().enumerate() {
233 let key = (rel.head.clone(), rel.tail.clone());
234 let rev_key = (rel.tail.clone(), rel.head.clone());
235 let existing_key = if best.contains_key(&key) { Some(key.clone()) }
236 else if best.contains_key(&rev_key) { Some(rev_key) }
237 else { None };
238
239 if let Some(k) = existing_key {
240 let prev_idx = best[&k];
241 if rel.confidence > relations[prev_idx].confidence {
242 best.insert(k, i);
243 }
244 } else {
245 best.insert(key, i);
246 }
247 }
248
249 let keep: std::collections::HashSet<usize> = best.values().copied().collect();
250 let mut idx = 0;
251 relations.retain(|_| {
252 let k = keep.contains(&idx);
253 idx += 1;
254 k
255 });
256}
257
258fn split_into_sentences(text: &str) -> Vec<(usize, usize)> {
260 let mut ranges = Vec::new();
261 let bytes = text.as_bytes();
262 let len = text.len();
263 let mut seg_start = 0usize;
264 let mut i = 0usize;
265
266 while i < len {
267 let boundary = if i + 1 < len
268 && (bytes[i] == b'.' || bytes[i] == b'!' || bytes[i] == b'?')
269 && bytes[i + 1] == b' '
270 {
271 Some(i + 1)
272 } else if i + 1 < len && bytes[i] == b'\n' && bytes[i + 1] == b'\n' {
273 Some(i)
274 } else {
275 None
276 };
277
278 if let Some(end) = boundary {
279 ranges.push((seg_start, end));
280 seg_start = end + 1;
281 i = seg_start;
282 continue;
283 }
284 i += 1;
285 }
286 if seg_start < len {
287 ranges.push((seg_start, len));
288 }
289 if ranges.is_empty() {
290 ranges.push((0, len));
291 }
292 ranges
293}
294
295#[derive(Debug, thiserror::Error)]
296pub enum NliError {
297 #[error("failed to load NLI model: {0}")]
298 ModelLoad(String),
299
300 #[error("NLI inference error: {0}")]
301 Inference(String),
302}