1use crate::tensor::DenseTensor;
4use crate::tensor::traits::{TensorBase, TensorOps};
5use super::layers::{MultiHeadAttention, FeedForward, RMSNorm, RoPE};
6pub use super::loader::LlamaConfig;
7
8#[derive(Debug, Clone)]
10pub struct LlamaDecoderLayer {
11 pub self_attn: MultiHeadAttention,
13 pub mlp: FeedForward,
15 pub input_layernorm: RMSNorm,
17 pub post_attention_layernorm: RMSNorm,
19}
20
21impl LlamaDecoderLayer {
22 pub fn new(
24 self_attn: MultiHeadAttention,
25 mlp: FeedForward,
26 input_layernorm: RMSNorm,
27 post_attention_layernorm: RMSNorm,
28 ) -> Self {
29 Self {
30 self_attn,
31 mlp,
32 input_layernorm,
33 post_attention_layernorm,
34 }
35 }
36
37 pub fn forward(&self, x: &DenseTensor, mask: Option<&DenseTensor>) -> DenseTensor {
46 let normed = self.input_layernorm.forward(x);
50
51 let attn_output = self.self_attn.forward_with_mask(&normed, mask);
53 let hidden = x.add(&attn_output);
54
55 let normed = self.post_attention_layernorm.forward(&hidden);
57
58 let mlp_output = self.mlp.forward(&normed);
60
61
62 hidden.add(&mlp_output)
63 }
64
65 pub fn forward_with_cache(
75 &self,
76 x: &DenseTensor,
77 kv_cache: Option<(&DenseTensor, &DenseTensor)>,
78 mask: Option<&DenseTensor>,
79 ) -> (DenseTensor, Option<(DenseTensor, DenseTensor)>) {
80 let output = self.forward(x, mask);
83 (output, kv_cache.map(|(k, v)| (k.clone(), v.clone())))
84 }
85
86 pub fn num_parameters(&self) -> usize {
88 let mut total = 0;
89
90 total += self.self_attn.num_parameters();
92
93 total += self.mlp.num_parameters();
95
96 total += self.input_layernorm.weight.shape().iter().product::<usize>();
98 total += self.post_attention_layernorm.weight.shape().iter().product::<usize>();
99
100 total
101 }
102}
103
104#[derive(Debug, Clone)]
106pub struct LlamaModel {
107 pub config: LlamaConfig,
109 pub embed_tokens: DenseTensor,
111 pub layers: Vec<LlamaDecoderLayer>,
113 pub norm: RMSNorm,
115 pub lm_head: Option<DenseTensor>,
117 pub rope: RoPE,
119}
120
121impl LlamaModel {
122 pub fn new(
124 config: LlamaConfig,
125 embed_tokens: DenseTensor,
126 layers: Vec<LlamaDecoderLayer>,
127 norm: RMSNorm,
128 lm_head: Option<DenseTensor>,
129 ) -> Self {
130 let rope = RoPE::new(
131 config.head_dim(),
132 config.max_position_embeddings,
133 config.rope_theta,
134 );
135
136 Self {
137 config,
138 embed_tokens,
139 layers,
140 norm,
141 lm_head,
142 rope,
143 }
144 }
145
146 pub fn forward(&self, input_ids: &[Vec<usize>], mask: Option<&DenseTensor>) -> DenseTensor {
155 let batch_size = input_ids.len();
156 let seq_len = input_ids[0].len();
157
158 let mut hidden = self.embed_tokens_batch(input_ids);
160
161 let _positions: Vec<usize> = (0..seq_len).collect();
163
164 for layer in &self.layers {
166 hidden = layer.forward(&hidden, mask);
167 }
168
169 hidden = self.norm.forward(&hidden);
171
172 let lm_head = self.lm_head.as_ref().unwrap_or(&self.embed_tokens);
176 let lm_head_t = lm_head.transpose(None); let hidden_data = hidden.data().to_vec();
180 let hidden_dim = self.config.hidden_size;
181 let flat_hidden = DenseTensor::new(hidden_data, vec![batch_size * seq_len, hidden_dim]);
182
183 let logits_flat = flat_hidden.matmul(&lm_head_t);
185
186 let vocab_size = self.config.vocab_size;
188 let logits_data = logits_flat.data().to_vec();
189
190 DenseTensor::new(logits_data, vec![batch_size, seq_len, vocab_size])
191 }
192
193 pub fn forward_single(&self, input_ids: &[usize], mask: Option<&DenseTensor>) -> DenseTensor {
195 self.forward(&[input_ids.to_vec()], mask)
196 }
197
198 fn embed_tokens_batch(&self, input_ids: &[Vec<usize>]) -> DenseTensor {
200 let batch_size = input_ids.len();
201 let seq_len = input_ids[0].len();
202 let hidden_dim = self.config.hidden_size;
203
204 let mut data = Vec::with_capacity(batch_size * seq_len * hidden_dim);
205
206 for batch in input_ids {
207 for &token_id in batch {
208 let start = token_id * hidden_dim;
209 let end = start + hidden_dim;
210 data.extend_from_slice(&self.embed_tokens.data()[start..end]);
211 }
212 }
213
214 DenseTensor::new(data, vec![batch_size, seq_len, hidden_dim])
215 }
216
217 pub fn hidden_dim(&self) -> usize {
219 self.config.hidden_size
220 }
221
222 pub fn vocab_size(&self) -> usize {
224 self.config.vocab_size
225 }
226
227 pub fn num_layers(&self) -> usize {
229 self.layers.len()
230 }
231
232 pub fn num_parameters(&self) -> usize {
234 let mut total = 0;
235
236 total += self.embed_tokens.shape().iter().product::<usize>();
238
239 for layer in &self.layers {
241 total += layer.num_parameters();
242 }
243
244 total += self.norm.weight.shape().iter().product::<usize>();
246
247 if let Some(lm_head) = &self.lm_head {
249 total += lm_head.shape().iter().product::<usize>();
250 }
251
252 total
253 }
254
255 pub fn size_mb(&self) -> f64 {
257 (self.num_parameters() * 8) as f64 / (1024.0 * 1024.0)
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 use crate::tensor::DenseTensor;
265 use crate::tensor::traits::TensorBase;
266
267 fn create_test_layer(config: &LlamaConfig) -> LlamaDecoderLayer {
268 let hidden_dim = config.hidden_size;
269 let num_heads = config.num_attention_heads;
270
271 let w_q = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
273 let w_k = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
274 let w_v = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
275 let w_o = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
276 let self_attn = MultiHeadAttention::standard(w_q, w_k, w_v, w_o, num_heads);
277
278 let gate_proj = DenseTensor::ones(vec![hidden_dim, config.intermediate_size]);
280 let up_proj = DenseTensor::ones(vec![hidden_dim, config.intermediate_size]);
281 let down_proj = DenseTensor::ones(vec![config.intermediate_size, hidden_dim]);
282 let mlp = FeedForward::swiglu(gate_proj, up_proj, down_proj);
283
284 let input_layernorm = RMSNorm::default(hidden_dim);
286 let post_attention_layernorm = RMSNorm::default(hidden_dim);
287
288 LlamaDecoderLayer::new(self_attn, mlp, input_layernorm, post_attention_layernorm)
289 }
290
291 #[test]
292 fn test_decoder_layer() {
293 let config = LlamaConfig::llama_7b();
294 let layer = create_test_layer(&config);
295
296 let batch_size = 2;
297 let seq_len = 4;
298 let x = DenseTensor::ones(vec![batch_size, seq_len, config.hidden_size]);
299
300 let output = layer.forward(&x, None);
301
302 assert_eq!(output.shape(), &[batch_size, seq_len, config.hidden_size]);
303 }
304
305 #[test]
306 fn test_llama_model_creation() {
307 let config = LlamaConfig::llama_7b();
308
309 let embed_tokens = DenseTensor::ones(vec![config.vocab_size, config.hidden_size]);
310 let layers = vec![create_test_layer(&config); config.num_hidden_layers];
311 let norm = RMSNorm::default(config.hidden_size);
312 let lm_head = None; let model = LlamaModel::new(config, embed_tokens, layers, norm, lm_head);
315
316 assert_eq!(model.num_layers(), 32);
317 assert_eq!(model.vocab_size(), 32000);
318 assert_eq!(model.hidden_dim(), 4096);
319 }
320}
321
322use crate::transformer::graph_transformer::GraphTransformer;
327
328pub struct LlamaModelGraphBuilder<'a> {
350 model: &'a LlamaModel,
351}
352
353impl<'a> LlamaModelGraphBuilder<'a> {
354 pub fn new(model: &'a LlamaModel) -> Self {
356 Self { model }
357 }
358
359 pub fn build_graph(&self) -> GraphTransformer {
361 let mut transformer = GraphTransformer::new(
362 self.model.num_layers(),
363 self.model.config.num_attention_heads,
364 self.model.config.hidden_size,
365 );
366
367 let dummy_input = vec![0; 1]; transformer.build_graph(&dummy_input);
372
373 transformer
374 }
375
376 pub fn build_graph_for_input(&self, input_ids: &[usize]) -> GraphTransformer {
378 let mut transformer = GraphTransformer::new(
379 self.model.num_layers(),
380 self.model.config.num_attention_heads,
381 self.model.config.hidden_size,
382 );
383
384 transformer.build_graph(input_ids);
385 transformer
386 }
387
388 pub fn export_to_dot(&self, transformer: &GraphTransformer) -> String {
390 transformer.to_dot()
391 }
392}
393
394#[cfg(test)]
395mod graph_builder_tests {
396 use super::*;
397 use crate::transformer::layers::{MultiHeadAttention, FeedForward, RMSNorm};
398
399 fn create_test_layer(config: &LlamaConfig) -> LlamaDecoderLayer {
400 let hidden_dim = config.hidden_size;
401 let num_heads = config.num_attention_heads;
402
403 let w_q = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
404 let w_k = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
405 let w_v = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
406 let w_o = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
407 let self_attn = MultiHeadAttention::standard(w_q, w_k, w_v, w_o, num_heads);
408
409 let gate_proj = DenseTensor::ones(vec![hidden_dim, config.intermediate_size]);
410 let up_proj = DenseTensor::ones(vec![hidden_dim, config.intermediate_size]);
411 let down_proj = DenseTensor::ones(vec![config.intermediate_size, hidden_dim]);
412 let mlp = FeedForward::swiglu(gate_proj, up_proj, down_proj);
413
414 let input_layernorm = RMSNorm::default(hidden_dim);
415 let post_attention_layernorm = RMSNorm::default(hidden_dim);
416
417 LlamaDecoderLayer::new(self_attn, mlp, input_layernorm, post_attention_layernorm)
418 }
419
420 #[test]
421 fn test_llama_model_graph_builder() {
422 let config = LlamaConfig::llama_7b();
423 let embed_tokens = DenseTensor::ones(vec![config.vocab_size, config.hidden_size]);
424 let layers = vec![create_test_layer(&config); 2]; let norm = RMSNorm::default(config.hidden_size);
426 let lm_head = None;
427
428 let model = LlamaModel::new(config.clone(), embed_tokens, layers, norm, lm_head);
429
430 let builder = LlamaModelGraphBuilder::new(&model);
431 let transformer = builder.build_graph();
432
433 assert!(transformer.num_nodes() > 0);
435 assert!(transformer.num_edges() > 0);
436 }
437
438 #[test]
439 fn test_llama_model_graph_builder_with_input() {
440 let config = LlamaConfig::llama_7b();
441 let embed_tokens = DenseTensor::ones(vec![config.vocab_size, config.hidden_size]);
442 let layers = vec![create_test_layer(&config); 1];
443 let norm = RMSNorm::default(config.hidden_size);
444 let lm_head = None;
445
446 let model = LlamaModel::new(config.clone(), embed_tokens, layers, norm, lm_head);
447
448 let builder = LlamaModelGraphBuilder::new(&model);
449 let input_ids = vec![1, 2, 3, 4, 5];
450 let mut transformer = builder.build_graph_for_input(&input_ids);
451
452 assert!(transformer.num_nodes() > 0);
454 assert!(transformer.num_edges() > 0);
455
456 let output = transformer.forward(&input_ids);
458 assert!(!output.data().is_empty());
459 }
460
461 #[test]
462 fn test_graph_export_to_dot() {
463 let config = LlamaConfig::llama_7b();
464 let embed_tokens = DenseTensor::ones(vec![config.vocab_size, config.hidden_size]);
465 let layers = vec![create_test_layer(&config); 1];
466 let norm = RMSNorm::default(config.hidden_size);
467 let lm_head = None;
468
469 let model = LlamaModel::new(config.clone(), embed_tokens, layers, norm, lm_head);
470
471 let builder = LlamaModelGraphBuilder::new(&model);
472 let transformer = builder.build_graph();
473 let dot = builder.export_to_dot(&transformer);
474
475 assert!(dot.contains("digraph Transformer"));
477 assert!(dot.contains("rankdir=TB"));
478 }
479}