Skip to main content

inference/
ner.rs

1//! Named Entity Recognition (NER) engine — CE-4 GLiNER zero-shot NER.
2//!
3//! Two-layer extraction pipeline:
4//! 1. **Rule-based pre-pass** — regex extraction of dates, URLs, UUIDs, emails, IPs.
5//!    Always on, zero latency, no model download required.
6//! 2. **GLiNER ONNX engine** — zero-shot NER via GLiNER-medium ONNX INT8 (52 MB).
7//!    Opt-in per namespace, lazy-loaded on first use.
8//!
9//! Extracted entities are stored as tags: `entity:person:Alice`, `entity:org:Anthropic`.
10//!
11
12use crate::error::{InferenceError, Result};
13use ort::inputs;
14use ort::session::builder::GraphOptimizationLevel;
15use ort::session::Session;
16use ort::value::Tensor;
17use parking_lot::Mutex;
18use regex::Regex;
19use std::path::PathBuf;
20use std::sync::Arc;
21use tokenizers::Tokenizer;
22use tracing::{debug, info, instrument, warn};
23
24// ─────────────────────────────────────────────────────────────
25// Public types
26// ─────────────────────────────────────────────────────────────
27
28/// A single extracted entity.
29#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
30pub struct ExtractedEntity {
31    /// Normalised entity type (e.g. "person", "org", "date", "url").
32    pub entity_type: String,
33    /// The entity surface form extracted from the text.
34    pub value: String,
35    /// Confidence score 0.0–1.0 (rule-based entities get 1.0).
36    pub score: f32,
37    /// Byte start offset in the original text.
38    pub start: usize,
39    /// Byte end offset in the original text.
40    pub end: usize,
41}
42
43impl ExtractedEntity {
44    /// Convert to the canonical tag format `entity:<type>:<value>`.
45    pub fn to_tag(&self) -> String {
46        let v = self.value.replace(':', "_");
47        format!("entity:{}:{}", self.entity_type, v)
48    }
49}
50
51// ─────────────────────────────────────────────────────────────
52// Rule-based pre-pass
53// ─────────────────────────────────────────────────────────────
54
55struct RulePatterns {
56    uuid: Regex,
57    url: Regex,
58    email: Regex,
59    iso_date: Regex,
60    natural_date: Regex,
61    ip_v4: Regex,
62}
63
64impl RulePatterns {
65    fn new() -> Self {
66        Self {
67            uuid: Regex::new(
68                r"(?i)\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b",
69            )
70            .expect("uuid regex"),
71            url: Regex::new(r#"https?://[^\s<>\[\]()"']+"#)
72                .expect("url regex"),
73            email: Regex::new(r"[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}")
74                .expect("email regex"),
75            iso_date: Regex::new(
76                r"\b\d{4}-(?:0[1-9]|1[0-2])-(?:0[1-9]|[12]\d|3[01])\b",
77            )
78            .expect("iso_date regex"),
79            natural_date: Regex::new(
80                r"(?i)\b(?:Jan(?:uary)?|Feb(?:ruary)?|Mar(?:ch)?|Apr(?:il)?|May|Jun(?:e)?|Jul(?:y)?|Aug(?:ust)?|Sep(?:tember)?|Oct(?:ober)?|Nov(?:ember)?|Dec(?:ember)?)\s+\d{1,2}(?:,\s*\d{4})?\b",
81            )
82            .expect("natural_date regex"),
83            ip_v4: Regex::new(
84                r"\b(?:(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(?:25[0-5]|2[0-4]\d|[01]?\d\d?)\b",
85            )
86            .expect("ipv4 regex"),
87        }
88    }
89}
90
91lazy_static::lazy_static! {
92    static ref RULE_PATTERNS: RulePatterns = RulePatterns::new();
93}
94
95/// Run the rule-based pre-pass — O(n) regex scan, zero model overhead.
96///
97/// Always extracts: uuid, url, email, date (ISO + natural), ipv4.
98pub fn rule_based_extract(text: &str) -> Vec<ExtractedEntity> {
99    let mut entities: Vec<ExtractedEntity> = Vec::new();
100
101    let push = |entities: &mut Vec<ExtractedEntity>, entity_type: &str, m: regex::Match| {
102        entities.push(ExtractedEntity {
103            entity_type: entity_type.to_string(),
104            value: m.as_str().to_string(),
105            score: 1.0,
106            start: m.start(),
107            end: m.end(),
108        });
109    };
110
111    // Order matters — email before URL (email contains @, URL starts with http)
112    for m in RULE_PATTERNS.email.find_iter(text) {
113        push(&mut entities, "email", m);
114    }
115    for m in RULE_PATTERNS.url.find_iter(text) {
116        // Skip if already captured as email
117        if !entities.iter().any(|e| e.start == m.start()) {
118            push(&mut entities, "url", m);
119        }
120    }
121    for m in RULE_PATTERNS.uuid.find_iter(text) {
122        push(&mut entities, "uuid", m);
123    }
124    for m in RULE_PATTERNS.iso_date.find_iter(text) {
125        push(&mut entities, "date", m);
126    }
127    for m in RULE_PATTERNS.natural_date.find_iter(text) {
128        // Only if no ISO date already at this offset
129        if !entities
130            .iter()
131            .any(|e| e.start == m.start() && e.entity_type == "date")
132        {
133            push(&mut entities, "date", m);
134        }
135    }
136    for m in RULE_PATTERNS.ip_v4.find_iter(text) {
137        push(&mut entities, "ip", m);
138    }
139
140    entities
141}
142
143// ─────────────────────────────────────────────────────────────
144// GLiNER ONNX engine
145// ─────────────────────────────────────────────────────────────
146
147const GLINER_MODEL_REPO: &str = "onnx-community/gliner-medium-v2.1";
148const GLINER_TOKENIZER_REPO: &str = "knowledgator/gliner-medium-v2.1";
149const GLINER_ONNX_FILE: &str = "onnx/model_quantized.onnx";
150
151/// Maximum span width in words (GLiNER default).
152const MAX_SPAN_WIDTH: usize = 12;
153/// Confidence threshold for accepting a span prediction.
154const SCORE_THRESHOLD: f32 = 0.5;
155
156/// GLiNER zero-shot NER engine backed by ONNX Runtime.
157///
158/// Thread-safe — the session is mutex-guarded.
159pub struct GlinerEngine {
160    session: Arc<Mutex<Session>>,
161    tokenizer: Arc<Tokenizer>,
162}
163
164impl GlinerEngine {
165    /// Create a new GLiNER engine, downloading the model if not cached.
166    #[instrument(skip_all)]
167    pub async fn new(num_threads: Option<usize>) -> Result<Self> {
168        let threads = num_threads.unwrap_or(1);
169        info!("Initializing GLiNER NER engine (threads={})", threads);
170
171        let (tokenizer_path, onnx_path) = Self::download_model_files().await?;
172
173        let tokenizer = Tokenizer::from_file(&tokenizer_path)
174            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
175
176        let session = Session::builder()
177            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
178            .with_optimization_level(GraphOptimizationLevel::Level3)
179            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
180            .with_intra_threads(threads)
181            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?
182            .commit_from_file(&onnx_path)
183            .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
184
185        info!("GLiNER engine ready");
186        Ok(Self {
187            session: Arc::new(Mutex::new(session)),
188            tokenizer: Arc::new(tokenizer),
189        })
190    }
191
192    /// Extract entities from text for the given entity types.
193    ///
194    /// Returns deduplicated, threshold-filtered entities sorted by start offset.
195    pub async fn extract(&self, text: &str, entity_types: &[&str]) -> Result<Vec<ExtractedEntity>> {
196        if entity_types.is_empty() || text.is_empty() {
197            return Ok(Vec::new());
198        }
199
200        let text_owned = text.to_string();
201        let entity_types_owned: Vec<String> = entity_types.iter().map(|s| s.to_string()).collect();
202        let session = self.session.clone();
203        let tokenizer = self.tokenizer.clone();
204
205        tokio::task::spawn_blocking(move || {
206            Self::run_inference(
207                &text_owned,
208                &entity_types_owned
209                    .iter()
210                    .map(|s| s.as_str())
211                    .collect::<Vec<_>>(),
212                &session,
213                &tokenizer,
214            )
215        })
216        .await
217        .map_err(|e| InferenceError::HubError(format!("GLiNER inference task panicked: {}", e)))?
218    }
219
220    fn run_inference(
221        text: &str,
222        entity_types: &[&str],
223        session: &Arc<Mutex<Session>>,
224        tokenizer: &Tokenizer,
225    ) -> Result<Vec<ExtractedEntity>> {
226        // ── Step 1: build input text ──────────────────────────────────────
227        // Format: "type1 << >> type2 << >> text"
228        let mut full_text = entity_types.join(" << >> ");
229        full_text.push_str(" << >> ");
230        full_text.push_str(text);
231
232        // ── Step 2: tokenize ──────────────────────────────────────────────
233        let encoding = tokenizer
234            .encode(full_text.as_str(), true)
235            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
236
237        let token_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
238        let attention_mask: Vec<i64> = encoding
239            .get_attention_mask()
240            .iter()
241            .map(|&x| x as i64)
242            .collect();
243        let seq_len = token_ids.len();
244
245        // ── Step 3: compute words_mask and text_length ────────────────────
246        // words_mask[i] = 1 if token i is the start of a word, 0 otherwise.
247        // We use word_ids from the fast tokenizer encoding.
248        let word_ids = encoding.get_word_ids();
249
250        let mut words_mask = vec![0i64; seq_len];
251        let mut last_word_id: Option<u32> = None;
252        let mut text_token_start = usize::MAX;
253
254        // Count words in the entity type prefix to find where text begins.
255        // Prefix format: "type1 << >> type2 << >> " — count distinct word_ids
256        // that appear before the text's first word.
257        // Simpler: find the token at which the text portion starts by
258        // counting words in the prefix.
259        let prefix = entity_types.join(" << >> ");
260        let prefix_plus_sep = format!("{} << >> ", prefix);
261        let prefix_encoding = tokenizer
262            .encode(prefix_plus_sep.as_str(), false)
263            .map_err(|e| InferenceError::TokenizationError(e.to_string()))?;
264        let prefix_word_count = prefix_encoding
265            .get_word_ids()
266            .iter()
267            .filter_map(|&w| w)
268            .collect::<std::collections::HashSet<_>>()
269            .len();
270
271        let mut text_word_count = 0usize;
272        for (i, &wid_opt) in word_ids.iter().enumerate() {
273            let wid = match wid_opt {
274                Some(w) => w,
275                None => {
276                    last_word_id = None;
277                    continue;
278                }
279            };
280            let is_new_word = last_word_id.map(|lw| lw != wid).unwrap_or(true);
281            if is_new_word {
282                let global_word_idx = {
283                    // count distinct words so far
284                    word_ids[..i]
285                        .iter()
286                        .filter_map(|&w| w)
287                        .collect::<std::collections::HashSet<_>>()
288                        .len()
289                };
290                if global_word_idx >= prefix_word_count {
291                    if text_token_start == usize::MAX {
292                        text_token_start = i;
293                    }
294                    words_mask[i] = 1;
295                    text_word_count += 1;
296                }
297            }
298            last_word_id = Some(wid);
299        }
300
301        if text_word_count == 0 || text_token_start == usize::MAX {
302            debug!("No text words found after entity type prefix, skipping inference");
303            return Ok(Vec::new());
304        }
305        let text_lengths = vec![text_word_count as i64];
306
307        // ── Step 4: generate candidate spans ─────────────────────────────
308        // Enumerate all (start, end) word pairs within MAX_SPAN_WIDTH.
309        let mut span_idx_flat: Vec<i64> = Vec::new();
310        let mut span_mask: Vec<bool> = Vec::new();
311
312        for start in 0..text_word_count {
313            for end in start..text_word_count.min(start + MAX_SPAN_WIDTH) {
314                span_idx_flat.push(start as i64);
315                span_idx_flat.push(end as i64);
316                span_mask.push(true);
317            }
318        }
319
320        let num_spans = span_mask.len();
321        if num_spans == 0 {
322            return Ok(Vec::new());
323        }
324
325        // ── Step 5: run ORT session ───────────────────────────────────────
326        let span_mask_values: Vec<i64> = span_mask
327            .iter()
328            .map(|&b| if b { 1i64 } else { 0 })
329            .collect();
330
331        let logits_raw: Vec<f32> = {
332            let mut session_guard = session.lock();
333
334            // ort rc.12: Tensor::from_array takes (shape_array, owned_Vec)
335            let input_ids_t = Tensor::<i64>::from_array(([1usize, seq_len], token_ids))
336                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
337            let attn_mask_t = Tensor::<i64>::from_array(([1usize, seq_len], attention_mask))
338                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
339            let words_mask_t = Tensor::<i64>::from_array(([1usize, seq_len], words_mask))
340                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
341            let text_lengths_t = Tensor::<i64>::from_array(([1usize], text_lengths))
342                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
343            let span_idx_t = Tensor::<i64>::from_array(([1usize, num_spans, 2], span_idx_flat))
344                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
345            let span_mask_t = Tensor::<i64>::from_array(([1usize, num_spans], span_mask_values))
346                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
347
348            let outputs = session_guard
349                .run(inputs![
350                    "input_ids" => input_ids_t,
351                    "attention_mask" => attn_mask_t,
352                    "words_mask" => words_mask_t,
353                    "text_lengths" => text_lengths_t,
354                    "span_idx" => span_idx_t,
355                    "span_mask" => span_mask_t,
356                ])
357                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
358
359            // outputs[0] = logits: shape [1, num_spans, num_entity_types]
360            let (_shape, logits_slice) = outputs[0]
361                .try_extract_tensor::<f32>()
362                .map_err(|e| InferenceError::ModelLoadError(e.to_string()))?;
363            logits_slice.to_vec()
364        };
365
366        // logits shape: [1, num_spans, num_entity_types]
367        let num_entity_types = entity_types.len();
368        let expected = num_spans * num_entity_types;
369        if logits_raw.len() != expected {
370            warn!(
371                "GLiNER logits shape mismatch: got {}, expected {}",
372                logits_raw.len(),
373                expected
374            );
375            return Ok(Vec::new());
376        }
377
378        // ── Step 6: post-process — sigmoid, threshold, NMS ───────────────
379        let mut raw_entities: Vec<(usize, usize, usize, f32)> = Vec::new(); // (type_idx, start, end, score)
380
381        for (span_i, (start_w, end_w)) in iter_spans(text_word_count).enumerate() {
382            for (type_i, _entity_type) in entity_types.iter().enumerate() {
383                let logit = logits_raw[span_i * num_entity_types + type_i];
384                let score = sigmoid(logit);
385                if score >= SCORE_THRESHOLD {
386                    raw_entities.push((type_i, start_w, end_w, score));
387                }
388            }
389        }
390
391        // NMS: keep highest-score non-overlapping spans per entity type.
392        raw_entities.sort_by(|a, b| b.3.partial_cmp(&a.3).unwrap_or(std::cmp::Ordering::Equal));
393        let mut kept: Vec<(usize, usize, usize, f32)> = Vec::new();
394        'outer: for candidate in &raw_entities {
395            for kept_span in &kept {
396                // Same type and overlapping
397                if kept_span.0 == candidate.0
398                    && kept_span.1 <= candidate.2
399                    && candidate.1 <= kept_span.2
400                {
401                    continue 'outer;
402                }
403            }
404            kept.push(*candidate);
405        }
406
407        // ── Step 7: map word offsets back to char offsets ─────────────────
408        let words: Vec<&str> = text.split_whitespace().collect();
409        let mut word_char_starts: Vec<usize> = Vec::with_capacity(words.len());
410        let mut word_char_ends: Vec<usize> = Vec::with_capacity(words.len());
411        {
412            let mut char_pos = 0usize;
413            for word in &words {
414                // Find this word's start in the original text
415                if let Some(rel) = text[char_pos..].find(word) {
416                    let start = char_pos + rel;
417                    let end = start + word.len();
418                    word_char_starts.push(start);
419                    word_char_ends.push(end);
420                    char_pos = end;
421                } else {
422                    word_char_starts.push(char_pos);
423                    word_char_ends.push(char_pos);
424                }
425            }
426        }
427
428        let mut entities: Vec<ExtractedEntity> = kept
429            .into_iter()
430            .filter_map(|(type_i, start_w, end_w, score)| {
431                let start_char = *word_char_starts.get(start_w)?;
432                let end_char = *word_char_ends.get(end_w)?;
433                let value = text[start_char..end_char].to_string();
434                Some(ExtractedEntity {
435                    entity_type: entity_types[type_i].to_lowercase().replace(' ', "_"),
436                    value,
437                    score,
438                    start: start_char,
439                    end: end_char,
440                })
441            })
442            .collect();
443
444        entities.sort_by_key(|e| e.start);
445        debug!("GLiNER extracted {} entities", entities.len());
446        Ok(entities)
447    }
448
449    // ── Model download helpers ────────────────────────────────────────────
450
451    #[instrument(skip_all)]
452    async fn download_model_files() -> Result<(PathBuf, PathBuf)> {
453        info!(
454            "Resolving GLiNER model files: tokenizer={}, onnx={}",
455            GLINER_TOKENIZER_REPO, GLINER_MODEL_REPO
456        );
457
458        let tokenizer_cache = Self::model_cache_dir(GLINER_TOKENIZER_REPO)?;
459        let onnx_cache = Self::model_cache_dir(GLINER_MODEL_REPO)?;
460        let onnx_subdir = onnx_cache.join("onnx");
461        std::fs::create_dir_all(&onnx_subdir)?;
462
463        let local_tokenizer = tokenizer_cache.join("tokenizer.json");
464        let local_onnx = onnx_subdir.join("model_quantized.onnx");
465
466        if !local_tokenizer.exists() || !local_onnx.exists() {
467            let tok_cache = tokenizer_cache.clone();
468            let onnx_c = onnx_cache.clone();
469            let tok_exists = local_tokenizer.exists();
470            let onnx_exists = local_onnx.exists();
471
472            tokio::task::spawn_blocking(move || {
473                if !tok_exists {
474                    crate::engine::EmbeddingEngine::download_hf_file_pub(
475                        GLINER_TOKENIZER_REPO,
476                        "tokenizer.json",
477                        &tok_cache,
478                    )
479                    .map_err(|e| {
480                        InferenceError::HubError(format!(
481                            "Failed to download GLiNER tokenizer: {}",
482                            e
483                        ))
484                    })?;
485                }
486                if !onnx_exists {
487                    crate::engine::EmbeddingEngine::download_hf_file_pub(
488                        GLINER_MODEL_REPO,
489                        GLINER_ONNX_FILE,
490                        &onnx_c,
491                    )
492                    .map_err(|e| {
493                        InferenceError::HubError(format!(
494                            "Failed to download GLiNER ONNX model: {}",
495                            e
496                        ))
497                    })?;
498                }
499                Ok::<_, InferenceError>(())
500            })
501            .await
502            .map_err(|e| InferenceError::HubError(format!("Download task panicked: {}", e)))??;
503        } else {
504            info!("GLiNER model files found in local cache");
505        }
506
507        let final_onnx = onnx_cache.join(GLINER_ONNX_FILE);
508        Ok((local_tokenizer, final_onnx))
509    }
510
511    fn model_cache_dir(model_id: &str) -> Result<PathBuf> {
512        let base = std::env::var("HF_HOME")
513            .map(PathBuf::from)
514            .unwrap_or_else(|_| {
515                let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
516                PathBuf::from(home).join(".cache").join("huggingface")
517            });
518        let dir = base.join("dakera").join(model_id.replace('/', "--"));
519        std::fs::create_dir_all(&dir)?;
520        Ok(dir)
521    }
522}
523
524// ─────────────────────────────────────────────────────────────
525// NerEngine — unified interface (rule-based + GLiNER)
526// ─────────────────────────────────────────────────────────────
527
528/// Unified NER engine combining rule-based and GLiNER extraction.
529pub struct NerEngine {
530    gliner: Option<Arc<GlinerEngine>>,
531}
532
533impl NerEngine {
534    /// Create a NerEngine with only the rule-based extractor (no model download).
535    pub fn rule_based_only() -> Self {
536        Self { gliner: None }
537    }
538
539    /// Create a NerEngine backed by GLiNER (downloads model on first call).
540    pub async fn with_gliner(num_threads: Option<usize>) -> Result<Self> {
541        let gliner = GlinerEngine::new(num_threads).await?;
542        Ok(Self {
543            gliner: Some(Arc::new(gliner)),
544        })
545    }
546
547    /// Extract entities from text.
548    ///
549    /// Always runs the rule-based pre-pass. If GLiNER is loaded and
550    /// `gliner_types` is non-empty, also runs the neural extractor.
551    /// Results are merged and deduplicated by offset.
552    pub async fn extract(&self, text: &str, gliner_types: &[&str]) -> Vec<ExtractedEntity> {
553        let mut entities = rule_based_extract(text);
554
555        if let Some(ref gliner) = self.gliner {
556            if !gliner_types.is_empty() {
557                match gliner.extract(text, gliner_types).await {
558                    Ok(neural) => {
559                        for ne in neural {
560                            // Skip if rule-based already captured the same span
561                            if !entities
562                                .iter()
563                                .any(|e| e.start == ne.start && e.end == ne.end)
564                            {
565                                entities.push(ne);
566                            }
567                        }
568                    }
569                    Err(e) => {
570                        warn!("GLiNER extraction failed, using rule-based only: {}", e);
571                    }
572                }
573            }
574        }
575
576        entities.sort_by_key(|e| e.start);
577        entities
578    }
579}
580
581// ─────────────────────────────────────────────────────────────
582// Helpers
583// ─────────────────────────────────────────────────────────────
584
585/// Iterate all valid (start, end) word index pairs up to MAX_SPAN_WIDTH.
586fn iter_spans(num_words: usize) -> impl Iterator<Item = (usize, usize)> {
587    (0..num_words).flat_map(move |start| {
588        let max_end = num_words.min(start + MAX_SPAN_WIDTH);
589        (start..max_end).map(move |end| (start, end))
590    })
591}
592
593/// Numerically stable sigmoid.
594#[inline]
595fn sigmoid(x: f32) -> f32 {
596    if x >= 0.0 {
597        1.0 / (1.0 + (-x).exp())
598    } else {
599        let ex = x.exp();
600        ex / (1.0 + ex)
601    }
602}
603
604// ─────────────────────────────────────────────────────────────
605// Tests
606// ─────────────────────────────────────────────────────────────
607
608#[cfg(test)]
609mod tests {
610    use super::*;
611
612    #[test]
613    fn test_rule_based_uuid() {
614        let text = "session id is 550e8400-e29b-41d4-a716-446655440000 here";
615        let entities = rule_based_extract(text);
616        assert!(entities.iter().any(|e| e.entity_type == "uuid"));
617    }
618
619    #[test]
620    fn test_rule_based_url() {
621        let text = "check https://example.com/path?q=1 for details";
622        let entities = rule_based_extract(text);
623        assert!(entities.iter().any(|e| e.entity_type == "url"));
624    }
625
626    #[test]
627    fn test_rule_based_email() {
628        let text = "contact alice@example.com for support";
629        let entities = rule_based_extract(text);
630        assert!(entities.iter().any(|e| e.entity_type == "email"));
631        // Email should NOT also be captured as url
632        assert!(!entities.iter().any(|e| e.entity_type == "url"));
633    }
634
635    #[test]
636    fn test_rule_based_iso_date() {
637        let text = "released on 2024-03-15 at noon";
638        let entities = rule_based_extract(text);
639        assert!(entities
640            .iter()
641            .any(|e| e.entity_type == "date" && e.value == "2024-03-15"));
642    }
643
644    #[test]
645    fn test_rule_based_natural_date() {
646        let text = "meeting on March 15, 2024 at noon";
647        let entities = rule_based_extract(text);
648        assert!(entities.iter().any(|e| e.entity_type == "date"));
649    }
650
651    #[test]
652    fn test_entity_to_tag() {
653        let e = ExtractedEntity {
654            entity_type: "person".to_string(),
655            value: "Alice Smith".to_string(),
656            score: 0.9,
657            start: 0,
658            end: 11,
659        };
660        assert_eq!(e.to_tag(), "entity:person:Alice Smith");
661    }
662
663    #[test]
664    fn test_entity_to_tag_colon_escaping() {
665        let e = ExtractedEntity {
666            entity_type: "url".to_string(),
667            value: "http://example.com:8080/path".to_string(),
668            score: 1.0,
669            start: 0,
670            end: 27,
671        };
672        let tag = e.to_tag();
673        // Tag has format entity:<type>:<value> — exactly 2 colons before value
674        // Colons within the value are replaced with underscores
675        let parts: Vec<&str> = tag.splitn(3, ':').collect();
676        assert_eq!(parts.len(), 3, "tag should have 3 parts: {}", tag);
677        assert_eq!(parts[0], "entity");
678        assert_eq!(parts[1], "url");
679        assert!(
680            !parts[2].contains(':'),
681            "value should not contain colons: {}",
682            parts[2]
683        );
684    }
685
686    #[test]
687    fn test_sigmoid() {
688        assert!((sigmoid(0.0) - 0.5).abs() < 1e-6);
689        assert!((sigmoid(100.0) - 1.0).abs() < 1e-4);
690        assert!((sigmoid(-100.0) - 0.0).abs() < 1e-4);
691    }
692}