1use std::path::Path;
2
3use gliner::model::input::text::TextInput;
4use gliner::model::params::Parameters;
5use gliner::model::pipeline::span::SpanMode;
6use gliner::model::GLiNER;
7use orp::params::RuntimeParameters;
8
9#[derive(Debug, Clone)]
11pub struct ExtractedEntity {
12 pub text: String,
13 pub entity_type: String,
14 pub span_start: usize,
15 pub span_end: usize,
16 pub confidence: f64,
17}
18
19pub struct NerEngine {
23 model: GLiNER<SpanMode>,
24}
25
26impl NerEngine {
27 pub fn new(model_path: &Path, tokenizer_path: &Path, threshold: f32) -> Result<Self, NerError> {
35 let params = Parameters::default().with_threshold(threshold);
36 let runtime_params = RuntimeParameters::default();
37
38 let model = GLiNER::<SpanMode>::new(
39 params,
40 runtime_params,
41 tokenizer_path.to_str().ok_or(NerError::InvalidPath(
42 tokenizer_path.display().to_string(),
43 ))?,
44 model_path
45 .to_str()
46 .ok_or(NerError::InvalidPath(model_path.display().to_string()))?,
47 )
48 .map_err(|e| NerError::ModelLoad(e.to_string()))?;
49
50 Ok(Self { model })
51 }
52
53 pub fn extract(
60 &self,
61 text: &str,
62 labels: &[&str],
63 label_to_type: Option<&std::collections::HashMap<&str, &str>>,
64 ) -> Result<Vec<ExtractedEntity>, NerError> {
65 let input = TextInput::from_str(&[text], labels)
66 .map_err(|e| NerError::Inference(e.to_string()))?;
67
68 let output = self
69 .model
70 .inference(input)
71 .map_err(|e| NerError::Inference(e.to_string()))?;
72
73 let mut entities = Vec::new();
74
75 for sequence_spans in &output.spans {
77 for span in sequence_spans {
78 let (start, end) = span.offsets();
81 let span_text = span.text().to_string();
82 let raw_class = span.class();
83 let entity_type = match label_to_type {
84 Some(map) => map.get(raw_class).copied().unwrap_or(raw_class),
85 None => raw_class,
86 }
87 .to_string();
88
89 entities.push(ExtractedEntity {
90 text: span_text,
91 entity_type,
92 span_start: start,
93 span_end: end,
94 confidence: span.probability() as f64,
95 });
96 }
97 }
98
99 Ok(entities)
100 }
101}
102
103#[derive(Debug, thiserror::Error)]
104pub enum NerError {
105 #[error("invalid path: {0}")]
106 InvalidPath(String),
107
108 #[error("failed to load model: {0}")]
109 ModelLoad(String),
110
111 #[error("inference error: {0}")]
112 Inference(String),
113}