Skip to main content

ctxgraph_extract/
relclf.rs

1//! Relation classifier using sentence embeddings + logistic regression.
2//!
3//! Takes a 384-dim embedding (from all-MiniLM-L6-v2) as input and outputs
4//! one of 10 relation classes. The ONNX model is a tiny linear classifier
5//! (~15 KB) that runs a single matmul + bias.
6//!
7//! Entity markers `[E1]`/`[/E1]` and `[E2]`/`[/E2]` are inserted around
8//! entities in the text before embedding.
9
10use std::collections::HashMap;
11use std::path::Path;
12
13use ort::session::Session;
14use ort::value::Tensor;
15
16use crate::ner::ExtractedEntity;
17use crate::rel::{ExtractedRelation, RelError};
18
19/// Default confidence threshold for accepting a classification.
20const DEFAULT_THRESHOLD: f32 = 0.5;
21
22/// Number of output classes (including "none").
23const NUM_CLASSES: usize = 10;
24
25/// The label map: class index → relation name.
26const LABEL_MAP: &[&str] = &[
27    "chose",          // 0
28    "rejected",       // 1
29    "replaced",       // 2
30    "depends_on",     // 3
31    "fixed",          // 4
32    "introduced",     // 5
33    "deprecated",     // 6
34    "caused",         // 7
35    "constrained_by", // 8
36    "none",           // 9
37];
38
39/// Index of the "none" class in the label map.
40const NONE_CLASS_IDX: usize = 9;
41
42/// Embedding-based relation classifier (logistic regression on MiniLM embeddings).
43pub struct RelationClassifier {
44    session: Session,
45    label_map: Vec<String>,
46    threshold: f32,
47}
48
49impl RelationClassifier {
50    /// Load the ONNX model from disk.
51    ///
52    /// The model expects a single input "embedding" of shape [1, 384]
53    /// and produces "logits" of shape [1, 10].
54    pub fn new(model_path: &Path) -> Result<Self, RelError> {
55        let session = Session::builder()
56            .and_then(|b| b.with_intra_threads(1))
57            .and_then(|b| b.commit_from_file(model_path))
58            .map_err(|e| RelError::ModelLoad(e.to_string()))?;
59
60        // Try loading label_map.json from model directory
61        let label_map = if let Some(parent) = model_path.parent() {
62            let label_map_path = parent.join("label_map.json");
63            load_label_map(&label_map_path).unwrap_or_else(default_label_map)
64        } else {
65            default_label_map()
66        };
67
68        Ok(Self {
69            session,
70            label_map,
71            threshold: DEFAULT_THRESHOLD,
72        })
73    }
74
75    /// Set the confidence threshold for accepting a classification.
76    pub fn with_threshold(mut self, threshold: f32) -> Self {
77        self.threshold = threshold;
78        self
79    }
80
81    /// Classify a relation from a pre-computed 384-dim embedding.
82    ///
83    /// Returns `Some((relation_type, confidence))` if a relation is detected,
84    /// or `None` if the "none" class wins or confidence is below threshold.
85    pub fn classify_embedding(
86        &self,
87        embedding: &[f32],
88    ) -> Result<Option<(String, f32)>, RelError> {
89        let dim = embedding.len();
90        let tensor = Tensor::from_array(([1, dim], embedding.to_vec()))
91            .map_err(|e| RelError::Inference(e.to_string()))?;
92
93        let inputs = ort::inputs![
94            "embedding" => tensor,
95        ]
96        .map_err(|e| RelError::Inference(e.to_string()))?;
97
98        let outputs = self
99            .session
100            .run(inputs)
101            .map_err(|e| RelError::Inference(e.to_string()))?;
102
103        let logits_view = outputs[0]
104            .try_extract_tensor::<f32>()
105            .map_err(|e| RelError::Inference(e.to_string()))?;
106
107        let logits = logits_view
108            .as_slice()
109            .ok_or_else(|| RelError::Inference("non-contiguous logits".into()))?;
110
111        let num_classes = logits.len().min(self.label_map.len());
112        if num_classes == 0 {
113            return Err(RelError::Inference("empty logits output".into()));
114        }
115
116        // Softmax
117        let probs = softmax(&logits[..num_classes]);
118
119        // Find the top class
120        let (best_idx, best_prob) = probs
121            .iter()
122            .enumerate()
123            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
124            .unwrap(); // safe: num_classes > 0
125
126        // Return None if "none" class wins or confidence is below threshold
127        if best_idx == NONE_CLASS_IDX || best_idx >= self.label_map.len() {
128            return Ok(None);
129        }
130        if *best_prob < self.threshold {
131            return Ok(None);
132        }
133
134        Ok(Some((self.label_map[best_idx].clone(), *best_prob)))
135    }
136
137    /// Classify all unique entity pairs using entity-pair embeddings.
138    ///
139    /// For each pair, generates 3 embeddings (head_ctx, tail_ctx, pair_ctx)
140    /// and concatenates them into a 1152-dim vector for classification.
141    /// `embed_fn` takes text and returns a 384-dim vector.
142    pub fn classify_batch<F>(
143        &self,
144        text: &str,
145        entities: &[ExtractedEntity],
146        embed_fn: &F,
147    ) -> Result<Vec<ExtractedRelation>, RelError>
148    where
149        F: Fn(&str) -> Result<Vec<f32>, RelError>,
150    {
151        let mut relations = Vec::new();
152        let mut seen = std::collections::HashSet::<(String, String)>::new();
153
154        for (i, head_ent) in entities.iter().enumerate() {
155            for tail_ent in entities.iter().skip(i + 1) {
156                if head_ent.text == tail_ent.text {
157                    continue;
158                }
159
160                let pair_key = if head_ent.text < tail_ent.text {
161                    (head_ent.text.clone(), tail_ent.text.clone())
162                } else {
163                    (tail_ent.text.clone(), head_ent.text.clone())
164                };
165
166                if !seen.insert(pair_key) {
167                    continue;
168                }
169
170                // Try head→tail
171                if let Some((relation, confidence)) =
172                    self.classify_pair(text, &head_ent.text, &tail_ent.text, embed_fn)?
173                {
174                    relations.push(ExtractedRelation {
175                        head: head_ent.text.clone(),
176                        relation,
177                        tail: tail_ent.text.clone(),
178                        confidence: confidence as f64,
179                    });
180                    continue;
181                }
182
183                // Try tail→head
184                if let Some((relation, confidence)) =
185                    self.classify_pair(text, &tail_ent.text, &head_ent.text, embed_fn)?
186                {
187                    relations.push(ExtractedRelation {
188                        head: tail_ent.text.clone(),
189                        relation,
190                        tail: head_ent.text.clone(),
191                        confidence: confidence as f64,
192                    });
193                }
194            }
195        }
196
197        Ok(relations)
198    }
199
200    /// Classify a single entity pair using 3-component entity-pair embeddings.
201    ///
202    /// Generates: head_ctx (text with [E1] only), tail_ctx (text with [E2] only),
203    /// pair_ctx (just entity names). Concatenates to 1152-dim input.
204    fn classify_pair<F>(
205        &self,
206        text: &str,
207        head: &str,
208        tail: &str,
209        embed_fn: &F,
210    ) -> Result<Option<(String, f32)>, RelError>
211    where
212        F: Fn(&str) -> Result<Vec<f32>, RelError>,
213    {
214        // head_ctx: text with only head marked
215        let head_ctx_text = insert_single_marker(text, head, "[E1]", "[/E1]");
216        let head_ctx = embed_fn(&head_ctx_text)?;
217
218        // tail_ctx: text with only tail marked
219        let tail_ctx_text = insert_single_marker(text, tail, "[E2]", "[/E2]");
220        let tail_ctx = embed_fn(&tail_ctx_text)?;
221
222        // pair_ctx: just the entity names
223        let pair_text = format!("{} {}", head, tail);
224        let pair_ctx = embed_fn(&pair_text)?;
225
226        // Concatenate: [head_ctx || tail_ctx || pair_ctx]
227        let mut combined = Vec::with_capacity(head_ctx.len() + tail_ctx.len() + pair_ctx.len());
228        combined.extend_from_slice(&head_ctx);
229        combined.extend_from_slice(&tail_ctx);
230        combined.extend_from_slice(&pair_ctx);
231
232        self.classify_embedding(&combined)
233    }
234}
235
236/// Insert `[E1]`/`[/E1]` around the head entity and `[E2]`/`[/E2]` around
237/// the tail entity in the text.
238fn insert_entity_markers(text: &str, head: &str, tail: &str) -> String {
239    let text_lower = text.to_lowercase();
240    let head_lower = head.to_lowercase();
241    let tail_lower = tail.to_lowercase();
242
243    let head_pos = text_lower.find(&head_lower);
244    let tail_pos = text_lower.find(&tail_lower);
245
246    match (head_pos, tail_pos) {
247        (Some(hp), Some(tp)) if hp <= tp => {
248            let head_end = hp + head.len();
249            // Find tail after head to avoid overlap
250            let tail_in_rest = text[head_end..].to_lowercase().find(&tail_lower);
251            if let Some(rel_tp) = tail_in_rest {
252                let abs_tp = head_end + rel_tp;
253                let tail_end = abs_tp + tail.len();
254                format!(
255                    "{}[E1]{}[/E1]{}[E2]{}[/E2]{}",
256                    &text[..hp],
257                    &text[hp..head_end],
258                    &text[head_end..abs_tp],
259                    &text[abs_tp..tail_end],
260                    &text[tail_end..]
261                )
262            } else {
263                // Tail overlaps with head or not found after — prepend tail marker
264                format!(
265                    "[E2]{}[/E2] {}[E1]{}[/E1]{}",
266                    tail,
267                    &text[..hp],
268                    &text[hp..head_end],
269                    &text[head_end..]
270                )
271            }
272        }
273        (Some(hp), Some(tp)) => {
274            // Tail before head
275            let tail_end = tp + tail.len();
276            let head_end = hp + head.len();
277            if tail_end <= hp {
278                format!(
279                    "{}[E2]{}[/E2]{}[E1]{}[/E1]{}",
280                    &text[..tp],
281                    &text[tp..tail_end],
282                    &text[tail_end..hp],
283                    &text[hp..head_end],
284                    &text[head_end..]
285                )
286            } else {
287                // Overlapping — prepend head marker
288                format!(
289                    "[E1]{}[/E1] {}[E2]{}[/E2]{}",
290                    head,
291                    &text[..tp],
292                    &text[tp..tail_end],
293                    &text[tail_end..]
294                )
295            }
296        }
297        (Some(hp), None) => {
298            let head_end = hp + head.len();
299            format!(
300                "[E2]{}[/E2] {}[E1]{}[/E1]{}",
301                tail,
302                &text[..hp],
303                &text[hp..head_end],
304                &text[head_end..]
305            )
306        }
307        (None, Some(tp)) => {
308            let tail_end = tp + tail.len();
309            format!(
310                "[E1]{}[/E1] {}[E2]{}[/E2]{}",
311                head,
312                &text[..tp],
313                &text[tp..tail_end],
314                &text[tail_end..]
315            )
316        }
317        (None, None) => {
318            format!("[E1]{}[/E1] [E2]{}[/E2] {}", head, tail, text)
319        }
320    }
321}
322
323/// Insert a single entity marker around the first occurrence of `entity` in `text`.
324fn insert_single_marker(text: &str, entity: &str, open: &str, close: &str) -> String {
325    let text_lower = text.to_lowercase();
326    let entity_lower = entity.to_lowercase();
327
328    if let Some(pos) = text_lower.find(&entity_lower) {
329        let end = pos + entity.len();
330        format!("{}{}{}{}{}", &text[..pos], open, &text[pos..end], close, &text[end..])
331    } else {
332        format!("{}{}{} {}", open, entity, close, text)
333    }
334}
335
336/// Compute softmax over a slice of f32 values.
337fn softmax(logits: &[f32]) -> Vec<f32> {
338    let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
339    let exps: Vec<f32> = logits.iter().map(|&l| (l - max).exp()).collect();
340    let sum: f32 = exps.iter().sum();
341    exps.into_iter().map(|e| e / sum).collect()
342}
343
344/// Load label_map.json from disk: `{"0": "chose", "1": "rejected", ...}`.
345fn load_label_map(path: &Path) -> Option<Vec<String>> {
346    let content = std::fs::read_to_string(path).ok()?;
347    let map: HashMap<String, String> = serde_json::from_str(&content).ok()?;
348
349    let max_idx = map.keys().filter_map(|k| k.parse::<usize>().ok()).max()?;
350    let mut labels = vec!["none".to_string(); max_idx + 1];
351    for (k, v) in &map {
352        if let Ok(idx) = k.parse::<usize>() {
353            labels[idx] = v.clone();
354        }
355    }
356    Some(labels)
357}
358
359/// Return the built-in label map as a Vec<String>.
360fn default_label_map() -> Vec<String> {
361    LABEL_MAP.iter().map(|&s| s.to_string()).collect()
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367
368    #[test]
369    fn test_insert_markers_both_found() {
370        let text = "The team chose PostgreSQL over MySQL for the project.";
371        let result = insert_entity_markers(text, "PostgreSQL", "MySQL");
372        assert!(result.contains("[E1]PostgreSQL[/E1]"));
373        assert!(result.contains("[E2]MySQL[/E2]"));
374    }
375
376    #[test]
377    fn test_insert_markers_tail_before_head() {
378        let text = "MySQL was replaced by PostgreSQL.";
379        let result = insert_entity_markers(text, "PostgreSQL", "MySQL");
380        assert!(result.contains("[E1]PostgreSQL[/E1]"));
381        assert!(result.contains("[E2]MySQL[/E2]"));
382    }
383
384    #[test]
385    fn test_insert_markers_neither_found() {
386        let text = "Some unrelated text about databases.";
387        let result = insert_entity_markers(text, "Redis", "Kafka");
388        assert!(result.contains("[E1]Redis[/E1]"));
389        assert!(result.contains("[E2]Kafka[/E2]"));
390        assert!(result.contains("databases"));
391    }
392
393    #[test]
394    fn test_softmax() {
395        let probs = softmax(&[1.0, 2.0, 3.0]);
396        let sum: f32 = probs.iter().sum();
397        assert!((sum - 1.0).abs() < 1e-5);
398        assert!(probs[2] > probs[1]);
399        assert!(probs[1] > probs[0]);
400    }
401
402    #[test]
403    fn test_default_label_map() {
404        let map = default_label_map();
405        assert_eq!(map.len(), NUM_CLASSES);
406        assert_eq!(map[0], "chose");
407        assert_eq!(map[9], "none");
408    }
409}