Skip to main content

anno/backends/
gliner2.rs

1//! GLiNER2: Multi-task Information Extraction.
2//!
3//! GLiNER2 extends GLiNER to support:
4//! - Named Entity Recognition (with label descriptions)
5//! - Text Classification (single/multi-label)
6//! - Hierarchical Structure Extraction
7//! - Task Composition (multiple tasks in one pass)
8//!
9//! This backend is based on the GLiNER2 paper (arXiv:2507.18546). The details of
10//! prompt formatting and the full task schema are paper-defined; this module
11//! focuses on the inference integration and trait wiring used by `anno`.
12//!
13//! # Trait Integration
14//!
15//! GLiNER2 implements the standard `anno` traits:
16//! - `Model` - Core entity extraction interface
17//! - `ZeroShotNER` - Open-domain entity types
18//! - `RelationExtractor` - Joint entity-relation extraction (via GLiREL)
19//! - `BatchCapable` - Efficient batch processing
20//!
21//! # Usage
22//!
23//! ```rust,ignore
24//! use anno::{Model, ZeroShotNER, DEFAULT_GLINER2_MODEL};
25//! use anno::backends::gliner2::{GLiNER2, TaskSchema};
26//!
27//! // Use the official Fastino Labs GLiNER2 model
28//! let model = GLiNER2::from_pretrained(DEFAULT_GLINER2_MODEL)?;
29//! // Or: GLiNER2::from_pretrained("fastino/gliner2-base-v1")?;
30//!
31//! // Standard Model trait
32//! let entities = model.extract_entities("Apple announced iPhone 15", None)?;
33//!
34//! // Zero-shot with custom types
35//! let types = &["company", "product", "event"];
36//! let entities = model.extract_with_types(text, types, 0.5)?;
37//!
38//! // Multi-task extraction with schema
39//! let schema = TaskSchema::new()
40//!     .with_entities(&["person", "organization", "product"])
41//!     .with_classification("sentiment", &["positive", "negative", "neutral"]);
42//!
43//! let result = model.extract_with_schema("Apple announced iPhone 15", &schema)?;
44//! ```
45//!
46//! # Backends
47//!
48//! - **ONNX** (recommended): `cargo build --features onnx`
49//! - **Candle** (native): `cargo build --features candle`
50
51#[cfg(feature = "onnx")]
52use crate::sync::{lock, Mutex};
53use crate::{Entity, EntityType, Error, Result};
54use anno_core::EntityCategory;
55use serde::{Deserialize, Serialize};
56use std::collections::HashMap;
57#[cfg(feature = "candle")]
58use std::sync::RwLock;
59
60// Import trait definitions for implementations
61use crate::backends::inference::{ExtractionWithRelations, RelationExtractor, ZeroShotNER};
62
63// =============================================================================
64// Special Token IDs (gliner-multitask-large-v0.5 vocabulary)
65// Valid tokens: [MASK]=128000, [FLERT]=128001, <<ENT>>=128002, <<SEP>>=128003
66// Note: [P], [C], [L] markers don't exist in this model - DO NOT USE 128004+
67// =============================================================================
68
69/// <<ENT>> token - entity type marker (class_token_index in config)
70#[cfg(feature = "onnx")]
71const TOKEN_ENT: u32 = 128002;
72/// <<SEP>> separator token
73#[cfg(feature = "onnx")]
74const TOKEN_SEP: u32 = 128003;
75/// Start token [CLS]
76#[cfg(feature = "onnx")]
77const TOKEN_START: u32 = 1;
78/// End token [SEP]
79#[cfg(feature = "onnx")]
80const TOKEN_END: u32 = 2;
81
82/// Default max span width
83const MAX_SPAN_WIDTH: usize = 12;
84/// Max count for structure instances (0-19)
85#[cfg(feature = "candle")]
86const MAX_COUNT: usize = 20;
87
88// =============================================================================
89// Label Embedding Cache
90// =============================================================================
91
92/// Cache for label embeddings to avoid recomputation
93#[derive(Debug, Default)]
94pub struct LabelCache {
95    #[cfg(feature = "candle")]
96    cache: RwLock<HashMap<String, Vec<f32>>>,
97    #[cfg(not(feature = "candle"))]
98    _phantom: std::marker::PhantomData<()>,
99}
100
101#[cfg(feature = "candle")]
102impl LabelCache {
103    fn new() -> Self {
104        Self {
105            cache: RwLock::new(HashMap::new()),
106        }
107    }
108
109    fn get(&self, label: &str) -> Option<Vec<f32>> {
110        self.cache.read().ok()?.get(label).cloned()
111    }
112
113    fn insert(&self, label: String, embedding: Vec<f32>) {
114        if let Ok(mut cache) = self.cache.write() {
115            cache.insert(label, embedding);
116        }
117    }
118}
119
120#[cfg(not(feature = "candle"))]
121impl LabelCache {
122    #[allow(dead_code)]
123    fn new() -> Self {
124        Self {
125            _phantom: std::marker::PhantomData,
126        }
127    }
128}
129
130// =============================================================================
131// Task Schema
132// =============================================================================
133
134/// Schema defining what to extract.
135///
136/// Use builder methods to construct complex schemas:
137///
138/// ```rust,ignore
139/// let schema = TaskSchema::new()
140///     .with_entities(&["person", "organization"])
141///     .with_classification("sentiment", &["positive", "negative"], false)
142///     .with_structure(
143///         StructureTask::new("product")
144///             .with_field("name", FieldType::String)
145///             .with_field("price", FieldType::String)
146///     );
147/// ```
148#[derive(Debug, Clone, Default, Serialize, Deserialize)]
149pub struct TaskSchema {
150    /// Entity types to extract
151    pub entities: Option<EntityTask>,
152    /// Classification tasks
153    pub classifications: Vec<ClassificationTask>,
154    /// Structure extraction tasks
155    pub structures: Vec<StructureTask>,
156}
157
158impl TaskSchema {
159    /// Create empty schema.
160    pub fn new() -> Self {
161        Self::default()
162    }
163
164    /// Add entity types to extract.
165    pub fn with_entities(mut self, types: &[&str]) -> Self {
166        self.entities = Some(EntityTask {
167            types: types.iter().map(|s| s.to_string()).collect(),
168            descriptions: HashMap::new(),
169        });
170        self
171    }
172
173    /// Add entity types with descriptions for better zero-shot.
174    pub fn with_entities_described(mut self, types_with_desc: HashMap<String, String>) -> Self {
175        let types: Vec<String> = types_with_desc.keys().cloned().collect();
176        self.entities = Some(EntityTask {
177            types,
178            descriptions: types_with_desc,
179        });
180        self
181    }
182
183    /// Add a classification task.
184    pub fn with_classification(mut self, name: &str, labels: &[&str], multi_label: bool) -> Self {
185        self.classifications.push(ClassificationTask {
186            name: name.to_string(),
187            labels: labels.iter().map(|s| s.to_string()).collect(),
188            multi_label,
189            descriptions: HashMap::new(),
190        });
191        self
192    }
193
194    /// Add a structure extraction task.
195    pub fn with_structure(mut self, task: StructureTask) -> Self {
196        self.structures.push(task);
197        self
198    }
199}
200
201/// Entity extraction task configuration.
202#[derive(Debug, Clone, Default, Serialize, Deserialize)]
203pub struct EntityTask {
204    /// Entity type labels
205    pub types: Vec<String>,
206    /// Optional descriptions for each type
207    pub descriptions: HashMap<String, String>,
208}
209
210/// Classification task configuration.
211#[derive(Debug, Clone, Default, Serialize, Deserialize)]
212pub struct ClassificationTask {
213    /// Task name (e.g., "sentiment")
214    pub name: String,
215    /// Class labels
216    pub labels: Vec<String>,
217    /// Whether multiple labels can be selected
218    pub multi_label: bool,
219    /// Optional descriptions for labels
220    pub descriptions: HashMap<String, String>,
221}
222
223/// Hierarchical structure extraction task.
224#[derive(Debug, Clone, Default, Serialize, Deserialize)]
225pub struct StructureTask {
226    /// Structure type name (parent entity)
227    pub name: String,
228    /// Internal alias for compatibility
229    #[serde(skip)]
230    pub structure_type: String,
231    /// Child fields to extract
232    pub fields: Vec<StructureField>,
233}
234
235impl StructureTask {
236    /// Create new structure task.
237    pub fn new(name: &str) -> Self {
238        Self {
239            name: name.to_string(),
240            structure_type: name.to_string(),
241            fields: Vec::new(),
242        }
243    }
244
245    /// Add a field to extract.
246    pub fn with_field(mut self, name: &str, field_type: FieldType) -> Self {
247        self.fields.push(StructureField {
248            name: name.to_string(),
249            field_type,
250            description: None,
251            choices: None,
252        });
253        self
254    }
255
256    /// Add a field with description.
257    pub fn with_field_described(
258        mut self,
259        name: &str,
260        field_type: FieldType,
261        description: &str,
262    ) -> Self {
263        self.fields.push(StructureField {
264            name: name.to_string(),
265            field_type,
266            description: Some(description.to_string()),
267            choices: None,
268        });
269        self
270    }
271
272    /// Add a choice field with constrained options.
273    pub fn with_choice_field(mut self, name: &str, choices: &[&str]) -> Self {
274        self.fields.push(StructureField {
275            name: name.to_string(),
276            field_type: FieldType::Choice,
277            description: None,
278            choices: Some(choices.iter().map(|s| s.to_string()).collect()),
279        });
280        self
281    }
282}
283
284/// Structure field configuration.
285#[derive(Debug, Clone, Serialize, Deserialize)]
286pub struct StructureField {
287    /// Field name
288    pub name: String,
289    /// Field type
290    pub field_type: FieldType,
291    /// Optional description
292    pub description: Option<String>,
293    /// For Choice type: allowed values
294    pub choices: Option<Vec<String>>,
295}
296
297/// Field type for structure extraction.
298#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
299pub enum FieldType {
300    /// Single string value
301    String,
302    /// List of values
303    List,
304    /// Choice from constrained options
305    Choice,
306}
307
308// =============================================================================
309// Extraction Results
310// =============================================================================
311
312/// Combined extraction result.
313#[derive(Debug, Clone, Default, Serialize, Deserialize)]
314pub struct ExtractionResult {
315    /// Extracted entities
316    pub entities: Vec<Entity>,
317    /// Classification results by task name
318    pub classifications: HashMap<String, ClassificationResult>,
319    /// Extracted structures
320    pub structures: Vec<ExtractedStructure>,
321}
322
323/// Classification result.
324#[derive(Debug, Clone, Default, Serialize, Deserialize)]
325pub struct ClassificationResult {
326    /// Selected label(s)
327    pub labels: Vec<String>,
328    /// Score for each label
329    pub scores: HashMap<String, f32>,
330}
331
332/// Extracted structure instance.
333#[derive(Debug, Clone, Default, Serialize, Deserialize)]
334pub struct ExtractedStructure {
335    /// Structure type
336    pub structure_type: String,
337    /// Extracted field values
338    pub fields: HashMap<String, StructureValue>,
339}
340
341/// Value for a structure field.
342#[derive(Debug, Clone, Serialize, Deserialize)]
343pub enum StructureValue {
344    /// Single value
345    Single(String),
346    /// List of values
347    List(Vec<String>),
348}
349
350// =============================================================================
351// ONNX Backend
352// =============================================================================
353
354/// GLiNER2 ONNX implementation.
355/// GLiNER2 ONNX implementation.
356#[cfg(feature = "onnx")]
357#[derive(Debug)]
358pub struct GLiNER2Onnx {
359    session: Mutex<ort::session::Session>,
360    tokenizer: tokenizers::Tokenizer,
361    #[allow(dead_code)]
362    model_name: String,
363    #[allow(dead_code)]
364    hidden_size: usize,
365}
366
367#[cfg(feature = "onnx")]
368impl GLiNER2Onnx {
369    /// Load model from HuggingFace Hub.
370    pub fn from_pretrained(model_id: &str) -> Result<Self> {
371        use hf_hub::api::sync::Api;
372        use ort::execution_providers::CPUExecutionProvider;
373        use ort::session::Session;
374
375        let api = Api::new().map_err(|e| Error::Retrieval(format!("HF API: {}", e)))?;
376        let repo = api.model(model_id.to_string());
377
378        // Try different model file names
379        let model_path = repo
380            .get("onnx/model.onnx")
381            .or_else(|_| repo.get("model.onnx"))
382            .map_err(|e| Error::Retrieval(format!("model.onnx: {}", e)))?;
383
384        let tokenizer_path = repo
385            .get("tokenizer.json")
386            .map_err(|e| Error::Retrieval(format!("tokenizer.json: {}", e)))?;
387
388        let config_path = repo
389            .get("config.json")
390            .map_err(|e| Error::Retrieval(format!("config.json: {}", e)))?;
391
392        // Load tokenizer
393        let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
394            .map_err(|e| Error::Retrieval(format!("tokenizer: {}", e)))?;
395
396        // Parse config
397        let config_str = std::fs::read_to_string(&config_path)
398            .map_err(|e| Error::Retrieval(format!("config read: {}", e)))?;
399        let config: serde_json::Value = serde_json::from_str(&config_str)
400            .map_err(|e| Error::Parse(format!("config parse: {}", e)))?;
401        let hidden_size = config["hidden_size"].as_u64().unwrap_or(768) as usize;
402
403        // Create ONNX session
404        let session = Session::builder()
405            .map_err(|e| Error::Retrieval(format!("ONNX builder: {}", e)))?
406            .with_execution_providers([CPUExecutionProvider::default().build()])
407            .map_err(|e| Error::Retrieval(format!("ONNX providers: {}", e)))?
408            .commit_from_file(&model_path)
409            .map_err(|e| Error::Retrieval(format!("ONNX load: {}", e)))?;
410
411        log::info!(
412            "[GLiNER2-ONNX] Loaded {} (hidden={})",
413            model_id,
414            hidden_size
415        );
416        log::debug!("[GLiNER2-ONNX] Model loaded");
417
418        Ok(Self {
419            session: Mutex::new(session),
420            tokenizer,
421            model_name: model_id.to_string(),
422            hidden_size,
423        })
424    }
425
426    /// Extract entities, classifications, and structures according to schema.
427    pub fn extract(&self, text: &str, schema: &TaskSchema) -> Result<ExtractionResult> {
428        let mut result = ExtractionResult::default();
429
430        // NER extraction
431        if let Some(ref ent_task) = schema.entities {
432            let labels: Vec<&str> = ent_task.types.iter().map(|s| s.as_str()).collect();
433            let entities = self.extract_ner(text, &labels, 0.5)?;
434            result.entities = entities;
435        }
436
437        // Classification
438        for class_task in &schema.classifications {
439            let labels: Vec<&str> = class_task.labels.iter().map(|s| s.as_str()).collect();
440            let class_result = self.classify(text, &labels, class_task.multi_label)?;
441            result
442                .classifications
443                .insert(class_task.name.clone(), class_result);
444        }
445
446        // Structure extraction
447        for struct_task in &schema.structures {
448            let structures = self.extract_structure(text, struct_task)?;
449            result.structures.extend(structures);
450        }
451
452        Ok(result)
453    }
454
455    /// Extract named entities using GLiNER2 NER format.
456    fn extract_ner(
457        &self,
458        text: &str,
459        entity_types: &[&str],
460        threshold: f32,
461    ) -> Result<Vec<Entity>> {
462        if text.is_empty() || entity_types.is_empty() {
463            return Ok(Vec::new());
464        }
465
466        // Split into words
467        let text_words: Vec<&str> = text.split_whitespace().collect();
468        if text_words.is_empty() {
469            return Ok(Vec::new());
470        }
471
472        // Encode following GLiNER2 format: [P] entities ([E]type1 [E]type2 ...) [SEP] text
473        let (input_ids, attention_mask, words_mask) =
474            self.encode_ner_prompt(&text_words, entity_types)?;
475
476        // Build tensors - GLiNER2 ONNX model only needs 4 inputs:
477        // input_ids, attention_mask, words_mask, text_lengths
478        // (NOT span_idx/span_mask - those were for older model variants)
479        use ndarray::Array2;
480
481        let batch_size = 1;
482        let seq_len = input_ids.len();
483
484        let input_ids_arr = Array2::from_shape_vec((batch_size, seq_len), input_ids)
485            .map_err(|e| Error::Parse(format!("Array: {}", e)))?;
486        let attention_mask_arr = Array2::from_shape_vec((batch_size, seq_len), attention_mask)
487            .map_err(|e| Error::Parse(format!("Array: {}", e)))?;
488        let words_mask_arr = Array2::from_shape_vec((batch_size, seq_len), words_mask)
489            .map_err(|e| Error::Parse(format!("Array: {}", e)))?;
490        let text_lengths_arr =
491            Array2::from_shape_vec((batch_size, 1), vec![text_words.len() as i64])
492                .map_err(|e| Error::Parse(format!("Array: {}", e)))?;
493
494        let input_ids_t = super::ort_compat::tensor_from_ndarray(input_ids_arr)
495            .map_err(|e| Error::Parse(format!("Tensor: {}", e)))?;
496        let attention_mask_t = super::ort_compat::tensor_from_ndarray(attention_mask_arr)
497            .map_err(|e| Error::Parse(format!("Tensor: {}", e)))?;
498        let words_mask_t = super::ort_compat::tensor_from_ndarray(words_mask_arr)
499            .map_err(|e| Error::Parse(format!("Tensor: {}", e)))?;
500        let text_lengths_t = super::ort_compat::tensor_from_ndarray(text_lengths_arr)
501            .map_err(|e| Error::Parse(format!("Tensor: {}", e)))?;
502
503        // Run inference with blocking lock for thread-safe parallel access
504        let mut session = lock(&self.session);
505
506        let outputs = session
507            .run(ort::inputs![
508                "input_ids" => input_ids_t.into_dyn(),
509                "attention_mask" => attention_mask_t.into_dyn(),
510                "words_mask" => words_mask_t.into_dyn(),
511                "text_lengths" => text_lengths_t.into_dyn(),
512            ])
513            .map_err(|e| Error::Inference(format!("ONNX run: {}", e)))?;
514
515        // Decode output
516        self.decode_ner_output(&outputs, text, &text_words, entity_types, threshold)
517    }
518
519    /// Encode NER prompt: [START] [P] entities ([E]type1 ...) [SEP] word1 word2 ... [END]
520    fn encode_ner_prompt(
521        &self,
522        text_words: &[&str],
523        entity_types: &[&str],
524    ) -> Result<(Vec<i64>, Vec<i64>, Vec<i64>)> {
525        let mut input_ids: Vec<i64> = Vec::new();
526        let mut word_mask: Vec<i64> = Vec::new();
527
528        // Start token [CLS]
529        input_ids.push(TOKEN_START as i64);
530        word_mask.push(0);
531
532        // Entity types: <<ENT>>type1 <<ENT>>type2 ...
533        // Format for token-level GLiNER: [CLS] <<ENT>>type1 <<ENT>>type2 ... <<SEP>> text [SEP]
534        for entity_type in entity_types {
535            input_ids.push(TOKEN_ENT as i64);
536            word_mask.push(0);
537
538            let type_enc = self
539                .tokenizer
540                .encode(*entity_type, false)
541                .map_err(|e| Error::Parse(format!("Tokenize: {}", e)))?;
542            for token_id in type_enc.get_ids() {
543                input_ids.push(*token_id as i64);
544                word_mask.push(0);
545            }
546        }
547
548        // [SEP] token
549        input_ids.push(TOKEN_SEP as i64);
550        word_mask.push(0);
551
552        // Text words with word_mask tracking
553        for (word_idx, word) in text_words.iter().enumerate() {
554            let word_enc = self
555                .tokenizer
556                .encode(*word, false)
557                .map_err(|e| Error::Parse(format!("Tokenize: {}", e)))?;
558
559            let word_id = (word_idx + 1) as i64; // 1-indexed
560            for (token_idx, token_id) in word_enc.get_ids().iter().enumerate() {
561                input_ids.push(*token_id as i64);
562                // First subword gets word ID, rest get 0
563                word_mask.push(if token_idx == 0 { word_id } else { 0 });
564            }
565        }
566
567        // End token
568        input_ids.push(TOKEN_END as i64);
569        word_mask.push(0);
570
571        let seq_len = input_ids.len();
572        let attention_mask: Vec<i64> = vec![1; seq_len];
573
574        Ok((input_ids, attention_mask, word_mask))
575    }
576
577    /// Generate span tensors.
578    /// Generate span tensors for span-level models (not needed for token-level ONNX models)
579    #[allow(dead_code)]
580    fn make_span_tensors(&self, num_words: usize) -> (Vec<i64>, Vec<bool>) {
581        // Use checked_mul to prevent overflow (same pattern as line 2388)
582        let num_spans = num_words.checked_mul(MAX_SPAN_WIDTH).unwrap_or_else(|| {
583            log::warn!(
584                "Span count overflow: {} words * {} MAX_SPAN_WIDTH, using max",
585                num_words,
586                MAX_SPAN_WIDTH
587            );
588            usize::MAX
589        });
590        // Check for overflow in num_spans * 2
591        let span_idx_len = num_spans.checked_mul(2).unwrap_or_else(|| {
592            log::warn!(
593                "Span idx length overflow: {} spans * 2, using max",
594                num_spans
595            );
596            usize::MAX
597        });
598        let mut span_idx: Vec<i64> = vec![0; span_idx_len];
599        let mut span_mask: Vec<bool> = vec![false; num_spans];
600
601        for start in 0..num_words {
602            let remaining = num_words - start;
603            let max_width = MAX_SPAN_WIDTH.min(remaining);
604
605            for width in 0..max_width {
606                // Check for overflow in dim calculation (same pattern as nuner.rs:399)
607                let dim = match start.checked_mul(MAX_SPAN_WIDTH) {
608                    Some(v) => match v.checked_add(width) {
609                        Some(d) => d,
610                        None => {
611                            log::warn!(
612                                "Dim calculation overflow: {} * {} + {}, skipping span",
613                                start,
614                                MAX_SPAN_WIDTH,
615                                width
616                            );
617                            continue;
618                        }
619                    },
620                    None => {
621                        log::warn!(
622                            "Dim calculation overflow: {} * {}, skipping span",
623                            start,
624                            MAX_SPAN_WIDTH
625                        );
626                        continue;
627                    }
628                };
629                // Check bounds before array access (dim * 2 could overflow or exceed span_idx_len)
630                if let Some(dim2) = dim.checked_mul(2) {
631                    if dim2 + 1 < span_idx_len && dim < num_spans {
632                        span_idx[dim2] = start as i64;
633                        span_idx[dim2 + 1] = (start + width) as i64;
634                        span_mask[dim] = true;
635                    } else {
636                        log::warn!(
637                            "Span idx access out of bounds: dim={}, dim*2={}, span_idx_len={}, num_spans={}, skipping",
638                            dim, dim2, span_idx_len, num_spans
639                        );
640                    }
641                } else {
642                    log::warn!("Dim * 2 overflow: dim={}, skipping span", dim);
643                }
644            }
645        }
646
647        (span_idx, span_mask)
648    }
649
650    /// Decode NER output.
651    fn decode_ner_output(
652        &self,
653        outputs: &ort::session::SessionOutputs,
654        text: &str,
655        text_words: &[&str],
656        entity_types: &[&str],
657        threshold: f32,
658    ) -> Result<Vec<Entity>> {
659        let output = outputs
660            .iter()
661            .next()
662            .map(|(_, v)| v)
663            .ok_or_else(|| Error::Parse("No output".into()))?;
664
665        let (_, data_slice) = output
666            .try_extract_tensor::<f32>()
667            .map_err(|e| Error::Parse(format!("Extract tensor: {}", e)))?;
668        let output_data: Vec<f32> = data_slice.to_vec();
669
670        let shape: Vec<i64> = match output.dtype() {
671            ort::value::ValueType::Tensor { shape, .. } => shape.iter().copied().collect(),
672            _ => return Err(Error::Parse("Not a tensor".into())),
673        };
674
675        if output_data.is_empty() || shape.contains(&0) {
676            return Err(Error::Inference(
677                "GLiNER2 ONNX returned empty/degenerate output tensor. This usually indicates an incompatible ONNX export (shape mismatch or missing dynamic axes).".to_string(),
678            ));
679        }
680
681        let mut entities = Vec::new();
682        let num_words = text_words.len();
683
684        // Token-level model: shape [position, batch, num_words, num_classes]
685        // where position = 3 for BIO tagging (B=0, I=1, O=2)
686        if shape.len() == 4 && shape[0] == 3 && shape[1] == 1 {
687            let out_num_words = shape[2] as usize;
688            let num_classes = shape[3] as usize;
689            let word_class_size = out_num_words * num_classes;
690
691            // BIO decoding: find B-type starts, extend with I-type
692            // Output shape [BIO=3, batch=1, words, classes] flattened to [BIO * batch * words * classes]
693            // BIO dimension: 0=Begin, 1=Inside, 2=Outside
694            let b_offset = 0_usize; // Begin logits start at offset 0
695            let i_offset = word_class_size; // Inside logits start after B (1 * word_class_size)
696
697            #[allow(clippy::needless_range_loop)] // class_idx used for multiple array accesses
698            for class_idx in 0..num_classes.min(entity_types.len()) {
699                let mut current_start: Option<(usize, f32)> = None; // (word_idx, score)
700
701                for word_idx in 0..out_num_words.min(num_words) {
702                    // B logit at BIO dimension 0
703                    let b_idx = b_offset + word_idx * num_classes + class_idx;
704                    // I logit at BIO dimension 1
705                    let i_idx = i_offset + word_idx * num_classes + class_idx;
706
707                    let b_logit = if b_idx < output_data.len() {
708                        output_data[b_idx]
709                    } else {
710                        -100.0
711                    };
712                    let i_logit = if i_idx < output_data.len() {
713                        output_data[i_idx]
714                    } else {
715                        -100.0
716                    };
717
718                    let b_score = 1.0 / (1.0 + (-b_logit).exp());
719                    let i_score = 1.0 / (1.0 + (-i_logit).exp());
720
721                    if b_score >= threshold {
722                        // End any existing entity
723                        if let Some((start_word, avg_score)) = current_start.take() {
724                            let end_word = word_idx - 1;
725                            if start_word <= end_word && end_word < num_words {
726                                let span_text = text_words[start_word..=end_word].join(" ");
727                                let (start, end) = word_span_to_char_offsets(
728                                    text, text_words, start_word, end_word,
729                                );
730                                let entity_type = map_entity_type(entity_types[class_idx]);
731                                entities.push(Entity::new(
732                                    span_text,
733                                    entity_type,
734                                    start,
735                                    end,
736                                    avg_score as f64,
737                                ));
738                            }
739                        }
740                        // Start new entity
741                        current_start = Some((word_idx, b_score));
742                    } else if i_score >= threshold && current_start.is_some() {
743                        // Continue entity - update score
744                        if let Some((start_word, score)) = current_start {
745                            current_start = Some((start_word, (score + i_score) / 2.0));
746                        }
747                    } else if current_start.is_some() {
748                        // End entity
749                        if let Some((start_word, avg_score)) = current_start.take() {
750                            let end_word = word_idx - 1;
751                            if start_word <= end_word && end_word < num_words {
752                                let span_text = text_words[start_word..=end_word].join(" ");
753                                let (start, end) = word_span_to_char_offsets(
754                                    text, text_words, start_word, end_word,
755                                );
756                                let entity_type = map_entity_type(entity_types[class_idx]);
757                                entities.push(Entity::new(
758                                    span_text,
759                                    entity_type,
760                                    start,
761                                    end,
762                                    avg_score as f64,
763                                ));
764                            }
765                        }
766                    }
767                }
768
769                // Handle entity at end of text
770                if let Some((start_word, avg_score)) = current_start.take() {
771                    let end_word = out_num_words.min(num_words) - 1;
772                    if start_word <= end_word {
773                        let span_text = text_words[start_word..=end_word].join(" ");
774                        let (start, end) =
775                            word_span_to_char_offsets(text, text_words, start_word, end_word);
776                        let entity_type = map_entity_type(entity_types[class_idx]);
777                        entities.push(Entity::new(
778                            span_text,
779                            entity_type,
780                            start,
781                            end,
782                            avg_score as f64,
783                        ));
784                    }
785                }
786            }
787        }
788        // Span-level model: shape [batch, num_words, max_width, num_classes]
789        else if shape.len() == 4 && shape[0] == 1 {
790            let out_num_words = shape[1] as usize;
791            let out_max_width = shape[2] as usize;
792            let num_classes = shape[3] as usize;
793
794            for word_idx in 0..out_num_words.min(num_words) {
795                for width in 0..out_max_width.min(MAX_SPAN_WIDTH) {
796                    let end_word = word_idx + width;
797                    if end_word >= num_words {
798                        continue;
799                    }
800
801                    #[allow(clippy::needless_range_loop)] // class_idx used for index math
802                    for class_idx in 0..num_classes.min(entity_types.len()) {
803                        let idx = (word_idx * out_max_width * num_classes)
804                            + (width * num_classes)
805                            + class_idx;
806
807                        if idx < output_data.len() {
808                            let logit = output_data[idx];
809                            let score = 1.0 / (1.0 + (-logit).exp());
810
811                            if score >= threshold {
812                                let span_text = text_words[word_idx..=end_word].join(" ");
813                                let (start, end) =
814                                    word_span_to_char_offsets(text, text_words, word_idx, end_word);
815
816                                let entity_type = map_entity_type(entity_types[class_idx]);
817
818                                entities.push(Entity::new(
819                                    span_text,
820                                    entity_type,
821                                    start,
822                                    end,
823                                    score as f64,
824                                ));
825                            }
826                        }
827                    }
828                }
829            }
830        }
831
832        // Deduplicate
833        entities.sort_by(|a, b| a.start.cmp(&b.start).then_with(|| b.end.cmp(&a.end)));
834        entities.dedup_by(|a, b| a.start == b.start && a.end == b.end);
835
836        Ok(entities)
837    }
838
839    /// Decode batch NER output into per-text entity vectors.
840    fn decode_ner_batch_output(
841        &self,
842        outputs: &ort::session::SessionOutputs,
843        texts: &[&str],
844        text_words_batch: &[Vec<&str>],
845        entity_types: &[&str],
846        threshold: f32,
847    ) -> Result<Vec<Vec<Entity>>> {
848        let output = outputs
849            .iter()
850            .next()
851            .map(|(_, v)| v)
852            .ok_or_else(|| Error::Parse("No output".into()))?;
853
854        let (_, data_slice) = output
855            .try_extract_tensor::<f32>()
856            .map_err(|e| Error::Parse(format!("Extract tensor: {}", e)))?;
857        let output_data: Vec<f32> = data_slice.to_vec();
858
859        let shape: Vec<i64> = match output.dtype() {
860            ort::value::ValueType::Tensor { shape, .. } => shape.iter().copied().collect(),
861            _ => return Err(Error::Parse("Not a tensor".into())),
862        };
863
864        if output_data.is_empty() || shape.contains(&0) {
865            return Err(Error::Inference(
866                "GLiNER2 ONNX returned empty/degenerate output tensor. This usually indicates an incompatible ONNX export (shape mismatch or missing dynamic axes).".to_string(),
867            ));
868        }
869
870        let mut results = Vec::with_capacity(texts.len());
871
872        // Token-level BIO output: [bio=3, batch, num_words, num_classes]
873        if shape.len() == 4 && shape[0] == 3 {
874            let batch_size = shape[1] as usize;
875            let out_num_words = shape[2] as usize;
876            let num_classes = shape[3] as usize;
877
878            let per_bio = batch_size * out_num_words * num_classes;
879            let per_batch = out_num_words * num_classes;
880
881            for batch_idx in 0..batch_size.min(texts.len()) {
882                let text = texts[batch_idx];
883                let text_words = &text_words_batch[batch_idx];
884                let num_words = text_words.len();
885                let mut entities = Vec::new();
886
887                // BIO decoding: find B-type starts, extend with I-type
888                #[allow(clippy::needless_range_loop)] // class_idx used for index math
889                for class_idx in 0..num_classes.min(entity_types.len()) {
890                    let mut current_start: Option<(usize, f32)> = None; // (word_idx, score)
891
892                    for word_idx in 0..out_num_words.min(num_words) {
893                        // B logit at BIO dimension 0
894                        let b_idx = (batch_idx * per_batch) + (word_idx * num_classes) + class_idx;
895                        // I logit at BIO dimension 1
896                        let i_idx = per_bio
897                            + (batch_idx * per_batch)
898                            + (word_idx * num_classes)
899                            + class_idx;
900
901                        let b_logit = output_data.get(b_idx).copied().unwrap_or(-100.0);
902                        let i_logit = output_data.get(i_idx).copied().unwrap_or(-100.0);
903
904                        let b_score = 1.0 / (1.0 + (-b_logit).exp());
905                        let i_score = 1.0 / (1.0 + (-i_logit).exp());
906
907                        if b_score >= threshold {
908                            // End any existing entity
909                            if let Some((start_word, avg_score)) = current_start.take() {
910                                let end_word = word_idx.saturating_sub(1);
911                                if start_word <= end_word && end_word < num_words {
912                                    let span_text = text_words[start_word..=end_word].join(" ");
913                                    let (start, end) = word_span_to_char_offsets(
914                                        text, text_words, start_word, end_word,
915                                    );
916                                    let entity_type = map_entity_type(entity_types[class_idx]);
917                                    entities.push(Entity::new(
918                                        span_text,
919                                        entity_type,
920                                        start,
921                                        end,
922                                        avg_score as f64,
923                                    ));
924                                }
925                            }
926                            // Start new entity
927                            current_start = Some((word_idx, b_score));
928                        } else if i_score >= threshold && current_start.is_some() {
929                            // Continue entity - update score
930                            if let Some((start_word, score)) = current_start {
931                                current_start = Some((start_word, (score + i_score) / 2.0));
932                            }
933                        } else if current_start.is_some() {
934                            // End entity
935                            if let Some((start_word, avg_score)) = current_start.take() {
936                                let end_word = word_idx.saturating_sub(1);
937                                if start_word <= end_word && end_word < num_words {
938                                    let span_text = text_words[start_word..=end_word].join(" ");
939                                    let (start, end) = word_span_to_char_offsets(
940                                        text, text_words, start_word, end_word,
941                                    );
942                                    let entity_type = map_entity_type(entity_types[class_idx]);
943                                    entities.push(Entity::new(
944                                        span_text,
945                                        entity_type,
946                                        start,
947                                        end,
948                                        avg_score as f64,
949                                    ));
950                                }
951                            }
952                        }
953                    }
954
955                    // Handle entity at end of text
956                    if let Some((start_word, avg_score)) = current_start.take() {
957                        if !text_words.is_empty() {
958                            let end_word = out_num_words.min(num_words).saturating_sub(1);
959                            if start_word <= end_word && end_word < num_words {
960                                let span_text = text_words[start_word..=end_word].join(" ");
961                                let (start, end) = word_span_to_char_offsets(
962                                    text, text_words, start_word, end_word,
963                                );
964                                let entity_type = map_entity_type(entity_types[class_idx]);
965                                entities.push(Entity::new(
966                                    span_text,
967                                    entity_type,
968                                    start,
969                                    end,
970                                    avg_score as f64,
971                                ));
972                            }
973                        }
974                    }
975                }
976
977                entities
978                    .sort_unstable_by(|a, b| a.start.cmp(&b.start).then_with(|| b.end.cmp(&a.end)));
979                entities.dedup_by(|a, b| a.start == b.start && a.end == b.end);
980                results.push(entities);
981            }
982        }
983        // Span-level output: [batch, num_words, max_width, num_classes]
984        else if shape.len() == 4 {
985            let batch_size = shape[0] as usize;
986            let out_num_words = shape[1] as usize;
987            let out_max_width = shape[2] as usize;
988            let num_classes = shape[3] as usize;
989            let stride_per_batch = out_num_words * out_max_width * num_classes;
990
991            for batch_idx in 0..batch_size.min(texts.len()) {
992                let text = texts[batch_idx];
993                let text_words = &text_words_batch[batch_idx];
994                let num_words = text_words.len();
995                let batch_offset = batch_idx * stride_per_batch;
996                let mut entities = Vec::new();
997
998                for word_idx in 0..out_num_words.min(num_words) {
999                    for width in 0..out_max_width.min(MAX_SPAN_WIDTH) {
1000                        let end_word = word_idx + width;
1001                        if end_word >= num_words {
1002                            continue;
1003                        }
1004
1005                        #[allow(clippy::needless_range_loop)] // class_idx used for index math
1006                        for class_idx in 0..num_classes.min(entity_types.len()) {
1007                            let idx = batch_offset
1008                                + (word_idx * out_max_width * num_classes)
1009                                + (width * num_classes)
1010                                + class_idx;
1011
1012                            if idx < output_data.len() {
1013                                let logit = output_data[idx];
1014                                let score = 1.0 / (1.0 + (-logit).exp());
1015
1016                                if score >= threshold {
1017                                    let span_text = text_words[word_idx..=end_word].join(" ");
1018                                    let (start, end) = word_span_to_char_offsets(
1019                                        text, text_words, word_idx, end_word,
1020                                    );
1021
1022                                    let entity_type = map_entity_type(entity_types[class_idx]);
1023
1024                                    entities.push(Entity::new(
1025                                        span_text,
1026                                        entity_type,
1027                                        start,
1028                                        end,
1029                                        score as f64,
1030                                    ));
1031                                }
1032                            }
1033                        }
1034                    }
1035                }
1036
1037                entities
1038                    .sort_unstable_by(|a, b| a.start.cmp(&b.start).then_with(|| b.end.cmp(&a.end)));
1039                entities.dedup_by(|a, b| a.start == b.start && a.end == b.end);
1040                results.push(entities);
1041            }
1042        } else {
1043            return Err(Error::Inference(format!(
1044                "Unsupported GLiNER2 batch output shape: {:?}. Expected [3,batch,words,classes] (BIO) or [batch,words,width,classes] (span-level).",
1045                shape
1046            )));
1047        }
1048
1049        // Ensure output length matches input texts length (BatchCapable contract).
1050        while results.len() < texts.len() {
1051            results.push(Vec::new());
1052        }
1053
1054        Ok(results)
1055    }
1056
1057    /// Classify text.
1058    fn classify(
1059        &self,
1060        text: &str,
1061        labels: &[&str],
1062        multi_label: bool,
1063    ) -> Result<ClassificationResult> {
1064        if text.is_empty() || labels.is_empty() {
1065            return Ok(ClassificationResult::default());
1066        }
1067
1068        // For classification, encode <<ENT>>label1 <<ENT>>label2 ... <<SEP>> text
1069        // Using same format as NER since this model uses shared token markers
1070
1071        // Encode input
1072        let mut input_ids: Vec<i64> = Vec::new();
1073
1074        input_ids.push(TOKEN_START as i64);
1075
1076        // Labels: <<ENT>>label1 <<ENT>>label2 ...
1077        for label in labels {
1078            input_ids.push(TOKEN_ENT as i64);
1079            let label_enc = self
1080                .tokenizer
1081                .encode(*label, false)
1082                .map_err(|e| Error::Parse(format!("Tokenize: {}", e)))?;
1083            for id in label_enc.get_ids() {
1084                input_ids.push(*id as i64);
1085            }
1086        }
1087
1088        input_ids.push(TOKEN_SEP as i64);
1089
1090        // Text
1091        let text_enc = self
1092            .tokenizer
1093            .encode(text, false)
1094            .map_err(|e| Error::Parse(format!("Tokenize: {}", e)))?;
1095        for id in text_enc.get_ids() {
1096            input_ids.push(*id as i64);
1097        }
1098
1099        input_ids.push(TOKEN_END as i64);
1100
1101        let seq_len = input_ids.len();
1102        let attention_mask: Vec<i64> = vec![1; seq_len];
1103
1104        use ndarray::Array2;
1105
1106        let input_arr = Array2::from_shape_vec((1, seq_len), input_ids)
1107            .map_err(|e| Error::Parse(format!("Array: {}", e)))?;
1108        let attn_arr = Array2::from_shape_vec((1, seq_len), attention_mask)
1109            .map_err(|e| Error::Parse(format!("Array: {}", e)))?;
1110
1111        let input_t = super::ort_compat::tensor_from_ndarray(input_arr)
1112            .map_err(|e| Error::Parse(format!("Tensor: {}", e)))?;
1113        let attn_t = super::ort_compat::tensor_from_ndarray(attn_arr)
1114            .map_err(|e| Error::Parse(format!("Tensor: {}", e)))?;
1115
1116        // For classification models, we typically need just input_ids and attention_mask
1117        // The model should output classification logits
1118        let mut session = lock(&self.session);
1119
1120        // Try running with standard classification inputs
1121        let outputs = session
1122            .run(ort::inputs![
1123                "input_ids" => input_t.into_dyn(),
1124                "attention_mask" => attn_t.into_dyn(),
1125            ])
1126            .map_err(|e| Error::Inference(format!("ONNX run: {}", e)))?;
1127
1128        // Decode classification output
1129        let output = outputs
1130            .iter()
1131            .next()
1132            .map(|(_, v)| v)
1133            .ok_or_else(|| Error::Parse("No output".into()))?;
1134
1135        let (_, data_slice) = output
1136            .try_extract_tensor::<f32>()
1137            .map_err(|e| Error::Parse(format!("Extract: {}", e)))?;
1138        let logits: Vec<f32> = data_slice.to_vec();
1139
1140        // Apply softmax or sigmoid
1141        let probs = if multi_label {
1142            logits.iter().map(|&x| 1.0 / (1.0 + (-x).exp())).collect()
1143        } else {
1144            let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
1145            let exp_logits: Vec<f32> = logits.iter().map(|&x| (x - max_logit).exp()).collect();
1146            let sum: f32 = exp_logits.iter().sum();
1147            // Handle division by zero: if sum is 0 (all logits are -inf), return uniform distribution
1148            if sum > 0.0 {
1149                exp_logits.iter().map(|&x| x / sum).collect::<Vec<_>>()
1150            } else if logits.is_empty() {
1151                // Edge case: empty logits, return empty probabilities
1152                vec![]
1153            } else {
1154                // All logits are -inf, return uniform distribution
1155                let uniform = 1.0 / logits.len() as f32;
1156                vec![uniform; logits.len()]
1157            }
1158        };
1159
1160        let mut scores = HashMap::new();
1161        let mut selected_labels: Vec<String> = Vec::new();
1162
1163        for (i, label) in labels.iter().enumerate() {
1164            let prob = probs.get(i).copied().unwrap_or(0.0);
1165            scores.insert((*label).to_string(), prob);
1166
1167            if multi_label && prob > 0.5 {
1168                selected_labels.push((*label).to_string());
1169            }
1170        }
1171
1172        if !multi_label {
1173            if let Some((idx, _)) = probs
1174                .iter()
1175                .enumerate()
1176                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
1177            {
1178                if let Some(label) = labels.get(idx) {
1179                    selected_labels.push((*label).to_string());
1180                }
1181            }
1182        }
1183
1184        Ok(ClassificationResult {
1185            labels: selected_labels,
1186            scores,
1187        })
1188    }
1189
1190    /// Extract hierarchical structures.
1191    fn extract_structure(
1192        &self,
1193        text: &str,
1194        task: &StructureTask,
1195    ) -> Result<Vec<ExtractedStructure>> {
1196        if text.is_empty() || task.fields.is_empty() {
1197            return Ok(Vec::new());
1198        }
1199
1200        // For structure extraction, first predict count of instances
1201        // Then extract fields for each instance
1202        // For simplicity, we'll use NER-style extraction for each field
1203
1204        let mut structures = Vec::new();
1205
1206        // Extract each field as a span
1207        let field_names: Vec<&str> = task.fields.iter().map(|f| f.name.as_str()).collect();
1208        let field_entities = self.extract_ner(text, &field_names, 0.3)?;
1209
1210        // Group by field type and build structure
1211        let mut structure = ExtractedStructure {
1212            structure_type: task.name.clone(),
1213            fields: HashMap::new(),
1214        };
1215
1216        for field in &task.fields {
1217            let matching: Vec<_> = field_entities
1218                .iter()
1219                .filter(|e| e.entity_type.as_label().eq_ignore_ascii_case(&field.name))
1220                .collect();
1221
1222            if !matching.is_empty() {
1223                let value = match field.field_type {
1224                    FieldType::List => {
1225                        let values: Vec<String> = matching.iter().map(|e| e.text.clone()).collect();
1226                        StructureValue::List(values)
1227                    }
1228                    FieldType::Choice => {
1229                        if let Some(ref choices) = field.choices {
1230                            let extracted = matching.first().map(|e| e.text.as_str()).unwrap_or("");
1231                            let best = choices
1232                                .iter()
1233                                .find(|c| extracted.to_lowercase().contains(&c.to_lowercase()))
1234                                .cloned()
1235                                .unwrap_or_else(|| extracted.to_string());
1236                            StructureValue::Single(best)
1237                        } else {
1238                            StructureValue::Single(
1239                                matching.first().map(|e| e.text.clone()).unwrap_or_default(),
1240                            )
1241                        }
1242                    }
1243                    FieldType::String => StructureValue::Single(
1244                        matching.first().map(|e| e.text.clone()).unwrap_or_default(),
1245                    ),
1246                };
1247                structure.fields.insert(field.name.clone(), value);
1248            }
1249        }
1250
1251        if !structure.fields.is_empty() {
1252            structures.push(structure);
1253        }
1254
1255        Ok(structures)
1256    }
1257
1258    /// Build prompt string for logging.
1259    #[allow(dead_code)]
1260    fn build_prompt(&self, schema: &TaskSchema) -> String {
1261        let mut parts = Vec::new();
1262
1263        if let Some(ref ent_task) = schema.entities {
1264            let types: Vec<String> = ent_task
1265                .types
1266                .iter()
1267                .map(|t| {
1268                    if let Some(desc) = ent_task.descriptions.get(t) {
1269                        format!("[E] {}: {}", t, desc)
1270                    } else {
1271                        format!("[E] {}", t)
1272                    }
1273                })
1274                .collect();
1275            parts.push(format!("[P] entities ({})", types.join(" ")));
1276        }
1277
1278        for class_task in &schema.classifications {
1279            let labels: Vec<String> = class_task
1280                .labels
1281                .iter()
1282                .map(|l| format!("[L] {}", l))
1283                .collect();
1284            parts.push(format!("[P] {} ({})", class_task.name, labels.join(" ")));
1285        }
1286
1287        for struct_task in &schema.structures {
1288            let fields: Vec<String> = struct_task
1289                .fields
1290                .iter()
1291                .map(|f| format!("[C] {}", f.name))
1292                .collect();
1293            parts.push(format!("[P] {} ({})", struct_task.name, fields.join(" ")));
1294        }
1295
1296        parts.join(" [SEP] ")
1297    }
1298}
1299
1300// =============================================================================
1301// Candle Backend
1302// =============================================================================
1303
1304#[cfg(feature = "candle")]
1305use crate::backends::encoder_candle::TextEncoder;
1306#[cfg(feature = "candle")]
1307use candle_core::{DType, Device, IndexOp, Module, Tensor, D};
1308#[cfg(feature = "candle")]
1309use candle_nn::{Linear, VarBuilder};
1310
1311/// GLiNER2 Candle implementation.
1312#[cfg(feature = "candle")]
1313#[derive(Debug)]
1314pub struct GLiNER2Candle {
1315    /// Text encoder
1316    encoder: crate::backends::encoder_candle::CandleEncoder,
1317    /// Span representation layer
1318    span_rep: SpanRepLayer,
1319    /// Label projection
1320    label_proj: Linear,
1321    /// Classification head for [L] tokens
1322    class_head: ClassificationHead,
1323    /// Structure count predictor for [P] tokens
1324    count_predictor: CountPredictor,
1325    /// Device
1326    device: Device,
1327    #[allow(dead_code)]
1328    model_name: String,
1329    hidden_size: usize,
1330    /// Label embedding cache
1331    label_cache: LabelCache,
1332}
1333
1334/// Span representation layer (from GLiNER).
1335#[cfg(feature = "candle")]
1336pub struct SpanRepLayer {
1337    /// Width embeddings for spans of different sizes
1338    width_embeddings: candle_nn::Embedding,
1339    /// Max span width
1340    max_width: usize,
1341}
1342
1343#[cfg(feature = "candle")]
1344impl std::fmt::Debug for SpanRepLayer {
1345    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1346        f.debug_struct("SpanRepLayer")
1347            .field("max_width", &self.max_width)
1348            .finish()
1349    }
1350}
1351
1352/// Classification head for text classification tasks.
1353#[cfg(feature = "candle")]
1354pub struct ClassificationHead {
1355    /// MLP that projects [L] token embeddings to logits
1356    mlp: Linear,
1357}
1358
1359#[cfg(feature = "candle")]
1360impl std::fmt::Debug for ClassificationHead {
1361    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1362        f.debug_struct("ClassificationHead").finish()
1363    }
1364}
1365
1366/// Count predictor for hierarchical structure extraction.
1367#[cfg(feature = "candle")]
1368pub struct CountPredictor {
1369    /// MLP that predicts instance count (0-19)
1370    mlp: Linear,
1371}
1372
1373#[cfg(feature = "candle")]
1374impl std::fmt::Debug for CountPredictor {
1375    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1376        f.debug_struct("CountPredictor").finish()
1377    }
1378}
1379
1380#[cfg(feature = "candle")]
1381impl SpanRepLayer {
1382    fn new(hidden_size: usize, max_width: usize, vb: VarBuilder) -> Result<Self> {
1383        let width_embeddings =
1384            candle_nn::embedding(max_width, hidden_size, vb.pp("width_embeddings"))
1385                .map_err(|e| Error::Retrieval(format!("width_embeddings: {}", e)))?;
1386        Ok(Self {
1387            width_embeddings,
1388            max_width,
1389        })
1390    }
1391
1392    fn forward(&self, token_embeddings: &Tensor, span_indices: &Tensor) -> Result<Tensor> {
1393        let device = token_embeddings.device();
1394        let batch_size = token_embeddings.dims()[0];
1395        let _seq_len = token_embeddings.dims()[1];
1396        let hidden_size = token_embeddings.dims()[2];
1397        let num_spans = span_indices.dims()[1];
1398
1399        let mut all_span_embs = Vec::new();
1400
1401        for b in 0..batch_size {
1402            let batch_tokens = token_embeddings
1403                .i(b)
1404                .map_err(|e| Error::Inference(format!("batch index: {}", e)))?;
1405            let batch_spans = span_indices
1406                .i(b)
1407                .map_err(|e| Error::Inference(format!("span index: {}", e)))?;
1408
1409            let spans_data = batch_spans
1410                .to_vec2::<i64>()
1411                .map_err(|e| Error::Inference(format!("spans to vec: {}", e)))?;
1412
1413            let mut span_embs = Vec::new();
1414
1415            for span in spans_data {
1416                let start = span[0] as usize;
1417                let end = span[1] as usize;
1418                // Validate span: end must be > start to prevent underflow
1419                if end <= start {
1420                    log::warn!("Invalid span: end ({}) <= start ({})", end, start);
1421                    continue;
1422                }
1423                let width = end - start;
1424
1425                // Get start token embedding
1426                let start_emb = batch_tokens
1427                    .i(start.min(batch_tokens.dims()[0] - 1))
1428                    .map_err(|e| Error::Inference(format!("start emb: {}", e)))?;
1429
1430                // Get width embedding
1431                let width_idx = width.min(self.max_width - 1);
1432                let width_emb = self
1433                    .width_embeddings
1434                    .forward(
1435                        &Tensor::new(&[width_idx as u32], device)
1436                            .map_err(|e| Error::Inference(format!("width idx: {}", e)))?,
1437                    )
1438                    .map_err(|e| Error::Inference(format!("width emb: {}", e)))?
1439                    .squeeze(0)
1440                    .map_err(|e| Error::Inference(format!("squeeze: {}", e)))?;
1441
1442                // Combine: start + width (could also use end and pool)
1443                let combined = start_emb
1444                    .add(&width_emb)
1445                    .map_err(|e| Error::Inference(format!("add: {}", e)))?;
1446
1447                let emb_vec = combined
1448                    .to_vec1::<f32>()
1449                    .map_err(|e| Error::Inference(format!("to vec: {}", e)))?;
1450                span_embs.extend(emb_vec);
1451            }
1452
1453            all_span_embs.extend(span_embs);
1454        }
1455
1456        Tensor::from_vec(all_span_embs, (batch_size, num_spans, hidden_size), device)
1457            .map_err(|e| Error::Inference(format!("span tensor: {}", e)))
1458    }
1459}
1460
1461#[cfg(feature = "candle")]
1462impl ClassificationHead {
1463    fn new(hidden_size: usize, vb: VarBuilder) -> Result<Self> {
1464        let mlp = candle_nn::linear(hidden_size, 1, vb.pp("mlp"))
1465            .map_err(|e| Error::Retrieval(format!("classification mlp: {}", e)))?;
1466        Ok(Self { mlp })
1467    }
1468
1469    /// Forward pass: project label embeddings to logits.
1470    fn forward(&self, label_embeddings: &Tensor) -> Result<Tensor> {
1471        self.mlp
1472            .forward(label_embeddings)
1473            .map_err(|e| Error::Inference(format!("class head forward: {}", e)))
1474    }
1475}
1476
1477#[cfg(feature = "candle")]
1478impl CountPredictor {
1479    fn new(hidden_size: usize, max_count: usize, vb: VarBuilder) -> Result<Self> {
1480        let mlp = candle_nn::linear(hidden_size, max_count, vb.pp("mlp"))
1481            .map_err(|e| Error::Retrieval(format!("count mlp: {}", e)))?;
1482        Ok(Self { mlp })
1483    }
1484
1485    /// Predict number of structure instances from [P] token embedding.
1486    fn forward(&self, prompt_embedding: &Tensor) -> Result<usize> {
1487        let logits = self
1488            .mlp
1489            .forward(prompt_embedding)
1490            .map_err(|e| Error::Inference(format!("count forward: {}", e)))?;
1491
1492        // Argmax to get predicted count
1493        let logits_vec = logits
1494            .flatten_all()
1495            .map_err(|e| Error::Inference(format!("flatten: {}", e)))?
1496            .to_vec1::<f32>()
1497            .map_err(|e| Error::Inference(format!("to vec: {}", e)))?;
1498
1499        let (max_idx, _) = logits_vec
1500            .iter()
1501            .enumerate()
1502            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
1503            .unwrap_or((1, &0.0));
1504
1505        Ok(max_idx.max(1)) // At least 1 instance
1506    }
1507}
1508
1509#[cfg(feature = "candle")]
1510impl GLiNER2Candle {
1511    /// Load model from HuggingFace Hub.
1512    pub fn from_pretrained(model_id: &str) -> Result<Self> {
1513        use crate::backends::encoder_candle::CandleEncoder;
1514        use hf_hub::api::sync::Api;
1515
1516        let api = Api::new().map_err(|e| Error::Retrieval(format!("HF API: {}", e)))?;
1517        let repo = api.model(model_id.to_string());
1518
1519        // Load config
1520        let config_path = repo
1521            .get("config.json")
1522            .map_err(|e| Error::Retrieval(format!("config.json: {}", e)))?;
1523        let config_str = std::fs::read_to_string(&config_path)
1524            .map_err(|e| Error::Retrieval(format!("read config: {}", e)))?;
1525        let config: serde_json::Value = serde_json::from_str(&config_str)
1526            .map_err(|e| Error::Parse(format!("parse config: {}", e)))?;
1527        let hidden_size = config["hidden_size"].as_u64().unwrap_or(768) as usize;
1528
1529        // Determine device
1530        let device = Device::cuda_if_available(0).unwrap_or(Device::Cpu);
1531
1532        // Load weights - try safetensors first, then convert pytorch if needed
1533        let weights_path = repo
1534            .get("model.safetensors")
1535            .or_else(|_| repo.get("gliner_model.safetensors"))
1536            .or_else(|_| {
1537                // Try to convert pytorch_model.bin to safetensors
1538                let pytorch_path = repo.get("pytorch_model.bin")?;
1539                crate::backends::gliner_candle::convert_pytorch_to_safetensors(&pytorch_path)
1540            })
1541            .map_err(|e| {
1542                Error::Retrieval(format!("weights not found and conversion failed: {}", e))
1543            })?;
1544
1545        // SAFETY: VarBuilder::from_mmaped_safetensors uses unsafe internally for memory mapping.
1546        // The weights_path is validated to exist before this call, and the safetensors format
1547        // is validated by the library. This is a safe FFI boundary.
1548        let vb = unsafe {
1549            VarBuilder::from_mmaped_safetensors(&[weights_path], DType::F32, &device)
1550                .map_err(|e| Error::Retrieval(format!("varbuilder: {}", e)))?
1551        };
1552
1553        // Build components
1554        let encoder = CandleEncoder::from_pretrained(model_id)?;
1555        let span_rep = SpanRepLayer::new(hidden_size, MAX_SPAN_WIDTH, vb.pp("span_rep"))?;
1556        let label_proj = candle_nn::linear(hidden_size, hidden_size, vb.pp("label_projection"))
1557            .map_err(|e| Error::Retrieval(format!("label_projection: {}", e)))?;
1558        let class_head = ClassificationHead::new(hidden_size, vb.pp("classification"))?;
1559        let count_predictor =
1560            CountPredictor::new(hidden_size, MAX_COUNT, vb.pp("count_predictor"))?;
1561
1562        log::info!(
1563            "[GLiNER2-Candle] Loaded {} (hidden={}) on {:?}",
1564            model_id,
1565            hidden_size,
1566            device
1567        );
1568
1569        Ok(Self {
1570            encoder,
1571            span_rep,
1572            label_proj,
1573            class_head,
1574            count_predictor,
1575            device,
1576            model_name: model_id.to_string(),
1577            hidden_size,
1578            label_cache: LabelCache::new(),
1579        })
1580    }
1581
1582    /// Extract entities, classifications, and structures according to schema.
1583    pub fn extract(&self, text: &str, schema: &TaskSchema) -> Result<ExtractionResult> {
1584        let mut result = ExtractionResult::default();
1585
1586        // NER extraction
1587        if let Some(ref ent_task) = schema.entities {
1588            let entities = self.extract_entities(text, &ent_task.types, 0.5)?;
1589            result.entities = entities;
1590        }
1591
1592        // Classification
1593        for class_task in &schema.classifications {
1594            let class_result = self.classify(text, &class_task.labels, class_task.multi_label)?;
1595            result
1596                .classifications
1597                .insert(class_task.name.clone(), class_result);
1598        }
1599
1600        // Structure extraction with count prediction
1601        for struct_task in &schema.structures {
1602            let structures = self.extract_structure_with_count(text, struct_task)?;
1603            result.structures.extend(structures);
1604        }
1605
1606        Ok(result)
1607    }
1608
1609    /// Extract named entities with zero-shot labels.
1610    fn extract_entities(
1611        &self,
1612        text: &str,
1613        types: &[String],
1614        threshold: f32,
1615    ) -> Result<Vec<Entity>> {
1616        if text.is_empty() || types.is_empty() {
1617            return Ok(Vec::new());
1618        }
1619
1620        let labels: Vec<&str> = types.iter().map(|s| s.as_str()).collect();
1621
1622        // Tokenize and get words
1623        let words: Vec<&str> = text.split_whitespace().collect();
1624        if words.is_empty() {
1625            return Ok(Vec::new());
1626        }
1627
1628        // Encode text
1629        let (text_embeddings, word_positions) = self.encode_text(&words)?;
1630
1631        // Encode labels (with caching)
1632        let label_embeddings = self.encode_labels_cached(&labels)?;
1633
1634        // Generate span candidates
1635        let span_indices = self.generate_spans(words.len())?;
1636
1637        // Compute span embeddings
1638        let span_embs = self.span_rep.forward(&text_embeddings, &span_indices)?;
1639
1640        // Project labels
1641        let label_embs = self
1642            .label_proj
1643            .forward(&label_embeddings)
1644            .map_err(|e| Error::Inference(format!("label projection: {}", e)))?;
1645
1646        // Match spans to labels via cosine similarity
1647        let scores = self.match_spans_labels(&span_embs, &label_embs)?;
1648
1649        // Decode to entities
1650        self.decode_entities(text, &words, &word_positions, &scores, &labels, threshold)
1651    }
1652
1653    /// Classify text using the ClassificationHead.
1654    fn classify(
1655        &self,
1656        text: &str,
1657        labels: &[String],
1658        multi_label: bool,
1659    ) -> Result<ClassificationResult> {
1660        if text.is_empty() || labels.is_empty() {
1661            return Ok(ClassificationResult::default());
1662        }
1663
1664        // Encode text and get [CLS] embedding
1665        let (text_emb, _seq_len) = self.encoder.encode(text)?;
1666        let cls_emb = Tensor::from_vec(
1667            text_emb[..self.hidden_size].to_vec(),
1668            (1, self.hidden_size),
1669            &self.device,
1670        )
1671        .map_err(|e| Error::Inference(format!("cls tensor: {}", e)))?;
1672
1673        // Encode labels
1674        let labels_str: Vec<&str> = labels.iter().map(|s| s.as_str()).collect();
1675        let label_embs = self.encode_labels_cached(&labels_str)?;
1676
1677        // Use classification head to get logits
1678        let label_logits = self.class_head.forward(&label_embs)?;
1679        let label_logits_vec = label_logits
1680            .flatten_all()
1681            .map_err(|e| Error::Inference(format!("flatten: {}", e)))?
1682            .to_vec1::<f32>()
1683            .map_err(|e| Error::Inference(format!("to vec: {}", e)))?;
1684
1685        // Also compute similarity for ranking
1686        let cls_norm = l2_normalize(&cls_emb, D::Minus1)?;
1687        let label_norm = l2_normalize(&label_embs, D::Minus1)?;
1688
1689        let sim_scores = cls_norm
1690            .matmul(
1691                &label_norm
1692                    .t()
1693                    .map_err(|e| Error::Inference(format!("transpose: {}", e)))?,
1694            )
1695            .map_err(|e| Error::Inference(format!("matmul: {}", e)))?;
1696
1697        let sim_vec = sim_scores
1698            .flatten_all()
1699            .map_err(|e| Error::Inference(format!("flatten: {}", e)))?
1700            .to_vec1::<f32>()
1701            .map_err(|e| Error::Inference(format!("to vec: {}", e)))?;
1702
1703        // Combine head logits with similarity (weighted)
1704        let combined: Vec<f32> = sim_vec
1705            .iter()
1706            .zip(label_logits_vec.iter().cycle())
1707            .map(|(s, l)| 0.7 * s + 0.3 * l)
1708            .collect();
1709
1710        // Apply softmax (single-label) or sigmoid (multi-label)
1711        let probs = if multi_label {
1712            combined.iter().map(|&s| 1.0 / (1.0 + (-s).exp())).collect()
1713        } else {
1714            let max_score = combined.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
1715            let exp_scores: Vec<f32> = combined.iter().map(|&s| (s - max_score).exp()).collect();
1716            let sum: f32 = exp_scores.iter().sum();
1717            // Handle division by zero: if sum is 0 (all logits are -inf), return uniform distribution
1718            if sum > 0.0 {
1719                exp_scores.iter().map(|&e| e / sum).collect::<Vec<_>>()
1720            } else if combined.is_empty() {
1721                // Edge case: empty scores, return empty probabilities
1722                vec![]
1723            } else {
1724                // All scores are -inf, return uniform distribution
1725                let uniform = 1.0 / combined.len() as f32;
1726                vec![uniform; combined.len()]
1727            }
1728        };
1729
1730        let mut scores_map = HashMap::new();
1731        let mut result_labels = Vec::new();
1732
1733        for (i, label) in labels.iter().enumerate() {
1734            let prob = probs.get(i).copied().unwrap_or(0.0);
1735            scores_map.insert(label.clone(), prob);
1736
1737            if multi_label && prob > 0.5 {
1738                result_labels.push(label.clone());
1739            }
1740        }
1741
1742        if !multi_label {
1743            if let Some((idx, _)) = probs
1744                .iter()
1745                .enumerate()
1746                .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
1747            {
1748                if let Some(label) = labels.get(idx) {
1749                    result_labels.push(label.clone());
1750                }
1751            }
1752        }
1753
1754        Ok(ClassificationResult {
1755            labels: result_labels,
1756            scores: scores_map,
1757        })
1758    }
1759
1760    /// Extract hierarchical structures using count predictor.
1761    fn extract_structure_with_count(
1762        &self,
1763        text: &str,
1764        task: &StructureTask,
1765    ) -> Result<Vec<ExtractedStructure>> {
1766        if text.is_empty() || task.fields.is_empty() {
1767            return Ok(Vec::new());
1768        }
1769
1770        // Encode text to get [P] token embedding for count prediction
1771        let (text_emb, _) = self.encoder.encode(text)?;
1772        let prompt_emb = Tensor::from_vec(
1773            text_emb[..self.hidden_size].to_vec(),
1774            (self.hidden_size,),
1775            &self.device,
1776        )
1777        .map_err(|e| Error::Inference(format!("prompt tensor: {}", e)))?;
1778
1779        // Predict number of instances
1780        let num_instances = self.count_predictor.forward(&prompt_emb)?;
1781
1782        log::debug!(
1783            "[GLiNER2] Count predictor: {} instances for {}",
1784            num_instances,
1785            task.name
1786        );
1787
1788        let mut structures = Vec::new();
1789
1790        // Extract fields for each predicted instance
1791        for instance_idx in 0..num_instances {
1792            let mut structure = ExtractedStructure {
1793                structure_type: task.name.clone(),
1794                fields: HashMap::new(),
1795            };
1796
1797            for field in &task.fields {
1798                let field_label = field.description.as_ref().unwrap_or(&field.name);
1799
1800                // Extract values for this field
1801                let labels_vec: Vec<String> = vec![field_label.to_string()];
1802                let entities = self.extract_entities(text, &labels_vec, 0.3)?;
1803
1804                // For multi-instance, try to get the nth entity
1805                let entity_for_instance = entities.get(instance_idx);
1806
1807                if let Some(entity) = entity_for_instance {
1808                    let value = match field.field_type {
1809                        FieldType::List => {
1810                            // For list type, get all matching entities
1811                            let values: Vec<String> =
1812                                entities.iter().map(|e| e.text.clone()).collect();
1813                            StructureValue::List(values)
1814                        }
1815                        FieldType::Choice => {
1816                            if let Some(ref choices) = field.choices {
1817                                let extracted = &entity.text;
1818                                let best_choice = choices
1819                                    .iter()
1820                                    .find(|c| extracted.to_lowercase().contains(&c.to_lowercase()))
1821                                    .cloned()
1822                                    .unwrap_or_else(|| extracted.clone());
1823                                StructureValue::Single(best_choice)
1824                            } else {
1825                                StructureValue::Single(entity.text.clone())
1826                            }
1827                        }
1828                        FieldType::String => StructureValue::Single(entity.text.clone()),
1829                    };
1830
1831                    structure.fields.insert(field.name.clone(), value);
1832                }
1833            }
1834
1835            if !structure.fields.is_empty() {
1836                structures.push(structure);
1837            }
1838        }
1839
1840        Ok(structures)
1841    }
1842
1843    // =========================================================================
1844    // Helper methods
1845    // =========================================================================
1846
1847    fn encode_text(&self, words: &[&str]) -> Result<(Tensor, Vec<(usize, usize)>)> {
1848        let text = words.join(" ");
1849        let (embeddings, seq_len) = self.encoder.encode(&text)?;
1850
1851        // Reshape to [1, seq_len, hidden]
1852        let tensor = Tensor::from_vec(embeddings, (1, seq_len, self.hidden_size), &self.device)
1853            .map_err(|e| Error::Inference(format!("text tensor: {}", e)))?;
1854
1855        // Build word positions using character offsets
1856        let full_text = words.join(" ");
1857        let word_positions: Vec<(usize, usize)> = {
1858            let mut positions = Vec::new();
1859            let mut pos = 0;
1860            for (idx, word) in words.iter().enumerate() {
1861                if let Some(start) = full_text[pos..].find(word) {
1862                    let abs_start = pos + start;
1863                    let abs_end = abs_start + word.len();
1864                    // Validate position is after previous word (words should be in order)
1865                    if !positions.is_empty() {
1866                        let (_prev_start, prev_end) = positions[positions.len() - 1];
1867                        if abs_start < prev_end {
1868                            log::warn!(
1869                                "Word '{}' (index {}) at position {} overlaps with previous word ending at {}",
1870                                word,
1871                                idx,
1872                                abs_start,
1873                                prev_end
1874                            );
1875                        }
1876                    }
1877                    positions.push((abs_start, abs_end));
1878                    pos = abs_end;
1879                } else {
1880                    // Word not found - return error to prevent silent entity skipping
1881                    return Err(Error::Inference(format!(
1882                        "Word '{}' (index {}) not found in text starting at position {}",
1883                        word, idx, pos
1884                    )));
1885                }
1886            }
1887            positions
1888        };
1889
1890        // Validate that we found positions for all words
1891        if word_positions.len() != words.len() {
1892            return Err(Error::Inference(format!(
1893                "Word position mismatch: found {} positions for {} words",
1894                word_positions.len(),
1895                words.len()
1896            )));
1897        }
1898
1899        Ok((tensor, word_positions))
1900    }
1901
1902    fn encode_labels_cached(&self, labels: &[&str]) -> Result<Tensor> {
1903        let mut all_embeddings = Vec::new();
1904
1905        for label in labels {
1906            // Check cache first
1907            if let Some(cached) = self.label_cache.get(label) {
1908                all_embeddings.extend(cached);
1909            } else {
1910                let (embeddings, seq_len) = self.encoder.encode(label)?;
1911                // Average pool - handle empty sequences
1912                let avg: Vec<f32> = if seq_len == 0 {
1913                    // Return zero vector for empty sequences
1914                    vec![0.0f32; self.hidden_size]
1915                } else {
1916                    (0..self.hidden_size)
1917                        .map(|i| {
1918                            embeddings
1919                                .iter()
1920                                .skip(i)
1921                                .step_by(self.hidden_size)
1922                                .take(seq_len)
1923                                .sum::<f32>()
1924                                / seq_len as f32
1925                        })
1926                        .collect()
1927                };
1928
1929                // Cache it
1930                self.label_cache.insert(label.to_string(), avg.clone());
1931                all_embeddings.extend(avg);
1932            }
1933        }
1934
1935        Tensor::from_vec(
1936            all_embeddings,
1937            (labels.len(), self.hidden_size),
1938            &self.device,
1939        )
1940        .map_err(|e| Error::Inference(format!("label tensor: {}", e)))
1941    }
1942
1943    fn generate_spans(&self, num_words: usize) -> Result<Tensor> {
1944        // Performance: Pre-allocate spans vec with estimated capacity
1945        // num_words * MAX_SPAN_WIDTH * 2 (for start/end pairs)
1946        let estimated_capacity = num_words.saturating_mul(MAX_SPAN_WIDTH).saturating_mul(2);
1947        let mut spans = Vec::with_capacity(estimated_capacity.min(1000));
1948
1949        for start in 0..num_words {
1950            for width in 0..MAX_SPAN_WIDTH.min(num_words - start) {
1951                let end = start + width;
1952                spans.push(start as i64);
1953                spans.push(end as i64);
1954            }
1955        }
1956
1957        let num_spans = spans.len() / 2;
1958        Tensor::from_vec(spans, (1, num_spans, 2), &self.device)
1959            .map_err(|e| Error::Inference(format!("span tensor: {}", e)))
1960    }
1961
1962    fn match_spans_labels(&self, span_embs: &Tensor, label_embs: &Tensor) -> Result<Tensor> {
1963        let span_norm = l2_normalize(span_embs, D::Minus1)?;
1964        let label_norm = l2_normalize(label_embs, D::Minus1)?;
1965
1966        let batch_size = span_norm.dims()[0];
1967        let label_t = label_norm
1968            .t()
1969            .map_err(|e| Error::Inference(format!("transpose: {}", e)))?;
1970        let label_t = label_t
1971            .unsqueeze(0)
1972            .map_err(|e| Error::Inference(format!("unsqueeze: {}", e)))?
1973            .broadcast_as((batch_size, label_t.dims()[0], label_t.dims()[1]))
1974            .map_err(|e| Error::Inference(format!("broadcast: {}", e)))?;
1975
1976        let scores = span_norm
1977            .matmul(&label_t)
1978            .map_err(|e| Error::Inference(format!("matmul: {}", e)))?;
1979
1980        candle_nn::ops::sigmoid(&scores).map_err(|e| Error::Inference(format!("sigmoid: {}", e)))
1981    }
1982
1983    fn decode_entities(
1984        &self,
1985        text: &str,
1986        words: &[&str],
1987        _word_positions: &[(usize, usize)],
1988        scores: &Tensor,
1989        labels: &[&str],
1990        threshold: f32,
1991    ) -> Result<Vec<Entity>> {
1992        let scores_vec = scores
1993            .flatten_all()
1994            .map_err(|e| Error::Inference(format!("flatten scores: {}", e)))?
1995            .to_vec1::<f32>()
1996            .map_err(|e| Error::Inference(format!("scores to vec: {}", e)))?;
1997
1998        let num_labels = labels.len();
1999        let num_spans = scores_vec.len() / num_labels;
2000
2001        // Performance: Pre-allocate entities vec with estimated capacity
2002        let mut entities = Vec::with_capacity(num_spans.min(32));
2003        let mut span_idx = 0;
2004
2005        for start in 0..words.len() {
2006            for width in 0..MAX_SPAN_WIDTH.min(words.len() - start) {
2007                if span_idx >= num_spans {
2008                    break;
2009                }
2010
2011                let end = start + width;
2012
2013                for (label_idx, label) in labels.iter().enumerate() {
2014                    let score = scores_vec[span_idx * num_labels + label_idx];
2015
2016                    if score >= threshold {
2017                        let span_text = words[start..=end].join(" ");
2018                        let (char_start, char_end) =
2019                            word_span_to_char_offsets(text, words, start, end);
2020
2021                        let entity_type = map_entity_type(label);
2022
2023                        entities.push(Entity::new(
2024                            span_text,
2025                            entity_type,
2026                            char_start,
2027                            char_end,
2028                            score as f64,
2029                        ));
2030                    }
2031                }
2032
2033                span_idx += 1;
2034            }
2035        }
2036
2037        // Deduplicate
2038        entities.sort_by(|a, b| a.start.cmp(&b.start).then_with(|| b.end.cmp(&a.end)));
2039        entities.dedup_by(|a, b| a.start == b.start && a.end == b.end);
2040
2041        Ok(entities)
2042    }
2043}
2044
2045/// L2 normalize tensor along dimension.
2046#[cfg(feature = "candle")]
2047fn l2_normalize(tensor: &Tensor, dim: D) -> Result<Tensor> {
2048    let norm = tensor
2049        .sqr()
2050        .map_err(|e| Error::Inference(format!("sqr: {}", e)))?
2051        .sum(dim)
2052        .map_err(|e| Error::Inference(format!("sum: {}", e)))?
2053        .sqrt()
2054        .map_err(|e| Error::Inference(format!("sqrt: {}", e)))?
2055        .unsqueeze(D::Minus1)
2056        .map_err(|e| Error::Inference(format!("unsqueeze: {}", e)))?;
2057
2058    let norm_clamped = norm
2059        .clamp(1e-12, f32::MAX)
2060        .map_err(|e| Error::Inference(format!("clamp: {}", e)))?;
2061
2062    tensor
2063        .broadcast_div(&norm_clamped)
2064        .map_err(|e| Error::Inference(format!("div: {}", e)))
2065}
2066
2067// =============================================================================
2068// Stub implementations (no feature)
2069// =============================================================================
2070
2071/// GLiNER2 stub (requires onnx or candle feature).
2072#[cfg(not(any(feature = "onnx", feature = "candle")))]
2073#[derive(Debug)]
2074pub struct GLiNER2 {
2075    _private: (),
2076}
2077
2078#[cfg(not(any(feature = "onnx", feature = "candle")))]
2079impl GLiNER2 {
2080    /// Load model (requires feature).
2081    pub fn from_pretrained(_model_id: &str) -> Result<Self> {
2082        Err(Error::FeatureNotAvailable(
2083            "GLiNER2 requires 'onnx' or 'candle' feature. \
2084             Build with: cargo build --features candle"
2085                .to_string(),
2086        ))
2087    }
2088
2089    /// Extract (requires feature).
2090    pub fn extract(&self, _text: &str, _schema: &TaskSchema) -> Result<ExtractionResult> {
2091        Err(Error::FeatureNotAvailable(
2092            "GLiNER2 requires features".to_string(),
2093        ))
2094    }
2095}
2096
2097// =============================================================================
2098// Unified GLiNER2 type
2099// =============================================================================
2100
2101/// GLiNER2 model - automatically selects best available backend.
2102#[cfg(feature = "candle")]
2103pub type GLiNER2 = GLiNER2Candle;
2104
2105/// GLiNER2 model - ONNX backend (when candle not enabled).
2106#[cfg(all(feature = "onnx", not(feature = "candle")))]
2107pub type GLiNER2 = GLiNER2Onnx;
2108
2109// =============================================================================
2110// Helper functions
2111// =============================================================================
2112
2113/// Convert word span indices to character offsets.
2114fn word_span_to_char_offsets(
2115    text: &str,
2116    words: &[&str],
2117    start_word: usize,
2118    end_word: usize,
2119) -> (usize, usize) {
2120    // Defensive: Validate bounds
2121    if words.is_empty()
2122        || start_word >= words.len()
2123        || end_word >= words.len()
2124        || start_word > end_word
2125    {
2126        // Return safe defaults: empty span at start of text
2127        return (0, 0);
2128    }
2129
2130    // Track our search position in **bytes**.
2131    let mut byte_pos = 0;
2132    let mut start_byte = 0;
2133    let mut end_byte = text.len();
2134    let mut found_start = false;
2135    let mut found_end = false;
2136
2137    for (i, word) in words.iter().enumerate() {
2138        if let Some(pos) = text.get(byte_pos..).and_then(|s| s.find(word)) {
2139            let abs_pos = byte_pos + pos;
2140
2141            if i == start_word {
2142                start_byte = abs_pos;
2143                found_start = true;
2144            }
2145            if i == end_word {
2146                end_byte = abs_pos + word.len();
2147                found_end = true;
2148                // Early exit: we found both start and end
2149                break;
2150            }
2151
2152            byte_pos = abs_pos + word.len();
2153        } else {
2154            // Word not found - this shouldn't happen in normal operation,
2155            // but if it does, we can't reliably compute offsets
2156            // Continue searching but mark that we may have incorrect results
2157        }
2158    }
2159
2160    // If we didn't find the words, return safe defaults
2161    if !found_start || !found_end {
2162        // Return empty span to avoid incorrect entity extraction
2163        (0, 0)
2164    } else {
2165        // Convert byte offsets to character offsets (anno spans are char-based).
2166        crate::offset::bytes_to_chars(text, start_byte, end_byte)
2167    }
2168}
2169
2170/// Map entity type string to EntityType.
2171///
2172/// Uses the canonical schema mapper for consistent semantics across all backends.
2173fn map_entity_type(type_str: &str) -> EntityType {
2174    crate::schema::map_to_canonical(type_str, None)
2175}
2176
2177// =============================================================================
2178// Model Trait Implementation (ONNX)
2179// =============================================================================
2180
2181#[cfg(feature = "onnx")]
2182impl crate::Model for GLiNER2Onnx {
2183    fn extract_entities(&self, text: &str, _language: Option<&str>) -> Result<Vec<Entity>> {
2184        let schema = TaskSchema::new().with_entities(&[
2185            "person",
2186            "organization",
2187            "location",
2188            "date",
2189            "event",
2190        ]);
2191
2192        let result = self.extract(text, &schema)?;
2193        Ok(result.entities)
2194    }
2195
2196    fn supported_types(&self) -> Vec<EntityType> {
2197        vec![
2198            EntityType::Person,
2199            EntityType::Organization,
2200            EntityType::Location,
2201            EntityType::Date,
2202            EntityType::Custom {
2203                name: "event".to_string(),
2204                category: EntityCategory::Creative,
2205            },
2206            EntityType::Custom {
2207                name: "product".to_string(),
2208                category: EntityCategory::Creative,
2209            },
2210            EntityType::Other("misc".to_string()),
2211        ]
2212    }
2213
2214    fn is_available(&self) -> bool {
2215        true
2216    }
2217
2218    fn name(&self) -> &'static str {
2219        "GLiNER2-ONNX"
2220    }
2221
2222    fn description(&self) -> &'static str {
2223        "Multi-task information extraction via GLiNER2 (ONNX backend)"
2224    }
2225}
2226
2227// =============================================================================
2228// Model Trait Implementation (Candle)
2229// =============================================================================
2230
2231#[cfg(feature = "candle")]
2232impl crate::Model for GLiNER2Candle {
2233    fn extract_entities(&self, text: &str, _language: Option<&str>) -> Result<Vec<Entity>> {
2234        let schema = TaskSchema::new().with_entities(&[
2235            "person",
2236            "organization",
2237            "location",
2238            "date",
2239            "event",
2240        ]);
2241
2242        let result = self.extract(text, &schema)?;
2243        Ok(result.entities)
2244    }
2245
2246    fn supported_types(&self) -> Vec<EntityType> {
2247        vec![
2248            EntityType::Person,
2249            EntityType::Organization,
2250            EntityType::Location,
2251            EntityType::Date,
2252            EntityType::Custom {
2253                name: "event".to_string(),
2254                category: EntityCategory::Creative,
2255            },
2256            EntityType::Custom {
2257                name: "product".to_string(),
2258                category: EntityCategory::Creative,
2259            },
2260            EntityType::Other("misc".to_string()),
2261        ]
2262    }
2263
2264    fn is_available(&self) -> bool {
2265        true
2266    }
2267
2268    fn name(&self) -> &'static str {
2269        "GLiNER2-Candle"
2270    }
2271
2272    fn description(&self) -> &'static str {
2273        "Multi-task information extraction via GLiNER2 (native Rust/Candle)"
2274    }
2275}
2276
2277// =============================================================================
2278// ZeroShotNER Trait Implementation
2279// =============================================================================
2280
2281#[cfg(feature = "onnx")]
2282impl ZeroShotNER for GLiNER2Onnx {
2283    fn default_types(&self) -> &[&'static str] {
2284        &["person", "organization", "location", "date", "event"]
2285    }
2286
2287    fn extract_with_types(
2288        &self,
2289        text: &str,
2290        types: &[&str],
2291        threshold: f32,
2292    ) -> Result<Vec<Entity>> {
2293        self.extract_ner(text, types, threshold)
2294    }
2295
2296    fn extract_with_descriptions(
2297        &self,
2298        text: &str,
2299        descriptions: &[&str],
2300        threshold: f32,
2301    ) -> Result<Vec<Entity>> {
2302        // Use descriptions as entity types directly (GLiNER2 supports this)
2303        self.extract_ner(text, descriptions, threshold)
2304    }
2305}
2306
2307#[cfg(feature = "candle")]
2308impl ZeroShotNER for GLiNER2Candle {
2309    fn default_types(&self) -> &[&'static str] {
2310        &["person", "organization", "location", "date", "event"]
2311    }
2312
2313    fn extract_with_types(
2314        &self,
2315        text: &str,
2316        types: &[&str],
2317        threshold: f32,
2318    ) -> Result<Vec<Entity>> {
2319        let type_strings: Vec<String> = types.iter().map(|s| s.to_string()).collect();
2320        self.extract_entities(text, &type_strings, threshold)
2321    }
2322
2323    fn extract_with_descriptions(
2324        &self,
2325        text: &str,
2326        descriptions: &[&str],
2327        threshold: f32,
2328    ) -> Result<Vec<Entity>> {
2329        // Use descriptions as entity types directly (GLiNER2 supports this)
2330        let type_strings: Vec<String> = descriptions.iter().map(|s| s.to_string()).collect();
2331        self.extract_entities(text, &type_strings, threshold)
2332    }
2333}
2334
2335// =============================================================================
2336// RelationExtractor Trait Implementation
2337// =============================================================================
2338
2339/// Relation extraction patterns for common entity type pairs.
2340/// Maps (head_type, tail_type) -> likely relation types.
2341#[cfg(any(feature = "onnx", feature = "candle"))]
2342fn get_likely_relations(head_type: &str, tail_type: &str) -> Vec<(&'static str, f32)> {
2343    let head = head_type.to_uppercase();
2344    let tail = tail_type.to_uppercase();
2345
2346    match (head.as_str(), tail.as_str()) {
2347        // CHisIEC-style entity type codes
2348        ("PER", "OFI") | ("PERSON", "OFI") => vec![("任职", 0.7), ("任職", 0.7)],
2349        ("OFI", "PER") => vec![("上下级", 0.6), ("上下級", 0.6)],
2350        ("PER", "LOC") => vec![
2351            ("到达", 0.55),
2352            ("到達", 0.55),
2353            ("出生于某地", 0.4),
2354            ("出生於某地", 0.4),
2355        ],
2356        ("LOC", "PER") => vec![("到达", 0.5), ("到達", 0.5)],
2357        ("PER", "PER") => vec![
2358            ("上下级", 0.45),
2359            ("上下級", 0.45),
2360            ("同僚", 0.4),
2361            ("父母", 0.3),
2362            ("兄弟", 0.3),
2363        ],
2364        ("OFI", "LOC") | ("LOC", "OFI") => vec![("管理", 0.5)],
2365        ("BOOK", "BOOK") | ("BOOK", "PER") | ("PER", "BOOK") => {
2366            vec![("别名", 0.35), ("別名", 0.35)]
2367        }
2368        // Person-Organization relations
2369        ("PERSON", "ORGANIZATION") | ("PER", "ORG") => vec![
2370            ("WORKS_FOR", 0.7),
2371            ("FOUNDED", 0.5),
2372            ("CEO_OF", 0.4),
2373            ("MEMBER_OF", 0.6),
2374        ],
2375        ("ORGANIZATION", "PERSON") | ("ORG", "PER") => {
2376            vec![("EMPLOYS", 0.7), ("FOUNDED_BY", 0.5), ("LED_BY", 0.4)]
2377        }
2378        // Person-Location relations
2379        ("PERSON", "LOCATION") | ("PERSON", "GPE") | ("PER", "GPE") => {
2380            vec![("LIVES_IN", 0.6), ("BORN_IN", 0.5), ("VISITED", 0.4)]
2381        }
2382        // Organization-Location relations
2383        ("ORGANIZATION", "LOCATION")
2384        | ("ORG", "LOC")
2385        | ("ORGANIZATION", "GPE")
2386        | ("ORG", "GPE") => vec![
2387            ("HEADQUARTERED_IN", 0.7),
2388            ("LOCATED_IN", 0.8),
2389            ("OPERATES_IN", 0.5),
2390        ],
2391        // Product-Organization relations
2392        ("PRODUCT", "ORGANIZATION") | ("PRODUCT", "ORG") => {
2393            vec![("MADE_BY", 0.8), ("PRODUCED_BY", 0.7)]
2394        }
2395        ("ORGANIZATION", "PRODUCT") | ("ORG", "PRODUCT") => {
2396            vec![("MAKES", 0.8), ("PRODUCES", 0.7), ("ANNOUNCED", 0.5)]
2397        }
2398        // Date relations
2399        (_, "DATE") | (_, "TIME") => vec![("OCCURRED_ON", 0.5), ("FOUNDED_ON", 0.4)],
2400        // Default: no strong relation signal
2401        _ => vec![],
2402    }
2403}
2404
2405/// Extract relations using proximity and type-based heuristics.
2406/// This is a lightweight approach that doesn't require a separate relation model.
2407#[cfg(any(feature = "onnx", feature = "candle"))]
2408fn extract_relations_heuristic(
2409    entities: &[Entity],
2410    text: &str,
2411    relation_types: &[&str],
2412    threshold: f32,
2413) -> Vec<crate::backends::inference::RelationTriple> {
2414    use crate::backends::inference::RelationTriple;
2415
2416    // Normalize relation slugs so dataset styles like "part-of" match canonical "PART_OF".
2417    fn norm_rel_slug(s: &str) -> String {
2418        let mut out = String::with_capacity(s.len());
2419        let mut prev_underscore = false;
2420        for ch in s.chars() {
2421            if ch.is_alphanumeric() {
2422                // Keep Unicode letters/digits; uppercase ASCII for stable matching.
2423                if ch.is_ascii_alphabetic() {
2424                    out.push(ch.to_ascii_uppercase());
2425                } else {
2426                    out.push(ch);
2427                }
2428                prev_underscore = false;
2429            } else if !prev_underscore {
2430                out.push('_');
2431                prev_underscore = true;
2432            }
2433        }
2434        while out.starts_with('_') {
2435            out.remove(0);
2436        }
2437        while out.ends_with('_') {
2438            out.pop();
2439        }
2440        out
2441    }
2442
2443    fn pick_relation_label(canonical: &str, relation_types: &[&str]) -> Option<String> {
2444        if relation_types.is_empty() {
2445            return None;
2446        }
2447        let want = norm_rel_slug(canonical);
2448        relation_types
2449            .iter()
2450            .find(|r| norm_rel_slug(r) == want)
2451            .map(|s| (*s).to_string())
2452    }
2453
2454    let mut relations = Vec::new();
2455    // Entity offsets in anno are character offsets. Keep all length math consistent with chars.
2456    let text_char_count = text.chars().count();
2457    let text_char_len = text_char_count.max(1) as f32;
2458
2459    // Relation trigger patterns (canonicalized).
2460    //
2461    // IMPORTANT: Many evaluation datasets use hyphenated / lowercase labels (e.g. "part-of",
2462    // "general-affiliation"). We emit the dataset’s exact label when it’s present in
2463    // `relation_types`; otherwise we fall back to the canonical name.
2464    let trigger_patterns: Vec<(&str, &str)> = vec![
2465        // CrossRE/DocRED-style coarse labels
2466        ("part of", "PART_OF"),
2467        ("subset of", "PART_OF"),
2468        ("member of", "PART_OF"),
2469        ("type of", "TYPE_OF"),
2470        ("kind of", "TYPE_OF"),
2471        ("is a", "TYPE_OF"),
2472        ("are a", "TYPE_OF"),
2473        ("related to", "RELATED_TO"),
2474        ("also known as", "NAMED"),
2475        ("known as", "NAMED"),
2476        ("called", "NAMED"),
2477        ("named", "NAMED"),
2478        ("born", "TEMPORAL"),
2479        ("in 19", "TEMPORAL"),
2480        ("in 20", "TEMPORAL"),
2481        ("during", "TEMPORAL"),
2482        ("from", "ORIGIN"),
2483        ("based in", "PHYSICAL"),
2484        ("located in", "PHYSICAL"),
2485        ("headquartered", "PHYSICAL"),
2486        ("at ", "PHYSICAL"),
2487        ("vs", "COMPARE"),
2488        ("versus", "COMPARE"),
2489        ("compared", "COMPARE"),
2490        ("use", "USAGE"),
2491        ("used", "USAGE"),
2492        ("uses", "USAGE"),
2493        ("invented", "ARTIFACT"),
2494        ("created", "ARTIFACT"),
2495        ("built", "ARTIFACT"),
2496        ("developed", "ARTIFACT"),
2497        ("won", "WIN_DEFEAT"),
2498        ("defeated", "WIN_DEFEAT"),
2499        ("beat", "WIN_DEFEAT"),
2500        ("caused", "CAUSE_EFFECT"),
2501        ("causes", "CAUSE_EFFECT"),
2502        ("leads to", "CAUSE_EFFECT"),
2503        ("because", "CAUSE_EFFECT"),
2504        // CHisIEC (classical Chinese) relation labels + minimal triggers
2505        // Note: CHisIEC text is character-tokenized (no spaces), so we match short substrings.
2506        ("父", "父母"),
2507        ("母", "父母"),
2508        ("兄", "兄弟"),
2509        ("弟", "兄弟"),
2510        ("别名", "别名"),
2511        ("別名", "別名"),
2512        ("生于", "出生于某地"),
2513        ("生於", "出生於某地"),
2514        ("到", "到达"),
2515        ("到", "到達"),
2516        ("至", "到达"),
2517        ("至", "到達"),
2518        ("驻", "驻守"),
2519        ("駐", "駐守"),
2520        ("守", "驻守"),
2521        ("守", "駐守"),
2522        ("攻", "敌对攻伐"),
2523        ("伐", "敌对攻伐"),
2524        ("攻", "敵對攻伐"),
2525        ("伐", "敵對攻伐"),
2526        ("任", "任职"),
2527        ("任", "任職"),
2528        ("拜", "任职"),
2529        ("拜", "任職"),
2530        ("管", "管理"),
2531        ("治", "管理"),
2532        // Legacy canonical names (kept for non-CrossRE label sets)
2533        ("ceo", "CEO_OF"),
2534        ("founder", "FOUNDED"),
2535        ("founded", "FOUNDED"),
2536        ("works at", "WORKS_FOR"),
2537        ("works for", "WORKS_FOR"),
2538        ("employee", "WORKS_FOR"),
2539        ("born in", "BORN_IN"),
2540        ("lives in", "LIVES_IN"),
2541        ("announced", "ANNOUNCED"),
2542        ("released", "RELEASED"),
2543        ("acquired", "ACQUIRED"),
2544        ("bought", "ACQUIRED"),
2545        ("merged", "MERGED_WITH"),
2546    ];
2547
2548    for (i, head) in entities.iter().enumerate() {
2549        for (j, tail) in entities.iter().enumerate() {
2550            if i == j {
2551                continue;
2552            }
2553
2554            // Distance-based scoring: closer entities are more likely related
2555            let head_center = (head.start + head.end) as f32 / 2.0;
2556            let tail_center = (tail.start + tail.end) as f32 / 2.0;
2557            let distance = (head_center - tail_center).abs() / text_char_len;
2558            let proximity_score = 1.0 - distance.min(1.0);
2559
2560            // Type-based relation candidates
2561            let head_type = head.entity_type.as_label();
2562            let tail_type = tail.entity_type.as_label();
2563            let type_relations = get_likely_relations(head_type, tail_type);
2564
2565            // Check for trigger patterns in text between entities
2566            let (span_start, span_end) = if head.end < tail.start {
2567                (head.end, tail.start)
2568            } else if tail.end < head.start {
2569                (tail.end, head.start)
2570            } else {
2571                // Overlapping entities - use surrounding context
2572                let min_start = head.start.min(tail.start);
2573                let max_end = head.end.max(tail.end);
2574                (
2575                    min_start.saturating_sub(20),
2576                    (max_end + 20).min(text_char_count),
2577                )
2578            };
2579
2580            let between_text = if span_end > span_start && span_end <= text_char_count {
2581                crate::offset::TextSpan::from_chars(text, span_start, span_end).extract(text)
2582            } else {
2583                ""
2584            };
2585            let between_lower = between_text.to_ascii_lowercase();
2586
2587            // Check trigger patterns
2588            for (trigger, rel_type) in &trigger_patterns {
2589                let hit = if trigger.is_ascii() {
2590                    between_lower.contains(trigger)
2591                } else {
2592                    between_text.contains(trigger)
2593                };
2594                if hit {
2595                    // Filter by requested relation types if specified, using normalization.
2596                    // This allows "part-of" to match canonical "PART_OF", etc.
2597                    if !relation_types.is_empty()
2598                        && pick_relation_label(rel_type, relation_types).is_none()
2599                    {
2600                        continue;
2601                    }
2602
2603                    let out_label = pick_relation_label(rel_type, relation_types)
2604                        .unwrap_or_else(|| rel_type.to_string());
2605
2606                    let confidence = (proximity_score * 0.6 + 0.4)
2607                        * (head.confidence + tail.confidence) as f32
2608                        / 2.0;
2609                    if confidence >= threshold {
2610                        relations.push(RelationTriple {
2611                            head_idx: i,
2612                            tail_idx: j,
2613                            relation_type: out_label,
2614                            confidence,
2615                        });
2616                    }
2617                }
2618            }
2619
2620            // Type-based relations (if no explicit trigger found)
2621            let has_trigger_relation = relations.iter().any(|r| r.head_idx == i && r.tail_idx == j);
2622            if !has_trigger_relation && proximity_score > 0.3 {
2623                for (rel_type, base_score) in type_relations {
2624                    if !relation_types.is_empty()
2625                        && pick_relation_label(rel_type, relation_types).is_none()
2626                    {
2627                        continue;
2628                    }
2629                    let out_label = pick_relation_label(rel_type, relation_types)
2630                        .unwrap_or_else(|| rel_type.to_string());
2631
2632                    let confidence =
2633                        proximity_score * base_score * (head.confidence + tail.confidence) as f32
2634                            / 2.0;
2635                    if confidence >= threshold {
2636                        relations.push(RelationTriple {
2637                            head_idx: i,
2638                            tail_idx: j,
2639                            relation_type: out_label,
2640                            confidence,
2641                        });
2642                        break; // Only add one type-based relation per pair
2643                    }
2644                }
2645            }
2646        }
2647    }
2648
2649    // Sort by confidence and deduplicate
2650    relations.sort_by(|a, b| {
2651        b.confidence
2652            .partial_cmp(&a.confidence)
2653            .unwrap_or(std::cmp::Ordering::Equal)
2654    });
2655
2656    // Keep only top relation per entity pair
2657    let mut seen_pairs = std::collections::HashSet::new();
2658    relations.retain(|r| seen_pairs.insert((r.head_idx, r.tail_idx)));
2659
2660    relations
2661}
2662
2663#[cfg(feature = "onnx")]
2664impl RelationExtractor for GLiNER2Onnx {
2665    fn extract_with_relations(
2666        &self,
2667        text: &str,
2668        types: &[&str],
2669        relation_types: &[&str],
2670        threshold: f32,
2671    ) -> Result<ExtractionWithRelations> {
2672        // Extract entities first
2673        let entities = self.extract_ner(text, types, threshold)?;
2674
2675        // Extract relations using heuristics
2676        let relations = extract_relations_heuristic(&entities, text, relation_types, threshold);
2677
2678        Ok(ExtractionWithRelations {
2679            entities,
2680            relations,
2681        })
2682    }
2683}
2684
2685#[cfg(feature = "candle")]
2686impl RelationExtractor for GLiNER2Candle {
2687    fn extract_with_relations(
2688        &self,
2689        text: &str,
2690        types: &[&str],
2691        relation_types: &[&str],
2692        threshold: f32,
2693    ) -> Result<ExtractionWithRelations> {
2694        let type_strings: Vec<String> = types.iter().map(|s| s.to_string()).collect();
2695        let entities = self.extract_entities(text, &type_strings, threshold)?;
2696
2697        // Extract relations using heuristics
2698        let relations = extract_relations_heuristic(&entities, text, relation_types, threshold);
2699
2700        Ok(ExtractionWithRelations {
2701            entities,
2702            relations,
2703        })
2704    }
2705}
2706
2707// =============================================================================
2708// BatchCapable Trait Implementation
2709// =============================================================================
2710
2711#[cfg(feature = "onnx")]
2712impl crate::BatchCapable for GLiNER2Onnx {
2713    fn extract_entities_batch(
2714        &self,
2715        texts: &[&str],
2716        _language: Option<&str>,
2717    ) -> Result<Vec<Vec<Entity>>> {
2718        if texts.is_empty() {
2719            return Ok(Vec::new());
2720        }
2721
2722        let default_types = &["person", "organization", "location", "date", "event"];
2723
2724        // For true batching, we need to:
2725        // 1. Tokenize all texts
2726        // 2. Pad to max length
2727        // 3. Run as single batch
2728        // 4. Split results back
2729
2730        // Collect word-level tokenizations
2731        let text_words: Vec<Vec<&str>> = texts
2732            .iter()
2733            .map(|t| t.split_whitespace().collect())
2734            .collect();
2735
2736        // Find max word count
2737        let max_words = text_words.iter().map(|w| w.len()).max().unwrap_or(0);
2738        if max_words == 0 {
2739            return Ok(texts.iter().map(|_| Vec::new()).collect());
2740        }
2741
2742        // Encode all prompts (no span tensors needed for current model)
2743        let mut all_input_ids = Vec::new();
2744        let mut all_attention_masks = Vec::new();
2745        let mut all_words_masks = Vec::new();
2746        let mut all_text_lengths = Vec::new();
2747        let mut seq_lens = Vec::new();
2748
2749        for words in &text_words {
2750            if words.is_empty() {
2751                // Handle empty text
2752                seq_lens.push(0);
2753                continue;
2754            }
2755
2756            let (input_ids, attention_mask, words_mask) =
2757                self.encode_ner_prompt(words, default_types)?;
2758            seq_lens.push(input_ids.len());
2759            all_input_ids.push(input_ids);
2760            all_attention_masks.push(attention_mask);
2761            all_words_masks.push(words_mask);
2762            all_text_lengths.push(words.len() as i64);
2763        }
2764
2765        // If all texts were empty, return empty results
2766        if seq_lens.iter().all(|&l| l == 0) {
2767            return Ok(texts.iter().map(|_| Vec::new()).collect());
2768        }
2769
2770        // Pad sequences to max length
2771        let max_seq_len = seq_lens.iter().copied().max().unwrap_or(0);
2772
2773        for i in 0..all_input_ids.len() {
2774            let pad_len = max_seq_len - all_input_ids[i].len();
2775            all_input_ids[i].extend(std::iter::repeat_n(0i64, pad_len));
2776            all_attention_masks[i].extend(std::iter::repeat_n(0i64, pad_len));
2777            all_words_masks[i].extend(std::iter::repeat_n(0i64, pad_len));
2778        }
2779
2780        // Build batched tensors - only 4 inputs (no span tensors)
2781        use ndarray::Array2;
2782
2783        let batch_size = all_input_ids.len();
2784
2785        let input_ids_flat: Vec<i64> = all_input_ids.into_iter().flatten().collect();
2786        let attention_mask_flat: Vec<i64> = all_attention_masks.into_iter().flatten().collect();
2787        let words_mask_flat: Vec<i64> = all_words_masks.into_iter().flatten().collect();
2788
2789        // Validate lengths before array creation
2790        let expected_input_len = batch_size * max_seq_len;
2791        if input_ids_flat.len() != expected_input_len {
2792            return Err(Error::Parse(format!(
2793                "Input IDs length mismatch: expected {}, got {}",
2794                expected_input_len,
2795                input_ids_flat.len()
2796            )));
2797        }
2798
2799        let input_ids_arr = Array2::from_shape_vec((batch_size, max_seq_len), input_ids_flat)
2800            .map_err(|e| Error::Parse(format!("Array: {}", e)))?;
2801        let attention_mask_arr =
2802            Array2::from_shape_vec((batch_size, max_seq_len), attention_mask_flat)
2803                .map_err(|e| Error::Parse(format!("Array: {}", e)))?;
2804        let words_mask_arr = Array2::from_shape_vec((batch_size, max_seq_len), words_mask_flat)
2805            .map_err(|e| Error::Parse(format!("Array: {}", e)))?;
2806        let text_lengths_arr = Array2::from_shape_vec((batch_size, 1), all_text_lengths)
2807            .map_err(|e| Error::Parse(format!("Array: {}", e)))?;
2808
2809        let input_ids_t = super::ort_compat::tensor_from_ndarray(input_ids_arr)
2810            .map_err(|e| Error::Parse(format!("Tensor: {}", e)))?;
2811        let attention_mask_t = super::ort_compat::tensor_from_ndarray(attention_mask_arr)
2812            .map_err(|e| Error::Parse(format!("Tensor: {}", e)))?;
2813        let words_mask_t = super::ort_compat::tensor_from_ndarray(words_mask_arr)
2814            .map_err(|e| Error::Parse(format!("Tensor: {}", e)))?;
2815        let text_lengths_t = super::ort_compat::tensor_from_ndarray(text_lengths_arr)
2816            .map_err(|e| Error::Parse(format!("Tensor: {}", e)))?;
2817
2818        // Run batched inference with blocking lock for thread-safe parallel access
2819        let mut session = lock(&self.session);
2820
2821        let outputs = session
2822            .run(ort::inputs![
2823                "input_ids" => input_ids_t.into_dyn(),
2824                "attention_mask" => attention_mask_t.into_dyn(),
2825                "words_mask" => words_mask_t.into_dyn(),
2826                "text_lengths" => text_lengths_t.into_dyn(),
2827            ])
2828            .map_err(|e| Error::Inference(format!("ONNX batch run: {}", e)))?;
2829
2830        // Decode batch results
2831        self.decode_ner_batch_output(&outputs, texts, &text_words, default_types, 0.5)
2832    }
2833
2834    fn optimal_batch_size(&self) -> Option<usize> {
2835        Some(16)
2836    }
2837}
2838
2839#[cfg(feature = "candle")]
2840impl crate::BatchCapable for GLiNER2Candle {
2841    fn extract_entities_batch(
2842        &self,
2843        texts: &[&str],
2844        _language: Option<&str>,
2845    ) -> Result<Vec<Vec<Entity>>> {
2846        if texts.is_empty() {
2847            return Ok(Vec::new());
2848        }
2849
2850        let default_types = vec![
2851            "person".to_string(),
2852            "organization".to_string(),
2853            "location".to_string(),
2854            "date".to_string(),
2855            "event".to_string(),
2856        ];
2857
2858        // Pre-compute label embeddings once for all texts
2859        let label_refs: Vec<&str> = default_types.iter().map(|s| s.as_str()).collect();
2860        let _ = self.encode_labels_cached(&label_refs)?;
2861
2862        // Process texts - labels are now cached for efficiency
2863        let mut results = Vec::with_capacity(texts.len());
2864
2865        for text in texts {
2866            let entities = self.extract_entities(text, &default_types, 0.5)?;
2867            results.push(entities);
2868        }
2869
2870        Ok(results)
2871    }
2872
2873    fn optimal_batch_size(&self) -> Option<usize> {
2874        Some(8)
2875    }
2876}
2877
2878// =============================================================================
2879// StreamingCapable Trait Implementation
2880// =============================================================================
2881
2882#[cfg(feature = "onnx")]
2883impl crate::StreamingCapable for GLiNER2Onnx {
2884    // Uses default extract_entities_streaming implementation which adjusts offsets
2885
2886    fn recommended_chunk_size(&self) -> usize {
2887        4096 // Characters - translates to roughly a few hundred words
2888    }
2889}
2890
2891#[cfg(feature = "candle")]
2892impl crate::StreamingCapable for GLiNER2Candle {
2893    // Uses default extract_entities_streaming implementation which adjusts offsets
2894
2895    fn recommended_chunk_size(&self) -> usize {
2896        4096
2897    }
2898}
2899
2900// =============================================================================
2901// GpuCapable Trait Implementation
2902// =============================================================================
2903
2904#[cfg(feature = "candle")]
2905impl crate::GpuCapable for GLiNER2Candle {
2906    fn is_gpu_active(&self) -> bool {
2907        matches!(&self.device, Device::Metal(_) | Device::Cuda(_))
2908    }
2909
2910    fn device(&self) -> &str {
2911        match &self.device {
2912            Device::Cpu => "cpu",
2913            Device::Metal(_) => "metal",
2914            Device::Cuda(_) => "cuda",
2915        }
2916    }
2917}
2918
2919// =============================================================================
2920// Tests
2921// =============================================================================
2922
2923#[cfg(test)]
2924mod tests {
2925    use super::*;
2926
2927    #[test]
2928    #[cfg(any(feature = "onnx", feature = "candle"))]
2929    fn test_relation_heuristic_unicode_safe_and_case_insensitive() {
2930        use crate::backends::inference::RelationTriple;
2931        use crate::offset::bytes_to_chars;
2932
2933        let text = "Dr. 田中 is CEO of Apple Inc. in 東京. François works at OpenAI.";
2934        let span = |needle: &str| {
2935            let (b_start, _) = text
2936                .match_indices(needle)
2937                .next()
2938                .expect("needle should exist in test text");
2939            let b_end = b_start + needle.len();
2940            bytes_to_chars(text, b_start, b_end)
2941        };
2942
2943        let (s, e) = span("田中");
2944        let e_tanaka = Entity::new("田中", EntityType::Person, s, e, 0.9);
2945        let (s, e) = span("Apple Inc.");
2946        let e_apple = Entity::new("Apple Inc.", EntityType::Organization, s, e, 0.9);
2947        let (s, e) = span("東京");
2948        let e_tokyo = Entity::new("東京", EntityType::Location, s, e, 0.9);
2949        let (s, e) = span("François");
2950        let e_francois = Entity::new("François", EntityType::Person, s, e, 0.9);
2951        let (s, e) = span("OpenAI");
2952        let e_openai = Entity::new("OpenAI", EntityType::Organization, s, e, 0.9);
2953
2954        let entities = vec![e_tanaka, e_apple, e_tokyo, e_francois, e_openai];
2955
2956        // Should not panic on Unicode text; should detect at least one trigger relation.
2957        let rels: Vec<RelationTriple> = extract_relations_heuristic(&entities, text, &[], 0.0);
2958        assert!(
2959            rels.iter()
2960                .any(|r| r.relation_type == "CEO_OF" || r.relation_type == "WORKS_FOR"),
2961            "expected at least one trigger-based relation, got {:?}",
2962            rels
2963        );
2964    }
2965
2966    #[test]
2967    fn test_task_schema_builder() {
2968        let schema = TaskSchema::new()
2969            .with_entities(&["person", "organization"])
2970            .with_classification("sentiment", &["positive", "negative"], false);
2971
2972        assert!(schema.entities.is_some());
2973        assert_eq!(schema.entities.as_ref().unwrap().types.len(), 2);
2974        assert_eq!(schema.classifications.len(), 1);
2975    }
2976
2977    #[test]
2978    fn test_structure_task_builder() {
2979        let task = StructureTask::new("product")
2980            .with_field("name", FieldType::String)
2981            .with_field_described("price", FieldType::String, "Product price in USD")
2982            .with_choice_field("category", &["electronics", "clothing"]);
2983
2984        assert_eq!(task.fields.len(), 3);
2985        assert_eq!(task.fields[2].choices.as_ref().unwrap().len(), 2);
2986    }
2987
2988    #[test]
2989    fn test_word_span_to_char_offsets() {
2990        use crate::offset::TextSpan;
2991
2992        let text = "John works at Apple";
2993        let words: Vec<&str> = text.split_whitespace().collect();
2994
2995        let (start, end) = word_span_to_char_offsets(text, &words, 0, 0);
2996        assert_eq!(TextSpan::from_chars(text, start, end).extract(text), "John");
2997
2998        let (start, end) = word_span_to_char_offsets(text, &words, 3, 3);
2999        assert_eq!(
3000            TextSpan::from_chars(text, start, end).extract(text),
3001            "Apple"
3002        );
3003
3004        let (start, end) = word_span_to_char_offsets(text, &words, 0, 2);
3005        assert_eq!(
3006            TextSpan::from_chars(text, start, end).extract(text),
3007            "John works at"
3008        );
3009    }
3010
3011    #[test]
3012    fn test_map_entity_type() {
3013        assert!(matches!(map_entity_type("person"), EntityType::Person));
3014        assert!(matches!(
3015            map_entity_type("ORGANIZATION"),
3016            EntityType::Organization
3017        ));
3018        assert!(matches!(map_entity_type("loc"), EntityType::Location));
3019        // Unknown types map to Other with the uppercase version (due to schema normalization)
3020        assert!(
3021            matches!(map_entity_type("custom_type"), EntityType::Other(ref s) if s == "CUSTOM_TYPE")
3022        );
3023        // Known special types map to Custom
3024        assert!(matches!(
3025            map_entity_type("product"),
3026            EntityType::Custom { .. }
3027        ));
3028        assert!(matches!(
3029            map_entity_type("event"),
3030            EntityType::Custom { .. }
3031        ));
3032    }
3033}