#[cfg(test)]
use crate::error::{EmptyInputPayload, LengthMismatchPayload, RankMismatchPayload};
use crate::{array::Array, error::Result};
#[derive(Debug)]
pub struct EmbeddingModelOutput {
last_hidden_state: Array,
pooled_output: Option<Array>,
}
impl EmbeddingModelOutput {
pub fn new(last_hidden_state: Array, pooled_output: Option<Array>) -> Self {
Self {
last_hidden_state,
pooled_output,
}
}
pub fn from_hidden_state(last_hidden_state: Array) -> Self {
Self::new(last_hidden_state, None)
}
#[inline(always)]
pub fn last_hidden_state(&self) -> &Array {
&self.last_hidden_state
}
#[inline(always)]
pub fn pooled_output(&self) -> Option<&Array> {
self.pooled_output.as_ref()
}
#[inline(always)]
pub fn into_parts(self) -> (Array, Option<Array>) {
(self.last_hidden_state, self.pooled_output)
}
}
pub trait EmbeddingModel {
fn forward(&self, input_ids: &Array, attention_mask: &Array) -> Result<EmbeddingModelOutput>;
}
#[cfg(test)]
pub(crate) struct MockEmbeddingModel {
pub canned: Vec<Vec<f32>>,
pub pooled: Option<Vec<Vec<f32>>>,
}
#[cfg(test)]
impl MockEmbeddingModel {
pub(crate) fn new(canned: Vec<Vec<f32>>) -> Self {
let hidden = canned.iter().map(Vec::len).max().unwrap_or(0);
let canned = canned
.into_iter()
.map(|mut row| {
row.resize(hidden, 0.0);
row
})
.collect();
Self {
canned,
pooled: None,
}
}
pub(crate) fn with_pooled(mut self, pooled: Vec<Vec<f32>>) -> Self {
let hidden = pooled.iter().map(Vec::len).max().unwrap_or(0);
let pooled = pooled
.into_iter()
.map(|mut row| {
row.resize(hidden, 0.0);
row
})
.collect();
self.pooled = Some(pooled);
self
}
}
#[cfg(test)]
impl EmbeddingModel for MockEmbeddingModel {
fn forward(&self, input_ids: &Array, _attention_mask: &Array) -> Result<EmbeddingModelOutput> {
let shape = input_ids.shape();
let (batch, seq) = match shape.as_slice() {
[b, s] => (*b, *s),
_ => {
return Err(crate::error::Error::RankMismatch(RankMismatchPayload::new(
"MockEmbeddingModel::forward expects rank-2 (batch, seq_len) ids",
shape.len() as u32,
shape.clone(),
)));
}
};
if seq > self.canned.len() {
return Err(crate::error::Error::LengthMismatch(
LengthMismatchPayload::new(
"MockEmbeddingModel: seq_len must be <= canned positions",
self.canned.len(),
seq,
),
));
}
let hidden = self.canned.first().map_or(0, Vec::len);
let mut data = Vec::with_capacity(batch * seq * hidden);
for _ in 0..batch {
for row in self.canned.iter().take(seq) {
data.extend_from_slice(row);
}
}
let last_hidden_state = Array::from_slice::<f32>(&data, &(batch, seq, hidden))?;
let pooled_output = match &self.pooled {
None => None,
Some(pooled) => {
if pooled.is_empty() {
return Err(crate::error::Error::EmptyInput(EmptyInputPayload::new(
"MockEmbeddingModel: pooled_output rows",
)));
}
let pooled_hidden = pooled[0].len();
let mut pdata = Vec::with_capacity(batch * pooled_hidden);
for b in 0..batch {
pdata.extend_from_slice(&pooled[b % pooled.len()]);
}
Some(Array::from_slice::<f32>(&pdata, &(batch, pooled_hidden))?)
}
};
Ok(EmbeddingModelOutput::new(last_hidden_state, pooled_output))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mock_forward_tiles_canned_rows_across_batch() {
let model = MockEmbeddingModel::new(vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]]);
let ids = Array::from_slice::<i32>(&[7, 8, 9, 10], &(2, 2)).unwrap();
let mask = Array::from_slice::<f32>(&[1.0, 1.0, 1.0, 1.0], &(2, 2)).unwrap();
let out = model.forward(&ids, &mask).unwrap();
assert_eq!(out.last_hidden_state().shape(), vec![2, 2, 2]);
assert!(out.pooled_output().is_none());
let (mut lhs, _) = out.into_parts();
assert_eq!(
lhs.to_vec::<f32>().unwrap(),
vec![1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]
);
}
#[test]
fn mock_forward_rejects_wrong_rank() {
let model = MockEmbeddingModel::new(vec![vec![1.0, 2.0]]);
let bad = Array::from_slice::<i32>(&[1, 2, 3], &(3,)).unwrap(); let mask = Array::from_slice::<f32>(&[1.0, 1.0, 1.0], &(3,)).unwrap();
assert!(model.forward(&bad, &mask).is_err());
}
#[test]
fn mock_forward_rejects_seq_longer_than_canned() {
let model = MockEmbeddingModel::new(vec![vec![1.0, 2.0]]); let ids = Array::from_slice::<i32>(&[1, 2], &(1, 2)).unwrap(); let mask = Array::from_slice::<f32>(&[1.0, 1.0], &(1, 2)).unwrap();
assert!(model.forward(&ids, &mask).is_err());
}
}