#[cfg(test)]
use crate::error::{Error, RankMismatchPayload};
use crate::{array::Array, error::Result, lm::cache::KvCache};
pub trait Model {
fn forward(&self, tokens: &Array, cache: &mut [Box<dyn KvCache>]) -> Result<Array>;
fn forward_embeddings(
&self,
_embeddings: &Array,
_cache: &mut [Box<dyn KvCache>],
) -> Result<Array> {
Err(crate::error::Error::InvariantViolation(
crate::error::InvariantViolationPayload::new(
"Model::forward_embeddings",
"not implemented (VLM seam, M4; override this method in multimodal models)",
),
))
}
fn supports_input_embeddings(&self) -> bool {
false
}
}
#[cfg(test)]
pub(crate) struct MockModel {
pub canned: Vec<f32>,
pub n_kv_heads: usize,
pub head_dim: usize,
}
#[cfg(test)]
impl MockModel {
pub(crate) fn new(vocab: usize) -> Self {
let canned = (0..vocab).map(|i| i as f32).collect();
Self {
canned,
n_kv_heads: 1,
head_dim: 2,
}
}
}
#[cfg(test)]
impl Model for MockModel {
fn forward(&self, tokens: &Array, cache: &mut [Box<dyn KvCache>]) -> Result<Array> {
let shape = tokens.shape();
let (batch, seq) = match shape.as_slice() {
[b, s] => (*b, *s),
[s] => (1, *s),
_ => {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"MockModel::forward: tokens must be rank-2 [B, S]",
shape.len() as u32,
shape.to_vec(),
)));
}
};
let vocab = self.canned.len();
for layer in cache.iter_mut() {
let elems = batch * self.n_kv_heads * seq * self.head_dim;
let k = Array::from_slice::<f32>(
&vec![1.0_f32; elems],
&(batch, self.n_kv_heads, seq, self.head_dim),
)?;
let v = Array::from_slice::<f32>(
&vec![2.0_f32; elems],
&(batch, self.n_kv_heads, seq, self.head_dim),
)?;
layer.update(&k, &v)?;
}
let mut data = Vec::with_capacity(batch * seq * vocab);
for _ in 0..batch * seq {
data.extend_from_slice(&self.canned);
}
Array::from_slice::<f32>(&data, &(batch, seq, vocab))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::lm::cache::{CacheConfig, KvCache, make_prompt_cache};
fn tokens(ids: &[i32], batch: usize, seq: usize) -> Array {
Array::from_slice::<i32>(ids, &(batch, seq)).unwrap()
}
#[test]
fn mock_model_forward_uses_cache() {
let model = MockModel::new(5); let cfg = CacheConfig {
num_hidden_layers: 2,
sliding_window: None,
};
let mut cache = make_prompt_cache(&cfg);
assert_eq!(cache.len(), 2);
assert!(cache.iter().all(|c| c.is_empty()));
let mut logits = model
.forward(&tokens(&[1, 2, 3], 1, 3), &mut cache)
.unwrap();
assert_eq!(logits.shape(), vec![1, 3, 5]);
assert!(cache.iter().all(|c| c.offset() == 3));
assert!(cache.iter().all(|c| !c.is_empty()));
let v = logits.to_vec::<f32>().unwrap();
assert_eq!(&v[0..5], &[0.0, 1.0, 2.0, 3.0, 4.0]);
let mut logits = model.forward(&tokens(&[4], 1, 1), &mut cache).unwrap();
assert_eq!(logits.shape(), vec![1, 1, 5]);
assert!(cache.iter().all(|c| c.offset() == 4));
assert_eq!(
logits.to_vec::<f32>().unwrap(),
vec![0.0, 1.0, 2.0, 3.0, 4.0]
);
}
#[test]
fn forward_embeddings_default_is_unimplemented_seam() {
let model = MockModel::new(3);
let mut cache: Vec<Box<dyn KvCache>> = Vec::new();
let emb = Array::from_slice::<f32>(&[0.0, 1.0], &(1usize, 1, 2)).unwrap();
assert!(model.forward_embeddings(&emb, &mut cache).is_err());
}
#[test]
fn forward_rejects_wrong_token_rank() {
let model = MockModel::new(3);
let mut cache: Vec<Box<dyn KvCache>> = Vec::new();
let bad = Array::from_slice::<f32>(&[1.0], &(1usize, 1, 1)).unwrap(); assert!(model.forward(&bad, &mut cache).is_err());
}
}