1use composable::Composable;
2use crate::commons::output::tensors::OutputTensors;
3
4
5pub struct TextEmbeddings {
7 pub embeddings: ndarray::Array2<f32>,
8}
9
10impl TextEmbeddings {
11 pub fn embeddings(&self, index: usize) -> ndarray::ArrayView1<f32> {
12 self.embeddings.slice(ndarray::s![index, ..])
13 }
14
15 pub fn len(&self) -> usize {
16 self.embeddings.dim().0
17 }
18
19 pub fn is_empty(&self) -> bool {
20 self.embeddings.is_empty()
21 }
22}
23
24
25#[derive(Debug, Clone, PartialEq, Eq, Hash)]
27pub struct OutputId(String);
28impl Default for OutputId { fn default() -> Self { OutputId("last_hidden_state".into()) } }
29impl From<String> for OutputId { fn from(s: String) -> Self { OutputId(s) } }
30impl From<&str> for OutputId { fn from(s: &str) -> Self { OutputId(s.to_string()) } }
31impl AsRef<str> for OutputId { fn as_ref(&self) -> &str { &self.0 } }
32impl std::ops::Deref for OutputId { type Target = str; fn deref(&self) -> &Self::Target { &self.0 } }
33impl std::fmt::Display for OutputId { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { self.0.fmt(f) } }
34
35
36#[derive(Clone, Copy)]
38pub enum ExtractorMode {
39 Raw,
41 Token(usize),
43}
44
45impl Default for ExtractorMode {
46 fn default() -> Self { ExtractorMode::Token(0) }
47}
48
49
50#[derive(Default)]
52pub struct EmbeddingsExtractor {
53 output_id: OutputId,
54 mode: ExtractorMode,
55}
56
57impl EmbeddingsExtractor {
58 pub fn new(output_id: &OutputId, mode: ExtractorMode) -> Self {
59 Self {
60 output_id: output_id.clone(),
61 mode,
62 }
63 }
64}
65
66impl Composable<OutputTensors<'_>, TextEmbeddings> for EmbeddingsExtractor {
67 fn apply(&self, output_tensors: OutputTensors) -> composable::Result<TextEmbeddings> {
68 let output_tensor = output_tensors.outputs.get(&self.output_id).ok_or_else(|| format!("tensor not found in model output: {}", self.output_id))?;
70 let output_tensor = output_tensor.try_extract_tensor::<f32>()?;
71
72 match self.mode {
74 ExtractorMode::Raw => {
75 let embeddings = output_tensor.into_dimensionality::<ndarray::Ix2>()?;
77 Ok(TextEmbeddings { embeddings: embeddings.into_owned() })
78 },
79 ExtractorMode::Token(index) => {
80 let embeddings = output_tensor.slice(ndarray::s![.., index, ..]);
82 Ok(TextEmbeddings { embeddings: embeddings.into_owned() })
83 },
84 }
85 }
86}