1#[cfg(feature = "onnx")]
52use crate::sync::lock;
53use crate::{Entity, EntityType, Error, Result};
54use anno_core::EntityCategory;
55#[cfg(feature = "candle")]
56use candle_core::Device;
57
58pub(crate) mod relations;
59
60use crate::backends::inference::{ExtractionWithRelations, RelationExtractor, ZeroShotNER};
61
62#[cfg(feature = "candle")]
63pub mod candle;
64#[cfg(feature = "onnx")]
65pub mod onnx;
66pub mod schema;
67#[cfg(feature = "candle")]
68pub use candle::GLiNER2Candle;
69#[cfg(feature = "onnx")]
70pub use onnx::GLiNER2Onnx;
71pub use schema::{
72 ClassificationResult, ClassificationTask, EntityTask, ExtractedStructure, ExtractionResult,
73 FieldType, LabelCache, StructureTask, StructureValue, TaskSchema,
74};
75
76#[cfg(not(any(feature = "onnx", feature = "candle")))]
81#[derive(Debug)]
82pub struct GLiNER2 {
83 _private: (),
84}
85
86#[cfg(not(any(feature = "onnx", feature = "candle")))]
87impl GLiNER2 {
88 pub fn from_pretrained(_model_id: &str) -> Result<Self> {
90 Err(Error::FeatureNotAvailable(
91 "GLiNER2 requires 'onnx' or 'candle' feature. \
92 Build with: cargo build --features candle"
93 .to_string(),
94 ))
95 }
96
97 pub fn extract(&self, _text: &str, _schema: &TaskSchema) -> Result<ExtractionResult> {
99 Err(Error::FeatureNotAvailable(
100 "GLiNER2 requires features".to_string(),
101 ))
102 }
103}
104
105#[cfg(feature = "candle")]
111pub type GLiNER2 = GLiNER2Candle;
112
113#[cfg(all(feature = "onnx", not(feature = "candle")))]
115pub type GLiNER2 = GLiNER2Onnx;
116
117pub(super) fn word_span_to_char_offsets(
123 text: &str,
124 words: &[&str],
125 start_word: usize,
126 end_word: usize,
127) -> (usize, usize) {
128 if words.is_empty()
130 || start_word >= words.len()
131 || end_word >= words.len()
132 || start_word > end_word
133 {
134 return (0, 0);
136 }
137
138 let mut byte_pos = 0;
140 let mut start_byte = 0;
141 let mut end_byte = text.len();
142 let mut found_start = false;
143 let mut found_end = false;
144
145 for (i, word) in words.iter().enumerate() {
146 if let Some(pos) = text.get(byte_pos..).and_then(|s| s.find(word)) {
147 let abs_pos = byte_pos + pos;
148
149 if i == start_word {
150 start_byte = abs_pos;
151 found_start = true;
152 }
153 if i == end_word {
154 end_byte = abs_pos + word.len();
155 found_end = true;
156 break;
158 }
159
160 byte_pos = abs_pos + word.len();
161 } else {
162 }
166 }
167
168 if !found_start || !found_end {
170 (0, 0)
172 } else {
173 crate::offset::bytes_to_chars(text, start_byte, end_byte)
175 }
176}
177
178pub(super) fn map_entity_type(type_str: &str) -> EntityType {
182 crate::schema::map_to_canonical(type_str, None)
183}
184
185#[cfg(feature = "onnx")]
190impl crate::Model for GLiNER2Onnx {
191 fn extract_entities(&self, text: &str, _language: Option<&str>) -> Result<Vec<Entity>> {
192 let schema = TaskSchema::new().with_entities(&[
193 "person",
194 "organization",
195 "location",
196 "date",
197 "event",
198 ]);
199
200 let result = self.extract(text, &schema)?;
201 Ok(result.entities)
202 }
203
204 fn supported_types(&self) -> Vec<EntityType> {
205 vec![
206 EntityType::Person,
207 EntityType::Organization,
208 EntityType::Location,
209 EntityType::Date,
210 EntityType::Custom {
211 name: "event".to_string(),
212 category: EntityCategory::Creative,
213 },
214 EntityType::Custom {
215 name: "product".to_string(),
216 category: EntityCategory::Creative,
217 },
218 EntityType::Other("misc".to_string()),
219 ]
220 }
221
222 fn is_available(&self) -> bool {
223 true
224 }
225
226 fn name(&self) -> &'static str {
227 "GLiNER2-ONNX"
228 }
229
230 fn description(&self) -> &'static str {
231 "Multi-task information extraction via GLiNER2 (ONNX backend)"
232 }
233
234 fn capabilities(&self) -> crate::ModelCapabilities {
235 crate::ModelCapabilities {
236 batch_capable: true,
237 streaming_capable: true,
238 relation_capable: true,
239 dynamic_labels: true,
240 ..Default::default()
241 }
242 }
243}
244
245#[cfg(feature = "onnx")]
246impl crate::NamedEntityCapable for GLiNER2Onnx {}
247
248#[cfg(feature = "onnx")]
249impl crate::DynamicLabels for GLiNER2Onnx {
250 fn extract_with_labels(
251 &self,
252 text: &str,
253 labels: &[&str],
254 _language: Option<&str>,
255 ) -> crate::Result<Vec<crate::Entity>> {
256 <Self as ZeroShotNER>::extract_with_types(self, text, labels, 0.3)
257 }
258}
259
260#[cfg(feature = "candle")]
265impl crate::Model for GLiNER2Candle {
266 fn extract_entities(&self, text: &str, _language: Option<&str>) -> Result<Vec<Entity>> {
267 let schema = TaskSchema::new().with_entities(&[
268 "person",
269 "organization",
270 "location",
271 "date",
272 "event",
273 ]);
274
275 let result = self.extract(text, &schema)?;
276 Ok(result.entities)
277 }
278
279 fn supported_types(&self) -> Vec<EntityType> {
280 vec![
281 EntityType::Person,
282 EntityType::Organization,
283 EntityType::Location,
284 EntityType::Date,
285 EntityType::Custom {
286 name: "event".to_string(),
287 category: EntityCategory::Creative,
288 },
289 EntityType::Custom {
290 name: "product".to_string(),
291 category: EntityCategory::Creative,
292 },
293 EntityType::Other("misc".to_string()),
294 ]
295 }
296
297 fn is_available(&self) -> bool {
298 true
299 }
300
301 fn name(&self) -> &'static str {
302 "GLiNER2-Candle"
303 }
304
305 fn description(&self) -> &'static str {
306 "Multi-task information extraction via GLiNER2 (native Rust/Candle)"
307 }
308
309 fn capabilities(&self) -> crate::ModelCapabilities {
310 crate::ModelCapabilities {
311 batch_capable: true,
312 streaming_capable: true,
313 gpu_capable: true,
314 relation_capable: true,
315 dynamic_labels: true,
316 ..Default::default()
317 }
318 }
319}
320
321#[cfg(feature = "candle")]
322impl crate::NamedEntityCapable for GLiNER2Candle {}
323
324#[cfg(feature = "candle")]
325impl crate::DynamicLabels for GLiNER2Candle {
326 fn extract_with_labels(
327 &self,
328 text: &str,
329 labels: &[&str],
330 _language: Option<&str>,
331 ) -> crate::Result<Vec<crate::Entity>> {
332 <Self as ZeroShotNER>::extract_with_types(self, text, labels, 0.3)
333 }
334}
335
336#[cfg(feature = "onnx")]
341impl ZeroShotNER for GLiNER2Onnx {
342 fn default_types(&self) -> &[&'static str] {
343 &["person", "organization", "location", "date", "event"]
344 }
345
346 fn extract_with_types(
347 &self,
348 text: &str,
349 types: &[&str],
350 threshold: f32,
351 ) -> Result<Vec<Entity>> {
352 self.extract_ner(text, types, threshold)
353 }
354
355 fn extract_with_descriptions(
356 &self,
357 text: &str,
358 descriptions: &[&str],
359 threshold: f32,
360 ) -> Result<Vec<Entity>> {
361 self.extract_ner(text, descriptions, threshold)
363 }
364}
365
366#[cfg(feature = "candle")]
367impl ZeroShotNER for GLiNER2Candle {
368 fn default_types(&self) -> &[&'static str] {
369 &["person", "organization", "location", "date", "event"]
370 }
371
372 fn extract_with_types(
373 &self,
374 text: &str,
375 types: &[&str],
376 threshold: f32,
377 ) -> Result<Vec<Entity>> {
378 let type_strings: Vec<String> = types.iter().map(|s| s.to_string()).collect();
379 self.extract_entities(text, &type_strings, threshold)
380 }
381
382 fn extract_with_descriptions(
383 &self,
384 text: &str,
385 descriptions: &[&str],
386 threshold: f32,
387 ) -> Result<Vec<Entity>> {
388 let type_strings: Vec<String> = descriptions.iter().map(|s| s.to_string()).collect();
390 self.extract_entities(text, &type_strings, threshold)
391 }
392}
393
394#[cfg(feature = "onnx")]
395impl RelationExtractor for GLiNER2Onnx {
396 fn extract_with_relations(
397 &self,
398 text: &str,
399 types: &[&str],
400 relation_types: &[&str],
401 threshold: f32,
402 ) -> Result<ExtractionWithRelations> {
403 let entities = self.extract_ner(text, types, threshold)?;
405
406 let relations =
408 relations::extract_relations_heuristic(&entities, text, relation_types, threshold);
409
410 Ok(ExtractionWithRelations {
411 entities,
412 relations,
413 })
414 }
415}
416
417#[cfg(feature = "candle")]
418impl RelationExtractor for GLiNER2Candle {
419 fn extract_with_relations(
420 &self,
421 text: &str,
422 types: &[&str],
423 relation_types: &[&str],
424 threshold: f32,
425 ) -> Result<ExtractionWithRelations> {
426 let type_strings: Vec<String> = types.iter().map(|s| s.to_string()).collect();
427 let entities = self.extract_entities(text, &type_strings, threshold)?;
428
429 let relations =
431 relations::extract_relations_heuristic(&entities, text, relation_types, threshold);
432
433 Ok(ExtractionWithRelations {
434 entities,
435 relations,
436 })
437 }
438}
439
440#[cfg(feature = "onnx")]
445impl crate::RelationCapable for GLiNER2Onnx {
446 fn extract_with_relations(
447 &self,
448 text: &str,
449 _language: Option<&str>,
450 ) -> Result<(Vec<Entity>, Vec<crate::Relation>)> {
451 use crate::backends::inference::{DEFAULT_ENTITY_TYPES, DEFAULT_RELATION_TYPES};
452 let result = <Self as RelationExtractor>::extract_with_relations(
453 self,
454 text,
455 DEFAULT_ENTITY_TYPES,
456 DEFAULT_RELATION_TYPES,
457 0.3,
458 )?;
459 Ok(result.into_anno_relations())
460 }
461}
462
463#[cfg(feature = "candle")]
464impl crate::RelationCapable for GLiNER2Candle {
465 fn extract_with_relations(
466 &self,
467 text: &str,
468 _language: Option<&str>,
469 ) -> Result<(Vec<Entity>, Vec<crate::Relation>)> {
470 use crate::backends::inference::{DEFAULT_ENTITY_TYPES, DEFAULT_RELATION_TYPES};
471 let result = <Self as RelationExtractor>::extract_with_relations(
472 self,
473 text,
474 DEFAULT_ENTITY_TYPES,
475 DEFAULT_RELATION_TYPES,
476 0.3,
477 )?;
478 Ok(result.into_anno_relations())
479 }
480}
481
482#[cfg(feature = "onnx")]
487impl crate::BatchCapable for GLiNER2Onnx {
488 fn extract_entities_batch(
489 &self,
490 texts: &[&str],
491 _language: Option<&str>,
492 ) -> Result<Vec<Vec<Entity>>> {
493 if texts.is_empty() {
494 return Ok(Vec::new());
495 }
496
497 let default_types = &["person", "organization", "location", "date", "event"];
498
499 let text_words: Vec<Vec<&str>> = texts
507 .iter()
508 .map(|t| t.split_whitespace().collect())
509 .collect();
510
511 let max_words = text_words.iter().map(|w| w.len()).max().unwrap_or(0);
513 if max_words == 0 {
514 return Ok(texts.iter().map(|_| Vec::new()).collect());
515 }
516
517 let mut all_input_ids = Vec::new();
519 let mut all_attention_masks = Vec::new();
520 let mut all_words_masks = Vec::new();
521 let mut all_text_lengths = Vec::new();
522 let mut seq_lens = Vec::new();
523
524 for words in &text_words {
525 if words.is_empty() {
526 seq_lens.push(0);
528 continue;
529 }
530
531 let (input_ids, attention_mask, words_mask) =
532 self.encode_ner_prompt(words, default_types)?;
533 seq_lens.push(input_ids.len());
534 all_input_ids.push(input_ids);
535 all_attention_masks.push(attention_mask);
536 all_words_masks.push(words_mask);
537 all_text_lengths.push(words.len() as i64);
538 }
539
540 if seq_lens.iter().all(|&l| l == 0) {
542 return Ok(texts.iter().map(|_| Vec::new()).collect());
543 }
544
545 let max_seq_len = seq_lens.iter().copied().max().unwrap_or(0);
547
548 for i in 0..all_input_ids.len() {
549 let pad_len = max_seq_len - all_input_ids[i].len();
550 all_input_ids[i].extend(std::iter::repeat_n(0i64, pad_len));
551 all_attention_masks[i].extend(std::iter::repeat_n(0i64, pad_len));
552 all_words_masks[i].extend(std::iter::repeat_n(0i64, pad_len));
553 }
554
555 use ndarray::Array2;
557
558 let batch_size = all_input_ids.len();
559
560 let input_ids_flat: Vec<i64> = all_input_ids.into_iter().flatten().collect();
561 let attention_mask_flat: Vec<i64> = all_attention_masks.into_iter().flatten().collect();
562 let words_mask_flat: Vec<i64> = all_words_masks.into_iter().flatten().collect();
563
564 let expected_input_len = batch_size * max_seq_len;
566 if input_ids_flat.len() != expected_input_len {
567 return Err(Error::Parse(format!(
568 "Input IDs length mismatch: expected {}, got {}",
569 expected_input_len,
570 input_ids_flat.len()
571 )));
572 }
573
574 let input_ids_arr = Array2::from_shape_vec((batch_size, max_seq_len), input_ids_flat)
575 .map_err(|e| Error::Parse(format!("Array: {}", e)))?;
576 let attention_mask_arr =
577 Array2::from_shape_vec((batch_size, max_seq_len), attention_mask_flat)
578 .map_err(|e| Error::Parse(format!("Array: {}", e)))?;
579 let words_mask_arr = Array2::from_shape_vec((batch_size, max_seq_len), words_mask_flat)
580 .map_err(|e| Error::Parse(format!("Array: {}", e)))?;
581 let text_lengths_arr = Array2::from_shape_vec((batch_size, 1), all_text_lengths)
582 .map_err(|e| Error::Parse(format!("Array: {}", e)))?;
583
584 let input_ids_t = super::ort_compat::tensor_from_ndarray(input_ids_arr)
585 .map_err(|e| Error::Parse(format!("Tensor: {}", e)))?;
586 let attention_mask_t = super::ort_compat::tensor_from_ndarray(attention_mask_arr)
587 .map_err(|e| Error::Parse(format!("Tensor: {}", e)))?;
588 let words_mask_t = super::ort_compat::tensor_from_ndarray(words_mask_arr)
589 .map_err(|e| Error::Parse(format!("Tensor: {}", e)))?;
590 let text_lengths_t = super::ort_compat::tensor_from_ndarray(text_lengths_arr)
591 .map_err(|e| Error::Parse(format!("Tensor: {}", e)))?;
592
593 let mut session = lock(&self.session);
595
596 let outputs = session
597 .run(ort::inputs![
598 "input_ids" => input_ids_t.into_dyn(),
599 "attention_mask" => attention_mask_t.into_dyn(),
600 "words_mask" => words_mask_t.into_dyn(),
601 "text_lengths" => text_lengths_t.into_dyn(),
602 ])
603 .map_err(|e| Error::Inference(format!("ONNX batch run: {}", e)))?;
604
605 self.decode_ner_batch_output(&outputs, texts, &text_words, default_types, 0.5)
607 }
608
609 fn optimal_batch_size(&self) -> Option<usize> {
610 Some(16)
611 }
612}
613
614#[cfg(feature = "candle")]
615impl crate::BatchCapable for GLiNER2Candle {
616 fn extract_entities_batch(
617 &self,
618 texts: &[&str],
619 _language: Option<&str>,
620 ) -> Result<Vec<Vec<Entity>>> {
621 if texts.is_empty() {
622 return Ok(Vec::new());
623 }
624
625 let default_types = vec![
626 "person".to_string(),
627 "organization".to_string(),
628 "location".to_string(),
629 "date".to_string(),
630 "event".to_string(),
631 ];
632
633 let label_refs: Vec<&str> = default_types.iter().map(|s| s.as_str()).collect();
635 let _ = self.encode_labels_cached(&label_refs)?;
636
637 let mut results = Vec::with_capacity(texts.len());
639
640 for text in texts {
641 let entities = self.extract_entities(text, &default_types, 0.5)?;
642 results.push(entities);
643 }
644
645 Ok(results)
646 }
647
648 fn optimal_batch_size(&self) -> Option<usize> {
649 Some(8)
650 }
651}
652
653#[cfg(feature = "onnx")]
658impl crate::StreamingCapable for GLiNER2Onnx {
659 fn recommended_chunk_size(&self) -> usize {
662 4096 }
664}
665
666#[cfg(feature = "candle")]
667impl crate::StreamingCapable for GLiNER2Candle {
668 fn recommended_chunk_size(&self) -> usize {
671 4096
672 }
673}
674
675#[cfg(feature = "candle")]
680impl crate::GpuCapable for GLiNER2Candle {
681 fn is_gpu_active(&self) -> bool {
682 matches!(&self.device, Device::Metal(_) | Device::Cuda(_))
683 }
684
685 fn device(&self) -> &str {
686 match &self.device {
687 Device::Cpu => "cpu",
688 Device::Metal(_) => "metal",
689 Device::Cuda(_) => "cuda",
690 }
691 }
692}
693
694#[cfg(test)]
699mod tests {
700 use super::*;
701
702 #[test]
703 #[cfg(any(feature = "onnx", feature = "candle"))]
704 fn test_relation_heuristic_unicode_safe_and_case_insensitive() {
705 use crate::backends::inference::RelationTriple;
706 use crate::offset::bytes_to_chars;
707
708 let text = "Dr. 田中 is CEO of Apple Inc. in 東京. François works at OpenAI.";
709 let span = |needle: &str| {
710 let (b_start, _) = text
711 .match_indices(needle)
712 .next()
713 .expect("needle should exist in test text");
714 let b_end = b_start + needle.len();
715 bytes_to_chars(text, b_start, b_end)
716 };
717
718 let (s, e) = span("田中");
719 let e_tanaka = Entity::new("田中", EntityType::Person, s, e, 0.9);
720 let (s, e) = span("Apple Inc.");
721 let e_apple = Entity::new("Apple Inc.", EntityType::Organization, s, e, 0.9);
722 let (s, e) = span("東京");
723 let e_tokyo = Entity::new("東京", EntityType::Location, s, e, 0.9);
724 let (s, e) = span("François");
725 let e_francois = Entity::new("François", EntityType::Person, s, e, 0.9);
726 let (s, e) = span("OpenAI");
727 let e_openai = Entity::new("OpenAI", EntityType::Organization, s, e, 0.9);
728
729 let entities = vec![e_tanaka, e_apple, e_tokyo, e_francois, e_openai];
730
731 let rels: Vec<RelationTriple> =
733 relations::extract_relations_heuristic(&entities, text, &[], 0.0);
734 assert!(
735 rels.iter()
736 .any(|r| r.relation_type == "CEO_OF" || r.relation_type == "WORKS_FOR"),
737 "expected at least one trigger-based relation, got {:?}",
738 rels
739 );
740 }
741
742 #[test]
743 fn test_task_schema_builder() {
744 let schema = TaskSchema::new()
745 .with_entities(&["person", "organization"])
746 .with_classification("sentiment", &["positive", "negative"], false);
747
748 assert!(schema.entities.is_some());
749 assert_eq!(schema.entities.as_ref().unwrap().types.len(), 2);
750 assert_eq!(schema.classifications.len(), 1);
751 }
752
753 #[test]
754 fn test_structure_task_builder() {
755 let task = StructureTask::new("product")
756 .with_field("name", FieldType::String)
757 .with_field_described("price", FieldType::String, "Product price in USD")
758 .with_choice_field("category", &["electronics", "clothing"]);
759
760 assert_eq!(task.fields.len(), 3);
761 assert_eq!(task.fields[2].choices.as_ref().unwrap().len(), 2);
762 }
763
764 #[test]
765 fn test_word_span_to_char_offsets() {
766 use crate::offset::TextSpan;
767
768 let text = "John works at Apple";
769 let words: Vec<&str> = text.split_whitespace().collect();
770
771 let (start, end) = word_span_to_char_offsets(text, &words, 0, 0);
772 assert_eq!(TextSpan::from_chars(text, start, end).extract(text), "John");
773
774 let (start, end) = word_span_to_char_offsets(text, &words, 3, 3);
775 assert_eq!(
776 TextSpan::from_chars(text, start, end).extract(text),
777 "Apple"
778 );
779
780 let (start, end) = word_span_to_char_offsets(text, &words, 0, 2);
781 assert_eq!(
782 TextSpan::from_chars(text, start, end).extract(text),
783 "John works at"
784 );
785 }
786
787 #[test]
788 fn test_map_entity_type() {
789 assert!(matches!(map_entity_type("person"), EntityType::Person));
790 assert!(matches!(
791 map_entity_type("ORGANIZATION"),
792 EntityType::Organization
793 ));
794 assert!(matches!(map_entity_type("loc"), EntityType::Location));
795 assert!(
797 matches!(map_entity_type("custom_type"), EntityType::Other(ref s) if s == "CUSTOM_TYPE")
798 );
799 assert!(matches!(
801 map_entity_type("product"),
802 EntityType::Custom { .. }
803 ));
804 assert!(matches!(
805 map_entity_type("event"),
806 EntityType::Custom { .. }
807 ));
808 }
809}