Skip to main content

anno/backends/gliner2/
mod.rs

1//! GLiNER2: Multi-task Information Extraction.
2//!
3//! GLiNER2 extends GLiNER to support:
4//! - Named Entity Recognition (with label descriptions)
5//! - Text Classification (single/multi-label)
6//! - Hierarchical Structure Extraction
7//! - Task Composition (multiple tasks in one pass)
8//!
9//! This backend is based on the GLiNER2 paper (arXiv:2507.18546). The details of
10//! prompt formatting and the full task schema are paper-defined; this module
11//! focuses on the inference integration and trait wiring used by `anno`.
12//!
13//! # Trait Integration
14//!
15//! GLiNER2 implements the standard `anno` traits:
16//! - `Model` - Core entity extraction interface
17//! - `ZeroShotNER` - Open-domain entity types
18//! - `RelationExtractor` - Joint entity-relation extraction (via GLiREL)
19//! - `BatchCapable` - Efficient batch processing
20//!
21//! # Usage
22//!
23//! ```rust,ignore
24//! use anno::{Model, ZeroShotNER, DEFAULT_GLINER2_MODEL};
25//! use anno::backends::gliner2::{GLiNER2, TaskSchema};
26//!
27//! // Use the official Fastino Labs GLiNER2 model
28//! let model = GLiNER2::from_pretrained(DEFAULT_GLINER2_MODEL)?;
29//! // Or: GLiNER2::from_pretrained("fastino/gliner2-base-v1")?;
30//!
31//! // Standard Model trait
32//! let entities = model.extract_entities("Apple announced iPhone 15", None)?;
33//!
34//! // Zero-shot with custom types
35//! let types = &["company", "product", "event"];
36//! let entities = model.extract_with_types(text, types, 0.5)?;
37//!
38//! // Multi-task extraction with schema
39//! let schema = TaskSchema::new()
40//!     .with_entities(&["person", "organization", "product"])
41//!     .with_classification("sentiment", &["positive", "negative", "neutral"]);
42//!
43//! let result = model.extract_with_schema("Apple announced iPhone 15", &schema)?;
44//! ```
45//!
46//! # Backends
47//!
48//! - **ONNX** (recommended): `cargo build --features onnx`
49//! - **Candle** (native): `cargo build --features candle`
50
51#[cfg(feature = "onnx")]
52use crate::sync::lock;
53use crate::{Entity, EntityType, Error, Result};
54use anno_core::EntityCategory;
55#[cfg(feature = "candle")]
56use candle_core::Device;
57
58pub(crate) mod relations;
59
60use crate::backends::inference::{ExtractionWithRelations, RelationExtractor, ZeroShotNER};
61
62#[cfg(feature = "candle")]
63pub mod candle;
64#[cfg(feature = "onnx")]
65pub mod onnx;
66pub mod schema;
67#[cfg(feature = "candle")]
68pub use candle::GLiNER2Candle;
69#[cfg(feature = "onnx")]
70pub use onnx::GLiNER2Onnx;
71pub use schema::{
72    ClassificationResult, ClassificationTask, EntityTask, ExtractedStructure, ExtractionResult,
73    FieldType, LabelCache, StructureTask, StructureValue, TaskSchema,
74};
75
76// Stub implementations (no feature)
77// =============================================================================
78
79/// GLiNER2 stub (requires onnx or candle feature).
80#[cfg(not(any(feature = "onnx", feature = "candle")))]
81#[derive(Debug)]
82pub struct GLiNER2 {
83    _private: (),
84}
85
86#[cfg(not(any(feature = "onnx", feature = "candle")))]
87impl GLiNER2 {
88    /// Load model (requires feature).
89    pub fn from_pretrained(_model_id: &str) -> Result<Self> {
90        Err(Error::FeatureNotAvailable(
91            "GLiNER2 requires 'onnx' or 'candle' feature. \
92             Build with: cargo build --features candle"
93                .to_string(),
94        ))
95    }
96
97    /// Extract (requires feature).
98    pub fn extract(&self, _text: &str, _schema: &TaskSchema) -> Result<ExtractionResult> {
99        Err(Error::FeatureNotAvailable(
100            "GLiNER2 requires features".to_string(),
101        ))
102    }
103}
104
105// =============================================================================
106// Unified GLiNER2 type
107// =============================================================================
108
109/// GLiNER2 model - automatically selects best available backend.
110#[cfg(feature = "candle")]
111pub type GLiNER2 = GLiNER2Candle;
112
113/// GLiNER2 model - ONNX backend (when candle not enabled).
114#[cfg(all(feature = "onnx", not(feature = "candle")))]
115pub type GLiNER2 = GLiNER2Onnx;
116
117// =============================================================================
118// Helper functions
119// =============================================================================
120
121/// Convert word span indices to character offsets.
122pub(super) fn word_span_to_char_offsets(
123    text: &str,
124    words: &[&str],
125    start_word: usize,
126    end_word: usize,
127) -> (usize, usize) {
128    // Defensive: Validate bounds
129    if words.is_empty()
130        || start_word >= words.len()
131        || end_word >= words.len()
132        || start_word > end_word
133    {
134        // Return safe defaults: empty span at start of text
135        return (0, 0);
136    }
137
138    // Track our search position in **bytes**.
139    let mut byte_pos = 0;
140    let mut start_byte = 0;
141    let mut end_byte = text.len();
142    let mut found_start = false;
143    let mut found_end = false;
144
145    for (i, word) in words.iter().enumerate() {
146        if let Some(pos) = text.get(byte_pos..).and_then(|s| s.find(word)) {
147            let abs_pos = byte_pos + pos;
148
149            if i == start_word {
150                start_byte = abs_pos;
151                found_start = true;
152            }
153            if i == end_word {
154                end_byte = abs_pos + word.len();
155                found_end = true;
156                // Early exit: we found both start and end
157                break;
158            }
159
160            byte_pos = abs_pos + word.len();
161        } else {
162            // Word not found - this shouldn't happen in normal operation,
163            // but if it does, we can't reliably compute offsets
164            // Continue searching but mark that we may have incorrect results
165        }
166    }
167
168    // If we didn't find the words, return safe defaults
169    if !found_start || !found_end {
170        // Return empty span to avoid incorrect entity extraction
171        (0, 0)
172    } else {
173        // Convert byte offsets to character offsets (anno spans are char-based).
174        crate::offset::bytes_to_chars(text, start_byte, end_byte)
175    }
176}
177
178/// Map entity type string to EntityType.
179///
180/// Uses the canonical schema mapper for consistent semantics across all backends.
181pub(super) fn map_entity_type(type_str: &str) -> EntityType {
182    crate::schema::map_to_canonical(type_str, None)
183}
184
185// =============================================================================
186// Model Trait Implementation (ONNX)
187// =============================================================================
188
189#[cfg(feature = "onnx")]
190impl crate::Model for GLiNER2Onnx {
191    fn extract_entities(&self, text: &str, _language: Option<&str>) -> Result<Vec<Entity>> {
192        let schema = TaskSchema::new().with_entities(&[
193            "person",
194            "organization",
195            "location",
196            "date",
197            "event",
198        ]);
199
200        let result = self.extract(text, &schema)?;
201        Ok(result.entities)
202    }
203
204    fn supported_types(&self) -> Vec<EntityType> {
205        vec![
206            EntityType::Person,
207            EntityType::Organization,
208            EntityType::Location,
209            EntityType::Date,
210            EntityType::Custom {
211                name: "event".to_string(),
212                category: EntityCategory::Creative,
213            },
214            EntityType::Custom {
215                name: "product".to_string(),
216                category: EntityCategory::Creative,
217            },
218            EntityType::Other("misc".to_string()),
219        ]
220    }
221
222    fn is_available(&self) -> bool {
223        true
224    }
225
226    fn name(&self) -> &'static str {
227        "GLiNER2-ONNX"
228    }
229
230    fn description(&self) -> &'static str {
231        "Multi-task information extraction via GLiNER2 (ONNX backend)"
232    }
233
234    fn capabilities(&self) -> crate::ModelCapabilities {
235        crate::ModelCapabilities {
236            batch_capable: true,
237            streaming_capable: true,
238            relation_capable: true,
239            dynamic_labels: true,
240            ..Default::default()
241        }
242    }
243}
244
245#[cfg(feature = "onnx")]
246impl crate::NamedEntityCapable for GLiNER2Onnx {}
247
248#[cfg(feature = "onnx")]
249impl crate::DynamicLabels for GLiNER2Onnx {
250    fn extract_with_labels(
251        &self,
252        text: &str,
253        labels: &[&str],
254        _language: Option<&str>,
255    ) -> crate::Result<Vec<crate::Entity>> {
256        <Self as ZeroShotNER>::extract_with_types(self, text, labels, 0.3)
257    }
258}
259
260// =============================================================================
261// Model Trait Implementation (Candle)
262// =============================================================================
263
264#[cfg(feature = "candle")]
265impl crate::Model for GLiNER2Candle {
266    fn extract_entities(&self, text: &str, _language: Option<&str>) -> Result<Vec<Entity>> {
267        let schema = TaskSchema::new().with_entities(&[
268            "person",
269            "organization",
270            "location",
271            "date",
272            "event",
273        ]);
274
275        let result = self.extract(text, &schema)?;
276        Ok(result.entities)
277    }
278
279    fn supported_types(&self) -> Vec<EntityType> {
280        vec![
281            EntityType::Person,
282            EntityType::Organization,
283            EntityType::Location,
284            EntityType::Date,
285            EntityType::Custom {
286                name: "event".to_string(),
287                category: EntityCategory::Creative,
288            },
289            EntityType::Custom {
290                name: "product".to_string(),
291                category: EntityCategory::Creative,
292            },
293            EntityType::Other("misc".to_string()),
294        ]
295    }
296
297    fn is_available(&self) -> bool {
298        true
299    }
300
301    fn name(&self) -> &'static str {
302        "GLiNER2-Candle"
303    }
304
305    fn description(&self) -> &'static str {
306        "Multi-task information extraction via GLiNER2 (native Rust/Candle)"
307    }
308
309    fn capabilities(&self) -> crate::ModelCapabilities {
310        crate::ModelCapabilities {
311            batch_capable: true,
312            streaming_capable: true,
313            gpu_capable: true,
314            relation_capable: true,
315            dynamic_labels: true,
316            ..Default::default()
317        }
318    }
319}
320
321#[cfg(feature = "candle")]
322impl crate::NamedEntityCapable for GLiNER2Candle {}
323
324#[cfg(feature = "candle")]
325impl crate::DynamicLabels for GLiNER2Candle {
326    fn extract_with_labels(
327        &self,
328        text: &str,
329        labels: &[&str],
330        _language: Option<&str>,
331    ) -> crate::Result<Vec<crate::Entity>> {
332        <Self as ZeroShotNER>::extract_with_types(self, text, labels, 0.3)
333    }
334}
335
336// =============================================================================
337// ZeroShotNER Trait Implementation
338// =============================================================================
339
340#[cfg(feature = "onnx")]
341impl ZeroShotNER for GLiNER2Onnx {
342    fn default_types(&self) -> &[&'static str] {
343        &["person", "organization", "location", "date", "event"]
344    }
345
346    fn extract_with_types(
347        &self,
348        text: &str,
349        types: &[&str],
350        threshold: f32,
351    ) -> Result<Vec<Entity>> {
352        self.extract_ner(text, types, threshold)
353    }
354
355    fn extract_with_descriptions(
356        &self,
357        text: &str,
358        descriptions: &[&str],
359        threshold: f32,
360    ) -> Result<Vec<Entity>> {
361        // Use descriptions as entity types directly (GLiNER2 supports this)
362        self.extract_ner(text, descriptions, threshold)
363    }
364}
365
366#[cfg(feature = "candle")]
367impl ZeroShotNER for GLiNER2Candle {
368    fn default_types(&self) -> &[&'static str] {
369        &["person", "organization", "location", "date", "event"]
370    }
371
372    fn extract_with_types(
373        &self,
374        text: &str,
375        types: &[&str],
376        threshold: f32,
377    ) -> Result<Vec<Entity>> {
378        let type_strings: Vec<String> = types.iter().map(|s| s.to_string()).collect();
379        self.extract_entities(text, &type_strings, threshold)
380    }
381
382    fn extract_with_descriptions(
383        &self,
384        text: &str,
385        descriptions: &[&str],
386        threshold: f32,
387    ) -> Result<Vec<Entity>> {
388        // Use descriptions as entity types directly (GLiNER2 supports this)
389        let type_strings: Vec<String> = descriptions.iter().map(|s| s.to_string()).collect();
390        self.extract_entities(text, &type_strings, threshold)
391    }
392}
393
394#[cfg(feature = "onnx")]
395impl RelationExtractor for GLiNER2Onnx {
396    fn extract_with_relations(
397        &self,
398        text: &str,
399        types: &[&str],
400        relation_types: &[&str],
401        threshold: f32,
402    ) -> Result<ExtractionWithRelations> {
403        // Extract entities first
404        let entities = self.extract_ner(text, types, threshold)?;
405
406        // Extract relations using heuristics
407        let relations =
408            relations::extract_relations_heuristic(&entities, text, relation_types, threshold);
409
410        Ok(ExtractionWithRelations {
411            entities,
412            relations,
413        })
414    }
415}
416
417#[cfg(feature = "candle")]
418impl RelationExtractor for GLiNER2Candle {
419    fn extract_with_relations(
420        &self,
421        text: &str,
422        types: &[&str],
423        relation_types: &[&str],
424        threshold: f32,
425    ) -> Result<ExtractionWithRelations> {
426        let type_strings: Vec<String> = types.iter().map(|s| s.to_string()).collect();
427        let entities = self.extract_entities(text, &type_strings, threshold)?;
428
429        // Extract relations using heuristics
430        let relations =
431            relations::extract_relations_heuristic(&entities, text, relation_types, threshold);
432
433        Ok(ExtractionWithRelations {
434            entities,
435            relations,
436        })
437    }
438}
439
440// =============================================================================
441// RelationCapable Trait Implementation (high-level public interface)
442// =============================================================================
443
444#[cfg(feature = "onnx")]
445impl crate::RelationCapable for GLiNER2Onnx {
446    fn extract_with_relations(
447        &self,
448        text: &str,
449        _language: Option<&str>,
450    ) -> Result<(Vec<Entity>, Vec<crate::Relation>)> {
451        use crate::backends::inference::{DEFAULT_ENTITY_TYPES, DEFAULT_RELATION_TYPES};
452        let result = <Self as RelationExtractor>::extract_with_relations(
453            self,
454            text,
455            DEFAULT_ENTITY_TYPES,
456            DEFAULT_RELATION_TYPES,
457            0.3,
458        )?;
459        Ok(result.into_anno_relations())
460    }
461}
462
463#[cfg(feature = "candle")]
464impl crate::RelationCapable for GLiNER2Candle {
465    fn extract_with_relations(
466        &self,
467        text: &str,
468        _language: Option<&str>,
469    ) -> Result<(Vec<Entity>, Vec<crate::Relation>)> {
470        use crate::backends::inference::{DEFAULT_ENTITY_TYPES, DEFAULT_RELATION_TYPES};
471        let result = <Self as RelationExtractor>::extract_with_relations(
472            self,
473            text,
474            DEFAULT_ENTITY_TYPES,
475            DEFAULT_RELATION_TYPES,
476            0.3,
477        )?;
478        Ok(result.into_anno_relations())
479    }
480}
481
482// =============================================================================
483// BatchCapable Trait Implementation
484// =============================================================================
485
486#[cfg(feature = "onnx")]
487impl crate::BatchCapable for GLiNER2Onnx {
488    fn extract_entities_batch(
489        &self,
490        texts: &[&str],
491        _language: Option<&str>,
492    ) -> Result<Vec<Vec<Entity>>> {
493        if texts.is_empty() {
494            return Ok(Vec::new());
495        }
496
497        let default_types = &["person", "organization", "location", "date", "event"];
498
499        // For true batching, we need to:
500        // 1. Tokenize all texts
501        // 2. Pad to max length
502        // 3. Run as single batch
503        // 4. Split results back
504
505        // Collect word-level tokenizations
506        let text_words: Vec<Vec<&str>> = texts
507            .iter()
508            .map(|t| t.split_whitespace().collect())
509            .collect();
510
511        // Find max word count
512        let max_words = text_words.iter().map(|w| w.len()).max().unwrap_or(0);
513        if max_words == 0 {
514            return Ok(texts.iter().map(|_| Vec::new()).collect());
515        }
516
517        // Encode all prompts (no span tensors needed for current model)
518        let mut all_input_ids = Vec::new();
519        let mut all_attention_masks = Vec::new();
520        let mut all_words_masks = Vec::new();
521        let mut all_text_lengths = Vec::new();
522        let mut seq_lens = Vec::new();
523
524        for words in &text_words {
525            if words.is_empty() {
526                // Handle empty text
527                seq_lens.push(0);
528                continue;
529            }
530
531            let (input_ids, attention_mask, words_mask) =
532                self.encode_ner_prompt(words, default_types)?;
533            seq_lens.push(input_ids.len());
534            all_input_ids.push(input_ids);
535            all_attention_masks.push(attention_mask);
536            all_words_masks.push(words_mask);
537            all_text_lengths.push(words.len() as i64);
538        }
539
540        // If all texts were empty, return empty results
541        if seq_lens.iter().all(|&l| l == 0) {
542            return Ok(texts.iter().map(|_| Vec::new()).collect());
543        }
544
545        // Pad sequences to max length
546        let max_seq_len = seq_lens.iter().copied().max().unwrap_or(0);
547
548        for i in 0..all_input_ids.len() {
549            let pad_len = max_seq_len - all_input_ids[i].len();
550            all_input_ids[i].extend(std::iter::repeat_n(0i64, pad_len));
551            all_attention_masks[i].extend(std::iter::repeat_n(0i64, pad_len));
552            all_words_masks[i].extend(std::iter::repeat_n(0i64, pad_len));
553        }
554
555        // Build batched tensors - only 4 inputs (no span tensors)
556        use ndarray::Array2;
557
558        let batch_size = all_input_ids.len();
559
560        let input_ids_flat: Vec<i64> = all_input_ids.into_iter().flatten().collect();
561        let attention_mask_flat: Vec<i64> = all_attention_masks.into_iter().flatten().collect();
562        let words_mask_flat: Vec<i64> = all_words_masks.into_iter().flatten().collect();
563
564        // Validate lengths before array creation
565        let expected_input_len = batch_size * max_seq_len;
566        if input_ids_flat.len() != expected_input_len {
567            return Err(Error::Parse(format!(
568                "Input IDs length mismatch: expected {}, got {}",
569                expected_input_len,
570                input_ids_flat.len()
571            )));
572        }
573
574        let input_ids_arr = Array2::from_shape_vec((batch_size, max_seq_len), input_ids_flat)
575            .map_err(|e| Error::Parse(format!("Array: {}", e)))?;
576        let attention_mask_arr =
577            Array2::from_shape_vec((batch_size, max_seq_len), attention_mask_flat)
578                .map_err(|e| Error::Parse(format!("Array: {}", e)))?;
579        let words_mask_arr = Array2::from_shape_vec((batch_size, max_seq_len), words_mask_flat)
580            .map_err(|e| Error::Parse(format!("Array: {}", e)))?;
581        let text_lengths_arr = Array2::from_shape_vec((batch_size, 1), all_text_lengths)
582            .map_err(|e| Error::Parse(format!("Array: {}", e)))?;
583
584        let input_ids_t = super::ort_compat::tensor_from_ndarray(input_ids_arr)
585            .map_err(|e| Error::Parse(format!("Tensor: {}", e)))?;
586        let attention_mask_t = super::ort_compat::tensor_from_ndarray(attention_mask_arr)
587            .map_err(|e| Error::Parse(format!("Tensor: {}", e)))?;
588        let words_mask_t = super::ort_compat::tensor_from_ndarray(words_mask_arr)
589            .map_err(|e| Error::Parse(format!("Tensor: {}", e)))?;
590        let text_lengths_t = super::ort_compat::tensor_from_ndarray(text_lengths_arr)
591            .map_err(|e| Error::Parse(format!("Tensor: {}", e)))?;
592
593        // Run batched inference with blocking lock for thread-safe parallel access
594        let mut session = lock(&self.session);
595
596        let outputs = session
597            .run(ort::inputs![
598                "input_ids" => input_ids_t.into_dyn(),
599                "attention_mask" => attention_mask_t.into_dyn(),
600                "words_mask" => words_mask_t.into_dyn(),
601                "text_lengths" => text_lengths_t.into_dyn(),
602            ])
603            .map_err(|e| Error::Inference(format!("ONNX batch run: {}", e)))?;
604
605        // Decode batch results
606        self.decode_ner_batch_output(&outputs, texts, &text_words, default_types, 0.5)
607    }
608
609    fn optimal_batch_size(&self) -> Option<usize> {
610        Some(16)
611    }
612}
613
614#[cfg(feature = "candle")]
615impl crate::BatchCapable for GLiNER2Candle {
616    fn extract_entities_batch(
617        &self,
618        texts: &[&str],
619        _language: Option<&str>,
620    ) -> Result<Vec<Vec<Entity>>> {
621        if texts.is_empty() {
622            return Ok(Vec::new());
623        }
624
625        let default_types = vec![
626            "person".to_string(),
627            "organization".to_string(),
628            "location".to_string(),
629            "date".to_string(),
630            "event".to_string(),
631        ];
632
633        // Pre-compute label embeddings once for all texts
634        let label_refs: Vec<&str> = default_types.iter().map(|s| s.as_str()).collect();
635        let _ = self.encode_labels_cached(&label_refs)?;
636
637        // Process texts - labels are now cached for efficiency
638        let mut results = Vec::with_capacity(texts.len());
639
640        for text in texts {
641            let entities = self.extract_entities(text, &default_types, 0.5)?;
642            results.push(entities);
643        }
644
645        Ok(results)
646    }
647
648    fn optimal_batch_size(&self) -> Option<usize> {
649        Some(8)
650    }
651}
652
653// =============================================================================
654// StreamingCapable Trait Implementation
655// =============================================================================
656
657#[cfg(feature = "onnx")]
658impl crate::StreamingCapable for GLiNER2Onnx {
659    // Uses default extract_entities_streaming implementation which adjusts offsets
660
661    fn recommended_chunk_size(&self) -> usize {
662        4096 // Characters - translates to roughly a few hundred words
663    }
664}
665
666#[cfg(feature = "candle")]
667impl crate::StreamingCapable for GLiNER2Candle {
668    // Uses default extract_entities_streaming implementation which adjusts offsets
669
670    fn recommended_chunk_size(&self) -> usize {
671        4096
672    }
673}
674
675// =============================================================================
676// GpuCapable Trait Implementation
677// =============================================================================
678
679#[cfg(feature = "candle")]
680impl crate::GpuCapable for GLiNER2Candle {
681    fn is_gpu_active(&self) -> bool {
682        matches!(&self.device, Device::Metal(_) | Device::Cuda(_))
683    }
684
685    fn device(&self) -> &str {
686        match &self.device {
687            Device::Cpu => "cpu",
688            Device::Metal(_) => "metal",
689            Device::Cuda(_) => "cuda",
690        }
691    }
692}
693
694// =============================================================================
695// Tests
696// =============================================================================
697
698#[cfg(test)]
699mod tests {
700    use super::*;
701
702    #[test]
703    #[cfg(any(feature = "onnx", feature = "candle"))]
704    fn test_relation_heuristic_unicode_safe_and_case_insensitive() {
705        use crate::backends::inference::RelationTriple;
706        use crate::offset::bytes_to_chars;
707
708        let text = "Dr. 田中 is CEO of Apple Inc. in 東京. François works at OpenAI.";
709        let span = |needle: &str| {
710            let (b_start, _) = text
711                .match_indices(needle)
712                .next()
713                .expect("needle should exist in test text");
714            let b_end = b_start + needle.len();
715            bytes_to_chars(text, b_start, b_end)
716        };
717
718        let (s, e) = span("田中");
719        let e_tanaka = Entity::new("田中", EntityType::Person, s, e, 0.9);
720        let (s, e) = span("Apple Inc.");
721        let e_apple = Entity::new("Apple Inc.", EntityType::Organization, s, e, 0.9);
722        let (s, e) = span("東京");
723        let e_tokyo = Entity::new("東京", EntityType::Location, s, e, 0.9);
724        let (s, e) = span("François");
725        let e_francois = Entity::new("François", EntityType::Person, s, e, 0.9);
726        let (s, e) = span("OpenAI");
727        let e_openai = Entity::new("OpenAI", EntityType::Organization, s, e, 0.9);
728
729        let entities = vec![e_tanaka, e_apple, e_tokyo, e_francois, e_openai];
730
731        // Should not panic on Unicode text; should detect at least one trigger relation.
732        let rels: Vec<RelationTriple> =
733            relations::extract_relations_heuristic(&entities, text, &[], 0.0);
734        assert!(
735            rels.iter()
736                .any(|r| r.relation_type == "CEO_OF" || r.relation_type == "WORKS_FOR"),
737            "expected at least one trigger-based relation, got {:?}",
738            rels
739        );
740    }
741
742    #[test]
743    fn test_task_schema_builder() {
744        let schema = TaskSchema::new()
745            .with_entities(&["person", "organization"])
746            .with_classification("sentiment", &["positive", "negative"], false);
747
748        assert!(schema.entities.is_some());
749        assert_eq!(schema.entities.as_ref().unwrap().types.len(), 2);
750        assert_eq!(schema.classifications.len(), 1);
751    }
752
753    #[test]
754    fn test_structure_task_builder() {
755        let task = StructureTask::new("product")
756            .with_field("name", FieldType::String)
757            .with_field_described("price", FieldType::String, "Product price in USD")
758            .with_choice_field("category", &["electronics", "clothing"]);
759
760        assert_eq!(task.fields.len(), 3);
761        assert_eq!(task.fields[2].choices.as_ref().unwrap().len(), 2);
762    }
763
764    #[test]
765    fn test_word_span_to_char_offsets() {
766        use crate::offset::TextSpan;
767
768        let text = "John works at Apple";
769        let words: Vec<&str> = text.split_whitespace().collect();
770
771        let (start, end) = word_span_to_char_offsets(text, &words, 0, 0);
772        assert_eq!(TextSpan::from_chars(text, start, end).extract(text), "John");
773
774        let (start, end) = word_span_to_char_offsets(text, &words, 3, 3);
775        assert_eq!(
776            TextSpan::from_chars(text, start, end).extract(text),
777            "Apple"
778        );
779
780        let (start, end) = word_span_to_char_offsets(text, &words, 0, 2);
781        assert_eq!(
782            TextSpan::from_chars(text, start, end).extract(text),
783            "John works at"
784        );
785    }
786
787    #[test]
788    fn test_map_entity_type() {
789        assert!(matches!(map_entity_type("person"), EntityType::Person));
790        assert!(matches!(
791            map_entity_type("ORGANIZATION"),
792            EntityType::Organization
793        ));
794        assert!(matches!(map_entity_type("loc"), EntityType::Location));
795        // Unknown types map to Other with the uppercase version (due to schema normalization)
796        assert!(
797            matches!(map_entity_type("custom_type"), EntityType::Other(ref s) if s == "CUSTOM_TYPE")
798        );
799        // Known special types map to Custom
800        assert!(matches!(
801            map_entity_type("product"),
802            EntityType::Custom { .. }
803        ));
804        assert!(matches!(
805            map_entity_type("event"),
806            EntityType::Custom { .. }
807        ));
808    }
809}