Skip to main content

ctxgraph_extract/
nli.rs

1//! Zero-shot relation classification via NLI cross-encoder.
2//!
3//! Uses a DeBERTa-v3-xsmall NLI model to score (premise, hypothesis) pairs.
4//! For relation extraction, the premise is the source text and the hypothesis
5//! is a natural language statement like "X depends on Y".
6//!
7//! Output labels: index 0 = contradiction, index 1 = entailment, index 2 = neutral.
8
9use std::path::Path;
10
11use ort::session::Session;
12use ort::value::Tensor;
13use tokenizers::Tokenizer;
14
15use crate::ner::ExtractedEntity;
16use crate::rel::ExtractedRelation;
17use crate::schema::ExtractionSchema;
18
19/// Hypothesis templates for each relation type.
20/// Multiple templates per relation improve recall.
21const HYPOTHESIS_TEMPLATES: &[(&str, &[&str])] = &[
22    ("chose", &[
23        "{head} chose {tail}",
24        "{head} selected {tail}",
25    ]),
26    ("rejected", &[
27        "{head} rejected {tail}",
28        "{head} decided against {tail}",
29    ]),
30    ("replaced", &[
31        "{head} replaced {tail}",
32        "{tail} was replaced by {head}",
33    ]),
34    ("depends_on", &[
35        "{head} depends on {tail}",
36        "{head} uses {tail}",
37    ]),
38    ("fixed", &[
39        "{head} fixed {tail}",
40        "{head} resolved {tail}",
41    ]),
42    ("introduced", &[
43        "{head} introduced {tail}",
44        "{head} added {tail}",
45    ]),
46    ("deprecated", &[
47        "{head} deprecated {tail}",
48        "{head} removed {tail}",
49    ]),
50    ("caused", &[
51        "{head} caused {tail}",
52        "{head} led to {tail}",
53    ]),
54    ("constrained_by", &[
55        "{head} is constrained by {tail}",
56        "{head} must comply with {tail}",
57    ]),
58];
59
60/// NLI-based relation extraction engine.
61pub struct NliEngine {
62    session: Session,
63    tokenizer: Tokenizer,
64}
65
66/// Index of the "entailment" label in model output.
67const ENTAILMENT_IDX: usize = 1;
68
69impl NliEngine {
70    /// Load the NLI ONNX model and tokenizer.
71    pub fn new(model_path: &Path, tokenizer_path: &Path) -> Result<Self, NliError> {
72        let session = Session::builder()
73            .and_then(|b| b.with_intra_threads(1))
74            .and_then(|b| b.commit_from_file(model_path))
75            .map_err(|e| NliError::ModelLoad(e.to_string()))?;
76
77        let tokenizer = Tokenizer::from_file(tokenizer_path)
78            .map_err(|e| NliError::ModelLoad(e.to_string()))?;
79
80        Ok(Self { session, tokenizer })
81    }
82
83    /// Score a single (premise, hypothesis) pair.
84    /// Returns softmax probabilities [contradiction, entailment, neutral].
85    fn score(&self, premise: &str, hypothesis: &str) -> Result<[f32; 3], NliError> {
86        let encoding = self.tokenizer
87            .encode((premise, hypothesis), true)
88            .map_err(|e| NliError::Inference(e.to_string()))?;
89
90        let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&id| id as i64).collect();
91        let attention_mask: Vec<i64> = encoding.get_attention_mask().iter().map(|&m| m as i64).collect();
92
93        let seq_len = input_ids.len();
94
95        let ids_tensor = Tensor::from_array(([1, seq_len], input_ids))
96            .map_err(|e| NliError::Inference(e.to_string()))?;
97        let mask_tensor = Tensor::from_array(([1, seq_len], attention_mask))
98            .map_err(|e| NliError::Inference(e.to_string()))?;
99
100        let inputs = ort::inputs![ids_tensor, mask_tensor]
101            .map_err(|e| NliError::Inference(e.to_string()))?;
102
103        let outputs = self.session.run(inputs)
104            .map_err(|e| NliError::Inference(e.to_string()))?;
105
106        // Output shape: [1, 3] — logits for [contradiction, entailment, neutral]
107        let logits_view = outputs[0]
108            .try_extract_tensor::<f32>()
109            .map_err(|e| NliError::Inference(e.to_string()))?;
110
111        let logits = logits_view.as_slice()
112            .ok_or_else(|| NliError::Inference("non-contiguous logits".into()))?;
113
114        if logits.len() < 3 {
115            return Err(NliError::Inference(format!("expected 3 logits, got {}", logits.len())));
116        }
117
118        // Softmax
119        let max_logit = logits[0].max(logits[1]).max(logits[2]);
120        let exp: Vec<f32> = logits[..3].iter().map(|&l| (l - max_logit).exp()).collect();
121        let sum: f32 = exp.iter().sum();
122        Ok([exp[0] / sum, exp[1] / sum, exp[2] / sum])
123    }
124
125    /// Extract relations using NLI entailment scoring.
126    ///
127    /// For each entity pair in the text, tests hypothesis templates for each
128    /// relation type. Returns relations where entailment score exceeds the threshold.
129    pub fn extract(
130        &self,
131        text: &str,
132        entities: &[ExtractedEntity],
133        schema: &ExtractionSchema,
134        threshold: f32,
135    ) -> Result<Vec<ExtractedRelation>, NliError> {
136        let mut relations = Vec::new();
137        let mut seen = std::collections::HashSet::<(String, String, String)>::new();
138
139        // Split text into sentences for focused premises
140        let sentences = split_into_sentences(text);
141
142        for (sent_start, sent_end) in &sentences {
143            let premise = &text[*sent_start..*sent_end];
144
145            // Find entities in this sentence (+ adjacent sentence window)
146            let sent_entities: Vec<&ExtractedEntity> = entities
147                .iter()
148                .filter(|e| e.span_start >= *sent_start && e.span_start < *sent_end)
149                .collect();
150
151            if sent_entities.len() < 2 {
152                continue;
153            }
154
155            // Test all entity pairs
156            for (i, head) in sent_entities.iter().enumerate() {
157                for tail in sent_entities.iter().skip(i + 1) {
158                    if head.text == tail.text {
159                        continue;
160                    }
161
162                    // Test both directions for each relation
163                    for &(rel_name, templates) in HYPOTHESIS_TEMPLATES {
164                        // Check schema validity
165                        let schema_valid = schema.relation_types.get(rel_name)
166                            .map(|spec| {
167                                (spec.head.contains(&head.entity_type) && spec.tail.contains(&tail.entity_type))
168                                || (spec.head.contains(&tail.entity_type) && spec.tail.contains(&head.entity_type))
169                            })
170                            .unwrap_or(false);
171
172                        if !schema_valid {
173                            continue;
174                        }
175
176                        // Try head→tail direction
177                        let mut best_score_fwd: f32 = 0.0;
178                        for template in templates {
179                            let hypothesis = template
180                                .replace("{head}", &head.text)
181                                .replace("{tail}", &tail.text);
182                            if let Ok(probs) = self.score(premise, &hypothesis) {
183                                best_score_fwd = best_score_fwd.max(probs[ENTAILMENT_IDX]);
184                            }
185                        }
186
187                        // Try tail→head direction
188                        let mut best_score_rev: f32 = 0.0;
189                        for template in templates {
190                            let hypothesis = template
191                                .replace("{head}", &tail.text)
192                                .replace("{tail}", &head.text);
193                            if let Ok(probs) = self.score(premise, &hypothesis) {
194                                best_score_rev = best_score_rev.max(probs[ENTAILMENT_IDX]);
195                            }
196                        }
197
198                        // Pick the best direction
199                        let (actual_head, actual_tail, score) = if best_score_fwd >= best_score_rev {
200                            (&head.text, &tail.text, best_score_fwd)
201                        } else {
202                            (&tail.text, &head.text, best_score_rev)
203                        };
204
205                        if score >= threshold {
206                            let key = (actual_head.clone(), rel_name.to_string(), actual_tail.clone());
207                            if seen.insert(key) {
208                                relations.push(ExtractedRelation {
209                                    head: actual_head.clone(),
210                                    relation: rel_name.to_string(),
211                                    tail: actual_tail.clone(),
212                                    confidence: score as f64,
213                                });
214                            }
215                        }
216                    }
217                }
218            }
219        }
220
221        // Keep only top relation per entity pair (highest confidence)
222        deduplicate_by_pair(&mut relations);
223
224        Ok(relations)
225    }
226}
227
228/// Keep only the highest-confidence relation per (head, tail) pair.
229fn deduplicate_by_pair(relations: &mut Vec<ExtractedRelation>) {
230    let mut best: std::collections::HashMap<(String, String), usize> = std::collections::HashMap::new();
231
232    for (i, rel) in relations.iter().enumerate() {
233        let key = (rel.head.clone(), rel.tail.clone());
234        let rev_key = (rel.tail.clone(), rel.head.clone());
235        let existing_key = if best.contains_key(&key) { Some(key.clone()) }
236            else if best.contains_key(&rev_key) { Some(rev_key) }
237            else { None };
238
239        if let Some(k) = existing_key {
240            let prev_idx = best[&k];
241            if rel.confidence > relations[prev_idx].confidence {
242                best.insert(k, i);
243            }
244        } else {
245            best.insert(key, i);
246        }
247    }
248
249    let keep: std::collections::HashSet<usize> = best.values().copied().collect();
250    let mut idx = 0;
251    relations.retain(|_| {
252        let k = keep.contains(&idx);
253        idx += 1;
254        k
255    });
256}
257
258/// Simple sentence splitting (byte-level).
259fn split_into_sentences(text: &str) -> Vec<(usize, usize)> {
260    let mut ranges = Vec::new();
261    let bytes = text.as_bytes();
262    let len = text.len();
263    let mut seg_start = 0usize;
264    let mut i = 0usize;
265
266    while i < len {
267        let boundary = if i + 1 < len
268            && (bytes[i] == b'.' || bytes[i] == b'!' || bytes[i] == b'?')
269            && bytes[i + 1] == b' '
270        {
271            Some(i + 1)
272        } else if i + 1 < len && bytes[i] == b'\n' && bytes[i + 1] == b'\n' {
273            Some(i)
274        } else {
275            None
276        };
277
278        if let Some(end) = boundary {
279            ranges.push((seg_start, end));
280            seg_start = end + 1;
281            i = seg_start;
282            continue;
283        }
284        i += 1;
285    }
286    if seg_start < len {
287        ranges.push((seg_start, len));
288    }
289    if ranges.is_empty() {
290        ranges.push((0, len));
291    }
292    ranges
293}
294
295#[derive(Debug, thiserror::Error)]
296pub enum NliError {
297    #[error("failed to load NLI model: {0}")]
298    ModelLoad(String),
299
300    #[error("NLI inference error: {0}")]
301    Inference(String),
302}