gliclass/output/
classes.rs

1use std::collections::{BTreeMap, HashMap};
2use ordered_float::OrderedFloat;
3
4use crate::{input::text::Labels, util::result::Result};
5use super::tensors::OutputTensors;
6
7const TENSOR_LOGITS: &str = "logits";
8
9/// End-result of the classification inference pipeline.
10pub struct Classes {
11    pub scores: ndarray::Array2<f32>,
12    pub labels: Labels,
13}
14
15
16impl Classes {
17    pub fn try_from(tensors: OutputTensors) -> Result<Self> {        
18        let logits = tensors.outputs.get(TENSOR_LOGITS).ok_or_else(|| format!("expected tensor not found in model output: {TENSOR_LOGITS}"))?;
19        let scores = logits.try_extract_tensor::<f32>()?;
20        let scores = scores.into_dimensionality::<ndarray::Ix2>()?;
21        let scores = crate::util::math::sigmoid_a(&scores);
22        Ok(Self { scores, labels: tensors.labels })
23    }
24
25    /// Returns the scores for each label, for the given text index, in their original order
26    pub fn scores(&self, index: usize) -> Option<Vec<f32>> {
27        if index < self.scores.nrows() { Some(self.scores.row(index).to_vec()) } else { None }
28    }
29
30    /// Returns the scores by label, for the given text index (with an optional threshold)
31    pub fn labeled_scores(&self, index: usize, threshold: Option<f32>) -> Option<HashMap<&String, f32>> {
32        Some(self.labels.get(index)?
33            .into_iter()
34            .zip(self.scores(index)?)
35            .filter(|(_, score)| *score >= threshold.unwrap_or(0.0))            
36            .collect()
37        )
38    }
39
40    /// Returns the ordered scores by label, for the given text index (with an optional threshold)
41    pub fn ordered_scores(&self, index: usize, threshold: Option<f32>) -> Option<BTreeMap<OrderedFloat<f32>, &String>> {
42        Some(self.scores(index)?
43            .into_iter()
44            .zip(self.labels.get(index)?)
45            .filter(|(score, _)| *score >= threshold.unwrap_or(0.0))
46            .map(|(score, label)| (OrderedFloat::from(score), label))
47            .collect()
48        )
49    }
50
51    /// Returns the best label for the given text index
52    pub fn best_label(&self, text_index: usize, threshold: Option<f32>) -> Option<&str> {
53        let label_index = self
54            .scores(text_index)?
55            .into_iter().enumerate()
56            .filter(|(_, score)| *score >= threshold.unwrap_or(0.0))
57            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
58            .map(|(i, _)| i)?;
59        self.labels
60            .get(text_index)?
61            .get(label_index)
62            .map(String::as_str)
63    }
64
65    pub fn len(&self) -> usize {
66        self.scores.nrows()
67    }
68
69}