Skip to main content

memvid_core/analysis/
ner.rs

1// Safe expect: Static NER model lookup with guaranteed default.
2#![allow(clippy::expect_used)]
3//! Named Entity Recognition (NER) module using DistilBERT-NER ONNX.
4//!
5//! This module provides entity extraction capabilities using DistilBERT-NER,
6//! a fast and accurate NER model fine-tuned on CoNLL-03.
7//!
8//! # Model
9//!
10//! Uses dslim/distilbert-NER ONNX (~261 MB) with 92% F1 score.
11//! Entities: Person (PER), Organization (ORG), Location (LOC), Miscellaneous (MISC)
12//!
13//! # Simple Interface
14//!
15//! Unlike `GLiNER`, DistilBERT-NER uses standard BERT tokenization:
16//! - Input: `input_ids`, `attention_mask`
17//! - Output: per-token logits for B-PER, I-PER, B-ORG, I-ORG, B-LOC, I-LOC, B-MISC, I-MISC, O
18
19use crate::types::{EntityKind, FrameId};
20use crate::{MemvidError, Result};
21use std::path::{Path, PathBuf};
22
23// ============================================================================
24// Configuration Constants
25// ============================================================================
26
27/// Model name for downloads and caching
28pub const NER_MODEL_NAME: &str = "distilbert-ner";
29
30/// Model download URL (`HuggingFace`)
31pub const NER_MODEL_URL: &str =
32    "https://huggingface.co/dslim/distilbert-NER/resolve/main/onnx/model.onnx";
33
34/// Tokenizer URL
35pub const NER_TOKENIZER_URL: &str =
36    "https://huggingface.co/dslim/distilbert-NER/resolve/main/tokenizer.json";
37
38/// Approximate model size in MB
39pub const NER_MODEL_SIZE_MB: f32 = 261.0;
40
41/// Maximum sequence length for the model
42pub const NER_MAX_SEQ_LEN: usize = 512;
43
44/// Minimum confidence threshold for entity extraction
45#[cfg_attr(not(feature = "logic_mesh"), allow(dead_code))]
46pub const NER_MIN_CONFIDENCE: f32 = 0.5;
47
48/// NER label mapping (CoNLL-03 format)
49/// O=0, B-PER=1, I-PER=2, B-ORG=3, I-ORG=4, B-LOC=5, I-LOC=6, B-MISC=7, I-MISC=8
50#[cfg_attr(not(feature = "logic_mesh"), allow(dead_code))]
51pub const NER_LABELS: &[&str] = &[
52    "O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "B-MISC", "I-MISC",
53];
54
55// ============================================================================
56// Model Info
57// ============================================================================
58
59/// NER model info for the models registry
60#[derive(Debug, Clone)]
61pub struct NerModelInfo {
62    /// Model identifier
63    pub name: &'static str,
64    /// URL for ONNX model
65    pub model_url: &'static str,
66    /// URL for tokenizer JSON
67    pub tokenizer_url: &'static str,
68    /// Model size in MB
69    pub size_mb: f32,
70    /// Maximum sequence length
71    pub max_seq_len: usize,
72    /// Whether this is the default model
73    pub is_default: bool,
74}
75
76/// Available NER models registry
77pub static NER_MODELS: &[NerModelInfo] = &[NerModelInfo {
78    name: NER_MODEL_NAME,
79    model_url: NER_MODEL_URL,
80    tokenizer_url: NER_TOKENIZER_URL,
81    size_mb: NER_MODEL_SIZE_MB,
82    max_seq_len: NER_MAX_SEQ_LEN,
83    is_default: true,
84}];
85
86/// Get NER model info by name
87#[must_use]
88pub fn get_ner_model_info(name: &str) -> Option<&'static NerModelInfo> {
89    NER_MODELS.iter().find(|m| m.name == name)
90}
91
92/// Get default NER model info
93#[must_use]
94pub fn default_ner_model_info() -> &'static NerModelInfo {
95    NER_MODELS
96        .iter()
97        .find(|m| m.is_default)
98        .expect("default NER model must exist")
99}
100
101// ============================================================================
102// Entity Extraction Types
103// ============================================================================
104
105/// Raw entity mention extracted from text
106#[derive(Debug, Clone)]
107pub struct ExtractedEntity {
108    /// The extracted text span
109    pub text: String,
110    /// Entity type (PER, ORG, LOC, MISC)
111    pub entity_type: String,
112    /// Confidence score (0.0-1.0)
113    pub confidence: f32,
114    /// Byte offset start in original text
115    pub byte_start: usize,
116    /// Byte offset end in original text
117    pub byte_end: usize,
118}
119
120impl ExtractedEntity {
121    /// Convert the raw entity type to our `EntityKind` enum
122    #[must_use]
123    pub fn to_entity_kind(&self) -> EntityKind {
124        match self.entity_type.to_uppercase().as_str() {
125            "PER" | "PERSON" | "B-PER" | "I-PER" => EntityKind::Person,
126            "ORG" | "ORGANIZATION" | "B-ORG" | "I-ORG" => EntityKind::Organization,
127            "LOC" | "LOCATION" | "B-LOC" | "I-LOC" => EntityKind::Location,
128            "MISC" | "B-MISC" | "I-MISC" => EntityKind::Other,
129            _ => EntityKind::Other,
130        }
131    }
132}
133
134/// Result of extracting entities from a frame
135#[derive(Debug, Clone)]
136pub struct FrameEntities {
137    /// Frame ID the entities were extracted from
138    pub frame_id: FrameId,
139    /// Extracted entities
140    pub entities: Vec<ExtractedEntity>,
141}
142
143// ============================================================================
144// NER Model (Feature-gated)
145// ============================================================================
146
147#[cfg(feature = "logic_mesh")]
148pub use model_impl::*;
149
150#[cfg(feature = "logic_mesh")]
151mod model_impl {
152    use super::*;
153    use ort::session::{Session, builder::GraphOptimizationLevel};
154    use ort::value::Tensor;
155    use std::sync::Mutex;
156    use tokenizers::{
157        PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer, TruncationDirection,
158        TruncationParams, TruncationStrategy,
159    };
160
161    /// DistilBERT-NER model for entity extraction
162    pub struct NerModel {
163        /// ONNX runtime session
164        session: Session,
165        /// Tokenizer for text preprocessing
166        tokenizer: Mutex<Tokenizer>,
167        /// Model path for reference
168        model_path: PathBuf,
169        /// Minimum confidence threshold
170        min_confidence: f32,
171    }
172
173    impl NerModel {
174        /// Load NER model from path
175        ///
176        /// # Arguments
177        /// * `model_path` - Path to the ONNX model file
178        /// * `tokenizer_path` - Path to the tokenizer.json file
179        /// * `min_confidence` - Minimum confidence threshold (default: 0.5)
180        pub fn load(
181            model_path: impl AsRef<Path>,
182            tokenizer_path: impl AsRef<Path>,
183            min_confidence: Option<f32>,
184        ) -> Result<Self> {
185            let model_path = model_path.as_ref().to_path_buf();
186            let tokenizer_path = tokenizer_path.as_ref();
187
188            // Load tokenizer
189            let mut tokenizer = Tokenizer::from_file(tokenizer_path).map_err(|e| {
190                MemvidError::NerModelNotAvailable {
191                    reason: format!("failed to load tokenizer from {:?}: {}", tokenizer_path, e)
192                        .into(),
193                }
194            })?;
195
196            // Configure padding and truncation
197            tokenizer.with_padding(Some(PaddingParams {
198                strategy: PaddingStrategy::BatchLongest,
199                direction: PaddingDirection::Right,
200                pad_to_multiple_of: None,
201                pad_id: 0,
202                pad_type_id: 0,
203                pad_token: "[PAD]".to_string(),
204            }));
205
206            tokenizer
207                .with_truncation(Some(TruncationParams {
208                    max_length: NER_MAX_SEQ_LEN,
209                    strategy: TruncationStrategy::LongestFirst,
210                    stride: 0,
211                    direction: TruncationDirection::Right,
212                }))
213                .map_err(|e| MemvidError::NerModelNotAvailable {
214                    reason: format!("failed to set truncation: {}", e).into(),
215                })?;
216
217            // Initialize ONNX Runtime
218            let session = Session::builder()
219                .map_err(|e| MemvidError::NerModelNotAvailable {
220                    reason: format!("failed to create session builder: {}", e).into(),
221                })?
222                .with_optimization_level(GraphOptimizationLevel::Level3)
223                .map_err(|e| MemvidError::NerModelNotAvailable {
224                    reason: format!("failed to set optimization level: {}", e).into(),
225                })?
226                .with_intra_threads(4)
227                .map_err(|e| MemvidError::NerModelNotAvailable {
228                    reason: format!("failed to set threads: {}", e).into(),
229                })?
230                .commit_from_file(&model_path)
231                .map_err(|e| MemvidError::NerModelNotAvailable {
232                    reason: format!("failed to load model from {:?}: {}", model_path, e).into(),
233                })?;
234
235            tracing::info!(
236                model = %model_path.display(),
237                "DistilBERT-NER model loaded"
238            );
239
240            Ok(Self {
241                session,
242                tokenizer: Mutex::new(tokenizer),
243                model_path,
244                min_confidence: min_confidence.unwrap_or(NER_MIN_CONFIDENCE),
245            })
246        }
247
248        /// Extract entities from text
249        pub fn extract(&mut self, text: &str) -> Result<Vec<ExtractedEntity>> {
250            if text.trim().is_empty() {
251                return Ok(Vec::new());
252            }
253
254            // Tokenize
255            let tokenizer = self
256                .tokenizer
257                .lock()
258                .map_err(|_| MemvidError::Lock("failed to lock tokenizer".into()))?;
259
260            let encoding =
261                tokenizer
262                    .encode(text, true)
263                    .map_err(|e| MemvidError::NerModelNotAvailable {
264                        reason: format!("tokenization failed: {}", e).into(),
265                    })?;
266
267            let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
268            let attention_mask: Vec<i64> = encoding
269                .get_attention_mask()
270                .iter()
271                .map(|&x| x as i64)
272                .collect();
273            let tokens = encoding.get_tokens().to_vec();
274            let offsets = encoding.get_offsets().to_vec();
275
276            drop(tokenizer); // Release lock before inference
277
278            let seq_len = input_ids.len();
279
280            // Create input tensors using Tensor::from_array
281            let input_ids_array = ndarray::Array2::from_shape_vec((1, seq_len), input_ids)
282                .map_err(|e| MemvidError::NerModelNotAvailable {
283                    reason: format!("failed to create input_ids array: {}", e).into(),
284                })?;
285
286            let attention_mask_array =
287                ndarray::Array2::from_shape_vec((1, seq_len), attention_mask).map_err(|e| {
288                    MemvidError::NerModelNotAvailable {
289                        reason: format!("failed to create attention_mask array: {}", e).into(),
290                    }
291                })?;
292
293            let input_ids_tensor = Tensor::from_array(input_ids_array).map_err(|e| {
294                MemvidError::NerModelNotAvailable {
295                    reason: format!("failed to create input_ids tensor: {}", e).into(),
296                }
297            })?;
298
299            let attention_mask_tensor = Tensor::from_array(attention_mask_array).map_err(|e| {
300                MemvidError::NerModelNotAvailable {
301                    reason: format!("failed to create attention_mask tensor: {}", e).into(),
302                }
303            })?;
304
305            // Get output name before inference (avoid borrow conflict)
306            let output_name = self
307                .session
308                .outputs
309                .first()
310                .map(|o| o.name.clone())
311                .unwrap_or_else(|| "logits".into());
312
313            // Run inference
314            let outputs = self
315                .session
316                .run(ort::inputs![
317                    "input_ids" => input_ids_tensor,
318                    "attention_mask" => attention_mask_tensor,
319                ])
320                .map_err(|e| MemvidError::NerModelNotAvailable {
321                    reason: format!("inference failed: {}", e).into(),
322                })?;
323
324            let logits =
325                outputs
326                    .get(&output_name)
327                    .ok_or_else(|| MemvidError::NerModelNotAvailable {
328                        reason: format!("no output '{}' found", output_name).into(),
329                    })?;
330
331            // Parse logits to get predictions
332            let entities = Self::decode_predictions_static(
333                text,
334                &tokens,
335                &offsets,
336                logits,
337                self.min_confidence,
338            )?;
339
340            Ok(entities)
341        }
342
343        /// Decode model predictions into entities (static to avoid borrow issues)
344        fn decode_predictions_static(
345            original_text: &str,
346            tokens: &[String],
347            offsets: &[(usize, usize)],
348            logits: &ort::value::Value,
349            min_confidence: f32,
350        ) -> Result<Vec<ExtractedEntity>> {
351            // Extract the logits tensor - shape: [1, seq_len, num_labels]
352            let (shape, data) = logits.try_extract_tensor::<f32>().map_err(|e| {
353                MemvidError::NerModelNotAvailable {
354                    reason: format!("failed to extract logits: {}", e).into(),
355                }
356            })?;
357
358            // Shape is iterable, convert to Vec
359            let shape_vec: Vec<i64> = shape.iter().copied().collect();
360
361            if shape_vec.len() != 3 {
362                return Err(MemvidError::NerModelNotAvailable {
363                    reason: format!("unexpected logits shape: {:?}", shape_vec).into(),
364                });
365            }
366
367            let seq_len = shape_vec[1] as usize;
368            let num_labels = shape_vec[2] as usize;
369
370            // Helper to index into flat data: [batch, seq, labels] -> flat index
371            let idx = |i: usize, j: usize| -> usize { i * num_labels + j };
372
373            let mut entities = Vec::new();
374            let mut current_entity: Option<(String, usize, usize, f32)> = None;
375
376            for i in 0..seq_len {
377                if i >= tokens.len() || i >= offsets.len() {
378                    break;
379                }
380
381                // Skip special tokens
382                let token = &tokens[i];
383                if token == "[CLS]" || token == "[SEP]" || token == "[PAD]" {
384                    // Finalize any current entity
385                    if let Some((entity_type, start, end, conf)) = current_entity.take() {
386                        if end > start && end <= original_text.len() {
387                            let text = original_text[start..end].trim().to_string();
388                            if !text.is_empty() {
389                                entities.push(ExtractedEntity {
390                                    text,
391                                    entity_type,
392                                    confidence: conf,
393                                    byte_start: start,
394                                    byte_end: end,
395                                });
396                            }
397                        }
398                    }
399                    continue;
400                }
401
402                // Get prediction for this token
403                let mut max_score = f32::NEG_INFINITY;
404                let mut max_label = 0usize;
405
406                for j in 0..num_labels {
407                    let score = data[idx(i, j)];
408                    if score > max_score {
409                        max_score = score;
410                        max_label = j;
411                    }
412                }
413
414                // Apply softmax to get confidence
415                let mut exp_sum = 0.0f32;
416                for j in 0..num_labels {
417                    exp_sum += (data[idx(i, j)] - max_score).exp();
418                }
419                let confidence = 1.0 / exp_sum;
420
421                let label = NER_LABELS.get(max_label).unwrap_or(&"O");
422                let (start_offset, end_offset) = offsets[i];
423
424                if *label == "O" || confidence < min_confidence {
425                    // End any current entity
426                    if let Some((entity_type, start, end, conf)) = current_entity.take() {
427                        if end > start && end <= original_text.len() {
428                            let text = original_text[start..end].trim().to_string();
429                            if !text.is_empty() {
430                                entities.push(ExtractedEntity {
431                                    text,
432                                    entity_type,
433                                    confidence: conf,
434                                    byte_start: start,
435                                    byte_end: end,
436                                });
437                            }
438                        }
439                    }
440                } else if label.starts_with("B-") {
441                    // Start new entity (end previous if any)
442                    if let Some((entity_type, start, end, conf)) = current_entity.take() {
443                        if end > start && end <= original_text.len() {
444                            let text = original_text[start..end].trim().to_string();
445                            if !text.is_empty() {
446                                entities.push(ExtractedEntity {
447                                    text,
448                                    entity_type,
449                                    confidence: conf,
450                                    byte_start: start,
451                                    byte_end: end,
452                                });
453                            }
454                        }
455                    }
456                    let entity_type = label[2..].to_string(); // Remove "B-" prefix
457                    current_entity = Some((entity_type, start_offset, end_offset, confidence));
458                } else if label.starts_with("I-") {
459                    // Continue entity
460                    if let Some((ref entity_type, start, _, ref mut conf)) = current_entity {
461                        let expected_type = &label[2..];
462                        if entity_type == expected_type {
463                            current_entity = Some((
464                                entity_type.clone(),
465                                start,
466                                end_offset,
467                                (*conf + confidence) / 2.0,
468                            ));
469                        }
470                    }
471                }
472            }
473
474            // Finalize last entity
475            if let Some((entity_type, start, end, conf)) = current_entity {
476                if end > start && end <= original_text.len() {
477                    let text = original_text[start..end].trim().to_string();
478                    if !text.is_empty() {
479                        entities.push(ExtractedEntity {
480                            text,
481                            entity_type,
482                            confidence: conf,
483                            byte_start: start,
484                            byte_end: end,
485                        });
486                    }
487                }
488            }
489
490            Ok(entities)
491        }
492
493        /// Extract entities from a frame's content
494        pub fn extract_from_frame(
495            &mut self,
496            frame_id: FrameId,
497            content: &str,
498        ) -> Result<FrameEntities> {
499            let entities = self.extract(content)?;
500            Ok(FrameEntities { frame_id, entities })
501        }
502
503        /// Get minimum confidence threshold
504        pub fn min_confidence(&self) -> f32 {
505            self.min_confidence
506        }
507
508        /// Get model path
509        pub fn model_path(&self) -> &Path {
510            &self.model_path
511        }
512    }
513
514    impl std::fmt::Debug for NerModel {
515        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
516            f.debug_struct("NerModel")
517                .field("model_path", &self.model_path)
518                .field("min_confidence", &self.min_confidence)
519                .finish()
520        }
521    }
522}
523
524// ============================================================================
525// Stub Implementation (when feature is disabled)
526// ============================================================================
527
528#[cfg(not(feature = "logic_mesh"))]
529#[allow(dead_code)]
530pub struct NerModel {
531    _private: (),
532}
533
534#[cfg(not(feature = "logic_mesh"))]
535#[allow(dead_code)]
536impl NerModel {
537    pub fn load(
538        _model_path: impl AsRef<Path>,
539        _tokenizer_path: impl AsRef<Path>,
540        _min_confidence: Option<f32>,
541    ) -> Result<Self> {
542        Err(MemvidError::FeatureUnavailable {
543            feature: "logic_mesh",
544        })
545    }
546
547    pub fn extract(&self, _text: &str) -> Result<Vec<ExtractedEntity>> {
548        Err(MemvidError::FeatureUnavailable {
549            feature: "logic_mesh",
550        })
551    }
552
553    pub fn extract_from_frame(&self, _frame_id: FrameId, _content: &str) -> Result<FrameEntities> {
554        Err(MemvidError::FeatureUnavailable {
555            feature: "logic_mesh",
556        })
557    }
558}
559
560// ============================================================================
561// Model Path Utilities
562// ============================================================================
563
564/// Get the expected path for the NER model in the models directory
565#[must_use]
566pub fn ner_model_path(models_dir: &Path) -> PathBuf {
567    models_dir.join(NER_MODEL_NAME).join("model.onnx")
568}
569
570/// Get the expected path for the NER tokenizer in the models directory
571#[must_use]
572pub fn ner_tokenizer_path(models_dir: &Path) -> PathBuf {
573    models_dir.join(NER_MODEL_NAME).join("tokenizer.json")
574}
575
576/// Check if NER model is installed
577#[must_use]
578pub fn is_ner_model_installed(models_dir: &Path) -> bool {
579    ner_model_path(models_dir).exists() && ner_tokenizer_path(models_dir).exists()
580}
581
582// ============================================================================
583// Tests
584// ============================================================================
585
586#[cfg(test)]
587mod tests {
588    use super::*;
589
590    #[test]
591    fn test_entity_kind_mapping() {
592        let cases = vec![
593            ("PER", EntityKind::Person),
594            ("B-PER", EntityKind::Person),
595            ("I-PER", EntityKind::Person),
596            ("ORG", EntityKind::Organization),
597            ("B-ORG", EntityKind::Organization),
598            ("LOC", EntityKind::Location),
599            ("B-LOC", EntityKind::Location),
600            ("MISC", EntityKind::Other),
601            ("B-MISC", EntityKind::Other),
602            ("unknown", EntityKind::Other),
603        ];
604
605        for (entity_type, expected_kind) in cases {
606            let entity = ExtractedEntity {
607                text: "test".to_string(),
608                entity_type: entity_type.to_string(),
609                confidence: 0.9,
610                byte_start: 0,
611                byte_end: 4,
612            };
613            assert_eq!(
614                entity.to_entity_kind(),
615                expected_kind,
616                "Failed for entity_type: {}",
617                entity_type
618            );
619        }
620    }
621
622    #[test]
623    fn test_model_info() {
624        let info = default_ner_model_info();
625        assert_eq!(info.name, NER_MODEL_NAME);
626        assert!(info.is_default);
627        assert!(info.size_mb > 200.0);
628    }
629
630    #[test]
631    fn test_model_paths() {
632        let models_dir = PathBuf::from("/tmp/models");
633        let model_path = ner_model_path(&models_dir);
634        let tokenizer_path = ner_tokenizer_path(&models_dir);
635
636        assert!(model_path.to_string_lossy().contains("model.onnx"));
637        assert!(tokenizer_path.to_string_lossy().contains("tokenizer.json"));
638    }
639
640    #[test]
641    fn test_ner_labels() {
642        assert_eq!(NER_LABELS.len(), 9);
643        assert_eq!(NER_LABELS[0], "O");
644        assert_eq!(NER_LABELS[1], "B-PER");
645        assert_eq!(NER_LABELS[3], "B-ORG");
646        assert_eq!(NER_LABELS[5], "B-LOC");
647        assert_eq!(NER_LABELS[7], "B-MISC");
648    }
649}