1#[cfg(feature = "onnx")]
52use crate::sync::{lock, Mutex};
53use crate::{Entity, EntityType, Error, Result};
54use anno_core::EntityCategory;
55use serde::{Deserialize, Serialize};
56use std::collections::HashMap;
57#[cfg(feature = "candle")]
58use std::sync::RwLock;
59
60use crate::backends::inference::{ExtractionWithRelations, RelationExtractor, ZeroShotNER};
62
63#[cfg(feature = "onnx")]
71const TOKEN_ENT: u32 = 128002;
72#[cfg(feature = "onnx")]
74const TOKEN_SEP: u32 = 128003;
75#[cfg(feature = "onnx")]
77const TOKEN_START: u32 = 1;
78#[cfg(feature = "onnx")]
80const TOKEN_END: u32 = 2;
81
82const MAX_SPAN_WIDTH: usize = 12;
84#[cfg(feature = "candle")]
86const MAX_COUNT: usize = 20;
87
88#[derive(Debug, Default)]
94pub struct LabelCache {
95 #[cfg(feature = "candle")]
96 cache: RwLock<HashMap<String, Vec<f32>>>,
97 #[cfg(not(feature = "candle"))]
98 _phantom: std::marker::PhantomData<()>,
99}
100
101#[cfg(feature = "candle")]
102impl LabelCache {
103 fn new() -> Self {
104 Self {
105 cache: RwLock::new(HashMap::new()),
106 }
107 }
108
109 fn get(&self, label: &str) -> Option<Vec<f32>> {
110 self.cache.read().ok()?.get(label).cloned()
111 }
112
113 fn insert(&self, label: String, embedding: Vec<f32>) {
114 if let Ok(mut cache) = self.cache.write() {
115 cache.insert(label, embedding);
116 }
117 }
118}
119
120#[cfg(not(feature = "candle"))]
121impl LabelCache {
122 #[allow(dead_code)]
123 fn new() -> Self {
124 Self {
125 _phantom: std::marker::PhantomData,
126 }
127 }
128}
129
130#[derive(Debug, Clone, Default, Serialize, Deserialize)]
149pub struct TaskSchema {
150 pub entities: Option<EntityTask>,
152 pub classifications: Vec<ClassificationTask>,
154 pub structures: Vec<StructureTask>,
156}
157
158impl TaskSchema {
159 pub fn new() -> Self {
161 Self::default()
162 }
163
164 pub fn with_entities(mut self, types: &[&str]) -> Self {
166 self.entities = Some(EntityTask {
167 types: types.iter().map(|s| s.to_string()).collect(),
168 descriptions: HashMap::new(),
169 });
170 self
171 }
172
173 pub fn with_entities_described(mut self, types_with_desc: HashMap<String, String>) -> Self {
175 let types: Vec<String> = types_with_desc.keys().cloned().collect();
176 self.entities = Some(EntityTask {
177 types,
178 descriptions: types_with_desc,
179 });
180 self
181 }
182
183 pub fn with_classification(mut self, name: &str, labels: &[&str], multi_label: bool) -> Self {
185 self.classifications.push(ClassificationTask {
186 name: name.to_string(),
187 labels: labels.iter().map(|s| s.to_string()).collect(),
188 multi_label,
189 descriptions: HashMap::new(),
190 });
191 self
192 }
193
194 pub fn with_structure(mut self, task: StructureTask) -> Self {
196 self.structures.push(task);
197 self
198 }
199}
200
201#[derive(Debug, Clone, Default, Serialize, Deserialize)]
203pub struct EntityTask {
204 pub types: Vec<String>,
206 pub descriptions: HashMap<String, String>,
208}
209
210#[derive(Debug, Clone, Default, Serialize, Deserialize)]
212pub struct ClassificationTask {
213 pub name: String,
215 pub labels: Vec<String>,
217 pub multi_label: bool,
219 pub descriptions: HashMap<String, String>,
221}
222
223#[derive(Debug, Clone, Default, Serialize, Deserialize)]
225pub struct StructureTask {
226 pub name: String,
228 #[serde(skip)]
230 pub structure_type: String,
231 pub fields: Vec<StructureField>,
233}
234
235impl StructureTask {
236 pub fn new(name: &str) -> Self {
238 Self {
239 name: name.to_string(),
240 structure_type: name.to_string(),
241 fields: Vec::new(),
242 }
243 }
244
245 pub fn with_field(mut self, name: &str, field_type: FieldType) -> Self {
247 self.fields.push(StructureField {
248 name: name.to_string(),
249 field_type,
250 description: None,
251 choices: None,
252 });
253 self
254 }
255
256 pub fn with_field_described(
258 mut self,
259 name: &str,
260 field_type: FieldType,
261 description: &str,
262 ) -> Self {
263 self.fields.push(StructureField {
264 name: name.to_string(),
265 field_type,
266 description: Some(description.to_string()),
267 choices: None,
268 });
269 self
270 }
271
272 pub fn with_choice_field(mut self, name: &str, choices: &[&str]) -> Self {
274 self.fields.push(StructureField {
275 name: name.to_string(),
276 field_type: FieldType::Choice,
277 description: None,
278 choices: Some(choices.iter().map(|s| s.to_string()).collect()),
279 });
280 self
281 }
282}
283
284#[derive(Debug, Clone, Serialize, Deserialize)]
286pub struct StructureField {
287 pub name: String,
289 pub field_type: FieldType,
291 pub description: Option<String>,
293 pub choices: Option<Vec<String>>,
295}
296
297#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
299pub enum FieldType {
300 String,
302 List,
304 Choice,
306}
307
308#[derive(Debug, Clone, Default, Serialize, Deserialize)]
314pub struct ExtractionResult {
315 pub entities: Vec<Entity>,
317 pub classifications: HashMap<String, ClassificationResult>,
319 pub structures: Vec<ExtractedStructure>,
321}
322
323#[derive(Debug, Clone, Default, Serialize, Deserialize)]
325pub struct ClassificationResult {
326 pub labels: Vec<String>,
328 pub scores: HashMap<String, f32>,
330}
331
332#[derive(Debug, Clone, Default, Serialize, Deserialize)]
334pub struct ExtractedStructure {
335 pub structure_type: String,
337 pub fields: HashMap<String, StructureValue>,
339}
340
341#[derive(Debug, Clone, Serialize, Deserialize)]
343pub enum StructureValue {
344 Single(String),
346 List(Vec<String>),
348}
349
350#[cfg(feature = "onnx")]
357#[derive(Debug)]
358pub struct GLiNER2Onnx {
359 session: Mutex<ort::session::Session>,
360 tokenizer: tokenizers::Tokenizer,
361 #[allow(dead_code)]
362 model_name: String,
363 #[allow(dead_code)]
364 hidden_size: usize,
365}
366
367#[cfg(feature = "onnx")]
368impl GLiNER2Onnx {
369 pub fn from_pretrained(model_id: &str) -> Result<Self> {
371 use hf_hub::api::sync::Api;
372 use ort::execution_providers::CPUExecutionProvider;
373 use ort::session::Session;
374
375 let api = Api::new().map_err(|e| Error::Retrieval(format!("HF API: {}", e)))?;
376 let repo = api.model(model_id.to_string());
377
378 let model_path = repo
380 .get("onnx/model.onnx")
381 .or_else(|_| repo.get("model.onnx"))
382 .map_err(|e| Error::Retrieval(format!("model.onnx: {}", e)))?;
383
384 let tokenizer_path = repo
385 .get("tokenizer.json")
386 .map_err(|e| Error::Retrieval(format!("tokenizer.json: {}", e)))?;
387
388 let config_path = repo
389 .get("config.json")
390 .map_err(|e| Error::Retrieval(format!("config.json: {}", e)))?;
391
392 let tokenizer = tokenizers::Tokenizer::from_file(&tokenizer_path)
394 .map_err(|e| Error::Retrieval(format!("tokenizer: {}", e)))?;
395
396 let config_str = std::fs::read_to_string(&config_path)
398 .map_err(|e| Error::Retrieval(format!("config read: {}", e)))?;
399 let config: serde_json::Value = serde_json::from_str(&config_str)
400 .map_err(|e| Error::Parse(format!("config parse: {}", e)))?;
401 let hidden_size = config["hidden_size"].as_u64().unwrap_or(768) as usize;
402
403 let session = Session::builder()
405 .map_err(|e| Error::Retrieval(format!("ONNX builder: {}", e)))?
406 .with_execution_providers([CPUExecutionProvider::default().build()])
407 .map_err(|e| Error::Retrieval(format!("ONNX providers: {}", e)))?
408 .commit_from_file(&model_path)
409 .map_err(|e| Error::Retrieval(format!("ONNX load: {}", e)))?;
410
411 log::info!(
412 "[GLiNER2-ONNX] Loaded {} (hidden={})",
413 model_id,
414 hidden_size
415 );
416 log::debug!("[GLiNER2-ONNX] Model loaded");
417
418 Ok(Self {
419 session: Mutex::new(session),
420 tokenizer,
421 model_name: model_id.to_string(),
422 hidden_size,
423 })
424 }
425
426 pub fn extract(&self, text: &str, schema: &TaskSchema) -> Result<ExtractionResult> {
428 let mut result = ExtractionResult::default();
429
430 if let Some(ref ent_task) = schema.entities {
432 let labels: Vec<&str> = ent_task.types.iter().map(|s| s.as_str()).collect();
433 let entities = self.extract_ner(text, &labels, 0.5)?;
434 result.entities = entities;
435 }
436
437 for class_task in &schema.classifications {
439 let labels: Vec<&str> = class_task.labels.iter().map(|s| s.as_str()).collect();
440 let class_result = self.classify(text, &labels, class_task.multi_label)?;
441 result
442 .classifications
443 .insert(class_task.name.clone(), class_result);
444 }
445
446 for struct_task in &schema.structures {
448 let structures = self.extract_structure(text, struct_task)?;
449 result.structures.extend(structures);
450 }
451
452 Ok(result)
453 }
454
455 fn extract_ner(
457 &self,
458 text: &str,
459 entity_types: &[&str],
460 threshold: f32,
461 ) -> Result<Vec<Entity>> {
462 if text.is_empty() || entity_types.is_empty() {
463 return Ok(Vec::new());
464 }
465
466 let text_words: Vec<&str> = text.split_whitespace().collect();
468 if text_words.is_empty() {
469 return Ok(Vec::new());
470 }
471
472 let (input_ids, attention_mask, words_mask) =
474 self.encode_ner_prompt(&text_words, entity_types)?;
475
476 use ndarray::Array2;
480
481 let batch_size = 1;
482 let seq_len = input_ids.len();
483
484 let input_ids_arr = Array2::from_shape_vec((batch_size, seq_len), input_ids)
485 .map_err(|e| Error::Parse(format!("Array: {}", e)))?;
486 let attention_mask_arr = Array2::from_shape_vec((batch_size, seq_len), attention_mask)
487 .map_err(|e| Error::Parse(format!("Array: {}", e)))?;
488 let words_mask_arr = Array2::from_shape_vec((batch_size, seq_len), words_mask)
489 .map_err(|e| Error::Parse(format!("Array: {}", e)))?;
490 let text_lengths_arr =
491 Array2::from_shape_vec((batch_size, 1), vec![text_words.len() as i64])
492 .map_err(|e| Error::Parse(format!("Array: {}", e)))?;
493
494 let input_ids_t = super::ort_compat::tensor_from_ndarray(input_ids_arr)
495 .map_err(|e| Error::Parse(format!("Tensor: {}", e)))?;
496 let attention_mask_t = super::ort_compat::tensor_from_ndarray(attention_mask_arr)
497 .map_err(|e| Error::Parse(format!("Tensor: {}", e)))?;
498 let words_mask_t = super::ort_compat::tensor_from_ndarray(words_mask_arr)
499 .map_err(|e| Error::Parse(format!("Tensor: {}", e)))?;
500 let text_lengths_t = super::ort_compat::tensor_from_ndarray(text_lengths_arr)
501 .map_err(|e| Error::Parse(format!("Tensor: {}", e)))?;
502
503 let mut session = lock(&self.session);
505
506 let outputs = session
507 .run(ort::inputs![
508 "input_ids" => input_ids_t.into_dyn(),
509 "attention_mask" => attention_mask_t.into_dyn(),
510 "words_mask" => words_mask_t.into_dyn(),
511 "text_lengths" => text_lengths_t.into_dyn(),
512 ])
513 .map_err(|e| Error::Inference(format!("ONNX run: {}", e)))?;
514
515 self.decode_ner_output(&outputs, text, &text_words, entity_types, threshold)
517 }
518
519 fn encode_ner_prompt(
521 &self,
522 text_words: &[&str],
523 entity_types: &[&str],
524 ) -> Result<(Vec<i64>, Vec<i64>, Vec<i64>)> {
525 let mut input_ids: Vec<i64> = Vec::new();
526 let mut word_mask: Vec<i64> = Vec::new();
527
528 input_ids.push(TOKEN_START as i64);
530 word_mask.push(0);
531
532 for entity_type in entity_types {
535 input_ids.push(TOKEN_ENT as i64);
536 word_mask.push(0);
537
538 let type_enc = self
539 .tokenizer
540 .encode(*entity_type, false)
541 .map_err(|e| Error::Parse(format!("Tokenize: {}", e)))?;
542 for token_id in type_enc.get_ids() {
543 input_ids.push(*token_id as i64);
544 word_mask.push(0);
545 }
546 }
547
548 input_ids.push(TOKEN_SEP as i64);
550 word_mask.push(0);
551
552 for (word_idx, word) in text_words.iter().enumerate() {
554 let word_enc = self
555 .tokenizer
556 .encode(*word, false)
557 .map_err(|e| Error::Parse(format!("Tokenize: {}", e)))?;
558
559 let word_id = (word_idx + 1) as i64; for (token_idx, token_id) in word_enc.get_ids().iter().enumerate() {
561 input_ids.push(*token_id as i64);
562 word_mask.push(if token_idx == 0 { word_id } else { 0 });
564 }
565 }
566
567 input_ids.push(TOKEN_END as i64);
569 word_mask.push(0);
570
571 let seq_len = input_ids.len();
572 let attention_mask: Vec<i64> = vec![1; seq_len];
573
574 Ok((input_ids, attention_mask, word_mask))
575 }
576
577 #[allow(dead_code)]
580 fn make_span_tensors(&self, num_words: usize) -> (Vec<i64>, Vec<bool>) {
581 let num_spans = num_words.checked_mul(MAX_SPAN_WIDTH).unwrap_or_else(|| {
583 log::warn!(
584 "Span count overflow: {} words * {} MAX_SPAN_WIDTH, using max",
585 num_words,
586 MAX_SPAN_WIDTH
587 );
588 usize::MAX
589 });
590 let span_idx_len = num_spans.checked_mul(2).unwrap_or_else(|| {
592 log::warn!(
593 "Span idx length overflow: {} spans * 2, using max",
594 num_spans
595 );
596 usize::MAX
597 });
598 let mut span_idx: Vec<i64> = vec![0; span_idx_len];
599 let mut span_mask: Vec<bool> = vec![false; num_spans];
600
601 for start in 0..num_words {
602 let remaining = num_words - start;
603 let max_width = MAX_SPAN_WIDTH.min(remaining);
604
605 for width in 0..max_width {
606 let dim = match start.checked_mul(MAX_SPAN_WIDTH) {
608 Some(v) => match v.checked_add(width) {
609 Some(d) => d,
610 None => {
611 log::warn!(
612 "Dim calculation overflow: {} * {} + {}, skipping span",
613 start,
614 MAX_SPAN_WIDTH,
615 width
616 );
617 continue;
618 }
619 },
620 None => {
621 log::warn!(
622 "Dim calculation overflow: {} * {}, skipping span",
623 start,
624 MAX_SPAN_WIDTH
625 );
626 continue;
627 }
628 };
629 if let Some(dim2) = dim.checked_mul(2) {
631 if dim2 + 1 < span_idx_len && dim < num_spans {
632 span_idx[dim2] = start as i64;
633 span_idx[dim2 + 1] = (start + width) as i64;
634 span_mask[dim] = true;
635 } else {
636 log::warn!(
637 "Span idx access out of bounds: dim={}, dim*2={}, span_idx_len={}, num_spans={}, skipping",
638 dim, dim2, span_idx_len, num_spans
639 );
640 }
641 } else {
642 log::warn!("Dim * 2 overflow: dim={}, skipping span", dim);
643 }
644 }
645 }
646
647 (span_idx, span_mask)
648 }
649
650 fn decode_ner_output(
652 &self,
653 outputs: &ort::session::SessionOutputs,
654 text: &str,
655 text_words: &[&str],
656 entity_types: &[&str],
657 threshold: f32,
658 ) -> Result<Vec<Entity>> {
659 let output = outputs
660 .iter()
661 .next()
662 .map(|(_, v)| v)
663 .ok_or_else(|| Error::Parse("No output".into()))?;
664
665 let (_, data_slice) = output
666 .try_extract_tensor::<f32>()
667 .map_err(|e| Error::Parse(format!("Extract tensor: {}", e)))?;
668 let output_data: Vec<f32> = data_slice.to_vec();
669
670 let shape: Vec<i64> = match output.dtype() {
671 ort::value::ValueType::Tensor { shape, .. } => shape.iter().copied().collect(),
672 _ => return Err(Error::Parse("Not a tensor".into())),
673 };
674
675 if output_data.is_empty() || shape.contains(&0) {
676 return Err(Error::Inference(
677 "GLiNER2 ONNX returned empty/degenerate output tensor. This usually indicates an incompatible ONNX export (shape mismatch or missing dynamic axes).".to_string(),
678 ));
679 }
680
681 let mut entities = Vec::new();
682 let num_words = text_words.len();
683
684 if shape.len() == 4 && shape[0] == 3 && shape[1] == 1 {
687 let out_num_words = shape[2] as usize;
688 let num_classes = shape[3] as usize;
689 let word_class_size = out_num_words * num_classes;
690
691 let b_offset = 0_usize; let i_offset = word_class_size; #[allow(clippy::needless_range_loop)] for class_idx in 0..num_classes.min(entity_types.len()) {
699 let mut current_start: Option<(usize, f32)> = None; for word_idx in 0..out_num_words.min(num_words) {
702 let b_idx = b_offset + word_idx * num_classes + class_idx;
704 let i_idx = i_offset + word_idx * num_classes + class_idx;
706
707 let b_logit = if b_idx < output_data.len() {
708 output_data[b_idx]
709 } else {
710 -100.0
711 };
712 let i_logit = if i_idx < output_data.len() {
713 output_data[i_idx]
714 } else {
715 -100.0
716 };
717
718 let b_score = 1.0 / (1.0 + (-b_logit).exp());
719 let i_score = 1.0 / (1.0 + (-i_logit).exp());
720
721 if b_score >= threshold {
722 if let Some((start_word, avg_score)) = current_start.take() {
724 let end_word = word_idx - 1;
725 if start_word <= end_word && end_word < num_words {
726 let span_text = text_words[start_word..=end_word].join(" ");
727 let (start, end) = word_span_to_char_offsets(
728 text, text_words, start_word, end_word,
729 );
730 let entity_type = map_entity_type(entity_types[class_idx]);
731 entities.push(Entity::new(
732 span_text,
733 entity_type,
734 start,
735 end,
736 avg_score as f64,
737 ));
738 }
739 }
740 current_start = Some((word_idx, b_score));
742 } else if i_score >= threshold && current_start.is_some() {
743 if let Some((start_word, score)) = current_start {
745 current_start = Some((start_word, (score + i_score) / 2.0));
746 }
747 } else if current_start.is_some() {
748 if let Some((start_word, avg_score)) = current_start.take() {
750 let end_word = word_idx - 1;
751 if start_word <= end_word && end_word < num_words {
752 let span_text = text_words[start_word..=end_word].join(" ");
753 let (start, end) = word_span_to_char_offsets(
754 text, text_words, start_word, end_word,
755 );
756 let entity_type = map_entity_type(entity_types[class_idx]);
757 entities.push(Entity::new(
758 span_text,
759 entity_type,
760 start,
761 end,
762 avg_score as f64,
763 ));
764 }
765 }
766 }
767 }
768
769 if let Some((start_word, avg_score)) = current_start.take() {
771 let end_word = out_num_words.min(num_words) - 1;
772 if start_word <= end_word {
773 let span_text = text_words[start_word..=end_word].join(" ");
774 let (start, end) =
775 word_span_to_char_offsets(text, text_words, start_word, end_word);
776 let entity_type = map_entity_type(entity_types[class_idx]);
777 entities.push(Entity::new(
778 span_text,
779 entity_type,
780 start,
781 end,
782 avg_score as f64,
783 ));
784 }
785 }
786 }
787 }
788 else if shape.len() == 4 && shape[0] == 1 {
790 let out_num_words = shape[1] as usize;
791 let out_max_width = shape[2] as usize;
792 let num_classes = shape[3] as usize;
793
794 for word_idx in 0..out_num_words.min(num_words) {
795 for width in 0..out_max_width.min(MAX_SPAN_WIDTH) {
796 let end_word = word_idx + width;
797 if end_word >= num_words {
798 continue;
799 }
800
801 #[allow(clippy::needless_range_loop)] for class_idx in 0..num_classes.min(entity_types.len()) {
803 let idx = (word_idx * out_max_width * num_classes)
804 + (width * num_classes)
805 + class_idx;
806
807 if idx < output_data.len() {
808 let logit = output_data[idx];
809 let score = 1.0 / (1.0 + (-logit).exp());
810
811 if score >= threshold {
812 let span_text = text_words[word_idx..=end_word].join(" ");
813 let (start, end) =
814 word_span_to_char_offsets(text, text_words, word_idx, end_word);
815
816 let entity_type = map_entity_type(entity_types[class_idx]);
817
818 entities.push(Entity::new(
819 span_text,
820 entity_type,
821 start,
822 end,
823 score as f64,
824 ));
825 }
826 }
827 }
828 }
829 }
830 }
831
832 entities.sort_by(|a, b| a.start.cmp(&b.start).then_with(|| b.end.cmp(&a.end)));
834 entities.dedup_by(|a, b| a.start == b.start && a.end == b.end);
835
836 Ok(entities)
837 }
838
839 fn decode_ner_batch_output(
841 &self,
842 outputs: &ort::session::SessionOutputs,
843 texts: &[&str],
844 text_words_batch: &[Vec<&str>],
845 entity_types: &[&str],
846 threshold: f32,
847 ) -> Result<Vec<Vec<Entity>>> {
848 let output = outputs
849 .iter()
850 .next()
851 .map(|(_, v)| v)
852 .ok_or_else(|| Error::Parse("No output".into()))?;
853
854 let (_, data_slice) = output
855 .try_extract_tensor::<f32>()
856 .map_err(|e| Error::Parse(format!("Extract tensor: {}", e)))?;
857 let output_data: Vec<f32> = data_slice.to_vec();
858
859 let shape: Vec<i64> = match output.dtype() {
860 ort::value::ValueType::Tensor { shape, .. } => shape.iter().copied().collect(),
861 _ => return Err(Error::Parse("Not a tensor".into())),
862 };
863
864 if output_data.is_empty() || shape.contains(&0) {
865 return Err(Error::Inference(
866 "GLiNER2 ONNX returned empty/degenerate output tensor. This usually indicates an incompatible ONNX export (shape mismatch or missing dynamic axes).".to_string(),
867 ));
868 }
869
870 let mut results = Vec::with_capacity(texts.len());
871
872 if shape.len() == 4 && shape[0] == 3 {
874 let batch_size = shape[1] as usize;
875 let out_num_words = shape[2] as usize;
876 let num_classes = shape[3] as usize;
877
878 let per_bio = batch_size * out_num_words * num_classes;
879 let per_batch = out_num_words * num_classes;
880
881 for batch_idx in 0..batch_size.min(texts.len()) {
882 let text = texts[batch_idx];
883 let text_words = &text_words_batch[batch_idx];
884 let num_words = text_words.len();
885 let mut entities = Vec::new();
886
887 #[allow(clippy::needless_range_loop)] for class_idx in 0..num_classes.min(entity_types.len()) {
890 let mut current_start: Option<(usize, f32)> = None; for word_idx in 0..out_num_words.min(num_words) {
893 let b_idx = (batch_idx * per_batch) + (word_idx * num_classes) + class_idx;
895 let i_idx = per_bio
897 + (batch_idx * per_batch)
898 + (word_idx * num_classes)
899 + class_idx;
900
901 let b_logit = output_data.get(b_idx).copied().unwrap_or(-100.0);
902 let i_logit = output_data.get(i_idx).copied().unwrap_or(-100.0);
903
904 let b_score = 1.0 / (1.0 + (-b_logit).exp());
905 let i_score = 1.0 / (1.0 + (-i_logit).exp());
906
907 if b_score >= threshold {
908 if let Some((start_word, avg_score)) = current_start.take() {
910 let end_word = word_idx.saturating_sub(1);
911 if start_word <= end_word && end_word < num_words {
912 let span_text = text_words[start_word..=end_word].join(" ");
913 let (start, end) = word_span_to_char_offsets(
914 text, text_words, start_word, end_word,
915 );
916 let entity_type = map_entity_type(entity_types[class_idx]);
917 entities.push(Entity::new(
918 span_text,
919 entity_type,
920 start,
921 end,
922 avg_score as f64,
923 ));
924 }
925 }
926 current_start = Some((word_idx, b_score));
928 } else if i_score >= threshold && current_start.is_some() {
929 if let Some((start_word, score)) = current_start {
931 current_start = Some((start_word, (score + i_score) / 2.0));
932 }
933 } else if current_start.is_some() {
934 if let Some((start_word, avg_score)) = current_start.take() {
936 let end_word = word_idx.saturating_sub(1);
937 if start_word <= end_word && end_word < num_words {
938 let span_text = text_words[start_word..=end_word].join(" ");
939 let (start, end) = word_span_to_char_offsets(
940 text, text_words, start_word, end_word,
941 );
942 let entity_type = map_entity_type(entity_types[class_idx]);
943 entities.push(Entity::new(
944 span_text,
945 entity_type,
946 start,
947 end,
948 avg_score as f64,
949 ));
950 }
951 }
952 }
953 }
954
955 if let Some((start_word, avg_score)) = current_start.take() {
957 if !text_words.is_empty() {
958 let end_word = out_num_words.min(num_words).saturating_sub(1);
959 if start_word <= end_word && end_word < num_words {
960 let span_text = text_words[start_word..=end_word].join(" ");
961 let (start, end) = word_span_to_char_offsets(
962 text, text_words, start_word, end_word,
963 );
964 let entity_type = map_entity_type(entity_types[class_idx]);
965 entities.push(Entity::new(
966 span_text,
967 entity_type,
968 start,
969 end,
970 avg_score as f64,
971 ));
972 }
973 }
974 }
975 }
976
977 entities
978 .sort_unstable_by(|a, b| a.start.cmp(&b.start).then_with(|| b.end.cmp(&a.end)));
979 entities.dedup_by(|a, b| a.start == b.start && a.end == b.end);
980 results.push(entities);
981 }
982 }
983 else if shape.len() == 4 {
985 let batch_size = shape[0] as usize;
986 let out_num_words = shape[1] as usize;
987 let out_max_width = shape[2] as usize;
988 let num_classes = shape[3] as usize;
989 let stride_per_batch = out_num_words * out_max_width * num_classes;
990
991 for batch_idx in 0..batch_size.min(texts.len()) {
992 let text = texts[batch_idx];
993 let text_words = &text_words_batch[batch_idx];
994 let num_words = text_words.len();
995 let batch_offset = batch_idx * stride_per_batch;
996 let mut entities = Vec::new();
997
998 for word_idx in 0..out_num_words.min(num_words) {
999 for width in 0..out_max_width.min(MAX_SPAN_WIDTH) {
1000 let end_word = word_idx + width;
1001 if end_word >= num_words {
1002 continue;
1003 }
1004
1005 #[allow(clippy::needless_range_loop)] for class_idx in 0..num_classes.min(entity_types.len()) {
1007 let idx = batch_offset
1008 + (word_idx * out_max_width * num_classes)
1009 + (width * num_classes)
1010 + class_idx;
1011
1012 if idx < output_data.len() {
1013 let logit = output_data[idx];
1014 let score = 1.0 / (1.0 + (-logit).exp());
1015
1016 if score >= threshold {
1017 let span_text = text_words[word_idx..=end_word].join(" ");
1018 let (start, end) = word_span_to_char_offsets(
1019 text, text_words, word_idx, end_word,
1020 );
1021
1022 let entity_type = map_entity_type(entity_types[class_idx]);
1023
1024 entities.push(Entity::new(
1025 span_text,
1026 entity_type,
1027 start,
1028 end,
1029 score as f64,
1030 ));
1031 }
1032 }
1033 }
1034 }
1035 }
1036
1037 entities
1038 .sort_unstable_by(|a, b| a.start.cmp(&b.start).then_with(|| b.end.cmp(&a.end)));
1039 entities.dedup_by(|a, b| a.start == b.start && a.end == b.end);
1040 results.push(entities);
1041 }
1042 } else {
1043 return Err(Error::Inference(format!(
1044 "Unsupported GLiNER2 batch output shape: {:?}. Expected [3,batch,words,classes] (BIO) or [batch,words,width,classes] (span-level).",
1045 shape
1046 )));
1047 }
1048
1049 while results.len() < texts.len() {
1051 results.push(Vec::new());
1052 }
1053
1054 Ok(results)
1055 }
1056
1057 fn classify(
1059 &self,
1060 text: &str,
1061 labels: &[&str],
1062 multi_label: bool,
1063 ) -> Result<ClassificationResult> {
1064 if text.is_empty() || labels.is_empty() {
1065 return Ok(ClassificationResult::default());
1066 }
1067
1068 let mut input_ids: Vec<i64> = Vec::new();
1073
1074 input_ids.push(TOKEN_START as i64);
1075
1076 for label in labels {
1078 input_ids.push(TOKEN_ENT as i64);
1079 let label_enc = self
1080 .tokenizer
1081 .encode(*label, false)
1082 .map_err(|e| Error::Parse(format!("Tokenize: {}", e)))?;
1083 for id in label_enc.get_ids() {
1084 input_ids.push(*id as i64);
1085 }
1086 }
1087
1088 input_ids.push(TOKEN_SEP as i64);
1089
1090 let text_enc = self
1092 .tokenizer
1093 .encode(text, false)
1094 .map_err(|e| Error::Parse(format!("Tokenize: {}", e)))?;
1095 for id in text_enc.get_ids() {
1096 input_ids.push(*id as i64);
1097 }
1098
1099 input_ids.push(TOKEN_END as i64);
1100
1101 let seq_len = input_ids.len();
1102 let attention_mask: Vec<i64> = vec![1; seq_len];
1103
1104 use ndarray::Array2;
1105
1106 let input_arr = Array2::from_shape_vec((1, seq_len), input_ids)
1107 .map_err(|e| Error::Parse(format!("Array: {}", e)))?;
1108 let attn_arr = Array2::from_shape_vec((1, seq_len), attention_mask)
1109 .map_err(|e| Error::Parse(format!("Array: {}", e)))?;
1110
1111 let input_t = super::ort_compat::tensor_from_ndarray(input_arr)
1112 .map_err(|e| Error::Parse(format!("Tensor: {}", e)))?;
1113 let attn_t = super::ort_compat::tensor_from_ndarray(attn_arr)
1114 .map_err(|e| Error::Parse(format!("Tensor: {}", e)))?;
1115
1116 let mut session = lock(&self.session);
1119
1120 let outputs = session
1122 .run(ort::inputs![
1123 "input_ids" => input_t.into_dyn(),
1124 "attention_mask" => attn_t.into_dyn(),
1125 ])
1126 .map_err(|e| Error::Inference(format!("ONNX run: {}", e)))?;
1127
1128 let output = outputs
1130 .iter()
1131 .next()
1132 .map(|(_, v)| v)
1133 .ok_or_else(|| Error::Parse("No output".into()))?;
1134
1135 let (_, data_slice) = output
1136 .try_extract_tensor::<f32>()
1137 .map_err(|e| Error::Parse(format!("Extract: {}", e)))?;
1138 let logits: Vec<f32> = data_slice.to_vec();
1139
1140 let probs = if multi_label {
1142 logits.iter().map(|&x| 1.0 / (1.0 + (-x).exp())).collect()
1143 } else {
1144 let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
1145 let exp_logits: Vec<f32> = logits.iter().map(|&x| (x - max_logit).exp()).collect();
1146 let sum: f32 = exp_logits.iter().sum();
1147 if sum > 0.0 {
1149 exp_logits.iter().map(|&x| x / sum).collect::<Vec<_>>()
1150 } else if logits.is_empty() {
1151 vec![]
1153 } else {
1154 let uniform = 1.0 / logits.len() as f32;
1156 vec![uniform; logits.len()]
1157 }
1158 };
1159
1160 let mut scores = HashMap::new();
1161 let mut selected_labels: Vec<String> = Vec::new();
1162
1163 for (i, label) in labels.iter().enumerate() {
1164 let prob = probs.get(i).copied().unwrap_or(0.0);
1165 scores.insert((*label).to_string(), prob);
1166
1167 if multi_label && prob > 0.5 {
1168 selected_labels.push((*label).to_string());
1169 }
1170 }
1171
1172 if !multi_label {
1173 if let Some((idx, _)) = probs
1174 .iter()
1175 .enumerate()
1176 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
1177 {
1178 if let Some(label) = labels.get(idx) {
1179 selected_labels.push((*label).to_string());
1180 }
1181 }
1182 }
1183
1184 Ok(ClassificationResult {
1185 labels: selected_labels,
1186 scores,
1187 })
1188 }
1189
1190 fn extract_structure(
1192 &self,
1193 text: &str,
1194 task: &StructureTask,
1195 ) -> Result<Vec<ExtractedStructure>> {
1196 if text.is_empty() || task.fields.is_empty() {
1197 return Ok(Vec::new());
1198 }
1199
1200 let mut structures = Vec::new();
1205
1206 let field_names: Vec<&str> = task.fields.iter().map(|f| f.name.as_str()).collect();
1208 let field_entities = self.extract_ner(text, &field_names, 0.3)?;
1209
1210 let mut structure = ExtractedStructure {
1212 structure_type: task.name.clone(),
1213 fields: HashMap::new(),
1214 };
1215
1216 for field in &task.fields {
1217 let matching: Vec<_> = field_entities
1218 .iter()
1219 .filter(|e| e.entity_type.as_label().eq_ignore_ascii_case(&field.name))
1220 .collect();
1221
1222 if !matching.is_empty() {
1223 let value = match field.field_type {
1224 FieldType::List => {
1225 let values: Vec<String> = matching.iter().map(|e| e.text.clone()).collect();
1226 StructureValue::List(values)
1227 }
1228 FieldType::Choice => {
1229 if let Some(ref choices) = field.choices {
1230 let extracted = matching.first().map(|e| e.text.as_str()).unwrap_or("");
1231 let best = choices
1232 .iter()
1233 .find(|c| extracted.to_lowercase().contains(&c.to_lowercase()))
1234 .cloned()
1235 .unwrap_or_else(|| extracted.to_string());
1236 StructureValue::Single(best)
1237 } else {
1238 StructureValue::Single(
1239 matching.first().map(|e| e.text.clone()).unwrap_or_default(),
1240 )
1241 }
1242 }
1243 FieldType::String => StructureValue::Single(
1244 matching.first().map(|e| e.text.clone()).unwrap_or_default(),
1245 ),
1246 };
1247 structure.fields.insert(field.name.clone(), value);
1248 }
1249 }
1250
1251 if !structure.fields.is_empty() {
1252 structures.push(structure);
1253 }
1254
1255 Ok(structures)
1256 }
1257
1258 #[allow(dead_code)]
1260 fn build_prompt(&self, schema: &TaskSchema) -> String {
1261 let mut parts = Vec::new();
1262
1263 if let Some(ref ent_task) = schema.entities {
1264 let types: Vec<String> = ent_task
1265 .types
1266 .iter()
1267 .map(|t| {
1268 if let Some(desc) = ent_task.descriptions.get(t) {
1269 format!("[E] {}: {}", t, desc)
1270 } else {
1271 format!("[E] {}", t)
1272 }
1273 })
1274 .collect();
1275 parts.push(format!("[P] entities ({})", types.join(" ")));
1276 }
1277
1278 for class_task in &schema.classifications {
1279 let labels: Vec<String> = class_task
1280 .labels
1281 .iter()
1282 .map(|l| format!("[L] {}", l))
1283 .collect();
1284 parts.push(format!("[P] {} ({})", class_task.name, labels.join(" ")));
1285 }
1286
1287 for struct_task in &schema.structures {
1288 let fields: Vec<String> = struct_task
1289 .fields
1290 .iter()
1291 .map(|f| format!("[C] {}", f.name))
1292 .collect();
1293 parts.push(format!("[P] {} ({})", struct_task.name, fields.join(" ")));
1294 }
1295
1296 parts.join(" [SEP] ")
1297 }
1298}
1299
1300#[cfg(feature = "candle")]
1305use crate::backends::encoder_candle::TextEncoder;
1306#[cfg(feature = "candle")]
1307use candle_core::{DType, Device, IndexOp, Module, Tensor, D};
1308#[cfg(feature = "candle")]
1309use candle_nn::{Linear, VarBuilder};
1310
1311#[cfg(feature = "candle")]
1313#[derive(Debug)]
1314pub struct GLiNER2Candle {
1315 encoder: crate::backends::encoder_candle::CandleEncoder,
1317 span_rep: SpanRepLayer,
1319 label_proj: Linear,
1321 class_head: ClassificationHead,
1323 count_predictor: CountPredictor,
1325 device: Device,
1327 #[allow(dead_code)]
1328 model_name: String,
1329 hidden_size: usize,
1330 label_cache: LabelCache,
1332}
1333
1334#[cfg(feature = "candle")]
1336pub struct SpanRepLayer {
1337 width_embeddings: candle_nn::Embedding,
1339 max_width: usize,
1341}
1342
1343#[cfg(feature = "candle")]
1344impl std::fmt::Debug for SpanRepLayer {
1345 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1346 f.debug_struct("SpanRepLayer")
1347 .field("max_width", &self.max_width)
1348 .finish()
1349 }
1350}
1351
1352#[cfg(feature = "candle")]
1354pub struct ClassificationHead {
1355 mlp: Linear,
1357}
1358
1359#[cfg(feature = "candle")]
1360impl std::fmt::Debug for ClassificationHead {
1361 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1362 f.debug_struct("ClassificationHead").finish()
1363 }
1364}
1365
1366#[cfg(feature = "candle")]
1368pub struct CountPredictor {
1369 mlp: Linear,
1371}
1372
1373#[cfg(feature = "candle")]
1374impl std::fmt::Debug for CountPredictor {
1375 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1376 f.debug_struct("CountPredictor").finish()
1377 }
1378}
1379
1380#[cfg(feature = "candle")]
1381impl SpanRepLayer {
1382 fn new(hidden_size: usize, max_width: usize, vb: VarBuilder) -> Result<Self> {
1383 let width_embeddings =
1384 candle_nn::embedding(max_width, hidden_size, vb.pp("width_embeddings"))
1385 .map_err(|e| Error::Retrieval(format!("width_embeddings: {}", e)))?;
1386 Ok(Self {
1387 width_embeddings,
1388 max_width,
1389 })
1390 }
1391
1392 fn forward(&self, token_embeddings: &Tensor, span_indices: &Tensor) -> Result<Tensor> {
1393 let device = token_embeddings.device();
1394 let batch_size = token_embeddings.dims()[0];
1395 let _seq_len = token_embeddings.dims()[1];
1396 let hidden_size = token_embeddings.dims()[2];
1397 let num_spans = span_indices.dims()[1];
1398
1399 let mut all_span_embs = Vec::new();
1400
1401 for b in 0..batch_size {
1402 let batch_tokens = token_embeddings
1403 .i(b)
1404 .map_err(|e| Error::Inference(format!("batch index: {}", e)))?;
1405 let batch_spans = span_indices
1406 .i(b)
1407 .map_err(|e| Error::Inference(format!("span index: {}", e)))?;
1408
1409 let spans_data = batch_spans
1410 .to_vec2::<i64>()
1411 .map_err(|e| Error::Inference(format!("spans to vec: {}", e)))?;
1412
1413 let mut span_embs = Vec::new();
1414
1415 for span in spans_data {
1416 let start = span[0] as usize;
1417 let end = span[1] as usize;
1418 if end <= start {
1420 log::warn!("Invalid span: end ({}) <= start ({})", end, start);
1421 continue;
1422 }
1423 let width = end - start;
1424
1425 let start_emb = batch_tokens
1427 .i(start.min(batch_tokens.dims()[0] - 1))
1428 .map_err(|e| Error::Inference(format!("start emb: {}", e)))?;
1429
1430 let width_idx = width.min(self.max_width - 1);
1432 let width_emb = self
1433 .width_embeddings
1434 .forward(
1435 &Tensor::new(&[width_idx as u32], device)
1436 .map_err(|e| Error::Inference(format!("width idx: {}", e)))?,
1437 )
1438 .map_err(|e| Error::Inference(format!("width emb: {}", e)))?
1439 .squeeze(0)
1440 .map_err(|e| Error::Inference(format!("squeeze: {}", e)))?;
1441
1442 let combined = start_emb
1444 .add(&width_emb)
1445 .map_err(|e| Error::Inference(format!("add: {}", e)))?;
1446
1447 let emb_vec = combined
1448 .to_vec1::<f32>()
1449 .map_err(|e| Error::Inference(format!("to vec: {}", e)))?;
1450 span_embs.extend(emb_vec);
1451 }
1452
1453 all_span_embs.extend(span_embs);
1454 }
1455
1456 Tensor::from_vec(all_span_embs, (batch_size, num_spans, hidden_size), device)
1457 .map_err(|e| Error::Inference(format!("span tensor: {}", e)))
1458 }
1459}
1460
1461#[cfg(feature = "candle")]
1462impl ClassificationHead {
1463 fn new(hidden_size: usize, vb: VarBuilder) -> Result<Self> {
1464 let mlp = candle_nn::linear(hidden_size, 1, vb.pp("mlp"))
1465 .map_err(|e| Error::Retrieval(format!("classification mlp: {}", e)))?;
1466 Ok(Self { mlp })
1467 }
1468
1469 fn forward(&self, label_embeddings: &Tensor) -> Result<Tensor> {
1471 self.mlp
1472 .forward(label_embeddings)
1473 .map_err(|e| Error::Inference(format!("class head forward: {}", e)))
1474 }
1475}
1476
1477#[cfg(feature = "candle")]
1478impl CountPredictor {
1479 fn new(hidden_size: usize, max_count: usize, vb: VarBuilder) -> Result<Self> {
1480 let mlp = candle_nn::linear(hidden_size, max_count, vb.pp("mlp"))
1481 .map_err(|e| Error::Retrieval(format!("count mlp: {}", e)))?;
1482 Ok(Self { mlp })
1483 }
1484
1485 fn forward(&self, prompt_embedding: &Tensor) -> Result<usize> {
1487 let logits = self
1488 .mlp
1489 .forward(prompt_embedding)
1490 .map_err(|e| Error::Inference(format!("count forward: {}", e)))?;
1491
1492 let logits_vec = logits
1494 .flatten_all()
1495 .map_err(|e| Error::Inference(format!("flatten: {}", e)))?
1496 .to_vec1::<f32>()
1497 .map_err(|e| Error::Inference(format!("to vec: {}", e)))?;
1498
1499 let (max_idx, _) = logits_vec
1500 .iter()
1501 .enumerate()
1502 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
1503 .unwrap_or((1, &0.0));
1504
1505 Ok(max_idx.max(1)) }
1507}
1508
1509#[cfg(feature = "candle")]
1510impl GLiNER2Candle {
1511 pub fn from_pretrained(model_id: &str) -> Result<Self> {
1513 use crate::backends::encoder_candle::CandleEncoder;
1514 use hf_hub::api::sync::Api;
1515
1516 let api = Api::new().map_err(|e| Error::Retrieval(format!("HF API: {}", e)))?;
1517 let repo = api.model(model_id.to_string());
1518
1519 let config_path = repo
1521 .get("config.json")
1522 .map_err(|e| Error::Retrieval(format!("config.json: {}", e)))?;
1523 let config_str = std::fs::read_to_string(&config_path)
1524 .map_err(|e| Error::Retrieval(format!("read config: {}", e)))?;
1525 let config: serde_json::Value = serde_json::from_str(&config_str)
1526 .map_err(|e| Error::Parse(format!("parse config: {}", e)))?;
1527 let hidden_size = config["hidden_size"].as_u64().unwrap_or(768) as usize;
1528
1529 let device = Device::cuda_if_available(0).unwrap_or(Device::Cpu);
1531
1532 let weights_path = repo
1534 .get("model.safetensors")
1535 .or_else(|_| repo.get("gliner_model.safetensors"))
1536 .or_else(|_| {
1537 let pytorch_path = repo.get("pytorch_model.bin")?;
1539 crate::backends::gliner_candle::convert_pytorch_to_safetensors(&pytorch_path)
1540 })
1541 .map_err(|e| {
1542 Error::Retrieval(format!("weights not found and conversion failed: {}", e))
1543 })?;
1544
1545 let vb = unsafe {
1549 VarBuilder::from_mmaped_safetensors(&[weights_path], DType::F32, &device)
1550 .map_err(|e| Error::Retrieval(format!("varbuilder: {}", e)))?
1551 };
1552
1553 let encoder = CandleEncoder::from_pretrained(model_id)?;
1555 let span_rep = SpanRepLayer::new(hidden_size, MAX_SPAN_WIDTH, vb.pp("span_rep"))?;
1556 let label_proj = candle_nn::linear(hidden_size, hidden_size, vb.pp("label_projection"))
1557 .map_err(|e| Error::Retrieval(format!("label_projection: {}", e)))?;
1558 let class_head = ClassificationHead::new(hidden_size, vb.pp("classification"))?;
1559 let count_predictor =
1560 CountPredictor::new(hidden_size, MAX_COUNT, vb.pp("count_predictor"))?;
1561
1562 log::info!(
1563 "[GLiNER2-Candle] Loaded {} (hidden={}) on {:?}",
1564 model_id,
1565 hidden_size,
1566 device
1567 );
1568
1569 Ok(Self {
1570 encoder,
1571 span_rep,
1572 label_proj,
1573 class_head,
1574 count_predictor,
1575 device,
1576 model_name: model_id.to_string(),
1577 hidden_size,
1578 label_cache: LabelCache::new(),
1579 })
1580 }
1581
1582 pub fn extract(&self, text: &str, schema: &TaskSchema) -> Result<ExtractionResult> {
1584 let mut result = ExtractionResult::default();
1585
1586 if let Some(ref ent_task) = schema.entities {
1588 let entities = self.extract_entities(text, &ent_task.types, 0.5)?;
1589 result.entities = entities;
1590 }
1591
1592 for class_task in &schema.classifications {
1594 let class_result = self.classify(text, &class_task.labels, class_task.multi_label)?;
1595 result
1596 .classifications
1597 .insert(class_task.name.clone(), class_result);
1598 }
1599
1600 for struct_task in &schema.structures {
1602 let structures = self.extract_structure_with_count(text, struct_task)?;
1603 result.structures.extend(structures);
1604 }
1605
1606 Ok(result)
1607 }
1608
1609 fn extract_entities(
1611 &self,
1612 text: &str,
1613 types: &[String],
1614 threshold: f32,
1615 ) -> Result<Vec<Entity>> {
1616 if text.is_empty() || types.is_empty() {
1617 return Ok(Vec::new());
1618 }
1619
1620 let labels: Vec<&str> = types.iter().map(|s| s.as_str()).collect();
1621
1622 let words: Vec<&str> = text.split_whitespace().collect();
1624 if words.is_empty() {
1625 return Ok(Vec::new());
1626 }
1627
1628 let (text_embeddings, word_positions) = self.encode_text(&words)?;
1630
1631 let label_embeddings = self.encode_labels_cached(&labels)?;
1633
1634 let span_indices = self.generate_spans(words.len())?;
1636
1637 let span_embs = self.span_rep.forward(&text_embeddings, &span_indices)?;
1639
1640 let label_embs = self
1642 .label_proj
1643 .forward(&label_embeddings)
1644 .map_err(|e| Error::Inference(format!("label projection: {}", e)))?;
1645
1646 let scores = self.match_spans_labels(&span_embs, &label_embs)?;
1648
1649 self.decode_entities(text, &words, &word_positions, &scores, &labels, threshold)
1651 }
1652
1653 fn classify(
1655 &self,
1656 text: &str,
1657 labels: &[String],
1658 multi_label: bool,
1659 ) -> Result<ClassificationResult> {
1660 if text.is_empty() || labels.is_empty() {
1661 return Ok(ClassificationResult::default());
1662 }
1663
1664 let (text_emb, _seq_len) = self.encoder.encode(text)?;
1666 let cls_emb = Tensor::from_vec(
1667 text_emb[..self.hidden_size].to_vec(),
1668 (1, self.hidden_size),
1669 &self.device,
1670 )
1671 .map_err(|e| Error::Inference(format!("cls tensor: {}", e)))?;
1672
1673 let labels_str: Vec<&str> = labels.iter().map(|s| s.as_str()).collect();
1675 let label_embs = self.encode_labels_cached(&labels_str)?;
1676
1677 let label_logits = self.class_head.forward(&label_embs)?;
1679 let label_logits_vec = label_logits
1680 .flatten_all()
1681 .map_err(|e| Error::Inference(format!("flatten: {}", e)))?
1682 .to_vec1::<f32>()
1683 .map_err(|e| Error::Inference(format!("to vec: {}", e)))?;
1684
1685 let cls_norm = l2_normalize(&cls_emb, D::Minus1)?;
1687 let label_norm = l2_normalize(&label_embs, D::Minus1)?;
1688
1689 let sim_scores = cls_norm
1690 .matmul(
1691 &label_norm
1692 .t()
1693 .map_err(|e| Error::Inference(format!("transpose: {}", e)))?,
1694 )
1695 .map_err(|e| Error::Inference(format!("matmul: {}", e)))?;
1696
1697 let sim_vec = sim_scores
1698 .flatten_all()
1699 .map_err(|e| Error::Inference(format!("flatten: {}", e)))?
1700 .to_vec1::<f32>()
1701 .map_err(|e| Error::Inference(format!("to vec: {}", e)))?;
1702
1703 let combined: Vec<f32> = sim_vec
1705 .iter()
1706 .zip(label_logits_vec.iter().cycle())
1707 .map(|(s, l)| 0.7 * s + 0.3 * l)
1708 .collect();
1709
1710 let probs = if multi_label {
1712 combined.iter().map(|&s| 1.0 / (1.0 + (-s).exp())).collect()
1713 } else {
1714 let max_score = combined.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
1715 let exp_scores: Vec<f32> = combined.iter().map(|&s| (s - max_score).exp()).collect();
1716 let sum: f32 = exp_scores.iter().sum();
1717 if sum > 0.0 {
1719 exp_scores.iter().map(|&e| e / sum).collect::<Vec<_>>()
1720 } else if combined.is_empty() {
1721 vec![]
1723 } else {
1724 let uniform = 1.0 / combined.len() as f32;
1726 vec![uniform; combined.len()]
1727 }
1728 };
1729
1730 let mut scores_map = HashMap::new();
1731 let mut result_labels = Vec::new();
1732
1733 for (i, label) in labels.iter().enumerate() {
1734 let prob = probs.get(i).copied().unwrap_or(0.0);
1735 scores_map.insert(label.clone(), prob);
1736
1737 if multi_label && prob > 0.5 {
1738 result_labels.push(label.clone());
1739 }
1740 }
1741
1742 if !multi_label {
1743 if let Some((idx, _)) = probs
1744 .iter()
1745 .enumerate()
1746 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
1747 {
1748 if let Some(label) = labels.get(idx) {
1749 result_labels.push(label.clone());
1750 }
1751 }
1752 }
1753
1754 Ok(ClassificationResult {
1755 labels: result_labels,
1756 scores: scores_map,
1757 })
1758 }
1759
1760 fn extract_structure_with_count(
1762 &self,
1763 text: &str,
1764 task: &StructureTask,
1765 ) -> Result<Vec<ExtractedStructure>> {
1766 if text.is_empty() || task.fields.is_empty() {
1767 return Ok(Vec::new());
1768 }
1769
1770 let (text_emb, _) = self.encoder.encode(text)?;
1772 let prompt_emb = Tensor::from_vec(
1773 text_emb[..self.hidden_size].to_vec(),
1774 (self.hidden_size,),
1775 &self.device,
1776 )
1777 .map_err(|e| Error::Inference(format!("prompt tensor: {}", e)))?;
1778
1779 let num_instances = self.count_predictor.forward(&prompt_emb)?;
1781
1782 log::debug!(
1783 "[GLiNER2] Count predictor: {} instances for {}",
1784 num_instances,
1785 task.name
1786 );
1787
1788 let mut structures = Vec::new();
1789
1790 for instance_idx in 0..num_instances {
1792 let mut structure = ExtractedStructure {
1793 structure_type: task.name.clone(),
1794 fields: HashMap::new(),
1795 };
1796
1797 for field in &task.fields {
1798 let field_label = field.description.as_ref().unwrap_or(&field.name);
1799
1800 let labels_vec: Vec<String> = vec![field_label.to_string()];
1802 let entities = self.extract_entities(text, &labels_vec, 0.3)?;
1803
1804 let entity_for_instance = entities.get(instance_idx);
1806
1807 if let Some(entity) = entity_for_instance {
1808 let value = match field.field_type {
1809 FieldType::List => {
1810 let values: Vec<String> =
1812 entities.iter().map(|e| e.text.clone()).collect();
1813 StructureValue::List(values)
1814 }
1815 FieldType::Choice => {
1816 if let Some(ref choices) = field.choices {
1817 let extracted = &entity.text;
1818 let best_choice = choices
1819 .iter()
1820 .find(|c| extracted.to_lowercase().contains(&c.to_lowercase()))
1821 .cloned()
1822 .unwrap_or_else(|| extracted.clone());
1823 StructureValue::Single(best_choice)
1824 } else {
1825 StructureValue::Single(entity.text.clone())
1826 }
1827 }
1828 FieldType::String => StructureValue::Single(entity.text.clone()),
1829 };
1830
1831 structure.fields.insert(field.name.clone(), value);
1832 }
1833 }
1834
1835 if !structure.fields.is_empty() {
1836 structures.push(structure);
1837 }
1838 }
1839
1840 Ok(structures)
1841 }
1842
1843 fn encode_text(&self, words: &[&str]) -> Result<(Tensor, Vec<(usize, usize)>)> {
1848 let text = words.join(" ");
1849 let (embeddings, seq_len) = self.encoder.encode(&text)?;
1850
1851 let tensor = Tensor::from_vec(embeddings, (1, seq_len, self.hidden_size), &self.device)
1853 .map_err(|e| Error::Inference(format!("text tensor: {}", e)))?;
1854
1855 let full_text = words.join(" ");
1857 let word_positions: Vec<(usize, usize)> = {
1858 let mut positions = Vec::new();
1859 let mut pos = 0;
1860 for (idx, word) in words.iter().enumerate() {
1861 if let Some(start) = full_text[pos..].find(word) {
1862 let abs_start = pos + start;
1863 let abs_end = abs_start + word.len();
1864 if !positions.is_empty() {
1866 let (_prev_start, prev_end) = positions[positions.len() - 1];
1867 if abs_start < prev_end {
1868 log::warn!(
1869 "Word '{}' (index {}) at position {} overlaps with previous word ending at {}",
1870 word,
1871 idx,
1872 abs_start,
1873 prev_end
1874 );
1875 }
1876 }
1877 positions.push((abs_start, abs_end));
1878 pos = abs_end;
1879 } else {
1880 return Err(Error::Inference(format!(
1882 "Word '{}' (index {}) not found in text starting at position {}",
1883 word, idx, pos
1884 )));
1885 }
1886 }
1887 positions
1888 };
1889
1890 if word_positions.len() != words.len() {
1892 return Err(Error::Inference(format!(
1893 "Word position mismatch: found {} positions for {} words",
1894 word_positions.len(),
1895 words.len()
1896 )));
1897 }
1898
1899 Ok((tensor, word_positions))
1900 }
1901
1902 fn encode_labels_cached(&self, labels: &[&str]) -> Result<Tensor> {
1903 let mut all_embeddings = Vec::new();
1904
1905 for label in labels {
1906 if let Some(cached) = self.label_cache.get(label) {
1908 all_embeddings.extend(cached);
1909 } else {
1910 let (embeddings, seq_len) = self.encoder.encode(label)?;
1911 let avg: Vec<f32> = if seq_len == 0 {
1913 vec![0.0f32; self.hidden_size]
1915 } else {
1916 (0..self.hidden_size)
1917 .map(|i| {
1918 embeddings
1919 .iter()
1920 .skip(i)
1921 .step_by(self.hidden_size)
1922 .take(seq_len)
1923 .sum::<f32>()
1924 / seq_len as f32
1925 })
1926 .collect()
1927 };
1928
1929 self.label_cache.insert(label.to_string(), avg.clone());
1931 all_embeddings.extend(avg);
1932 }
1933 }
1934
1935 Tensor::from_vec(
1936 all_embeddings,
1937 (labels.len(), self.hidden_size),
1938 &self.device,
1939 )
1940 .map_err(|e| Error::Inference(format!("label tensor: {}", e)))
1941 }
1942
1943 fn generate_spans(&self, num_words: usize) -> Result<Tensor> {
1944 let estimated_capacity = num_words.saturating_mul(MAX_SPAN_WIDTH).saturating_mul(2);
1947 let mut spans = Vec::with_capacity(estimated_capacity.min(1000));
1948
1949 for start in 0..num_words {
1950 for width in 0..MAX_SPAN_WIDTH.min(num_words - start) {
1951 let end = start + width;
1952 spans.push(start as i64);
1953 spans.push(end as i64);
1954 }
1955 }
1956
1957 let num_spans = spans.len() / 2;
1958 Tensor::from_vec(spans, (1, num_spans, 2), &self.device)
1959 .map_err(|e| Error::Inference(format!("span tensor: {}", e)))
1960 }
1961
1962 fn match_spans_labels(&self, span_embs: &Tensor, label_embs: &Tensor) -> Result<Tensor> {
1963 let span_norm = l2_normalize(span_embs, D::Minus1)?;
1964 let label_norm = l2_normalize(label_embs, D::Minus1)?;
1965
1966 let batch_size = span_norm.dims()[0];
1967 let label_t = label_norm
1968 .t()
1969 .map_err(|e| Error::Inference(format!("transpose: {}", e)))?;
1970 let label_t = label_t
1971 .unsqueeze(0)
1972 .map_err(|e| Error::Inference(format!("unsqueeze: {}", e)))?
1973 .broadcast_as((batch_size, label_t.dims()[0], label_t.dims()[1]))
1974 .map_err(|e| Error::Inference(format!("broadcast: {}", e)))?;
1975
1976 let scores = span_norm
1977 .matmul(&label_t)
1978 .map_err(|e| Error::Inference(format!("matmul: {}", e)))?;
1979
1980 candle_nn::ops::sigmoid(&scores).map_err(|e| Error::Inference(format!("sigmoid: {}", e)))
1981 }
1982
1983 fn decode_entities(
1984 &self,
1985 text: &str,
1986 words: &[&str],
1987 _word_positions: &[(usize, usize)],
1988 scores: &Tensor,
1989 labels: &[&str],
1990 threshold: f32,
1991 ) -> Result<Vec<Entity>> {
1992 let scores_vec = scores
1993 .flatten_all()
1994 .map_err(|e| Error::Inference(format!("flatten scores: {}", e)))?
1995 .to_vec1::<f32>()
1996 .map_err(|e| Error::Inference(format!("scores to vec: {}", e)))?;
1997
1998 let num_labels = labels.len();
1999 let num_spans = scores_vec.len() / num_labels;
2000
2001 let mut entities = Vec::with_capacity(num_spans.min(32));
2003 let mut span_idx = 0;
2004
2005 for start in 0..words.len() {
2006 for width in 0..MAX_SPAN_WIDTH.min(words.len() - start) {
2007 if span_idx >= num_spans {
2008 break;
2009 }
2010
2011 let end = start + width;
2012
2013 for (label_idx, label) in labels.iter().enumerate() {
2014 let score = scores_vec[span_idx * num_labels + label_idx];
2015
2016 if score >= threshold {
2017 let span_text = words[start..=end].join(" ");
2018 let (char_start, char_end) =
2019 word_span_to_char_offsets(text, words, start, end);
2020
2021 let entity_type = map_entity_type(label);
2022
2023 entities.push(Entity::new(
2024 span_text,
2025 entity_type,
2026 char_start,
2027 char_end,
2028 score as f64,
2029 ));
2030 }
2031 }
2032
2033 span_idx += 1;
2034 }
2035 }
2036
2037 entities.sort_by(|a, b| a.start.cmp(&b.start).then_with(|| b.end.cmp(&a.end)));
2039 entities.dedup_by(|a, b| a.start == b.start && a.end == b.end);
2040
2041 Ok(entities)
2042 }
2043}
2044
2045#[cfg(feature = "candle")]
2047fn l2_normalize(tensor: &Tensor, dim: D) -> Result<Tensor> {
2048 let norm = tensor
2049 .sqr()
2050 .map_err(|e| Error::Inference(format!("sqr: {}", e)))?
2051 .sum(dim)
2052 .map_err(|e| Error::Inference(format!("sum: {}", e)))?
2053 .sqrt()
2054 .map_err(|e| Error::Inference(format!("sqrt: {}", e)))?
2055 .unsqueeze(D::Minus1)
2056 .map_err(|e| Error::Inference(format!("unsqueeze: {}", e)))?;
2057
2058 let norm_clamped = norm
2059 .clamp(1e-12, f32::MAX)
2060 .map_err(|e| Error::Inference(format!("clamp: {}", e)))?;
2061
2062 tensor
2063 .broadcast_div(&norm_clamped)
2064 .map_err(|e| Error::Inference(format!("div: {}", e)))
2065}
2066
2067#[cfg(not(any(feature = "onnx", feature = "candle")))]
2073#[derive(Debug)]
2074pub struct GLiNER2 {
2075 _private: (),
2076}
2077
2078#[cfg(not(any(feature = "onnx", feature = "candle")))]
2079impl GLiNER2 {
2080 pub fn from_pretrained(_model_id: &str) -> Result<Self> {
2082 Err(Error::FeatureNotAvailable(
2083 "GLiNER2 requires 'onnx' or 'candle' feature. \
2084 Build with: cargo build --features candle"
2085 .to_string(),
2086 ))
2087 }
2088
2089 pub fn extract(&self, _text: &str, _schema: &TaskSchema) -> Result<ExtractionResult> {
2091 Err(Error::FeatureNotAvailable(
2092 "GLiNER2 requires features".to_string(),
2093 ))
2094 }
2095}
2096
2097#[cfg(feature = "candle")]
2103pub type GLiNER2 = GLiNER2Candle;
2104
2105#[cfg(all(feature = "onnx", not(feature = "candle")))]
2107pub type GLiNER2 = GLiNER2Onnx;
2108
2109fn word_span_to_char_offsets(
2115 text: &str,
2116 words: &[&str],
2117 start_word: usize,
2118 end_word: usize,
2119) -> (usize, usize) {
2120 if words.is_empty()
2122 || start_word >= words.len()
2123 || end_word >= words.len()
2124 || start_word > end_word
2125 {
2126 return (0, 0);
2128 }
2129
2130 let mut byte_pos = 0;
2132 let mut start_byte = 0;
2133 let mut end_byte = text.len();
2134 let mut found_start = false;
2135 let mut found_end = false;
2136
2137 for (i, word) in words.iter().enumerate() {
2138 if let Some(pos) = text.get(byte_pos..).and_then(|s| s.find(word)) {
2139 let abs_pos = byte_pos + pos;
2140
2141 if i == start_word {
2142 start_byte = abs_pos;
2143 found_start = true;
2144 }
2145 if i == end_word {
2146 end_byte = abs_pos + word.len();
2147 found_end = true;
2148 break;
2150 }
2151
2152 byte_pos = abs_pos + word.len();
2153 } else {
2154 }
2158 }
2159
2160 if !found_start || !found_end {
2162 (0, 0)
2164 } else {
2165 crate::offset::bytes_to_chars(text, start_byte, end_byte)
2167 }
2168}
2169
2170fn map_entity_type(type_str: &str) -> EntityType {
2174 crate::schema::map_to_canonical(type_str, None)
2175}
2176
2177#[cfg(feature = "onnx")]
2182impl crate::Model for GLiNER2Onnx {
2183 fn extract_entities(&self, text: &str, _language: Option<&str>) -> Result<Vec<Entity>> {
2184 let schema = TaskSchema::new().with_entities(&[
2185 "person",
2186 "organization",
2187 "location",
2188 "date",
2189 "event",
2190 ]);
2191
2192 let result = self.extract(text, &schema)?;
2193 Ok(result.entities)
2194 }
2195
2196 fn supported_types(&self) -> Vec<EntityType> {
2197 vec![
2198 EntityType::Person,
2199 EntityType::Organization,
2200 EntityType::Location,
2201 EntityType::Date,
2202 EntityType::Custom {
2203 name: "event".to_string(),
2204 category: EntityCategory::Creative,
2205 },
2206 EntityType::Custom {
2207 name: "product".to_string(),
2208 category: EntityCategory::Creative,
2209 },
2210 EntityType::Other("misc".to_string()),
2211 ]
2212 }
2213
2214 fn is_available(&self) -> bool {
2215 true
2216 }
2217
2218 fn name(&self) -> &'static str {
2219 "GLiNER2-ONNX"
2220 }
2221
2222 fn description(&self) -> &'static str {
2223 "Multi-task information extraction via GLiNER2 (ONNX backend)"
2224 }
2225}
2226
2227#[cfg(feature = "candle")]
2232impl crate::Model for GLiNER2Candle {
2233 fn extract_entities(&self, text: &str, _language: Option<&str>) -> Result<Vec<Entity>> {
2234 let schema = TaskSchema::new().with_entities(&[
2235 "person",
2236 "organization",
2237 "location",
2238 "date",
2239 "event",
2240 ]);
2241
2242 let result = self.extract(text, &schema)?;
2243 Ok(result.entities)
2244 }
2245
2246 fn supported_types(&self) -> Vec<EntityType> {
2247 vec![
2248 EntityType::Person,
2249 EntityType::Organization,
2250 EntityType::Location,
2251 EntityType::Date,
2252 EntityType::Custom {
2253 name: "event".to_string(),
2254 category: EntityCategory::Creative,
2255 },
2256 EntityType::Custom {
2257 name: "product".to_string(),
2258 category: EntityCategory::Creative,
2259 },
2260 EntityType::Other("misc".to_string()),
2261 ]
2262 }
2263
2264 fn is_available(&self) -> bool {
2265 true
2266 }
2267
2268 fn name(&self) -> &'static str {
2269 "GLiNER2-Candle"
2270 }
2271
2272 fn description(&self) -> &'static str {
2273 "Multi-task information extraction via GLiNER2 (native Rust/Candle)"
2274 }
2275}
2276
2277#[cfg(feature = "onnx")]
2282impl ZeroShotNER for GLiNER2Onnx {
2283 fn default_types(&self) -> &[&'static str] {
2284 &["person", "organization", "location", "date", "event"]
2285 }
2286
2287 fn extract_with_types(
2288 &self,
2289 text: &str,
2290 types: &[&str],
2291 threshold: f32,
2292 ) -> Result<Vec<Entity>> {
2293 self.extract_ner(text, types, threshold)
2294 }
2295
2296 fn extract_with_descriptions(
2297 &self,
2298 text: &str,
2299 descriptions: &[&str],
2300 threshold: f32,
2301 ) -> Result<Vec<Entity>> {
2302 self.extract_ner(text, descriptions, threshold)
2304 }
2305}
2306
2307#[cfg(feature = "candle")]
2308impl ZeroShotNER for GLiNER2Candle {
2309 fn default_types(&self) -> &[&'static str] {
2310 &["person", "organization", "location", "date", "event"]
2311 }
2312
2313 fn extract_with_types(
2314 &self,
2315 text: &str,
2316 types: &[&str],
2317 threshold: f32,
2318 ) -> Result<Vec<Entity>> {
2319 let type_strings: Vec<String> = types.iter().map(|s| s.to_string()).collect();
2320 self.extract_entities(text, &type_strings, threshold)
2321 }
2322
2323 fn extract_with_descriptions(
2324 &self,
2325 text: &str,
2326 descriptions: &[&str],
2327 threshold: f32,
2328 ) -> Result<Vec<Entity>> {
2329 let type_strings: Vec<String> = descriptions.iter().map(|s| s.to_string()).collect();
2331 self.extract_entities(text, &type_strings, threshold)
2332 }
2333}
2334
2335#[cfg(any(feature = "onnx", feature = "candle"))]
2342fn get_likely_relations(head_type: &str, tail_type: &str) -> Vec<(&'static str, f32)> {
2343 let head = head_type.to_uppercase();
2344 let tail = tail_type.to_uppercase();
2345
2346 match (head.as_str(), tail.as_str()) {
2347 ("PER", "OFI") | ("PERSON", "OFI") => vec![("任职", 0.7), ("任職", 0.7)],
2349 ("OFI", "PER") => vec![("上下级", 0.6), ("上下級", 0.6)],
2350 ("PER", "LOC") => vec![
2351 ("到达", 0.55),
2352 ("到達", 0.55),
2353 ("出生于某地", 0.4),
2354 ("出生於某地", 0.4),
2355 ],
2356 ("LOC", "PER") => vec![("到达", 0.5), ("到達", 0.5)],
2357 ("PER", "PER") => vec![
2358 ("上下级", 0.45),
2359 ("上下級", 0.45),
2360 ("同僚", 0.4),
2361 ("父母", 0.3),
2362 ("兄弟", 0.3),
2363 ],
2364 ("OFI", "LOC") | ("LOC", "OFI") => vec![("管理", 0.5)],
2365 ("BOOK", "BOOK") | ("BOOK", "PER") | ("PER", "BOOK") => {
2366 vec![("别名", 0.35), ("別名", 0.35)]
2367 }
2368 ("PERSON", "ORGANIZATION") | ("PER", "ORG") => vec![
2370 ("WORKS_FOR", 0.7),
2371 ("FOUNDED", 0.5),
2372 ("CEO_OF", 0.4),
2373 ("MEMBER_OF", 0.6),
2374 ],
2375 ("ORGANIZATION", "PERSON") | ("ORG", "PER") => {
2376 vec![("EMPLOYS", 0.7), ("FOUNDED_BY", 0.5), ("LED_BY", 0.4)]
2377 }
2378 ("PERSON", "LOCATION") | ("PERSON", "GPE") | ("PER", "GPE") => {
2380 vec![("LIVES_IN", 0.6), ("BORN_IN", 0.5), ("VISITED", 0.4)]
2381 }
2382 ("ORGANIZATION", "LOCATION")
2384 | ("ORG", "LOC")
2385 | ("ORGANIZATION", "GPE")
2386 | ("ORG", "GPE") => vec![
2387 ("HEADQUARTERED_IN", 0.7),
2388 ("LOCATED_IN", 0.8),
2389 ("OPERATES_IN", 0.5),
2390 ],
2391 ("PRODUCT", "ORGANIZATION") | ("PRODUCT", "ORG") => {
2393 vec![("MADE_BY", 0.8), ("PRODUCED_BY", 0.7)]
2394 }
2395 ("ORGANIZATION", "PRODUCT") | ("ORG", "PRODUCT") => {
2396 vec![("MAKES", 0.8), ("PRODUCES", 0.7), ("ANNOUNCED", 0.5)]
2397 }
2398 (_, "DATE") | (_, "TIME") => vec![("OCCURRED_ON", 0.5), ("FOUNDED_ON", 0.4)],
2400 _ => vec![],
2402 }
2403}
2404
2405#[cfg(any(feature = "onnx", feature = "candle"))]
2408fn extract_relations_heuristic(
2409 entities: &[Entity],
2410 text: &str,
2411 relation_types: &[&str],
2412 threshold: f32,
2413) -> Vec<crate::backends::inference::RelationTriple> {
2414 use crate::backends::inference::RelationTriple;
2415
2416 fn norm_rel_slug(s: &str) -> String {
2418 let mut out = String::with_capacity(s.len());
2419 let mut prev_underscore = false;
2420 for ch in s.chars() {
2421 if ch.is_alphanumeric() {
2422 if ch.is_ascii_alphabetic() {
2424 out.push(ch.to_ascii_uppercase());
2425 } else {
2426 out.push(ch);
2427 }
2428 prev_underscore = false;
2429 } else if !prev_underscore {
2430 out.push('_');
2431 prev_underscore = true;
2432 }
2433 }
2434 while out.starts_with('_') {
2435 out.remove(0);
2436 }
2437 while out.ends_with('_') {
2438 out.pop();
2439 }
2440 out
2441 }
2442
2443 fn pick_relation_label(canonical: &str, relation_types: &[&str]) -> Option<String> {
2444 if relation_types.is_empty() {
2445 return None;
2446 }
2447 let want = norm_rel_slug(canonical);
2448 relation_types
2449 .iter()
2450 .find(|r| norm_rel_slug(r) == want)
2451 .map(|s| (*s).to_string())
2452 }
2453
2454 let mut relations = Vec::new();
2455 let text_char_count = text.chars().count();
2457 let text_char_len = text_char_count.max(1) as f32;
2458
2459 let trigger_patterns: Vec<(&str, &str)> = vec![
2465 ("part of", "PART_OF"),
2467 ("subset of", "PART_OF"),
2468 ("member of", "PART_OF"),
2469 ("type of", "TYPE_OF"),
2470 ("kind of", "TYPE_OF"),
2471 ("is a", "TYPE_OF"),
2472 ("are a", "TYPE_OF"),
2473 ("related to", "RELATED_TO"),
2474 ("also known as", "NAMED"),
2475 ("known as", "NAMED"),
2476 ("called", "NAMED"),
2477 ("named", "NAMED"),
2478 ("born", "TEMPORAL"),
2479 ("in 19", "TEMPORAL"),
2480 ("in 20", "TEMPORAL"),
2481 ("during", "TEMPORAL"),
2482 ("from", "ORIGIN"),
2483 ("based in", "PHYSICAL"),
2484 ("located in", "PHYSICAL"),
2485 ("headquartered", "PHYSICAL"),
2486 ("at ", "PHYSICAL"),
2487 ("vs", "COMPARE"),
2488 ("versus", "COMPARE"),
2489 ("compared", "COMPARE"),
2490 ("use", "USAGE"),
2491 ("used", "USAGE"),
2492 ("uses", "USAGE"),
2493 ("invented", "ARTIFACT"),
2494 ("created", "ARTIFACT"),
2495 ("built", "ARTIFACT"),
2496 ("developed", "ARTIFACT"),
2497 ("won", "WIN_DEFEAT"),
2498 ("defeated", "WIN_DEFEAT"),
2499 ("beat", "WIN_DEFEAT"),
2500 ("caused", "CAUSE_EFFECT"),
2501 ("causes", "CAUSE_EFFECT"),
2502 ("leads to", "CAUSE_EFFECT"),
2503 ("because", "CAUSE_EFFECT"),
2504 ("父", "父母"),
2507 ("母", "父母"),
2508 ("兄", "兄弟"),
2509 ("弟", "兄弟"),
2510 ("别名", "别名"),
2511 ("別名", "別名"),
2512 ("生于", "出生于某地"),
2513 ("生於", "出生於某地"),
2514 ("到", "到达"),
2515 ("到", "到達"),
2516 ("至", "到达"),
2517 ("至", "到達"),
2518 ("驻", "驻守"),
2519 ("駐", "駐守"),
2520 ("守", "驻守"),
2521 ("守", "駐守"),
2522 ("攻", "敌对攻伐"),
2523 ("伐", "敌对攻伐"),
2524 ("攻", "敵對攻伐"),
2525 ("伐", "敵對攻伐"),
2526 ("任", "任职"),
2527 ("任", "任職"),
2528 ("拜", "任职"),
2529 ("拜", "任職"),
2530 ("管", "管理"),
2531 ("治", "管理"),
2532 ("ceo", "CEO_OF"),
2534 ("founder", "FOUNDED"),
2535 ("founded", "FOUNDED"),
2536 ("works at", "WORKS_FOR"),
2537 ("works for", "WORKS_FOR"),
2538 ("employee", "WORKS_FOR"),
2539 ("born in", "BORN_IN"),
2540 ("lives in", "LIVES_IN"),
2541 ("announced", "ANNOUNCED"),
2542 ("released", "RELEASED"),
2543 ("acquired", "ACQUIRED"),
2544 ("bought", "ACQUIRED"),
2545 ("merged", "MERGED_WITH"),
2546 ];
2547
2548 for (i, head) in entities.iter().enumerate() {
2549 for (j, tail) in entities.iter().enumerate() {
2550 if i == j {
2551 continue;
2552 }
2553
2554 let head_center = (head.start + head.end) as f32 / 2.0;
2556 let tail_center = (tail.start + tail.end) as f32 / 2.0;
2557 let distance = (head_center - tail_center).abs() / text_char_len;
2558 let proximity_score = 1.0 - distance.min(1.0);
2559
2560 let head_type = head.entity_type.as_label();
2562 let tail_type = tail.entity_type.as_label();
2563 let type_relations = get_likely_relations(head_type, tail_type);
2564
2565 let (span_start, span_end) = if head.end < tail.start {
2567 (head.end, tail.start)
2568 } else if tail.end < head.start {
2569 (tail.end, head.start)
2570 } else {
2571 let min_start = head.start.min(tail.start);
2573 let max_end = head.end.max(tail.end);
2574 (
2575 min_start.saturating_sub(20),
2576 (max_end + 20).min(text_char_count),
2577 )
2578 };
2579
2580 let between_text = if span_end > span_start && span_end <= text_char_count {
2581 crate::offset::TextSpan::from_chars(text, span_start, span_end).extract(text)
2582 } else {
2583 ""
2584 };
2585 let between_lower = between_text.to_ascii_lowercase();
2586
2587 for (trigger, rel_type) in &trigger_patterns {
2589 let hit = if trigger.is_ascii() {
2590 between_lower.contains(trigger)
2591 } else {
2592 between_text.contains(trigger)
2593 };
2594 if hit {
2595 if !relation_types.is_empty()
2598 && pick_relation_label(rel_type, relation_types).is_none()
2599 {
2600 continue;
2601 }
2602
2603 let out_label = pick_relation_label(rel_type, relation_types)
2604 .unwrap_or_else(|| rel_type.to_string());
2605
2606 let confidence = (proximity_score * 0.6 + 0.4)
2607 * (head.confidence + tail.confidence) as f32
2608 / 2.0;
2609 if confidence >= threshold {
2610 relations.push(RelationTriple {
2611 head_idx: i,
2612 tail_idx: j,
2613 relation_type: out_label,
2614 confidence,
2615 });
2616 }
2617 }
2618 }
2619
2620 let has_trigger_relation = relations.iter().any(|r| r.head_idx == i && r.tail_idx == j);
2622 if !has_trigger_relation && proximity_score > 0.3 {
2623 for (rel_type, base_score) in type_relations {
2624 if !relation_types.is_empty()
2625 && pick_relation_label(rel_type, relation_types).is_none()
2626 {
2627 continue;
2628 }
2629 let out_label = pick_relation_label(rel_type, relation_types)
2630 .unwrap_or_else(|| rel_type.to_string());
2631
2632 let confidence =
2633 proximity_score * base_score * (head.confidence + tail.confidence) as f32
2634 / 2.0;
2635 if confidence >= threshold {
2636 relations.push(RelationTriple {
2637 head_idx: i,
2638 tail_idx: j,
2639 relation_type: out_label,
2640 confidence,
2641 });
2642 break; }
2644 }
2645 }
2646 }
2647 }
2648
2649 relations.sort_by(|a, b| {
2651 b.confidence
2652 .partial_cmp(&a.confidence)
2653 .unwrap_or(std::cmp::Ordering::Equal)
2654 });
2655
2656 let mut seen_pairs = std::collections::HashSet::new();
2658 relations.retain(|r| seen_pairs.insert((r.head_idx, r.tail_idx)));
2659
2660 relations
2661}
2662
2663#[cfg(feature = "onnx")]
2664impl RelationExtractor for GLiNER2Onnx {
2665 fn extract_with_relations(
2666 &self,
2667 text: &str,
2668 types: &[&str],
2669 relation_types: &[&str],
2670 threshold: f32,
2671 ) -> Result<ExtractionWithRelations> {
2672 let entities = self.extract_ner(text, types, threshold)?;
2674
2675 let relations = extract_relations_heuristic(&entities, text, relation_types, threshold);
2677
2678 Ok(ExtractionWithRelations {
2679 entities,
2680 relations,
2681 })
2682 }
2683}
2684
2685#[cfg(feature = "candle")]
2686impl RelationExtractor for GLiNER2Candle {
2687 fn extract_with_relations(
2688 &self,
2689 text: &str,
2690 types: &[&str],
2691 relation_types: &[&str],
2692 threshold: f32,
2693 ) -> Result<ExtractionWithRelations> {
2694 let type_strings: Vec<String> = types.iter().map(|s| s.to_string()).collect();
2695 let entities = self.extract_entities(text, &type_strings, threshold)?;
2696
2697 let relations = extract_relations_heuristic(&entities, text, relation_types, threshold);
2699
2700 Ok(ExtractionWithRelations {
2701 entities,
2702 relations,
2703 })
2704 }
2705}
2706
2707#[cfg(feature = "onnx")]
2712impl crate::BatchCapable for GLiNER2Onnx {
2713 fn extract_entities_batch(
2714 &self,
2715 texts: &[&str],
2716 _language: Option<&str>,
2717 ) -> Result<Vec<Vec<Entity>>> {
2718 if texts.is_empty() {
2719 return Ok(Vec::new());
2720 }
2721
2722 let default_types = &["person", "organization", "location", "date", "event"];
2723
2724 let text_words: Vec<Vec<&str>> = texts
2732 .iter()
2733 .map(|t| t.split_whitespace().collect())
2734 .collect();
2735
2736 let max_words = text_words.iter().map(|w| w.len()).max().unwrap_or(0);
2738 if max_words == 0 {
2739 return Ok(texts.iter().map(|_| Vec::new()).collect());
2740 }
2741
2742 let mut all_input_ids = Vec::new();
2744 let mut all_attention_masks = Vec::new();
2745 let mut all_words_masks = Vec::new();
2746 let mut all_text_lengths = Vec::new();
2747 let mut seq_lens = Vec::new();
2748
2749 for words in &text_words {
2750 if words.is_empty() {
2751 seq_lens.push(0);
2753 continue;
2754 }
2755
2756 let (input_ids, attention_mask, words_mask) =
2757 self.encode_ner_prompt(words, default_types)?;
2758 seq_lens.push(input_ids.len());
2759 all_input_ids.push(input_ids);
2760 all_attention_masks.push(attention_mask);
2761 all_words_masks.push(words_mask);
2762 all_text_lengths.push(words.len() as i64);
2763 }
2764
2765 if seq_lens.iter().all(|&l| l == 0) {
2767 return Ok(texts.iter().map(|_| Vec::new()).collect());
2768 }
2769
2770 let max_seq_len = seq_lens.iter().copied().max().unwrap_or(0);
2772
2773 for i in 0..all_input_ids.len() {
2774 let pad_len = max_seq_len - all_input_ids[i].len();
2775 all_input_ids[i].extend(std::iter::repeat_n(0i64, pad_len));
2776 all_attention_masks[i].extend(std::iter::repeat_n(0i64, pad_len));
2777 all_words_masks[i].extend(std::iter::repeat_n(0i64, pad_len));
2778 }
2779
2780 use ndarray::Array2;
2782
2783 let batch_size = all_input_ids.len();
2784
2785 let input_ids_flat: Vec<i64> = all_input_ids.into_iter().flatten().collect();
2786 let attention_mask_flat: Vec<i64> = all_attention_masks.into_iter().flatten().collect();
2787 let words_mask_flat: Vec<i64> = all_words_masks.into_iter().flatten().collect();
2788
2789 let expected_input_len = batch_size * max_seq_len;
2791 if input_ids_flat.len() != expected_input_len {
2792 return Err(Error::Parse(format!(
2793 "Input IDs length mismatch: expected {}, got {}",
2794 expected_input_len,
2795 input_ids_flat.len()
2796 )));
2797 }
2798
2799 let input_ids_arr = Array2::from_shape_vec((batch_size, max_seq_len), input_ids_flat)
2800 .map_err(|e| Error::Parse(format!("Array: {}", e)))?;
2801 let attention_mask_arr =
2802 Array2::from_shape_vec((batch_size, max_seq_len), attention_mask_flat)
2803 .map_err(|e| Error::Parse(format!("Array: {}", e)))?;
2804 let words_mask_arr = Array2::from_shape_vec((batch_size, max_seq_len), words_mask_flat)
2805 .map_err(|e| Error::Parse(format!("Array: {}", e)))?;
2806 let text_lengths_arr = Array2::from_shape_vec((batch_size, 1), all_text_lengths)
2807 .map_err(|e| Error::Parse(format!("Array: {}", e)))?;
2808
2809 let input_ids_t = super::ort_compat::tensor_from_ndarray(input_ids_arr)
2810 .map_err(|e| Error::Parse(format!("Tensor: {}", e)))?;
2811 let attention_mask_t = super::ort_compat::tensor_from_ndarray(attention_mask_arr)
2812 .map_err(|e| Error::Parse(format!("Tensor: {}", e)))?;
2813 let words_mask_t = super::ort_compat::tensor_from_ndarray(words_mask_arr)
2814 .map_err(|e| Error::Parse(format!("Tensor: {}", e)))?;
2815 let text_lengths_t = super::ort_compat::tensor_from_ndarray(text_lengths_arr)
2816 .map_err(|e| Error::Parse(format!("Tensor: {}", e)))?;
2817
2818 let mut session = lock(&self.session);
2820
2821 let outputs = session
2822 .run(ort::inputs![
2823 "input_ids" => input_ids_t.into_dyn(),
2824 "attention_mask" => attention_mask_t.into_dyn(),
2825 "words_mask" => words_mask_t.into_dyn(),
2826 "text_lengths" => text_lengths_t.into_dyn(),
2827 ])
2828 .map_err(|e| Error::Inference(format!("ONNX batch run: {}", e)))?;
2829
2830 self.decode_ner_batch_output(&outputs, texts, &text_words, default_types, 0.5)
2832 }
2833
2834 fn optimal_batch_size(&self) -> Option<usize> {
2835 Some(16)
2836 }
2837}
2838
2839#[cfg(feature = "candle")]
2840impl crate::BatchCapable for GLiNER2Candle {
2841 fn extract_entities_batch(
2842 &self,
2843 texts: &[&str],
2844 _language: Option<&str>,
2845 ) -> Result<Vec<Vec<Entity>>> {
2846 if texts.is_empty() {
2847 return Ok(Vec::new());
2848 }
2849
2850 let default_types = vec![
2851 "person".to_string(),
2852 "organization".to_string(),
2853 "location".to_string(),
2854 "date".to_string(),
2855 "event".to_string(),
2856 ];
2857
2858 let label_refs: Vec<&str> = default_types.iter().map(|s| s.as_str()).collect();
2860 let _ = self.encode_labels_cached(&label_refs)?;
2861
2862 let mut results = Vec::with_capacity(texts.len());
2864
2865 for text in texts {
2866 let entities = self.extract_entities(text, &default_types, 0.5)?;
2867 results.push(entities);
2868 }
2869
2870 Ok(results)
2871 }
2872
2873 fn optimal_batch_size(&self) -> Option<usize> {
2874 Some(8)
2875 }
2876}
2877
2878#[cfg(feature = "onnx")]
2883impl crate::StreamingCapable for GLiNER2Onnx {
2884 fn recommended_chunk_size(&self) -> usize {
2887 4096 }
2889}
2890
2891#[cfg(feature = "candle")]
2892impl crate::StreamingCapable for GLiNER2Candle {
2893 fn recommended_chunk_size(&self) -> usize {
2896 4096
2897 }
2898}
2899
2900#[cfg(feature = "candle")]
2905impl crate::GpuCapable for GLiNER2Candle {
2906 fn is_gpu_active(&self) -> bool {
2907 matches!(&self.device, Device::Metal(_) | Device::Cuda(_))
2908 }
2909
2910 fn device(&self) -> &str {
2911 match &self.device {
2912 Device::Cpu => "cpu",
2913 Device::Metal(_) => "metal",
2914 Device::Cuda(_) => "cuda",
2915 }
2916 }
2917}
2918
2919#[cfg(test)]
2924mod tests {
2925 use super::*;
2926
2927 #[test]
2928 #[cfg(any(feature = "onnx", feature = "candle"))]
2929 fn test_relation_heuristic_unicode_safe_and_case_insensitive() {
2930 use crate::backends::inference::RelationTriple;
2931 use crate::offset::bytes_to_chars;
2932
2933 let text = "Dr. 田中 is CEO of Apple Inc. in 東京. François works at OpenAI.";
2934 let span = |needle: &str| {
2935 let (b_start, _) = text
2936 .match_indices(needle)
2937 .next()
2938 .expect("needle should exist in test text");
2939 let b_end = b_start + needle.len();
2940 bytes_to_chars(text, b_start, b_end)
2941 };
2942
2943 let (s, e) = span("田中");
2944 let e_tanaka = Entity::new("田中", EntityType::Person, s, e, 0.9);
2945 let (s, e) = span("Apple Inc.");
2946 let e_apple = Entity::new("Apple Inc.", EntityType::Organization, s, e, 0.9);
2947 let (s, e) = span("東京");
2948 let e_tokyo = Entity::new("東京", EntityType::Location, s, e, 0.9);
2949 let (s, e) = span("François");
2950 let e_francois = Entity::new("François", EntityType::Person, s, e, 0.9);
2951 let (s, e) = span("OpenAI");
2952 let e_openai = Entity::new("OpenAI", EntityType::Organization, s, e, 0.9);
2953
2954 let entities = vec![e_tanaka, e_apple, e_tokyo, e_francois, e_openai];
2955
2956 let rels: Vec<RelationTriple> = extract_relations_heuristic(&entities, text, &[], 0.0);
2958 assert!(
2959 rels.iter()
2960 .any(|r| r.relation_type == "CEO_OF" || r.relation_type == "WORKS_FOR"),
2961 "expected at least one trigger-based relation, got {:?}",
2962 rels
2963 );
2964 }
2965
2966 #[test]
2967 fn test_task_schema_builder() {
2968 let schema = TaskSchema::new()
2969 .with_entities(&["person", "organization"])
2970 .with_classification("sentiment", &["positive", "negative"], false);
2971
2972 assert!(schema.entities.is_some());
2973 assert_eq!(schema.entities.as_ref().unwrap().types.len(), 2);
2974 assert_eq!(schema.classifications.len(), 1);
2975 }
2976
2977 #[test]
2978 fn test_structure_task_builder() {
2979 let task = StructureTask::new("product")
2980 .with_field("name", FieldType::String)
2981 .with_field_described("price", FieldType::String, "Product price in USD")
2982 .with_choice_field("category", &["electronics", "clothing"]);
2983
2984 assert_eq!(task.fields.len(), 3);
2985 assert_eq!(task.fields[2].choices.as_ref().unwrap().len(), 2);
2986 }
2987
2988 #[test]
2989 fn test_word_span_to_char_offsets() {
2990 use crate::offset::TextSpan;
2991
2992 let text = "John works at Apple";
2993 let words: Vec<&str> = text.split_whitespace().collect();
2994
2995 let (start, end) = word_span_to_char_offsets(text, &words, 0, 0);
2996 assert_eq!(TextSpan::from_chars(text, start, end).extract(text), "John");
2997
2998 let (start, end) = word_span_to_char_offsets(text, &words, 3, 3);
2999 assert_eq!(
3000 TextSpan::from_chars(text, start, end).extract(text),
3001 "Apple"
3002 );
3003
3004 let (start, end) = word_span_to_char_offsets(text, &words, 0, 2);
3005 assert_eq!(
3006 TextSpan::from_chars(text, start, end).extract(text),
3007 "John works at"
3008 );
3009 }
3010
3011 #[test]
3012 fn test_map_entity_type() {
3013 assert!(matches!(map_entity_type("person"), EntityType::Person));
3014 assert!(matches!(
3015 map_entity_type("ORGANIZATION"),
3016 EntityType::Organization
3017 ));
3018 assert!(matches!(map_entity_type("loc"), EntityType::Location));
3019 assert!(
3021 matches!(map_entity_type("custom_type"), EntityType::Other(ref s) if s == "CUSTOM_TYPE")
3022 );
3023 assert!(matches!(
3025 map_entity_type("product"),
3026 EntityType::Custom { .. }
3027 ));
3028 assert!(matches!(
3029 map_entity_type("event"),
3030 EntityType::Custom { .. }
3031 ));
3032 }
3033}