gliner/model/output/decoded/
token_flat.rs

1//! Experimental alternative for the first step of span decoding (in token mode)
2
3use std::iter;
4use composable::Composable;
5use crate::util::result::Result;
6use crate::model::output::tensors::TensorOutput;
7use crate::{model::pipeline::context::EntityContext, text::span::Span};
8use crate::util::math::sigmoid;
9use super::SpanOutput;
10
11
12/// *Experimental* token decoding with a one-dimensional approach, working directly on a flat representation of 
13/// the model output, with one padding by dimension to access appropriate value. Not very readable, but might 
14/// be interresting from a performance standpoint. To be benchmarked, and checked for accurracy according to
15/// the original implementation. In the meantime, prefer the `token.rs` which performs the same operation in 
16/// a much more readable way, basing on the four-dimensional output tensor.
17pub struct FlatTokenDecoder {
18    threshold: f32,
19}
20
21
22impl FlatTokenDecoder {
23
24    fn new(threshold: f32) -> Self {
25        Self {
26            threshold,
27        }
28    }
29
30    fn decode(&self, model_output: &[f32], input: &EntityContext) -> Result<Vec<Vec<Span>>> {        
31        let tokens = &input.tokens;
32        let batch_size = tokens.len();
33        let num_entities = input.entities.len();
34
35        // compute paddings to navigate the flattened tensor
36        let sequence_padding = input.num_words * num_entities;
37        let position_padding = batch_size * sequence_padding;
38        let token_padding = num_entities;
39
40        // prepare the set of spans
41        let mut spans: Vec<Vec<Span>> = iter::repeat_with(Vec::new).take(batch_size).collect();
42
43        // iterate over the whole vector
44        for start_idx in 0..position_padding {
45            // check the start token score is above threshold, otherwise continue
46            if sigmoid(Self::get(model_output, start_idx)) < self.threshold {
47                continue
48            }
49
50            // retrieve the appropriate indices
51            let sequence_id = (start_idx / sequence_padding) % batch_size;
52            let start_token = (start_idx / token_padding) % input.num_words;
53            let class = start_idx % num_entities;
54
55            // accumulators to compute the mean score of inside tokens
56            let mut sum = 0f32;
57            let mut count = 0usize;
58
59            // iterate over end tokens
60            let mut end_token = start_token;
61            let mut end_idx = start_idx + position_padding;
62
63            while (((end_idx / sequence_padding) % batch_size) == sequence_id) && (end_idx < 2 * position_padding) {
64                // check the end token score is above threshold, otherwise continue
65                if sigmoid(Self::get(model_output, end_idx)) >= self.threshold {
66                    // we won't consider a span at all if it contains a score below the threshold
67                    let score = sigmoid(Self::get(model_output, end_idx + position_padding));
68                    if score < self.threshold {
69                        break
70                    }
71                    // consume next inside token and update the results
72                    else {
73                        // compute the actual probability (score) for the current span
74                        sum += score;
75                        count += 1;
76                        let probability = sum / (count as f32);
77
78                        // actually create the span
79                        let span = input.create_span(sequence_id, start_token, end_token, class, probability)?;
80                        spans.get_mut(sequence_id).unwrap().push(span);
81                    }
82                }
83
84                // proceed
85                end_token += 1;
86                end_idx += token_padding;
87            }
88        }
89
90        Ok(spans)
91    }
92
93    #[inline] fn get(model_output: &[f32], index: usize) -> f32 {
94        *model_output.get(index).unwrap()
95    }
96}
97
98
99pub struct TensorsToDecoded {
100    decoder: FlatTokenDecoder,
101}
102
103impl TensorsToDecoded {
104    pub fn new(threshold: f32) -> Self {
105        Self { 
106            decoder: FlatTokenDecoder::new(threshold)
107        }
108    }
109}
110
111impl Composable<TensorOutput<'_>, SpanOutput> for TensorsToDecoded {
112    fn apply(&self, input: TensorOutput) -> Result<SpanOutput> {        
113        let logits = input.tensors.get("logits").ok_or("logits not found in model output")?;
114        let (_shape, logits) = logits.try_extract_raw_tensor::<f32>()?;
115        let spans = self.decoder.decode(logits, &input.context)?;        
116        Ok(SpanOutput::new(input.context.texts, input.context.entities, spans))      
117    }
118}