use std::sync::Arc;
use crate::backend::Backend;
use crate::tensor::{DType, Tensor};
use super::config::ModelConfig;
use super::deltanet::RecurrentConfig;
use super::error::{ModelError, ModelResult};
use super::layers::{Linear, NormLayer, TransformerLayer};
use super::deltanet::DeltaNetConfig;
use super::{Architecture, InferenceContext, Model};
pub struct LlamaModel {
config: ModelConfig,
token_embedding: Tensor,
layers: Vec<TransformerLayer>,
norm: NormLayer,
output: Linear,
architecture: Architecture,
recurrent_mask: Vec<bool>,
recurrent_config: Option<RecurrentConfig>,
}
impl LlamaModel {
pub fn new(
config: ModelConfig,
token_embedding: Tensor,
layers: Vec<TransformerLayer>,
norm: NormLayer,
output: Linear,
architecture: Architecture,
) -> ModelResult<Self> {
if layers.len() != config.num_layers {
return Err(ModelError::ConfigError(format!(
"Expected {} layers, got {}",
config.num_layers,
layers.len()
)));
}
let recurrent_mask: Vec<bool> = layers.iter().map(|l| l.is_recurrent()).collect();
let has_recurrent = recurrent_mask.iter().any(|&r| r);
let recurrent_config = if has_recurrent && config.has_ssm() {
let is_mamba =
matches!(architecture, Architecture::Mamba | Architecture::Mamba2);
Some(if is_mamba {
RecurrentConfig::Mamba(super::mamba::MambaConfig {
d_inner: config.ssm_d_inner,
d_state: config.ssm_d_state,
dt_rank: config.ssm_dt_rank,
conv_kernel: config.ssm_conv_kernel.max(1),
})
} else {
let d_inner = config.ssm_d_inner;
let d_state = config.ssm_d_state;
let num_v_heads = config.ssm_dt_rank;
let num_k_heads = config.ssm_n_group.max(1);
let head_v_dim = d_inner / num_v_heads.max(1);
let head_k_dim = d_state;
let conv_kernel = config.ssm_conv_kernel;
let q_dim = num_k_heads * head_k_dim;
let k_dim = num_k_heads * head_k_dim;
let qkv_dim = q_dim + k_dim + d_inner;
RecurrentConfig::DeltaNet(DeltaNetConfig {
d_inner,
d_state,
num_v_heads,
num_k_heads,
head_v_dim,
head_k_dim,
conv_kernel,
qkv_dim,
})
})
} else {
None
};
Ok(Self {
config,
token_embedding,
layers,
norm,
output,
architecture,
recurrent_mask,
recurrent_config,
})
}
pub fn create_context(&self, backend: Arc<dyn Backend>) -> InferenceContext {
if let Some(ref rc) = self.recurrent_config {
InferenceContext::new_with_recurrent(
&self.config,
backend,
&self.recurrent_mask,
rc,
)
} else {
InferenceContext::new(&self.config, backend)
}
}
pub fn config(&self) -> &ModelConfig {
&self.config
}
pub fn layers(&self) -> &[TransformerLayer] {
&self.layers
}
#[allow(clippy::type_complexity)]
pub fn into_parts(
self,
) -> (
ModelConfig,
Tensor,
Vec<TransformerLayer>,
NormLayer,
Linear,
Architecture,
Vec<bool>,
Option<RecurrentConfig>,
) {
(
self.config,
self.token_embedding,
self.layers,
self.norm,
self.output,
self.architecture,
self.recurrent_mask,
self.recurrent_config,
)
}
pub fn norm(&self) -> &NormLayer {
&self.norm
}
pub fn output(&self) -> &Linear {
&self.output
}
pub fn token_embedding(&self) -> &Tensor {
&self.token_embedding
}
fn dequantize_embeddings<'a>(
&'a self,
backend: &dyn Backend,
) -> ModelResult<std::borrow::Cow<'a, [f32]>> {
if self.token_embedding.dtype() == DType::F32 {
Ok(std::borrow::Cow::Borrowed(self.token_embedding.as_f32()?))
} else {
let numel = self.token_embedding.numel();
let mut dequant = Tensor::zeros(vec![numel], DType::F32);
backend.dequantize(&self.token_embedding, &mut dequant)?;
Ok(std::borrow::Cow::Owned(dequant.as_f32()?.to_vec()))
}
}
pub fn embed_tokens(&self, tokens: &[u32], backend: &dyn Backend) -> ModelResult<Tensor> {
let hidden_size = self.config.hidden_size;
let vocab_size = self.config.vocab_size;
let seq_len = tokens.len();
let embedding_data = self.dequantize_embeddings(backend)?;
let mut output = vec![0.0f32; seq_len * hidden_size];
for (i, &token) in tokens.iter().enumerate() {
let token_idx = token as usize;
if token_idx >= vocab_size {
return Err(ModelError::InvalidMetadata {
key: "token".into(),
message: format!("Token ID {} exceeds vocab size {}", token, vocab_size),
});
}
let src_start = token_idx * hidden_size;
let src_end = src_start + hidden_size;
if src_end > embedding_data.len() {
return Err(ModelError::InvalidMetadata {
key: "embedding".into(),
message: format!(
"Embedding index out of bounds: token_idx={}, src_end={}, embedding_len={}",
token_idx,
src_end,
embedding_data.len()
),
});
}
let dst_start = i * hidden_size;
output[dst_start..dst_start + hidden_size]
.copy_from_slice(&embedding_data[src_start..src_end]);
}
if seq_len == 1 {
Tensor::from_f32(&output, vec![hidden_size])
} else {
Tensor::from_f32(&output, vec![seq_len, hidden_size])
}
.map_err(|e| e.into())
}
fn compute_logits(&self, hidden: &Tensor, backend: &dyn Backend) -> ModelResult<Tensor> {
let mut normed = Tensor::zeros(hidden.shape().to_vec(), DType::F32);
self.norm.forward(hidden, &mut normed, backend)?;
let mut logits = Tensor::zeros(vec![self.config.vocab_size], DType::F32);
self.output.forward(&normed, &mut logits, backend)?;
if self.config.final_logit_softcap > 0.0 {
let cap = self.config.final_logit_softcap;
let data = logits.as_f32_mut()?;
for v in data.iter_mut() {
*v = cap * (*v / cap).tanh();
}
}
Ok(logits)
}
}
impl Model for LlamaModel {
fn create_context(&self, backend: Arc<dyn Backend>) -> InferenceContext {
self.create_context(backend)
}
fn forward(&self, tokens: &[u32], ctx: &mut InferenceContext) -> ModelResult<Tensor> {
let backend = ctx.backend.as_ref();
let num_tokens = tokens.len();
let new_pos = ctx.position + num_tokens;
if new_pos > self.config.max_seq_len {
return Err(ModelError::ContextLengthExceeded {
current: new_pos,
max: self.config.max_seq_len,
});
}
let embedding_data = self.dequantize_embeddings(backend)?;
let hidden_size = self.config.hidden_size;
let vocab_size = self.config.vocab_size;
let mut hiddens: Vec<Tensor> = Vec::with_capacity(num_tokens);
for &token in tokens {
let token_idx = token as usize;
if token_idx >= vocab_size {
return Err(ModelError::InvalidMetadata {
key: "token".into(),
message: format!("Token ID {} exceeds vocab size {}", token, vocab_size),
});
}
let src = token_idx * hidden_size;
hiddens.push(Tensor::from_f32(
&embedding_data[src..src + hidden_size],
vec![hidden_size],
)?);
}
if std::env::var("LLAMA_DEBUG").is_ok() && ctx.position == 0 {
let h = hiddens.last().unwrap().as_f32().unwrap();
let n = h.len().min(8);
eprintln!("[DBG] tokens: {:?}", tokens);
eprintln!("[DBG] embed[0..{}]: {:?}", n, &h[..n]);
}
if self.architecture.is_gemma() {
let scale = (hidden_size as f32).sqrt();
for hidden in &mut hiddens {
let data = hidden.as_f32_mut()?;
for v in data.iter_mut() {
*v *= scale;
}
}
}
for (layer_idx, layer) in self.layers.iter().enumerate() {
for (token_offset, hidden) in hiddens.iter_mut().enumerate() {
let current_pos = ctx.position + token_offset;
let recurrent_state = ctx
.recurrent_state
.as_mut()
.and_then(|rs| rs.states[layer_idx].as_mut());
*hidden = layer.forward(
hidden,
&mut ctx.kv_cache.k_cache[layer_idx],
&mut ctx.kv_cache.v_cache[layer_idx],
current_pos,
self.config.rope_config.freq_base,
self.config.rope_config.freq_scale,
backend,
recurrent_state,
)?;
}
if std::env::var("LLAMA_DEBUG").is_ok() && ctx.position == 0 && (layer_idx < 4 || layer_idx == 31) {
let h = hiddens.last().unwrap().as_f32().unwrap();
let n = h.len().min(8);
let is_rec = layer.is_recurrent();
let rms: f32 = h.iter().map(|x| x * x).sum::<f32>() / h.len() as f32;
eprintln!("[DBG] layer {} ({}): rms={:.6} first8={:?}", layer_idx,
if is_rec { "deltanet" } else { "attn" }, rms.sqrt(), &h[..n]);
}
}
ctx.position = new_pos;
ctx.kv_cache.seq_len = new_pos;
self.compute_logits(hiddens.last().unwrap(), backend)
}
fn config(&self) -> &ModelConfig {
&self.config
}
fn architecture(&self) -> Architecture {
self.architecture
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_llama_config() {
let config = ModelConfig::llama_7b();
assert_eq!(config.vocab_size, 32000);
assert_eq!(config.hidden_size, 4096);
assert_eq!(config.num_layers, 32);
assert_eq!(config.num_heads, 32);
}
}