1use crate::util::result::Result;
2use crate::commons::output::tensors::OutputTensors;
3
4const TENSOR_LOGITS: &str = "logits";
5
6pub struct TextSimilarities {
8 pub scores: ndarray::Array1<f32>,
9}
10
11
12impl TextSimilarities {
13 pub fn try_from(tensors: OutputTensors, sigmoid: bool) -> Result<Self> {
14 let last_hidden_state = tensors.outputs.get(TENSOR_LOGITS).ok_or_else(|| format!("expected tensor not found in model output: {TENSOR_LOGITS}"))?;
16 let scores = last_hidden_state.try_extract_tensor::<f32>()?;
17 let scores = scores.slice(ndarray::s!(.., 0));
19 let scores = if sigmoid { crate::util::math::sigmoid_a(&scores) } else { scores.into_owned() };
21 Ok(Self { scores })
23 }
24}