memvid_core/analysis/
ner.rs

1//! Named Entity Recognition (NER) module using DistilBERT-NER ONNX.
2//!
3//! This module provides entity extraction capabilities using DistilBERT-NER,
4//! a fast and accurate NER model fine-tuned on CoNLL-03.
5//!
6//! # Model
7//!
8//! Uses dslim/distilbert-NER ONNX (~261 MB) with 92% F1 score.
9//! Entities: Person (PER), Organization (ORG), Location (LOC), Miscellaneous (MISC)
10//!
11//! # Simple Interface
12//!
13//! Unlike GLiNER, DistilBERT-NER uses standard BERT tokenization:
14//! - Input: `input_ids`, `attention_mask`
15//! - Output: per-token logits for B-PER, I-PER, B-ORG, I-ORG, B-LOC, I-LOC, B-MISC, I-MISC, O
16
17use crate::types::{EntityKind, FrameId};
18use crate::{MemvidError, Result};
19use std::path::{Path, PathBuf};
20
21// ============================================================================
22// Configuration Constants
23// ============================================================================
24
25/// Model name for downloads and caching
26pub const NER_MODEL_NAME: &str = "distilbert-ner";
27
28/// Model download URL (HuggingFace)
29pub const NER_MODEL_URL: &str =
30    "https://huggingface.co/dslim/distilbert-NER/resolve/main/onnx/model.onnx";
31
32/// Tokenizer URL
33pub const NER_TOKENIZER_URL: &str =
34    "https://huggingface.co/dslim/distilbert-NER/resolve/main/tokenizer.json";
35
36/// Approximate model size in MB
37pub const NER_MODEL_SIZE_MB: f32 = 261.0;
38
39/// Maximum sequence length for the model
40pub const NER_MAX_SEQ_LEN: usize = 512;
41
42/// Minimum confidence threshold for entity extraction
43#[cfg_attr(not(feature = "logic_mesh"), allow(dead_code))]
44pub const NER_MIN_CONFIDENCE: f32 = 0.5;
45
46/// NER label mapping (CoNLL-03 format)
47/// O=0, B-PER=1, I-PER=2, B-ORG=3, I-ORG=4, B-LOC=5, I-LOC=6, B-MISC=7, I-MISC=8
48#[cfg_attr(not(feature = "logic_mesh"), allow(dead_code))]
49pub const NER_LABELS: &[&str] = &[
50    "O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "B-MISC", "I-MISC",
51];
52
53// ============================================================================
54// Model Info
55// ============================================================================
56
57/// NER model info for the models registry
58#[derive(Debug, Clone)]
59pub struct NerModelInfo {
60    /// Model identifier
61    pub name: &'static str,
62    /// URL for ONNX model
63    pub model_url: &'static str,
64    /// URL for tokenizer JSON
65    pub tokenizer_url: &'static str,
66    /// Model size in MB
67    pub size_mb: f32,
68    /// Maximum sequence length
69    pub max_seq_len: usize,
70    /// Whether this is the default model
71    pub is_default: bool,
72}
73
74/// Available NER models registry
75pub static NER_MODELS: &[NerModelInfo] = &[NerModelInfo {
76    name: NER_MODEL_NAME,
77    model_url: NER_MODEL_URL,
78    tokenizer_url: NER_TOKENIZER_URL,
79    size_mb: NER_MODEL_SIZE_MB,
80    max_seq_len: NER_MAX_SEQ_LEN,
81    is_default: true,
82}];
83
84/// Get NER model info by name
85pub fn get_ner_model_info(name: &str) -> Option<&'static NerModelInfo> {
86    NER_MODELS.iter().find(|m| m.name == name)
87}
88
89/// Get default NER model info
90pub fn default_ner_model_info() -> &'static NerModelInfo {
91    NER_MODELS
92        .iter()
93        .find(|m| m.is_default)
94        .expect("default NER model must exist")
95}
96
97// ============================================================================
98// Entity Extraction Types
99// ============================================================================
100
101/// Raw entity mention extracted from text
102#[derive(Debug, Clone)]
103pub struct ExtractedEntity {
104    /// The extracted text span
105    pub text: String,
106    /// Entity type (PER, ORG, LOC, MISC)
107    pub entity_type: String,
108    /// Confidence score (0.0-1.0)
109    pub confidence: f32,
110    /// Byte offset start in original text
111    pub byte_start: usize,
112    /// Byte offset end in original text
113    pub byte_end: usize,
114}
115
116impl ExtractedEntity {
117    /// Convert the raw entity type to our EntityKind enum
118    pub fn to_entity_kind(&self) -> EntityKind {
119        match self.entity_type.to_uppercase().as_str() {
120            "PER" | "PERSON" | "B-PER" | "I-PER" => EntityKind::Person,
121            "ORG" | "ORGANIZATION" | "B-ORG" | "I-ORG" => EntityKind::Organization,
122            "LOC" | "LOCATION" | "B-LOC" | "I-LOC" => EntityKind::Location,
123            "MISC" | "B-MISC" | "I-MISC" => EntityKind::Other,
124            _ => EntityKind::Other,
125        }
126    }
127}
128
129/// Result of extracting entities from a frame
130#[derive(Debug, Clone)]
131pub struct FrameEntities {
132    /// Frame ID the entities were extracted from
133    pub frame_id: FrameId,
134    /// Extracted entities
135    pub entities: Vec<ExtractedEntity>,
136}
137
138// ============================================================================
139// NER Model (Feature-gated)
140// ============================================================================
141
142#[cfg(feature = "logic_mesh")]
143pub use model_impl::*;
144
145#[cfg(feature = "logic_mesh")]
146mod model_impl {
147    use super::*;
148    use ort::session::{builder::GraphOptimizationLevel, Session};
149    use ort::value::Tensor;
150    use std::sync::Mutex;
151    use tokenizers::{
152        PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer, TruncationDirection,
153        TruncationParams, TruncationStrategy,
154    };
155
156    /// DistilBERT-NER model for entity extraction
157    pub struct NerModel {
158        /// ONNX runtime session
159        session: Session,
160        /// Tokenizer for text preprocessing
161        tokenizer: Mutex<Tokenizer>,
162        /// Model path for reference
163        model_path: PathBuf,
164        /// Minimum confidence threshold
165        min_confidence: f32,
166    }
167
168    impl NerModel {
169        /// Load NER model from path
170        ///
171        /// # Arguments
172        /// * `model_path` - Path to the ONNX model file
173        /// * `tokenizer_path` - Path to the tokenizer.json file
174        /// * `min_confidence` - Minimum confidence threshold (default: 0.5)
175        pub fn load(
176            model_path: impl AsRef<Path>,
177            tokenizer_path: impl AsRef<Path>,
178            min_confidence: Option<f32>,
179        ) -> Result<Self> {
180            let model_path = model_path.as_ref().to_path_buf();
181            let tokenizer_path = tokenizer_path.as_ref();
182
183            // Load tokenizer
184            let mut tokenizer = Tokenizer::from_file(tokenizer_path).map_err(|e| {
185                MemvidError::NerModelNotAvailable {
186                    reason: format!("failed to load tokenizer from {:?}: {}", tokenizer_path, e)
187                        .into(),
188                }
189            })?;
190
191            // Configure padding and truncation
192            tokenizer.with_padding(Some(PaddingParams {
193                strategy: PaddingStrategy::BatchLongest,
194                direction: PaddingDirection::Right,
195                pad_to_multiple_of: None,
196                pad_id: 0,
197                pad_type_id: 0,
198                pad_token: "[PAD]".to_string(),
199            }));
200
201            tokenizer
202                .with_truncation(Some(TruncationParams {
203                    max_length: NER_MAX_SEQ_LEN,
204                    strategy: TruncationStrategy::LongestFirst,
205                    stride: 0,
206                    direction: TruncationDirection::Right,
207                }))
208                .map_err(|e| MemvidError::NerModelNotAvailable {
209                    reason: format!("failed to set truncation: {}", e).into(),
210                })?;
211
212            // Initialize ONNX Runtime
213            let session = Session::builder()
214                .map_err(|e| MemvidError::NerModelNotAvailable {
215                    reason: format!("failed to create session builder: {}", e).into(),
216                })?
217                .with_optimization_level(GraphOptimizationLevel::Level3)
218                .map_err(|e| MemvidError::NerModelNotAvailable {
219                    reason: format!("failed to set optimization level: {}", e).into(),
220                })?
221                .with_intra_threads(4)
222                .map_err(|e| MemvidError::NerModelNotAvailable {
223                    reason: format!("failed to set threads: {}", e).into(),
224                })?
225                .commit_from_file(&model_path)
226                .map_err(|e| MemvidError::NerModelNotAvailable {
227                    reason: format!("failed to load model from {:?}: {}", model_path, e).into(),
228                })?;
229
230            tracing::info!(
231                model = %model_path.display(),
232                "DistilBERT-NER model loaded"
233            );
234
235            Ok(Self {
236                session,
237                tokenizer: Mutex::new(tokenizer),
238                model_path,
239                min_confidence: min_confidence.unwrap_or(NER_MIN_CONFIDENCE),
240            })
241        }
242
243        /// Extract entities from text
244        pub fn extract(&mut self, text: &str) -> Result<Vec<ExtractedEntity>> {
245            if text.trim().is_empty() {
246                return Ok(Vec::new());
247            }
248
249            // Tokenize
250            let tokenizer = self.tokenizer.lock().map_err(|_| MemvidError::Lock(
251                "failed to lock tokenizer".into(),
252            ))?;
253
254            let encoding = tokenizer.encode(text, true).map_err(|e| {
255                MemvidError::NerModelNotAvailable {
256                    reason: format!("tokenization failed: {}", e).into(),
257                }
258            })?;
259
260            let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
261            let attention_mask: Vec<i64> =
262                encoding.get_attention_mask().iter().map(|&x| x as i64).collect();
263            let tokens = encoding.get_tokens().to_vec();
264            let offsets = encoding.get_offsets().to_vec();
265
266            drop(tokenizer); // Release lock before inference
267
268            let seq_len = input_ids.len();
269
270            // Create input tensors using Tensor::from_array
271            let input_ids_array =
272                ndarray::Array2::from_shape_vec((1, seq_len), input_ids).map_err(|e| {
273                    MemvidError::NerModelNotAvailable {
274                        reason: format!("failed to create input_ids array: {}", e).into(),
275                    }
276                })?;
277
278            let attention_mask_array =
279                ndarray::Array2::from_shape_vec((1, seq_len), attention_mask).map_err(|e| {
280                    MemvidError::NerModelNotAvailable {
281                        reason: format!("failed to create attention_mask array: {}", e).into(),
282                    }
283                })?;
284
285            let input_ids_tensor =
286                Tensor::from_array(input_ids_array).map_err(|e| MemvidError::NerModelNotAvailable {
287                    reason: format!("failed to create input_ids tensor: {}", e).into(),
288                })?;
289
290            let attention_mask_tensor = Tensor::from_array(attention_mask_array).map_err(|e| {
291                MemvidError::NerModelNotAvailable {
292                    reason: format!("failed to create attention_mask tensor: {}", e).into(),
293                }
294            })?;
295
296            // Get output name before inference (avoid borrow conflict)
297            let output_name = self
298                .session
299                .outputs
300                .first()
301                .map(|o| o.name.clone())
302                .unwrap_or_else(|| "logits".into());
303
304            // Run inference
305            let outputs = self
306                .session
307                .run(ort::inputs![
308                    "input_ids" => input_ids_tensor,
309                    "attention_mask" => attention_mask_tensor,
310                ])
311                .map_err(|e| MemvidError::NerModelNotAvailable {
312                    reason: format!("inference failed: {}", e).into(),
313                })?;
314
315            let logits = outputs
316                .get(&output_name)
317                .ok_or_else(|| MemvidError::NerModelNotAvailable {
318                    reason: format!("no output '{}' found", output_name).into(),
319                })?;
320
321            // Parse logits to get predictions
322            let entities = Self::decode_predictions_static(
323                text,
324                &tokens,
325                &offsets,
326                logits,
327                self.min_confidence,
328            )?;
329
330            Ok(entities)
331        }
332
333        /// Decode model predictions into entities (static to avoid borrow issues)
334        fn decode_predictions_static(
335            original_text: &str,
336            tokens: &[String],
337            offsets: &[(usize, usize)],
338            logits: &ort::value::Value,
339            min_confidence: f32,
340        ) -> Result<Vec<ExtractedEntity>> {
341            // Extract the logits tensor - shape: [1, seq_len, num_labels]
342            let (shape, data) = logits
343                .try_extract_tensor::<f32>()
344                .map_err(|e| MemvidError::NerModelNotAvailable {
345                    reason: format!("failed to extract logits: {}", e).into(),
346                })?;
347
348            // Shape is iterable, convert to Vec
349            let shape_vec: Vec<i64> = shape.iter().copied().collect();
350
351            if shape_vec.len() != 3 {
352                return Err(MemvidError::NerModelNotAvailable {
353                    reason: format!("unexpected logits shape: {:?}", shape_vec).into(),
354                });
355            }
356
357            let seq_len = shape_vec[1] as usize;
358            let num_labels = shape_vec[2] as usize;
359
360            // Helper to index into flat data: [batch, seq, labels] -> flat index
361            let idx = |i: usize, j: usize| -> usize { i * num_labels + j };
362
363            let mut entities = Vec::new();
364            let mut current_entity: Option<(String, usize, usize, f32)> = None;
365
366            for i in 0..seq_len {
367                if i >= tokens.len() || i >= offsets.len() {
368                    break;
369                }
370
371                // Skip special tokens
372                let token = &tokens[i];
373                if token == "[CLS]" || token == "[SEP]" || token == "[PAD]" {
374                    // Finalize any current entity
375                    if let Some((entity_type, start, end, conf)) = current_entity.take() {
376                        if end > start && end <= original_text.len() {
377                            let text = original_text[start..end].trim().to_string();
378                            if !text.is_empty() {
379                                entities.push(ExtractedEntity {
380                                    text,
381                                    entity_type,
382                                    confidence: conf,
383                                    byte_start: start,
384                                    byte_end: end,
385                                });
386                            }
387                        }
388                    }
389                    continue;
390                }
391
392                // Get prediction for this token
393                let mut max_score = f32::NEG_INFINITY;
394                let mut max_label = 0usize;
395
396                for j in 0..num_labels {
397                    let score = data[idx(i, j)];
398                    if score > max_score {
399                        max_score = score;
400                        max_label = j;
401                    }
402                }
403
404                // Apply softmax to get confidence
405                let mut exp_sum = 0.0f32;
406                for j in 0..num_labels {
407                    exp_sum += (data[idx(i, j)] - max_score).exp();
408                }
409                let confidence = 1.0 / exp_sum;
410
411                let label = NER_LABELS.get(max_label).unwrap_or(&"O");
412                let (start_offset, end_offset) = offsets[i];
413
414                if *label == "O" || confidence < min_confidence {
415                    // End any current entity
416                    if let Some((entity_type, start, end, conf)) = current_entity.take() {
417                        if end > start && end <= original_text.len() {
418                            let text = original_text[start..end].trim().to_string();
419                            if !text.is_empty() {
420                                entities.push(ExtractedEntity {
421                                    text,
422                                    entity_type,
423                                    confidence: conf,
424                                    byte_start: start,
425                                    byte_end: end,
426                                });
427                            }
428                        }
429                    }
430                } else if label.starts_with("B-") {
431                    // Start new entity (end previous if any)
432                    if let Some((entity_type, start, end, conf)) = current_entity.take() {
433                        if end > start && end <= original_text.len() {
434                            let text = original_text[start..end].trim().to_string();
435                            if !text.is_empty() {
436                                entities.push(ExtractedEntity {
437                                    text,
438                                    entity_type,
439                                    confidence: conf,
440                                    byte_start: start,
441                                    byte_end: end,
442                                });
443                            }
444                        }
445                    }
446                    let entity_type = label[2..].to_string(); // Remove "B-" prefix
447                    current_entity = Some((entity_type, start_offset, end_offset, confidence));
448                } else if label.starts_with("I-") {
449                    // Continue entity
450                    if let Some((ref entity_type, start, _, ref mut conf)) = current_entity {
451                        let expected_type = &label[2..];
452                        if entity_type == expected_type {
453                            current_entity = Some((
454                                entity_type.clone(),
455                                start,
456                                end_offset,
457                                (*conf + confidence) / 2.0,
458                            ));
459                        }
460                    }
461                }
462            }
463
464            // Finalize last entity
465            if let Some((entity_type, start, end, conf)) = current_entity {
466                if end > start && end <= original_text.len() {
467                    let text = original_text[start..end].trim().to_string();
468                    if !text.is_empty() {
469                        entities.push(ExtractedEntity {
470                            text,
471                            entity_type,
472                            confidence: conf,
473                            byte_start: start,
474                            byte_end: end,
475                        });
476                    }
477                }
478            }
479
480            Ok(entities)
481        }
482
483        /// Extract entities from a frame's content
484        pub fn extract_from_frame(
485            &mut self,
486            frame_id: FrameId,
487            content: &str,
488        ) -> Result<FrameEntities> {
489            let entities = self.extract(content)?;
490            Ok(FrameEntities { frame_id, entities })
491        }
492
493        /// Get minimum confidence threshold
494        pub fn min_confidence(&self) -> f32 {
495            self.min_confidence
496        }
497
498        /// Get model path
499        pub fn model_path(&self) -> &Path {
500            &self.model_path
501        }
502    }
503
504    impl std::fmt::Debug for NerModel {
505        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
506            f.debug_struct("NerModel")
507                .field("model_path", &self.model_path)
508                .field("min_confidence", &self.min_confidence)
509                .finish()
510        }
511    }
512}
513
514// ============================================================================
515// Stub Implementation (when feature is disabled)
516// ============================================================================
517
518#[cfg(not(feature = "logic_mesh"))]
519#[allow(dead_code)]
520pub struct NerModel {
521    _private: (),
522}
523
524#[cfg(not(feature = "logic_mesh"))]
525#[allow(dead_code)]
526impl NerModel {
527    pub fn load(
528        _model_path: impl AsRef<Path>,
529        _tokenizer_path: impl AsRef<Path>,
530        _min_confidence: Option<f32>,
531    ) -> Result<Self> {
532        Err(MemvidError::FeatureUnavailable {
533            feature: "logic_mesh",
534        })
535    }
536
537    pub fn extract(&self, _text: &str) -> Result<Vec<ExtractedEntity>> {
538        Err(MemvidError::FeatureUnavailable {
539            feature: "logic_mesh",
540        })
541    }
542
543    pub fn extract_from_frame(&self, _frame_id: FrameId, _content: &str) -> Result<FrameEntities> {
544        Err(MemvidError::FeatureUnavailable {
545            feature: "logic_mesh",
546        })
547    }
548}
549
550// ============================================================================
551// Model Path Utilities
552// ============================================================================
553
554/// Get the expected path for the NER model in the models directory
555pub fn ner_model_path(models_dir: &Path) -> PathBuf {
556    models_dir.join(NER_MODEL_NAME).join("model.onnx")
557}
558
559/// Get the expected path for the NER tokenizer in the models directory
560pub fn ner_tokenizer_path(models_dir: &Path) -> PathBuf {
561    models_dir.join(NER_MODEL_NAME).join("tokenizer.json")
562}
563
564/// Check if NER model is installed
565pub fn is_ner_model_installed(models_dir: &Path) -> bool {
566    ner_model_path(models_dir).exists() && ner_tokenizer_path(models_dir).exists()
567}
568
569// ============================================================================
570// Tests
571// ============================================================================
572
573#[cfg(test)]
574mod tests {
575    use super::*;
576
577    #[test]
578    fn test_entity_kind_mapping() {
579        let cases = vec![
580            ("PER", EntityKind::Person),
581            ("B-PER", EntityKind::Person),
582            ("I-PER", EntityKind::Person),
583            ("ORG", EntityKind::Organization),
584            ("B-ORG", EntityKind::Organization),
585            ("LOC", EntityKind::Location),
586            ("B-LOC", EntityKind::Location),
587            ("MISC", EntityKind::Other),
588            ("B-MISC", EntityKind::Other),
589            ("unknown", EntityKind::Other),
590        ];
591
592        for (entity_type, expected_kind) in cases {
593            let entity = ExtractedEntity {
594                text: "test".to_string(),
595                entity_type: entity_type.to_string(),
596                confidence: 0.9,
597                byte_start: 0,
598                byte_end: 4,
599            };
600            assert_eq!(
601                entity.to_entity_kind(),
602                expected_kind,
603                "Failed for entity_type: {}",
604                entity_type
605            );
606        }
607    }
608
609    #[test]
610    fn test_model_info() {
611        let info = default_ner_model_info();
612        assert_eq!(info.name, NER_MODEL_NAME);
613        assert!(info.is_default);
614        assert!(info.size_mb > 200.0);
615    }
616
617    #[test]
618    fn test_model_paths() {
619        let models_dir = PathBuf::from("/tmp/models");
620        let model_path = ner_model_path(&models_dir);
621        let tokenizer_path = ner_tokenizer_path(&models_dir);
622
623        assert!(model_path.to_string_lossy().contains("model.onnx"));
624        assert!(tokenizer_path.to_string_lossy().contains("tokenizer.json"));
625    }
626
627    #[test]
628    fn test_ner_labels() {
629        assert_eq!(NER_LABELS.len(), 9);
630        assert_eq!(NER_LABELS[0], "O");
631        assert_eq!(NER_LABELS[1], "B-PER");
632        assert_eq!(NER_LABELS[3], "B-ORG");
633        assert_eq!(NER_LABELS[5], "B-LOC");
634        assert_eq!(NER_LABELS[7], "B-MISC");
635    }
636}