use ndarray::{Array2, ArrayView, Dim, IxDynImpl};
use crate::pooling;
use super::{OutputKey, OutputPrecedence};
pub struct SingleBatchOutput {
pub outputs: Vec<(String, ort::value::Value)>,
pub attention_mask_array: Array2<i64>,
}
impl SingleBatchOutput {
pub fn select_output(
&self,
precedence: &impl OutputPrecedence,
) -> anyhow::Result<ArrayView<'_, f32, Dim<IxDynImpl>>> {
let ort_output: &ort::value::Value = precedence
.key_precedence()
.find_map(|key| match key {
OutputKey::OnlyOne => {
if self.outputs.len() == 1 {
self.outputs.first().map(|(_, v)| v)
} else {
None
}
}
OutputKey::ByOrder(idx) => self.outputs.get(*idx).map(|(_, v)| v),
OutputKey::ByName(name) => {
self.outputs.iter().find(|(n, _)| n == name).map(|(_, v)| v)
}
})
.ok_or_else(|| {
anyhow::Error::msg(format!(
"No suitable output found in the outputs. Available outputs: {:?}",
self.outputs.iter().map(|(k, _)| k).collect::<Vec<_>>()
))
})?;
ort_output.try_extract_array().map_err(anyhow::Error::new)
}
pub fn select_and_pool_output(
&self,
precedence: &impl OutputPrecedence,
pooling_opt: Option<pooling::Pooling>,
) -> anyhow::Result<Array2<f32>> {
let tensor = self.select_output(precedence)?;
match pooling_opt.unwrap_or_default() {
pooling::Pooling::Cls => pooling::cls(&tensor),
pooling::Pooling::Mean => pooling::mean(&tensor, self.attention_mask_array.clone()),
}
}
}
pub struct EmbeddingOutput {
batches: Vec<SingleBatchOutput>,
}
impl EmbeddingOutput {
pub fn new(batches: impl IntoIterator<Item = SingleBatchOutput>) -> Self {
Self {
batches: batches.into_iter().collect(),
}
}
pub fn into_raw(self) -> Vec<SingleBatchOutput> {
self.batches
}
pub fn export_with_transformer<R>(
&self,
transformer: impl Fn(&[SingleBatchOutput]) -> anyhow::Result<R>,
) -> anyhow::Result<R> {
transformer(&self.batches)
}
}