use crate::causal_attention::CausalAttention;
use crate::decoder_layer::DecoderLayer;
use crate::model_config::ModelConfig;
use crate::rms_norm::RmsNorm;
use crate::types::{Error, Result};
use burn::nn::{Embedding, EmbeddingConfig};
use burn::tensor::backend::Backend;
use burn::tensor::{Int, Tensor, TensorData};
use std::path::Path;
#[derive(Clone)]
pub struct DecoderModel<B: Backend> {
pub(crate) embeddings: Embedding<B>,
pub(crate) layers: Vec<DecoderLayer<B>>,
pub(crate) final_norm: RmsNorm<B>,
pub(crate) pad_token_id: i64,
pub(crate) max_position_embeddings: usize,
pub(crate) hidden_size: usize,
pub(crate) device: B::Device,
}
impl<B: Backend> DecoderModel<B> {
pub fn new(device: &B::Device, config: ModelConfig) -> Result<Self> {
if config.num_hidden_layers == 0 {
return Err(Error::InvalidConfig(
"num_hidden_layers must be greater than 0 for decoder model".into(),
));
}
if config.vocab_size == 0 {
return Err(Error::InvalidConfig(
"vocab_size must be greater than 0 for decoder model".into(),
));
}
let embeddings = EmbeddingConfig::new(config.vocab_size, config.hidden_size).init(device);
let head_dim = config
.head_dim
.unwrap_or_else(|| config.hidden_size / config.num_attention_heads);
let rope = CausalAttention::build_rope(device, &config, head_dim);
let mut layers = Vec::with_capacity(config.num_hidden_layers);
for _ in 0..config.num_hidden_layers {
layers.push(DecoderLayer::new(device, &config, rope.clone())?);
}
let final_norm = RmsNorm::new(device, &config);
let pad_token_id = config.pad_token_id.unwrap_or(0);
Ok(Self {
embeddings,
layers,
final_norm,
pad_token_id,
max_position_embeddings: config.max_position_embeddings,
hidden_size: config.hidden_size,
device: device.clone(),
})
}
pub fn forward(&self, input_ids: Tensor<B, 2, Int>) -> Result<Tensor<B, 3>> {
let [_batch_size, seq_len] = input_ids.dims();
if seq_len == 0 {
return Err(Error::InvalidConfig(
"input sequence length must be greater than 0".into(),
));
}
if self.max_position_embeddings > 0 && seq_len > self.max_position_embeddings {
return Err(Error::InvalidConfig(format!(
"Sequence length {} exceeds configured maximum {}",
seq_len, self.max_position_embeddings
)));
}
let mut hidden_states = self.embeddings.forward(input_ids);
for layer in &self.layers {
hidden_states = layer.forward(hidden_states, 0);
}
Ok(self.final_norm.forward(hidden_states))
}
pub fn pool_hidden_states(
&self,
hidden_states: Tensor<B, 3>,
input_ids: Tensor<B, 2, Int>,
) -> Result<Tensor<B, 2>> {
let indices = self.last_token_indices(&input_ids)?;
let [batch_size, _seq_len, hidden_size] = hidden_states.dims();
let mut gather_indices = Vec::with_capacity(batch_size * hidden_size);
for index in indices {
for _ in 0..hidden_size {
gather_indices.push(index);
}
}
let indices_tensor = Tensor::<B, 3, Int>::from_data(
TensorData::new(gather_indices, [batch_size, 1, hidden_size]),
&self.device,
);
Ok(hidden_states
.gather(1, indices_tensor)
.reshape([batch_size, hidden_size]))
}
pub fn hidden_size(&self) -> usize {
self.hidden_size
}
pub fn load_safetensors(&mut self, safetensors_path: &Path) -> Result<()> {
use crate::weight_loader::{load_linear, load_embedding, WeightLoader};
use burn::module::Param;
let bytes = std::fs::read(safetensors_path).map_err(|err| {
Error::LoadError(format!(
"Failed to read SafeTensors file {}: {err}",
safetensors_path.display()
))
})?;
let loader = WeightLoader::from_bytes(&bytes)?;
let embed_names = [
"model.embed_tokens.weight",
"transformer.wte.weight",
"embeddings.word_embeddings.weight",
];
for name in embed_names {
if loader.has_tensor(name) {
self.embeddings = load_embedding(&loader, name, &self.device)?;
break;
}
}
for (layer_idx, layer) in self.layers.iter_mut().enumerate() {
let prefix = format!("model.layers.{}", layer_idx);
if loader.has_tensor(&format!("{}.self_attn.q_proj.weight", prefix)) {
layer.attention.q_proj = load_linear(
&loader,
&format!("{}.self_attn.q_proj.weight", prefix),
Some(&format!("{}.self_attn.q_proj.bias", prefix)),
&self.device,
)?;
layer.attention.k_proj = load_linear(
&loader,
&format!("{}.self_attn.k_proj.weight", prefix),
Some(&format!("{}.self_attn.k_proj.bias", prefix)),
&self.device,
)?;
layer.attention.v_proj = load_linear(
&loader,
&format!("{}.self_attn.v_proj.weight", prefix),
Some(&format!("{}.self_attn.v_proj.bias", prefix)),
&self.device,
)?;
layer.attention.o_proj = load_linear(
&loader,
&format!("{}.self_attn.o_proj.weight", prefix),
Some(&format!("{}.self_attn.o_proj.bias", prefix)),
&self.device,
)?;
}
if loader.has_tensor(&format!("{}.mlp.gate_proj.weight", prefix)) {
layer.gate_proj = load_linear(
&loader,
&format!("{}.mlp.gate_proj.weight", prefix),
None,
&self.device,
)?;
layer.up_proj = load_linear(
&loader,
&format!("{}.mlp.up_proj.weight", prefix),
None,
&self.device,
)?;
layer.down_proj = load_linear(
&loader,
&format!("{}.mlp.down_proj.weight", prefix),
None,
&self.device,
)?;
}
if loader.has_tensor(&format!("{}.input_layernorm.weight", prefix)) {
let norm_tensor = loader.load_tensor(&format!("{}.input_layernorm.weight", prefix))?;
let norm_weight = norm_tensor.to_tensor::<B, 1>(&self.device, [norm_tensor.shape[0]])?;
layer.attention_norm.inner.gamma = Param::from_tensor(norm_weight);
}
if loader.has_tensor(&format!("{}.post_attention_layernorm.weight", prefix)) {
let norm_tensor = loader.load_tensor(&format!("{}.post_attention_layernorm.weight", prefix))?;
let norm_weight = norm_tensor.to_tensor::<B, 1>(&self.device, [norm_tensor.shape[0]])?;
layer.ffn_norm.inner.gamma = Param::from_tensor(norm_weight);
}
}
let final_norm_names = ["model.norm.weight", "transformer.ln_f.weight"];
for name in final_norm_names {
if loader.has_tensor(name) {
let norm_tensor = loader.load_tensor(name)?;
let norm_weight = norm_tensor.to_tensor::<B, 1>(&self.device, [norm_tensor.shape[0]])?;
self.final_norm.inner.gamma = Param::from_tensor(norm_weight);
break;
}
}
log::info!("Successfully loaded decoder weights from {}", safetensors_path.display());
Ok(())
}
fn last_token_indices(&self, input_ids: &Tensor<B, 2, Int>) -> Result<Vec<i64>> {
let [batch_size, seq_len] = input_ids.dims();
let data = input_ids
.clone()
.into_data()
.into_vec::<i64>()
.map_err(|err| Error::InferenceError(err.to_string()))?;
let mut indices = Vec::with_capacity(batch_size);
for batch in 0..batch_size {
let start = batch * seq_len;
let end = start + seq_len;
let row = &data[start..end];
let mut idx = seq_len.saturating_sub(1);
while idx > 0 && row[idx] == self.pad_token_id {
idx = idx.saturating_sub(1);
}
indices.push(idx as i64);
}
Ok(indices)
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::ndarray::NdArray;
#[test]
fn pool_last_token_respects_padding() {
let device = <NdArray<f32> as Backend>::Device::default();
let mut config = ModelConfig::default();
config.hidden_size = 4;
config.num_hidden_layers = 1;
config.num_attention_heads = 1;
config.num_key_value_heads = Some(1);
config.intermediate_size = Some(8);
config.vocab_size = 16;
config.max_position_embeddings = 8;
config.position_embedding_type = Some("rope".to_string());
config.rms_norm_eps = Some(1e-6);
config.pad_token_id = Some(0);
let model = DecoderModel::<NdArray<f32>>::new(&device, config).expect("model");
let hidden_states = Tensor::<NdArray<f32>, 3>::from_data(
[
[
[1.0, 1.0, 1.0, 1.0],
[2.0, 2.0, 2.0, 2.0],
[3.0, 3.0, 3.0, 3.0],
[4.0, 4.0, 4.0, 4.0],
],
[
[10.0, 10.0, 10.0, 10.0],
[20.0, 20.0, 20.0, 20.0],
[30.0, 30.0, 30.0, 30.0],
[40.0, 40.0, 40.0, 40.0],
],
],
&device,
);
let input_ids = Tensor::<NdArray<f32>, 2, Int>::from_data(
[[5i64, 6, 0, 0], [7, 8, 9, 0]],
&device,
);
let pooled = model
.pool_hidden_states(hidden_states, input_ids)
.expect("pool");
let data = pooled
.into_data()
.into_vec::<f32>()
.expect("pooled data");
assert_eq!(data, vec![2.0, 2.0, 2.0, 2.0, 30.0, 30.0, 30.0, 30.0]);
}
}