anno/backends/
extractor.rs1use crate::{Entity, EntityType, Model, RegexNER, Result};
22use std::sync::Arc;
23
24pub mod defaults {
26 pub const BERT_ONNX: &str = "protectai/bert-base-NER-onnx";
28
29 pub const GLINER_SMALL: &str = "onnx-community/gliner_small-v2.1";
31
32 pub const GLINER_MEDIUM: &str = "onnx-community/gliner_medium-v2.1";
34
35 pub const GLINER_LARGE: &str = "onnx-community/gliner_large-v2.1";
37
38 pub const GLINER_MULTITASK: &str = "onnx-community/gliner-multitask-large-v0.5";
40
41 pub const CANDLE_BERT: &str = "dslim/bert-base-NER";
43}
44
45pub const STANDARD_ENTITY_TYPES: &[&str] = &[
47 "person",
48 "organization",
49 "location",
50 "date",
51 "money",
52 "percent",
53];
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
57#[non_exhaustive]
58pub enum BackendType {
59 GLiNER,
61 BertOnnx,
63 Candle,
65 Pattern,
67}
68
69impl BackendType {
70 #[must_use]
72 pub fn name(&self) -> &'static str {
73 match self {
74 BackendType::GLiNER => "gliner",
75 BackendType::BertOnnx => "bert-onnx",
76 BackendType::Candle => "candle",
77 BackendType::Pattern => "pattern",
78 }
79 }
80
81 #[must_use]
83 pub fn requires_network(&self) -> bool {
84 !matches!(self, BackendType::Pattern)
85 }
86
87 #[must_use]
89 pub fn supports_zero_shot(&self) -> bool {
90 matches!(self, BackendType::GLiNER)
91 }
92}
93
94pub struct NERExtractor {
116 primary: Option<Arc<dyn Model>>,
118 fallback: Arc<RegexNER>,
120 backend_type: BackendType,
122}
123
124impl NERExtractor {
125 pub fn new(primary: Option<Arc<dyn Model>>, backend_type: BackendType) -> Self {
127 Self {
128 primary,
129 fallback: Arc::new(RegexNER::new()),
130 backend_type,
131 }
132 }
133
134 #[must_use]
139 pub fn pattern_only() -> Self {
140 Self {
141 primary: None,
142 fallback: Arc::new(RegexNER::new()),
143 backend_type: BackendType::Pattern,
144 }
145 }
146
147 #[must_use]
155 pub fn best_available() -> Self {
156 #[cfg(feature = "onnx")]
158 {
159 if let Ok(extractor) = Self::with_gliner(defaults::GLINER_SMALL) {
160 log::info!("[NER] Using GLiNER Small (zero-shot)");
161 return extractor;
162 }
163 log::warn!("[NER] GLiNER init failed, trying BERT ONNX");
164
165 if let Ok(extractor) = Self::with_bert_onnx(defaults::BERT_ONNX) {
167 log::info!("[NER] Using BERT ONNX");
168 return extractor;
169 }
170 log::warn!("[NER] BERT ONNX init failed, trying Candle");
171 }
172
173 #[cfg(feature = "candle")]
175 {
176 if let Ok(extractor) = Self::with_candle(defaults::CANDLE_BERT) {
177 log::info!("[NER] Using Candle");
178 return extractor;
179 }
180 log::warn!("[NER] Candle init failed, falling back to patterns");
181 }
182
183 log::info!("[NER] Using RegexNER (structured entities only)");
185 Self::pattern_only()
186 }
187
188 #[must_use]
194 pub fn fast() -> Self {
195 #[cfg(feature = "onnx")]
196 {
197 if let Ok(extractor) = Self::with_gliner(defaults::GLINER_SMALL) {
198 log::info!("[NER] Using GLiNER Small (fast mode)");
199 return extractor;
200 }
201 }
202 log::info!("[NER] Using RegexNER (structured entities only)");
203 Self::pattern_only()
204 }
205
206 #[must_use]
214 pub fn best_quality() -> Self {
215 #[cfg(feature = "onnx")]
216 {
217 if let Ok(extractor) = Self::with_gliner(defaults::GLINER_LARGE) {
218 log::info!("[NER] Using GLiNER Large (best quality)");
219 return extractor;
220 }
221 if let Ok(extractor) = Self::with_gliner(defaults::GLINER_MEDIUM) {
222 log::info!("[NER] Using GLiNER Medium");
223 return extractor;
224 }
225 if let Ok(extractor) = Self::with_bert_onnx(defaults::BERT_ONNX) {
226 log::info!("[NER] Using BERT ONNX");
227 return extractor;
228 }
229 }
230 log::info!("[NER] Using RegexNER (structured entities only)");
231 Self::pattern_only()
232 }
233
234 #[cfg(feature = "onnx")]
242 pub fn with_bert_onnx(model_name: &str) -> Result<Self> {
243 let bert = crate::backends::BertNEROnnx::new(model_name)?;
244 Ok(Self {
245 primary: Some(Arc::new(bert)),
246 fallback: Arc::new(RegexNER::new()),
247 backend_type: BackendType::BertOnnx,
248 })
249 }
250
251 #[cfg(not(feature = "onnx"))]
253 pub fn with_bert_onnx(_model_name: &str) -> Result<Self> {
254 Ok(Self::pattern_only())
255 }
256
257 #[cfg(feature = "onnx")]
265 pub fn with_gliner(model_name: &str) -> Result<Self> {
266 let gliner = crate::backends::GLiNEROnnx::new(model_name)?;
267 Ok(Self {
268 primary: Some(Arc::new(gliner)),
269 fallback: Arc::new(RegexNER::new()),
270 backend_type: BackendType::GLiNER,
271 })
272 }
273
274 #[cfg(not(feature = "onnx"))]
276 pub fn with_gliner(_model_name: &str) -> Result<Self> {
277 Ok(Self::pattern_only())
278 }
279
280 #[cfg(feature = "candle")]
287 pub fn with_candle(model_name: &str) -> Result<Self> {
288 let candle = crate::backends::CandleNER::new(model_name)?;
289 Ok(Self {
290 primary: Some(Arc::new(candle)),
291 fallback: Arc::new(RegexNER::new()),
292 backend_type: BackendType::Candle,
293 })
294 }
295
296 #[cfg(not(feature = "candle"))]
298 pub fn with_candle(_model_name: &str) -> Result<Self> {
299 Ok(Self::pattern_only())
300 }
301
302 pub fn extract(&self, text: &str, language: Option<&str>) -> Result<Vec<Entity>> {
306 if let Some(ref primary) = self.primary {
308 if primary.is_available() {
309 match primary.extract_entities(text, language) {
310 Ok(entities) if !entities.is_empty() => return Ok(entities),
311 Ok(_) => {
312 log::debug!("[NER] Primary returned empty, using fallback");
313 }
314 Err(e) => {
315 log::debug!("[NER] Primary failed ({}), using fallback", e);
316 }
317 }
318 }
319 }
320
321 self.fallback.extract_entities(text, language)
323 }
324
325 pub fn extract_hybrid(&self, text: &str, language: Option<&str>) -> Result<Vec<Entity>> {
335 let mut entities = Vec::with_capacity(16);
337
338 if let Some(ref primary) = self.primary {
340 if primary.is_available() {
341 if let Ok(ml_entities) = primary.extract_entities(text, language) {
342 entities.extend(
344 ml_entities
345 .into_iter()
346 .filter(|e| e.entity_type.requires_ml()),
347 );
348 }
349 }
350 }
351
352 if let Ok(pattern_entities) = self.fallback.extract_entities(text, language) {
354 for pe in pattern_entities {
356 let overlaps = entities.iter().any(|e| {
357 !(e.end <= pe.start || pe.end <= e.start)
359 });
360 if !overlaps {
361 entities.push(pe);
362 }
363 }
364 }
365
366 entities.sort_unstable_by_key(|e| e.start);
369
370 Ok(entities)
371 }
372
373 #[must_use]
375 pub fn backend_type(&self) -> BackendType {
376 self.backend_type
377 }
378
379 #[must_use]
381 pub fn active_backend_name(&self) -> &'static str {
382 if let Some(ref primary) = self.primary {
383 if primary.is_available() {
384 return primary.name();
385 }
386 }
387 self.fallback.name()
388 }
389
390 #[must_use]
392 pub fn has_ml_backend(&self) -> bool {
393 self.primary.as_ref().is_some_and(|p| p.is_available())
394 }
395
396 #[must_use]
398 pub fn supports_zero_shot(&self) -> bool {
399 self.backend_type.supports_zero_shot()
400 }
401}
402
403impl Model for NERExtractor {
405 fn extract_entities(&self, text: &str, language: Option<&str>) -> Result<Vec<Entity>> {
406 self.extract(text, language)
407 }
408
409 fn supported_types(&self) -> Vec<EntityType> {
410 if let Some(ref primary) = self.primary {
411 if primary.is_available() {
412 return primary.supported_types();
413 }
414 }
415 self.fallback.supported_types()
416 }
417
418 fn is_available(&self) -> bool {
419 true }
421
422 fn name(&self) -> &'static str {
423 self.active_backend_name()
424 }
425
426 fn description(&self) -> &'static str {
427 match self.backend_type {
428 BackendType::GLiNER => "GLiNER zero-shot NER (ONNX/Candle backends)",
429 BackendType::BertOnnx => "BERT NER via ONNX Runtime",
430 BackendType::Candle => "BERT NER via Candle (Rust-native)",
431 BackendType::Pattern => "Regex-based NER (structured entities only)",
432 }
433 }
434}
435
436#[cfg(test)]
437mod tests {
438 use super::*;
439
440 #[test]
441 fn test_pattern_only() {
442 let extractor = NERExtractor::pattern_only();
443 assert_eq!(extractor.backend_type(), BackendType::Pattern);
444 assert!(!extractor.has_ml_backend());
445 assert!(!extractor.supports_zero_shot());
446 }
447
448 #[test]
449 fn test_best_available_always_works() {
450 let extractor = NERExtractor::best_available();
452 assert!(extractor.is_available());
453
454 let text = "Meeting on 2024-01-15 cost $100.";
456 let entities = extractor.extract(text, None).unwrap();
457 let has_date = entities
458 .iter()
459 .any(|e| matches!(e.entity_type, EntityType::Date));
460 let has_money = entities
461 .iter()
462 .any(|e| matches!(e.entity_type, EntityType::Money));
463 assert!(has_date || has_money, "Should find pattern entities");
464 }
465
466 #[test]
467 fn test_backend_type_properties() {
468 assert!(BackendType::GLiNER.requires_network());
469 assert!(BackendType::BertOnnx.requires_network());
470 assert!(BackendType::Candle.requires_network());
471 assert!(!BackendType::Pattern.requires_network());
472
473 assert!(BackendType::GLiNER.supports_zero_shot());
474 assert!(!BackendType::BertOnnx.supports_zero_shot());
475 assert!(!BackendType::Candle.supports_zero_shot());
476 assert!(!BackendType::Pattern.supports_zero_shot());
477 }
478
479 #[test]
480 fn test_extract_hybrid() {
481 let extractor = NERExtractor::pattern_only();
482 let text = "Meeting at 3:30 PM cost $50.";
483 let entities = extractor.extract_hybrid(text, None).unwrap();
484
485 assert!(!entities.is_empty());
487 }
488}