use crate::{
common::{normalize, Embedding},
output::{OutputKey, OutputPrecedence, SingleBatchOutput},
pooling::Pooling,
};
#[cfg(doc)]
use super::TextEmbedding;
pub const OUTPUT_TYPE_PRECENDENCE: &[OutputKey] = &[
OutputKey::OnlyOne,
OutputKey::ByName("last_hidden_state"),
OutputKey::ByName("sentence_embedding"),
];
#[allow(unused_variables)]
pub fn transformer_with_precedence(
output_precedence: impl OutputPrecedence,
pooling: Option<Pooling>,
) -> impl Fn(&[SingleBatchOutput]) -> anyhow::Result<Vec<Embedding>> {
move |batches| {
batches
.iter()
.map(|batch| {
batch
.select_and_pool_output(&output_precedence, pooling.clone())
.map(|array| {
array
.rows()
.into_iter()
.map(|row| normalize(row.as_slice().unwrap()))
.collect::<Vec<Embedding>>()
})
})
.try_fold(Vec::new(), |mut acc, res| {
acc.extend(res?);
Ok(acc)
})
}
}