Skip to main content

inference/
ner.rs

1//! Named Entity Recognition (NER) engine — CE-4 GLiNER zero-shot NER.
2//!
3//! Two-layer extraction pipeline:
4//! 1. **Rule-based pre-pass** — regex extraction of dates, URLs, UUIDs, emails, IPs.
5//!    Always on, zero latency, no model download required.
6//! 2. **GLiNER ONNX engine** — zero-shot NER via GLiNER-medium ONNX INT8 (52 MB).
7//!    Opt-in per namespace, lazy-loaded on first use.
8//!
9//! Extracted entities are stored as tags: `entity:person:Alice`, `entity:org:Anthropic`.
10//!
11
12use crate::error::{InferenceError, Result};
13use ort::inputs;
14use ort::session::builder::GraphOptimizationLevel;
15use ort::session::Session;
16use ort::value::Tensor;
17use parking_lot::Mutex;
18use regex::Regex;
19use std::path::PathBuf;
20use std::sync::Arc;
21use tokenizers::Tokenizer;
22use tracing::{debug, info, instrument, warn};
23
24// ─────────────────────────────────────────────────────────────
25// Public types
26// ─────────────────────────────────────────────────────────────
27
28/// A single extracted entity.
29#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
30pub struct ExtractedEntity {
31    /// Normalised entity type (e.g. "person", "org", "date", "url").
32    pub entity_type: String,
33    /// The entity surface form extracted from the text.
34    pub value: String,
35    /// Confidence score 0.0–1.0 (rule-based entities get 1.0).
36    pub score: f32,
37    /// Byte start offset in the original text.
38    pub start: usize,
39    /// Byte end offset in the original text.
40    pub end: usize,
41}
42
43impl ExtractedEntity {
44    /// Convert to the canonical tag format `entity:<type>:<value>`.
45    pub fn to_tag(&self) -> String {
46        let v = self.value.replace(':', "_");
47        format!("entity:{}:{}", self.entity_type, v)
48    }
49}
50
51// ─────────────────────────────────────────────────────────────
52// Rule-based pre-pass
53// ─────────────────────────────────────────────────────────────
54
55struct RulePatterns {
56    uuid: Regex,
57    url: Regex,
58    email: Regex,
59    iso_date: Regex,
60    natural_date: Regex,
61    ip_v4: Regex,
62}
63
64impl RulePatterns {
65    fn new() -> Self {
66        Self {
67            uuid: Regex::new(
68                r"(?i)\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b",
69            )
70            .expect("uuid regex"),
71            url: Regex::new(r#"https?://[^\s<>\[\]()"']+"#)
72                .expect("url regex"),
73            email: Regex::new(r"[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}")
74                .expect("email regex"),
75            iso_date: Regex::new(
76                r"\b\d{4}-(?:0[1-9]|1[0-2])-(?:0[1-9]|[12]\d|3[01])\b",
77            )
78            .expect("iso_date regex"),
79            natural_date: Regex::new(
80                r"(?i)\b(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{1,2}(?:,\s*\d{4})?\b",
81            )
82            .expect("natural_date regex"),
83            ip_v4: Regex::new(
84                r"\b(?:(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\b",
85            )
86            .expect("ipv4 regex"),
87        }
88    }
89}
90
91lazy_static::lazy_static! {
92    static ref RULE_PATTERNS: RulePatterns = RulePatterns::new();
93}
94
95/// Run the rule-based pre-pass — O(n) regex scan, zero model overhead.
96///
97/// Always extracts: uuid, url, email, date (ISO + natural), ipv4.
98pub fn rule_based_extract(text: &str) -> Vec<ExtractedEntity> {
99    let mut entities: Vec<ExtractedEntity> = Vec::new();
100
101    let push = |entities: &mut Vec<ExtractedEntity>, entity_type: &str, m: regex::Match| {
102        entities.push(ExtractedEntity {
103            entity_type: entity_type.to_string(),
104            value: m.as_str().to_string(),
105            score: 1.0,
106            start: m.start(),
107            end: m.end(),
108        });
109    };
110
111    // Order matters — email before URL (email contains @, URL starts with http)
112    for m in RULE_PATTERNS.email.find_iter(text) {
113        push(&mut entities, "email", m);
114    }
115    for m in RULE_PATTERNS.url.find_iter(text) {
116        // Skip if already captured as email
117        if !entities.iter().any(|e| e.start == m.start()) {
118            push(&mut entities, "url", m);
119        }
120    }
121    for m in RULE_PATTERNS.uuid.find_iter(text) {
122        push(&mut entities, "uuid", m);
123    }
124    for m in RULE_PATTERNS.iso_date.find_iter(text) {
125        push(&mut entities, "date", m);
126    }
127    for m in RULE_PATTERNS.natural_date.find_iter(text) {
128        // Only if no ISO date already at this offset
129        if !entities
130            .iter()
131            .any(|e| e.start == m.start() && e.entity_type == "date")
132        {
133            push(&mut entities, "date", m);
134        }
135    }
136    for m in RULE_PATTERNS.ip_v4.find_iter(text) {
137        push(&mut entities, "ip", m);
138    }
139
140    entities
141}
142
143// ─────────────────────────────────────────────────────────────
144// GLiNER ONNX engine
145// ─────────────────────────────────────────────────────────────
146
147const GLINER_MODEL_REPO: &str = "onnx-community/gliner_medium-v2.1";
148const GLINER_TOKENIZER_REPO: &str = "onnx-community/gliner_medium-v2.1";
149const GLINER_ONNX_FILE: &str = "onnx/model_quantized.onnx";
150
151/// Maximum span width in words (GLiNER default).
152const MAX_SPAN_WIDTH: usize = 12;
153/// Confidence threshold for accepting a span prediction.
154const SCORE_THRESHOLD: f32 = 0.5;
155
156/// GLiNER zero-shot NER engine backed by ONNX Runtime.
157///
158/// Thread-safe — the session is mutex-guarded.
159pub struct GlinerEngine {
160    session: Arc<Mutex<Session>>,
161    tokenizer: Arc<Tokenizer>,
162}
163
164impl GlinerEngine {
165    /// Create a new GLiNER engine, downloading the model if not cached.
166    #[instrument(skip_all)]
167    pub async fn new(num_threads: Option<usize>) -> Result<Self> {
168        let threads = num_threads.unwrap_or(1);
169        info!("Initializing GLiNER NER engine (threads={})", threads);
170
171        let (tokenizer_path, onnx_path) = Self::download_model_files().await?;
172
173        let tokenizer = Tokenizer::from_file(&tokenizer_path)
174            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
175
176        let session = Session::builder()
177            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
178            .with_optimization_level(GraphOptimizationLevel::Level3)
179            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
180            .with_intra_threads(threads)
181            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
182            .commit_from_file(&onnx_path)
183            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
184
185        info!("GLiNER engine ready");
186        Ok(Self {
187            session: Arc::new(Mutex::new(session)),
188            tokenizer: Arc::new(tokenizer),
189        })
190    }
191
192    /// Extract entities from text for the given entity types.
193    ///
194    /// Returns deduplicated, threshold-filtered entities sorted by start offset.
195    pub async fn extract(&self, text: &str, entity_types: &[&str]) -> Result<Vec<ExtractedEntity>> {
196        if entity_types.is_empty() || text.is_empty() {
197            return Ok(Vec::new());
198        }
199
200        let text_owned = text.to_string();
201        let entity_types_owned: Vec<String> = entity_types.iter().map(|s| s.to_string()).collect();
202        let session = self.session.clone();
203        let tokenizer = self.tokenizer.clone();
204
205        tokio::task::spawn_blocking(move || {
206            Self::run_inference(
207                &text_owned,
208                &entity_types_owned
209                    .iter()
210                    .map(|s| s.as_str())
211                    .collect::<Vec<_>>(),
212                &session,
213                &tokenizer,
214            )
215        })
216        .await
217        .map_err(|e| InferenceError::HubError(format!("GLiNER inference task panicked: {}", e)))?
218    }
219
220    fn run_inference(
221        text: &str,
222        entity_types: &[&str],
223        session: &Arc<Mutex<Session>>,
224        tokenizer: &Tokenizer,
225    ) -> Result<Vec<ExtractedEntity>> {
226        // ── Step 1: build input text ──────────────────────────────────────
227        // Format: "type1 << >> type2 << >> text"
228        let mut full_text = entity_types.join(" << >> ");
229        full_text.push_str(" << >> ");
230        full_text.push_str(text);
231
232        // ── Step 2: tokenize ──────────────────────────────────────────────
233        let encoding = tokenizer
234            .encode(full_text.as_str(), true)
235            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
236
237        let token_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
238        let attention_mask: Vec<i64> = encoding
239            .get_attention_mask()
240            .iter()
241            .map(|&x| x as i64)
242            .collect();
243        let seq_len = token_ids.len();
244
245        // ── Step 3: compute words_mask and text_length ────────────────────
246        // words_mask[i] = 1 if token i is the start of a word, 0 otherwise.
247        // We use word_ids from the fast tokenizer encoding.
248        let word_ids = encoding.get_word_ids();
249
250        let mut words_mask = vec![0i64; seq_len];
251        let mut last_word_id: Option<u32> = None;
252        let mut text_token_start = usize::MAX;
253
254        // Count words in the entity type prefix to find where text begins.
255        // Prefix format: "type1 << >> type2 << >> " — count distinct word_ids
256        // that appear before the text's first word.
257        // Simpler: find the token at which the text portion starts by
258        // counting words in the prefix.
259        let prefix = entity_types.join(" << >> ");
260        let prefix_plus_sep = format!("{} << >> ", prefix);
261        let prefix_encoding = tokenizer
262            .encode(prefix_plus_sep.as_str(), false)
263            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
264        let prefix_word_count = prefix_encoding
265            .get_word_ids()
266            .iter()
267            .filter_map(|&w| w)
268            .collect::<std::collections::HashSet<_>>()
269            .len();
270
271        let mut text_word_count = 0usize;
272        for (i, &wid_opt) in word_ids.iter().enumerate() {
273            let wid = match wid_opt {
274                Some(w) => w,
275                None => {
276                    last_word_id = None;
277                    continue;
278                }
279            };
280            let is_new_word = last_word_id.map(|lw| lw != wid).unwrap_or(true);
281            if is_new_word {
282                let global_word_idx = {
283                    // count distinct words so far
284                    word_ids[..i]
285                        .iter()
286                        .filter_map(|&w| w)
287                        .collect::<std::collections::HashSet<_>>()
288                        .len()
289                };
290                if global_word_idx >= prefix_word_count {
291                    if text_token_start == usize::MAX {
292                        text_token_start = i;
293                    }
294                    words_mask[i] = 1;
295                    text_word_count += 1;
296                }
297            }
298            last_word_id = Some(wid);
299        }
300
301        if text_word_count == 0 || text_token_start == usize::MAX {
302            debug!("No text words found after entity type prefix, skipping inference");
303            return Ok(Vec::new());
304        }
305        let text_lengths = vec![text_word_count as i64];
306
307        // ── Step 4: generate candidate spans ─────────────────────────────
308        // Enumerate all (start, end) word pairs within MAX_SPAN_WIDTH.
309        let mut span_idx_flat: Vec<i64> = Vec::new();
310        let mut span_mask: Vec<bool> = Vec::new();
311
312        for start in 0..text_word_count {
313            for end in start..text_word_count.min(start + MAX_SPAN_WIDTH) {
314                span_idx_flat.push(start as i64);
315                span_idx_flat.push(end as i64);
316                span_mask.push(true);
317            }
318        }
319
320        let num_spans = span_mask.len();
321        if num_spans == 0 {
322            return Ok(Vec::new());
323        }
324
325        // ── Step 5: run ORT session ───────────────────────────────────────
326        // onnx-community/gliner_medium-v2.1 defines span_mask as tensor(bool).
327        // All other inputs (input_ids, attention_mask, words_mask, text_lengths,
328        // span_idx) are tensor(int64) — verified against the ONNX proto elem_type.
329
330        let logits_raw: Vec<f32> = {
331            let mut session_guard = session.lock();
332
333            // ort rc.12: Tensor::from_array takes (shape_array, owned_Vec)
334            let input_ids_t = Tensor::<i64>::from_array(([1usize, seq_len], token_ids))
335                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
336            let attn_mask_t = Tensor::<i64>::from_array(([1usize, seq_len], attention_mask))
337                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
338            let words_mask_t = Tensor::<i64>::from_array(([1usize, seq_len], words_mask))
339                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
340            // Shape [batch=1, 1]: GLiNER model requires rank-2 tensor for text_lengths.
341            // All other inputs use [1, N] shape; text_lengths is [1, 1] for single text.
342            let text_lengths_t = Tensor::<i64>::from_array(([1usize, 1usize], text_lengths))
343                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
344            let span_idx_t = Tensor::<i64>::from_array(([1usize, num_spans, 2], span_idx_flat))
345                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
346            // span_mask is bool in the ONNX model graph — must use Tensor::<bool>.
347            let span_mask_t = Tensor::<bool>::from_array(([1usize, num_spans], span_mask))
348                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
349
350            let outputs = session_guard
351                .run(inputs![
352                    "input_ids" => input_ids_t,
353                    "attention_mask" => attn_mask_t,
354                    "words_mask" => words_mask_t,
355                    "text_lengths" => text_lengths_t,
356                    "span_idx" => span_idx_t,
357                    "span_mask" => span_mask_t,
358                ])
359                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
360
361            // outputs[0] = logits: shape [1, num_spans, num_entity_types]
362            let (_shape, logits_slice) = outputs[0]
363                .try_extract_tensor::<f32>()
364                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
365            logits_slice.to_vec()
366        };
367
368        // logits shape: [1, num_spans, num_entity_types]
369        let num_entity_types = entity_types.len();
370        let expected = num_spans * num_entity_types;
371        if logits_raw.len() != expected {
372            warn!(
373                "GLiNER logits shape mismatch: got {}, expected {}",
374                logits_raw.len(),
375                expected
376            );
377            return Ok(Vec::new());
378        }
379
380        // ── Step 6: post-process — sigmoid, threshold, NMS ───────────────
381        let mut raw_entities: Vec<(usize, usize, usize, f32)> = Vec::new(); // (type_idx, start, end, score)
382
383        for (span_i, (start_w, end_w)) in iter_spans(text_word_count).enumerate() {
384            for (type_i, _entity_type) in entity_types.iter().enumerate() {
385                let logit = logits_raw[span_i * num_entity_types + type_i];
386                let score = sigmoid(logit);
387                if score >= SCORE_THRESHOLD {
388                    raw_entities.push((type_i, start_w, end_w, score));
389                }
390            }
391        }
392
393        // NMS: keep highest-score non-overlapping spans per entity type.
394        raw_entities.sort_by(|a, b| b.3.partial_cmp(&a.3).unwrap_or(std::cmp::Ordering::Equal));
395        let mut kept: Vec<(usize, usize, usize, f32)> = Vec::new();
396        'outer: for candidate in &raw_entities {
397            for kept_span in &kept {
398                // Same type and overlapping
399                if kept_span.0 == candidate.0
400                    && kept_span.1 <= candidate.2
401                    && candidate.1 <= kept_span.2
402                {
403                    continue 'outer;
404                }
405            }
406            kept.push(*candidate);
407        }
408
409        // ── Step 7: map word offsets back to char offsets ─────────────────
410        let words: Vec<&str> = text.split_whitespace().collect();
411        let mut word_char_starts: Vec<usize> = Vec::with_capacity(words.len());
412        let mut word_char_ends: Vec<usize> = Vec::with_capacity(words.len());
413        {
414            let mut char_pos = 0usize;
415            for word in &words {
416                // Find this word's start in the original text
417                if let Some(rel) = text[char_pos..].find(word) {
418                    let start = char_pos + rel;
419                    let end = start + word.len();
420                    word_char_starts.push(start);
421                    word_char_ends.push(end);
422                    char_pos = end;
423                } else {
424                    word_char_starts.push(char_pos);
425                    word_char_ends.push(char_pos);
426                }
427            }
428        }
429
430        let mut entities: Vec<ExtractedEntity> = kept
431            .into_iter()
432            .filter_map(|(type_i, start_w, end_w, score)| {
433                let start_char = *word_char_starts.get(start_w)?;
434                let end_char = *word_char_ends.get(end_w)?;
435                let value = text[start_char..end_char].to_string();
436                Some(ExtractedEntity {
437                    entity_type: entity_types[type_i].to_lowercase().replace(' ', "_"),
438                    value,
439                    score,
440                    start: start_char,
441                    end: end_char,
442                })
443            })
444            .collect();
445
446        entities.sort_by_key(|e| e.start);
447        debug!("GLiNER extracted {} entities", entities.len());
448        Ok(entities)
449    }
450
451    // ── Model download helpers ────────────────────────────────────────────
452
453    #[instrument(skip_all)]
454    async fn download_model_files() -> Result<(PathBuf, PathBuf)> {
455        info!(
456            "Resolving GLiNER model files: tokenizer={}, onnx={}",
457            GLINER_TOKENIZER_REPO, GLINER_MODEL_REPO
458        );
459
460        let tokenizer_cache = Self::model_cache_dir(GLINER_TOKENIZER_REPO)?;
461        let onnx_cache = Self::model_cache_dir(GLINER_MODEL_REPO)?;
462        let onnx_subdir = onnx_cache.join("onnx");
463        std::fs::create_dir_all(&onnx_subdir)?;
464
465        let local_tokenizer = tokenizer_cache.join("tokenizer.json");
466        let local_onnx = onnx_subdir.join("model_quantized.onnx");
467
468        if !local_tokenizer.exists() || !local_onnx.exists() {
469            let tok_cache = tokenizer_cache.clone();
470            let onnx_c = onnx_cache.clone();
471            let tok_exists = local_tokenizer.exists();
472            let onnx_exists = local_onnx.exists();
473
474            tokio::task::spawn_blocking(move || {
475                if !tok_exists {
476                    crate::engine::EmbeddingEngine::download_hf_file_pub(
477                        GLINER_TOKENIZER_REPO,
478                        "tokenizer.json",
479                        &tok_cache,
480                    )
481                    .map_err(|e| {
482                        InferenceError::HubError(format!(
483                            "Failed to download GLiNER tokenizer: {}",
484                            e
485                        ))
486                    })?;
487                }
488                if !onnx_exists {
489                    crate::engine::EmbeddingEngine::download_hf_file_pub(
490                        GLINER_MODEL_REPO,
491                        GLINER_ONNX_FILE,
492                        &onnx_c,
493                    )
494                    .map_err(|e| {
495                        InferenceError::HubError(format!(
496                            "Failed to download GLiNER ONNX model: {}",
497                            e
498                        ))
499                    })?;
500                }
501                Ok::<_, InferenceError>(())
502            })
503            .await
504            .map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
505        } else {
506            info!("GLiNER model files found in local cache");
507        }
508
509        let final_onnx = onnx_cache.join(GLINER_ONNX_FILE);
510        Ok((local_tokenizer, final_onnx))
511    }
512
513    fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
514        let base = std::env::var("HF_HOME")
515            .map(PathBuf::from)
516            .unwrap_or_else(|_| {
517                let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
518                PathBuf::from(home).join(".cache").join("huggingface")
519            });
520        let dir = base.join("dakera").join(model_id.replace('/', "--"));
521        std::fs::create_dir_all(&dir)?;
522        Ok(dir)
523    }
524}
525
526// ─────────────────────────────────────────────────────────────
527// NerEngine — unified interface (rule-based + GLiNER)
528// ─────────────────────────────────────────────────────────────
529
530/// Unified NER engine combining rule-based and GLiNER extraction.
531pub struct NerEngine {
532    gliner: Option<Arc<GlinerEngine>>,
533}
534
535impl NerEngine {
536    /// Create a NerEngine with only the rule-based extractor (no model download).
537    pub fn rule_based_only() -> Self {
538        Self { gliner: None }
539    }
540
541    /// Create a NerEngine backed by GLiNER (downloads model on first call).
542    pub async fn with_gliner(num_threads: Option<usize>) -> Result<Self> {
543        let gliner = GlinerEngine::new(num_threads).await?;
544        Ok(Self {
545            gliner: Some(Arc::new(gliner)),
546        })
547    }
548
549    /// Extract entities from text.
550    ///
551    /// Always runs the rule-based pre-pass. If GLiNER is loaded and
552    /// `gliner_types` is non-empty, also runs the neural extractor.
553    /// Results are merged and deduplicated by offset.
554    pub async fn extract(&self, text: &str, gliner_types: &[&str]) -> Vec<ExtractedEntity> {
555        let mut entities = rule_based_extract(text);
556
557        if let Some(ref gliner) = self.gliner {
558            if !gliner_types.is_empty() {
559                match gliner.extract(text, gliner_types).await {
560                    Ok(neural) => {
561                        for ne in neural {
562                            // Skip if rule-based already captured the same span
563                            if !entities
564                                .iter()
565                                .any(|e| e.start == ne.start && e.end == ne.end)
566                            {
567                                entities.push(ne);
568                            }
569                        }
570                    }
571                    Err(e) => {
572                        warn!("GLiNER extraction failed, using rule-based only: {}", e);
573                    }
574                }
575            }
576        }
577
578        entities.sort_by_key(|e| e.start);
579        entities
580    }
581}
582
583// ─────────────────────────────────────────────────────────────
584// Helpers
585// ─────────────────────────────────────────────────────────────
586
587/// Iterate all valid (start, end) word index pairs up to MAX_SPAN_WIDTH.
588fn iter_spans(num_words: usize) -> impl Iterator<Item = (usize, usize)> {
589    (0..num_words).flat_map(move |start| {
590        let max_end = num_words.min(start + MAX_SPAN_WIDTH);
591        (start..max_end).map(move |end| (start, end))
592    })
593}
594
595/// Numerically stable sigmoid.
596#[inline]
597fn sigmoid(x: f32) -> f32 {
598    if x >= 0.0 {
599        1.0 / (1.0 + (-x).exp())
600    } else {
601        let ex = x.exp();
602        ex / (1.0 + ex)
603    }
604}
605
606// ─────────────────────────────────────────────────────────────
607// Tests
608// ─────────────────────────────────────────────────────────────
609
610#[cfg(test)]
611mod tests {
612    use super::*;
613
614    #[test]
615    fn test_rule_based_uuid() {
616        let text = "session id is 550e8400-e29b-41d4-a716-446655440000 here";
617        let entities = rule_based_extract(text);
618        assert!(entities.iter().any(|e| e.entity_type == "uuid"));
619    }
620
621    #[test]
622    fn test_rule_based_url() {
623        let text = "check https://example.com/path?q=1 for details";
624        let entities = rule_based_extract(text);
625        assert!(entities.iter().any(|e| e.entity_type == "url"));
626    }
627
628    #[test]
629    fn test_rule_based_email() {
630        let text = "contact alice@example.com for support";
631        let entities = rule_based_extract(text);
632        assert!(entities.iter().any(|e| e.entity_type == "email"));
633        // Email should NOT also be captured as url
634        assert!(!entities.iter().any(|e| e.entity_type == "url"));
635    }
636
637    #[test]
638    fn test_rule_based_iso_date() {
639        let text = "released on 2024-03-15 at noon";
640        let entities = rule_based_extract(text);
641        assert!(entities
642            .iter()
643            .any(|e| e.entity_type == "date" && e.value == "2024-03-15"));
644    }
645
646    #[test]
647    fn test_rule_based_natural_date() {
648        let text = "meeting on March 15, 2024 at noon";
649        let entities = rule_based_extract(text);
650        assert!(entities.iter().any(|e| e.entity_type == "date"));
651    }
652
653    #[test]
654    fn test_entity_to_tag() {
655        let e = ExtractedEntity {
656            entity_type: "person".to_string(),
657            value: "Alice Smith".to_string(),
658            score: 0.9,
659            start: 0,
660            end: 11,
661        };
662        assert_eq!(e.to_tag(), "entity:person:Alice Smith");
663    }
664
665    #[test]
666    fn test_entity_to_tag_colon_escaping() {
667        let e = ExtractedEntity {
668            entity_type: "url".to_string(),
669            value: "http://example.com:8080/path".to_string(),
670            score: 1.0,
671            start: 0,
672            end: 27,
673        };
674        let tag = e.to_tag();
675        // Tag has format entity:<type>:<value> — exactly 2 colons before value
676        // Colons within the value are replaced with underscores
677        let parts: Vec<&str> = tag.splitn(3, ':').collect();
678        assert_eq!(parts.len(), 3, "tag should have 3 parts: {}", tag);
679        assert_eq!(parts[0], "entity");
680        assert_eq!(parts[1], "url");
681        assert!(
682            !parts[2].contains(':'),
683            "value should not contain colons: {}",
684            parts[2]
685        );
686    }
687
688    #[test]
689    fn test_sigmoid() {
690        assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
691        assert!((sigmoid(100.0) - 1.0).abs() < 1e-4);
692        assert!((sigmoid(-100.0) - 0.0).abs() < 1e-4);
693    }
694}