1use crate::funcs::*;
2use crate::graph::{Graph, GraphError, TensorId};
3use crate::optimizer::{Optimizer, OptimizerState};
4use crate::tensor::{GeneralTensor, Tensor, TensorError, TensorOps};
5use rand::Rng;
6use rayon::prelude::*;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::time::Instant;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct TrainingState {
13 pub tensors: HashMap<String, Tensor<f32>>,
14 pub optimizer: OptimizerState,
15}
16
17pub struct GPT<G: Graph> {
18 graph: G,
19 num_tokens: usize,
20 token_input: TensorId,
21 pos_input: TensorId,
22 output: TensorId,
23 expected_output: TensorId,
24 loss: TensorId,
25 pos_input_fixed: Tensor<f32>,
26}
27
28fn sample_dataset<R: Rng>(
29 dataset: &[usize],
30 batch_size: usize,
31 context_size: usize,
32 rng: &mut R,
33) -> (Tensor<usize>, Tensor<usize>) {
34 let mut xs: Vec<usize> = Vec::with_capacity(batch_size * context_size);
35 let mut ys: Vec<usize> = Vec::with_capacity(batch_size * context_size);
36 for _i in 0..batch_size {
37 let start: usize = rng.gen_range(0..dataset.len());
38 let all = dataset
39 .iter()
40 .cycle()
41 .skip(start)
42 .take(context_size + 1)
43 .cloned()
44 .collect::<Vec<_>>();
45 xs.extend(&all[0..context_size]);
46 ys.extend(&all[1..context_size + 1]);
47 }
48
49 (
50 Tensor::raw(&[batch_size, context_size], xs).unwrap(),
51 Tensor::raw(&[batch_size, context_size], ys).unwrap(),
52 )
53}
54
55fn select<R: Rng, T: TensorOps<f32>>(
56 rng: &mut R,
57 t: &T,
58 temperature: f32,
59) -> Result<usize, TensorError> {
60 let t = Softmax::new().run(
61 &[&GeneralTensor::Float(Tensor::<f32>::raw(
62 t.shape(),
63 t.blob().to_vec(),
64 )?)],
65 false,
66 )?;
67 let mut ts = t.blob().iter().cloned().enumerate().collect::<Vec<_>>();
68 ts.sort_by_key(|(_, b)| (b * 1000.) as usize);
69 let dice = rng.gen_range(0.0..temperature);
70 let mut accum = 0.;
71 for (id, t) in ts.iter().rev() {
72 accum += t;
73 if dice < accum {
74 return Ok(*id);
75 }
76 }
77 panic!();
78}
79
80fn pos_encode_inter(num_tokens: usize, embedding_size: usize) -> Tensor<f32> {
81 let mut raw_new = Vec::new();
82 let cols = embedding_size;
83 let rows = num_tokens;
84 for row in 0..rows {
85 for col in 0..cols {
86 let k = row as f32;
87 let i = (col / 2) as f32;
88 let factor = 10000f32.powf(2f32 * i / embedding_size as f32);
89
90 let pos = if col % 2 == 0 {
91 (k / factor).sin()
92 } else {
93 (k / factor).cos()
94 };
95
96 raw_new.push(pos);
97 }
98 }
99
100 Tensor::raw(&[rows, cols], raw_new).unwrap()
101}
102
103impl<G: Graph> GPT<G> {
104 pub fn new<R: Rng>(
105 rng: &mut R,
106 mut g: G,
107 batch_size: Option<usize>,
108 vocab_size: usize,
109 embedding_degree: usize,
110 num_tokens: usize,
111 num_layers: usize,
112 num_heads: usize,
113 head_size: usize,
114 dropout: f32,
115 ) -> Result<Self, GraphError> {
116 let token_embedding = g.alloc(
118 Tensor::<f32>::rand(rng, &[vocab_size, embedding_degree]),
119 true,
120 "token_embedding".into(),
121 )?;
122
123 let token_input = g.alloc_usize(
131 Tensor::<usize>::zeros(&if let Some(batch_size) = batch_size {
132 vec![batch_size, num_tokens]
133 } else {
134 vec![num_tokens]
135 }),
136 "token_input".into(),
137 )?;
138
139 let expected_output = g.alloc_usize(
140 Tensor::<usize>::zeros(&if let Some(batch_size) = batch_size {
141 vec![batch_size, num_tokens]
142 } else {
143 vec![num_tokens]
144 }),
145 "expected_output".into(),
146 )?;
147
148 let embedded_token_input = g.call(Embedding::new(), &[token_input, token_embedding])?;
151
152 let pos_input = g.alloc(
154 Tensor::<f32>::rand(rng, &[num_tokens, embedding_degree]),
155 false,
156 "pos_input".into(),
157 )?;
158
159 let inp = g.call(Add::new(), &[embedded_token_input, pos_input])?;
162
163 let mut curr_inp = inp;
164 for l in 0..num_layers {
165 let norm_coeff = g.alloc(
167 Tensor::<f32>::rand(rng, &[embedding_degree]),
168 true,
169 format!("norm_{}_coeff", l),
170 )?;
171 let norm_bias = g.alloc(
172 Tensor::<f32>::zeros(&[embedding_degree]),
173 true,
174 format!("norm_{}_bias", l),
175 )?;
176 let norm_inp = g.call(LayerNorm::new(), &[curr_inp, norm_coeff, norm_bias])?;
177
178 let mut heads = Vec::new();
179
180 for h in 0..num_heads {
182 let k_params = g.alloc(
184 Tensor::<f32>::rand(rng, &[embedding_degree, head_size]),
185 true,
186 format!("head_{}_{}_k", l, h),
187 )?;
188 let k = g.call(MatMul::new(), &[norm_inp, k_params])?;
189
190 let q_params = g.alloc(
192 Tensor::<f32>::rand(rng, &[embedding_degree, head_size]),
193 true,
194 format!("head_{}_{}_q", l, h),
195 )?;
196 let q = g.call(MatMul::new(), &[norm_inp, q_params])?;
197
198 let v_params = g.alloc(
200 Tensor::<f32>::rand(rng, &[embedding_degree, head_size]),
201 true,
202 format!("head_{}_{}_v", l, h),
203 )?;
204 let v = g.call(MatMul::new(), &[norm_inp, v_params])?;
205
206 let q_t = g.call(Transpose::new(), &[q])?;
207 let kq = g.call(MatMul::new(), &[k, q_t])?;
208
209 let head_size_sqrt_inv = (head_size as f32).powf(-0.5);
210 let kq_coeff = g.call(Coeff::new(head_size_sqrt_inv), &[kq])?;
211
212 let masked_kq = g.call(TrilMask::new(num_tokens), &[kq_coeff])?;
213 let soft_masked_kq = g.call(Softmax::new(), &[masked_kq])?;
214 let dropped_soft_masked_kq = g.call(Dropout::new(dropout), &[soft_masked_kq])?;
215 let atten = g.call(MatMul::new(), &[dropped_soft_masked_kq, v])?;
216 heads.push(atten);
217 }
218
219 let cat = g.call(Cat::new(), &heads)?;
221 let proj_params = g.alloc(
222 Tensor::<f32>::rand(rng, &[num_heads * head_size, embedding_degree]),
223 true,
224 format!("proj_{}_weights", l),
225 )?;
226 let proj_bias_params = g.alloc(
227 Tensor::<f32>::zeros(&[embedding_degree]),
228 true,
229 format!("proj_{}_bias", l),
230 )?;
231 let proj_cat = g.call(MatMul::new(), &[cat, proj_params])?;
232 let proj_cat_bias = g.call(Add::new(), &[proj_cat, proj_bias_params])?;
233 let dropped_proj_cat_bias = g.call(Dropout::new(dropout), &[proj_cat_bias])?;
234
235 let add_atten = g.call(Add::new(), &[norm_inp, dropped_proj_cat_bias])?;
237 let add_atten_norm_coeff = g.alloc(
238 Tensor::<f32>::rand(rng, &[embedding_degree]),
239 true,
240 format!("atten_norm_{}_coeff", l),
241 )?;
242 let add_atten_norm_bias = g.alloc(
243 Tensor::<f32>::zeros(&[embedding_degree]),
244 true,
245 format!("atten_norm_{}_bias", l),
246 )?;
247 let add_atten_norm = g.call(
248 LayerNorm::new(),
249 &[add_atten, add_atten_norm_coeff, add_atten_norm_bias],
250 )?;
251
252 let lin1_params = g.alloc(
257 Tensor::<f32>::rand(rng, &[embedding_degree, 4 * embedding_degree]),
258 true,
259 format!("feedforward1_{}_weights", l),
260 )?;
261 let bias1_params = g.alloc(
262 Tensor::<f32>::zeros(&[4 * embedding_degree]),
263 true,
264 format!("feedforward1_{}_bias", l),
265 )?;
266 let lin1_result = g.call(MatMul::new(), &[add_atten_norm, lin1_params])?;
267 let lin1_bias_result = g.call(Add::new(), &[lin1_result, bias1_params])?;
268 let lin1_act = g.call(Gelu::new(), &[lin1_bias_result])?;
269 let lin2_params = g.alloc(
270 Tensor::<f32>::rand(rng, &[4 * embedding_degree, embedding_degree]),
271 true,
272 format!("feedforward2_{}_weights", l),
273 )?;
274 let bias2_params = g.alloc(
275 Tensor::<f32>::zeros(&[embedding_degree]),
276 true,
277 format!("feedforward2_{}_bias", l),
278 )?;
279 let lin2_result = g.call(MatMul::new(), &[lin1_act, lin2_params])?;
280 let lin2_bias_result = g.call(Add::new(), &[lin2_result, bias2_params])?;
281
282 curr_inp = g.call(Add::new(), &[add_atten_norm, lin2_bias_result])?;
283 }
284
285 let norm_out_coeff = g.alloc(
287 Tensor::<f32>::rand(rng, &[embedding_degree]),
288 true,
289 format!("head_norm_coeff"),
290 )?;
291 let norm_out_bias = g.alloc(
292 Tensor::<f32>::zeros(&[embedding_degree]),
293 true,
294 format!("head_norm_bias"),
295 )?;
296 let norm_out = g.call(LayerNorm::new(), &[curr_inp, norm_out_coeff, norm_out_bias])?;
297
298 let to_vocab = g.alloc(
300 Tensor::<f32>::rand(rng, &[embedding_degree, vocab_size]),
301 true,
302 format!("head_map_weights"),
303 )?;
304 let to_vocab_bias = g.alloc(
305 Tensor::<f32>::zeros(&[vocab_size]),
306 true,
307 format!("head_map_bias"),
308 )?;
309 let result_lin = g.call(MatMul::new(), &[norm_out, to_vocab])?;
310 let output = g.call(Add::new(), &[result_lin, to_vocab_bias])?;
311
312 let loss = g.call(CrossEntropy::new(), &[output, expected_output])?;
313
314 Ok(Self {
315 graph: g,
316 num_tokens,
317 token_input,
318 pos_input,
319 output,
320 expected_output,
321 loss,
322 pos_input_fixed: pos_encode_inter(num_tokens, embedding_degree),
323 })
324 }
325
326 pub fn sync(&mut self) -> Result<(), GraphError> {
327 self.graph
328 .params()
329 .to_vec()
330 .into_iter()
331 .map(|p| self.graph.fetch(p, false))
332 .collect::<Result<Vec<_>, GraphError>>()?;
333 Ok(())
334 }
335
336 pub fn num_params(&self) -> usize {
337 self.graph
338 .params()
339 .to_vec()
340 .into_iter()
341 .map(|p| self.graph.get(p).unwrap().as_float().unwrap().size())
342 .sum::<usize>()
343 }
344
345 pub fn set_training_state(
346 &mut self,
347 training_state: TrainingState,
348 load_optimizer: bool,
349 ) -> Result<(), GraphError> {
350 for p in self.graph.params().to_vec() {
351 let name = self.graph.name_of(p)?;
352 if let Some(t) = training_state.tensors.get(name) {
353 self.graph.load(p, t)?;
354 }
355 }
356 if load_optimizer {
357 self.graph.set_optimizer_state(&training_state.optimizer)?;
358 }
359 Ok(())
360 }
361
362 pub fn get_training_state(&self) -> Result<TrainingState, GraphError> {
363 let mut state = TrainingState {
364 tensors: Default::default(),
365 optimizer: self.graph.get_optimizer_state()?,
366 };
367 for p in self.graph.params().iter() {
368 let k = self.graph.name_of(*p)?.to_string();
369 let v = self.graph.get(*p)?.as_float()?.clone();
370 state.tensors.insert(k, v);
371 }
372 Ok(state)
373 }
374
375 pub fn train_cpu<
376 O: Optimizer,
377 F: Fn(usize) -> f32,
378 C: Fn(&mut Self) -> Result<(), GraphError>,
379 >(
380 &mut self,
381 dataset: &[usize],
382 num_batches: usize,
383 batch_size: usize,
384 limit: Option<usize>,
385 optimizer: &O,
386 learning_rate: F,
387 callback: C,
388 ) -> Result<(), GraphError>
389 where
390 G: Clone + Send + Sync,
391 {
392 self.graph.load(self.pos_input, &self.pos_input_fixed)?;
393
394 for i in 0..num_batches {
395 let timer = Instant::now();
396 let (graphs, errs): (Vec<G>, Vec<f32>) = (0..batch_size)
397 .into_par_iter()
398 .map(|_| {
399 let mut rng = rand::thread_rng();
400 let mut graph = self.graph.clone();
401 let (xs, ys) = sample_dataset(dataset, 1, self.num_tokens, &mut rng);
402
403 graph.load_usize(self.token_input, &xs)?;
404 graph.load_usize(self.expected_output, &ys)?;
405 graph.forward(true)?;
406 graph.zero_grad()?;
407 let err = graph.backward_all(self.loss, limit)?;
408 Ok((graph, err))
409 })
410 .collect::<Result<Vec<(G, f32)>, GraphError>>()?
411 .into_iter()
412 .unzip();
413 for (id, avg) in self
414 .graph
415 .params()
416 .to_vec()
417 .into_par_iter()
418 .map(|id| {
419 let mut avg = Tensor::<f32>::scalar(0.);
420 for g in graphs.iter() {
421 avg = (&avg + g.get_grad(id)?)?;
422 }
423 avg = avg.map_values(|f| f / graphs.len() as f32);
424 Ok((id, avg))
425 })
426 .collect::<Result<Vec<_>, GraphError>>()?
427 {
428 self.graph.load_grad(id, &avg)?;
429 }
430 let avg_loss = errs.iter().sum::<f32>() / errs.len() as f32;
431 let lr = learning_rate(self.graph.optimizer_step());
432 self.graph.optimize(optimizer, lr)?;
433 if i % 10 == 0 {
434 self.sync()?;
435 callback(self)?;
436 }
437 println!(
438 "Step: {} Loss: {} (Elapsed: {}ms)",
439 self.graph.optimizer_step(),
440 avg_loss,
441 timer.elapsed().as_millis()
442 );
443 }
444 Ok(())
445 }
446
447 pub fn train<O: Optimizer, F: Fn(usize) -> f32, C: Fn(&mut Self) -> Result<(), GraphError>>(
448 &mut self,
449 dataset: &[usize],
450 num_batches: usize,
451 batch_size: usize,
452 limit: Option<usize>,
453 optimizer: &O,
454 learning_rate: F,
455 callback: C,
456 ) -> Result<(), GraphError> {
457 self.graph.load(self.pos_input, &self.pos_input_fixed)?;
458
459 for i in 0..num_batches {
460 let timer = Instant::now();
461 let mut rng = rand::thread_rng();
462 let (xs, ys) = sample_dataset(dataset, batch_size, self.num_tokens, &mut rng);
463
464 self.graph.load_usize(self.token_input, &xs)?;
465 self.graph.load_usize(self.expected_output, &ys)?;
466
467 self.graph.forward(true)?;
468 self.graph.zero_grad()?;
469 let err = self.graph.backward_all(self.loss, limit)?;
470 let lr = learning_rate(self.graph.optimizer_step());
471 self.graph.optimize(optimizer, lr)?;
472 if i % 50 == 0 {
473 callback(self)?;
474 }
475 println!(
476 "Step: {} Loss: {} (Elapsed: {}ms)",
477 self.graph.optimizer_step(),
478 err,
479 timer.elapsed().as_millis()
480 );
481 }
482 Ok(())
483 }
484
485 pub fn infer<R: Rng, F: Fn(usize) -> ()>(
486 &mut self,
487 rng: &mut R,
488 prompt: &[usize],
489 count: usize,
490 temperature: f32,
491 callback: F,
492 ) -> Result<Vec<usize>, GraphError> {
493 let mut cnt = prompt.len();
494 let mut context = vec![0; self.num_tokens];
495 context[..prompt.len()].copy_from_slice(prompt);
496
497 self.graph.load(self.pos_input, &self.pos_input_fixed)?;
498
499 for ch in prompt {
500 callback(*ch);
501 }
502 let mut chs = prompt.to_vec();
503 for _ in 0..count {
504 self.graph.load_usize(
505 self.token_input,
506 &Tensor::raw(&[1, self.num_tokens], context.clone())?,
507 )?;
508
509 self.graph.forward(false)?;
510 self.graph.fetch(self.output, false)?;
511 let next_ch = select(
512 rng,
513 &self
514 .graph
515 .get(self.output)?
516 .as_float()?
517 .get(0)?
518 .get(cnt - 1)?,
519 temperature,
520 )?;
521
522 chs.push(next_ch);
523 callback(next_ch);
524 if cnt == self.num_tokens {
525 context.remove(0);
526 context.push(0);
527 cnt -= 1;
528 }
529 context[cnt] = next_ch;
530 cnt += 1;
531 }
532 Ok(chs)
533 }
534}