1use crate::types::{EntityKind, FrameId};
18use crate::{MemvidError, Result};
19use std::path::{Path, PathBuf};
20
21pub const NER_MODEL_NAME: &str = "distilbert-ner";
27
28pub const NER_MODEL_URL: &str =
30 "https://huggingface.co/dslim/distilbert-NER/resolve/main/onnx/model.onnx";
31
32pub const NER_TOKENIZER_URL: &str =
34 "https://huggingface.co/dslim/distilbert-NER/resolve/main/tokenizer.json";
35
36pub const NER_MODEL_SIZE_MB: f32 = 261.0;
38
39pub const NER_MAX_SEQ_LEN: usize = 512;
41
42#[cfg_attr(not(feature = "logic_mesh"), allow(dead_code))]
44pub const NER_MIN_CONFIDENCE: f32 = 0.5;
45
46#[cfg_attr(not(feature = "logic_mesh"), allow(dead_code))]
49pub const NER_LABELS: &[&str] = &[
50 "O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "B-MISC", "I-MISC",
51];
52
53#[derive(Debug, Clone)]
59pub struct NerModelInfo {
60 pub name: &'static str,
62 pub model_url: &'static str,
64 pub tokenizer_url: &'static str,
66 pub size_mb: f32,
68 pub max_seq_len: usize,
70 pub is_default: bool,
72}
73
74pub static NER_MODELS: &[NerModelInfo] = &[NerModelInfo {
76 name: NER_MODEL_NAME,
77 model_url: NER_MODEL_URL,
78 tokenizer_url: NER_TOKENIZER_URL,
79 size_mb: NER_MODEL_SIZE_MB,
80 max_seq_len: NER_MAX_SEQ_LEN,
81 is_default: true,
82}];
83
84pub fn get_ner_model_info(name: &str) -> Option<&'static NerModelInfo> {
86 NER_MODELS.iter().find(|m| m.name == name)
87}
88
89pub fn default_ner_model_info() -> &'static NerModelInfo {
91 NER_MODELS
92 .iter()
93 .find(|m| m.is_default)
94 .expect("default NER model must exist")
95}
96
97#[derive(Debug, Clone)]
103pub struct ExtractedEntity {
104 pub text: String,
106 pub entity_type: String,
108 pub confidence: f32,
110 pub byte_start: usize,
112 pub byte_end: usize,
114}
115
116impl ExtractedEntity {
117 pub fn to_entity_kind(&self) -> EntityKind {
119 match self.entity_type.to_uppercase().as_str() {
120 "PER" | "PERSON" | "B-PER" | "I-PER" => EntityKind::Person,
121 "ORG" | "ORGANIZATION" | "B-ORG" | "I-ORG" => EntityKind::Organization,
122 "LOC" | "LOCATION" | "B-LOC" | "I-LOC" => EntityKind::Location,
123 "MISC" | "B-MISC" | "I-MISC" => EntityKind::Other,
124 _ => EntityKind::Other,
125 }
126 }
127}
128
129#[derive(Debug, Clone)]
131pub struct FrameEntities {
132 pub frame_id: FrameId,
134 pub entities: Vec<ExtractedEntity>,
136}
137
138#[cfg(feature = "logic_mesh")]
143pub use model_impl::*;
144
145#[cfg(feature = "logic_mesh")]
146mod model_impl {
147 use super::*;
148 use ort::session::{builder::GraphOptimizationLevel, Session};
149 use ort::value::Tensor;
150 use std::sync::Mutex;
151 use tokenizers::{
152 PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer, TruncationDirection,
153 TruncationParams, TruncationStrategy,
154 };
155
156 pub struct NerModel {
158 session: Session,
160 tokenizer: Mutex<Tokenizer>,
162 model_path: PathBuf,
164 min_confidence: f32,
166 }
167
168 impl NerModel {
169 pub fn load(
176 model_path: impl AsRef<Path>,
177 tokenizer_path: impl AsRef<Path>,
178 min_confidence: Option<f32>,
179 ) -> Result<Self> {
180 let model_path = model_path.as_ref().to_path_buf();
181 let tokenizer_path = tokenizer_path.as_ref();
182
183 let mut tokenizer = Tokenizer::from_file(tokenizer_path).map_err(|e| {
185 MemvidError::NerModelNotAvailable {
186 reason: format!("failed to load tokenizer from {:?}: {}", tokenizer_path, e)
187 .into(),
188 }
189 })?;
190
191 tokenizer.with_padding(Some(PaddingParams {
193 strategy: PaddingStrategy::BatchLongest,
194 direction: PaddingDirection::Right,
195 pad_to_multiple_of: None,
196 pad_id: 0,
197 pad_type_id: 0,
198 pad_token: "[PAD]".to_string(),
199 }));
200
201 tokenizer
202 .with_truncation(Some(TruncationParams {
203 max_length: NER_MAX_SEQ_LEN,
204 strategy: TruncationStrategy::LongestFirst,
205 stride: 0,
206 direction: TruncationDirection::Right,
207 }))
208 .map_err(|e| MemvidError::NerModelNotAvailable {
209 reason: format!("failed to set truncation: {}", e).into(),
210 })?;
211
212 let session = Session::builder()
214 .map_err(|e| MemvidError::NerModelNotAvailable {
215 reason: format!("failed to create session builder: {}", e).into(),
216 })?
217 .with_optimization_level(GraphOptimizationLevel::Level3)
218 .map_err(|e| MemvidError::NerModelNotAvailable {
219 reason: format!("failed to set optimization level: {}", e).into(),
220 })?
221 .with_intra_threads(4)
222 .map_err(|e| MemvidError::NerModelNotAvailable {
223 reason: format!("failed to set threads: {}", e).into(),
224 })?
225 .commit_from_file(&model_path)
226 .map_err(|e| MemvidError::NerModelNotAvailable {
227 reason: format!("failed to load model from {:?}: {}", model_path, e).into(),
228 })?;
229
230 tracing::info!(
231 model = %model_path.display(),
232 "DistilBERT-NER model loaded"
233 );
234
235 Ok(Self {
236 session,
237 tokenizer: Mutex::new(tokenizer),
238 model_path,
239 min_confidence: min_confidence.unwrap_or(NER_MIN_CONFIDENCE),
240 })
241 }
242
243 pub fn extract(&mut self, text: &str) -> Result<Vec<ExtractedEntity>> {
245 if text.trim().is_empty() {
246 return Ok(Vec::new());
247 }
248
249 let tokenizer = self.tokenizer.lock().map_err(|_| MemvidError::Lock(
251 "failed to lock tokenizer".into(),
252 ))?;
253
254 let encoding = tokenizer.encode(text, true).map_err(|e| {
255 MemvidError::NerModelNotAvailable {
256 reason: format!("tokenization failed: {}", e).into(),
257 }
258 })?;
259
260 let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
261 let attention_mask: Vec<i64> =
262 encoding.get_attention_mask().iter().map(|&x| x as i64).collect();
263 let tokens = encoding.get_tokens().to_vec();
264 let offsets = encoding.get_offsets().to_vec();
265
266 drop(tokenizer); let seq_len = input_ids.len();
269
270 let input_ids_array =
272 ndarray::Array2::from_shape_vec((1, seq_len), input_ids).map_err(|e| {
273 MemvidError::NerModelNotAvailable {
274 reason: format!("failed to create input_ids array: {}", e).into(),
275 }
276 })?;
277
278 let attention_mask_array =
279 ndarray::Array2::from_shape_vec((1, seq_len), attention_mask).map_err(|e| {
280 MemvidError::NerModelNotAvailable {
281 reason: format!("failed to create attention_mask array: {}", e).into(),
282 }
283 })?;
284
285 let input_ids_tensor =
286 Tensor::from_array(input_ids_array).map_err(|e| MemvidError::NerModelNotAvailable {
287 reason: format!("failed to create input_ids tensor: {}", e).into(),
288 })?;
289
290 let attention_mask_tensor = Tensor::from_array(attention_mask_array).map_err(|e| {
291 MemvidError::NerModelNotAvailable {
292 reason: format!("failed to create attention_mask tensor: {}", e).into(),
293 }
294 })?;
295
296 let output_name = self
298 .session
299 .outputs
300 .first()
301 .map(|o| o.name.clone())
302 .unwrap_or_else(|| "logits".into());
303
304 let outputs = self
306 .session
307 .run(ort::inputs![
308 "input_ids" => input_ids_tensor,
309 "attention_mask" => attention_mask_tensor,
310 ])
311 .map_err(|e| MemvidError::NerModelNotAvailable {
312 reason: format!("inference failed: {}", e).into(),
313 })?;
314
315 let logits = outputs
316 .get(&output_name)
317 .ok_or_else(|| MemvidError::NerModelNotAvailable {
318 reason: format!("no output '{}' found", output_name).into(),
319 })?;
320
321 let entities = Self::decode_predictions_static(
323 text,
324 &tokens,
325 &offsets,
326 logits,
327 self.min_confidence,
328 )?;
329
330 Ok(entities)
331 }
332
333 fn decode_predictions_static(
335 original_text: &str,
336 tokens: &[String],
337 offsets: &[(usize, usize)],
338 logits: &ort::value::Value,
339 min_confidence: f32,
340 ) -> Result<Vec<ExtractedEntity>> {
341 let (shape, data) = logits
343 .try_extract_tensor::<f32>()
344 .map_err(|e| MemvidError::NerModelNotAvailable {
345 reason: format!("failed to extract logits: {}", e).into(),
346 })?;
347
348 let shape_vec: Vec<i64> = shape.iter().copied().collect();
350
351 if shape_vec.len() != 3 {
352 return Err(MemvidError::NerModelNotAvailable {
353 reason: format!("unexpected logits shape: {:?}", shape_vec).into(),
354 });
355 }
356
357 let seq_len = shape_vec[1] as usize;
358 let num_labels = shape_vec[2] as usize;
359
360 let idx = |i: usize, j: usize| -> usize { i * num_labels + j };
362
363 let mut entities = Vec::new();
364 let mut current_entity: Option<(String, usize, usize, f32)> = None;
365
366 for i in 0..seq_len {
367 if i >= tokens.len() || i >= offsets.len() {
368 break;
369 }
370
371 let token = &tokens[i];
373 if token == "[CLS]" || token == "[SEP]" || token == "[PAD]" {
374 if let Some((entity_type, start, end, conf)) = current_entity.take() {
376 if end > start && end <= original_text.len() {
377 let text = original_text[start..end].trim().to_string();
378 if !text.is_empty() {
379 entities.push(ExtractedEntity {
380 text,
381 entity_type,
382 confidence: conf,
383 byte_start: start,
384 byte_end: end,
385 });
386 }
387 }
388 }
389 continue;
390 }
391
392 let mut max_score = f32::NEG_INFINITY;
394 let mut max_label = 0usize;
395
396 for j in 0..num_labels {
397 let score = data[idx(i, j)];
398 if score > max_score {
399 max_score = score;
400 max_label = j;
401 }
402 }
403
404 let mut exp_sum = 0.0f32;
406 for j in 0..num_labels {
407 exp_sum += (data[idx(i, j)] - max_score).exp();
408 }
409 let confidence = 1.0 / exp_sum;
410
411 let label = NER_LABELS.get(max_label).unwrap_or(&"O");
412 let (start_offset, end_offset) = offsets[i];
413
414 if *label == "O" || confidence < min_confidence {
415 if let Some((entity_type, start, end, conf)) = current_entity.take() {
417 if end > start && end <= original_text.len() {
418 let text = original_text[start..end].trim().to_string();
419 if !text.is_empty() {
420 entities.push(ExtractedEntity {
421 text,
422 entity_type,
423 confidence: conf,
424 byte_start: start,
425 byte_end: end,
426 });
427 }
428 }
429 }
430 } else if label.starts_with("B-") {
431 if let Some((entity_type, start, end, conf)) = current_entity.take() {
433 if end > start && end <= original_text.len() {
434 let text = original_text[start..end].trim().to_string();
435 if !text.is_empty() {
436 entities.push(ExtractedEntity {
437 text,
438 entity_type,
439 confidence: conf,
440 byte_start: start,
441 byte_end: end,
442 });
443 }
444 }
445 }
446 let entity_type = label[2..].to_string(); current_entity = Some((entity_type, start_offset, end_offset, confidence));
448 } else if label.starts_with("I-") {
449 if let Some((ref entity_type, start, _, ref mut conf)) = current_entity {
451 let expected_type = &label[2..];
452 if entity_type == expected_type {
453 current_entity = Some((
454 entity_type.clone(),
455 start,
456 end_offset,
457 (*conf + confidence) / 2.0,
458 ));
459 }
460 }
461 }
462 }
463
464 if let Some((entity_type, start, end, conf)) = current_entity {
466 if end > start && end <= original_text.len() {
467 let text = original_text[start..end].trim().to_string();
468 if !text.is_empty() {
469 entities.push(ExtractedEntity {
470 text,
471 entity_type,
472 confidence: conf,
473 byte_start: start,
474 byte_end: end,
475 });
476 }
477 }
478 }
479
480 Ok(entities)
481 }
482
483 pub fn extract_from_frame(
485 &mut self,
486 frame_id: FrameId,
487 content: &str,
488 ) -> Result<FrameEntities> {
489 let entities = self.extract(content)?;
490 Ok(FrameEntities { frame_id, entities })
491 }
492
493 pub fn min_confidence(&self) -> f32 {
495 self.min_confidence
496 }
497
498 pub fn model_path(&self) -> &Path {
500 &self.model_path
501 }
502 }
503
504 impl std::fmt::Debug for NerModel {
505 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
506 f.debug_struct("NerModel")
507 .field("model_path", &self.model_path)
508 .field("min_confidence", &self.min_confidence)
509 .finish()
510 }
511 }
512}
513
514#[cfg(not(feature = "logic_mesh"))]
519#[allow(dead_code)]
520pub struct NerModel {
521 _private: (),
522}
523
524#[cfg(not(feature = "logic_mesh"))]
525#[allow(dead_code)]
526impl NerModel {
527 pub fn load(
528 _model_path: impl AsRef<Path>,
529 _tokenizer_path: impl AsRef<Path>,
530 _min_confidence: Option<f32>,
531 ) -> Result<Self> {
532 Err(MemvidError::FeatureUnavailable {
533 feature: "logic_mesh",
534 })
535 }
536
537 pub fn extract(&self, _text: &str) -> Result<Vec<ExtractedEntity>> {
538 Err(MemvidError::FeatureUnavailable {
539 feature: "logic_mesh",
540 })
541 }
542
543 pub fn extract_from_frame(&self, _frame_id: FrameId, _content: &str) -> Result<FrameEntities> {
544 Err(MemvidError::FeatureUnavailable {
545 feature: "logic_mesh",
546 })
547 }
548}
549
550pub fn ner_model_path(models_dir: &Path) -> PathBuf {
556 models_dir.join(NER_MODEL_NAME).join("model.onnx")
557}
558
559pub fn ner_tokenizer_path(models_dir: &Path) -> PathBuf {
561 models_dir.join(NER_MODEL_NAME).join("tokenizer.json")
562}
563
564pub fn is_ner_model_installed(models_dir: &Path) -> bool {
566 ner_model_path(models_dir).exists() && ner_tokenizer_path(models_dir).exists()
567}
568
569#[cfg(test)]
574mod tests {
575 use super::*;
576
577 #[test]
578 fn test_entity_kind_mapping() {
579 let cases = vec![
580 ("PER", EntityKind::Person),
581 ("B-PER", EntityKind::Person),
582 ("I-PER", EntityKind::Person),
583 ("ORG", EntityKind::Organization),
584 ("B-ORG", EntityKind::Organization),
585 ("LOC", EntityKind::Location),
586 ("B-LOC", EntityKind::Location),
587 ("MISC", EntityKind::Other),
588 ("B-MISC", EntityKind::Other),
589 ("unknown", EntityKind::Other),
590 ];
591
592 for (entity_type, expected_kind) in cases {
593 let entity = ExtractedEntity {
594 text: "test".to_string(),
595 entity_type: entity_type.to_string(),
596 confidence: 0.9,
597 byte_start: 0,
598 byte_end: 4,
599 };
600 assert_eq!(
601 entity.to_entity_kind(),
602 expected_kind,
603 "Failed for entity_type: {}",
604 entity_type
605 );
606 }
607 }
608
609 #[test]
610 fn test_model_info() {
611 let info = default_ner_model_info();
612 assert_eq!(info.name, NER_MODEL_NAME);
613 assert!(info.is_default);
614 assert!(info.size_mb > 200.0);
615 }
616
617 #[test]
618 fn test_model_paths() {
619 let models_dir = PathBuf::from("/tmp/models");
620 let model_path = ner_model_path(&models_dir);
621 let tokenizer_path = ner_tokenizer_path(&models_dir);
622
623 assert!(model_path.to_string_lossy().contains("model.onnx"));
624 assert!(tokenizer_path.to_string_lossy().contains("tokenizer.json"));
625 }
626
627 #[test]
628 fn test_ner_labels() {
629 assert_eq!(NER_LABELS.len(), 9);
630 assert_eq!(NER_LABELS[0], "O");
631 assert_eq!(NER_LABELS[1], "B-PER");
632 assert_eq!(NER_LABELS[3], "B-ORG");
633 assert_eq!(NER_LABELS[5], "B-LOC");
634 assert_eq!(NER_LABELS[7], "B-MISC");
635 }
636}