use crate::{
common::{normalize, Embedding},
output::{OutputKey, OutputPrecedence, SingleBatchOutput},
pooling::Pooling,
};
#[cfg(doc)]
use super::TextEmbedding;
pub const OUTPUT_TYPE_PRECEDENCE: &[OutputKey] = &[
OutputKey::OnlyOne,
OutputKey::ByName("text_embeds"),
OutputKey::ByName("last_hidden_state"),
OutputKey::ByName("sentence_embedding"),
];
#[allow(unused_variables)]
pub fn transformer_with_precedence(
output_precedence: impl OutputPrecedence,
pooling: Option<Pooling>,
) -> impl Fn(&[SingleBatchOutput]) -> anyhow::Result<Vec<Embedding>> {
move |batches| {
batches
.iter()
.map(|batch| {
batch
.select_and_pool_output(&output_precedence, pooling.clone())
.and_then(|array| {
array
.rows()
.into_iter()
.map(|row| {
row.as_slice()
.ok_or_else(|| {
anyhow::anyhow!("Failed to convert array row to slice")
})
.map(normalize)
})
.collect::<anyhow::Result<Vec<Embedding>>>()
})
})
.try_fold(Vec::new(), |mut acc, res| {
acc.extend(res?);
Ok(acc)
})
}
}