use crate::jamba2::{
config::{Jamba2Config, Jamba2ConfigError},
model::{Jamba2Error, Jamba2Model},
};
#[derive(Debug, thiserror::Error)]
pub enum Jamba2TaskError {
#[error("Configuration error: {0}")]
Config(#[from] Jamba2ConfigError),
#[error("Model error: {0}")]
Model(#[from] Jamba2Error),
#[error("Logits shape mismatch: expected [seq_len={expected_seq}, vocab={expected_vocab}], got [{got}]")]
LogitsMismatch {
expected_seq: usize,
expected_vocab: usize,
got: usize,
},
}
pub struct CausalLmOutput {
pub logits: Vec<Vec<f64>>,
pub hidden_states: Vec<Vec<f64>>,
}
pub struct Jamba2ForCausalLM {
model: Jamba2Model,
lm_head: Vec<Vec<f64>>,
}
impl Jamba2ForCausalLM {
pub fn new(config: Jamba2Config) -> Result<Self, Jamba2TaskError> {
config.validate()?;
let vocab_size = config.vocab_size;
let hidden_size = config.hidden_size;
let lm_head: Vec<Vec<f64>> = (0..vocab_size)
.map(|i| {
let mut row = vec![0.0f64; hidden_size];
row[i % hidden_size] = 0.01;
row
})
.collect();
let model = Jamba2Model::new(config);
Ok(Self { model, lm_head })
}
pub fn forward(&self, input_ids: &[u32]) -> Result<CausalLmOutput, Jamba2TaskError> {
let hidden_states = self.model.forward(input_ids)?;
let logits: Vec<Vec<f64>> = hidden_states
.iter()
.map(|h| {
self.lm_head
.iter()
.map(|row| row.iter().zip(h.iter()).map(|(w, v)| w * v).sum())
.collect()
})
.collect();
Ok(CausalLmOutput {
logits,
hidden_states,
})
}
pub fn generate(
&self,
input_ids: &[u32],
max_new_tokens: usize,
) -> Result<Vec<u32>, Jamba2TaskError> {
if input_ids.is_empty() {
return Err(Jamba2TaskError::Model(Jamba2Error::EmptyInput));
}
let vocab_size = self.model.config().vocab_size;
let mut tokens: Vec<u32> = input_ids.to_vec();
for _ in 0..max_new_tokens {
let output = self.forward(&tokens)?;
let last_logits = &output.logits[output.logits.len() - 1];
let next_token = last_logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i as u32)
.unwrap_or(0);
tokens.push(next_token % vocab_size as u32);
}
Ok(tokens[input_ids.len()..].to_vec())
}
pub fn model(&self) -> &Jamba2Model {
&self.model
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::jamba2::config::Jamba2Config;
fn small_config() -> Jamba2Config {
Jamba2Config {
vocab_size: 64,
hidden_size: 32,
intermediate_size: 64,
num_hidden_layers: 4,
num_attention_heads: 4,
num_key_value_heads: 2,
head_dim: 8,
mamba_d_state: 4,
mamba_d_conv: 2,
mamba_expand: 2,
mamba_dt_rank: 8,
attn_layer_offset: 1,
attn_layer_period: 2,
expert_layer_offset: 1,
expert_layer_period: 2,
num_experts: 4,
num_experts_per_tok: 2,
max_position_embeddings: 64,
rms_norm_eps: 1e-5,
rope_theta: 10000.0,
hidden_act: "silu".to_string(),
attention_dropout: 0.0,
tie_word_embeddings: false,
}
}
#[test]
fn test_construction_succeeds() {
let result = Jamba2ForCausalLM::new(small_config());
assert!(
result.is_ok(),
"Jamba2ForCausalLM must construct successfully"
);
}
#[test]
fn test_forward_produces_output() {
let model =
Jamba2ForCausalLM::new(small_config()).unwrap_or_else(|_| panic!("init failed"));
let result = model.forward(&[1u32, 2, 3]);
assert!(result.is_ok(), "forward must succeed");
}
#[test]
fn test_forward_logits_row_count() {
let cfg = small_config();
let model = Jamba2ForCausalLM::new(cfg.clone()).unwrap_or_else(|_| panic!("init failed"));
let out = model.forward(&[1u32, 2, 3]).unwrap_or_else(|_| panic!("forward failed"));
assert_eq!(out.logits.len(), 3, "logits must have seq_len rows");
}
#[test]
fn test_forward_logits_vocab_dim() {
let cfg = small_config();
let vocab = cfg.vocab_size;
let model = Jamba2ForCausalLM::new(cfg).unwrap_or_else(|_| panic!("init failed"));
let out = model.forward(&[0u32]).unwrap_or_else(|_| panic!("forward failed"));
assert_eq!(
out.logits[0].len(),
vocab,
"logit row must have vocab_size columns"
);
}
#[test]
fn test_forward_hidden_states_length() {
let model =
Jamba2ForCausalLM::new(small_config()).unwrap_or_else(|_| panic!("init failed"));
let out = model.forward(&[1u32, 2]).unwrap_or_else(|_| panic!("forward failed"));
assert_eq!(
out.hidden_states.len(),
2,
"hidden_states must have seq_len rows"
);
}
#[test]
fn test_generate_token_count() {
let model =
Jamba2ForCausalLM::new(small_config()).unwrap_or_else(|_| panic!("init failed"));
let result = model.generate(&[1u32, 2], 3);
assert!(result.is_ok(), "generate must succeed");
let tokens = result.unwrap_or_else(|_| panic!("generate failed"));
assert_eq!(tokens.len(), 3, "must return exactly 3 new tokens");
}
#[test]
fn test_generate_empty_input_error() {
let model =
Jamba2ForCausalLM::new(small_config()).unwrap_or_else(|_| panic!("init failed"));
let err = model.generate(&[], 1);
assert!(err.is_err(), "empty input must return error");
}
#[test]
fn test_generate_zero_tokens() {
let model =
Jamba2ForCausalLM::new(small_config()).unwrap_or_else(|_| panic!("init failed"));
let tokens = model.generate(&[1u32], 0).unwrap_or_default();
assert!(tokens.is_empty(), "zero new tokens must return empty vec");
}
#[test]
fn test_generated_tokens_within_vocab() {
let cfg = small_config();
let vocab = cfg.vocab_size as u32;
let model = Jamba2ForCausalLM::new(cfg).unwrap_or_else(|_| panic!("init failed"));
if let Ok(tokens) = model.generate(&[1u32, 2], 3) {
for &t in &tokens {
assert!(t < vocab, "token {t} must be < vocab {vocab}");
}
}
}
#[test]
fn test_model_accessor_works() {
let model =
Jamba2ForCausalLM::new(small_config()).unwrap_or_else(|_| panic!("init failed"));
let _m = model.model();
}
#[test]
fn test_task_error_config_display() {
use crate::jamba2::config::Jamba2ConfigError;
let err = Jamba2TaskError::Config(Jamba2ConfigError::InvalidField("test".to_string()));
let msg = format!("{err}");
assert!(!msg.is_empty(), "error display must be non-empty");
}
#[test]
fn test_causal_lm_output_accessible() {
let model =
Jamba2ForCausalLM::new(small_config()).unwrap_or_else(|_| panic!("init failed"));
if let Ok(out) = model.forward(&[0u32]) {
let _ = &out.logits;
let _ = &out.hidden_states;
}
}
#[test]
fn test_forward_logits_finite() {
let model =
Jamba2ForCausalLM::new(small_config()).unwrap_or_else(|_| panic!("init failed"));
if let Ok(out) = model.forward(&[1u32]) {
for row in &out.logits {
for &v in row {
assert!(v.is_finite(), "logit {v} must be finite");
}
}
}
}
#[test]
fn test_forward_deterministic() {
let model =
Jamba2ForCausalLM::new(small_config()).unwrap_or_else(|_| panic!("init failed"));
let ids = &[1u32, 2, 3];
let r1 = model.forward(ids);
let r2 = model.forward(ids);
if let (Ok(a), Ok(b)) = (r1, r2) {
for (row_a, row_b) in a.logits.iter().zip(b.logits.iter()) {
assert_eq!(row_a, row_b, "forward must be deterministic");
}
}
}
#[test]
fn test_validate_rejects_zero_vocab() {
let mut cfg = small_config();
cfg.vocab_size = 0;
let result = Jamba2ForCausalLM::new(cfg);
assert!(result.is_err(), "zero vocab_size must be rejected");
}
#[test]
fn test_hidden_states_dim_matches_config() {
let cfg = small_config();
let hidden_size = cfg.hidden_size;
let model = Jamba2ForCausalLM::new(cfg).unwrap_or_else(|_| panic!("init failed"));
if let Ok(out) = model.forward(&[0u32]) {
for row in &out.hidden_states {
assert_eq!(row.len(), hidden_size, "hidden dim must match config");
}
}
}
}