use crate::errors::NoosResult;
pub trait LocalModel: Send {
fn forward(&mut self, tokens: &[u32], position: usize) -> NoosResult<Vec<f32>>;
fn vocab_size(&self) -> usize;
fn reset_cache(&mut self);
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
pub(crate) struct MockModel {
vocab_size: usize,
call_count: usize,
}
impl MockModel {
pub fn new(vocab_size: usize) -> Self {
Self {
vocab_size,
call_count: 0,
}
}
}
impl LocalModel for MockModel {
fn forward(&mut self, _tokens: &[u32], _position: usize) -> NoosResult<Vec<f32>> {
self.call_count += 1;
let mut logits = vec![0.0f32; self.vocab_size];
let peak = self.call_count % self.vocab_size;
logits[peak] = 5.0;
Ok(logits)
}
fn vocab_size(&self) -> usize {
self.vocab_size
}
fn reset_cache(&mut self) {
self.call_count = 0;
}
}
#[test]
fn mock_model_returns_logits() {
let mut model = MockModel::new(10);
let logits = model.forward(&[1, 2, 3], 0).unwrap();
assert_eq!(logits.len(), 10);
}
#[test]
fn mock_model_cycles_peak() {
let mut model = MockModel::new(3);
let logits1 = model.forward(&[1], 0).unwrap();
assert_eq!(logits1[1], 5.0);
let logits2 = model.forward(&[1], 1).unwrap();
assert_eq!(logits2[2], 5.0); }
#[test]
fn reset_cache_resets_state() {
let mut model = MockModel::new(3);
model.forward(&[1], 0).unwrap();
model.forward(&[1], 1).unwrap();
model.reset_cache();
let logits = model.forward(&[1], 0).unwrap();
assert_eq!(logits[1], 5.0); }
}