gliner/model/input/tensors/
span.rs

1use ort::session::SessionInputs;
2use composable::Composable;
3use crate::util::result::Result;
4use super::super::encoded::EncodedInput;
5use super::super::super::pipeline::context::EntityContext;
6
7
8const TENSOR_INPUT_IDS: &str = "input_ids";
9const TENSOR_ATTENTION_MASK: &str = "attention_mask";
10const TENSOR_WORD_MASK: &str = "words_mask";
11const TENSOR_TEXT_LENGTHS: &str = "text_lengths";
12const TENSOR_SPAN_IDX: &str = "span_idx";
13const TENSOR_SPAN_MASK: &str = "span_mask";
14
15
16/// Ready-for-inference tensors (span mode)
17pub struct SpanTensors<'a> {
18    pub tensors: SessionInputs<'a, 'a>,
19    pub context: EntityContext,    
20}
21
22impl SpanTensors<'_> {
23
24    pub fn from(encoded: EncodedInput, max_width: usize) -> Result<Self> {
25        let (span_idx, span_mask) = Self::make_spans_tensors(&encoded, max_width);
26        let inputs = ort::inputs!{
27            TENSOR_INPUT_IDS => encoded.input_ids,
28            TENSOR_ATTENTION_MASK => encoded.attention_masks,
29            TENSOR_WORD_MASK => encoded.word_masks,
30            TENSOR_TEXT_LENGTHS => encoded.text_lengths,
31            TENSOR_SPAN_IDX => span_idx,
32            TENSOR_SPAN_MASK => span_mask,
33        }?;
34        Ok(Self {
35            tensors: inputs.into(),
36            context: EntityContext { 
37                texts: encoded.texts, 
38                tokens: encoded.tokens, 
39                entities: encoded.entities, 
40                num_words: encoded.num_words 
41            },            
42        })
43    }
44
45    pub fn inputs() -> [&'static str; 6] {
46        [TENSOR_INPUT_IDS, TENSOR_ATTENTION_MASK, TENSOR_WORD_MASK, TENSOR_TEXT_LENGTHS, TENSOR_SPAN_IDX, TENSOR_SPAN_MASK]
47    }
48
49    /// Expected tensor for num_words=4 and max_width=12:
50    /// ```text
51    /// start, end, mask
52    /// 0, 0, true
53    /// 0, 1, true
54    /// 0, 2, true
55    /// 0, 3, true
56    /// 0, 0, false
57    /// [...until we have all the 12 spans for this token]
58    /// 1, 1, true
59    /// 1, 2, true
60    /// 1, 3, true
61    /// 0, 0, false
62    /// [...]
63    /// 2, 2, true
64    /// 2, 3, true
65    /// 0, 0, false
66    /// [...]
67    /// 3, 3, true
68    /// 0, 0, false
69    /// [...]
70    /// ```    
71    fn make_spans_tensors(encoded: &EncodedInput, max_width: usize) -> (ndarray::Array3<i64>, ndarray::Array2<bool>) {
72        // total number of spans for each sequence: at most num_words * max_width
73        let num_spans = encoded.num_words * max_width;
74        
75        // prepare output tensors (zero-filled, values will be set in place)
76        let mut span_idx = ndarray::Array::zeros((encoded.texts.len(), num_spans, 2));
77        let mut span_mask = ndarray::Array::from_elem((encoded.texts.len(), num_spans), false);
78
79        // iterate over segments
80        for s in 0..encoded.texts.len() {
81            // get the actual width of the current segment
82            let text_width = *encoded.text_lengths.get((s, 0)).unwrap() as usize;
83
84            // repeat for each start offset in [0;text_width]
85            for start in 0..text_width {          
86                // remaining width from start offset
87                let remaining_width = text_width - start;
88                // the maximum span width is no more than remaining width in the sequence, or maximum span width
89                let actual_max_width = std::cmp::min(max_width, remaining_width);
90                // repeat for each possible width in between
91                for width in 0..actual_max_width {
92                    // retrieve the appropriate dimension on the second axis
93                    let dim = start * max_width + width;
94                    // fill the tensors in place
95                    span_idx[[s, dim, 0]] = start as i64; // start offset
96                    span_idx[[s, dim, 1]] = (start + width) as i64; // end offset
97                    span_mask[[s, dim]] = true; // mask
98                }
99            }
100        }
101
102        // return both tensors
103        (span_idx, span_mask)
104    }
105
106}
107
108
109/// Composable: Encoded => SpanTensors
110pub struct EncodedToTensors { 
111    max_width: usize,
112}
113
114impl EncodedToTensors {
115    pub fn new(max_width: usize) -> Self { 
116        Self { max_width } 
117    }
118}
119
120impl<'a> Composable<EncodedInput, SpanTensors<'a>> for EncodedToTensors {
121    fn apply(&self, input: EncodedInput) -> Result<SpanTensors<'a>> {
122        SpanTensors::from(input, self.max_width)
123    }
124}
125
126
127/// Composable: SpanTensors => (SessionInput, EntityContext) 
128#[derive(Default)]
129pub struct TensorsToSessionInput { 
130}
131
132
133impl<'a> Composable<SpanTensors<'a>, (SessionInputs<'a, 'a>, EntityContext)> for TensorsToSessionInput {
134    fn apply(&self, input: SpanTensors<'a>) -> Result<(SessionInputs<'a, 'a>, EntityContext)> {
135        Ok((input.tensors, input.context))
136    }
137}
138
139
140/// Unit tests
141#[cfg(test)]
142mod tests {
143    use ort::session::SessionInputValue;
144    use super::*;
145
146    #[test]
147    fn test() -> Result<()> {        
148        // Silent some clippy warnings for unit tests
149        #![allow(clippy::get_first)]
150        #![allow(clippy::unwrap_used)]
151        // Processing
152        let splitter = crate::text::splitter::RegexSplitter::default();        
153        let tokenizer = crate::text::tokenizer::HFTokenizer::from_file(std::path::Path::new("models/gliner_small-v2.1/tokenizer.json"))?;
154        let batch = [ "My name is James Bond", "I like to drive my Aston Martin"];
155        let entities = [ "movie character", "vehicle" ];
156        let input = super::super::super::text::TextInput::from_str(&batch, &entities)?;
157        let tokenized = super::super::super::tokenized::TokenizedInput::from(input, &splitter, None)?;
158        let prepared = super::super::super::prompt::PromptInput::from(tokenized);
159        let encoded = EncodedInput::from(prepared, &tokenizer)?;
160        let spans = SpanTensors::from(encoded, 12)?;
161        let span_idx = get_tensor("span_idx", &spans.tensors)?;
162        let span_idx = span_idx.try_extract_tensor::<i64>()?;        
163        let span_masks = get_tensor("span_mask", &spans.tensors)?;
164        let span_masks = span_masks.try_extract_tensor::<bool>()?;        
165        // Some prints
166        if false {
167            println!("Spans: {:?}", &span_idx);
168            println!("Spans Masks: {:?}", &span_masks);
169        }
170        // Assertions (TODO: add more)
171        assert_eq!(span_idx.shape(), vec![2, 84, 2]);
172        assert_eq!(span_masks.shape(), vec![2, 84]);
173        // Everything rules
174        Ok(())
175    }
176
177    fn get_tensor<'a>(key: &str, si: &'a SessionInputs<'a, 'a>) -> Result<&'a SessionInputValue<'a>> {
178        if let SessionInputs::ValueMap(map) = si {
179            for (k, v) in map {
180                if k.eq(key) {
181                    return Ok(v);
182                }
183            }
184        }
185        Err("cannot extract expected tensor".into())
186    }
187
188}