1#![allow(clippy::expect_used)]
3use crate::types::{EntityKind, FrameId};
20use crate::{MemvidError, Result};
21use std::path::{Path, PathBuf};
22
23pub const NER_MODEL_NAME: &str = "distilbert-ner";
29
30pub const NER_MODEL_URL: &str =
32 "https://huggingface.co/dslim/distilbert-NER/resolve/main/onnx/model.onnx";
33
34pub const NER_TOKENIZER_URL: &str =
36 "https://huggingface.co/dslim/distilbert-NER/resolve/main/tokenizer.json";
37
38pub const NER_MODEL_SIZE_MB: f32 = 261.0;
40
41pub const NER_MAX_SEQ_LEN: usize = 512;
43
44#[cfg_attr(not(feature = "logic_mesh"), allow(dead_code))]
46pub const NER_MIN_CONFIDENCE: f32 = 0.5;
47
48#[cfg_attr(not(feature = "logic_mesh"), allow(dead_code))]
51pub const NER_LABELS: &[&str] = &[
52 "O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "B-MISC", "I-MISC",
53];
54
55#[derive(Debug, Clone)]
61pub struct NerModelInfo {
62 pub name: &'static str,
64 pub model_url: &'static str,
66 pub tokenizer_url: &'static str,
68 pub size_mb: f32,
70 pub max_seq_len: usize,
72 pub is_default: bool,
74}
75
76pub static NER_MODELS: &[NerModelInfo] = &[NerModelInfo {
78 name: NER_MODEL_NAME,
79 model_url: NER_MODEL_URL,
80 tokenizer_url: NER_TOKENIZER_URL,
81 size_mb: NER_MODEL_SIZE_MB,
82 max_seq_len: NER_MAX_SEQ_LEN,
83 is_default: true,
84}];
85
86#[must_use]
88pub fn get_ner_model_info(name: &str) -> Option<&'static NerModelInfo> {
89 NER_MODELS.iter().find(|m| m.name == name)
90}
91
92#[must_use]
94pub fn default_ner_model_info() -> &'static NerModelInfo {
95 NER_MODELS
96 .iter()
97 .find(|m| m.is_default)
98 .expect("default NER model must exist")
99}
100
101#[derive(Debug, Clone)]
107pub struct ExtractedEntity {
108 pub text: String,
110 pub entity_type: String,
112 pub confidence: f32,
114 pub byte_start: usize,
116 pub byte_end: usize,
118}
119
120impl ExtractedEntity {
121 #[must_use]
123 pub fn to_entity_kind(&self) -> EntityKind {
124 match self.entity_type.to_uppercase().as_str() {
125 "PER" | "PERSON" | "B-PER" | "I-PER" => EntityKind::Person,
126 "ORG" | "ORGANIZATION" | "B-ORG" | "I-ORG" => EntityKind::Organization,
127 "LOC" | "LOCATION" | "B-LOC" | "I-LOC" => EntityKind::Location,
128 "MISC" | "B-MISC" | "I-MISC" => EntityKind::Other,
129 _ => EntityKind::Other,
130 }
131 }
132}
133
134#[derive(Debug, Clone)]
136pub struct FrameEntities {
137 pub frame_id: FrameId,
139 pub entities: Vec<ExtractedEntity>,
141}
142
143#[cfg(feature = "logic_mesh")]
148pub use model_impl::*;
149
150#[cfg(feature = "logic_mesh")]
151mod model_impl {
152 use super::*;
153 use ort::session::{Session, builder::GraphOptimizationLevel};
154 use ort::value::Tensor;
155 use std::sync::Mutex;
156 use tokenizers::{
157 PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer, TruncationDirection,
158 TruncationParams, TruncationStrategy,
159 };
160
161 pub struct NerModel {
163 session: Session,
165 tokenizer: Mutex<Tokenizer>,
167 model_path: PathBuf,
169 min_confidence: f32,
171 }
172
173 impl NerModel {
174 pub fn load(
181 model_path: impl AsRef<Path>,
182 tokenizer_path: impl AsRef<Path>,
183 min_confidence: Option<f32>,
184 ) -> Result<Self> {
185 let model_path = model_path.as_ref().to_path_buf();
186 let tokenizer_path = tokenizer_path.as_ref();
187
188 let mut tokenizer = Tokenizer::from_file(tokenizer_path).map_err(|e| {
190 MemvidError::NerModelNotAvailable {
191 reason: format!("failed to load tokenizer from {:?}: {}", tokenizer_path, e)
192 .into(),
193 }
194 })?;
195
196 tokenizer.with_padding(Some(PaddingParams {
198 strategy: PaddingStrategy::BatchLongest,
199 direction: PaddingDirection::Right,
200 pad_to_multiple_of: None,
201 pad_id: 0,
202 pad_type_id: 0,
203 pad_token: "[PAD]".to_string(),
204 }));
205
206 tokenizer
207 .with_truncation(Some(TruncationParams {
208 max_length: NER_MAX_SEQ_LEN,
209 strategy: TruncationStrategy::LongestFirst,
210 stride: 0,
211 direction: TruncationDirection::Right,
212 }))
213 .map_err(|e| MemvidError::NerModelNotAvailable {
214 reason: format!("failed to set truncation: {}", e).into(),
215 })?;
216
217 let session = Session::builder()
219 .map_err(|e| MemvidError::NerModelNotAvailable {
220 reason: format!("failed to create session builder: {}", e).into(),
221 })?
222 .with_optimization_level(GraphOptimizationLevel::Level3)
223 .map_err(|e| MemvidError::NerModelNotAvailable {
224 reason: format!("failed to set optimization level: {}", e).into(),
225 })?
226 .with_intra_threads(4)
227 .map_err(|e| MemvidError::NerModelNotAvailable {
228 reason: format!("failed to set threads: {}", e).into(),
229 })?
230 .commit_from_file(&model_path)
231 .map_err(|e| MemvidError::NerModelNotAvailable {
232 reason: format!("failed to load model from {:?}: {}", model_path, e).into(),
233 })?;
234
235 tracing::info!(
236 model = %model_path.display(),
237 "DistilBERT-NER model loaded"
238 );
239
240 Ok(Self {
241 session,
242 tokenizer: Mutex::new(tokenizer),
243 model_path,
244 min_confidence: min_confidence.unwrap_or(NER_MIN_CONFIDENCE),
245 })
246 }
247
248 pub fn extract(&mut self, text: &str) -> Result<Vec<ExtractedEntity>> {
250 if text.trim().is_empty() {
251 return Ok(Vec::new());
252 }
253
254 let tokenizer = self
256 .tokenizer
257 .lock()
258 .map_err(|_| MemvidError::Lock("failed to lock tokenizer".into()))?;
259
260 let encoding =
261 tokenizer
262 .encode(text, true)
263 .map_err(|e| MemvidError::NerModelNotAvailable {
264 reason: format!("tokenization failed: {}", e).into(),
265 })?;
266
267 let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
268 let attention_mask: Vec<i64> = encoding
269 .get_attention_mask()
270 .iter()
271 .map(|&x| x as i64)
272 .collect();
273 let tokens = encoding.get_tokens().to_vec();
274 let offsets = encoding.get_offsets().to_vec();
275
276 drop(tokenizer); let seq_len = input_ids.len();
279
280 let input_ids_array = ndarray::Array2::from_shape_vec((1, seq_len), input_ids)
282 .map_err(|e| MemvidError::NerModelNotAvailable {
283 reason: format!("failed to create input_ids array: {}", e).into(),
284 })?;
285
286 let attention_mask_array =
287 ndarray::Array2::from_shape_vec((1, seq_len), attention_mask).map_err(|e| {
288 MemvidError::NerModelNotAvailable {
289 reason: format!("failed to create attention_mask array: {}", e).into(),
290 }
291 })?;
292
293 let input_ids_tensor = Tensor::from_array(input_ids_array).map_err(|e| {
294 MemvidError::NerModelNotAvailable {
295 reason: format!("failed to create input_ids tensor: {}", e).into(),
296 }
297 })?;
298
299 let attention_mask_tensor = Tensor::from_array(attention_mask_array).map_err(|e| {
300 MemvidError::NerModelNotAvailable {
301 reason: format!("failed to create attention_mask tensor: {}", e).into(),
302 }
303 })?;
304
305 let output_name = self
307 .session
308 .outputs
309 .first()
310 .map(|o| o.name.clone())
311 .unwrap_or_else(|| "logits".into());
312
313 let outputs = self
315 .session
316 .run(ort::inputs![
317 "input_ids" => input_ids_tensor,
318 "attention_mask" => attention_mask_tensor,
319 ])
320 .map_err(|e| MemvidError::NerModelNotAvailable {
321 reason: format!("inference failed: {}", e).into(),
322 })?;
323
324 let logits =
325 outputs
326 .get(&output_name)
327 .ok_or_else(|| MemvidError::NerModelNotAvailable {
328 reason: format!("no output '{}' found", output_name).into(),
329 })?;
330
331 let entities = Self::decode_predictions_static(
333 text,
334 &tokens,
335 &offsets,
336 logits,
337 self.min_confidence,
338 )?;
339
340 Ok(entities)
341 }
342
343 fn decode_predictions_static(
345 original_text: &str,
346 tokens: &[String],
347 offsets: &[(usize, usize)],
348 logits: &ort::value::Value,
349 min_confidence: f32,
350 ) -> Result<Vec<ExtractedEntity>> {
351 let (shape, data) = logits.try_extract_tensor::<f32>().map_err(|e| {
353 MemvidError::NerModelNotAvailable {
354 reason: format!("failed to extract logits: {}", e).into(),
355 }
356 })?;
357
358 let shape_vec: Vec<i64> = shape.iter().copied().collect();
360
361 if shape_vec.len() != 3 {
362 return Err(MemvidError::NerModelNotAvailable {
363 reason: format!("unexpected logits shape: {:?}", shape_vec).into(),
364 });
365 }
366
367 let seq_len = shape_vec[1] as usize;
368 let num_labels = shape_vec[2] as usize;
369
370 let idx = |i: usize, j: usize| -> usize { i * num_labels + j };
372
373 let mut entities = Vec::new();
374 let mut current_entity: Option<(String, usize, usize, f32)> = None;
375
376 for i in 0..seq_len {
377 if i >= tokens.len() || i >= offsets.len() {
378 break;
379 }
380
381 let token = &tokens[i];
383 if token == "[CLS]" || token == "[SEP]" || token == "[PAD]" {
384 if let Some((entity_type, start, end, conf)) = current_entity.take() {
386 if end > start && end <= original_text.len() {
387 let text = original_text[start..end].trim().to_string();
388 if !text.is_empty() {
389 entities.push(ExtractedEntity {
390 text,
391 entity_type,
392 confidence: conf,
393 byte_start: start,
394 byte_end: end,
395 });
396 }
397 }
398 }
399 continue;
400 }
401
402 let mut max_score = f32::NEG_INFINITY;
404 let mut max_label = 0usize;
405
406 for j in 0..num_labels {
407 let score = data[idx(i, j)];
408 if score > max_score {
409 max_score = score;
410 max_label = j;
411 }
412 }
413
414 let mut exp_sum = 0.0f32;
416 for j in 0..num_labels {
417 exp_sum += (data[idx(i, j)] - max_score).exp();
418 }
419 let confidence = 1.0 / exp_sum;
420
421 let label = NER_LABELS.get(max_label).unwrap_or(&"O");
422 let (start_offset, end_offset) = offsets[i];
423
424 if *label == "O" || confidence < min_confidence {
425 if let Some((entity_type, start, end, conf)) = current_entity.take() {
427 if end > start && end <= original_text.len() {
428 let text = original_text[start..end].trim().to_string();
429 if !text.is_empty() {
430 entities.push(ExtractedEntity {
431 text,
432 entity_type,
433 confidence: conf,
434 byte_start: start,
435 byte_end: end,
436 });
437 }
438 }
439 }
440 } else if label.starts_with("B-") {
441 if let Some((entity_type, start, end, conf)) = current_entity.take() {
443 if end > start && end <= original_text.len() {
444 let text = original_text[start..end].trim().to_string();
445 if !text.is_empty() {
446 entities.push(ExtractedEntity {
447 text,
448 entity_type,
449 confidence: conf,
450 byte_start: start,
451 byte_end: end,
452 });
453 }
454 }
455 }
456 let entity_type = label[2..].to_string(); current_entity = Some((entity_type, start_offset, end_offset, confidence));
458 } else if label.starts_with("I-") {
459 if let Some((ref entity_type, start, _, ref mut conf)) = current_entity {
461 let expected_type = &label[2..];
462 if entity_type == expected_type {
463 current_entity = Some((
464 entity_type.clone(),
465 start,
466 end_offset,
467 (*conf + confidence) / 2.0,
468 ));
469 }
470 }
471 }
472 }
473
474 if let Some((entity_type, start, end, conf)) = current_entity {
476 if end > start && end <= original_text.len() {
477 let text = original_text[start..end].trim().to_string();
478 if !text.is_empty() {
479 entities.push(ExtractedEntity {
480 text,
481 entity_type,
482 confidence: conf,
483 byte_start: start,
484 byte_end: end,
485 });
486 }
487 }
488 }
489
490 Ok(entities)
491 }
492
493 pub fn extract_from_frame(
495 &mut self,
496 frame_id: FrameId,
497 content: &str,
498 ) -> Result<FrameEntities> {
499 let entities = self.extract(content)?;
500 Ok(FrameEntities { frame_id, entities })
501 }
502
503 pub fn min_confidence(&self) -> f32 {
505 self.min_confidence
506 }
507
508 pub fn model_path(&self) -> &Path {
510 &self.model_path
511 }
512 }
513
514 impl std::fmt::Debug for NerModel {
515 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
516 f.debug_struct("NerModel")
517 .field("model_path", &self.model_path)
518 .field("min_confidence", &self.min_confidence)
519 .finish()
520 }
521 }
522}
523
524#[cfg(not(feature = "logic_mesh"))]
529#[allow(dead_code)]
530pub struct NerModel {
531 _private: (),
532}
533
534#[cfg(not(feature = "logic_mesh"))]
535#[allow(dead_code)]
536impl NerModel {
537 pub fn load(
538 _model_path: impl AsRef<Path>,
539 _tokenizer_path: impl AsRef<Path>,
540 _min_confidence: Option<f32>,
541 ) -> Result<Self> {
542 Err(MemvidError::FeatureUnavailable {
543 feature: "logic_mesh",
544 })
545 }
546
547 pub fn extract(&self, _text: &str) -> Result<Vec<ExtractedEntity>> {
548 Err(MemvidError::FeatureUnavailable {
549 feature: "logic_mesh",
550 })
551 }
552
553 pub fn extract_from_frame(&self, _frame_id: FrameId, _content: &str) -> Result<FrameEntities> {
554 Err(MemvidError::FeatureUnavailable {
555 feature: "logic_mesh",
556 })
557 }
558}
559
560#[must_use]
566pub fn ner_model_path(models_dir: &Path) -> PathBuf {
567 models_dir.join(NER_MODEL_NAME).join("model.onnx")
568}
569
570#[must_use]
572pub fn ner_tokenizer_path(models_dir: &Path) -> PathBuf {
573 models_dir.join(NER_MODEL_NAME).join("tokenizer.json")
574}
575
576#[must_use]
578pub fn is_ner_model_installed(models_dir: &Path) -> bool {
579 ner_model_path(models_dir).exists() && ner_tokenizer_path(models_dir).exists()
580}
581
582#[cfg(test)]
587mod tests {
588 use super::*;
589
590 #[test]
591 fn test_entity_kind_mapping() {
592 let cases = vec![
593 ("PER", EntityKind::Person),
594 ("B-PER", EntityKind::Person),
595 ("I-PER", EntityKind::Person),
596 ("ORG", EntityKind::Organization),
597 ("B-ORG", EntityKind::Organization),
598 ("LOC", EntityKind::Location),
599 ("B-LOC", EntityKind::Location),
600 ("MISC", EntityKind::Other),
601 ("B-MISC", EntityKind::Other),
602 ("unknown", EntityKind::Other),
603 ];
604
605 for (entity_type, expected_kind) in cases {
606 let entity = ExtractedEntity {
607 text: "test".to_string(),
608 entity_type: entity_type.to_string(),
609 confidence: 0.9,
610 byte_start: 0,
611 byte_end: 4,
612 };
613 assert_eq!(
614 entity.to_entity_kind(),
615 expected_kind,
616 "Failed for entity_type: {}",
617 entity_type
618 );
619 }
620 }
621
622 #[test]
623 fn test_model_info() {
624 let info = default_ner_model_info();
625 assert_eq!(info.name, NER_MODEL_NAME);
626 assert!(info.is_default);
627 assert!(info.size_mb > 200.0);
628 }
629
630 #[test]
631 fn test_model_paths() {
632 let models_dir = PathBuf::from("/tmp/models");
633 let model_path = ner_model_path(&models_dir);
634 let tokenizer_path = ner_tokenizer_path(&models_dir);
635
636 assert!(model_path.to_string_lossy().contains("model.onnx"));
637 assert!(tokenizer_path.to_string_lossy().contains("tokenizer.json"));
638 }
639
640 #[test]
641 fn test_ner_labels() {
642 assert_eq!(NER_LABELS.len(), 9);
643 assert_eq!(NER_LABELS[0], "O");
644 assert_eq!(NER_LABELS[1], "B-PER");
645 assert_eq!(NER_LABELS[3], "B-ORG");
646 assert_eq!(NER_LABELS[5], "B-LOC");
647 assert_eq!(NER_LABELS[7], "B-MISC");
648 }
649}