1use crate::error::{InferenceError, Result};
14use ort::inputs;
15use ort::session::builder::GraphOptimizationLevel;
16use ort::session::Session;
17use ort::value::Tensor;
18use parking_lot::Mutex;
19use regex::Regex;
20use std::collections::HashMap;
21use std::path::PathBuf;
22use std::sync::Arc;
23use tokenizers::Tokenizer;
24use tracing::{debug, info, instrument, warn};
25
26#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
32pub struct ExtractedEntity {
33 pub entity_type: String,
35 pub value: String,
37 pub score: f32,
39 pub start: usize,
41 pub end: usize,
43}
44
45impl ExtractedEntity {
46 pub fn to_tag(&self) -> String {
52 let normalized_value = normalize_tag_value(&self.value);
53 format!("entity:{}:{}", self.entity_type, normalized_value)
54 }
55
56 pub fn dedup_key(&self) -> (String, String) {
58 (self.entity_type.clone(), normalize_tag_value(&self.value))
59 }
60}
61
62pub fn normalize_label(label: &str) -> String {
70 label.trim().to_lowercase().replace(' ', "_")
71}
72
73fn normalize_tag_value(value: &str) -> String {
80 value
81 .split_whitespace()
82 .collect::<Vec<_>>()
83 .join(" ")
84 .to_lowercase()
85 .replace(':', "_")
86}
87
88pub fn deduplicate_entities(mut entities: Vec<ExtractedEntity>) -> Vec<ExtractedEntity> {
95 entities.sort_by(|a, b| {
97 b.score
98 .partial_cmp(&a.score)
99 .unwrap_or(std::cmp::Ordering::Equal)
100 });
101
102 let mut seen: HashMap<(String, String), ()> = HashMap::new();
103 let mut out: Vec<ExtractedEntity> = Vec::with_capacity(entities.len());
104
105 for entity in entities {
106 let key = entity.dedup_key();
107 if seen.insert(key, ()).is_none() {
108 out.push(entity);
109 }
110 }
111
112 out.sort_by_key(|e| e.start);
114 out
115}
116
117struct RulePatterns {
122 uuid: Regex,
123 url: Regex,
124 email: Regex,
125 iso_date: Regex,
126 natural_date: Regex,
127 ip_v4: Regex,
128}
129
130impl RulePatterns {
131 fn new() -> Self {
132 Self {
133 uuid: Regex::new(
134 r"(?i)\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b",
135 )
136 .expect("uuid regex"),
137 url: Regex::new(r#"https?://[^\s<>\[\]()"']+"#).expect("url regex"),
138 email: Regex::new(r"[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}")
139 .expect("email regex"),
140 iso_date: Regex::new(
141 r"\b\d{4}-(?:0[1-9]|1[0-2])-(?:0[1-9]|[12]\d|3[01])\b",
142 )
143 .expect("iso_date regex"),
144 natural_date: Regex::new(
145 r"(?i)\b(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{1,2}(?:,\s*\d{4})?\b",
146 )
147 .expect("natural_date regex"),
148 ip_v4: Regex::new(
149 r"\b(?:(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\b",
150 )
151 .expect("ipv4 regex"),
152 }
153 }
154}
155
156lazy_static::lazy_static! {
157 static ref RULE_PATTERNS: RulePatterns = RulePatterns::new();
158}
159
160pub fn rule_based_extract(text: &str) -> Vec<ExtractedEntity> {
165 let mut entities: Vec<ExtractedEntity> = Vec::new();
166
167 let push = |entities: &mut Vec<ExtractedEntity>, entity_type: &str, m: regex::Match| {
168 entities.push(ExtractedEntity {
169 entity_type: entity_type.to_string(),
170 value: m.as_str().trim().to_string(),
171 score: 1.0,
172 start: m.start(),
173 end: m.end(),
174 });
175 };
176
177 for m in RULE_PATTERNS.email.find_iter(text) {
179 push(&mut entities, "email", m);
180 }
181 for m in RULE_PATTERNS.url.find_iter(text) {
182 if !entities.iter().any(|e| e.start == m.start()) {
184 push(&mut entities, "url", m);
185 }
186 }
187 for m in RULE_PATTERNS.uuid.find_iter(text) {
188 push(&mut entities, "uuid", m);
189 }
190 for m in RULE_PATTERNS.iso_date.find_iter(text) {
191 push(&mut entities, "date", m);
192 }
193 for m in RULE_PATTERNS.natural_date.find_iter(text) {
194 if !entities
195 .iter()
196 .any(|e| e.start == m.start() && e.entity_type == "date")
197 {
198 push(&mut entities, "date", m);
199 }
200 }
201 for m in RULE_PATTERNS.ip_v4.find_iter(text) {
202 push(&mut entities, "ip", m);
203 }
204
205 entities
206}
207
208const GLINER_MODEL_REPO: &str = "onnx-community/gliner_medium-v2.1";
213const GLINER_TOKENIZER_REPO: &str = "onnx-community/gliner_medium-v2.1";
214const GLINER_ONNX_FILE: &str = "onnx/model_quantized.onnx";
215
216const MAX_SPAN_WIDTH: usize = 12;
218const DEFAULT_SCORE_THRESHOLD: f32 = 0.5;
220const MAX_TEXT_WORDS: usize = 300;
223
224pub struct GlinerEngine {
228 session: Arc<Mutex<Session>>,
229 tokenizer: Arc<Tokenizer>,
230}
231
232impl GlinerEngine {
233 #[instrument(skip_all)]
235 pub async fn new(num_threads: Option<usize>) -> Result<Self> {
236 let threads = num_threads.unwrap_or(1);
237 info!("Initializing GLiNER NER engine (threads={})", threads);
238
239 let (tokenizer_path, onnx_path) = Self::download_model_files().await?;
240
241 let tokenizer = Tokenizer::from_file(&tokenizer_path)
242 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
243
244 let session = Session::builder()
245 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
246 .with_optimization_level(GraphOptimizationLevel::Level3)
247 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
248 .with_intra_threads(threads)
249 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
250 .commit_from_file(&onnx_path)
251 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
252
253 info!("GLiNER engine ready");
254 Ok(Self {
255 session: Arc::new(Mutex::new(session)),
256 tokenizer: Arc::new(tokenizer),
257 })
258 }
259
260 pub async fn extract(&self, text: &str, entity_types: &[&str]) -> Result<Vec<ExtractedEntity>> {
267 if entity_types.is_empty() || text.is_empty() {
268 return Ok(Vec::new());
269 }
270
271 let text_owned = text.to_string();
272 let entity_types_owned: Vec<String> = entity_types.iter().map(|s| s.to_string()).collect();
273 let session = self.session.clone();
274 let tokenizer = self.tokenizer.clone();
275
276 tokio::task::spawn_blocking(move || {
277 Self::run_inference(
278 &text_owned,
279 &entity_types_owned
280 .iter()
281 .map(|s| s.as_str())
282 .collect::<Vec<_>>(),
283 &session,
284 &tokenizer,
285 )
286 })
287 .await
288 .map_err(|e| InferenceError::HubError(format!("GLiNER inference task panicked: {}", e)))?
289 }
290
291 fn run_inference(
292 text: &str,
293 entity_types: &[&str],
294 session: &Arc<Mutex<Session>>,
295 tokenizer: &Tokenizer,
296 ) -> Result<Vec<ExtractedEntity>> {
297 let text = truncate_to_word_limit(text, MAX_TEXT_WORDS);
301
302 let prefix = entity_types.join(" << >> ");
305 let prefix_plus_sep = format!("{} << >> ", prefix);
306 let full_text = format!("{}{}", prefix_plus_sep, text);
307
308 let encoding = tokenizer
310 .encode(full_text.as_str(), true)
311 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
312
313 let token_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
314 let attention_mask: Vec<i64> = encoding
315 .get_attention_mask()
316 .iter()
317 .map(|&x| x as i64)
318 .collect();
319 let seq_len = token_ids.len();
320
321 let prefix_encoding = tokenizer
325 .encode(prefix_plus_sep.as_str(), false)
326 .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
327 let prefix_word_count = count_distinct_word_ids(prefix_encoding.get_word_ids());
328
329 let word_ids = encoding.get_word_ids();
341 let token_offsets = encoding.get_offsets(); let mut words_mask = vec![0i64; seq_len];
344 let mut last_word_id: Option<u32> = None;
345 let mut cumulative_word_count = 0usize; let mut text_word_count = 0usize;
347 let mut text_word_ids: Vec<u32> = Vec::new();
348 let mut word_byte_ranges: HashMap<u32, (usize, usize)> = HashMap::new();
350
351 for (i, &wid_opt) in word_ids.iter().enumerate() {
352 let wid = match wid_opt {
353 Some(w) => w,
354 None => {
355 last_word_id = None;
356 continue;
357 }
358 };
359
360 let (tok_start, tok_end) = token_offsets[i];
362 let entry = word_byte_ranges.entry(wid).or_insert((tok_start, tok_end));
363 if tok_start < entry.0 {
364 entry.0 = tok_start;
365 }
366 if tok_end > entry.1 {
367 entry.1 = tok_end;
368 }
369
370 let is_new_word = last_word_id.map(|lw| lw != wid).unwrap_or(true);
371 if is_new_word {
372 if cumulative_word_count >= prefix_word_count {
373 words_mask[i] = 1;
374 text_word_count += 1;
375 text_word_ids.push(wid);
376 }
377 cumulative_word_count += 1;
378 }
379 last_word_id = Some(wid);
380 }
381
382 if text_word_count == 0 {
383 debug!("No text words after entity type prefix — skipping inference");
384 return Ok(Vec::new());
385 }
386 let text_lengths = vec![text_word_count as i64];
387
388 let prefix_byte_offset = prefix_plus_sep.len();
391
392 let mut span_idx_flat: Vec<i64> = Vec::new();
394 let mut span_mask: Vec<bool> = Vec::new();
395
396 for start in 0..text_word_count {
397 for end in start..text_word_count.min(start + MAX_SPAN_WIDTH) {
398 span_idx_flat.push(start as i64);
399 span_idx_flat.push(end as i64);
400 span_mask.push(true);
401 }
402 }
403
404 let num_spans = span_mask.len();
405 if num_spans == 0 {
406 return Ok(Vec::new());
407 }
408
409 let logits_raw: Vec<f32> = {
413 let mut session_guard = session.lock();
414
415 let input_ids_t = Tensor::<i64>::from_array(([1usize, seq_len], token_ids))
416 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
417 let attn_mask_t = Tensor::<i64>::from_array(([1usize, seq_len], attention_mask))
418 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
419 let words_mask_t = Tensor::<i64>::from_array(([1usize, seq_len], words_mask))
420 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
421 let text_lengths_t = Tensor::<i64>::from_array(([1usize, 1usize], text_lengths))
423 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
424 let span_idx_t = Tensor::<i64>::from_array(([1usize, num_spans, 2], span_idx_flat))
425 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
426 let span_mask_t = Tensor::<bool>::from_array(([1usize, num_spans], span_mask))
427 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
428
429 let outputs = session_guard
430 .run(inputs![
431 "input_ids" => input_ids_t,
432 "attention_mask" => attn_mask_t,
433 "words_mask" => words_mask_t,
434 "text_lengths" => text_lengths_t,
435 "span_idx" => span_idx_t,
436 "span_mask" => span_mask_t,
437 ])
438 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
439
440 let (_shape, logits_slice) = outputs[0]
442 .try_extract_tensor::<f32>()
443 .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
444 logits_slice.to_vec()
445 };
446
447 let num_entity_types = entity_types.len();
449 if logits_raw.len() != num_spans * num_entity_types {
450 warn!(
451 "GLiNER logits shape mismatch: got {}, expected {}",
452 logits_raw.len(),
453 num_spans * num_entity_types
454 );
455 return Ok(Vec::new());
456 }
457
458 let mut raw_entities: Vec<(usize, usize, usize, f32)> = Vec::new(); for (span_i, (start_w, end_w)) in iter_spans(text_word_count).enumerate() {
462 for (type_i, _) in entity_types.iter().enumerate() {
463 let score = sigmoid(logits_raw[span_i * num_entity_types + type_i]);
464 if score >= DEFAULT_SCORE_THRESHOLD {
465 raw_entities.push((type_i, start_w, end_w, score));
466 }
467 }
468 }
469
470 raw_entities.sort_by(|a, b| b.3.partial_cmp(&a.3).unwrap_or(std::cmp::Ordering::Equal));
473 let mut kept: Vec<(usize, usize, usize, f32)> = Vec::new();
474 'outer: for candidate in &raw_entities {
475 for kept_span in &kept {
476 if kept_span.0 == candidate.0
477 && kept_span.1 <= candidate.2
478 && candidate.1 <= kept_span.2
479 {
480 continue 'outer;
481 }
482 }
483 kept.push(*candidate);
484 }
485
486 let mut entities: Vec<ExtractedEntity> = kept
492 .into_iter()
493 .filter_map(|(type_i, start_w, end_w, score)| {
494 let start_wid = *text_word_ids.get(start_w)?;
495 let end_wid = *text_word_ids.get(end_w)?;
496 let &(start_byte_full, _) = word_byte_ranges.get(&start_wid)?;
497 let &(_, end_byte_full) = word_byte_ranges.get(&end_wid)?;
498
499 let start_byte = start_byte_full.saturating_sub(prefix_byte_offset);
501 let end_byte = end_byte_full.saturating_sub(prefix_byte_offset);
502
503 if start_byte >= end_byte || end_byte > text.len() {
504 return None;
505 }
506
507 let value = text[start_byte..end_byte].trim().to_string();
508 if value.is_empty() {
509 return None;
510 }
511
512 let entity_type = normalize_label(entity_types[type_i]);
515
516 Some(ExtractedEntity {
517 entity_type,
518 value,
519 score,
520 start: start_byte,
521 end: end_byte,
522 })
523 })
524 .collect();
525
526 entities.sort_by_key(|e| e.start);
527 debug!("GLiNER extracted {} entities", entities.len());
528 Ok(entities)
529 }
530
531 #[instrument(skip_all)]
534 async fn download_model_files() -> Result<(PathBuf, PathBuf)> {
535 info!(
536 "Resolving GLiNER model files: tokenizer={}, onnx={}",
537 GLINER_TOKENIZER_REPO, GLINER_MODEL_REPO
538 );
539
540 let tokenizer_cache = Self::model_cache_dir(GLINER_TOKENIZER_REPO)?;
541 let onnx_cache = Self::model_cache_dir(GLINER_MODEL_REPO)?;
542 let onnx_subdir = onnx_cache.join("onnx");
543 std::fs::create_dir_all(&onnx_subdir)?;
544
545 let local_tokenizer = tokenizer_cache.join("tokenizer.json");
546 let local_onnx = onnx_subdir.join("model_quantized.onnx");
547
548 if !local_tokenizer.exists() || !local_onnx.exists() {
549 let tok_cache = tokenizer_cache.clone();
550 let onnx_c = onnx_cache.clone();
551 let tok_exists = local_tokenizer.exists();
552 let onnx_exists = local_onnx.exists();
553
554 tokio::task::spawn_blocking(move || {
555 if !tok_exists {
556 crate::engine::EmbeddingEngine::download_hf_file_pub(
557 GLINER_TOKENIZER_REPO,
558 "tokenizer.json",
559 &tok_cache,
560 )
561 .map_err(|e| {
562 InferenceError::HubError(format!(
563 "Failed to download GLiNER tokenizer: {}",
564 e
565 ))
566 })?;
567 }
568 if !onnx_exists {
569 crate::engine::EmbeddingEngine::download_hf_file_pub(
570 GLINER_MODEL_REPO,
571 GLINER_ONNX_FILE,
572 &onnx_c,
573 )
574 .map_err(|e| {
575 InferenceError::HubError(format!(
576 "Failed to download GLiNER ONNX model: {}",
577 e
578 ))
579 })?;
580 }
581 Ok::<_, InferenceError>(())
582 })
583 .await
584 .map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
585 } else {
586 info!("GLiNER model files found in local cache");
587 }
588
589 let final_onnx = onnx_cache.join(GLINER_ONNX_FILE);
590 Ok((local_tokenizer, final_onnx))
591 }
592
593 fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
594 let base = std::env::var("HF_HOME")
595 .map(PathBuf::from)
596 .unwrap_or_else(|_| {
597 let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
598 PathBuf::from(home).join(".cache").join("huggingface")
599 });
600 let dir = base.join("dakera").join(model_id.replace('/', "--"));
601 std::fs::create_dir_all(&dir)?;
602 Ok(dir)
603 }
604}
605
606pub struct NerEngine {
612 gliner: Option<Arc<GlinerEngine>>,
613}
614
615impl NerEngine {
616 pub fn rule_based_only() -> Self {
618 Self { gliner: None }
619 }
620
621 pub async fn with_gliner(num_threads: Option<usize>) -> Result<Self> {
623 let gliner = GlinerEngine::new(num_threads).await?;
624 Ok(Self {
625 gliner: Some(Arc::new(gliner)),
626 })
627 }
628
629 pub async fn extract(&self, text: &str, gliner_types: &[&str]) -> Vec<ExtractedEntity> {
638 let mut entities = rule_based_extract(text);
639
640 if let Some(ref gliner) = self.gliner {
641 if !gliner_types.is_empty() {
642 match gliner.extract(text, gliner_types).await {
643 Ok(neural) => {
644 for ne in neural {
645 if !entities
647 .iter()
648 .any(|e| e.start == ne.start && e.end == ne.end)
649 {
650 entities.push(ne);
651 }
652 }
653 }
654 Err(e) => {
655 warn!("GLiNER extraction failed, using rule-based only: {}", e);
656 }
657 }
658 }
659 }
660
661 entities.sort_by_key(|e| e.start);
662
663 deduplicate_entities(entities)
665 }
666}
667
668fn count_distinct_word_ids(word_ids: &[Option<u32>]) -> usize {
674 let mut seen = std::collections::HashSet::new();
675 for &wid in word_ids {
676 if let Some(w) = wid {
677 seen.insert(w);
678 }
679 }
680 seen.len()
681}
682
683fn truncate_to_word_limit(text: &str, max_words: usize) -> &str {
688 let mut word_count = 0usize;
689 let mut byte_end = text.len();
690 let mut in_word = false;
691
692 for (i, ch) in text.char_indices() {
693 if ch.is_whitespace() {
694 if in_word {
695 word_count += 1;
696 if word_count >= max_words {
697 byte_end = i;
698 break;
699 }
700 }
701 in_word = false;
702 } else {
703 in_word = true;
704 }
705 }
706
707 &text[..byte_end]
708}
709
710fn iter_spans(num_words: usize) -> impl Iterator<Item = (usize, usize)> {
712 (0..num_words).flat_map(move |start| {
713 let max_end = num_words.min(start + MAX_SPAN_WIDTH);
714 (start..max_end).map(move |end| (start, end))
715 })
716}
717
718#[inline]
720fn sigmoid(x: f32) -> f32 {
721 if x >= 0.0 {
722 1.0 / (1.0 + (-x).exp())
723 } else {
724 let ex = x.exp();
725 ex / (1.0 + ex)
726 }
727}
728
729#[cfg(test)]
734mod tests {
735 use super::*;
736
737 #[test]
738 fn test_rule_based_uuid() {
739 let text = "session id is 550e8400-e29b-41d4-a716-446655440000 here";
740 let entities = rule_based_extract(text);
741 assert!(entities.iter().any(|e| e.entity_type == "uuid"));
742 }
743
744 #[test]
745 fn test_rule_based_url() {
746 let text = "check https://example.com/path?q=1 for details";
747 let entities = rule_based_extract(text);
748 assert!(entities.iter().any(|e| e.entity_type == "url"));
749 }
750
751 #[test]
752 fn test_rule_based_email() {
753 let text = "contact alice@example.com for support";
754 let entities = rule_based_extract(text);
755 assert!(entities.iter().any(|e| e.entity_type == "email"));
756 assert!(!entities.iter().any(|e| e.entity_type == "url"));
758 }
759
760 #[test]
761 fn test_rule_based_iso_date() {
762 let text = "released on 2024-03-15 at noon";
763 let entities = rule_based_extract(text);
764 assert!(entities
765 .iter()
766 .any(|e| e.entity_type == "date" && e.value == "2024-03-15"));
767 }
768
769 #[test]
770 fn test_rule_based_natural_date() {
771 let text = "meeting on March 15, 2024 at noon";
772 let entities = rule_based_extract(text);
773 assert!(entities.iter().any(|e| e.entity_type == "date"));
774 }
775
776 #[test]
777 fn test_entity_to_tag_lowercase_value() {
778 let e = ExtractedEntity {
780 entity_type: "person".to_string(),
781 value: "Alice Smith".to_string(),
782 score: 0.9,
783 start: 0,
784 end: 11,
785 };
786 assert_eq!(e.to_tag(), "entity:person:alice smith");
787 }
788
789 #[test]
790 fn test_entity_to_tag_colon_escaping() {
791 let e = ExtractedEntity {
792 entity_type: "url".to_string(),
793 value: "http://example.com:8080/path".to_string(),
794 score: 1.0,
795 start: 0,
796 end: 27,
797 };
798 let tag = e.to_tag();
799 let parts: Vec<&str> = tag.splitn(3, ':').collect();
800 assert_eq!(parts.len(), 3, "tag should have 3 parts: {}", tag);
801 assert_eq!(parts[0], "entity");
802 assert_eq!(parts[1], "url");
803 assert!(
804 !parts[2].contains(':'),
805 "value should not contain colons: {}",
806 parts[2]
807 );
808 }
809
810 #[test]
811 fn test_entity_to_tag_normalizes_whitespace() {
812 let e = ExtractedEntity {
813 entity_type: "person".to_string(),
814 value: " John Doe ".to_string(),
815 score: 0.9,
816 start: 0,
817 end: 12,
818 };
819 assert_eq!(e.to_tag(), "entity:person:john doe");
820 }
821
822 #[test]
823 fn test_normalize_label() {
824 assert_eq!(normalize_label("Person"), "person");
825 assert_eq!(normalize_label("Law Firm"), "law_firm");
826 assert_eq!(normalize_label(" ORG "), "org");
827 assert_eq!(normalize_label("ORGANIZATION"), "organization");
828 assert_eq!(normalize_label("location"), "location");
829 }
830
831 #[test]
832 fn test_deduplicate_same_value_different_positions() {
833 let entities = vec![
835 ExtractedEntity {
836 entity_type: "person".to_string(),
837 value: "Alice".to_string(),
838 score: 0.8,
839 start: 0,
840 end: 5,
841 },
842 ExtractedEntity {
843 entity_type: "person".to_string(),
844 value: "Alice".to_string(),
845 score: 0.9,
846 start: 20,
847 end: 25,
848 },
849 ];
850 let deduped = deduplicate_entities(entities);
851 assert_eq!(
852 deduped.len(),
853 1,
854 "same entity at different positions should be merged"
855 );
856 assert_eq!(deduped[0].score, 0.9, "should retain highest score");
857 }
858
859 #[test]
860 fn test_deduplicate_case_insensitive() {
861 let entities = vec![
863 ExtractedEntity {
864 entity_type: "person".to_string(),
865 value: "alice".to_string(),
866 score: 0.7,
867 start: 10,
868 end: 15,
869 },
870 ExtractedEntity {
871 entity_type: "person".to_string(),
872 value: "Alice".to_string(),
873 score: 0.95,
874 start: 0,
875 end: 5,
876 },
877 ];
878 let deduped = deduplicate_entities(entities);
879 assert_eq!(
880 deduped.len(),
881 1,
882 "case-insensitive dedup: 'Alice' == 'alice'"
883 );
884 assert_eq!(deduped[0].score, 0.95);
885 }
886
887 #[test]
888 fn test_deduplicate_different_types_kept() {
889 let entities = vec![
891 ExtractedEntity {
892 entity_type: "person".to_string(),
893 value: "Apple".to_string(),
894 score: 0.6,
895 start: 0,
896 end: 5,
897 },
898 ExtractedEntity {
899 entity_type: "organization".to_string(),
900 value: "Apple".to_string(),
901 score: 0.9,
902 start: 0,
903 end: 5,
904 },
905 ];
906 let deduped = deduplicate_entities(entities);
907 assert_eq!(
908 deduped.len(),
909 2,
910 "same value with different types must be kept separately"
911 );
912 }
913
914 #[test]
915 fn test_truncate_to_word_limit_long() {
916 let words: Vec<String> = (0..500).map(|i| format!("word{}", i)).collect();
917 let text = words.join(" ");
918 let truncated = truncate_to_word_limit(&text, 300);
919 let word_count = truncated.split_whitespace().count();
920 assert!(
921 word_count <= 300,
922 "truncated text must be ≤ 300 words, got {}",
923 word_count
924 );
925 }
926
927 #[test]
928 fn test_truncate_to_word_limit_short_pass_through() {
929 let text = "Hello world this is fine";
930 assert_eq!(
931 truncate_to_word_limit(text, 300),
932 text,
933 "short text must pass through unchanged"
934 );
935 }
936
937 #[test]
938 fn test_sigmoid() {
939 assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
940 assert!((sigmoid(100.0) - 1.0).abs() < 1e-4);
941 assert!((sigmoid(-100.0) - 0.0).abs() < 1e-4);
942 }
943
944 #[test]
945 fn test_count_distinct_word_ids() {
946 let wids: Vec<Option<u32>> =
947 vec![Some(0), Some(0), Some(1), Some(1), Some(2), None, Some(3)];
948 assert_eq!(count_distinct_word_ids(&wids), 4);
949 }
950}