memvid_core/analysis/
ner.rs

1//! Named Entity Recognition (NER) module using DistilBERT-NER ONNX.
2//!
3//! This module provides entity extraction capabilities using DistilBERT-NER,
4//! a fast and accurate NER model fine-tuned on CoNLL-03.
5//!
6//! # Model
7//!
8//! Uses dslim/distilbert-NER ONNX (~261 MB) with 92% F1 score.
9//! Entities: Person (PER), Organization (ORG), Location (LOC), Miscellaneous (MISC)
10//!
11//! # Simple Interface
12//!
13//! Unlike GLiNER, DistilBERT-NER uses standard BERT tokenization:
14//! - Input: `input_ids`, `attention_mask`
15//! - Output: per-token logits for B-PER, I-PER, B-ORG, I-ORG, B-LOC, I-LOC, B-MISC, I-MISC, O
16
17use crate::types::{EntityKind, FrameId};
18use crate::{MemvidError, Result};
19use std::path::{Path, PathBuf};
20
21// ============================================================================
22// Configuration Constants
23// ============================================================================
24
25/// Model name for downloads and caching
26pub const NER_MODEL_NAME: &str = "distilbert-ner";
27
28/// Model download URL (HuggingFace)
29pub const NER_MODEL_URL: &str =
30    "https://huggingface.co/dslim/distilbert-NER/resolve/main/onnx/model.onnx";
31
32/// Tokenizer URL
33pub const NER_TOKENIZER_URL: &str =
34    "https://huggingface.co/dslim/distilbert-NER/resolve/main/tokenizer.json";
35
36/// Approximate model size in MB
37pub const NER_MODEL_SIZE_MB: f32 = 261.0;
38
39/// Maximum sequence length for the model
40pub const NER_MAX_SEQ_LEN: usize = 512;
41
42/// Minimum confidence threshold for entity extraction
43#[cfg_attr(not(feature = "logic_mesh"), allow(dead_code))]
44pub const NER_MIN_CONFIDENCE: f32 = 0.5;
45
46/// NER label mapping (CoNLL-03 format)
47/// O=0, B-PER=1, I-PER=2, B-ORG=3, I-ORG=4, B-LOC=5, I-LOC=6, B-MISC=7, I-MISC=8
48#[cfg_attr(not(feature = "logic_mesh"), allow(dead_code))]
49pub const NER_LABELS: &[&str] = &[
50    "O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "B-MISC", "I-MISC",
51];
52
53// ============================================================================
54// Model Info
55// ============================================================================
56
57/// NER model info for the models registry
58#[derive(Debug, Clone)]
59pub struct NerModelInfo {
60    /// Model identifier
61    pub name: &'static str,
62    /// URL for ONNX model
63    pub model_url: &'static str,
64    /// URL for tokenizer JSON
65    pub tokenizer_url: &'static str,
66    /// Model size in MB
67    pub size_mb: f32,
68    /// Maximum sequence length
69    pub max_seq_len: usize,
70    /// Whether this is the default model
71    pub is_default: bool,
72}
73
74/// Available NER models registry
75pub static NER_MODELS: &[NerModelInfo] = &[NerModelInfo {
76    name: NER_MODEL_NAME,
77    model_url: NER_MODEL_URL,
78    tokenizer_url: NER_TOKENIZER_URL,
79    size_mb: NER_MODEL_SIZE_MB,
80    max_seq_len: NER_MAX_SEQ_LEN,
81    is_default: true,
82}];
83
84/// Get NER model info by name
85pub fn get_ner_model_info(name: &str) -> Option<&'static NerModelInfo> {
86    NER_MODELS.iter().find(|m| m.name == name)
87}
88
89/// Get default NER model info
90pub fn default_ner_model_info() -> &'static NerModelInfo {
91    NER_MODELS
92        .iter()
93        .find(|m| m.is_default)
94        .expect("default NER model must exist")
95}
96
97// ============================================================================
98// Entity Extraction Types
99// ============================================================================
100
101/// Raw entity mention extracted from text
102#[derive(Debug, Clone)]
103pub struct ExtractedEntity {
104    /// The extracted text span
105    pub text: String,
106    /// Entity type (PER, ORG, LOC, MISC)
107    pub entity_type: String,
108    /// Confidence score (0.0-1.0)
109    pub confidence: f32,
110    /// Byte offset start in original text
111    pub byte_start: usize,
112    /// Byte offset end in original text
113    pub byte_end: usize,
114}
115
116impl ExtractedEntity {
117    /// Convert the raw entity type to our EntityKind enum
118    pub fn to_entity_kind(&self) -> EntityKind {
119        match self.entity_type.to_uppercase().as_str() {
120            "PER" | "PERSON" | "B-PER" | "I-PER" => EntityKind::Person,
121            "ORG" | "ORGANIZATION" | "B-ORG" | "I-ORG" => EntityKind::Organization,
122            "LOC" | "LOCATION" | "B-LOC" | "I-LOC" => EntityKind::Location,
123            "MISC" | "B-MISC" | "I-MISC" => EntityKind::Other,
124            _ => EntityKind::Other,
125        }
126    }
127}
128
129/// Result of extracting entities from a frame
130#[derive(Debug, Clone)]
131pub struct FrameEntities {
132    /// Frame ID the entities were extracted from
133    pub frame_id: FrameId,
134    /// Extracted entities
135    pub entities: Vec<ExtractedEntity>,
136}
137
138// ============================================================================
139// NER Model (Feature-gated)
140// ============================================================================
141
142#[cfg(feature = "logic_mesh")]
143pub use model_impl::*;
144
145#[cfg(feature = "logic_mesh")]
146mod model_impl {
147    use super::*;
148    use ort::session::{Session, builder::GraphOptimizationLevel};
149    use ort::value::Tensor;
150    use std::sync::Mutex;
151    use tokenizers::{
152        PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer, TruncationDirection,
153        TruncationParams, TruncationStrategy,
154    };
155
156    /// DistilBERT-NER model for entity extraction
157    pub struct NerModel {
158        /// ONNX runtime session
159        session: Session,
160        /// Tokenizer for text preprocessing
161        tokenizer: Mutex<Tokenizer>,
162        /// Model path for reference
163        model_path: PathBuf,
164        /// Minimum confidence threshold
165        min_confidence: f32,
166    }
167
168    impl NerModel {
169        /// Load NER model from path
170        ///
171        /// # Arguments
172        /// * `model_path` - Path to the ONNX model file
173        /// * `tokenizer_path` - Path to the tokenizer.json file
174        /// * `min_confidence` - Minimum confidence threshold (default: 0.5)
175        pub fn load(
176            model_path: impl AsRef<Path>,
177            tokenizer_path: impl AsRef<Path>,
178            min_confidence: Option<f32>,
179        ) -> Result<Self> {
180            let model_path = model_path.as_ref().to_path_buf();
181            let tokenizer_path = tokenizer_path.as_ref();
182
183            // Load tokenizer
184            let mut tokenizer = Tokenizer::from_file(tokenizer_path).map_err(|e| {
185                MemvidError::NerModelNotAvailable {
186                    reason: format!("failed to load tokenizer from {:?}: {}", tokenizer_path, e)
187                        .into(),
188                }
189            })?;
190
191            // Configure padding and truncation
192            tokenizer.with_padding(Some(PaddingParams {
193                strategy: PaddingStrategy::BatchLongest,
194                direction: PaddingDirection::Right,
195                pad_to_multiple_of: None,
196                pad_id: 0,
197                pad_type_id: 0,
198                pad_token: "[PAD]".to_string(),
199            }));
200
201            tokenizer
202                .with_truncation(Some(TruncationParams {
203                    max_length: NER_MAX_SEQ_LEN,
204                    strategy: TruncationStrategy::LongestFirst,
205                    stride: 0,
206                    direction: TruncationDirection::Right,
207                }))
208                .map_err(|e| MemvidError::NerModelNotAvailable {
209                    reason: format!("failed to set truncation: {}", e).into(),
210                })?;
211
212            // Initialize ONNX Runtime
213            let session = Session::builder()
214                .map_err(|e| MemvidError::NerModelNotAvailable {
215                    reason: format!("failed to create session builder: {}", e).into(),
216                })?
217                .with_optimization_level(GraphOptimizationLevel::Level3)
218                .map_err(|e| MemvidError::NerModelNotAvailable {
219                    reason: format!("failed to set optimization level: {}", e).into(),
220                })?
221                .with_intra_threads(4)
222                .map_err(|e| MemvidError::NerModelNotAvailable {
223                    reason: format!("failed to set threads: {}", e).into(),
224                })?
225                .commit_from_file(&model_path)
226                .map_err(|e| MemvidError::NerModelNotAvailable {
227                    reason: format!("failed to load model from {:?}: {}", model_path, e).into(),
228                })?;
229
230            tracing::info!(
231                model = %model_path.display(),
232                "DistilBERT-NER model loaded"
233            );
234
235            Ok(Self {
236                session,
237                tokenizer: Mutex::new(tokenizer),
238                model_path,
239                min_confidence: min_confidence.unwrap_or(NER_MIN_CONFIDENCE),
240            })
241        }
242
243        /// Extract entities from text
244        pub fn extract(&mut self, text: &str) -> Result<Vec<ExtractedEntity>> {
245            if text.trim().is_empty() {
246                return Ok(Vec::new());
247            }
248
249            // Tokenize
250            let tokenizer = self
251                .tokenizer
252                .lock()
253                .map_err(|_| MemvidError::Lock("failed to lock tokenizer".into()))?;
254
255            let encoding =
256                tokenizer
257                    .encode(text, true)
258                    .map_err(|e| MemvidError::NerModelNotAvailable {
259                        reason: format!("tokenization failed: {}", e).into(),
260                    })?;
261
262            let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
263            let attention_mask: Vec<i64> = encoding
264                .get_attention_mask()
265                .iter()
266                .map(|&x| x as i64)
267                .collect();
268            let tokens = encoding.get_tokens().to_vec();
269            let offsets = encoding.get_offsets().to_vec();
270
271            drop(tokenizer); // Release lock before inference
272
273            let seq_len = input_ids.len();
274
275            // Create input tensors using Tensor::from_array
276            let input_ids_array = ndarray::Array2::from_shape_vec((1, seq_len), input_ids)
277                .map_err(|e| MemvidError::NerModelNotAvailable {
278                    reason: format!("failed to create input_ids array: {}", e).into(),
279                })?;
280
281            let attention_mask_array =
282                ndarray::Array2::from_shape_vec((1, seq_len), attention_mask).map_err(|e| {
283                    MemvidError::NerModelNotAvailable {
284                        reason: format!("failed to create attention_mask array: {}", e).into(),
285                    }
286                })?;
287
288            let input_ids_tensor = Tensor::from_array(input_ids_array).map_err(|e| {
289                MemvidError::NerModelNotAvailable {
290                    reason: format!("failed to create input_ids tensor: {}", e).into(),
291                }
292            })?;
293
294            let attention_mask_tensor = Tensor::from_array(attention_mask_array).map_err(|e| {
295                MemvidError::NerModelNotAvailable {
296                    reason: format!("failed to create attention_mask tensor: {}", e).into(),
297                }
298            })?;
299
300            // Get output name before inference (avoid borrow conflict)
301            let output_name = self
302                .session
303                .outputs
304                .first()
305                .map(|o| o.name.clone())
306                .unwrap_or_else(|| "logits".into());
307
308            // Run inference
309            let outputs = self
310                .session
311                .run(ort::inputs![
312                    "input_ids" => input_ids_tensor,
313                    "attention_mask" => attention_mask_tensor,
314                ])
315                .map_err(|e| MemvidError::NerModelNotAvailable {
316                    reason: format!("inference failed: {}", e).into(),
317                })?;
318
319            let logits =
320                outputs
321                    .get(&output_name)
322                    .ok_or_else(|| MemvidError::NerModelNotAvailable {
323                        reason: format!("no output '{}' found", output_name).into(),
324                    })?;
325
326            // Parse logits to get predictions
327            let entities = Self::decode_predictions_static(
328                text,
329                &tokens,
330                &offsets,
331                logits,
332                self.min_confidence,
333            )?;
334
335            Ok(entities)
336        }
337
338        /// Decode model predictions into entities (static to avoid borrow issues)
339        fn decode_predictions_static(
340            original_text: &str,
341            tokens: &[String],
342            offsets: &[(usize, usize)],
343            logits: &ort::value::Value,
344            min_confidence: f32,
345        ) -> Result<Vec<ExtractedEntity>> {
346            // Extract the logits tensor - shape: [1, seq_len, num_labels]
347            let (shape, data) = logits.try_extract_tensor::<f32>().map_err(|e| {
348                MemvidError::NerModelNotAvailable {
349                    reason: format!("failed to extract logits: {}", e).into(),
350                }
351            })?;
352
353            // Shape is iterable, convert to Vec
354            let shape_vec: Vec<i64> = shape.iter().copied().collect();
355
356            if shape_vec.len() != 3 {
357                return Err(MemvidError::NerModelNotAvailable {
358                    reason: format!("unexpected logits shape: {:?}", shape_vec).into(),
359                });
360            }
361
362            let seq_len = shape_vec[1] as usize;
363            let num_labels = shape_vec[2] as usize;
364
365            // Helper to index into flat data: [batch, seq, labels] -> flat index
366            let idx = |i: usize, j: usize| -> usize { i * num_labels + j };
367
368            let mut entities = Vec::new();
369            let mut current_entity: Option<(String, usize, usize, f32)> = None;
370
371            for i in 0..seq_len {
372                if i >= tokens.len() || i >= offsets.len() {
373                    break;
374                }
375
376                // Skip special tokens
377                let token = &tokens[i];
378                if token == "[CLS]" || token == "[SEP]" || token == "[PAD]" {
379                    // Finalize any current entity
380                    if let Some((entity_type, start, end, conf)) = current_entity.take() {
381                        if end > start && end <= original_text.len() {
382                            let text = original_text[start..end].trim().to_string();
383                            if !text.is_empty() {
384                                entities.push(ExtractedEntity {
385                                    text,
386                                    entity_type,
387                                    confidence: conf,
388                                    byte_start: start,
389                                    byte_end: end,
390                                });
391                            }
392                        }
393                    }
394                    continue;
395                }
396
397                // Get prediction for this token
398                let mut max_score = f32::NEG_INFINITY;
399                let mut max_label = 0usize;
400
401                for j in 0..num_labels {
402                    let score = data[idx(i, j)];
403                    if score > max_score {
404                        max_score = score;
405                        max_label = j;
406                    }
407                }
408
409                // Apply softmax to get confidence
410                let mut exp_sum = 0.0f32;
411                for j in 0..num_labels {
412                    exp_sum += (data[idx(i, j)] - max_score).exp();
413                }
414                let confidence = 1.0 / exp_sum;
415
416                let label = NER_LABELS.get(max_label).unwrap_or(&"O");
417                let (start_offset, end_offset) = offsets[i];
418
419                if *label == "O" || confidence < min_confidence {
420                    // End any current entity
421                    if let Some((entity_type, start, end, conf)) = current_entity.take() {
422                        if end > start && end <= original_text.len() {
423                            let text = original_text[start..end].trim().to_string();
424                            if !text.is_empty() {
425                                entities.push(ExtractedEntity {
426                                    text,
427                                    entity_type,
428                                    confidence: conf,
429                                    byte_start: start,
430                                    byte_end: end,
431                                });
432                            }
433                        }
434                    }
435                } else if label.starts_with("B-") {
436                    // Start new entity (end previous if any)
437                    if let Some((entity_type, start, end, conf)) = current_entity.take() {
438                        if end > start && end <= original_text.len() {
439                            let text = original_text[start..end].trim().to_string();
440                            if !text.is_empty() {
441                                entities.push(ExtractedEntity {
442                                    text,
443                                    entity_type,
444                                    confidence: conf,
445                                    byte_start: start,
446                                    byte_end: end,
447                                });
448                            }
449                        }
450                    }
451                    let entity_type = label[2..].to_string(); // Remove "B-" prefix
452                    current_entity = Some((entity_type, start_offset, end_offset, confidence));
453                } else if label.starts_with("I-") {
454                    // Continue entity
455                    if let Some((ref entity_type, start, _, ref mut conf)) = current_entity {
456                        let expected_type = &label[2..];
457                        if entity_type == expected_type {
458                            current_entity = Some((
459                                entity_type.clone(),
460                                start,
461                                end_offset,
462                                (*conf + confidence) / 2.0,
463                            ));
464                        }
465                    }
466                }
467            }
468
469            // Finalize last entity
470            if let Some((entity_type, start, end, conf)) = current_entity {
471                if end > start && end <= original_text.len() {
472                    let text = original_text[start..end].trim().to_string();
473                    if !text.is_empty() {
474                        entities.push(ExtractedEntity {
475                            text,
476                            entity_type,
477                            confidence: conf,
478                            byte_start: start,
479                            byte_end: end,
480                        });
481                    }
482                }
483            }
484
485            Ok(entities)
486        }
487
488        /// Extract entities from a frame's content
489        pub fn extract_from_frame(
490            &mut self,
491            frame_id: FrameId,
492            content: &str,
493        ) -> Result<FrameEntities> {
494            let entities = self.extract(content)?;
495            Ok(FrameEntities { frame_id, entities })
496        }
497
498        /// Get minimum confidence threshold
499        pub fn min_confidence(&self) -> f32 {
500            self.min_confidence
501        }
502
503        /// Get model path
504        pub fn model_path(&self) -> &Path {
505            &self.model_path
506        }
507    }
508
509    impl std::fmt::Debug for NerModel {
510        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
511            f.debug_struct("NerModel")
512                .field("model_path", &self.model_path)
513                .field("min_confidence", &self.min_confidence)
514                .finish()
515        }
516    }
517}
518
519// ============================================================================
520// Stub Implementation (when feature is disabled)
521// ============================================================================
522
523#[cfg(not(feature = "logic_mesh"))]
524#[allow(dead_code)]
525pub struct NerModel {
526    _private: (),
527}
528
529#[cfg(not(feature = "logic_mesh"))]
530#[allow(dead_code)]
531impl NerModel {
532    pub fn load(
533        _model_path: impl AsRef<Path>,
534        _tokenizer_path: impl AsRef<Path>,
535        _min_confidence: Option<f32>,
536    ) -> Result<Self> {
537        Err(MemvidError::FeatureUnavailable {
538            feature: "logic_mesh",
539        })
540    }
541
542    pub fn extract(&self, _text: &str) -> Result<Vec<ExtractedEntity>> {
543        Err(MemvidError::FeatureUnavailable {
544            feature: "logic_mesh",
545        })
546    }
547
548    pub fn extract_from_frame(&self, _frame_id: FrameId, _content: &str) -> Result<FrameEntities> {
549        Err(MemvidError::FeatureUnavailable {
550            feature: "logic_mesh",
551        })
552    }
553}
554
555// ============================================================================
556// Model Path Utilities
557// ============================================================================
558
559/// Get the expected path for the NER model in the models directory
560pub fn ner_model_path(models_dir: &Path) -> PathBuf {
561    models_dir.join(NER_MODEL_NAME).join("model.onnx")
562}
563
564/// Get the expected path for the NER tokenizer in the models directory
565pub fn ner_tokenizer_path(models_dir: &Path) -> PathBuf {
566    models_dir.join(NER_MODEL_NAME).join("tokenizer.json")
567}
568
569/// Check if NER model is installed
570pub fn is_ner_model_installed(models_dir: &Path) -> bool {
571    ner_model_path(models_dir).exists() && ner_tokenizer_path(models_dir).exists()
572}
573
574// ============================================================================
575// Tests
576// ============================================================================
577
578#[cfg(test)]
579mod tests {
580    use super::*;
581
582    #[test]
583    fn test_entity_kind_mapping() {
584        let cases = vec![
585            ("PER", EntityKind::Person),
586            ("B-PER", EntityKind::Person),
587            ("I-PER", EntityKind::Person),
588            ("ORG", EntityKind::Organization),
589            ("B-ORG", EntityKind::Organization),
590            ("LOC", EntityKind::Location),
591            ("B-LOC", EntityKind::Location),
592            ("MISC", EntityKind::Other),
593            ("B-MISC", EntityKind::Other),
594            ("unknown", EntityKind::Other),
595        ];
596
597        for (entity_type, expected_kind) in cases {
598            let entity = ExtractedEntity {
599                text: "test".to_string(),
600                entity_type: entity_type.to_string(),
601                confidence: 0.9,
602                byte_start: 0,
603                byte_end: 4,
604            };
605            assert_eq!(
606                entity.to_entity_kind(),
607                expected_kind,
608                "Failed for entity_type: {}",
609                entity_type
610            );
611        }
612    }
613
614    #[test]
615    fn test_model_info() {
616        let info = default_ner_model_info();
617        assert_eq!(info.name, NER_MODEL_NAME);
618        assert!(info.is_default);
619        assert!(info.size_mb > 200.0);
620    }
621
622    #[test]
623    fn test_model_paths() {
624        let models_dir = PathBuf::from("/tmp/models");
625        let model_path = ner_model_path(&models_dir);
626        let tokenizer_path = ner_tokenizer_path(&models_dir);
627
628        assert!(model_path.to_string_lossy().contains("model.onnx"));
629        assert!(tokenizer_path.to_string_lossy().contains("tokenizer.json"));
630    }
631
632    #[test]
633    fn test_ner_labels() {
634        assert_eq!(NER_LABELS.len(), 9);
635        assert_eq!(NER_LABELS[0], "O");
636        assert_eq!(NER_LABELS[1], "B-PER");
637        assert_eq!(NER_LABELS[3], "B-ORG");
638        assert_eq!(NER_LABELS[5], "B-LOC");
639        assert_eq!(NER_LABELS[7], "B-MISC");
640    }
641}