Skip to main content

ctxgraph_extract/
ner.rs

1use std::path::Path;
2
3use gliner::model::input::text::TextInput;
4use gliner::model::params::Parameters;
5use gliner::model::pipeline::span::SpanMode;
6use gliner::model::GLiNER;
7use orp::params::RuntimeParameters;
8
9/// An entity extracted from text by the NER model.
10#[derive(Debug, Clone)]
11pub struct ExtractedEntity {
12    pub text: String,
13    pub entity_type: String,
14    pub span_start: usize,
15    pub span_end: usize,
16    pub confidence: f64,
17}
18
19/// NER engine wrapping gline-rs GLiNER in span mode.
20///
21/// Uses `onnx-community/gliner_large-v2.1` (or any span-based GLiNER ONNX model).
22pub struct NerEngine {
23    model: GLiNER<SpanMode>,
24}
25
26impl NerEngine {
27    /// Create a new NER engine from model and tokenizer paths.
28    ///
29    /// - `model_path`: path to `model.onnx` (or `model_int8.onnx`)
30    /// - `tokenizer_path`: path to `tokenizer.json`
31    /// - `threshold`: minimum GLiNER span probability (0.0–1.0). `Parameters::default`
32    ///   uses 0.5 which is too aggressive for domain-specific labels; use a lower
33    ///   value like 0.1–0.3 and let the pipeline apply its own threshold on top.
34    pub fn new(model_path: &Path, tokenizer_path: &Path, threshold: f32) -> Result<Self, NerError> {
35        let params = Parameters::default().with_threshold(threshold);
36        let runtime_params = RuntimeParameters::default();
37
38        let model = GLiNER::<SpanMode>::new(
39            params,
40            runtime_params,
41            tokenizer_path.to_str().ok_or(NerError::InvalidPath(
42                tokenizer_path.display().to_string(),
43            ))?,
44            model_path
45                .to_str()
46                .ok_or(NerError::InvalidPath(model_path.display().to_string()))?,
47        )
48        .map_err(|e| NerError::ModelLoad(e.to_string()))?;
49
50        Ok(Self { model })
51    }
52
53    /// Extract entities from text using the given labels.
54    ///
55    /// `label_to_type` is an optional mapping from GLiNER label string → canonical
56    /// entity type key. Pass `None` to use the label string as-is for `entity_type`.
57    /// Pass `Some(pairs)` when using natural-language descriptions as labels so the
58    /// returned `entity_type` is the short canonical key (e.g. "Database").
59    pub fn extract(
60        &self,
61        text: &str,
62        labels: &[&str],
63        label_to_type: Option<&std::collections::HashMap<&str, &str>>,
64    ) -> Result<Vec<ExtractedEntity>, NerError> {
65        let input = TextInput::from_str(&[text], labels)
66            .map_err(|e| NerError::Inference(e.to_string()))?;
67
68        let output = self
69            .model
70            .inference(input)
71            .map_err(|e| NerError::Inference(e.to_string()))?;
72
73        let mut entities = Vec::new();
74
75        // output.spans is Vec<Vec<Span>> — outer vec is per-sequence
76        for sequence_spans in &output.spans {
77            for span in sequence_spans {
78                // Use character byte offsets from the span directly — avoids the
79                // `text.find()` pitfall that always returns the first occurrence.
80                let (start, end) = span.offsets();
81                let span_text = span.text().to_string();
82                let raw_class = span.class();
83                let entity_type = match label_to_type {
84                    Some(map) => map.get(raw_class).copied().unwrap_or(raw_class),
85                    None => raw_class,
86                }
87                .to_string();
88
89                entities.push(ExtractedEntity {
90                    text: span_text,
91                    entity_type,
92                    span_start: start,
93                    span_end: end,
94                    confidence: span.probability() as f64,
95                });
96            }
97        }
98
99        Ok(entities)
100    }
101}
102
103#[derive(Debug, thiserror::Error)]
104pub enum NerError {
105    #[error("invalid path: {0}")]
106    InvalidPath(String),
107
108    #[error("failed to load model: {0}")]
109    ModelLoad(String),
110
111    #[error("inference error: {0}")]
112    Inference(String),
113}