use crate::tensor::DenseTensor;
use crate::tensor::traits::{TensorBase, TensorOps};
use super::layers::{MultiHeadAttention, FeedForward, RMSNorm, RoPE};
pub use super::loader::LlamaConfig;
#[derive(Debug, Clone)]
pub struct LlamaDecoderLayer {
pub self_attn: MultiHeadAttention,
pub mlp: FeedForward,
pub input_layernorm: RMSNorm,
pub post_attention_layernorm: RMSNorm,
}
impl LlamaDecoderLayer {
pub fn new(
self_attn: MultiHeadAttention,
mlp: FeedForward,
input_layernorm: RMSNorm,
post_attention_layernorm: RMSNorm,
) -> Self {
Self {
self_attn,
mlp,
input_layernorm,
post_attention_layernorm,
}
}
pub fn forward(&self, x: &DenseTensor, mask: Option<&DenseTensor>) -> DenseTensor {
let normed = self.input_layernorm.forward(x);
let attn_output = self.self_attn.forward_with_mask(&normed, mask);
let hidden = x.add(&attn_output);
let normed = self.post_attention_layernorm.forward(&hidden);
let mlp_output = self.mlp.forward(&normed);
hidden.add(&mlp_output)
}
pub fn forward_with_cache(
&self,
x: &DenseTensor,
kv_cache: Option<(&DenseTensor, &DenseTensor)>,
mask: Option<&DenseTensor>,
) -> (DenseTensor, Option<(DenseTensor, DenseTensor)>) {
let output = self.forward(x, mask);
(output, kv_cache.map(|(k, v)| (k.clone(), v.clone())))
}
pub fn num_parameters(&self) -> usize {
let mut total = 0;
total += self.self_attn.num_parameters();
total += self.mlp.num_parameters();
total += self.input_layernorm.weight.shape().iter().product::<usize>();
total += self.post_attention_layernorm.weight.shape().iter().product::<usize>();
total
}
}
#[derive(Debug, Clone)]
pub struct LlamaModel {
pub config: LlamaConfig,
pub embed_tokens: DenseTensor,
pub layers: Vec<LlamaDecoderLayer>,
pub norm: RMSNorm,
pub lm_head: Option<DenseTensor>,
pub rope: RoPE,
}
impl LlamaModel {
pub fn new(
config: LlamaConfig,
embed_tokens: DenseTensor,
layers: Vec<LlamaDecoderLayer>,
norm: RMSNorm,
lm_head: Option<DenseTensor>,
) -> Self {
let rope = RoPE::new(
config.head_dim(),
config.max_position_embeddings,
config.rope_theta,
);
Self {
config,
embed_tokens,
layers,
norm,
lm_head,
rope,
}
}
pub fn forward(&self, input_ids: &[Vec<usize>], mask: Option<&DenseTensor>) -> DenseTensor {
let batch_size = input_ids.len();
let seq_len = input_ids[0].len();
let mut hidden = self.embed_tokens_batch(input_ids);
let _positions: Vec<usize> = (0..seq_len).collect();
for layer in &self.layers {
hidden = layer.forward(&hidden, mask);
}
hidden = self.norm.forward(&hidden);
let lm_head = self.lm_head.as_ref().unwrap_or(&self.embed_tokens);
let lm_head_t = lm_head.transpose(None);
let hidden_data = hidden.data().to_vec();
let hidden_dim = self.config.hidden_size;
let flat_hidden = DenseTensor::new(hidden_data, vec![batch_size * seq_len, hidden_dim]);
let logits_flat = flat_hidden.matmul(&lm_head_t);
let vocab_size = self.config.vocab_size;
let logits_data = logits_flat.data().to_vec();
DenseTensor::new(logits_data, vec![batch_size, seq_len, vocab_size])
}
pub fn forward_single(&self, input_ids: &[usize], mask: Option<&DenseTensor>) -> DenseTensor {
self.forward(&[input_ids.to_vec()], mask)
}
fn embed_tokens_batch(&self, input_ids: &[Vec<usize>]) -> DenseTensor {
let batch_size = input_ids.len();
let seq_len = input_ids[0].len();
let hidden_dim = self.config.hidden_size;
let mut data = Vec::with_capacity(batch_size * seq_len * hidden_dim);
for batch in input_ids {
for &token_id in batch {
let start = token_id * hidden_dim;
let end = start + hidden_dim;
data.extend_from_slice(&self.embed_tokens.data()[start..end]);
}
}
DenseTensor::new(data, vec![batch_size, seq_len, hidden_dim])
}
pub fn hidden_dim(&self) -> usize {
self.config.hidden_size
}
pub fn vocab_size(&self) -> usize {
self.config.vocab_size
}
pub fn num_layers(&self) -> usize {
self.layers.len()
}
pub fn num_parameters(&self) -> usize {
let mut total = 0;
total += self.embed_tokens.shape().iter().product::<usize>();
for layer in &self.layers {
total += layer.num_parameters();
}
total += self.norm.weight.shape().iter().product::<usize>();
if let Some(lm_head) = &self.lm_head {
total += lm_head.shape().iter().product::<usize>();
}
total
}
pub fn size_mb(&self) -> f64 {
(self.num_parameters() * 8) as f64 / (1024.0 * 1024.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::DenseTensor;
use crate::tensor::traits::TensorBase;
fn create_test_layer(config: &LlamaConfig) -> LlamaDecoderLayer {
let hidden_dim = config.hidden_size;
let num_heads = config.num_attention_heads;
let w_q = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
let w_k = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
let w_v = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
let w_o = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
let self_attn = MultiHeadAttention::standard(w_q, w_k, w_v, w_o, num_heads);
let gate_proj = DenseTensor::ones(vec![hidden_dim, config.intermediate_size]);
let up_proj = DenseTensor::ones(vec![hidden_dim, config.intermediate_size]);
let down_proj = DenseTensor::ones(vec![config.intermediate_size, hidden_dim]);
let mlp = FeedForward::swiglu(gate_proj, up_proj, down_proj);
let input_layernorm = RMSNorm::default(hidden_dim);
let post_attention_layernorm = RMSNorm::default(hidden_dim);
LlamaDecoderLayer::new(self_attn, mlp, input_layernorm, post_attention_layernorm)
}
#[test]
fn test_decoder_layer() {
let config = LlamaConfig::llama_7b();
let layer = create_test_layer(&config);
let batch_size = 2;
let seq_len = 4;
let x = DenseTensor::ones(vec![batch_size, seq_len, config.hidden_size]);
let output = layer.forward(&x, None);
assert_eq!(output.shape(), &[batch_size, seq_len, config.hidden_size]);
}
#[test]
fn test_llama_model_creation() {
let config = LlamaConfig::llama_7b();
let embed_tokens = DenseTensor::ones(vec![config.vocab_size, config.hidden_size]);
let layers = vec![create_test_layer(&config); config.num_hidden_layers];
let norm = RMSNorm::default(config.hidden_size);
let lm_head = None;
let model = LlamaModel::new(config, embed_tokens, layers, norm, lm_head);
assert_eq!(model.num_layers(), 32);
assert_eq!(model.vocab_size(), 32000);
assert_eq!(model.hidden_dim(), 4096);
}
}
use crate::transformer::graph_transformer::GraphTransformer;
pub struct LlamaModelGraphBuilder<'a> {
model: &'a LlamaModel,
}
impl<'a> LlamaModelGraphBuilder<'a> {
pub fn new(model: &'a LlamaModel) -> Self {
Self { model }
}
pub fn build_graph(&self) -> GraphTransformer {
let mut transformer = GraphTransformer::new(
self.model.num_layers(),
self.model.config.num_attention_heads,
self.model.config.hidden_size,
);
let dummy_input = vec![0; 1]; transformer.build_graph(&dummy_input);
transformer
}
pub fn build_graph_for_input(&self, input_ids: &[usize]) -> GraphTransformer {
let mut transformer = GraphTransformer::new(
self.model.num_layers(),
self.model.config.num_attention_heads,
self.model.config.hidden_size,
);
transformer.build_graph(input_ids);
transformer
}
pub fn export_to_dot(&self, transformer: &GraphTransformer) -> String {
transformer.to_dot()
}
}
#[cfg(test)]
mod graph_builder_tests {
use super::*;
use crate::transformer::layers::{MultiHeadAttention, FeedForward, RMSNorm};
fn create_test_layer(config: &LlamaConfig) -> LlamaDecoderLayer {
let hidden_dim = config.hidden_size;
let num_heads = config.num_attention_heads;
let w_q = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
let w_k = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
let w_v = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
let w_o = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
let self_attn = MultiHeadAttention::standard(w_q, w_k, w_v, w_o, num_heads);
let gate_proj = DenseTensor::ones(vec![hidden_dim, config.intermediate_size]);
let up_proj = DenseTensor::ones(vec![hidden_dim, config.intermediate_size]);
let down_proj = DenseTensor::ones(vec![config.intermediate_size, hidden_dim]);
let mlp = FeedForward::swiglu(gate_proj, up_proj, down_proj);
let input_layernorm = RMSNorm::default(hidden_dim);
let post_attention_layernorm = RMSNorm::default(hidden_dim);
LlamaDecoderLayer::new(self_attn, mlp, input_layernorm, post_attention_layernorm)
}
#[test]
fn test_llama_model_graph_builder() {
let config = LlamaConfig::llama_7b();
let embed_tokens = DenseTensor::ones(vec![config.vocab_size, config.hidden_size]);
let layers = vec![create_test_layer(&config); 2]; let norm = RMSNorm::default(config.hidden_size);
let lm_head = None;
let model = LlamaModel::new(config.clone(), embed_tokens, layers, norm, lm_head);
let builder = LlamaModelGraphBuilder::new(&model);
let transformer = builder.build_graph();
assert!(transformer.num_nodes() > 0);
assert!(transformer.num_edges() > 0);
}
#[test]
fn test_llama_model_graph_builder_with_input() {
let config = LlamaConfig::llama_7b();
let embed_tokens = DenseTensor::ones(vec![config.vocab_size, config.hidden_size]);
let layers = vec![create_test_layer(&config); 1];
let norm = RMSNorm::default(config.hidden_size);
let lm_head = None;
let model = LlamaModel::new(config.clone(), embed_tokens, layers, norm, lm_head);
let builder = LlamaModelGraphBuilder::new(&model);
let input_ids = vec![1, 2, 3, 4, 5];
let mut transformer = builder.build_graph_for_input(&input_ids);
assert!(transformer.num_nodes() > 0);
assert!(transformer.num_edges() > 0);
let output = transformer.forward(&input_ids);
assert!(!output.data().is_empty());
}
#[test]
fn test_graph_export_to_dot() {
let config = LlamaConfig::llama_7b();
let embed_tokens = DenseTensor::ones(vec![config.vocab_size, config.hidden_size]);
let layers = vec![create_test_layer(&config); 1];
let norm = RMSNorm::default(config.hidden_size);
let lm_head = None;
let model = LlamaModel::new(config.clone(), embed_tokens, layers, norm, lm_head);
let builder = LlamaModelGraphBuilder::new(&model);
let transformer = builder.build_graph();
let dot = builder.export_to_dot(&transformer);
assert!(dot.contains("digraph Transformer"));
assert!(dot.contains("rankdir=TB"));
}
}