Skip to main content

inference/
ner.rs

1//! Named Entity Recognition (NER) engine — CE-4 GLiNER zero-shot NER.
2//!
3//! Two-layer extraction pipeline:
4//! 1. **Rule-based pre-pass** — regex extraction of dates, URLs, UUIDs, emails, IPs.
5//!    Always on, zero latency, no model download required.
6//! 2. **GLiNER ONNX engine** — zero-shot NER via GLiNER-medium ONNX INT8 (52 MB).
7//!    Opt-in per namespace, lazy-loaded on first use.
8//!
9//! Extracted entities are stored as tags: `entity:person:alice`, `entity:org:anthropic`.
10//! Tag values are lowercased and whitespace-normalized for consistent deduplication.
11//!
12
13use crate::error::{InferenceError, Result};
14use ort::inputs;
15use ort::session::builder::GraphOptimizationLevel;
16use ort::session::Session;
17use ort::value::Tensor;
18use parking_lot::Mutex;
19use regex::Regex;
20use std::collections::HashMap;
21use std::path::PathBuf;
22use std::sync::Arc;
23use tokenizers::Tokenizer;
24use tracing::{debug, info, instrument, warn};
25
26// ─────────────────────────────────────────────────────────────
27// Public types
28// ─────────────────────────────────────────────────────────────
29
30/// A single extracted entity.
31#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
32pub struct ExtractedEntity {
33    /// Normalised entity type label, lowercase with underscores (e.g. "person", "law_firm").
34    pub entity_type: String,
35    /// The entity surface form as it appeared in the text (original casing, trimmed).
36    pub value: String,
37    /// Confidence score 0.0–1.0 (rule-based entities get 1.0).
38    pub score: f32,
39    /// Byte start offset in the original text.
40    pub start: usize,
41    /// Byte end offset in the original text.
42    pub end: usize,
43}
44
45impl ExtractedEntity {
46    /// Convert to the canonical tag format `entity:<type>:<normalized_value>`.
47    ///
48    /// Value is lowercased, whitespace-collapsed, and special characters sanitized
49    /// so that tags are consistent regardless of surface form casing.
50    /// "Alice Smith" and "alice smith" produce the same tag.
51    pub fn to_tag(&self) -> String {
52        let normalized_value = normalize_tag_value(&self.value);
53        format!("entity:{}:{}", self.entity_type, normalized_value)
54    }
55
56    /// Canonical dedup key: (entity_type, normalized_value).
57    pub fn dedup_key(&self) -> (String, String) {
58        (self.entity_type.clone(), normalize_tag_value(&self.value))
59    }
60}
61
62/// Normalize an entity type label: lowercase, spaces → underscores.
63///
64/// GLiNER is zero-shot and uses the label text in its attention, so the original
65/// label is passed to the model unchanged. This function is only used for tag
66/// storage to ensure consistent, searchable keys.
67///
68/// Examples: "Person" → "person", "Law Firm" → "law_firm", "  ORG  " → "org"
69pub fn normalize_label(label: &str) -> String {
70    label.trim().to_lowercase().replace(' ', "_")
71}
72
73/// Normalize an entity value for use in a tag.
74///
75/// - Trims leading/trailing whitespace and collapses internal runs to a single space
76/// - Lowercases for case-insensitive deduplication
77/// - Replaces `:` with `_` (tag structural separator)
78/// - Strips control characters
79fn normalize_tag_value(value: &str) -> String {
80    value
81        .split_whitespace()
82        .collect::<Vec<_>>()
83        .join(" ")
84        .to_lowercase()
85        .replace(':', "_")
86}
87
88/// Deduplicate entities by (entity_type, normalized_value), keeping the highest score.
89///
90/// Repeated mentions of the same entity (e.g., "Alice" at positions 5 and 42) are
91/// folded into a single entity. The highest-confidence occurrence wins. Different
92/// entity types for the same value ("Apple" as person vs. organization) are kept
93/// as separate entries.
94pub fn deduplicate_entities(mut entities: Vec<ExtractedEntity>) -> Vec<ExtractedEntity> {
95    // Sort by score descending so the first insertion per key is the best score.
96    entities.sort_by(|a, b| {
97        b.score
98            .partial_cmp(&a.score)
99            .unwrap_or(std::cmp::Ordering::Equal)
100    });
101
102    let mut seen: HashMap<(String, String), ()> = HashMap::new();
103    let mut out: Vec<ExtractedEntity> = Vec::with_capacity(entities.len());
104
105    for entity in entities {
106        let key = entity.dedup_key();
107        if seen.insert(key, ()).is_none() {
108            out.push(entity);
109        }
110    }
111
112    // Re-sort by start offset for stable, position-ordered output.
113    out.sort_by_key(|e| e.start);
114    out
115}
116
117// ─────────────────────────────────────────────────────────────
118// Rule-based pre-pass
119// ─────────────────────────────────────────────────────────────
120
121struct RulePatterns {
122    uuid: Regex,
123    url: Regex,
124    email: Regex,
125    iso_date: Regex,
126    natural_date: Regex,
127    ip_v4: Regex,
128}
129
130impl RulePatterns {
131    fn new() -> Self {
132        Self {
133            uuid: Regex::new(
134                r"(?i)\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b",
135            )
136            .expect("uuid regex"),
137            url: Regex::new(r#"https?://[^\s<>\[\]()"']+"#).expect("url regex"),
138            email: Regex::new(r"[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}")
139                .expect("email regex"),
140            iso_date: Regex::new(
141                r"\b\d{4}-(?:0[1-9]|1[0-2])-(?:0[1-9]|[12]\d|3[01])\b",
142            )
143            .expect("iso_date regex"),
144            natural_date: Regex::new(
145                r"(?i)\b(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{1,2}(?:,\s*\d{4})?\b",
146            )
147            .expect("natural_date regex"),
148            ip_v4: Regex::new(
149                r"\b(?:(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\b",
150            )
151            .expect("ipv4 regex"),
152        }
153    }
154}
155
156lazy_static::lazy_static! {
157    static ref RULE_PATTERNS: RulePatterns = RulePatterns::new();
158}
159
160/// Run the rule-based pre-pass — O(n) regex scan, zero model overhead.
161///
162/// Always extracts: uuid, url, email, date (ISO + natural), ipv4.
163/// Entity type labels are lowercase; values are trimmed.
164pub fn rule_based_extract(text: &str) -> Vec<ExtractedEntity> {
165    let mut entities: Vec<ExtractedEntity> = Vec::new();
166
167    let push = |entities: &mut Vec<ExtractedEntity>, entity_type: &str, m: regex::Match| {
168        entities.push(ExtractedEntity {
169            entity_type: entity_type.to_string(),
170            value: m.as_str().trim().to_string(),
171            score: 1.0,
172            start: m.start(),
173            end: m.end(),
174        });
175    };
176
177    // Order matters — email before URL (email contains @, URL starts with http)
178    for m in RULE_PATTERNS.email.find_iter(text) {
179        push(&mut entities, "email", m);
180    }
181    for m in RULE_PATTERNS.url.find_iter(text) {
182        // Skip if already captured as email
183        if !entities.iter().any(|e| e.start == m.start()) {
184            push(&mut entities, "url", m);
185        }
186    }
187    for m in RULE_PATTERNS.uuid.find_iter(text) {
188        push(&mut entities, "uuid", m);
189    }
190    for m in RULE_PATTERNS.iso_date.find_iter(text) {
191        push(&mut entities, "date", m);
192    }
193    for m in RULE_PATTERNS.natural_date.find_iter(text) {
194        if !entities
195            .iter()
196            .any(|e| e.start == m.start() && e.entity_type == "date")
197        {
198            push(&mut entities, "date", m);
199        }
200    }
201    for m in RULE_PATTERNS.ip_v4.find_iter(text) {
202        push(&mut entities, "ip", m);
203    }
204
205    entities
206}
207
208// ─────────────────────────────────────────────────────────────
209// GLiNER ONNX engine
210// ─────────────────────────────────────────────────────────────
211
212const GLINER_MODEL_REPO: &str = "onnx-community/gliner_medium-v2.1";
213const GLINER_TOKENIZER_REPO: &str = "onnx-community/gliner_medium-v2.1";
214const GLINER_ONNX_FILE: &str = "onnx/model_quantized.onnx";
215
216/// Maximum span width in words (GLiNER default).
217const MAX_SPAN_WIDTH: usize = 12;
218/// Default confidence threshold for accepting a span prediction.
219const DEFAULT_SCORE_THRESHOLD: f32 = 0.5;
220/// Maximum text words before truncation to stay within GLiNER's 512-token limit.
221/// At ~1.5 tokens/word average: 300 words × 1.5 + ~15 prefix tokens ≈ 465 tokens — safe margin.
222const MAX_TEXT_WORDS: usize = 300;
223
224/// GLiNER zero-shot NER engine backed by ONNX Runtime.
225///
226/// Thread-safe — the session is mutex-guarded.
227pub struct GlinerEngine {
228    session: Arc<Mutex<Session>>,
229    tokenizer: Arc<Tokenizer>,
230}
231
232impl GlinerEngine {
233    /// Create a new GLiNER engine, downloading the model if not cached.
234    #[instrument(skip_all)]
235    pub async fn new(num_threads: Option<usize>) -> Result<Self> {
236        let threads = num_threads.unwrap_or(1);
237        info!("Initializing GLiNER NER engine (threads={})", threads);
238
239        let (tokenizer_path, onnx_path) = Self::download_model_files().await?;
240
241        let tokenizer = Tokenizer::from_file(&tokenizer_path)
242            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
243
244        let session = Session::builder()
245            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
246            .with_optimization_level(GraphOptimizationLevel::Level3)
247            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
248            .with_intra_threads(threads)
249            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
250            .commit_from_file(&onnx_path)
251            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
252
253        info!("GLiNER engine ready");
254        Ok(Self {
255            session: Arc::new(Mutex::new(session)),
256            tokenizer: Arc::new(tokenizer),
257        })
258    }
259
260    /// Extract entities from text for the given entity types.
261    ///
262    /// Entity type labels are passed as user-provided strings; GLiNER is zero-shot
263    /// and uses the label text directly in its attention mechanism.
264    ///
265    /// Returns deduplicated, threshold-filtered entities sorted by start offset.
266    pub async fn extract(&self, text: &str, entity_types: &[&str]) -> Result<Vec<ExtractedEntity>> {
267        if entity_types.is_empty() || text.is_empty() {
268            return Ok(Vec::new());
269        }
270
271        let text_owned = text.to_string();
272        let entity_types_owned: Vec<String> = entity_types.iter().map(|s| s.to_string()).collect();
273        let session = self.session.clone();
274        let tokenizer = self.tokenizer.clone();
275
276        tokio::task::spawn_blocking(move || {
277            Self::run_inference(
278                &text_owned,
279                &entity_types_owned
280                    .iter()
281                    .map(|s| s.as_str())
282                    .collect::<Vec<_>>(),
283                &session,
284                &tokenizer,
285            )
286        })
287        .await
288        .map_err(|e| InferenceError::HubError(format!("GLiNER inference task panicked: {}", e)))?
289    }
290
291    fn run_inference(
292        text: &str,
293        entity_types: &[&str],
294        session: &Arc<Mutex<Session>>,
295        tokenizer: &Tokenizer,
296    ) -> Result<Vec<ExtractedEntity>> {
297        // ── Step 0: guard against texts that would overflow the model ─────
298        // Truncate at a word boundary before tokenising. GLiNER medium has a
299        // 512-token limit; silently exceeding it produces garbage logits.
300        let text = truncate_to_word_limit(text, MAX_TEXT_WORDS);
301
302        // ── Step 1: build full input text ────────────────────────────────
303        // GLiNER v2.1 prompt format: "type1 << >> type2 << >> <text>"
304        let prefix = entity_types.join(" << >> ");
305        let prefix_plus_sep = format!("{} << >> ", prefix);
306        let full_text = format!("{}{}", prefix_plus_sep, text);
307
308        // ── Step 2: tokenize full input ──────────────────────────────────
309        let encoding = tokenizer
310            .encode(full_text.as_str(), true)
311            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
312
313        let token_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
314        let attention_mask: Vec<i64> = encoding
315            .get_attention_mask()
316            .iter()
317            .map(|&x| x as i64)
318            .collect();
319        let seq_len = token_ids.len();
320
321        // ── Step 3: count prefix words (O(n)) ────────────────────────────
322        // Tokenize the prefix alone to determine how many words it contributes.
323        // This tells us where the actual text words begin in the full token sequence.
324        let prefix_encoding = tokenizer
325            .encode(prefix_plus_sep.as_str(), false)
326            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
327        let prefix_word_count = count_distinct_word_ids(prefix_encoding.get_word_ids());
328
329        // ── Step 4: single-pass words_mask + char offset map (O(n)) ──────
330        //
331        // Previous implementation recomputed a HashSet from scratch for each new word
332        // (O(n²) overall). This single pass is O(n):
333        //
334        //   a) words_mask[i] = 1 if token i is the first sub-token of a TEXT word.
335        //   b) text_word_ids — ordered list of global word-IDs for each text word,
336        //      used below to recover byte offsets via the tokenizer's offset table.
337        //   c) word_byte_ranges — global_word_id → (byte_start, byte_end) in full_text,
338        //      built by taking the union of all sub-token offsets for each word.
339        //
340        let word_ids = encoding.get_word_ids();
341        let token_offsets = encoding.get_offsets(); // byte offsets into full_text
342
343        let mut words_mask = vec![0i64; seq_len];
344        let mut last_word_id: Option<u32> = None;
345        let mut cumulative_word_count = 0usize; // prefix + text words seen
346        let mut text_word_count = 0usize;
347        let mut text_word_ids: Vec<u32> = Vec::new();
348        // Union of sub-token byte ranges per global word ID.
349        let mut word_byte_ranges: HashMap<u32, (usize, usize)> = HashMap::new();
350
351        for (i, &wid_opt) in word_ids.iter().enumerate() {
352            let wid = match wid_opt {
353                Some(w) => w,
354                None => {
355                    last_word_id = None;
356                    continue;
357                }
358            };
359
360            // Extend the byte range for this word to cover all its sub-tokens.
361            let (tok_start, tok_end) = token_offsets[i];
362            let entry = word_byte_ranges.entry(wid).or_insert((tok_start, tok_end));
363            if tok_start < entry.0 {
364                entry.0 = tok_start;
365            }
366            if tok_end > entry.1 {
367                entry.1 = tok_end;
368            }
369
370            let is_new_word = last_word_id.map(|lw| lw != wid).unwrap_or(true);
371            if is_new_word {
372                if cumulative_word_count >= prefix_word_count {
373                    words_mask[i] = 1;
374                    text_word_count += 1;
375                    text_word_ids.push(wid);
376                }
377                cumulative_word_count += 1;
378            }
379            last_word_id = Some(wid);
380        }
381
382        if text_word_count == 0 {
383            debug!("No text words after entity type prefix — skipping inference");
384            return Ok(Vec::new());
385        }
386        let text_lengths = vec![text_word_count as i64];
387
388        // Prefix byte length: tokenizer byte offsets for text words are relative to
389        // full_text; subtract this to obtain offsets relative to `text`.
390        let prefix_byte_offset = prefix_plus_sep.len();
391
392        // ── Step 5: enumerate candidate spans ────────────────────────────
393        let mut span_idx_flat: Vec<i64> = Vec::new();
394        let mut span_mask: Vec<bool> = Vec::new();
395
396        for start in 0..text_word_count {
397            for end in start..text_word_count.min(start + MAX_SPAN_WIDTH) {
398                span_idx_flat.push(start as i64);
399                span_idx_flat.push(end as i64);
400                span_mask.push(true);
401            }
402        }
403
404        let num_spans = span_mask.len();
405        if num_spans == 0 {
406            return Ok(Vec::new());
407        }
408
409        // ── Step 6: ORT session forward pass ─────────────────────────────
410        // onnx-community/gliner_medium-v2.1: span_mask → tensor(bool),
411        // all other inputs → tensor(int64).
412        let logits_raw: Vec<f32> = {
413            let mut session_guard = session.lock();
414
415            let input_ids_t = Tensor::<i64>::from_array(([1usize, seq_len], token_ids))
416                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
417            let attn_mask_t = Tensor::<i64>::from_array(([1usize, seq_len], attention_mask))
418                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
419            let words_mask_t = Tensor::<i64>::from_array(([1usize, seq_len], words_mask))
420                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
421            // Shape [1, 1]: GLiNER requires rank-2 tensor for text_lengths.
422            let text_lengths_t = Tensor::<i64>::from_array(([1usize, 1usize], text_lengths))
423                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
424            let span_idx_t = Tensor::<i64>::from_array(([1usize, num_spans, 2], span_idx_flat))
425                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
426            let span_mask_t = Tensor::<bool>::from_array(([1usize, num_spans], span_mask))
427                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
428
429            let outputs = session_guard
430                .run(inputs![
431                    "input_ids" => input_ids_t,
432                    "attention_mask" => attn_mask_t,
433                    "words_mask" => words_mask_t,
434                    "text_lengths" => text_lengths_t,
435                    "span_idx" => span_idx_t,
436                    "span_mask" => span_mask_t,
437                ])
438                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
439
440            // outputs[0] = logits: shape [1, num_spans, num_entity_types]
441            let (_shape, logits_slice) = outputs[0]
442                .try_extract_tensor::<f32>()
443                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
444            logits_slice.to_vec()
445        };
446
447        // logits shape: [1, num_spans, num_entity_types]
448        let num_entity_types = entity_types.len();
449        if logits_raw.len() != num_spans * num_entity_types {
450            warn!(
451                "GLiNER logits shape mismatch: got {}, expected {}",
452                logits_raw.len(),
453                num_spans * num_entity_types
454            );
455            return Ok(Vec::new());
456        }
457
458        // ── Step 7: sigmoid → threshold → per-type NMS ──────────────────
459        let mut raw_entities: Vec<(usize, usize, usize, f32)> = Vec::new(); // (type_idx, start_w, end_w, score)
460
461        for (span_i, (start_w, end_w)) in iter_spans(text_word_count).enumerate() {
462            for (type_i, _) in entity_types.iter().enumerate() {
463                let score = sigmoid(logits_raw[span_i * num_entity_types + type_i]);
464                if score >= DEFAULT_SCORE_THRESHOLD {
465                    raw_entities.push((type_i, start_w, end_w, score));
466                }
467            }
468        }
469
470        // NMS: sort by score, suppress same-type overlapping spans.
471        // Cross-type overlaps are preserved (e.g., "New York" as both location and org).
472        raw_entities.sort_by(|a, b| b.3.partial_cmp(&a.3).unwrap_or(std::cmp::Ordering::Equal));
473        let mut kept: Vec<(usize, usize, usize, f32)> = Vec::new();
474        'outer: for candidate in &raw_entities {
475            for kept_span in &kept {
476                if kept_span.0 == candidate.0
477                    && kept_span.1 <= candidate.2
478                    && candidate.1 <= kept_span.2
479                {
480                    continue 'outer;
481                }
482            }
483            kept.push(*candidate);
484        }
485
486        // ── Step 8: map word-span indices → byte offsets in `text` ───────
487        // Uses the tokenizer's own byte-offset table (word_byte_ranges built in Step 4)
488        // rather than a whitespace-split approximation. This correctly handles
489        // sub-word-tokenised words: "John's" → tokens ["John", "'s"] share word_id=N,
490        // and their byte range covers the full "John's" span.
491        let mut entities: Vec<ExtractedEntity> = kept
492            .into_iter()
493            .filter_map(|(type_i, start_w, end_w, score)| {
494                let start_wid = *text_word_ids.get(start_w)?;
495                let end_wid = *text_word_ids.get(end_w)?;
496                let &(start_byte_full, _) = word_byte_ranges.get(&start_wid)?;
497                let &(_, end_byte_full) = word_byte_ranges.get(&end_wid)?;
498
499                // Convert from full_text-relative to text-relative byte offsets.
500                let start_byte = start_byte_full.saturating_sub(prefix_byte_offset);
501                let end_byte = end_byte_full.saturating_sub(prefix_byte_offset);
502
503                if start_byte >= end_byte || end_byte > text.len() {
504                    return None;
505                }
506
507                let value = text[start_byte..end_byte].trim().to_string();
508                if value.is_empty() {
509                    return None;
510                }
511
512                // Normalize entity type label for storage (lowercase + underscores).
513                // The original user-provided label was used in the model prompt above.
514                let entity_type = normalize_label(entity_types[type_i]);
515
516                Some(ExtractedEntity {
517                    entity_type,
518                    value,
519                    score,
520                    start: start_byte,
521                    end: end_byte,
522                })
523            })
524            .collect();
525
526        entities.sort_by_key(|e| e.start);
527        debug!("GLiNER extracted {} entities", entities.len());
528        Ok(entities)
529    }
530
531    // ── Model download helpers ────────────────────────────────────────────
532
533    #[instrument(skip_all)]
534    async fn download_model_files() -> Result<(PathBuf, PathBuf)> {
535        info!(
536            "Resolving GLiNER model files: tokenizer={}, onnx={}",
537            GLINER_TOKENIZER_REPO, GLINER_MODEL_REPO
538        );
539
540        let tokenizer_cache = Self::model_cache_dir(GLINER_TOKENIZER_REPO)?;
541        let onnx_cache = Self::model_cache_dir(GLINER_MODEL_REPO)?;
542        let onnx_subdir = onnx_cache.join("onnx");
543        std::fs::create_dir_all(&onnx_subdir)?;
544
545        let local_tokenizer = tokenizer_cache.join("tokenizer.json");
546        let local_onnx = onnx_subdir.join("model_quantized.onnx");
547
548        if !local_tokenizer.exists() || !local_onnx.exists() {
549            let tok_cache = tokenizer_cache.clone();
550            let onnx_c = onnx_cache.clone();
551            let tok_exists = local_tokenizer.exists();
552            let onnx_exists = local_onnx.exists();
553
554            tokio::task::spawn_blocking(move || {
555                if !tok_exists {
556                    crate::engine::EmbeddingEngine::download_hf_file_pub(
557                        GLINER_TOKENIZER_REPO,
558                        "tokenizer.json",
559                        &tok_cache,
560                    )
561                    .map_err(|e| {
562                        InferenceError::HubError(format!(
563                            "Failed to download GLiNER tokenizer: {}",
564                            e
565                        ))
566                    })?;
567                }
568                if !onnx_exists {
569                    crate::engine::EmbeddingEngine::download_hf_file_pub(
570                        GLINER_MODEL_REPO,
571                        GLINER_ONNX_FILE,
572                        &onnx_c,
573                    )
574                    .map_err(|e| {
575                        InferenceError::HubError(format!(
576                            "Failed to download GLiNER ONNX model: {}",
577                            e
578                        ))
579                    })?;
580                }
581                Ok::<_, InferenceError>(())
582            })
583            .await
584            .map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
585        } else {
586            info!("GLiNER model files found in local cache");
587        }
588
589        let final_onnx = onnx_cache.join(GLINER_ONNX_FILE);
590        Ok((local_tokenizer, final_onnx))
591    }
592
593    fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
594        let base = std::env::var("HF_HOME")
595            .map(PathBuf::from)
596            .unwrap_or_else(|_| {
597                let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
598                PathBuf::from(home).join(".cache").join("huggingface")
599            });
600        let dir = base.join("dakera").join(model_id.replace('/', "--"));
601        std::fs::create_dir_all(&dir)?;
602        Ok(dir)
603    }
604}
605
606// ─────────────────────────────────────────────────────────────
607// NerEngine — unified interface (rule-based + GLiNER)
608// ─────────────────────────────────────────────────────────────
609
610/// Unified NER engine combining rule-based and GLiNER extraction.
611pub struct NerEngine {
612    gliner: Option<Arc<GlinerEngine>>,
613}
614
615impl NerEngine {
616    /// Create a NerEngine with only the rule-based extractor (no model download).
617    pub fn rule_based_only() -> Self {
618        Self { gliner: None }
619    }
620
621    /// Create a NerEngine backed by GLiNER (downloads model on first call).
622    pub async fn with_gliner(num_threads: Option<usize>) -> Result<Self> {
623        let gliner = GlinerEngine::new(num_threads).await?;
624        Ok(Self {
625            gliner: Some(Arc::new(gliner)),
626        })
627    }
628
629    /// Extract entities from text.
630    ///
631    /// Always runs the rule-based pre-pass. If GLiNER is loaded and
632    /// `gliner_types` is non-empty, also runs the neural extractor.
633    ///
634    /// Results are merged (rule-based entities are not duplicated by GLiNER),
635    /// then semantically deduplicated by (entity_type, normalized_value) so that
636    /// repeated mentions of the same entity collapse into one entry.
637    pub async fn extract(&self, text: &str, gliner_types: &[&str]) -> Vec<ExtractedEntity> {
638        let mut entities = rule_based_extract(text);
639
640        if let Some(ref gliner) = self.gliner {
641            if !gliner_types.is_empty() {
642                match gliner.extract(text, gliner_types).await {
643                    Ok(neural) => {
644                        for ne in neural {
645                            // Skip if a rule-based pass already covered the same byte span.
646                            if !entities
647                                .iter()
648                                .any(|e| e.start == ne.start && e.end == ne.end)
649                            {
650                                entities.push(ne);
651                            }
652                        }
653                    }
654                    Err(e) => {
655                        warn!("GLiNER extraction failed, using rule-based only: {}", e);
656                    }
657                }
658            }
659        }
660
661        entities.sort_by_key(|e| e.start);
662
663        // Deduplicate repeated mentions of the same (type, value) across the text.
664        deduplicate_entities(entities)
665    }
666}
667
668// ─────────────────────────────────────────────────────────────
669// Helpers
670// ─────────────────────────────────────────────────────────────
671
672/// Count distinct word IDs in a word_ids slice — O(n).
673fn count_distinct_word_ids(word_ids: &[Option<u32>]) -> usize {
674    let mut seen = std::collections::HashSet::new();
675    for &wid in word_ids {
676        if let Some(w) = wid {
677            seen.insert(w);
678        }
679    }
680    seen.len()
681}
682
683/// Truncate text to at most `max_words` whitespace-separated words.
684///
685/// Returns a subslice ending at the last word boundary ≤ max_words.
686/// Short texts that are already under the limit are returned unchanged.
687fn truncate_to_word_limit(text: &str, max_words: usize) -> &str {
688    let mut word_count = 0usize;
689    let mut byte_end = text.len();
690    let mut in_word = false;
691
692    for (i, ch) in text.char_indices() {
693        if ch.is_whitespace() {
694            if in_word {
695                word_count += 1;
696                if word_count >= max_words {
697                    byte_end = i;
698                    break;
699                }
700            }
701            in_word = false;
702        } else {
703            in_word = true;
704        }
705    }
706
707    &text[..byte_end]
708}
709
710/// Iterate all valid (start, end) word-index pairs within MAX_SPAN_WIDTH.
711fn iter_spans(num_words: usize) -> impl Iterator<Item = (usize, usize)> {
712    (0..num_words).flat_map(move |start| {
713        let max_end = num_words.min(start + MAX_SPAN_WIDTH);
714        (start..max_end).map(move |end| (start, end))
715    })
716}
717
718/// Numerically stable sigmoid.
719#[inline]
720fn sigmoid(x: f32) -> f32 {
721    if x >= 0.0 {
722        1.0 / (1.0 + (-x).exp())
723    } else {
724        let ex = x.exp();
725        ex / (1.0 + ex)
726    }
727}
728
729// ─────────────────────────────────────────────────────────────
730// Tests
731// ─────────────────────────────────────────────────────────────
732
733#[cfg(test)]
734mod tests {
735    use super::*;
736
737    #[test]
738    fn test_rule_based_uuid() {
739        let text = "session id is 550e8400-e29b-41d4-a716-446655440000 here";
740        let entities = rule_based_extract(text);
741        assert!(entities.iter().any(|e| e.entity_type == "uuid"));
742    }
743
744    #[test]
745    fn test_rule_based_url() {
746        let text = "check https://example.com/path?q=1 for details";
747        let entities = rule_based_extract(text);
748        assert!(entities.iter().any(|e| e.entity_type == "url"));
749    }
750
751    #[test]
752    fn test_rule_based_email() {
753        let text = "contact alice@example.com for support";
754        let entities = rule_based_extract(text);
755        assert!(entities.iter().any(|e| e.entity_type == "email"));
756        // Email should NOT also be captured as url
757        assert!(!entities.iter().any(|e| e.entity_type == "url"));
758    }
759
760    #[test]
761    fn test_rule_based_iso_date() {
762        let text = "released on 2024-03-15 at noon";
763        let entities = rule_based_extract(text);
764        assert!(entities
765            .iter()
766            .any(|e| e.entity_type == "date" && e.value == "2024-03-15"));
767    }
768
769    #[test]
770    fn test_rule_based_natural_date() {
771        let text = "meeting on March 15, 2024 at noon";
772        let entities = rule_based_extract(text);
773        assert!(entities.iter().any(|e| e.entity_type == "date"));
774    }
775
776    #[test]
777    fn test_entity_to_tag_lowercase_value() {
778        // Values are lowercased in tags for consistent deduplication.
779        let e = ExtractedEntity {
780            entity_type: "person".to_string(),
781            value: "Alice Smith".to_string(),
782            score: 0.9,
783            start: 0,
784            end: 11,
785        };
786        assert_eq!(e.to_tag(), "entity:person:alice smith");
787    }
788
789    #[test]
790    fn test_entity_to_tag_colon_escaping() {
791        let e = ExtractedEntity {
792            entity_type: "url".to_string(),
793            value: "http://example.com:8080/path".to_string(),
794            score: 1.0,
795            start: 0,
796            end: 27,
797        };
798        let tag = e.to_tag();
799        let parts: Vec<&str> = tag.splitn(3, ':').collect();
800        assert_eq!(parts.len(), 3, "tag should have 3 parts: {}", tag);
801        assert_eq!(parts[0], "entity");
802        assert_eq!(parts[1], "url");
803        assert!(
804            !parts[2].contains(':'),
805            "value should not contain colons: {}",
806            parts[2]
807        );
808    }
809
810    #[test]
811    fn test_entity_to_tag_normalizes_whitespace() {
812        let e = ExtractedEntity {
813            entity_type: "person".to_string(),
814            value: "  John   Doe  ".to_string(),
815            score: 0.9,
816            start: 0,
817            end: 12,
818        };
819        assert_eq!(e.to_tag(), "entity:person:john doe");
820    }
821
822    #[test]
823    fn test_normalize_label() {
824        assert_eq!(normalize_label("Person"), "person");
825        assert_eq!(normalize_label("Law Firm"), "law_firm");
826        assert_eq!(normalize_label("  ORG  "), "org");
827        assert_eq!(normalize_label("ORGANIZATION"), "organization");
828        assert_eq!(normalize_label("location"), "location");
829    }
830
831    #[test]
832    fn test_deduplicate_same_value_different_positions() {
833        // "Alice" at position 0 and 20 — keep only one (highest score).
834        let entities = vec![
835            ExtractedEntity {
836                entity_type: "person".to_string(),
837                value: "Alice".to_string(),
838                score: 0.8,
839                start: 0,
840                end: 5,
841            },
842            ExtractedEntity {
843                entity_type: "person".to_string(),
844                value: "Alice".to_string(),
845                score: 0.9,
846                start: 20,
847                end: 25,
848            },
849        ];
850        let deduped = deduplicate_entities(entities);
851        assert_eq!(
852            deduped.len(),
853            1,
854            "same entity at different positions should be merged"
855        );
856        assert_eq!(deduped[0].score, 0.9, "should retain highest score");
857    }
858
859    #[test]
860    fn test_deduplicate_case_insensitive() {
861        // "alice" and "Alice" are the same entity.
862        let entities = vec![
863            ExtractedEntity {
864                entity_type: "person".to_string(),
865                value: "alice".to_string(),
866                score: 0.7,
867                start: 10,
868                end: 15,
869            },
870            ExtractedEntity {
871                entity_type: "person".to_string(),
872                value: "Alice".to_string(),
873                score: 0.95,
874                start: 0,
875                end: 5,
876            },
877        ];
878        let deduped = deduplicate_entities(entities);
879        assert_eq!(
880            deduped.len(),
881            1,
882            "case-insensitive dedup: 'Alice' == 'alice'"
883        );
884        assert_eq!(deduped[0].score, 0.95);
885    }
886
887    #[test]
888    fn test_deduplicate_different_types_kept() {
889        // "Apple" as person vs. organization — both must be kept.
890        let entities = vec![
891            ExtractedEntity {
892                entity_type: "person".to_string(),
893                value: "Apple".to_string(),
894                score: 0.6,
895                start: 0,
896                end: 5,
897            },
898            ExtractedEntity {
899                entity_type: "organization".to_string(),
900                value: "Apple".to_string(),
901                score: 0.9,
902                start: 0,
903                end: 5,
904            },
905        ];
906        let deduped = deduplicate_entities(entities);
907        assert_eq!(
908            deduped.len(),
909            2,
910            "same value with different types must be kept separately"
911        );
912    }
913
914    #[test]
915    fn test_truncate_to_word_limit_long() {
916        let words: Vec<String> = (0..500).map(|i| format!("word{}", i)).collect();
917        let text = words.join(" ");
918        let truncated = truncate_to_word_limit(&text, 300);
919        let word_count = truncated.split_whitespace().count();
920        assert!(
921            word_count <= 300,
922            "truncated text must be ≤ 300 words, got {}",
923            word_count
924        );
925    }
926
927    #[test]
928    fn test_truncate_to_word_limit_short_pass_through() {
929        let text = "Hello world this is fine";
930        assert_eq!(
931            truncate_to_word_limit(text, 300),
932            text,
933            "short text must pass through unchanged"
934        );
935    }
936
937    #[test]
938    fn test_sigmoid() {
939        assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
940        assert!((sigmoid(100.0) - 1.0).abs() < 1e-4);
941        assert!((sigmoid(-100.0) - 0.0).abs() < 1e-4);
942    }
943
944    #[test]
945    fn test_count_distinct_word_ids() {
946        let wids: Vec<Option<u32>> =
947            vec![Some(0), Some(0), Some(1), Some(1), Some(2), None, Some(3)];
948        assert_eq!(count_distinct_word_ids(&wids), 4);
949    }
950}