gliclass/output/
classes.rs1use std::collections::{BTreeMap, HashMap};
2use ordered_float::OrderedFloat;
3
4use crate::{input::text::Labels, util::result::Result};
5use super::tensors::OutputTensors;
6
7const TENSOR_LOGITS: &str = "logits";
8
9pub struct Classes {
11 pub scores: ndarray::Array2<f32>,
12 pub labels: Labels,
13}
14
15
16impl Classes {
17 pub fn try_from(tensors: OutputTensors) -> Result<Self> {
18 let logits = tensors.outputs.get(TENSOR_LOGITS).ok_or_else(|| format!("expected tensor not found in model output: {TENSOR_LOGITS}"))?;
19 let scores = logits.try_extract_tensor::<f32>()?;
20 let scores = scores.into_dimensionality::<ndarray::Ix2>()?;
21 let scores = crate::util::math::sigmoid_a(&scores);
22 Ok(Self { scores, labels: tensors.labels })
23 }
24
25 pub fn scores(&self, index: usize) -> Option<Vec<f32>> {
27 if index < self.scores.nrows() { Some(self.scores.row(index).to_vec()) } else { None }
28 }
29
30 pub fn labeled_scores(&self, index: usize, threshold: Option<f32>) -> Option<HashMap<&String, f32>> {
32 Some(self.labels.get(index)?
33 .into_iter()
34 .zip(self.scores(index)?)
35 .filter(|(_, score)| *score >= threshold.unwrap_or(0.0))
36 .collect()
37 )
38 }
39
40 pub fn ordered_scores(&self, index: usize, threshold: Option<f32>) -> Option<BTreeMap<OrderedFloat<f32>, &String>> {
42 Some(self.scores(index)?
43 .into_iter()
44 .zip(self.labels.get(index)?)
45 .filter(|(score, _)| *score >= threshold.unwrap_or(0.0))
46 .map(|(score, label)| (OrderedFloat::from(score), label))
47 .collect()
48 )
49 }
50
51 pub fn best_label(&self, text_index: usize, threshold: Option<f32>) -> Option<&str> {
53 let label_index = self
54 .scores(text_index)?
55 .into_iter().enumerate()
56 .filter(|(_, score)| *score >= threshold.unwrap_or(0.0))
57 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
58 .map(|(i, _)| i)?;
59 self.labels
60 .get(text_index)?
61 .get(label_index)
62 .map(String::as_str)
63 }
64
65 pub fn len(&self) -> usize {
66 self.scores.nrows()
67 }
68
69}