1use ghostflow_core::Tensor;
11use crate::linear::Linear;
12use crate::Module;
13
14#[derive(Debug, Clone)]
16pub struct LLaMAConfig {
17 pub vocab_size: usize,
19 pub hidden_size: usize,
21 pub intermediate_size: usize,
23 pub num_layers: usize,
25 pub num_attention_heads: usize,
27 pub num_key_value_heads: usize,
29 pub max_position_embeddings: usize,
31 pub rms_norm_eps: f32,
33 pub rope_theta: f32,
35}
36
37impl Default for LLaMAConfig {
38 fn default() -> Self {
39 LLaMAConfig {
40 vocab_size: 32000,
41 hidden_size: 4096,
42 intermediate_size: 11008,
43 num_layers: 32,
44 num_attention_heads: 32,
45 num_key_value_heads: 32,
46 max_position_embeddings: 2048,
47 rms_norm_eps: 1e-6,
48 rope_theta: 10000.0,
49 }
50 }
51}
52
53impl LLaMAConfig {
54 pub fn llama_7b() -> Self {
56 Self::default()
57 }
58
59 pub fn llama_13b() -> Self {
61 LLaMAConfig {
62 hidden_size: 5120,
63 intermediate_size: 13824,
64 num_layers: 40,
65 num_attention_heads: 40,
66 num_key_value_heads: 40,
67 ..Default::default()
68 }
69 }
70
71 pub fn llama_30b() -> Self {
73 LLaMAConfig {
74 hidden_size: 6656,
75 intermediate_size: 17920,
76 num_layers: 60,
77 num_attention_heads: 52,
78 num_key_value_heads: 52,
79 ..Default::default()
80 }
81 }
82
83 pub fn llama_65b() -> Self {
85 LLaMAConfig {
86 hidden_size: 8192,
87 intermediate_size: 22016,
88 num_layers: 80,
89 num_attention_heads: 64,
90 num_key_value_heads: 64,
91 ..Default::default()
92 }
93 }
94
95 pub fn llama2_7b() -> Self {
97 LLaMAConfig {
98 max_position_embeddings: 4096,
99 ..Self::llama_7b()
100 }
101 }
102
103 pub fn llama2_13b() -> Self {
105 LLaMAConfig {
106 max_position_embeddings: 4096,
107 ..Self::llama_13b()
108 }
109 }
110
111 pub fn llama2_70b() -> Self {
113 LLaMAConfig {
114 hidden_size: 8192,
115 intermediate_size: 28672,
116 num_layers: 80,
117 num_attention_heads: 64,
118 num_key_value_heads: 8, max_position_embeddings: 4096,
120 ..Default::default()
121 }
122 }
123
124 pub fn llama_tiny() -> Self {
126 LLaMAConfig {
127 vocab_size: 1000,
128 hidden_size: 256,
129 intermediate_size: 688,
130 num_layers: 4,
131 num_attention_heads: 4,
132 num_key_value_heads: 4,
133 max_position_embeddings: 512,
134 rms_norm_eps: 1e-6,
135 rope_theta: 10000.0,
136 }
137 }
138}
139
140pub struct RMSNorm {
142 weight: Tensor,
143 eps: f32,
144}
145
146impl RMSNorm {
147 pub fn new(hidden_size: usize, eps: f32) -> Self {
149 let weight = Tensor::ones(&[hidden_size]);
150 RMSNorm { weight, eps }
151 }
152
153 pub fn forward(&self, x: &Tensor) -> Result<Tensor, String> {
155 let x_data = x.data_f32();
156 let dims = x.dims();
157
158 if dims.len() < 2 {
159 return Err(format!("Expected at least 2D input, got {}D", dims.len()));
160 }
161
162 let hidden_size = dims[dims.len() - 1];
163 let batch_seq = x_data.len() / hidden_size;
164
165 let weight_data = self.weight.data_f32();
166 let mut result = Vec::with_capacity(x_data.len());
167
168 for i in 0..batch_seq {
169 let start = i * hidden_size;
170 let end = start + hidden_size;
171 let slice = &x_data[start..end];
172
173 let mean_sq: f32 = slice.iter().map(|x| x * x).sum::<f32>() / hidden_size as f32;
175 let rms = (mean_sq + self.eps).sqrt();
176
177 for (j, &x) in slice.iter().enumerate() {
179 result.push(x / rms * weight_data[j]);
180 }
181 }
182
183 Tensor::from_slice(&result, dims)
184 .map_err(|e| format!("Failed to create normalized tensor: {:?}", e))
185 }
186}
187
188pub struct RotaryEmbedding {
190 dim: usize,
192 max_seq_len: usize,
194 cos_cached: Vec<f32>,
196 sin_cached: Vec<f32>,
197}
198
199impl RotaryEmbedding {
200 pub fn new(dim: usize, max_seq_len: usize, theta: f32) -> Self {
202 let mut cos_cached = Vec::with_capacity(max_seq_len * dim);
203 let mut sin_cached = Vec::with_capacity(max_seq_len * dim);
204
205 for pos in 0..max_seq_len {
207 for i in 0..(dim / 2) {
208 let freq = 1.0 / theta.powf(2.0 * i as f32 / dim as f32);
209 let angle = pos as f32 * freq;
210 cos_cached.push(angle.cos());
211 sin_cached.push(angle.sin());
212 }
213 }
214
215 RotaryEmbedding {
216 dim,
217 max_seq_len,
218 cos_cached,
219 sin_cached,
220 }
221 }
222
223 pub fn forward(&self, x: &Tensor, position: usize) -> Result<Tensor, String> {
225 if position >= self.max_seq_len {
226 return Err(format!("Position {} exceeds max_seq_len {}", position, self.max_seq_len));
227 }
228
229 let x_data = x.data_f32();
230 let dims = x.dims();
231 let hidden_size = dims[dims.len() - 1];
232
233 if hidden_size != self.dim {
234 return Err(format!("Hidden size {} doesn't match RoPE dim {}", hidden_size, self.dim));
235 }
236
237 let mut result = Vec::with_capacity(x_data.len());
238 let offset = position * (self.dim / 2);
239
240 for chunk in x_data.chunks(self.dim) {
242 for i in 0..(self.dim / 2) {
243 let cos = self.cos_cached[offset + i];
244 let sin = self.sin_cached[offset + i];
245
246 let x1 = chunk[2 * i];
247 let x2 = chunk[2 * i + 1];
248
249 result.push(x1 * cos - x2 * sin);
250 result.push(x1 * sin + x2 * cos);
251 }
252 }
253
254 Tensor::from_slice(&result, dims)
255 .map_err(|e| format!("Failed to apply RoPE: {:?}", e))
256 }
257}
258
259pub struct SwiGLU {
261 gate_proj: Linear,
262 up_proj: Linear,
263 down_proj: Linear,
264}
265
266impl SwiGLU {
267 pub fn new(hidden_size: usize, intermediate_size: usize) -> Self {
269 SwiGLU {
270 gate_proj: Linear::new(hidden_size, intermediate_size),
271 up_proj: Linear::new(hidden_size, intermediate_size),
272 down_proj: Linear::new(intermediate_size, hidden_size),
273 }
274 }
275
276 pub fn forward(&self, x: &Tensor) -> Tensor {
278 let gate = self.gate_proj.forward(x);
279 let up = self.up_proj.forward(x);
280
281 let gate_silu = gate.silu();
283 let intermediate = gate_silu.mul(&up).unwrap_or(gate_silu);
284
285 self.down_proj.forward(&intermediate)
286 }
287}
288
289pub struct LLaMAAttention {
291 q_proj: Linear,
292 k_proj: Linear,
293 v_proj: Linear,
294 o_proj: Linear,
295 rope: RotaryEmbedding,
296 num_heads: usize,
297 num_kv_heads: usize,
298 head_dim: usize,
299}
300
301impl LLaMAAttention {
302 pub fn new(config: &LLaMAConfig) -> Self {
304 let head_dim = config.hidden_size / config.num_attention_heads;
305
306 LLaMAAttention {
307 q_proj: Linear::new(config.hidden_size, config.num_attention_heads * head_dim),
308 k_proj: Linear::new(config.hidden_size, config.num_key_value_heads * head_dim),
309 v_proj: Linear::new(config.hidden_size, config.num_key_value_heads * head_dim),
310 o_proj: Linear::new(config.num_attention_heads * head_dim, config.hidden_size),
311 rope: RotaryEmbedding::new(head_dim, config.max_position_embeddings, config.rope_theta),
312 num_heads: config.num_attention_heads,
313 num_kv_heads: config.num_key_value_heads,
314 head_dim,
315 }
316 }
317
318 pub fn forward(&self, hidden_states: &Tensor, position: usize) -> Tensor {
320 let q = self.q_proj.forward(hidden_states);
322 let _k = self.k_proj.forward(hidden_states);
323 let _v = self.v_proj.forward(hidden_states);
324
325 let q_rope = self.rope.forward(&q, position).unwrap_or(q);
327
328 self.o_proj.forward(&q_rope)
331 }
332}
333
334pub struct LLaMADecoderLayer {
336 self_attn: LLaMAAttention,
337 mlp: SwiGLU,
338 input_layernorm: RMSNorm,
339 post_attention_layernorm: RMSNorm,
340}
341
342impl LLaMADecoderLayer {
343 pub fn new(config: &LLaMAConfig) -> Self {
345 LLaMADecoderLayer {
346 self_attn: LLaMAAttention::new(config),
347 mlp: SwiGLU::new(config.hidden_size, config.intermediate_size),
348 input_layernorm: RMSNorm::new(config.hidden_size, config.rms_norm_eps),
349 post_attention_layernorm: RMSNorm::new(config.hidden_size, config.rms_norm_eps),
350 }
351 }
352
353 pub fn forward(&self, hidden_states: &Tensor, position: usize) -> Result<Tensor, String> {
355 let residual = hidden_states.clone();
357 let hidden_states = self.input_layernorm.forward(hidden_states)?;
358 let hidden_states = self.self_attn.forward(&hidden_states, position);
359 let hidden_states = hidden_states.add(&residual).unwrap_or(hidden_states);
360
361 let residual = hidden_states.clone();
363 let hidden_states = self.post_attention_layernorm.forward(&hidden_states)?;
364 let hidden_states = self.mlp.forward(&hidden_states);
365 let hidden_states = hidden_states.add(&residual).unwrap_or(hidden_states);
366
367 Ok(hidden_states)
368 }
369}
370
371pub struct LLaMAModel {
373 config: LLaMAConfig,
374 embed_tokens: Tensor,
375 layers: Vec<LLaMADecoderLayer>,
376 norm: RMSNorm,
377}
378
379impl LLaMAModel {
380 pub fn new(config: LLaMAConfig) -> Self {
382 let embed_tokens = Tensor::randn(&[config.vocab_size, config.hidden_size]);
383
384 let layers = (0..config.num_layers)
385 .map(|_| LLaMADecoderLayer::new(&config))
386 .collect();
387
388 let norm = RMSNorm::new(config.hidden_size, config.rms_norm_eps);
389
390 LLaMAModel {
391 config,
392 embed_tokens,
393 layers,
394 norm,
395 }
396 }
397
398 pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor, String> {
400 let mut hidden_states = self.get_embeddings(input_ids)?;
402
403 let seq_len = input_ids.dims()[1];
405 for pos in 0..seq_len {
406 for layer in &self.layers {
407 hidden_states = layer.forward(&hidden_states, pos)?;
408 }
409 }
410
411 self.norm.forward(&hidden_states)
413 }
414
415 fn get_embeddings(&self, input_ids: &Tensor) -> Result<Tensor, String> {
417 let ids_data = input_ids.data_f32();
418 let embed_data = self.embed_tokens.data_f32();
419
420 let dims = input_ids.dims();
421 let batch_size = dims[0];
422 let seq_length = dims[1];
423 let hidden_size = self.config.hidden_size;
424
425 let mut result = Vec::with_capacity(batch_size * seq_length * hidden_size);
426
427 for &id in ids_data.iter() {
428 let idx = id as usize;
429 if idx >= self.config.vocab_size {
430 return Err(format!("Token ID {} out of vocabulary", idx));
431 }
432
433 let start = idx * hidden_size;
434 let end = start + hidden_size;
435 result.extend_from_slice(&embed_data[start..end]);
436 }
437
438 Tensor::from_slice(&result, &[batch_size, seq_length, hidden_size])
439 .map_err(|e| format!("Failed to create embeddings: {:?}", e))
440 }
441}
442
443pub struct LLaMAForCausalLM {
445 model: LLaMAModel,
446 lm_head: Linear,
447}
448
449impl LLaMAForCausalLM {
450 pub fn new(config: LLaMAConfig) -> Self {
452 let model = LLaMAModel::new(config.clone());
453 let lm_head = Linear::new(config.hidden_size, config.vocab_size);
454
455 LLaMAForCausalLM {
456 model,
457 lm_head,
458 }
459 }
460
461 pub fn forward(&self, input_ids: &Tensor) -> Result<Tensor, String> {
463 let hidden_states = self.model.forward(input_ids)?;
464 let logits = self.lm_head.forward(&hidden_states);
465 Ok(logits)
466 }
467
468 pub fn generate(&self, input_ids: &Tensor, max_new_tokens: usize) -> Result<Vec<usize>, String> {
470 let mut current_ids = input_ids.data_f32().iter().map(|&x| x as usize).collect::<Vec<_>>();
471
472 for _ in 0..max_new_tokens {
473 let input_tensor = Tensor::from_slice(
474 ¤t_ids.iter().map(|&x| x as f32).collect::<Vec<_>>(),
475 &[1, current_ids.len()]
476 ).map_err(|e| format!("Failed to create input: {:?}", e))?;
477
478 let logits = self.forward(&input_tensor)?;
479 let next_token = self.sample_next_token(&logits)?;
480
481 current_ids.push(next_token);
482 }
483
484 Ok(current_ids)
485 }
486
487 fn sample_next_token(&self, logits: &Tensor) -> Result<usize, String> {
489 let data = logits.data_f32();
490 let dims = logits.dims();
491
492 let seq_len = dims[1];
493 let vocab_size = dims[2];
494
495 let start = (seq_len - 1) * vocab_size;
497 let end = start + vocab_size;
498 let last_logits = &data[start..end];
499
500 let next_token = last_logits.iter()
502 .enumerate()
503 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
504 .map(|(idx, _)| idx)
505 .ok_or_else(|| "Failed to sample token".to_string())?;
506
507 Ok(next_token)
508 }
509}
510
511#[cfg(test)]
512mod tests {
513 use super::*;
514
515 #[test]
516 fn test_llama_config() {
517 let config = LLaMAConfig::llama_7b();
518 assert_eq!(config.hidden_size, 4096);
519 assert_eq!(config.num_layers, 32);
520
521 let config = LLaMAConfig::llama2_70b();
522 assert_eq!(config.num_key_value_heads, 8); assert_eq!(config.max_position_embeddings, 4096);
524 }
525
526 #[test]
527 fn test_rms_norm() {
528 let norm = RMSNorm::new(128, 1e-6);
529 let x = Tensor::randn(&[2, 4, 128]);
530 let output = norm.forward(&x).unwrap();
531 assert_eq!(output.dims(), &[2, 4, 128]);
532 }
533
534 #[test]
535 fn test_rope() {
536 let rope = RotaryEmbedding::new(64, 512, 10000.0);
537 let x = Tensor::randn(&[2, 64]);
538 let output = rope.forward(&x, 10).unwrap();
539 assert_eq!(output.dims(), &[2, 64]);
540 }
541
542 #[test]
543 fn test_llama_model() {
544 let config = LLaMAConfig::llama_tiny();
545 let model = LLaMAModel::new(config);
546
547 let input_ids = Tensor::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
548 let output = model.forward(&input_ids).unwrap();
549
550 assert_eq!(output.dims(), &[2, 2, 256]); }
552}