gliner/model/output/decoded/
token_flat.rs1use std::iter;
4use composable::Composable;
5use crate::util::result::Result;
6use crate::model::output::tensors::TensorOutput;
7use crate::{model::pipeline::context::EntityContext, text::span::Span};
8use crate::util::math::sigmoid;
9use super::SpanOutput;
10
11
12pub struct FlatTokenDecoder {
18 threshold: f32,
19}
20
21
22impl FlatTokenDecoder {
23
24 fn new(threshold: f32) -> Self {
25 Self {
26 threshold,
27 }
28 }
29
30 fn decode(&self, model_output: &[f32], input: &EntityContext) -> Result<Vec<Vec<Span>>> {
31 let tokens = &input.tokens;
32 let batch_size = tokens.len();
33 let num_entities = input.entities.len();
34
35 let sequence_padding = input.num_words * num_entities;
37 let position_padding = batch_size * sequence_padding;
38 let token_padding = num_entities;
39
40 let mut spans: Vec<Vec<Span>> = iter::repeat_with(Vec::new).take(batch_size).collect();
42
43 for start_idx in 0..position_padding {
45 if sigmoid(Self::get(model_output, start_idx)) < self.threshold {
47 continue
48 }
49
50 let sequence_id = (start_idx / sequence_padding) % batch_size;
52 let start_token = (start_idx / token_padding) % input.num_words;
53 let class = start_idx % num_entities;
54
55 let mut sum = 0f32;
57 let mut count = 0usize;
58
59 let mut end_token = start_token;
61 let mut end_idx = start_idx + position_padding;
62
63 while (((end_idx / sequence_padding) % batch_size) == sequence_id) && (end_idx < 2 * position_padding) {
64 if sigmoid(Self::get(model_output, end_idx)) >= self.threshold {
66 let score = sigmoid(Self::get(model_output, end_idx + position_padding));
68 if score < self.threshold {
69 break
70 }
71 else {
73 sum += score;
75 count += 1;
76 let probability = sum / (count as f32);
77
78 let span = input.create_span(sequence_id, start_token, end_token, class, probability)?;
80 spans.get_mut(sequence_id).unwrap().push(span);
81 }
82 }
83
84 end_token += 1;
86 end_idx += token_padding;
87 }
88 }
89
90 Ok(spans)
91 }
92
93 #[inline] fn get(model_output: &[f32], index: usize) -> f32 {
94 *model_output.get(index).unwrap()
95 }
96}
97
98
99pub struct TensorsToDecoded {
100 decoder: FlatTokenDecoder,
101}
102
103impl TensorsToDecoded {
104 pub fn new(threshold: f32) -> Self {
105 Self {
106 decoder: FlatTokenDecoder::new(threshold)
107 }
108 }
109}
110
111impl Composable<TensorOutput<'_>, SpanOutput> for TensorsToDecoded {
112 fn apply(&self, input: TensorOutput) -> Result<SpanOutput> {
113 let logits = input.tensors.get("logits").ok_or("logits not found in model output")?;
114 let (_shape, logits) = logits.try_extract_raw_tensor::<f32>()?;
115 let spans = self.decoder.decode(logits, &input.context)?;
116 Ok(SpanOutput::new(input.context.texts, input.context.entities, spans))
117 }
118}