Skip to main content

anno/backends/
w2ner.rs

1//! W2NER - Unified NER via Word-Word Relation Classification.
2//!
3//! W2NER (Word-to-Word NER) models NER as classifying relations between
4//! every pair of words in a sentence. This elegantly handles:
5//!
6//! - **Nested entities**: "The \[University of \[California\]\]"
7//! - **Discontinuous entities**: "severe \[pain\] ... in \[abdomen\]" *(see limitation below)*
8//! - **Overlapping entities**: Same span, different types
9//!
10//! # Discontinuous Entities (Important Limitation)
11//!
12//! **True discontinuous entity decoding is not yet implemented.** The W2NER
13//! paper describes a grid-based algorithm for linking non-adjacent spans, but
14//! this implementation currently returns only contiguous spans.
15//!
16//! The [`DiscontinuousNER`] trait is implemented for API compatibility, but
17//! `extract_discontinuous()` wraps each contiguous entity into a single-segment
18//! result. The `W2NERConfig.allow_discontinuous` flag exists for forward-compatibility
19//! but does not change behavior today.
20//!
21//! # Language Support (Important Limitation)
22//!
23//! **This implementation uses whitespace tokenization** (`split_whitespace()`),
24//! which works correctly for:
25//!
26//! - **Latin-script languages**: English, German, French, Spanish, etc.
27//! - **Cyrillic**: Russian, Ukrainian, etc.
28//! - **Languages with explicit word boundaries**
29//!
30//! It does **NOT** work correctly for:
31//!
32//! - **CJK languages** (Chinese, Japanese, Korean): No whitespace between words
33//! - **Thai, Khmer, Lao**: Scriptio continua (no word boundaries)
34//! - **Languages requiring morphological analysis**
35//!
36//! If you need CJK/Thai support, consider:
37//! 1. Pre-tokenizing with a proper segmenter (e.g., jieba, mecab, pythainlp)
38//! 2. Using a different backend (e.g., GLiNER with subword tokenization)
39//!
40//! The `language` parameter to [`Model::extract_entities`] is currently ignored,
41//! but a warning is logged if a non-whitespace language is detected.
42//!
43//! # Architecture
44//!
45//! ```text
46//! Input: "New York City is great"
47//!
48//!        ┌─────────────────────────────┐
49//!        │      Encoder (BERT)          │
50//!        └─────────────────────────────┘
51//!                     │
52//!        ┌─────────────────────────────┐
53//!        │    Biaffine Attention        │
54//!        │    (word-word scoring)       │
55//!        └─────────────────────────────┘
56//!                     │
57//!        ┌───────────────────────────────┐
58//!        │     Word-Word Grid (N×N×L)    │
59//!        │  ┌───┬───┬───┬───┬───┐       │
60//!        │  │   │New│York│City│...│      │
61//!        │  ├───┼───┼───┼───┼───┤       │
62//!        │  │New│ B │NNW│THW│   │       │
63//!        │  ├───┼───┼───┼───┼───┤       │
64//!        │  │Yrk│   │ B │NNW│   │       │
65//!        │  ├───┼───┼───┼───┼───┤       │
66//!        │  │Cty│   │   │ B │   │       │
67//!        │  └───┴───┴───┴───┴───┘       │
68//!        └───────────────────────────────┘
69//!
70//! Legend:
71//!   B   = Begin entity
72//!   NNW = Next-Neighboring-Word (same entity)
73//!   THW = Tail-Head-Word (entity boundary)
74//! ```
75//!
76//! # Grid Labels
77//!
78//! W2NER uses three relation types for each entity label:
79//!
80//! - **NNW (Next-Neighboring-Word)**: Token i and j are adjacent in same entity
81//! - **THW (Tail-Head-Word)**: Token i is tail, token j is head of entity
82//! - **None**: No relation
83//!
84//! # Usage
85//!
86//! ```rust,ignore
87//! use anno::W2NER;
88//!
89//! // Load W2NER model (requires `onnx` feature)
90//! let w2ner = W2NER::from_pretrained("path/to/w2ner-model")?;
91//!
92//! let text = "The University of California Berkeley";
93//! let entities = w2ner.extract_entities(text, None)?;
94//! // Returns nested entities: ORG + nested LOC
95//! ```
96//!
97//! # References
98//!
99//! - [W2NER Paper](https://arxiv.org/abs/2112.10070) (AAAI 2022)
100//! - [TPLinker](https://aclanthology.org/2020.coling-main.138/) (related approach)
101
102use crate::backends::inference::{
103    DiscontinuousEntity, DiscontinuousNER, HandshakingCell, HandshakingMatrix,
104};
105use crate::{Entity, EntityType, Model, Result};
106
107#[cfg(feature = "onnx")]
108use crate::Error;
109
110/// W2NER relation types for word-word classification.
111#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
112pub enum W2NERRelation {
113    /// Next-Neighboring-Word: tokens are adjacent in same entity
114    NNW,
115    /// Tail-Head-Word: marks entity boundary (tail -> head)
116    THW,
117    /// No relation between tokens
118    None,
119}
120
121impl W2NERRelation {
122    /// Convert from label index.
123    #[must_use]
124    pub fn from_index(idx: usize) -> Self {
125        match idx {
126            0 => Self::None,
127            1 => Self::NNW,
128            2 => Self::THW,
129            _ => Self::None,
130        }
131    }
132
133    /// Convert to label index.
134    #[must_use]
135    pub fn to_index(self) -> usize {
136        match self {
137            Self::None => 0,
138            Self::NNW => 1,
139            Self::THW => 2,
140        }
141    }
142}
143
144/// Configuration for W2NER decoding.
145///
146/// # Tokenization
147///
148/// W2NER uses **whitespace tokenization** (`split_whitespace()`), which works
149/// for Latin-script languages but fails for CJK/Thai/Lao. See module-level
150/// docs for details and workarounds.
151#[derive(Debug, Clone)]
152pub struct W2NERConfig {
153    /// Confidence threshold for grid predictions
154    pub threshold: f64,
155    /// Entity type labels (maps grid channels to types)
156    pub entity_labels: Vec<String>,
157    /// Whether to extract nested entities
158    pub allow_nested: bool,
159    /// Whether to extract discontinuous entities.
160    ///
161    /// **Note**: Currently, discontinuous decoding is not fully implemented.
162    /// This flag exists for forward-compatibility; setting it to `true` does
163    /// not yet produce true discontinuous spans. See `backend-02` in docs.
164    pub allow_discontinuous: bool,
165    /// Model identifier for loading
166    pub model_id: String,
167}
168
169impl Default for W2NERConfig {
170    fn default() -> Self {
171        Self {
172            threshold: 0.5,
173            entity_labels: vec!["PER".to_string(), "ORG".to_string(), "LOC".to_string()],
174            allow_nested: true,
175            allow_discontinuous: true,
176            model_id: String::new(),
177        }
178    }
179}
180
181/// W2NER model for unified named entity recognition.
182///
183/// Uses word-word relation classification to handle complex entity
184/// structures (nested, overlapping, discontinuous).
185///
186/// # Feature Requirements
187///
188/// Requires the `onnx` feature for actual inference. Without it, only the
189/// [`decode_from_matrix`](Self::decode_from_matrix) method works with
190/// pre-computed grids.
191///
192/// # Example
193///
194/// ```rust,ignore
195/// let w2ner = W2NER::from_pretrained("ljynlp/w2ner-bert-base")?;
196///
197/// // Handles nested entities naturally
198/// let text = "The University of California Berkeley";
199/// let entities = w2ner.extract_entities(text, None)?;
200/// ```
201pub struct W2NER {
202    config: W2NERConfig,
203    #[cfg(feature = "onnx")]
204    session: Option<crate::sync::Mutex<ort::session::Session>>,
205    #[cfg(feature = "onnx")]
206    tokenizer: Option<tokenizers::Tokenizer>,
207}
208
209impl W2NER {
210    /// Create W2NER with default configuration.
211    #[must_use]
212    pub fn new() -> Self {
213        Self {
214            config: W2NERConfig::default(),
215            #[cfg(feature = "onnx")]
216            session: None,
217            #[cfg(feature = "onnx")]
218            tokenizer: None,
219        }
220    }
221
222    /// Create with custom configuration.
223    #[must_use]
224    pub fn with_config(config: W2NERConfig) -> Self {
225        Self {
226            config,
227            #[cfg(feature = "onnx")]
228            session: None,
229            #[cfg(feature = "onnx")]
230            tokenizer: None,
231        }
232    }
233
234    /// Load W2NER model from path or HuggingFace.
235    ///
236    /// Automatically loads `.env` for HF_TOKEN if present.
237    ///
238    /// # Arguments
239    /// * `model_path` - Local path or HuggingFace model ID
240    #[cfg(feature = "onnx")]
241    pub fn from_pretrained(model_path: &str) -> Result<Self> {
242        use hf_hub::api::sync::{Api, ApiBuilder};
243        use ort::execution_providers::CPUExecutionProvider;
244        use ort::session::Session;
245        use std::path::Path;
246        use std::process::Command;
247
248        // Load .env if present (for HF_TOKEN)
249        crate::env::load_dotenv();
250
251        let (model_file, tokenizer_file) = if Path::new(model_path).exists() {
252            // Local path
253            let model_file = Path::new(model_path).join("model.onnx");
254            let tokenizer_file = Path::new(model_path).join("tokenizer.json");
255            (model_file, tokenizer_file)
256        } else {
257            // HuggingFace download - explicitly use token if available
258            let api = if let Some(token) = crate::env::hf_token() {
259                ApiBuilder::new()
260                    .with_token(Some(token))
261                    .build()
262                    .map_err(|e| {
263                        Error::Retrieval(format!(
264                            "Failed to initialize HuggingFace API with token: {}",
265                            e
266                        ))
267                    })?
268            } else {
269                Api::new().map_err(|e| {
270                    Error::Retrieval(format!("Failed to initialize HuggingFace API: {}", e))
271                })?
272            };
273            let repo = api.model(model_path.to_string());
274
275            let (model_file, tokenizer_file) = match repo
276                .get("model.onnx")
277                .or_else(|_| repo.get("onnx/model.onnx"))
278            {
279                Ok(p) => {
280                    let tok = repo.get("tokenizer.json").map_err(|e| {
281                        Error::Retrieval(format!("Failed to download tokenizer: {}", e))
282                    })?;
283                    (p, tok)
284                }
285                Err(e) => {
286                    let error_msg = format!("{e}");
287                    // Check if it's an authentication error (401) or gated model
288                    if error_msg.contains("401") || error_msg.contains("Unauthorized") {
289                        return Err(Error::Retrieval(format!(
290                            "W2NER model '{}' requires HuggingFace authentication.\n\
291                             \n\
292                             To fix this:\n\
293                             1. Get a HuggingFace token from https://huggingface.co/settings/tokens\n\
294                             2. Request access to the model on HuggingFace (if it's gated)\n\
295                             3. Set the token: export HF_TOKEN=your_token_here (or HF_API_TOKEN)\n\
296                             \n\
297                             Alternative: set W2NER_MODEL_PATH to a local export (see scripts/export_w2ner_to_onnx.py).",
298                            model_path
299                        )));
300                    }
301
302                    // 404 / missing ONNX is common: HF repos typically don't ship `model.onnx`.
303                    // We can auto-export a local ONNX model (bounded by env + CI) and proceed.
304                    //
305                    // IMPORTANT: many dev shells set `CI=1`, which should not disable auto-export
306                    // when running locally. Only treat GitHub Actions as “CI” for this purpose.
307                    let in_github_actions = std::env::var("GITHUB_ACTIONS").is_ok();
308                    let auto_export = match std::env::var("ANNO_W2NER_AUTO_EXPORT").ok() {
309                        None => !in_github_actions,
310                        Some(v) => {
311                            let t = v.trim().to_lowercase();
312                            t == "1" || t == "true" || t == "yes" || t == "y" || t == "on"
313                        }
314                    };
315
316                    if auto_export {
317                        let Some(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR").ok() else {
318                            return Err(Error::Retrieval(format!(
319                                "W2NER model '{}' is missing ONNX files, and auto-export is enabled, but CARGO_MANIFEST_DIR is not set.\n\
320                                 \n\
321                                 Fix:\n\
322                                 - Run from the repo via cargo (so CARGO_MANIFEST_DIR is present), or\n\
323                                 - Export manually and set W2NER_MODEL_PATH to the export directory.\n\
324                                 \n\
325                                 Original error: {e}",
326                                model_path
327                            )));
328                        };
329
330                        // Export location under the cache dir.
331                        //
332                        // IMPORTANT: `anno::eval` is feature-gated, so backends must not depend on
333                        // it. Mirror the cache-root logic in a lightweight way here.
334                        let cache_dir = std::env::var("ANNO_CACHE_DIR")
335                            .ok()
336                            .filter(|v| !v.trim().is_empty())
337                            .map(std::path::PathBuf::from)
338                            .unwrap_or_else(|| {
339                                dirs::cache_dir()
340                                    .unwrap_or_else(|| std::path::PathBuf::from("."))
341                                    .join("anno")
342                            });
343                        // Export model choice: default to a public BERT id so auto-export works
344                        // even when the configured W2NER HF repo is gated.
345                        let export_bert_model = std::env::var("W2NER_EXPORT_BERT_MODEL")
346                            .ok()
347                            .filter(|v| !v.trim().is_empty())
348                            .unwrap_or_else(|| "bert-base-cased".to_string());
349                        let safe_id = export_bert_model
350                            .chars()
351                            .map(|c| if c.is_ascii_alphanumeric() { c } else { '_' })
352                            .collect::<String>();
353                        let out_dir = cache_dir.join("models").join("w2ner").join(safe_id);
354                        std::fs::create_dir_all(&out_dir).map_err(|ioe| {
355                            Error::Retrieval(format!(
356                                "Failed to create W2NER export dir {:?}: {}",
357                                out_dir, ioe
358                            ))
359                        })?;
360
361                        let script_path = std::path::PathBuf::from(manifest_dir)
362                            .join("../../scripts/export_w2ner_to_onnx.py");
363                        let out_onnx = out_dir.join("model.onnx");
364
365                        // Run export via `uv`, which is expected in dev environments.
366                        let mut cmd = Command::new("uv");
367                        cmd.arg("run")
368                            .arg(script_path)
369                            .arg("--bert-model")
370                            .arg(&export_bert_model)
371                            .arg("--output")
372                            .arg(&out_onnx);
373
374                        let output = cmd.output().map_err(|ioe| {
375                            Error::Retrieval(format!(
376                                "Failed to spawn W2NER auto-export (uv): {}",
377                                ioe
378                            ))
379                        })?;
380                        if !output.status.success() {
381                            let stderr = String::from_utf8_lossy(&output.stderr);
382                            let stdout = String::from_utf8_lossy(&output.stdout);
383                            return Err(Error::Retrieval(format!(
384                                "W2NER auto-export failed (exit={}).\n\
385                                 \n\
386                                 stdout:\n{}\n\
387                                 \n\
388                                 stderr:\n{}\n\
389                                 \n\
390                                 Original HF error: {e}",
391                                output.status.code().unwrap_or(-1),
392                                stdout,
393                                stderr
394                            )));
395                        }
396
397                        // Tokenizer is saved alongside the ONNX by the export script.
398                        let tok = out_dir.join("tokenizer.json");
399                        if !out_onnx.exists() || !tok.exists() {
400                            return Err(Error::Retrieval(format!(
401                                "W2NER auto-export succeeded but expected files are missing.\n\
402                                 expected: {:?} and {:?}",
403                                out_onnx, tok
404                            )));
405                        }
406
407                        (out_onnx, tok)
408                    } else {
409                        return Err(Error::Retrieval(format!(
410                            "W2NER model '{}' not found or missing ONNX files.\n\
411                             \n\
412                             The model may be:\n\
413                             - A gated model requiring access approval at https://huggingface.co/{}\n\
414                             - Missing pre-exported ONNX files (model.onnx or onnx/model.onnx)\n\
415                             - Removed or renamed on HuggingFace\n\
416                             \n\
417                             Fix options:\n\
418                             - Set ANNO_W2NER_AUTO_EXPORT=1 (dev) to auto-export to ONNX\n\
419                             - Or export manually and set W2NER_MODEL_PATH to the export directory\n\
420                             \n\
421                             If you have HF_TOKEN set, ensure you've requested and received access to this model.\n\
422                             Alternative: Use nuner, gliner2, or other available NER backends.\n\
423                             \n\
424                             Original error: {e}",
425                            model_path, model_path
426                        )));
427                    }
428                }
429            };
430
431            (model_file, tokenizer_file)
432        };
433
434        let session = Session::builder()
435            .map_err(|e| Error::Retrieval(format!("Failed to create session: {}", e)))?
436            .with_execution_providers([CPUExecutionProvider::default().build()])
437            .map_err(|e| Error::Retrieval(format!("Failed to set providers: {}", e)))?
438            .commit_from_file(&model_file)
439            .map_err(|e| Error::Retrieval(format!("Failed to load model: {}", e)))?;
440
441        let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_file)
442            .map_err(|e| Error::Retrieval(format!("Failed to load tokenizer: {}", e)))?;
443
444        log::debug!("[W2NER] Loaded model");
445
446        Ok(Self {
447            config: W2NERConfig {
448                model_id: model_path.to_string(),
449                ..Default::default()
450            },
451            session: Some(crate::sync::Mutex::new(session)),
452            tokenizer: Some(tokenizer),
453        })
454    }
455
456    /// Set confidence threshold.
457    #[must_use]
458    pub fn with_threshold(mut self, threshold: f64) -> Self {
459        self.config.threshold = threshold.clamp(0.0, 1.0);
460        self
461    }
462
463    /// Set entity type labels.
464    #[must_use]
465    pub fn with_labels(mut self, labels: Vec<String>) -> Self {
466        self.config.entity_labels = labels;
467        self
468    }
469
470    /// Enable/disable nested entity extraction.
471    #[must_use]
472    pub fn with_nested(mut self, allow: bool) -> Self {
473        self.config.allow_nested = allow;
474        self
475    }
476
477    /// Decode entities from a handshaking matrix.
478    ///
479    /// This is the core W2NER decoding algorithm that can be used with
480    /// pre-computed grid predictions (e.g., from external inference).
481    ///
482    /// # Algorithm
483    ///
484    /// 1. Find all THW cells (entity boundaries)
485    /// 2. For each THW(i,j), the entity spans from word j (head) to word i (tail)
486    /// 3. Handle nested/overlapping entities based on config
487    ///
488    /// # Arguments
489    ///
490    /// * `matrix` - The predicted word-word relation grid
491    /// * `tokens` - Original tokens for text reconstruction
492    /// * `entity_type_idx` - Which entity type channel this is
493    pub fn decode_from_matrix(
494        &self,
495        matrix: &HandshakingMatrix,
496        tokens: &[&str],
497        entity_type_idx: usize,
498    ) -> Vec<(usize, usize, f64)> {
499        // Performance: Pre-allocate entities vec with estimated capacity
500        let mut entities = Vec::with_capacity(16);
501
502        // Find all THW (Tail-Head-Word) markers
503        // THW at (i,j) means: token i is tail, token j is head
504        // Entity spans from j (head/start) to i (tail/end)
505        for cell in &matrix.cells {
506            let relation = W2NERRelation::from_index(cell.label_idx as usize);
507            if relation == W2NERRelation::THW && cell.score >= self.config.threshold as f32 {
508                let tail = cell.i as usize;
509                let head = cell.j as usize;
510
511                // Validate: head <= tail (head is start, tail is end)
512                if head <= tail && head < tokens.len() && tail < tokens.len() {
513                    entities.push((head, tail + 1, cell.score as f64));
514                }
515            }
516        }
517
518        // Performance: Use unstable sort (we don't need stable sort here)
519        // Sort by start position, then by length (longer first for nested)
520        entities.sort_unstable_by(|a, b| a.0.cmp(&b.0).then_with(|| (b.1 - b.0).cmp(&(a.1 - a.0))));
521
522        // Remove nested entities if not allowed
523        if !self.config.allow_nested {
524            entities = Self::remove_nested(&entities);
525        }
526
527        let _ = entity_type_idx; // May be used for multi-type grids
528        entities
529    }
530
531    /// Decode dense grid output to HandshakingMatrix.
532    ///
533    /// # Arguments
534    /// * `grid` - Dense grid of shape [seq_len, seq_len, num_relations]
535    /// * `seq_len` - Sequence length
536    /// * `threshold` - Score threshold for sparse representation
537    pub fn grid_to_matrix(
538        grid: &[f32],
539        seq_len: usize,
540        num_relations: usize,
541        threshold: f32,
542    ) -> HandshakingMatrix {
543        let mut cells = Vec::new();
544
545        for i in 0..seq_len {
546            for j in 0..seq_len {
547                for rel in 0..num_relations {
548                    let idx = i * seq_len * num_relations + j * num_relations + rel;
549                    if let Some(&score) = grid.get(idx) {
550                        if score >= threshold && rel > 0 {
551                            // rel > 0 excludes "None"
552                            cells.push(HandshakingCell {
553                                i: i as u32,
554                                j: j as u32,
555                                label_idx: rel as u16,
556                                score,
557                            });
558                        }
559                    }
560                }
561            }
562        }
563
564        HandshakingMatrix {
565            cells,
566            seq_len,
567            num_labels: num_relations,
568        }
569    }
570
571    /// Remove nested entities (keep outermost only).
572    fn remove_nested(entities: &[(usize, usize, f64)]) -> Vec<(usize, usize, f64)> {
573        let mut result = Vec::new();
574        let mut last_end = 0;
575
576        for &(start, end, score) in entities {
577            if start >= last_end {
578                result.push((start, end, score));
579                last_end = end;
580            }
581        }
582
583        result
584    }
585
586    /// Map label string to EntityType.
587    fn map_label(label: &str) -> EntityType {
588        match label.to_uppercase().as_str() {
589            "PER" | "PERSON" => EntityType::Person,
590            "ORG" | "ORGANIZATION" => EntityType::Organization,
591            "LOC" | "LOCATION" | "GPE" => EntityType::Location,
592            "DATE" => EntityType::Date,
593            "TIME" => EntityType::Time,
594            "MONEY" => EntityType::Money,
595            "PERCENT" => EntityType::Percent,
596            "MISC" => EntityType::Other("MISC".to_string()),
597            _ => EntityType::Other(label.to_string()),
598        }
599    }
600
601    /// Run inference with ONNX model.
602    #[cfg(feature = "onnx")]
603    pub fn extract_with_grid(&self, text: &str, threshold: f32) -> Result<Vec<Entity>> {
604        if text.is_empty() {
605            return Ok(vec![]);
606        }
607
608        let session = self.session.as_ref().ok_or_else(|| {
609            Error::Retrieval("Model not loaded. Call from_pretrained() first.".to_string())
610        })?;
611
612        let tokenizer = self
613            .tokenizer
614            .as_ref()
615            .ok_or_else(|| Error::Retrieval("Tokenizer not loaded.".to_string()))?;
616
617        // Tokenize via whitespace splitting.
618        //
619        // LIMITATION: This only works for languages with explicit word boundaries
620        // (Latin, Cyrillic, etc.). CJK/Thai/Khmer/Lao will produce single "words"
621        // for entire sentences, breaking entity extraction. See module docs.
622        let words: Vec<&str> = text.split_whitespace().collect();
623        if words.is_empty() {
624            return Ok(vec![]);
625        }
626
627        let encoding = tokenizer
628            .encode(text.to_string(), true)
629            .map_err(|e| Error::Parse(format!("Tokenization failed: {}", e)))?;
630
631        let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
632        let attention_mask: Vec<i64> = encoding
633            .get_attention_mask()
634            .iter()
635            .map(|&x| x as i64)
636            .collect();
637        let seq_len = input_ids.len();
638
639        // Build tensors
640        use ndarray::Array2;
641
642        let input_ids_arr = Array2::from_shape_vec((1, seq_len), input_ids)
643            .map_err(|e| Error::Parse(format!("Array error: {}", e)))?;
644        let attention_arr = Array2::from_shape_vec((1, seq_len), attention_mask)
645            .map_err(|e| Error::Parse(format!("Array error: {}", e)))?;
646
647        let input_ids_t = super::ort_compat::tensor_from_ndarray(input_ids_arr)
648            .map_err(|e| Error::Parse(format!("Tensor error: {}", e)))?;
649        let attention_t = super::ort_compat::tensor_from_ndarray(attention_arr)
650            .map_err(|e| Error::Parse(format!("Tensor error: {}", e)))?;
651
652        // Run inference with blocking lock for thread-safe parallel access
653        let mut session_guard = crate::sync::lock(session);
654
655        let outputs = session_guard
656            .run(ort::inputs![
657                "input_ids" => input_ids_t.into_dyn(),
658                "attention_mask" => attention_t.into_dyn(),
659            ])
660            .map_err(|e| Error::Parse(format!("Inference failed: {}", e)))?;
661
662        // Decode grid output
663        let output = outputs
664            .iter()
665            .next()
666            .map(|(_, v)| v)
667            .ok_or_else(|| Error::Parse("No output".to_string()))?;
668
669        let (_, data) = output
670            .try_extract_tensor::<f32>()
671            .map_err(|e| Error::Parse(format!("Extract failed: {}", e)))?;
672        let grid: Vec<f32> = data.to_vec();
673
674        // Convert grid to matrix and decode
675        let num_relations = 3; // None, NNW, THW
676        let matrix = Self::grid_to_matrix(&grid, seq_len, num_relations, threshold);
677
678        // Calculate word positions
679        // Note: This assumes words appear in order and don't overlap.
680        // If a word appears multiple times, this will find the first occurrence
681        // after the previous word. This is correct for tokenized input where
682        // words are in sequence, but may fail if words are out of order.
683        let word_positions: Vec<(usize, usize)> = {
684            // Performance: Pre-allocate positions vec with known size
685            let mut positions = Vec::with_capacity(words.len());
686            let mut pos = 0;
687            for (idx, word) in words.iter().enumerate() {
688                if let Some(start) = text[pos..].find(word) {
689                    let abs_start = pos + start;
690                    let abs_end = abs_start + word.len();
691                    // Validate position is after previous word (words should be in order)
692                    if !positions.is_empty() {
693                        let (_prev_start, prev_end) = positions[positions.len() - 1];
694                        if abs_start < prev_end {
695                            log::warn!(
696                                "Word '{}' (index {}) at position {} overlaps with previous word ending at {}",
697                                word,
698                                idx,
699                                abs_start,
700                                prev_end
701                            );
702                        }
703                    }
704                    positions.push((abs_start, abs_end));
705                    pos = abs_end;
706                } else {
707                    // Word not found - return error to prevent silent entity skipping
708                    return Err(Error::Parse(format!(
709                        "Word '{}' (index {}) not found in text starting at position {}",
710                        word, idx, pos
711                    )));
712                }
713            }
714            positions
715        };
716
717        // Validate that we found positions for all words
718        if word_positions.len() != words.len() {
719            return Err(Error::Parse(format!(
720                "Word position mismatch: found {} positions for {} words",
721                word_positions.len(),
722                words.len()
723            )));
724        }
725
726        // Word positions are byte offsets; `Entity` requires character offsets.
727        let span_converter = crate::offset::SpanConverter::new(text);
728
729        // Performance: Pre-allocate entities vec with estimated capacity
730        // Decode entities for each type
731        let mut entities = Vec::with_capacity(16);
732        for (type_idx, label) in self.config.entity_labels.iter().enumerate() {
733            let spans = self.decode_from_matrix(&matrix, &words.to_vec(), type_idx);
734
735            for (start_word, end_word, score) in spans {
736                if let (Some(&(start_pos, _)), Some(&(_, end_pos))) = (
737                    word_positions.get(start_word),
738                    word_positions.get(end_word.saturating_sub(1)),
739                ) {
740                    if let Some(entity_text) = text.get(start_pos..end_pos) {
741                        entities.push(Entity::new(
742                            entity_text,
743                            Self::map_label(label),
744                            span_converter.byte_to_char(start_pos),
745                            span_converter.byte_to_char(end_pos),
746                            score,
747                        ));
748                    }
749                }
750            }
751        }
752
753        Ok(entities)
754    }
755}
756
757impl Default for W2NER {
758    fn default() -> Self {
759        Self::new()
760    }
761}
762
763impl Model for W2NER {
764    fn extract_entities(&self, text: &str, language: Option<&str>) -> Result<Vec<Entity>> {
765        if text.trim().is_empty() {
766            return Ok(vec![]);
767        }
768
769        // Warn if the language hint suggests a non-whitespace-tokenized language.
770        // W2NER uses `split_whitespace()`, which doesn't work for CJK/Thai/etc.
771        if let Some(lang) = language {
772            let lang_lower = lang.to_lowercase();
773            let is_non_whitespace_lang = matches!(
774                lang_lower.as_str(),
775                "zh" | "zh-cn"
776                    | "zh-tw"
777                    | "chinese"
778                    | "mandarin"
779                    | "cantonese"
780                    | "ja"
781                    | "jp"
782                    | "japanese"
783                    | "ko"
784                    | "kr"
785                    | "korean"
786                    | "th"
787                    | "thai"
788                    | "km"
789                    | "khmer"
790                    | "lo"
791                    | "lao"
792                    | "my"
793                    | "burmese"
794                    | "myanmar"
795            );
796            if is_non_whitespace_lang {
797                log::warn!(
798                    "[W2NER] Language '{}' detected, but W2NER uses whitespace tokenization \
799                     which does not work correctly for CJK/Thai/Khmer/Lao. \
800                     Consider pre-tokenizing or using a different backend (e.g., GLiNER).",
801                    lang
802                );
803            }
804        }
805
806        #[cfg(feature = "onnx")]
807        {
808            if self.session.is_some() {
809                return self.extract_with_grid(text, self.config.threshold as f32);
810            }
811
812            Err(crate::Error::ModelInit(
813                "W2NER model not loaded. Call `W2NER::from_pretrained(...)` (requires `onnx` feature) before calling `extract_entities`.".to_string(),
814            ))
815        }
816
817        #[cfg(not(feature = "onnx"))]
818        {
819            Err(crate::Error::FeatureNotAvailable(
820                "W2NER requires the 'onnx' feature. Build with: cargo build --features onnx"
821                    .to_string(),
822            ))
823        }
824    }
825
826    fn supported_types(&self) -> Vec<EntityType> {
827        self.config
828            .entity_labels
829            .iter()
830            .map(|l| Self::map_label(l))
831            .collect()
832    }
833
834    fn is_available(&self) -> bool {
835        #[cfg(feature = "onnx")]
836        {
837            self.session.is_some()
838        }
839        #[cfg(not(feature = "onnx"))]
840        {
841            false
842        }
843    }
844
845    fn name(&self) -> &'static str {
846        "w2ner"
847    }
848
849    fn description(&self) -> &'static str {
850        "W2NER: Unified NER via Word-Word Relation Classification (nested/discontinuous support)"
851    }
852
853    fn version(&self) -> String {
854        format!("w2ner-{}", self.config.model_id)
855    }
856}
857
858// =============================================================================
859// BatchCapable Trait Implementation
860// =============================================================================
861
862impl crate::BatchCapable for W2NER {
863    fn optimal_batch_size(&self) -> Option<usize> {
864        Some(4) // W2NER is more memory-intensive due to grid computation
865    }
866}
867
868// =============================================================================
869// StreamingCapable Trait Implementation
870// =============================================================================
871
872impl crate::StreamingCapable for W2NER {
873    fn recommended_chunk_size(&self) -> usize {
874        2048 // Smaller chunks due to grid memory requirements
875    }
876}
877
878// =============================================================================
879// DiscontinuousNER Trait Implementation
880// =============================================================================
881
882impl DiscontinuousNER for W2NER {
883    /// Extract entities with discontinuous span support.
884    ///
885    /// # Current Limitation
886    ///
887    /// **True discontinuous decoding is not yet implemented.** This method
888    /// currently wraps each contiguous entity into a single-segment
889    /// `DiscontinuousEntity`. The W2NER paper describes a grid-based decoding
890    /// algorithm for discontinuous entities, but this implementation does not
891    /// yet decode those relations.
892    ///
893    /// If you need true discontinuous entity support, consider:
894    /// 1. Post-processing with heuristics (e.g., linking "severe" to "pain")
895    /// 2. Using a specialized discontinuous NER model
896    ///
897    /// This trait implementation exists for API compatibility and will be
898    /// upgraded when true discontinuous decoding is implemented.
899    fn extract_discontinuous(
900        &self,
901        text: &str,
902        entity_types: &[&str],
903        threshold: f32,
904    ) -> Result<Vec<DiscontinuousEntity>> {
905        if text.trim().is_empty() {
906            return Ok(vec![]);
907        }
908
909        #[cfg(feature = "onnx")]
910        {
911            if self.session.is_some() {
912                // TODO(discontinuous): Implement true discontinuous decoding.
913                //
914                // The W2NER grid contains relation information that could be
915                // used to link non-adjacent spans into discontinuous entities.
916                // For now, we wrap each contiguous entity into a single-segment
917                // DiscontinuousEntity for API compatibility.
918                //
919                // See: https://arxiv.org/abs/2112.10070 (Section 3.3)
920                let entities = self.extract_with_grid(text, threshold)?;
921
922                return Ok(entities
923                    .into_iter()
924                    .map(|e| DiscontinuousEntity {
925                        spans: vec![(e.start, e.end)],
926                        text: e.text,
927                        entity_type: e.entity_type.as_label().to_string(),
928                        confidence: e.confidence as f32,
929                    })
930                    .collect());
931            }
932        }
933
934        let _ = (entity_types, threshold);
935
936        #[cfg(feature = "onnx")]
937        {
938            Err(crate::Error::ModelInit(
939                "W2NER model not loaded. Call `W2NER::from_pretrained(...)` (requires `onnx` feature) before calling `extract_discontinuous`.".to_string(),
940            ))
941        }
942
943        #[cfg(not(feature = "onnx"))]
944        {
945            Err(crate::Error::FeatureNotAvailable(
946                "W2NER requires the 'onnx' feature. Build with: cargo build --features onnx"
947                    .to_string(),
948            ))
949        }
950    }
951}
952
953#[cfg(test)]
954mod tests {
955    use super::*;
956
957    #[test]
958    fn test_w2ner_relation_conversion() {
959        assert_eq!(W2NERRelation::from_index(0), W2NERRelation::None);
960        assert_eq!(W2NERRelation::from_index(1), W2NERRelation::NNW);
961        assert_eq!(W2NERRelation::from_index(2), W2NERRelation::THW);
962
963        assert_eq!(W2NERRelation::None.to_index(), 0);
964        assert_eq!(W2NERRelation::NNW.to_index(), 1);
965        assert_eq!(W2NERRelation::THW.to_index(), 2);
966    }
967
968    #[test]
969    fn test_w2ner_config_defaults() {
970        let config = W2NERConfig::default();
971        assert!((config.threshold - 0.5).abs() < f64::EPSILON);
972        assert!(config.allow_nested);
973        assert!(config.allow_discontinuous);
974        assert_eq!(config.entity_labels.len(), 3);
975    }
976
977    #[test]
978    fn test_decode_simple_entity() {
979        let w2ner = W2NER::new();
980        let tokens = ["New", "York", "City"];
981
982        // THW marker: tail=2, head=0 (entity spans all 3 tokens)
983        let matrix = HandshakingMatrix {
984            cells: vec![HandshakingCell {
985                i: 2, // tail
986                j: 0, // head
987                label_idx: W2NERRelation::THW.to_index() as u16,
988                score: 0.9,
989            }],
990            seq_len: 3,
991            num_labels: 3,
992        };
993
994        let entities = w2ner.decode_from_matrix(&matrix, &tokens, 0);
995        assert_eq!(entities.len(), 1);
996        assert_eq!(entities[0].0, 0); // start
997        assert_eq!(entities[0].1, 3); // end
998    }
999
1000    #[test]
1001    fn test_decode_nested_entities() {
1002        let w2ner = W2NER::with_config(W2NERConfig {
1003            allow_nested: true,
1004            ..Default::default()
1005        });
1006
1007        let tokens = ["University", "of", "California", "Berkeley"];
1008
1009        let matrix = HandshakingMatrix {
1010            cells: vec![
1011                // Full entity: tail=3, head=0
1012                HandshakingCell {
1013                    i: 3,
1014                    j: 0,
1015                    label_idx: W2NERRelation::THW.to_index() as u16,
1016                    score: 0.95,
1017                },
1018                // Nested: tail=2, head=2 (just "California")
1019                HandshakingCell {
1020                    i: 2,
1021                    j: 2,
1022                    label_idx: W2NERRelation::THW.to_index() as u16,
1023                    score: 0.85,
1024                },
1025            ],
1026            seq_len: 4,
1027            num_labels: 3,
1028        };
1029
1030        let entities = w2ner.decode_from_matrix(&matrix, &tokens, 0);
1031        assert_eq!(entities.len(), 2);
1032    }
1033
1034    #[test]
1035    fn test_remove_nested() {
1036        let entities = vec![
1037            (0, 4, 0.9), // outer
1038            (2, 3, 0.8), // nested
1039        ];
1040
1041        let filtered = W2NER::remove_nested(&entities);
1042        assert_eq!(filtered.len(), 1);
1043        assert_eq!(filtered[0], (0, 4, 0.9));
1044    }
1045
1046    #[test]
1047    fn test_grid_to_matrix() {
1048        // 3x3 grid with 3 relations (None, NNW, THW)
1049        let seq_len = 3;
1050        let num_rels = 3;
1051        let mut grid = vec![0.0f32; seq_len * seq_len * num_rels];
1052
1053        // Set THW at (2, 0) with score 0.9
1054        // Index formula: i * seq_len * num_rels + j * num_rels + rel_idx
1055        let i = 2;
1056        let j = 0;
1057        let rel_thw = 2;
1058        let idx = i * seq_len * num_rels + j * num_rels + rel_thw;
1059        grid[idx] = 0.9;
1060
1061        let matrix = W2NER::grid_to_matrix(&grid, seq_len, num_rels, 0.5);
1062        assert_eq!(matrix.cells.len(), 1);
1063        assert_eq!(matrix.cells[0].i, 2);
1064        assert_eq!(matrix.cells[0].j, 0);
1065    }
1066
1067    #[test]
1068    fn test_label_mapping() {
1069        assert_eq!(W2NER::map_label("PER"), EntityType::Person);
1070        assert_eq!(W2NER::map_label("org"), EntityType::Organization);
1071        assert_eq!(W2NER::map_label("GPE"), EntityType::Location);
1072        assert_eq!(
1073            W2NER::map_label("CUSTOM"),
1074            EntityType::Other("CUSTOM".to_string())
1075        );
1076    }
1077
1078    #[test]
1079    fn test_empty_input() {
1080        let w2ner = W2NER::new();
1081        let entities = w2ner.extract_entities("", None).unwrap();
1082        assert!(entities.is_empty());
1083    }
1084
1085    #[test]
1086    fn test_not_available_without_model() {
1087        let w2ner = W2NER::new();
1088        // Without model loaded, should not be available
1089        assert!(!w2ner.is_available());
1090    }
1091
1092    #[test]
1093    fn test_errors_without_model() {
1094        let w2ner = W2NER::new();
1095        // Without model, should return an explicit error (no silent empty fallback).
1096        let err = w2ner
1097            .extract_entities("Steve Jobs founded Apple", None)
1098            .unwrap_err();
1099        assert!(
1100            matches!(
1101                err,
1102                crate::Error::ModelInit(_) | crate::Error::FeatureNotAvailable(_)
1103            ),
1104            "unexpected error: {:?}",
1105            err
1106        );
1107    }
1108}