gte/rerank/
output.rs

1use crate::util::result::Result;
2use crate::commons::output::tensors::OutputTensors;
3
4const TENSOR_LOGITS: &str = "logits";
5
6/// Re-ranking output
7pub struct TextSimilarities {
8    pub scores: ndarray::Array1<f32>,
9}
10
11
12impl TextSimilarities {
13    pub fn try_from(tensors: OutputTensors, sigmoid: bool) -> Result<Self> {
14        // extract last hidden state from `ort` output
15        let last_hidden_state = tensors.outputs.get(TENSOR_LOGITS).ok_or_else(|| format!("expected tensor not found in model output: {TENSOR_LOGITS}"))?;
16        let scores = last_hidden_state.try_extract_tensor::<f32>()?;
17        // reduce superfluous dimensionality 
18        let scores = scores.slice(ndarray::s!(.., 0));
19        // apply sigmoid        
20        let scores = if sigmoid { crate::util::math::sigmoid_a(&scores) } else { scores.into_owned() };
21        // job's done
22        Ok(Self {  scores })
23    }
24}