1use crate::types::{EntityKind, FrameId};
18use crate::{MemvidError, Result};
19use std::path::{Path, PathBuf};
20
21pub const NER_MODEL_NAME: &str = "distilbert-ner";
27
28pub const NER_MODEL_URL: &str =
30 "https://huggingface.co/dslim/distilbert-NER/resolve/main/onnx/model.onnx";
31
32pub const NER_TOKENIZER_URL: &str =
34 "https://huggingface.co/dslim/distilbert-NER/resolve/main/tokenizer.json";
35
36pub const NER_MODEL_SIZE_MB: f32 = 261.0;
38
39pub const NER_MAX_SEQ_LEN: usize = 512;
41
42#[cfg_attr(not(feature = "logic_mesh"), allow(dead_code))]
44pub const NER_MIN_CONFIDENCE: f32 = 0.5;
45
46#[cfg_attr(not(feature = "logic_mesh"), allow(dead_code))]
49pub const NER_LABELS: &[&str] = &[
50 "O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "B-MISC", "I-MISC",
51];
52
53#[derive(Debug, Clone)]
59pub struct NerModelInfo {
60 pub name: &'static str,
62 pub model_url: &'static str,
64 pub tokenizer_url: &'static str,
66 pub size_mb: f32,
68 pub max_seq_len: usize,
70 pub is_default: bool,
72}
73
74pub static NER_MODELS: &[NerModelInfo] = &[NerModelInfo {
76 name: NER_MODEL_NAME,
77 model_url: NER_MODEL_URL,
78 tokenizer_url: NER_TOKENIZER_URL,
79 size_mb: NER_MODEL_SIZE_MB,
80 max_seq_len: NER_MAX_SEQ_LEN,
81 is_default: true,
82}];
83
84pub fn get_ner_model_info(name: &str) -> Option<&'static NerModelInfo> {
86 NER_MODELS.iter().find(|m| m.name == name)
87}
88
89pub fn default_ner_model_info() -> &'static NerModelInfo {
91 NER_MODELS
92 .iter()
93 .find(|m| m.is_default)
94 .expect("default NER model must exist")
95}
96
97#[derive(Debug, Clone)]
103pub struct ExtractedEntity {
104 pub text: String,
106 pub entity_type: String,
108 pub confidence: f32,
110 pub byte_start: usize,
112 pub byte_end: usize,
114}
115
116impl ExtractedEntity {
117 pub fn to_entity_kind(&self) -> EntityKind {
119 match self.entity_type.to_uppercase().as_str() {
120 "PER" | "PERSON" | "B-PER" | "I-PER" => EntityKind::Person,
121 "ORG" | "ORGANIZATION" | "B-ORG" | "I-ORG" => EntityKind::Organization,
122 "LOC" | "LOCATION" | "B-LOC" | "I-LOC" => EntityKind::Location,
123 "MISC" | "B-MISC" | "I-MISC" => EntityKind::Other,
124 _ => EntityKind::Other,
125 }
126 }
127}
128
129#[derive(Debug, Clone)]
131pub struct FrameEntities {
132 pub frame_id: FrameId,
134 pub entities: Vec<ExtractedEntity>,
136}
137
138#[cfg(feature = "logic_mesh")]
143pub use model_impl::*;
144
145#[cfg(feature = "logic_mesh")]
146mod model_impl {
147 use super::*;
148 use ort::session::{Session, builder::GraphOptimizationLevel};
149 use ort::value::Tensor;
150 use std::sync::Mutex;
151 use tokenizers::{
152 PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer, TruncationDirection,
153 TruncationParams, TruncationStrategy,
154 };
155
156 pub struct NerModel {
158 session: Session,
160 tokenizer: Mutex<Tokenizer>,
162 model_path: PathBuf,
164 min_confidence: f32,
166 }
167
168 impl NerModel {
169 pub fn load(
176 model_path: impl AsRef<Path>,
177 tokenizer_path: impl AsRef<Path>,
178 min_confidence: Option<f32>,
179 ) -> Result<Self> {
180 let model_path = model_path.as_ref().to_path_buf();
181 let tokenizer_path = tokenizer_path.as_ref();
182
183 let mut tokenizer = Tokenizer::from_file(tokenizer_path).map_err(|e| {
185 MemvidError::NerModelNotAvailable {
186 reason: format!("failed to load tokenizer from {:?}: {}", tokenizer_path, e)
187 .into(),
188 }
189 })?;
190
191 tokenizer.with_padding(Some(PaddingParams {
193 strategy: PaddingStrategy::BatchLongest,
194 direction: PaddingDirection::Right,
195 pad_to_multiple_of: None,
196 pad_id: 0,
197 pad_type_id: 0,
198 pad_token: "[PAD]".to_string(),
199 }));
200
201 tokenizer
202 .with_truncation(Some(TruncationParams {
203 max_length: NER_MAX_SEQ_LEN,
204 strategy: TruncationStrategy::LongestFirst,
205 stride: 0,
206 direction: TruncationDirection::Right,
207 }))
208 .map_err(|e| MemvidError::NerModelNotAvailable {
209 reason: format!("failed to set truncation: {}", e).into(),
210 })?;
211
212 let session = Session::builder()
214 .map_err(|e| MemvidError::NerModelNotAvailable {
215 reason: format!("failed to create session builder: {}", e).into(),
216 })?
217 .with_optimization_level(GraphOptimizationLevel::Level3)
218 .map_err(|e| MemvidError::NerModelNotAvailable {
219 reason: format!("failed to set optimization level: {}", e).into(),
220 })?
221 .with_intra_threads(4)
222 .map_err(|e| MemvidError::NerModelNotAvailable {
223 reason: format!("failed to set threads: {}", e).into(),
224 })?
225 .commit_from_file(&model_path)
226 .map_err(|e| MemvidError::NerModelNotAvailable {
227 reason: format!("failed to load model from {:?}: {}", model_path, e).into(),
228 })?;
229
230 tracing::info!(
231 model = %model_path.display(),
232 "DistilBERT-NER model loaded"
233 );
234
235 Ok(Self {
236 session,
237 tokenizer: Mutex::new(tokenizer),
238 model_path,
239 min_confidence: min_confidence.unwrap_or(NER_MIN_CONFIDENCE),
240 })
241 }
242
243 pub fn extract(&mut self, text: &str) -> Result<Vec<ExtractedEntity>> {
245 if text.trim().is_empty() {
246 return Ok(Vec::new());
247 }
248
249 let tokenizer = self
251 .tokenizer
252 .lock()
253 .map_err(|_| MemvidError::Lock("failed to lock tokenizer".into()))?;
254
255 let encoding =
256 tokenizer
257 .encode(text, true)
258 .map_err(|e| MemvidError::NerModelNotAvailable {
259 reason: format!("tokenization failed: {}", e).into(),
260 })?;
261
262 let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
263 let attention_mask: Vec<i64> = encoding
264 .get_attention_mask()
265 .iter()
266 .map(|&x| x as i64)
267 .collect();
268 let tokens = encoding.get_tokens().to_vec();
269 let offsets = encoding.get_offsets().to_vec();
270
271 drop(tokenizer); let seq_len = input_ids.len();
274
275 let input_ids_array = ndarray::Array2::from_shape_vec((1, seq_len), input_ids)
277 .map_err(|e| MemvidError::NerModelNotAvailable {
278 reason: format!("failed to create input_ids array: {}", e).into(),
279 })?;
280
281 let attention_mask_array =
282 ndarray::Array2::from_shape_vec((1, seq_len), attention_mask).map_err(|e| {
283 MemvidError::NerModelNotAvailable {
284 reason: format!("failed to create attention_mask array: {}", e).into(),
285 }
286 })?;
287
288 let input_ids_tensor = Tensor::from_array(input_ids_array).map_err(|e| {
289 MemvidError::NerModelNotAvailable {
290 reason: format!("failed to create input_ids tensor: {}", e).into(),
291 }
292 })?;
293
294 let attention_mask_tensor = Tensor::from_array(attention_mask_array).map_err(|e| {
295 MemvidError::NerModelNotAvailable {
296 reason: format!("failed to create attention_mask tensor: {}", e).into(),
297 }
298 })?;
299
300 let output_name = self
302 .session
303 .outputs
304 .first()
305 .map(|o| o.name.clone())
306 .unwrap_or_else(|| "logits".into());
307
308 let outputs = self
310 .session
311 .run(ort::inputs![
312 "input_ids" => input_ids_tensor,
313 "attention_mask" => attention_mask_tensor,
314 ])
315 .map_err(|e| MemvidError::NerModelNotAvailable {
316 reason: format!("inference failed: {}", e).into(),
317 })?;
318
319 let logits =
320 outputs
321 .get(&output_name)
322 .ok_or_else(|| MemvidError::NerModelNotAvailable {
323 reason: format!("no output '{}' found", output_name).into(),
324 })?;
325
326 let entities = Self::decode_predictions_static(
328 text,
329 &tokens,
330 &offsets,
331 logits,
332 self.min_confidence,
333 )?;
334
335 Ok(entities)
336 }
337
338 fn decode_predictions_static(
340 original_text: &str,
341 tokens: &[String],
342 offsets: &[(usize, usize)],
343 logits: &ort::value::Value,
344 min_confidence: f32,
345 ) -> Result<Vec<ExtractedEntity>> {
346 let (shape, data) = logits.try_extract_tensor::<f32>().map_err(|e| {
348 MemvidError::NerModelNotAvailable {
349 reason: format!("failed to extract logits: {}", e).into(),
350 }
351 })?;
352
353 let shape_vec: Vec<i64> = shape.iter().copied().collect();
355
356 if shape_vec.len() != 3 {
357 return Err(MemvidError::NerModelNotAvailable {
358 reason: format!("unexpected logits shape: {:?}", shape_vec).into(),
359 });
360 }
361
362 let seq_len = shape_vec[1] as usize;
363 let num_labels = shape_vec[2] as usize;
364
365 let idx = |i: usize, j: usize| -> usize { i * num_labels + j };
367
368 let mut entities = Vec::new();
369 let mut current_entity: Option<(String, usize, usize, f32)> = None;
370
371 for i in 0..seq_len {
372 if i >= tokens.len() || i >= offsets.len() {
373 break;
374 }
375
376 let token = &tokens[i];
378 if token == "[CLS]" || token == "[SEP]" || token == "[PAD]" {
379 if let Some((entity_type, start, end, conf)) = current_entity.take() {
381 if end > start && end <= original_text.len() {
382 let text = original_text[start..end].trim().to_string();
383 if !text.is_empty() {
384 entities.push(ExtractedEntity {
385 text,
386 entity_type,
387 confidence: conf,
388 byte_start: start,
389 byte_end: end,
390 });
391 }
392 }
393 }
394 continue;
395 }
396
397 let mut max_score = f32::NEG_INFINITY;
399 let mut max_label = 0usize;
400
401 for j in 0..num_labels {
402 let score = data[idx(i, j)];
403 if score > max_score {
404 max_score = score;
405 max_label = j;
406 }
407 }
408
409 let mut exp_sum = 0.0f32;
411 for j in 0..num_labels {
412 exp_sum += (data[idx(i, j)] - max_score).exp();
413 }
414 let confidence = 1.0 / exp_sum;
415
416 let label = NER_LABELS.get(max_label).unwrap_or(&"O");
417 let (start_offset, end_offset) = offsets[i];
418
419 if *label == "O" || confidence < min_confidence {
420 if let Some((entity_type, start, end, conf)) = current_entity.take() {
422 if end > start && end <= original_text.len() {
423 let text = original_text[start..end].trim().to_string();
424 if !text.is_empty() {
425 entities.push(ExtractedEntity {
426 text,
427 entity_type,
428 confidence: conf,
429 byte_start: start,
430 byte_end: end,
431 });
432 }
433 }
434 }
435 } else if label.starts_with("B-") {
436 if let Some((entity_type, start, end, conf)) = current_entity.take() {
438 if end > start && end <= original_text.len() {
439 let text = original_text[start..end].trim().to_string();
440 if !text.is_empty() {
441 entities.push(ExtractedEntity {
442 text,
443 entity_type,
444 confidence: conf,
445 byte_start: start,
446 byte_end: end,
447 });
448 }
449 }
450 }
451 let entity_type = label[2..].to_string(); current_entity = Some((entity_type, start_offset, end_offset, confidence));
453 } else if label.starts_with("I-") {
454 if let Some((ref entity_type, start, _, ref mut conf)) = current_entity {
456 let expected_type = &label[2..];
457 if entity_type == expected_type {
458 current_entity = Some((
459 entity_type.clone(),
460 start,
461 end_offset,
462 (*conf + confidence) / 2.0,
463 ));
464 }
465 }
466 }
467 }
468
469 if let Some((entity_type, start, end, conf)) = current_entity {
471 if end > start && end <= original_text.len() {
472 let text = original_text[start..end].trim().to_string();
473 if !text.is_empty() {
474 entities.push(ExtractedEntity {
475 text,
476 entity_type,
477 confidence: conf,
478 byte_start: start,
479 byte_end: end,
480 });
481 }
482 }
483 }
484
485 Ok(entities)
486 }
487
488 pub fn extract_from_frame(
490 &mut self,
491 frame_id: FrameId,
492 content: &str,
493 ) -> Result<FrameEntities> {
494 let entities = self.extract(content)?;
495 Ok(FrameEntities { frame_id, entities })
496 }
497
498 pub fn min_confidence(&self) -> f32 {
500 self.min_confidence
501 }
502
503 pub fn model_path(&self) -> &Path {
505 &self.model_path
506 }
507 }
508
509 impl std::fmt::Debug for NerModel {
510 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
511 f.debug_struct("NerModel")
512 .field("model_path", &self.model_path)
513 .field("min_confidence", &self.min_confidence)
514 .finish()
515 }
516 }
517}
518
519#[cfg(not(feature = "logic_mesh"))]
524#[allow(dead_code)]
525pub struct NerModel {
526 _private: (),
527}
528
529#[cfg(not(feature = "logic_mesh"))]
530#[allow(dead_code)]
531impl NerModel {
532 pub fn load(
533 _model_path: impl AsRef<Path>,
534 _tokenizer_path: impl AsRef<Path>,
535 _min_confidence: Option<f32>,
536 ) -> Result<Self> {
537 Err(MemvidError::FeatureUnavailable {
538 feature: "logic_mesh",
539 })
540 }
541
542 pub fn extract(&self, _text: &str) -> Result<Vec<ExtractedEntity>> {
543 Err(MemvidError::FeatureUnavailable {
544 feature: "logic_mesh",
545 })
546 }
547
548 pub fn extract_from_frame(&self, _frame_id: FrameId, _content: &str) -> Result<FrameEntities> {
549 Err(MemvidError::FeatureUnavailable {
550 feature: "logic_mesh",
551 })
552 }
553}
554
555pub fn ner_model_path(models_dir: &Path) -> PathBuf {
561 models_dir.join(NER_MODEL_NAME).join("model.onnx")
562}
563
564pub fn ner_tokenizer_path(models_dir: &Path) -> PathBuf {
566 models_dir.join(NER_MODEL_NAME).join("tokenizer.json")
567}
568
569pub fn is_ner_model_installed(models_dir: &Path) -> bool {
571 ner_model_path(models_dir).exists() && ner_tokenizer_path(models_dir).exists()
572}
573
574#[cfg(test)]
579mod tests {
580 use super::*;
581
582 #[test]
583 fn test_entity_kind_mapping() {
584 let cases = vec![
585 ("PER", EntityKind::Person),
586 ("B-PER", EntityKind::Person),
587 ("I-PER", EntityKind::Person),
588 ("ORG", EntityKind::Organization),
589 ("B-ORG", EntityKind::Organization),
590 ("LOC", EntityKind::Location),
591 ("B-LOC", EntityKind::Location),
592 ("MISC", EntityKind::Other),
593 ("B-MISC", EntityKind::Other),
594 ("unknown", EntityKind::Other),
595 ];
596
597 for (entity_type, expected_kind) in cases {
598 let entity = ExtractedEntity {
599 text: "test".to_string(),
600 entity_type: entity_type.to_string(),
601 confidence: 0.9,
602 byte_start: 0,
603 byte_end: 4,
604 };
605 assert_eq!(
606 entity.to_entity_kind(),
607 expected_kind,
608 "Failed for entity_type: {}",
609 entity_type
610 );
611 }
612 }
613
614 #[test]
615 fn test_model_info() {
616 let info = default_ner_model_info();
617 assert_eq!(info.name, NER_MODEL_NAME);
618 assert!(info.is_default);
619 assert!(info.size_mb > 200.0);
620 }
621
622 #[test]
623 fn test_model_paths() {
624 let models_dir = PathBuf::from("/tmp/models");
625 let model_path = ner_model_path(&models_dir);
626 let tokenizer_path = ner_tokenizer_path(&models_dir);
627
628 assert!(model_path.to_string_lossy().contains("model.onnx"));
629 assert!(tokenizer_path.to_string_lossy().contains("tokenizer.json"));
630 }
631
632 #[test]
633 fn test_ner_labels() {
634 assert_eq!(NER_LABELS.len(), 9);
635 assert_eq!(NER_LABELS[0], "O");
636 assert_eq!(NER_LABELS[1], "B-PER");
637 assert_eq!(NER_LABELS[3], "B-ORG");
638 assert_eq!(NER_LABELS[5], "B-LOC");
639 assert_eq!(NER_LABELS[7], "B-MISC");
640 }
641}