1use std::collections::HashMap;
11use std::path::Path;
12
13use ort::session::Session;
14use ort::value::Tensor;
15
16use crate::ner::ExtractedEntity;
17use crate::rel::{ExtractedRelation, RelError};
18
19const DEFAULT_THRESHOLD: f32 = 0.5;
21
22const NUM_CLASSES: usize = 10;
24
25const LABEL_MAP: &[&str] = &[
27 "chose", "rejected", "replaced", "depends_on", "fixed", "introduced", "deprecated", "caused", "constrained_by", "none", ];
38
39const NONE_CLASS_IDX: usize = 9;
41
42pub struct RelationClassifier {
44 session: Session,
45 label_map: Vec<String>,
46 threshold: f32,
47}
48
49impl RelationClassifier {
50 pub fn new(model_path: &Path) -> Result<Self, RelError> {
55 let session = Session::builder()
56 .and_then(|b| b.with_intra_threads(1))
57 .and_then(|b| b.commit_from_file(model_path))
58 .map_err(|e| RelError::ModelLoad(e.to_string()))?;
59
60 let label_map = if let Some(parent) = model_path.parent() {
62 let label_map_path = parent.join("label_map.json");
63 load_label_map(&label_map_path).unwrap_or_else(default_label_map)
64 } else {
65 default_label_map()
66 };
67
68 Ok(Self {
69 session,
70 label_map,
71 threshold: DEFAULT_THRESHOLD,
72 })
73 }
74
75 pub fn with_threshold(mut self, threshold: f32) -> Self {
77 self.threshold = threshold;
78 self
79 }
80
81 pub fn classify_embedding(
86 &self,
87 embedding: &[f32],
88 ) -> Result<Option<(String, f32)>, RelError> {
89 let dim = embedding.len();
90 let tensor = Tensor::from_array(([1, dim], embedding.to_vec()))
91 .map_err(|e| RelError::Inference(e.to_string()))?;
92
93 let inputs = ort::inputs![
94 "embedding" => tensor,
95 ]
96 .map_err(|e| RelError::Inference(e.to_string()))?;
97
98 let outputs = self
99 .session
100 .run(inputs)
101 .map_err(|e| RelError::Inference(e.to_string()))?;
102
103 let logits_view = outputs[0]
104 .try_extract_tensor::<f32>()
105 .map_err(|e| RelError::Inference(e.to_string()))?;
106
107 let logits = logits_view
108 .as_slice()
109 .ok_or_else(|| RelError::Inference("non-contiguous logits".into()))?;
110
111 let num_classes = logits.len().min(self.label_map.len());
112 if num_classes == 0 {
113 return Err(RelError::Inference("empty logits output".into()));
114 }
115
116 let probs = softmax(&logits[..num_classes]);
118
119 let (best_idx, best_prob) = probs
121 .iter()
122 .enumerate()
123 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
124 .unwrap(); if best_idx == NONE_CLASS_IDX || best_idx >= self.label_map.len() {
128 return Ok(None);
129 }
130 if *best_prob < self.threshold {
131 return Ok(None);
132 }
133
134 Ok(Some((self.label_map[best_idx].clone(), *best_prob)))
135 }
136
137 pub fn classify_batch<F>(
143 &self,
144 text: &str,
145 entities: &[ExtractedEntity],
146 embed_fn: &F,
147 ) -> Result<Vec<ExtractedRelation>, RelError>
148 where
149 F: Fn(&str) -> Result<Vec<f32>, RelError>,
150 {
151 let mut relations = Vec::new();
152 let mut seen = std::collections::HashSet::<(String, String)>::new();
153
154 for (i, head_ent) in entities.iter().enumerate() {
155 for tail_ent in entities.iter().skip(i + 1) {
156 if head_ent.text == tail_ent.text {
157 continue;
158 }
159
160 let pair_key = if head_ent.text < tail_ent.text {
161 (head_ent.text.clone(), tail_ent.text.clone())
162 } else {
163 (tail_ent.text.clone(), head_ent.text.clone())
164 };
165
166 if !seen.insert(pair_key) {
167 continue;
168 }
169
170 if let Some((relation, confidence)) =
172 self.classify_pair(text, &head_ent.text, &tail_ent.text, embed_fn)?
173 {
174 relations.push(ExtractedRelation {
175 head: head_ent.text.clone(),
176 relation,
177 tail: tail_ent.text.clone(),
178 confidence: confidence as f64,
179 });
180 continue;
181 }
182
183 if let Some((relation, confidence)) =
185 self.classify_pair(text, &tail_ent.text, &head_ent.text, embed_fn)?
186 {
187 relations.push(ExtractedRelation {
188 head: tail_ent.text.clone(),
189 relation,
190 tail: head_ent.text.clone(),
191 confidence: confidence as f64,
192 });
193 }
194 }
195 }
196
197 Ok(relations)
198 }
199
200 fn classify_pair<F>(
205 &self,
206 text: &str,
207 head: &str,
208 tail: &str,
209 embed_fn: &F,
210 ) -> Result<Option<(String, f32)>, RelError>
211 where
212 F: Fn(&str) -> Result<Vec<f32>, RelError>,
213 {
214 let head_ctx_text = insert_single_marker(text, head, "[E1]", "[/E1]");
216 let head_ctx = embed_fn(&head_ctx_text)?;
217
218 let tail_ctx_text = insert_single_marker(text, tail, "[E2]", "[/E2]");
220 let tail_ctx = embed_fn(&tail_ctx_text)?;
221
222 let pair_text = format!("{} {}", head, tail);
224 let pair_ctx = embed_fn(&pair_text)?;
225
226 let mut combined = Vec::with_capacity(head_ctx.len() + tail_ctx.len() + pair_ctx.len());
228 combined.extend_from_slice(&head_ctx);
229 combined.extend_from_slice(&tail_ctx);
230 combined.extend_from_slice(&pair_ctx);
231
232 self.classify_embedding(&combined)
233 }
234}
235
236fn insert_entity_markers(text: &str, head: &str, tail: &str) -> String {
239 let text_lower = text.to_lowercase();
240 let head_lower = head.to_lowercase();
241 let tail_lower = tail.to_lowercase();
242
243 let head_pos = text_lower.find(&head_lower);
244 let tail_pos = text_lower.find(&tail_lower);
245
246 match (head_pos, tail_pos) {
247 (Some(hp), Some(tp)) if hp <= tp => {
248 let head_end = hp + head.len();
249 let tail_in_rest = text[head_end..].to_lowercase().find(&tail_lower);
251 if let Some(rel_tp) = tail_in_rest {
252 let abs_tp = head_end + rel_tp;
253 let tail_end = abs_tp + tail.len();
254 format!(
255 "{}[E1]{}[/E1]{}[E2]{}[/E2]{}",
256 &text[..hp],
257 &text[hp..head_end],
258 &text[head_end..abs_tp],
259 &text[abs_tp..tail_end],
260 &text[tail_end..]
261 )
262 } else {
263 format!(
265 "[E2]{}[/E2] {}[E1]{}[/E1]{}",
266 tail,
267 &text[..hp],
268 &text[hp..head_end],
269 &text[head_end..]
270 )
271 }
272 }
273 (Some(hp), Some(tp)) => {
274 let tail_end = tp + tail.len();
276 let head_end = hp + head.len();
277 if tail_end <= hp {
278 format!(
279 "{}[E2]{}[/E2]{}[E1]{}[/E1]{}",
280 &text[..tp],
281 &text[tp..tail_end],
282 &text[tail_end..hp],
283 &text[hp..head_end],
284 &text[head_end..]
285 )
286 } else {
287 format!(
289 "[E1]{}[/E1] {}[E2]{}[/E2]{}",
290 head,
291 &text[..tp],
292 &text[tp..tail_end],
293 &text[tail_end..]
294 )
295 }
296 }
297 (Some(hp), None) => {
298 let head_end = hp + head.len();
299 format!(
300 "[E2]{}[/E2] {}[E1]{}[/E1]{}",
301 tail,
302 &text[..hp],
303 &text[hp..head_end],
304 &text[head_end..]
305 )
306 }
307 (None, Some(tp)) => {
308 let tail_end = tp + tail.len();
309 format!(
310 "[E1]{}[/E1] {}[E2]{}[/E2]{}",
311 head,
312 &text[..tp],
313 &text[tp..tail_end],
314 &text[tail_end..]
315 )
316 }
317 (None, None) => {
318 format!("[E1]{}[/E1] [E2]{}[/E2] {}", head, tail, text)
319 }
320 }
321}
322
323fn insert_single_marker(text: &str, entity: &str, open: &str, close: &str) -> String {
325 let text_lower = text.to_lowercase();
326 let entity_lower = entity.to_lowercase();
327
328 if let Some(pos) = text_lower.find(&entity_lower) {
329 let end = pos + entity.len();
330 format!("{}{}{}{}{}", &text[..pos], open, &text[pos..end], close, &text[end..])
331 } else {
332 format!("{}{}{} {}", open, entity, close, text)
333 }
334}
335
336fn softmax(logits: &[f32]) -> Vec<f32> {
338 let max = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
339 let exps: Vec<f32> = logits.iter().map(|&l| (l - max).exp()).collect();
340 let sum: f32 = exps.iter().sum();
341 exps.into_iter().map(|e| e / sum).collect()
342}
343
344fn load_label_map(path: &Path) -> Option<Vec<String>> {
346 let content = std::fs::read_to_string(path).ok()?;
347 let map: HashMap<String, String> = serde_json::from_str(&content).ok()?;
348
349 let max_idx = map.keys().filter_map(|k| k.parse::<usize>().ok()).max()?;
350 let mut labels = vec!["none".to_string(); max_idx + 1];
351 for (k, v) in &map {
352 if let Ok(idx) = k.parse::<usize>() {
353 labels[idx] = v.clone();
354 }
355 }
356 Some(labels)
357}
358
359fn default_label_map() -> Vec<String> {
361 LABEL_MAP.iter().map(|&s| s.to_string()).collect()
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367
368 #[test]
369 fn test_insert_markers_both_found() {
370 let text = "The team chose PostgreSQL over MySQL for the project.";
371 let result = insert_entity_markers(text, "PostgreSQL", "MySQL");
372 assert!(result.contains("[E1]PostgreSQL[/E1]"));
373 assert!(result.contains("[E2]MySQL[/E2]"));
374 }
375
376 #[test]
377 fn test_insert_markers_tail_before_head() {
378 let text = "MySQL was replaced by PostgreSQL.";
379 let result = insert_entity_markers(text, "PostgreSQL", "MySQL");
380 assert!(result.contains("[E1]PostgreSQL[/E1]"));
381 assert!(result.contains("[E2]MySQL[/E2]"));
382 }
383
384 #[test]
385 fn test_insert_markers_neither_found() {
386 let text = "Some unrelated text about databases.";
387 let result = insert_entity_markers(text, "Redis", "Kafka");
388 assert!(result.contains("[E1]Redis[/E1]"));
389 assert!(result.contains("[E2]Kafka[/E2]"));
390 assert!(result.contains("databases"));
391 }
392
393 #[test]
394 fn test_softmax() {
395 let probs = softmax(&[1.0, 2.0, 3.0]);
396 let sum: f32 = probs.iter().sum();
397 assert!((sum - 1.0).abs() < 1e-5);
398 assert!(probs[2] > probs[1]);
399 assert!(probs[1] > probs[0]);
400 }
401
402 #[test]
403 fn test_default_label_map() {
404 let map = default_label_map();
405 assert_eq!(map.len(), NUM_CLASSES);
406 assert_eq!(map[0], "chose");
407 assert_eq!(map[9], "none");
408 }
409}