use ndarray::{Array2, ArrayView, Dim, IxDynImpl};
use ort::session::SessionOutputs;
use crate::pooling;
use super::{OutputKey, OutputPrecedence};
pub struct SingleBatchOutput<'r, 's> {
pub session_outputs: SessionOutputs<'r, 's>,
pub attention_mask_array: Array2<i64>,
}
impl SingleBatchOutput<'_, '_> {
pub fn select_output(
&self,
precedence: &impl OutputPrecedence,
) -> anyhow::Result<ArrayView<f32, Dim<IxDynImpl>>> {
let ort_output: &ort::value::Value = precedence
.key_precedence()
.find_map(|key| match key {
OutputKey::OnlyOne => self
.session_outputs
.get(self.session_outputs.keys().nth(0)?),
OutputKey::ByOrder(idx) => {
let x = self
.session_outputs
.get(self.session_outputs.keys().nth(*idx)?);
x
}
OutputKey::ByName(name) => self.session_outputs.get(name),
})
.ok_or_else(|| {
anyhow::Error::msg(format!(
"No suitable output found in the session outputs. Available outputs: {:?}",
self.session_outputs.keys().collect::<Vec<_>>()
))
})?;
ort_output
.try_extract_tensor::<f32>()
.map_err(anyhow::Error::new)
}
pub fn select_and_pool_output(
&self,
precedence: &impl OutputPrecedence,
pooling_opt: Option<pooling::Pooling>,
) -> anyhow::Result<Array2<f32>> {
let tensor = self.select_output(precedence)?;
match pooling_opt.unwrap_or_default() {
pooling::Pooling::Cls => pooling::cls(&tensor),
pooling::Pooling::Mean => pooling::mean(&tensor, self.attention_mask_array.clone()),
}
}
}
pub struct EmbeddingOutput<'r, 's> {
batches: Vec<SingleBatchOutput<'r, 's>>,
}
impl<'r, 's> EmbeddingOutput<'r, 's> {
pub fn new(batches: impl IntoIterator<Item = SingleBatchOutput<'r, 's>>) -> Self {
Self {
batches: batches.into_iter().collect(),
}
}
pub fn into_raw(self) -> Vec<SingleBatchOutput<'r, 's>> {
self.batches
}
pub fn export_with_transformer<R>(
&self,
transformer: impl Fn(&[SingleBatchOutput]) -> anyhow::Result<R>,
) -> anyhow::Result<R> {
transformer(&self.batches)
}
}