gte/embed/
output.rs

1use composable::Composable;
2use crate::commons::output::tensors::OutputTensors;
3
4
5/// Text embedding output
6pub struct TextEmbeddings {
7    pub embeddings: ndarray::Array2<f32>,
8}
9
10impl TextEmbeddings {
11    pub fn embeddings(&self, index: usize) -> ndarray::ArrayView1<f32> {
12        self.embeddings.slice(ndarray::s![index, ..])
13    }
14
15    pub fn len(&self) -> usize {
16        self.embeddings.dim().0
17    }
18
19    pub fn is_empty(&self) -> bool {
20        self.embeddings.is_empty()
21    }
22}
23
24
25/// Output tensor identifier
26#[derive(Debug, Clone, PartialEq, Eq, Hash)]
27pub struct OutputId(String);
28impl Default for OutputId { fn default() -> Self { OutputId("last_hidden_state".into()) } }
29impl From<String> for OutputId { fn from(s: String) -> Self { OutputId(s) } }
30impl From<&str> for OutputId { fn from(s: &str) -> Self { OutputId(s.to_string()) } }
31impl AsRef<str> for OutputId { fn as_ref(&self) -> &str { &self.0 } }
32impl std::ops::Deref for OutputId { type Target = str; fn deref(&self) -> &Self::Target { &self.0 } }
33impl std::fmt::Display for OutputId { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { self.0.fmt(f) } }
34
35
36/// Defines the way embeddings are extracted from the output tensor
37#[derive(Clone, Copy)]
38pub enum ExtractorMode {
39    /// The tensor is expected to provide directly usable embeddings for each sequence
40    Raw,
41    /// The tensor is expected to provide embeddings for each token, and we use one's vector as sequence embedding (usually the first one)
42    Token(usize),
43}
44
45impl Default for ExtractorMode {
46    fn default() -> Self { ExtractorMode::Token(0) }
47}
48
49
50/// Composable that extracts the embeddings from the output tensors, according to the specified output id and extraction mode
51#[derive(Default)]
52pub struct EmbeddingsExtractor {
53    output_id: OutputId,
54    mode: ExtractorMode,    
55}
56
57impl EmbeddingsExtractor {
58    pub fn new(output_id: &OutputId, mode: ExtractorMode) -> Self {
59        Self { 
60            output_id: output_id.clone(), 
61            mode,
62        }
63    }
64}
65
66impl Composable<OutputTensors<'_>, TextEmbeddings> for EmbeddingsExtractor {
67    fn apply(&self, output_tensors: OutputTensors) -> composable::Result<TextEmbeddings> {
68        // extract the tensor from the ORT output
69        let output_tensor = output_tensors.outputs.get(&self.output_id).ok_or_else(|| format!("tensor not found in model output: {}", self.output_id))?;
70        let output_tensor = output_tensor.try_extract_tensor::<f32>()?;
71        
72        // extract the actual embeddings depending on the desired mode
73        match self.mode {
74            ExtractorMode::Raw => {
75                // the raw output tensor is supposed to provide the actual embeddings by sequence
76                let embeddings = output_tensor.into_dimensionality::<ndarray::Ix2>()?;
77                Ok(TextEmbeddings { embeddings: embeddings.into_owned() })
78            },
79            ExtractorMode::Token(index) => {
80                // we select the selected token (by index) of each sequence         
81                let embeddings = output_tensor.slice(ndarray::s![.., index, ..]);
82                Ok(TextEmbeddings { embeddings: embeddings.into_owned() })
83            },
84        }
85    }
86}