anno/backends/w2ner.rs
1//! W2NER - Unified NER via Word-Word Relation Classification.
2//!
3//! W2NER (Word-to-Word NER) models NER as classifying relations between
4//! every pair of words in a sentence. This elegantly handles:
5//!
6//! - **Nested entities**: "The \[University of \[California\]\]"
7//! - **Discontinuous entities**: "severe \[pain\] ... in \[abdomen\]" *(see limitation below)*
8//! - **Overlapping entities**: Same span, different types
9//!
10//! # Discontinuous Entities (Important Limitation)
11//!
12//! **True discontinuous entity decoding is not yet implemented.** The W2NER
13//! paper describes a grid-based algorithm for linking non-adjacent spans, but
14//! this implementation currently returns only contiguous spans.
15//!
16//! The [`DiscontinuousNER`] trait is implemented for API compatibility, but
17//! `extract_discontinuous()` wraps each contiguous entity into a single-segment
18//! result. The `W2NERConfig.allow_discontinuous` flag exists for forward-compatibility
19//! but does not change behavior today.
20//!
21//! # Language Support (Important Limitation)
22//!
23//! **This implementation uses whitespace tokenization** (`split_whitespace()`),
24//! which works correctly for:
25//!
26//! - **Latin-script languages**: English, German, French, Spanish, etc.
27//! - **Cyrillic**: Russian, Ukrainian, etc.
28//! - **Languages with explicit word boundaries**
29//!
30//! It does **NOT** work correctly for:
31//!
32//! - **CJK languages** (Chinese, Japanese, Korean): No whitespace between words
33//! - **Thai, Khmer, Lao**: Scriptio continua (no word boundaries)
34//! - **Languages requiring morphological analysis**
35//!
36//! If you need CJK/Thai support, consider:
37//! 1. Pre-tokenizing with a proper segmenter (e.g., jieba, mecab, pythainlp)
38//! 2. Using a different backend (e.g., GLiNER with subword tokenization)
39//!
40//! The `language` parameter to [`Model::extract_entities`] is currently ignored,
41//! but a warning is logged if a non-whitespace language is detected.
42//!
43//! # Architecture
44//!
45//! ```text
46//! Input: "New York City is great"
47//!
48//! ┌─────────────────────────────┐
49//! │ Encoder (BERT) │
50//! └─────────────────────────────┘
51//! │
52//! ┌─────────────────────────────┐
53//! │ Biaffine Attention │
54//! │ (word-word scoring) │
55//! └─────────────────────────────┘
56//! │
57//! ┌───────────────────────────────┐
58//! │ Word-Word Grid (N×N×L) │
59//! │ ┌───┬───┬───┬───┬───┐ │
60//! │ │ │New│York│City│...│ │
61//! │ ├───┼───┼───┼───┼───┤ │
62//! │ │New│ B │NNW│THW│ │ │
63//! │ ├───┼───┼───┼───┼───┤ │
64//! │ │Yrk│ │ B │NNW│ │ │
65//! │ ├───┼───┼───┼───┼───┤ │
66//! │ │Cty│ │ │ B │ │ │
67//! │ └───┴───┴───┴───┴───┘ │
68//! └───────────────────────────────┘
69//!
70//! Legend:
71//! B = Begin entity
72//! NNW = Next-Neighboring-Word (same entity)
73//! THW = Tail-Head-Word (entity boundary)
74//! ```
75//!
76//! # Grid Labels
77//!
78//! W2NER uses three relation types for each entity label:
79//!
80//! - **NNW (Next-Neighboring-Word)**: Token i and j are adjacent in same entity
81//! - **THW (Tail-Head-Word)**: Token i is tail, token j is head of entity
82//! - **None**: No relation
83//!
84//! # Usage
85//!
86//! ```rust,ignore
87//! use anno::W2NER;
88//!
89//! // Load W2NER model (requires `onnx` feature)
90//! let w2ner = W2NER::from_pretrained("path/to/w2ner-model")?;
91//!
92//! let text = "The University of California Berkeley";
93//! let entities = w2ner.extract_entities(text, None)?;
94//! // Returns nested entities: ORG + nested LOC
95//! ```
96//!
97//! # References
98//!
99//! - [W2NER Paper](https://arxiv.org/abs/2112.10070) (AAAI 2022)
100//! - [TPLinker](https://aclanthology.org/2020.coling-main.138/) (related approach)
101
102use crate::backends::inference::{
103 DiscontinuousEntity, DiscontinuousNER, HandshakingCell, HandshakingMatrix,
104};
105use crate::{Entity, EntityType, Model, Result};
106
107#[cfg(feature = "onnx")]
108use crate::Error;
109
110/// W2NER relation types for word-word classification.
111#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
112pub enum W2NERRelation {
113 /// Next-Neighboring-Word: tokens are adjacent in same entity
114 NNW,
115 /// Tail-Head-Word: marks entity boundary (tail -> head)
116 THW,
117 /// No relation between tokens
118 None,
119}
120
121impl W2NERRelation {
122 /// Convert from label index.
123 #[must_use]
124 pub fn from_index(idx: usize) -> Self {
125 match idx {
126 0 => Self::None,
127 1 => Self::NNW,
128 2 => Self::THW,
129 _ => Self::None,
130 }
131 }
132
133 /// Convert to label index.
134 #[must_use]
135 pub fn to_index(self) -> usize {
136 match self {
137 Self::None => 0,
138 Self::NNW => 1,
139 Self::THW => 2,
140 }
141 }
142}
143
144/// Configuration for W2NER decoding.
145///
146/// # Tokenization
147///
148/// W2NER uses **whitespace tokenization** (`split_whitespace()`), which works
149/// for Latin-script languages but fails for CJK/Thai/Lao. See module-level
150/// docs for details and workarounds.
151#[derive(Debug, Clone)]
152pub struct W2NERConfig {
153 /// Confidence threshold for grid predictions
154 pub threshold: f64,
155 /// Entity type labels (maps grid channels to types)
156 pub entity_labels: Vec<String>,
157 /// Whether to extract nested entities
158 pub allow_nested: bool,
159 /// Whether to extract discontinuous entities.
160 ///
161 /// **Note**: Currently, discontinuous decoding is not fully implemented.
162 /// This flag exists for forward-compatibility; setting it to `true` does
163 /// not yet produce true discontinuous spans. See `backend-02` in docs.
164 pub allow_discontinuous: bool,
165 /// Model identifier for loading
166 pub model_id: String,
167}
168
169impl Default for W2NERConfig {
170 fn default() -> Self {
171 Self {
172 threshold: 0.5,
173 entity_labels: vec!["PER".to_string(), "ORG".to_string(), "LOC".to_string()],
174 allow_nested: true,
175 allow_discontinuous: true,
176 model_id: String::new(),
177 }
178 }
179}
180
181/// W2NER model for unified named entity recognition.
182///
183/// Uses word-word relation classification to handle complex entity
184/// structures (nested, overlapping, discontinuous).
185///
186/// # Feature Requirements
187///
188/// Requires the `onnx` feature for actual inference. Without it, only the
189/// [`decode_from_matrix`](Self::decode_from_matrix) method works with
190/// pre-computed grids.
191///
192/// # Example
193///
194/// ```rust,ignore
195/// let w2ner = W2NER::from_pretrained("ljynlp/w2ner-bert-base")?;
196///
197/// // Handles nested entities naturally
198/// let text = "The University of California Berkeley";
199/// let entities = w2ner.extract_entities(text, None)?;
200/// ```
201pub struct W2NER {
202 config: W2NERConfig,
203 #[cfg(feature = "onnx")]
204 session: Option<crate::sync::Mutex<ort::session::Session>>,
205 #[cfg(feature = "onnx")]
206 tokenizer: Option<tokenizers::Tokenizer>,
207}
208
209impl W2NER {
210 /// Create W2NER with default configuration.
211 #[must_use]
212 pub fn new() -> Self {
213 Self {
214 config: W2NERConfig::default(),
215 #[cfg(feature = "onnx")]
216 session: None,
217 #[cfg(feature = "onnx")]
218 tokenizer: None,
219 }
220 }
221
222 /// Create with custom configuration.
223 #[must_use]
224 pub fn with_config(config: W2NERConfig) -> Self {
225 Self {
226 config,
227 #[cfg(feature = "onnx")]
228 session: None,
229 #[cfg(feature = "onnx")]
230 tokenizer: None,
231 }
232 }
233
234 /// Load W2NER model from path or HuggingFace.
235 ///
236 /// Automatically loads `.env` for HF_TOKEN if present.
237 ///
238 /// # Arguments
239 /// * `model_path` - Local path or HuggingFace model ID
240 #[cfg(feature = "onnx")]
241 pub fn from_pretrained(model_path: &str) -> Result<Self> {
242 use hf_hub::api::sync::{Api, ApiBuilder};
243 use ort::execution_providers::CPUExecutionProvider;
244 use ort::session::Session;
245 use std::path::Path;
246 use std::process::Command;
247
248 // Load .env if present (for HF_TOKEN)
249 crate::env::load_dotenv();
250
251 let (model_file, tokenizer_file) = if Path::new(model_path).exists() {
252 // Local path
253 let model_file = Path::new(model_path).join("model.onnx");
254 let tokenizer_file = Path::new(model_path).join("tokenizer.json");
255 (model_file, tokenizer_file)
256 } else {
257 // HuggingFace download - explicitly use token if available
258 let api = if let Some(token) = crate::env::hf_token() {
259 ApiBuilder::new()
260 .with_token(Some(token))
261 .build()
262 .map_err(|e| {
263 Error::Retrieval(format!(
264 "Failed to initialize HuggingFace API with token: {}",
265 e
266 ))
267 })?
268 } else {
269 Api::new().map_err(|e| {
270 Error::Retrieval(format!("Failed to initialize HuggingFace API: {}", e))
271 })?
272 };
273 let repo = api.model(model_path.to_string());
274
275 let (model_file, tokenizer_file) = match repo
276 .get("model.onnx")
277 .or_else(|_| repo.get("onnx/model.onnx"))
278 {
279 Ok(p) => {
280 let tok = repo.get("tokenizer.json").map_err(|e| {
281 Error::Retrieval(format!("Failed to download tokenizer: {}", e))
282 })?;
283 (p, tok)
284 }
285 Err(e) => {
286 let error_msg = format!("{e}");
287 // Check if it's an authentication error (401) or gated model
288 if error_msg.contains("401") || error_msg.contains("Unauthorized") {
289 return Err(Error::Retrieval(format!(
290 "W2NER model '{}' requires HuggingFace authentication.\n\
291 \n\
292 To fix this:\n\
293 1. Get a HuggingFace token from https://huggingface.co/settings/tokens\n\
294 2. Request access to the model on HuggingFace (if it's gated)\n\
295 3. Set the token: export HF_TOKEN=your_token_here (or HF_API_TOKEN)\n\
296 \n\
297 Alternative: set W2NER_MODEL_PATH to a local export (see scripts/export_w2ner_to_onnx.py).",
298 model_path
299 )));
300 }
301
302 // 404 / missing ONNX is common: HF repos typically don't ship `model.onnx`.
303 // We can auto-export a local ONNX model (bounded by env + CI) and proceed.
304 //
305 // IMPORTANT: many dev shells set `CI=1`, which should not disable auto-export
306 // when running locally. Only treat GitHub Actions as “CI” for this purpose.
307 let in_github_actions = std::env::var("GITHUB_ACTIONS").is_ok();
308 let auto_export = match std::env::var("ANNO_W2NER_AUTO_EXPORT").ok() {
309 None => !in_github_actions,
310 Some(v) => {
311 let t = v.trim().to_lowercase();
312 t == "1" || t == "true" || t == "yes" || t == "y" || t == "on"
313 }
314 };
315
316 if auto_export {
317 let Some(manifest_dir) = std::env::var("CARGO_MANIFEST_DIR").ok() else {
318 return Err(Error::Retrieval(format!(
319 "W2NER model '{}' is missing ONNX files, and auto-export is enabled, but CARGO_MANIFEST_DIR is not set.\n\
320 \n\
321 Fix:\n\
322 - Run from the repo via cargo (so CARGO_MANIFEST_DIR is present), or\n\
323 - Export manually and set W2NER_MODEL_PATH to the export directory.\n\
324 \n\
325 Original error: {e}",
326 model_path
327 )));
328 };
329
330 // Export location under the cache dir.
331 //
332 // IMPORTANT: `anno::eval` is feature-gated, so backends must not depend on
333 // it. Mirror the cache-root logic in a lightweight way here.
334 let cache_dir = std::env::var("ANNO_CACHE_DIR")
335 .ok()
336 .filter(|v| !v.trim().is_empty())
337 .map(std::path::PathBuf::from)
338 .unwrap_or_else(|| {
339 dirs::cache_dir()
340 .unwrap_or_else(|| std::path::PathBuf::from("."))
341 .join("anno")
342 });
343 // Export model choice: default to a public BERT id so auto-export works
344 // even when the configured W2NER HF repo is gated.
345 let export_bert_model = std::env::var("W2NER_EXPORT_BERT_MODEL")
346 .ok()
347 .filter(|v| !v.trim().is_empty())
348 .unwrap_or_else(|| "bert-base-cased".to_string());
349 let safe_id = export_bert_model
350 .chars()
351 .map(|c| if c.is_ascii_alphanumeric() { c } else { '_' })
352 .collect::<String>();
353 let out_dir = cache_dir.join("models").join("w2ner").join(safe_id);
354 std::fs::create_dir_all(&out_dir).map_err(|ioe| {
355 Error::Retrieval(format!(
356 "Failed to create W2NER export dir {:?}: {}",
357 out_dir, ioe
358 ))
359 })?;
360
361 let script_path = std::path::PathBuf::from(manifest_dir)
362 .join("../../scripts/export_w2ner_to_onnx.py");
363 let out_onnx = out_dir.join("model.onnx");
364
365 // Run export via `uv`, which is expected in dev environments.
366 let mut cmd = Command::new("uv");
367 cmd.arg("run")
368 .arg(script_path)
369 .arg("--bert-model")
370 .arg(&export_bert_model)
371 .arg("--output")
372 .arg(&out_onnx);
373
374 let output = cmd.output().map_err(|ioe| {
375 Error::Retrieval(format!(
376 "Failed to spawn W2NER auto-export (uv): {}",
377 ioe
378 ))
379 })?;
380 if !output.status.success() {
381 let stderr = String::from_utf8_lossy(&output.stderr);
382 let stdout = String::from_utf8_lossy(&output.stdout);
383 return Err(Error::Retrieval(format!(
384 "W2NER auto-export failed (exit={}).\n\
385 \n\
386 stdout:\n{}\n\
387 \n\
388 stderr:\n{}\n\
389 \n\
390 Original HF error: {e}",
391 output.status.code().unwrap_or(-1),
392 stdout,
393 stderr
394 )));
395 }
396
397 // Tokenizer is saved alongside the ONNX by the export script.
398 let tok = out_dir.join("tokenizer.json");
399 if !out_onnx.exists() || !tok.exists() {
400 return Err(Error::Retrieval(format!(
401 "W2NER auto-export succeeded but expected files are missing.\n\
402 expected: {:?} and {:?}",
403 out_onnx, tok
404 )));
405 }
406
407 (out_onnx, tok)
408 } else {
409 return Err(Error::Retrieval(format!(
410 "W2NER model '{}' not found or missing ONNX files.\n\
411 \n\
412 The model may be:\n\
413 - A gated model requiring access approval at https://huggingface.co/{}\n\
414 - Missing pre-exported ONNX files (model.onnx or onnx/model.onnx)\n\
415 - Removed or renamed on HuggingFace\n\
416 \n\
417 Fix options:\n\
418 - Set ANNO_W2NER_AUTO_EXPORT=1 (dev) to auto-export to ONNX\n\
419 - Or export manually and set W2NER_MODEL_PATH to the export directory\n\
420 \n\
421 If you have HF_TOKEN set, ensure you've requested and received access to this model.\n\
422 Alternative: Use nuner, gliner2, or other available NER backends.\n\
423 \n\
424 Original error: {e}",
425 model_path, model_path
426 )));
427 }
428 }
429 };
430
431 (model_file, tokenizer_file)
432 };
433
434 let session = Session::builder()
435 .map_err(|e| Error::Retrieval(format!("Failed to create session: {}", e)))?
436 .with_execution_providers([CPUExecutionProvider::default().build()])
437 .map_err(|e| Error::Retrieval(format!("Failed to set providers: {}", e)))?
438 .commit_from_file(&model_file)
439 .map_err(|e| Error::Retrieval(format!("Failed to load model: {}", e)))?;
440
441 let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_file)
442 .map_err(|e| Error::Retrieval(format!("Failed to load tokenizer: {}", e)))?;
443
444 log::debug!("[W2NER] Loaded model");
445
446 Ok(Self {
447 config: W2NERConfig {
448 model_id: model_path.to_string(),
449 ..Default::default()
450 },
451 session: Some(crate::sync::Mutex::new(session)),
452 tokenizer: Some(tokenizer),
453 })
454 }
455
456 /// Set confidence threshold.
457 #[must_use]
458 pub fn with_threshold(mut self, threshold: f64) -> Self {
459 self.config.threshold = threshold.clamp(0.0, 1.0);
460 self
461 }
462
463 /// Set entity type labels.
464 #[must_use]
465 pub fn with_labels(mut self, labels: Vec<String>) -> Self {
466 self.config.entity_labels = labels;
467 self
468 }
469
470 /// Enable/disable nested entity extraction.
471 #[must_use]
472 pub fn with_nested(mut self, allow: bool) -> Self {
473 self.config.allow_nested = allow;
474 self
475 }
476
477 /// Decode entities from a handshaking matrix.
478 ///
479 /// This is the core W2NER decoding algorithm that can be used with
480 /// pre-computed grid predictions (e.g., from external inference).
481 ///
482 /// # Algorithm
483 ///
484 /// 1. Find all THW cells (entity boundaries)
485 /// 2. For each THW(i,j), the entity spans from word j (head) to word i (tail)
486 /// 3. Handle nested/overlapping entities based on config
487 ///
488 /// # Arguments
489 ///
490 /// * `matrix` - The predicted word-word relation grid
491 /// * `tokens` - Original tokens for text reconstruction
492 /// * `entity_type_idx` - Which entity type channel this is
493 pub fn decode_from_matrix(
494 &self,
495 matrix: &HandshakingMatrix,
496 tokens: &[&str],
497 entity_type_idx: usize,
498 ) -> Vec<(usize, usize, f64)> {
499 // Performance: Pre-allocate entities vec with estimated capacity
500 let mut entities = Vec::with_capacity(16);
501
502 // Find all THW (Tail-Head-Word) markers
503 // THW at (i,j) means: token i is tail, token j is head
504 // Entity spans from j (head/start) to i (tail/end)
505 for cell in &matrix.cells {
506 let relation = W2NERRelation::from_index(cell.label_idx as usize);
507 if relation == W2NERRelation::THW && cell.score >= self.config.threshold as f32 {
508 let tail = cell.i as usize;
509 let head = cell.j as usize;
510
511 // Validate: head <= tail (head is start, tail is end)
512 if head <= tail && head < tokens.len() && tail < tokens.len() {
513 entities.push((head, tail + 1, cell.score as f64));
514 }
515 }
516 }
517
518 // Performance: Use unstable sort (we don't need stable sort here)
519 // Sort by start position, then by length (longer first for nested)
520 entities.sort_unstable_by(|a, b| a.0.cmp(&b.0).then_with(|| (b.1 - b.0).cmp(&(a.1 - a.0))));
521
522 // Remove nested entities if not allowed
523 if !self.config.allow_nested {
524 entities = Self::remove_nested(&entities);
525 }
526
527 let _ = entity_type_idx; // May be used for multi-type grids
528 entities
529 }
530
531 /// Decode dense grid output to HandshakingMatrix.
532 ///
533 /// # Arguments
534 /// * `grid` - Dense grid of shape [seq_len, seq_len, num_relations]
535 /// * `seq_len` - Sequence length
536 /// * `threshold` - Score threshold for sparse representation
537 pub fn grid_to_matrix(
538 grid: &[f32],
539 seq_len: usize,
540 num_relations: usize,
541 threshold: f32,
542 ) -> HandshakingMatrix {
543 let mut cells = Vec::new();
544
545 for i in 0..seq_len {
546 for j in 0..seq_len {
547 for rel in 0..num_relations {
548 let idx = i * seq_len * num_relations + j * num_relations + rel;
549 if let Some(&score) = grid.get(idx) {
550 if score >= threshold && rel > 0 {
551 // rel > 0 excludes "None"
552 cells.push(HandshakingCell {
553 i: i as u32,
554 j: j as u32,
555 label_idx: rel as u16,
556 score,
557 });
558 }
559 }
560 }
561 }
562 }
563
564 HandshakingMatrix {
565 cells,
566 seq_len,
567 num_labels: num_relations,
568 }
569 }
570
571 /// Remove nested entities (keep outermost only).
572 fn remove_nested(entities: &[(usize, usize, f64)]) -> Vec<(usize, usize, f64)> {
573 let mut result = Vec::new();
574 let mut last_end = 0;
575
576 for &(start, end, score) in entities {
577 if start >= last_end {
578 result.push((start, end, score));
579 last_end = end;
580 }
581 }
582
583 result
584 }
585
586 /// Map label string to EntityType.
587 fn map_label(label: &str) -> EntityType {
588 match label.to_uppercase().as_str() {
589 "PER" | "PERSON" => EntityType::Person,
590 "ORG" | "ORGANIZATION" => EntityType::Organization,
591 "LOC" | "LOCATION" | "GPE" => EntityType::Location,
592 "DATE" => EntityType::Date,
593 "TIME" => EntityType::Time,
594 "MONEY" => EntityType::Money,
595 "PERCENT" => EntityType::Percent,
596 "MISC" => EntityType::Other("MISC".to_string()),
597 _ => EntityType::Other(label.to_string()),
598 }
599 }
600
601 /// Run inference with ONNX model.
602 #[cfg(feature = "onnx")]
603 pub fn extract_with_grid(&self, text: &str, threshold: f32) -> Result<Vec<Entity>> {
604 if text.is_empty() {
605 return Ok(vec![]);
606 }
607
608 let session = self.session.as_ref().ok_or_else(|| {
609 Error::Retrieval("Model not loaded. Call from_pretrained() first.".to_string())
610 })?;
611
612 let tokenizer = self
613 .tokenizer
614 .as_ref()
615 .ok_or_else(|| Error::Retrieval("Tokenizer not loaded.".to_string()))?;
616
617 // Tokenize via whitespace splitting.
618 //
619 // LIMITATION: This only works for languages with explicit word boundaries
620 // (Latin, Cyrillic, etc.). CJK/Thai/Khmer/Lao will produce single "words"
621 // for entire sentences, breaking entity extraction. See module docs.
622 let words: Vec<&str> = text.split_whitespace().collect();
623 if words.is_empty() {
624 return Ok(vec![]);
625 }
626
627 let encoding = tokenizer
628 .encode(text.to_string(), true)
629 .map_err(|e| Error::Parse(format!("Tokenization failed: {}", e)))?;
630
631 let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&x| x as i64).collect();
632 let attention_mask: Vec<i64> = encoding
633 .get_attention_mask()
634 .iter()
635 .map(|&x| x as i64)
636 .collect();
637 let seq_len = input_ids.len();
638
639 // Build tensors
640 use ndarray::Array2;
641
642 let input_ids_arr = Array2::from_shape_vec((1, seq_len), input_ids)
643 .map_err(|e| Error::Parse(format!("Array error: {}", e)))?;
644 let attention_arr = Array2::from_shape_vec((1, seq_len), attention_mask)
645 .map_err(|e| Error::Parse(format!("Array error: {}", e)))?;
646
647 let input_ids_t = super::ort_compat::tensor_from_ndarray(input_ids_arr)
648 .map_err(|e| Error::Parse(format!("Tensor error: {}", e)))?;
649 let attention_t = super::ort_compat::tensor_from_ndarray(attention_arr)
650 .map_err(|e| Error::Parse(format!("Tensor error: {}", e)))?;
651
652 // Run inference with blocking lock for thread-safe parallel access
653 let mut session_guard = crate::sync::lock(session);
654
655 let outputs = session_guard
656 .run(ort::inputs![
657 "input_ids" => input_ids_t.into_dyn(),
658 "attention_mask" => attention_t.into_dyn(),
659 ])
660 .map_err(|e| Error::Parse(format!("Inference failed: {}", e)))?;
661
662 // Decode grid output
663 let output = outputs
664 .iter()
665 .next()
666 .map(|(_, v)| v)
667 .ok_or_else(|| Error::Parse("No output".to_string()))?;
668
669 let (_, data) = output
670 .try_extract_tensor::<f32>()
671 .map_err(|e| Error::Parse(format!("Extract failed: {}", e)))?;
672 let grid: Vec<f32> = data.to_vec();
673
674 // Convert grid to matrix and decode
675 let num_relations = 3; // None, NNW, THW
676 let matrix = Self::grid_to_matrix(&grid, seq_len, num_relations, threshold);
677
678 // Calculate word positions
679 // Note: This assumes words appear in order and don't overlap.
680 // If a word appears multiple times, this will find the first occurrence
681 // after the previous word. This is correct for tokenized input where
682 // words are in sequence, but may fail if words are out of order.
683 let word_positions: Vec<(usize, usize)> = {
684 // Performance: Pre-allocate positions vec with known size
685 let mut positions = Vec::with_capacity(words.len());
686 let mut pos = 0;
687 for (idx, word) in words.iter().enumerate() {
688 if let Some(start) = text[pos..].find(word) {
689 let abs_start = pos + start;
690 let abs_end = abs_start + word.len();
691 // Validate position is after previous word (words should be in order)
692 if !positions.is_empty() {
693 let (_prev_start, prev_end) = positions[positions.len() - 1];
694 if abs_start < prev_end {
695 log::warn!(
696 "Word '{}' (index {}) at position {} overlaps with previous word ending at {}",
697 word,
698 idx,
699 abs_start,
700 prev_end
701 );
702 }
703 }
704 positions.push((abs_start, abs_end));
705 pos = abs_end;
706 } else {
707 // Word not found - return error to prevent silent entity skipping
708 return Err(Error::Parse(format!(
709 "Word '{}' (index {}) not found in text starting at position {}",
710 word, idx, pos
711 )));
712 }
713 }
714 positions
715 };
716
717 // Validate that we found positions for all words
718 if word_positions.len() != words.len() {
719 return Err(Error::Parse(format!(
720 "Word position mismatch: found {} positions for {} words",
721 word_positions.len(),
722 words.len()
723 )));
724 }
725
726 // Word positions are byte offsets; `Entity` requires character offsets.
727 let span_converter = crate::offset::SpanConverter::new(text);
728
729 // Performance: Pre-allocate entities vec with estimated capacity
730 // Decode entities for each type
731 let mut entities = Vec::with_capacity(16);
732 for (type_idx, label) in self.config.entity_labels.iter().enumerate() {
733 let spans = self.decode_from_matrix(&matrix, &words.to_vec(), type_idx);
734
735 for (start_word, end_word, score) in spans {
736 if let (Some(&(start_pos, _)), Some(&(_, end_pos))) = (
737 word_positions.get(start_word),
738 word_positions.get(end_word.saturating_sub(1)),
739 ) {
740 if let Some(entity_text) = text.get(start_pos..end_pos) {
741 entities.push(Entity::new(
742 entity_text,
743 Self::map_label(label),
744 span_converter.byte_to_char(start_pos),
745 span_converter.byte_to_char(end_pos),
746 score,
747 ));
748 }
749 }
750 }
751 }
752
753 Ok(entities)
754 }
755}
756
757impl Default for W2NER {
758 fn default() -> Self {
759 Self::new()
760 }
761}
762
763impl Model for W2NER {
764 fn extract_entities(&self, text: &str, language: Option<&str>) -> Result<Vec<Entity>> {
765 if text.trim().is_empty() {
766 return Ok(vec![]);
767 }
768
769 // Warn if the language hint suggests a non-whitespace-tokenized language.
770 // W2NER uses `split_whitespace()`, which doesn't work for CJK/Thai/etc.
771 if let Some(lang) = language {
772 let lang_lower = lang.to_lowercase();
773 let is_non_whitespace_lang = matches!(
774 lang_lower.as_str(),
775 "zh" | "zh-cn"
776 | "zh-tw"
777 | "chinese"
778 | "mandarin"
779 | "cantonese"
780 | "ja"
781 | "jp"
782 | "japanese"
783 | "ko"
784 | "kr"
785 | "korean"
786 | "th"
787 | "thai"
788 | "km"
789 | "khmer"
790 | "lo"
791 | "lao"
792 | "my"
793 | "burmese"
794 | "myanmar"
795 );
796 if is_non_whitespace_lang {
797 log::warn!(
798 "[W2NER] Language '{}' detected, but W2NER uses whitespace tokenization \
799 which does not work correctly for CJK/Thai/Khmer/Lao. \
800 Consider pre-tokenizing or using a different backend (e.g., GLiNER).",
801 lang
802 );
803 }
804 }
805
806 #[cfg(feature = "onnx")]
807 {
808 if self.session.is_some() {
809 return self.extract_with_grid(text, self.config.threshold as f32);
810 }
811
812 Err(crate::Error::ModelInit(
813 "W2NER model not loaded. Call `W2NER::from_pretrained(...)` (requires `onnx` feature) before calling `extract_entities`.".to_string(),
814 ))
815 }
816
817 #[cfg(not(feature = "onnx"))]
818 {
819 Err(crate::Error::FeatureNotAvailable(
820 "W2NER requires the 'onnx' feature. Build with: cargo build --features onnx"
821 .to_string(),
822 ))
823 }
824 }
825
826 fn supported_types(&self) -> Vec<EntityType> {
827 self.config
828 .entity_labels
829 .iter()
830 .map(|l| Self::map_label(l))
831 .collect()
832 }
833
834 fn is_available(&self) -> bool {
835 #[cfg(feature = "onnx")]
836 {
837 self.session.is_some()
838 }
839 #[cfg(not(feature = "onnx"))]
840 {
841 false
842 }
843 }
844
845 fn name(&self) -> &'static str {
846 "w2ner"
847 }
848
849 fn description(&self) -> &'static str {
850 "W2NER: Unified NER via Word-Word Relation Classification (nested/discontinuous support)"
851 }
852
853 fn version(&self) -> String {
854 format!("w2ner-{}", self.config.model_id)
855 }
856}
857
858// =============================================================================
859// BatchCapable Trait Implementation
860// =============================================================================
861
862impl crate::BatchCapable for W2NER {
863 fn optimal_batch_size(&self) -> Option<usize> {
864 Some(4) // W2NER is more memory-intensive due to grid computation
865 }
866}
867
868// =============================================================================
869// StreamingCapable Trait Implementation
870// =============================================================================
871
872impl crate::StreamingCapable for W2NER {
873 fn recommended_chunk_size(&self) -> usize {
874 2048 // Smaller chunks due to grid memory requirements
875 }
876}
877
878// =============================================================================
879// DiscontinuousNER Trait Implementation
880// =============================================================================
881
882impl DiscontinuousNER for W2NER {
883 /// Extract entities with discontinuous span support.
884 ///
885 /// # Current Limitation
886 ///
887 /// **True discontinuous decoding is not yet implemented.** This method
888 /// currently wraps each contiguous entity into a single-segment
889 /// `DiscontinuousEntity`. The W2NER paper describes a grid-based decoding
890 /// algorithm for discontinuous entities, but this implementation does not
891 /// yet decode those relations.
892 ///
893 /// If you need true discontinuous entity support, consider:
894 /// 1. Post-processing with heuristics (e.g., linking "severe" to "pain")
895 /// 2. Using a specialized discontinuous NER model
896 ///
897 /// This trait implementation exists for API compatibility and will be
898 /// upgraded when true discontinuous decoding is implemented.
899 fn extract_discontinuous(
900 &self,
901 text: &str,
902 entity_types: &[&str],
903 threshold: f32,
904 ) -> Result<Vec<DiscontinuousEntity>> {
905 if text.trim().is_empty() {
906 return Ok(vec![]);
907 }
908
909 #[cfg(feature = "onnx")]
910 {
911 if self.session.is_some() {
912 // TODO(discontinuous): Implement true discontinuous decoding.
913 //
914 // The W2NER grid contains relation information that could be
915 // used to link non-adjacent spans into discontinuous entities.
916 // For now, we wrap each contiguous entity into a single-segment
917 // DiscontinuousEntity for API compatibility.
918 //
919 // See: https://arxiv.org/abs/2112.10070 (Section 3.3)
920 let entities = self.extract_with_grid(text, threshold)?;
921
922 return Ok(entities
923 .into_iter()
924 .map(|e| DiscontinuousEntity {
925 spans: vec![(e.start, e.end)],
926 text: e.text,
927 entity_type: e.entity_type.as_label().to_string(),
928 confidence: e.confidence as f32,
929 })
930 .collect());
931 }
932 }
933
934 let _ = (entity_types, threshold);
935
936 #[cfg(feature = "onnx")]
937 {
938 Err(crate::Error::ModelInit(
939 "W2NER model not loaded. Call `W2NER::from_pretrained(...)` (requires `onnx` feature) before calling `extract_discontinuous`.".to_string(),
940 ))
941 }
942
943 #[cfg(not(feature = "onnx"))]
944 {
945 Err(crate::Error::FeatureNotAvailable(
946 "W2NER requires the 'onnx' feature. Build with: cargo build --features onnx"
947 .to_string(),
948 ))
949 }
950 }
951}
952
953#[cfg(test)]
954mod tests {
955 use super::*;
956
957 #[test]
958 fn test_w2ner_relation_conversion() {
959 assert_eq!(W2NERRelation::from_index(0), W2NERRelation::None);
960 assert_eq!(W2NERRelation::from_index(1), W2NERRelation::NNW);
961 assert_eq!(W2NERRelation::from_index(2), W2NERRelation::THW);
962
963 assert_eq!(W2NERRelation::None.to_index(), 0);
964 assert_eq!(W2NERRelation::NNW.to_index(), 1);
965 assert_eq!(W2NERRelation::THW.to_index(), 2);
966 }
967
968 #[test]
969 fn test_w2ner_config_defaults() {
970 let config = W2NERConfig::default();
971 assert!((config.threshold - 0.5).abs() < f64::EPSILON);
972 assert!(config.allow_nested);
973 assert!(config.allow_discontinuous);
974 assert_eq!(config.entity_labels.len(), 3);
975 }
976
977 #[test]
978 fn test_decode_simple_entity() {
979 let w2ner = W2NER::new();
980 let tokens = ["New", "York", "City"];
981
982 // THW marker: tail=2, head=0 (entity spans all 3 tokens)
983 let matrix = HandshakingMatrix {
984 cells: vec![HandshakingCell {
985 i: 2, // tail
986 j: 0, // head
987 label_idx: W2NERRelation::THW.to_index() as u16,
988 score: 0.9,
989 }],
990 seq_len: 3,
991 num_labels: 3,
992 };
993
994 let entities = w2ner.decode_from_matrix(&matrix, &tokens, 0);
995 assert_eq!(entities.len(), 1);
996 assert_eq!(entities[0].0, 0); // start
997 assert_eq!(entities[0].1, 3); // end
998 }
999
1000 #[test]
1001 fn test_decode_nested_entities() {
1002 let w2ner = W2NER::with_config(W2NERConfig {
1003 allow_nested: true,
1004 ..Default::default()
1005 });
1006
1007 let tokens = ["University", "of", "California", "Berkeley"];
1008
1009 let matrix = HandshakingMatrix {
1010 cells: vec![
1011 // Full entity: tail=3, head=0
1012 HandshakingCell {
1013 i: 3,
1014 j: 0,
1015 label_idx: W2NERRelation::THW.to_index() as u16,
1016 score: 0.95,
1017 },
1018 // Nested: tail=2, head=2 (just "California")
1019 HandshakingCell {
1020 i: 2,
1021 j: 2,
1022 label_idx: W2NERRelation::THW.to_index() as u16,
1023 score: 0.85,
1024 },
1025 ],
1026 seq_len: 4,
1027 num_labels: 3,
1028 };
1029
1030 let entities = w2ner.decode_from_matrix(&matrix, &tokens, 0);
1031 assert_eq!(entities.len(), 2);
1032 }
1033
1034 #[test]
1035 fn test_remove_nested() {
1036 let entities = vec![
1037 (0, 4, 0.9), // outer
1038 (2, 3, 0.8), // nested
1039 ];
1040
1041 let filtered = W2NER::remove_nested(&entities);
1042 assert_eq!(filtered.len(), 1);
1043 assert_eq!(filtered[0], (0, 4, 0.9));
1044 }
1045
1046 #[test]
1047 fn test_grid_to_matrix() {
1048 // 3x3 grid with 3 relations (None, NNW, THW)
1049 let seq_len = 3;
1050 let num_rels = 3;
1051 let mut grid = vec![0.0f32; seq_len * seq_len * num_rels];
1052
1053 // Set THW at (2, 0) with score 0.9
1054 // Index formula: i * seq_len * num_rels + j * num_rels + rel_idx
1055 let i = 2;
1056 let j = 0;
1057 let rel_thw = 2;
1058 let idx = i * seq_len * num_rels + j * num_rels + rel_thw;
1059 grid[idx] = 0.9;
1060
1061 let matrix = W2NER::grid_to_matrix(&grid, seq_len, num_rels, 0.5);
1062 assert_eq!(matrix.cells.len(), 1);
1063 assert_eq!(matrix.cells[0].i, 2);
1064 assert_eq!(matrix.cells[0].j, 0);
1065 }
1066
1067 #[test]
1068 fn test_label_mapping() {
1069 assert_eq!(W2NER::map_label("PER"), EntityType::Person);
1070 assert_eq!(W2NER::map_label("org"), EntityType::Organization);
1071 assert_eq!(W2NER::map_label("GPE"), EntityType::Location);
1072 assert_eq!(
1073 W2NER::map_label("CUSTOM"),
1074 EntityType::Other("CUSTOM".to_string())
1075 );
1076 }
1077
1078 #[test]
1079 fn test_empty_input() {
1080 let w2ner = W2NER::new();
1081 let entities = w2ner.extract_entities("", None).unwrap();
1082 assert!(entities.is_empty());
1083 }
1084
1085 #[test]
1086 fn test_not_available_without_model() {
1087 let w2ner = W2NER::new();
1088 // Without model loaded, should not be available
1089 assert!(!w2ner.is_available());
1090 }
1091
1092 #[test]
1093 fn test_errors_without_model() {
1094 let w2ner = W2NER::new();
1095 // Without model, should return an explicit error (no silent empty fallback).
1096 let err = w2ner
1097 .extract_entities("Steve Jobs founded Apple", None)
1098 .unwrap_err();
1099 assert!(
1100 matches!(
1101 err,
1102 crate::Error::ModelInit(_) | crate::Error::FeatureNotAvailable(_)
1103 ),
1104 "unexpected error: {:?}",
1105 err
1106 );
1107 }
1108}