gliner/model/input/tensors/
span.rs1use ort::session::SessionInputs;
2use composable::Composable;
3use crate::util::result::Result;
4use super::super::encoded::EncodedInput;
5use super::super::super::pipeline::context::EntityContext;
6
7
8const TENSOR_INPUT_IDS: &str = "input_ids";
9const TENSOR_ATTENTION_MASK: &str = "attention_mask";
10const TENSOR_WORD_MASK: &str = "words_mask";
11const TENSOR_TEXT_LENGTHS: &str = "text_lengths";
12const TENSOR_SPAN_IDX: &str = "span_idx";
13const TENSOR_SPAN_MASK: &str = "span_mask";
14
15
16pub struct SpanTensors<'a> {
18 pub tensors: SessionInputs<'a, 'a>,
19 pub context: EntityContext,
20}
21
22impl SpanTensors<'_> {
23
24 pub fn from(encoded: EncodedInput, max_width: usize) -> Result<Self> {
25 let (span_idx, span_mask) = Self::make_spans_tensors(&encoded, max_width);
26 let inputs = ort::inputs!{
27 TENSOR_INPUT_IDS => encoded.input_ids,
28 TENSOR_ATTENTION_MASK => encoded.attention_masks,
29 TENSOR_WORD_MASK => encoded.word_masks,
30 TENSOR_TEXT_LENGTHS => encoded.text_lengths,
31 TENSOR_SPAN_IDX => span_idx,
32 TENSOR_SPAN_MASK => span_mask,
33 }?;
34 Ok(Self {
35 tensors: inputs.into(),
36 context: EntityContext {
37 texts: encoded.texts,
38 tokens: encoded.tokens,
39 entities: encoded.entities,
40 num_words: encoded.num_words
41 },
42 })
43 }
44
45 pub fn inputs() -> [&'static str; 6] {
46 [TENSOR_INPUT_IDS, TENSOR_ATTENTION_MASK, TENSOR_WORD_MASK, TENSOR_TEXT_LENGTHS, TENSOR_SPAN_IDX, TENSOR_SPAN_MASK]
47 }
48
49 fn make_spans_tensors(encoded: &EncodedInput, max_width: usize) -> (ndarray::Array3<i64>, ndarray::Array2<bool>) {
72 let num_spans = encoded.num_words * max_width;
74
75 let mut span_idx = ndarray::Array::zeros((encoded.texts.len(), num_spans, 2));
77 let mut span_mask = ndarray::Array::from_elem((encoded.texts.len(), num_spans), false);
78
79 for s in 0..encoded.texts.len() {
81 let text_width = *encoded.text_lengths.get((s, 0)).unwrap() as usize;
83
84 for start in 0..text_width {
86 let remaining_width = text_width - start;
88 let actual_max_width = std::cmp::min(max_width, remaining_width);
90 for width in 0..actual_max_width {
92 let dim = start * max_width + width;
94 span_idx[[s, dim, 0]] = start as i64; span_idx[[s, dim, 1]] = (start + width) as i64; span_mask[[s, dim]] = true; }
99 }
100 }
101
102 (span_idx, span_mask)
104 }
105
106}
107
108
109pub struct EncodedToTensors {
111 max_width: usize,
112}
113
114impl EncodedToTensors {
115 pub fn new(max_width: usize) -> Self {
116 Self { max_width }
117 }
118}
119
120impl<'a> Composable<EncodedInput, SpanTensors<'a>> for EncodedToTensors {
121 fn apply(&self, input: EncodedInput) -> Result<SpanTensors<'a>> {
122 SpanTensors::from(input, self.max_width)
123 }
124}
125
126
127#[derive(Default)]
129pub struct TensorsToSessionInput {
130}
131
132
133impl<'a> Composable<SpanTensors<'a>, (SessionInputs<'a, 'a>, EntityContext)> for TensorsToSessionInput {
134 fn apply(&self, input: SpanTensors<'a>) -> Result<(SessionInputs<'a, 'a>, EntityContext)> {
135 Ok((input.tensors, input.context))
136 }
137}
138
139
140#[cfg(test)]
142mod tests {
143 use ort::session::SessionInputValue;
144 use super::*;
145
146 #[test]
147 fn test() -> Result<()> {
148 #![allow(clippy::get_first)]
150 #![allow(clippy::unwrap_used)]
151 let splitter = crate::text::splitter::RegexSplitter::default();
153 let tokenizer = crate::text::tokenizer::HFTokenizer::from_file(std::path::Path::new("models/gliner_small-v2.1/tokenizer.json"))?;
154 let batch = [ "My name is James Bond", "I like to drive my Aston Martin"];
155 let entities = [ "movie character", "vehicle" ];
156 let input = super::super::super::text::TextInput::from_str(&batch, &entities)?;
157 let tokenized = super::super::super::tokenized::TokenizedInput::from(input, &splitter, None)?;
158 let prepared = super::super::super::prompt::PromptInput::from(tokenized);
159 let encoded = EncodedInput::from(prepared, &tokenizer)?;
160 let spans = SpanTensors::from(encoded, 12)?;
161 let span_idx = get_tensor("span_idx", &spans.tensors)?;
162 let span_idx = span_idx.try_extract_tensor::<i64>()?;
163 let span_masks = get_tensor("span_mask", &spans.tensors)?;
164 let span_masks = span_masks.try_extract_tensor::<bool>()?;
165 if false {
167 println!("Spans: {:?}", &span_idx);
168 println!("Spans Masks: {:?}", &span_masks);
169 }
170 assert_eq!(span_idx.shape(), vec![2, 84, 2]);
172 assert_eq!(span_masks.shape(), vec![2, 84]);
173 Ok(())
175 }
176
177 fn get_tensor<'a>(key: &str, si: &'a SessionInputs<'a, 'a>) -> Result<&'a SessionInputValue<'a>> {
178 if let SessionInputs::ValueMap(map) = si {
179 for (k, v) in map {
180 if k.eq(key) {
181 return Ok(v);
182 }
183 }
184 }
185 Err("cannot extract expected tensor".into())
186 }
187
188}