Skip to main content

anno/backends/w2ner/
mod.rs

1//! W2NER - Unified NER via Word-Word Relation Classification.
2//!
3//! W2NER (Word-to-Word NER) models NER as classifying relations between
4//! every pair of words in a sentence. This elegantly handles:
5//!
6//! - **Nested entities**: "The \[University of \[California\]\]"
7//! - **Discontinuous entities**: "severe \[pain\] ... in \[abdomen\]" *(see limitation below)*
8//! - **Overlapping entities**: Same span, different types
9//!
10//! # Discontinuous Entities
11//!
12//! True discontinuous entity decoding is implemented via [`W2NER::decode_discontinuous_from_matrix`].
13//! The algorithm follows arXiv:2112.10070 §3.3: THW cells identify entity boundaries and NNW cells
14//! identify adjacent-word connections within the same entity; gaps in the NNW chain produce disjoint
15//! sub-spans.
16//!
17//! Use `extract_discontinuous()` (the [`DiscontinuousNER`] trait) when you need
18//! non-contiguous span support. The standard `extract_entities()` returns only
19//! contiguous spans for backwards-compatibility with the flat `Entity` type.
20//!
21//! # Language Support (Important Limitation)
22//!
23//! **This implementation uses whitespace tokenization** (`split_whitespace()`),
24//! which works correctly for:
25//!
26//! - **Latin-script languages**: English, German, French, Spanish, etc.
27//! - **Cyrillic**: Russian, Ukrainian, etc.
28//! - **Languages with explicit word boundaries**
29//!
30//! It does **NOT** work correctly for:
31//!
32//! - **CJK languages** (Chinese, Japanese, Korean): No whitespace between words
33//! - **Thai, Khmer, Lao**: Scriptio continua (no word boundaries)
34//! - **Languages requiring morphological analysis**
35//!
36//! If you need CJK/Thai support, consider:
37//! 1. Pre-tokenizing with a proper segmenter (e.g., jieba, mecab, pythainlp)
38//! 2. Using a different backend (e.g., GLiNER with subword tokenization)
39//!
40//! The `language` parameter to [`Model::extract_entities`] is currently ignored,
41//! but a warning is logged if a non-whitespace language is detected.
42//!
43//! # Architecture
44//!
45//! ```text
46//! Input: "New York City is great"
47//!
48//!        ┌─────────────────────────────┐
49//!        │      Encoder (BERT)          │
50//!        └─────────────────────────────┘
51//!                     │
52//!        ┌─────────────────────────────┐
53//!        │    Biaffine Attention        │
54//!        │    (word-word scoring)       │
55//!        └─────────────────────────────┘
56//!                     │
57//!        ┌───────────────────────────────┐
58//!        │     Word-Word Grid (N×N×L)    │
59//!        │  ┌───┬───┬───┬───┬───┐       │
60//!        │  │   │New│York│City│...│      │
61//!        │  ├───┼───┼───┼───┼───┤       │
62//!        │  │New│ B │NNW│THW│   │       │
63//!        │  ├───┼───┼───┼───┼───┤       │
64//!        │  │Yrk│   │ B │NNW│   │       │
65//!        │  ├───┼───┼───┼───┼───┤       │
66//!        │  │Cty│   │   │ B │   │       │
67//!        │  └───┴───┴───┴───┴───┘       │
68//!        └───────────────────────────────┘
69//!
70//! Legend:
71//!   B   = Begin entity
72//!   NNW = Next-Neighboring-Word (same entity)
73//!   THW = Tail-Head-Word (entity boundary)
74//! ```
75//!
76//! # Grid Labels
77//!
78//! W2NER uses three relation types for each entity label:
79//!
80//! - **NNW (Next-Neighboring-Word)**: Token i and j are adjacent in same entity
81//! - **THW (Tail-Head-Word)**: Token i is tail, token j is head of entity
82//! - **None**: No relation
83//!
84//! # Usage
85//!
86//! ```rust,ignore
87//! use anno::W2NER;
88//!
89//! // Load W2NER model (requires `onnx` feature)
90//! let w2ner = W2NER::from_pretrained("path/to/w2ner-model")?;
91//!
92//! let text = "The University of California Berkeley";
93//! let entities = w2ner.extract_entities(text, None)?;
94//! // Returns nested entities: ORG + nested LOC
95//! ```
96//!
97//! # References
98//!
99//! - [W2NER Paper](https://arxiv.org/abs/2112.10070) (AAAI 2022)
100//! - [TPLinker](https://aclanthology.org/2020.coling-main.138/) (related approach)
101
102pub mod decode;
103pub use decode::{map_label_to_entity_type, DiscontinuousDecodeRow, W2NERRelation};
104
105use crate::backends::inference::{DiscontinuousEntity, DiscontinuousNER, HandshakingMatrix};
106use crate::{Entity, EntityType, Model, Result};
107
108#[cfg(feature = "onnx")]
109use crate::Error;
110
111/// Configuration for W2NER decoding.
112///
113/// # Tokenization
114///
115/// W2NER uses **whitespace tokenization** (`split_whitespace()`), which works
116/// for Latin-script languages but fails for CJK/Thai/Lao. See module-level
117/// docs for details and workarounds.
118#[derive(Debug, Clone)]
119pub struct W2NERConfig {
120    /// Confidence threshold for grid predictions
121    pub threshold: f64,
122    /// Entity type labels (maps grid channels to types)
123    pub entity_labels: Vec<String>,
124    /// Whether to extract nested entities
125    pub allow_nested: bool,
126    /// Whether to extract discontinuous entities.
127    ///
128    /// **Note**: Currently, discontinuous decoding is not fully implemented.
129    /// This flag exists for forward-compatibility; setting it to `true` does
130    /// not yet produce true discontinuous spans. See `backend-02` in docs.
131    pub allow_discontinuous: bool,
132    /// Model identifier for loading
133    pub model_id: String,
134}
135
136impl Default for W2NERConfig {
137    fn default() -> Self {
138        Self {
139            threshold: 0.5,
140            entity_labels: vec!["PER".to_string(), "ORG".to_string(), "LOC".to_string()],
141            allow_nested: true,
142            allow_discontinuous: true,
143            model_id: String::new(),
144        }
145    }
146}
147
148/// W2NER model for unified named entity recognition.
149///
150/// Uses word-word relation classification to handle complex entity
151/// structures (nested, overlapping, discontinuous).
152///
153/// # Feature Requirements
154///
155/// Requires the `onnx` feature for actual inference. Without it, only the
156/// [`decode_from_matrix`](Self::decode_from_matrix) method works with
157/// pre-computed grids.
158///
159/// # Example
160///
161/// ```rust,ignore
162/// let w2ner = W2NER::from_pretrained("ljynlp/w2ner-bert-base")?;
163///
164/// // Handles nested entities naturally
165/// let text = "The University of California Berkeley";
166/// let entities = w2ner.extract_entities(text, None)?;
167/// ```
168pub struct W2NER {
169    config: W2NERConfig,
170    #[cfg(feature = "onnx")]
171    session: Option<crate::sync::Mutex<ort::session::Session>>,
172    #[cfg(feature = "onnx")]
173    tokenizer: Option<tokenizers::Tokenizer>,
174}
175
176impl W2NER {
177    /// Create W2NER with default configuration.
178    #[must_use]
179    pub fn new() -> Self {
180        Self {
181            config: W2NERConfig::default(),
182            #[cfg(feature = "onnx")]
183            session: None,
184            #[cfg(feature = "onnx")]
185            tokenizer: None,
186        }
187    }
188
189    /// Create with custom configuration.
190    #[must_use]
191    pub fn with_config(config: W2NERConfig) -> Self {
192        Self {
193            config,
194            #[cfg(feature = "onnx")]
195            session: None,
196            #[cfg(feature = "onnx")]
197            tokenizer: None,
198        }
199    }
200
201    /// Load W2NER model from path or HuggingFace.
202    ///
203    /// Automatically loads `.env` for HF_TOKEN if present.
204    ///
205    /// # Arguments
206    /// * `model_path` - Local path or HuggingFace model ID
207    #[cfg(feature = "onnx")]
208    pub fn from_pretrained(model_path: &str) -> Result<Self> {
209        use hf_hub::api::sync::{Api, ApiBuilder};
210        use ort::execution_providers::CPUExecutionProvider;
211        use ort::session::Session;
212        use std::path::Path;
213        use std::process::Command;
214
215        // Load .env if present (for HF_TOKEN)
216        crate::env::load_dotenv();
217
218        let (model_file, tokenizer_file) = if Path::new(model_path).exists() {
219            // Local path
220            let model_file = Path::new(model_path).join("model.onnx");
221            let tokenizer_file = Path::new(model_path).join("tokenizer.json");
222            (model_file, tokenizer_file)
223        } else {
224            // HuggingFace download - explicitly use token if available
225            let api = if let Some(token) = crate::env::hf_token() {
226                ApiBuilder::new()
227                    .with_token(Some(token))
228                    .build()
229                    .map_err(|e| {
230                        Error::Retrieval(format!(
231                            "Failed to initialize HuggingFace API with token: {}",
232                            e
233                        ))
234                    })?
235            } else {
236                Api::new().map_err(|e| {
237                    Error::Retrieval(format!("Failed to initialize HuggingFace API: {}", e))
238                })?
239            };
240            let repo = api.model(model_path.to_string());
241
242            let (model_file, tokenizer_file) = match repo
243                .get("model.onnx")
244                .or_else(|_| repo.get("onnx/model.onnx"))
245            {
246                Ok(p) => {
247                    let tok = repo.get("tokenizer.json").map_err(|e| {
248                        Error::Retrieval(format!("Failed to download tokenizer: {}", e))
249                    })?;
250                    (p, tok)
251                }
252                Err(e) => {
253                    let error_msg = format!("{e}");
254                    // Check if it's an authentication error (401) or gated model
255                    if error_msg.contains("401") || error_msg.contains("Unauthorized") {
256                        return Err(Error::Retrieval(format!(
257                            "W2NER model '{}' requires HuggingFace authentication.\n\
258                             \n\
259                             To fix this:\n\
260                             1. Get a HuggingFace token from https://huggingface.co/settings/tokens\n\
261                             2. Request access to the model on HuggingFace (if it's gated)\n\
262                             3. Set the token: export HF_TOKEN=your_token_here (or HF_API_TOKEN)\n\
263                             \n\
264                             Alternative: set W2NER_MODEL_PATH to a local export (see scripts/export_w2ner_to_onnx.py).",
265                            model_path
266                        )));
267                    }
268
269                    // 404 / missing ONNX is common: HF repos typically don't ship `model.onnx`.
270                    // We can auto-export a local ONNX model (bounded by env + CI) and proceed.
271                    //
272                    // IMPORTANT: many dev shells set `CI=1`, which should not disable auto-export
273                    // when running locally. Only treat GitHub Actions as “CI” for this purpose.
274                    let in_github_actions = std::env::var("GITHUB_ACTIONS").is_ok();
275                    let auto_export = match std::env::var("ANNO_W2NER_AUTO_EXPORT").ok() {
276                        None => !in_github_actions,
277                        Some(v) => {
278                            let t = v.trim().to_lowercase();
279                            t == "1" || t == "true" || t == "yes" || t == "y" || t == "on"
280                        }
281                    };
282
283                    if auto_export {
284                        let Some(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR").ok() else {
285                            return Err(Error::Retrieval(format!(
286                                "W2NER model '{}' is missing ONNX files, and auto-export is enabled, but CARGO_MANIFEST_DIR is not set.\n\
287                                 \n\
288                                 Fix:\n\
289                                 - Run from the repo via cargo (so CARGO_MANIFEST_DIR is present), or\n\
290                                 - Export manually and set W2NER_MODEL_PATH to the export directory.\n\
291                                 \n\
292                                 Original error: {e}",
293                                model_path
294                            )));
295                        };
296
297                        // Export location under the cache dir.
298                        //
299                        // IMPORTANT: `anno::eval` is feature-gated, so backends must not depend on
300                        // it. Mirror the cache-root logic in a lightweight way here.
301                        let cache_dir = std::env::var("ANNO_CACHE_DIR")
302                            .ok()
303                            .filter(|v| !v.trim().is_empty())
304                            .map(std::path::PathBuf::from)
305                            .unwrap_or_else(|| {
306                                dirs::cache_dir()
307                                    .unwrap_or_else(|| std::path::PathBuf::from("."))
308                                    .join("anno")
309                            });
310                        // Export model choice: default to a public BERT id so auto-export works
311                        // even when the configured W2NER HF repo is gated.
312                        let export_bert_model = std::env::var("W2NER_EXPORT_BERT_MODEL")
313                            .ok()
314                            .filter(|v| !v.trim().is_empty())
315                            .unwrap_or_else(|| "bert-base-cased".to_string());
316                        let safe_id = export_bert_model
317                            .chars()
318                            .map(|c| if c.is_ascii_alphanumeric() { c } else { '_' })
319                            .collect::<String>();
320                        let out_dir = cache_dir.join("models").join("w2ner").join(safe_id);
321                        std::fs::create_dir_all(&out_dir).map_err(|ioe| {
322                            Error::Retrieval(format!(
323                                "Failed to create W2NER export dir {:?}: {}",
324                                out_dir, ioe
325                            ))
326                        })?;
327
328                        let script_path = std::path::PathBuf::from(manifest_dir)
329                            .join("../../scripts/export_w2ner_to_onnx.py");
330                        let out_onnx = out_dir.join("model.onnx");
331
332                        // Run export via `uv`, which is expected in dev environments.
333                        let mut cmd = Command::new("uv");
334                        cmd.arg("run")
335                            .arg(script_path)
336                            .arg("--bert-model")
337                            .arg(&export_bert_model)
338                            .arg("--output")
339                            .arg(&out_onnx);
340
341                        let output = cmd.output().map_err(|ioe| {
342                            Error::Retrieval(format!(
343                                "Failed to spawn W2NER auto-export (uv): {}",
344                                ioe
345                            ))
346                        })?;
347                        if !output.status.success() {
348                            let stderr = String::from_utf8_lossy(&output.stderr);
349                            let stdout = String::from_utf8_lossy(&output.stdout);
350                            return Err(Error::Retrieval(format!(
351                                "W2NER auto-export failed (exit={}).\n\
352                                 \n\
353                                 stdout:\n{}\n\
354                                 \n\
355                                 stderr:\n{}\n\
356                                 \n\
357                                 Original HF error: {e}",
358                                output.status.code().unwrap_or(-1),
359                                stdout,
360                                stderr
361                            )));
362                        }
363
364                        // Tokenizer is saved alongside the ONNX by the export script.
365                        let tok = out_dir.join("tokenizer.json");
366                        if !out_onnx.exists() || !tok.exists() {
367                            return Err(Error::Retrieval(format!(
368                                "W2NER auto-export succeeded but expected files are missing.\n\
369                                 expected: {:?} and {:?}",
370                                out_onnx, tok
371                            )));
372                        }
373
374                        (out_onnx, tok)
375                    } else {
376                        return Err(Error::Retrieval(format!(
377                            "W2NER model '{}' not found or missing ONNX files.\n\
378                             \n\
379                             The model may be:\n\
380                             - A gated model requiring access approval at https://huggingface.co/{}\n\
381                             - Missing pre-exported ONNX files (model.onnx or onnx/model.onnx)\n\
382                             - Removed or renamed on HuggingFace\n\
383                             \n\
384                             Fix options:\n\
385                             - Set ANNO_W2NER_AUTO_EXPORT=1 (dev) to auto-export to ONNX\n\
386                             - Or export manually and set W2NER_MODEL_PATH to the export directory\n\
387                             \n\
388                             If you have HF_TOKEN set, ensure you've requested and received access to this model.\n\
389                             Alternative: Use nuner, gliner2, or other available NER backends.\n\
390                             \n\
391                             Original error: {e}",
392                            model_path, model_path
393                        )));
394                    }
395                }
396            };
397
398            (model_file, tokenizer_file)
399        };
400
401        let session = Session::builder()
402            .map_err(|e| Error::Retrieval(format!("Failed to create session: {}", e)))?
403            .with_execution_providers([CPUExecutionProvider::default().build()])
404            .map_err(|e| Error::Retrieval(format!("Failed to set providers: {}", e)))?
405            .commit_from_file(&model_file)
406            .map_err(|e| Error::Retrieval(format!("Failed to load model: {}", e)))?;
407
408        let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_file)
409            .map_err(|e| Error::Retrieval(format!("Failed to load tokenizer: {}", e)))?;
410
411        log::debug!("[W2NER] Loaded model");
412
413        Ok(Self {
414            config: W2NERConfig {
415                model_id: model_path.to_string(),
416                ..Default::default()
417            },
418            session: Some(crate::sync::Mutex::new(session)),
419            tokenizer: Some(tokenizer),
420        })
421    }
422
423    /// Set confidence threshold.
424    #[must_use]
425    pub fn with_threshold(mut self, threshold: f64) -> Self {
426        self.config.threshold = threshold.clamp(0.0, 1.0);
427        self
428    }
429
430    /// Set entity type labels.
431    #[must_use]
432    pub fn with_labels(mut self, labels: Vec<String>) -> Self {
433        self.config.entity_labels = labels;
434        self
435    }
436
437    /// Enable/disable nested entity extraction.
438    #[must_use]
439    pub fn with_nested(mut self, allow: bool) -> Self {
440        self.config.allow_nested = allow;
441        self
442    }
443
444    /// Decode entities from a handshaking matrix.
445    ///
446    /// Delegates to [`decode::decode_from_matrix`] with this model's threshold
447    /// and `allow_nested` settings. See that function for the full algorithm.
448    pub fn decode_from_matrix(
449        &self,
450        matrix: &HandshakingMatrix,
451        tokens: &[&str],
452        entity_type_idx: usize,
453    ) -> Vec<(usize, usize, f64)> {
454        decode::decode_from_matrix(
455            matrix,
456            tokens,
457            entity_type_idx,
458            self.config.threshold as f32,
459            self.config.allow_nested,
460        )
461    }
462
463    /// Decode discontinuous entities from a handshaking matrix.
464    ///
465    /// Delegates to [`decode::decode_discontinuous_from_matrix`].
466    pub fn decode_discontinuous_from_matrix(
467        &self,
468        matrix: &HandshakingMatrix,
469        tokens: &[&str],
470        threshold: f32,
471    ) -> Vec<DiscontinuousDecodeRow> {
472        let first_label = self
473            .config
474            .entity_labels
475            .first()
476            .map(|s| s.as_str())
477            .unwrap_or("");
478        decode::decode_discontinuous_from_matrix(matrix, tokens, threshold, first_label)
479    }
480
481    /// Decode dense grid output to HandshakingMatrix.
482    ///
483    /// Delegates to [`decode::grid_to_matrix`].
484    pub fn grid_to_matrix(
485        grid: &[f32],
486        seq_len: usize,
487        num_relations: usize,
488        threshold: f32,
489    ) -> HandshakingMatrix {
490        decode::grid_to_matrix(grid, seq_len, num_relations, threshold)
491    }
492
493    /// Run inference with ONNX model.
494    #[cfg(feature = "onnx")]
495    pub fn extract_with_grid(&self, text: &str, threshold: f32) -> Result<Vec<Entity>> {
496        if text.is_empty() {
497            return Ok(vec![]);
498        }
499
500        let session = self.session.as_ref().ok_or_else(|| {
501            Error::Retrieval("Model not loaded. Call from_pretrained() first.".to_string())
502        })?;
503
504        let tokenizer = self
505            .tokenizer
506            .as_ref()
507            .ok_or_else(|| Error::Retrieval("Tokenizer not loaded.".to_string()))?;
508
509        // Tokenize via whitespace splitting.
510        //
511        // LIMITATION: This only works for languages with explicit word boundaries
512        // (Latin, Cyrillic, etc.). CJK/Thai/Khmer/Lao will produce single "words"
513        // for entire sentences, breaking entity extraction. See module docs.
514        let words: Vec<&str> = text.split_whitespace().collect();
515        if words.is_empty() {
516            return Ok(vec![]);
517        }
518
519        let encoding = tokenizer
520            .encode(text.to_string(), true)
521            .map_err(|e| Error::Parse(format!("Tokenization failed: {}", e)))?;
522
523        let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
524        let attention_mask: Vec<i64> = encoding
525            .get_attention_mask()
526            .iter()
527            .map(|&x| x as i64)
528            .collect();
529        let seq_len = input_ids.len();
530
531        // Build tensors
532        use ndarray::Array2;
533
534        let input_ids_arr = Array2::from_shape_vec((1, seq_len), input_ids)
535            .map_err(|e| Error::Parse(format!("Array error: {}", e)))?;
536        let attention_arr = Array2::from_shape_vec((1, seq_len), attention_mask)
537            .map_err(|e| Error::Parse(format!("Array error: {}", e)))?;
538
539        let input_ids_t = super::ort_compat::tensor_from_ndarray(input_ids_arr)
540            .map_err(|e| Error::Parse(format!("Tensor error: {}", e)))?;
541        let attention_t = super::ort_compat::tensor_from_ndarray(attention_arr)
542            .map_err(|e| Error::Parse(format!("Tensor error: {}", e)))?;
543
544        // Run inference with blocking lock for thread-safe parallel access
545        let mut session_guard = crate::sync::lock(session);
546
547        let outputs = session_guard
548            .run(ort::inputs![
549                "input_ids" => input_ids_t.into_dyn(),
550                "attention_mask" => attention_t.into_dyn(),
551            ])
552            .map_err(|e| Error::Parse(format!("Inference failed: {}", e)))?;
553
554        // Decode grid output
555        let output = outputs
556            .iter()
557            .next()
558            .map(|(_, v)| v)
559            .ok_or_else(|| Error::Parse("No output".to_string()))?;
560
561        let (_, data) = output
562            .try_extract_tensor::<f32>()
563            .map_err(|e| Error::Parse(format!("Extract failed: {}", e)))?;
564        let grid: Vec<f32> = data.to_vec();
565
566        // Convert grid to matrix and decode
567        let num_relations = 3; // None, NNW, THW
568        let matrix = Self::grid_to_matrix(&grid, seq_len, num_relations, threshold);
569
570        // Calculate word positions
571        // Note: This assumes words appear in order and don't overlap.
572        // If a word appears multiple times, this will find the first occurrence
573        // after the previous word. This is correct for tokenized input where
574        // words are in sequence, but may fail if words are out of order.
575        let word_positions: Vec<(usize, usize)> = {
576            // Performance: Pre-allocate positions vec with known size
577            let mut positions = Vec::with_capacity(words.len());
578            let mut pos = 0;
579            for (idx, word) in words.iter().enumerate() {
580                if let Some(start) = text[pos..].find(word) {
581                    let abs_start = pos + start;
582                    let abs_end = abs_start + word.len();
583                    // Validate position is after previous word (words should be in order)
584                    if !positions.is_empty() {
585                        let (_prev_start, prev_end) = positions[positions.len() - 1];
586                        if abs_start < prev_end {
587                            log::warn!(
588                                "Word '{}' (index {}) at position {} overlaps with previous word ending at {}",
589                                word,
590                                idx,
591                                abs_start,
592                                prev_end
593                            );
594                        }
595                    }
596                    positions.push((abs_start, abs_end));
597                    pos = abs_end;
598                } else {
599                    // Word not found - return error to prevent silent entity skipping
600                    return Err(Error::Parse(format!(
601                        "Word '{}' (index {}) not found in text starting at position {}",
602                        word, idx, pos
603                    )));
604                }
605            }
606            positions
607        };
608
609        // Validate that we found positions for all words
610        if word_positions.len() != words.len() {
611            return Err(Error::Parse(format!(
612                "Word position mismatch: found {} positions for {} words",
613                word_positions.len(),
614                words.len()
615            )));
616        }
617
618        // Word positions are byte offsets; `Entity` requires character offsets.
619        let span_converter = crate::offset::SpanConverter::new(text);
620
621        // Performance: Pre-allocate entities vec with estimated capacity
622        // Decode entities for each type
623        let mut entities = Vec::with_capacity(16);
624        for (type_idx, label) in self.config.entity_labels.iter().enumerate() {
625            let spans = self.decode_from_matrix(&matrix, &words.to_vec(), type_idx);
626
627            for (start_word, end_word, score) in spans {
628                if let (Some(&(start_pos, _)), Some(&(_, end_pos))) = (
629                    word_positions.get(start_word),
630                    word_positions.get(end_word.saturating_sub(1)),
631                ) {
632                    if let Some(entity_text) = text.get(start_pos..end_pos) {
633                        entities.push(Entity::new(
634                            entity_text,
635                            decode::map_label_to_entity_type(label),
636                            span_converter.byte_to_char(start_pos),
637                            span_converter.byte_to_char(end_pos),
638                            score,
639                        ));
640                    }
641                }
642            }
643        }
644
645        Ok(entities)
646    }
647
648    /// Full discontinuous-NER extraction using the NNW+THW grid decoding algorithm.
649    ///
650    /// Called by `extract_discontinuous` when an ONNX session is loaded.
651    #[cfg(feature = "onnx")]
652    fn extract_discontinuous_with_nnw(
653        &self,
654        text: &str,
655        threshold: f32,
656    ) -> Result<Vec<DiscontinuousEntity>> {
657        use ndarray::Array2;
658
659        let session = self
660            .session
661            .as_ref()
662            .ok_or_else(|| Error::Retrieval("Model not loaded.".to_string()))?;
663        let tokenizer = self
664            .tokenizer
665            .as_ref()
666            .ok_or_else(|| Error::Retrieval("Tokenizer not loaded.".to_string()))?;
667
668        let words: Vec<&str> = text.split_whitespace().collect();
669        if words.is_empty() {
670            return Ok(vec![]);
671        }
672
673        let encoding = tokenizer
674            .encode(text.to_string(), true)
675            .map_err(|e| Error::Parse(format!("Tokenization failed: {}", e)))?;
676
677        let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
678        let attention_mask: Vec<i64> = encoding
679            .get_attention_mask()
680            .iter()
681            .map(|&x| x as i64)
682            .collect();
683        let seq_len = input_ids.len();
684
685        let input_ids_arr = Array2::from_shape_vec((1, seq_len), input_ids)
686            .map_err(|e| Error::Parse(format!("Array error: {}", e)))?;
687        let attention_arr = Array2::from_shape_vec((1, seq_len), attention_mask)
688            .map_err(|e| Error::Parse(format!("Array error: {}", e)))?;
689        let input_ids_t = super::ort_compat::tensor_from_ndarray(input_ids_arr)
690            .map_err(|e| Error::Parse(format!("Tensor error: {}", e)))?;
691        let attention_t = super::ort_compat::tensor_from_ndarray(attention_arr)
692            .map_err(|e| Error::Parse(format!("Tensor error: {}", e)))?;
693
694        let grid: Vec<f32> = {
695            let mut session_guard = crate::sync::lock(session);
696            let outputs = session_guard
697                .run(ort::inputs![
698                    "input_ids" => input_ids_t.into_dyn(),
699                    "attention_mask" => attention_t.into_dyn(),
700                ])
701                .map_err(|e| Error::Parse(format!("Inference failed: {}", e)))?;
702            let output = outputs
703                .iter()
704                .next()
705                .map(|(_, v)| v)
706                .ok_or_else(|| Error::Parse("No output".to_string()))?;
707            let (_, data) = output
708                .try_extract_tensor::<f32>()
709                .map_err(|e| Error::Parse(format!("Extract failed: {}", e)))?;
710            data.to_vec()
711        }; // session_guard + outputs dropped here
712
713        let num_relations = 3; // None, NNW, THW
714        let matrix = Self::grid_to_matrix(&grid, seq_len, num_relations, threshold);
715
716        // Compute word → byte position map
717        let word_positions: Vec<(usize, usize)> = {
718            let mut positions = Vec::with_capacity(words.len());
719            let mut pos = 0;
720            for word in &words {
721                if let Some(start) = text[pos..].find(word) {
722                    let abs_start = pos + start;
723                    let abs_end = abs_start + word.len();
724                    positions.push((abs_start, abs_end));
725                    pos = abs_end;
726                } else {
727                    return Err(Error::Parse(format!("Word '{}' not found", word)));
728                }
729            }
730            positions
731        };
732
733        let span_converter = crate::offset::SpanConverter::new(text);
734
735        // Decode with NNW-aware discontinuous algorithm
736        let decoded = self.decode_discontinuous_from_matrix(&matrix, &words, threshold);
737        let mut entities = Vec::new();
738        for (type_label, word_spans, score) in decoded {
739            // Convert word-index spans to character-offset spans
740            let mut char_spans: Vec<(usize, usize)> = Vec::new();
741            let mut valid = true;
742            for (ws, we) in &word_spans {
743                let word_start = *ws;
744                let word_end = we.saturating_sub(1);
745                if let (Some(&(byte_start, _)), Some(&(_, byte_end))) =
746                    (word_positions.get(word_start), word_positions.get(word_end))
747                {
748                    char_spans.push((
749                        span_converter.byte_to_char(byte_start),
750                        span_converter.byte_to_char(byte_end),
751                    ));
752                } else {
753                    valid = false;
754                    break;
755                }
756            }
757            if !valid || char_spans.is_empty() {
758                continue;
759            }
760
761            // Reconstruct entity text from all spans
762            let entity_text: String = word_spans
763                .iter()
764                .filter_map(|(ws, we)| {
765                    let last = we.saturating_sub(1);
766                    let byte_start = word_positions.get(*ws)?.0;
767                    let byte_end = word_positions.get(last)?.1;
768                    text.get(byte_start..byte_end)
769                })
770                .collect::<Vec<_>>()
771                .join(" ");
772
773            entities.push(DiscontinuousEntity {
774                spans: char_spans,
775                text: entity_text,
776                entity_type: type_label,
777                confidence: score as f32,
778            });
779        }
780
781        Ok(entities)
782    }
783}
784
785impl Default for W2NER {
786    fn default() -> Self {
787        Self::new()
788    }
789}
790
791impl Model for W2NER {
792    fn extract_entities(&self, text: &str, language: Option<&str>) -> Result<Vec<Entity>> {
793        if text.trim().is_empty() {
794            return Ok(vec![]);
795        }
796
797        // Warn if the language hint suggests a non-whitespace-tokenized language.
798        // W2NER uses `split_whitespace()`, which doesn't work for CJK/Thai/etc.
799        if let Some(lang) = language {
800            let lang_lower = lang.to_lowercase();
801            let is_non_whitespace_lang = matches!(
802                lang_lower.as_str(),
803                "zh" | "zh-cn"
804                    | "zh-tw"
805                    | "chinese"
806                    | "mandarin"
807                    | "cantonese"
808                    | "ja"
809                    | "jp"
810                    | "japanese"
811                    | "ko"
812                    | "kr"
813                    | "korean"
814                    | "th"
815                    | "thai"
816                    | "km"
817                    | "khmer"
818                    | "lo"
819                    | "lao"
820                    | "my"
821                    | "burmese"
822                    | "myanmar"
823            );
824            if is_non_whitespace_lang {
825                log::warn!(
826                    "[W2NER] Language '{}' detected, but W2NER uses whitespace tokenization \
827                     which does not work correctly for CJK/Thai/Khmer/Lao. \
828                     Consider pre-tokenizing or using a different backend (e.g., GLiNER).",
829                    lang
830                );
831            }
832        }
833
834        #[cfg(feature = "onnx")]
835        {
836            if self.session.is_some() {
837                return self.extract_with_grid(text, self.config.threshold as f32);
838            }
839
840            Err(crate::Error::ModelInit(
841                "W2NER model not loaded. Call `W2NER::from_pretrained(...)` (requires `onnx` feature) before calling `extract_entities`.".to_string(),
842            ))
843        }
844
845        #[cfg(not(feature = "onnx"))]
846        {
847            Err(crate::Error::FeatureNotAvailable(
848                "W2NER requires the 'onnx' feature. Build with: cargo build --features onnx"
849                    .to_string(),
850            ))
851        }
852    }
853
854    fn supported_types(&self) -> Vec<EntityType> {
855        self.config
856            .entity_labels
857            .iter()
858            .map(|l| decode::map_label_to_entity_type(l))
859            .collect()
860    }
861
862    fn is_available(&self) -> bool {
863        #[cfg(feature = "onnx")]
864        {
865            self.session.is_some()
866        }
867        #[cfg(not(feature = "onnx"))]
868        {
869            false
870        }
871    }
872
873    fn name(&self) -> &'static str {
874        "w2ner"
875    }
876
877    fn description(&self) -> &'static str {
878        "W2NER: Unified NER via Word-Word Relation Classification (nested/discontinuous support)"
879    }
880
881    fn version(&self) -> String {
882        format!("w2ner-{}", self.config.model_id)
883    }
884
885    fn capabilities(&self) -> crate::ModelCapabilities {
886        crate::ModelCapabilities {
887            batch_capable: true,
888            optimal_batch_size: Some(4),
889            streaming_capable: true,
890            discontinuous_capable: true,
891            ..Default::default()
892        }
893    }
894}
895
896impl crate::NamedEntityCapable for W2NER {}
897
898// =============================================================================
899// BatchCapable Trait Implementation
900// =============================================================================
901
902impl crate::BatchCapable for W2NER {
903    fn optimal_batch_size(&self) -> Option<usize> {
904        Some(4) // W2NER is more memory-intensive due to grid computation
905    }
906}
907
908// =============================================================================
909// StreamingCapable Trait Implementation
910// =============================================================================
911
912impl crate::StreamingCapable for W2NER {
913    fn recommended_chunk_size(&self) -> usize {
914        2048 // Smaller chunks due to grid memory requirements
915    }
916}
917
918// =============================================================================
919// DiscontinuousNER Trait Implementation
920// =============================================================================
921
922impl DiscontinuousNER for W2NER {
923    /// Extract entities with discontinuous span support via the full NNW+THW decoding algorithm.
924    ///
925    /// Uses `decode_discontinuous_from_matrix` (arXiv:2112.10070 §3.3):
926    /// THW cells identify entity boundaries; NNW cells identify adjacent word pairs within
927    /// the same entity.  Gaps in the NNW chain produce disjoint sub-spans, yielding true
928    /// discontinuous entities (e.g. "severe … pain" → two spans).
929    fn extract_discontinuous(
930        &self,
931        text: &str,
932        entity_types: &[&str],
933        threshold: f32,
934    ) -> Result<Vec<DiscontinuousEntity>> {
935        if text.trim().is_empty() {
936            return Ok(vec![]);
937        }
938
939        #[cfg(feature = "onnx")]
940        {
941            if self.session.is_some() {
942                return self.extract_discontinuous_with_nnw(text, threshold);
943            }
944        }
945
946        let _ = (entity_types, threshold);
947
948        #[cfg(feature = "onnx")]
949        {
950            Err(crate::Error::ModelInit(
951                "W2NER model not loaded. Call `W2NER::from_pretrained(...)` (requires `onnx` feature) before calling `extract_discontinuous`.".to_string(),
952            ))
953        }
954
955        #[cfg(not(feature = "onnx"))]
956        {
957            Err(crate::Error::FeatureNotAvailable(
958                "W2NER requires the 'onnx' feature. Build with: cargo build --features onnx"
959                    .to_string(),
960            ))
961        }
962    }
963}
964
965#[cfg(test)]
966mod tests;