femto_gpt/
gpt.rs

1use crate::funcs::*;
2use crate::graph::{Graph, GraphError, TensorId};
3use crate::optimizer::{Optimizer, OptimizerState};
4use crate::tensor::{GeneralTensor, Tensor, TensorError, TensorOps};
5use rand::Rng;
6use rayon::prelude::*;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::time::Instant;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct TrainingState {
13    pub tensors: HashMap<String, Tensor<f32>>,
14    pub optimizer: OptimizerState,
15}
16
17pub struct GPT<G: Graph> {
18    graph: G,
19    num_tokens: usize,
20    token_input: TensorId,
21    pos_input: TensorId,
22    output: TensorId,
23    expected_output: TensorId,
24    loss: TensorId,
25    pos_input_fixed: Tensor<f32>,
26}
27
28fn sample_dataset<R: Rng>(
29    dataset: &[usize],
30    batch_size: usize,
31    context_size: usize,
32    rng: &mut R,
33) -> (Tensor<usize>, Tensor<usize>) {
34    let mut xs: Vec<usize> = Vec::with_capacity(batch_size * context_size);
35    let mut ys: Vec<usize> = Vec::with_capacity(batch_size * context_size);
36    for _i in 0..batch_size {
37        let start: usize = rng.gen_range(0..dataset.len());
38        let all = dataset
39            .iter()
40            .cycle()
41            .skip(start)
42            .take(context_size + 1)
43            .cloned()
44            .collect::<Vec<_>>();
45        xs.extend(&all[0..context_size]);
46        ys.extend(&all[1..context_size + 1]);
47    }
48
49    (
50        Tensor::raw(&[batch_size, context_size], xs).unwrap(),
51        Tensor::raw(&[batch_size, context_size], ys).unwrap(),
52    )
53}
54
55fn select<R: Rng, T: TensorOps<f32>>(
56    rng: &mut R,
57    t: &T,
58    temperature: f32,
59) -> Result<usize, TensorError> {
60    let t = Softmax::new().run(
61        &[&GeneralTensor::Float(Tensor::<f32>::raw(
62            t.shape(),
63            t.blob().to_vec(),
64        )?)],
65        false,
66    )?;
67    let mut ts = t.blob().iter().cloned().enumerate().collect::<Vec<_>>();
68    ts.sort_by_key(|(_, b)| (b * 1000.) as usize);
69    let dice = rng.gen_range(0.0..temperature);
70    let mut accum = 0.;
71    for (id, t) in ts.iter().rev() {
72        accum += t;
73        if dice < accum {
74            return Ok(*id);
75        }
76    }
77    panic!();
78}
79
80fn pos_encode_inter(num_tokens: usize, embedding_size: usize) -> Tensor<f32> {
81    let mut raw_new = Vec::new();
82    let cols = embedding_size;
83    let rows = num_tokens;
84    for row in 0..rows {
85        for col in 0..cols {
86            let k = row as f32;
87            let i = (col / 2) as f32;
88            let factor = 10000f32.powf(2f32 * i / embedding_size as f32);
89
90            let pos = if col % 2 == 0 {
91                (k / factor).sin()
92            } else {
93                (k / factor).cos()
94            };
95
96            raw_new.push(pos);
97        }
98    }
99
100    Tensor::raw(&[rows, cols], raw_new).unwrap()
101}
102
103impl<G: Graph> GPT<G> {
104    pub fn new<R: Rng>(
105        rng: &mut R,
106        mut g: G,
107        batch_size: Option<usize>,
108        vocab_size: usize,
109        embedding_degree: usize,
110        num_tokens: usize,
111        num_layers: usize,
112        num_heads: usize,
113        head_size: usize,
114        dropout: f32,
115    ) -> Result<Self, GraphError> {
116        // Mapping each token to a `embedding_degree` dimension space through a lookup table
117        let token_embedding = g.alloc(
118            Tensor::<f32>::rand(rng, &[vocab_size, embedding_degree]),
119            true,
120            "token_embedding".into(),
121        )?;
122
123        // Token inputs. We will get `num_tokens` tokens as inputs and will have `num_tokens`
124        // outputs.
125        // In the case of CPU training, it's much more efficient to parallelize over instances
126        // in a single batch. (I.e. it's not very efficient to parallelize a matrix multiplication
127        // operation on CPUs. Better approach is to process a 32-instanced batch on a 32-core CPU,
128        // where each instance runs on its own core, without parallelizing operations)
129        // That's why we DO NOT specify a `batch_size` when training on a CPU.
130        let token_input = g.alloc_usize(
131            Tensor::<usize>::zeros(&if let Some(batch_size) = batch_size {
132                vec![batch_size, num_tokens]
133            } else {
134                vec![num_tokens]
135            }),
136            "token_input".into(),
137        )?;
138
139        let expected_output = g.alloc_usize(
140            Tensor::<usize>::zeros(&if let Some(batch_size) = batch_size {
141                vec![batch_size, num_tokens]
142            } else {
143                vec![num_tokens]
144            }),
145            "expected_output".into(),
146        )?;
147
148        // Map the token index into a `embedding_degree` dimension vector through the `token_embedding`
149        // lookup table.
150        let embedded_token_input = g.call(Embedding::new(), &[token_input, token_embedding])?;
151
152        // Map token positions into `embedding_degree` dimension vectors.
153        let pos_input = g.alloc(
154            Tensor::<f32>::rand(rng, &[num_tokens, embedding_degree]),
155            false,
156            "pos_input".into(),
157        )?;
158
159        // Positional+Token information will both reside in a single `embedding_degree` dimension
160        // vector.
161        let inp = g.call(Add::new(), &[embedded_token_input, pos_input])?;
162
163        let mut curr_inp = inp;
164        for l in 0..num_layers {
165            // Normalize input before applying multi-head attention
166            let norm_coeff = g.alloc(
167                Tensor::<f32>::rand(rng, &[embedding_degree]),
168                true,
169                format!("norm_{}_coeff", l),
170            )?;
171            let norm_bias = g.alloc(
172                Tensor::<f32>::zeros(&[embedding_degree]),
173                true,
174                format!("norm_{}_bias", l),
175            )?;
176            let norm_inp = g.call(LayerNorm::new(), &[curr_inp, norm_coeff, norm_bias])?;
177
178            let mut heads = Vec::new();
179
180            // Multi-head Attention
181            for h in 0..num_heads {
182                // Key
183                let k_params = g.alloc(
184                    Tensor::<f32>::rand(rng, &[embedding_degree, head_size]),
185                    true,
186                    format!("head_{}_{}_k", l, h),
187                )?;
188                let k = g.call(MatMul::new(), &[norm_inp, k_params])?;
189
190                // Query
191                let q_params = g.alloc(
192                    Tensor::<f32>::rand(rng, &[embedding_degree, head_size]),
193                    true,
194                    format!("head_{}_{}_q", l, h),
195                )?;
196                let q = g.call(MatMul::new(), &[norm_inp, q_params])?;
197
198                // Value
199                let v_params = g.alloc(
200                    Tensor::<f32>::rand(rng, &[embedding_degree, head_size]),
201                    true,
202                    format!("head_{}_{}_v", l, h),
203                )?;
204                let v = g.call(MatMul::new(), &[norm_inp, v_params])?;
205
206                let q_t = g.call(Transpose::new(), &[q])?;
207                let kq = g.call(MatMul::new(), &[k, q_t])?;
208
209                let head_size_sqrt_inv = (head_size as f32).powf(-0.5);
210                let kq_coeff = g.call(Coeff::new(head_size_sqrt_inv), &[kq])?;
211
212                let masked_kq = g.call(TrilMask::new(num_tokens), &[kq_coeff])?;
213                let soft_masked_kq = g.call(Softmax::new(), &[masked_kq])?;
214                let dropped_soft_masked_kq = g.call(Dropout::new(dropout), &[soft_masked_kq])?;
215                let atten = g.call(MatMul::new(), &[dropped_soft_masked_kq, v])?;
216                heads.push(atten);
217            }
218
219            // Concat head results and project into embedding_degree
220            let cat = g.call(Cat::new(), &heads)?;
221            let proj_params = g.alloc(
222                Tensor::<f32>::rand(rng, &[num_heads * head_size, embedding_degree]),
223                true,
224                format!("proj_{}_weights", l),
225            )?;
226            let proj_bias_params = g.alloc(
227                Tensor::<f32>::zeros(&[embedding_degree]),
228                true,
229                format!("proj_{}_bias", l),
230            )?;
231            let proj_cat = g.call(MatMul::new(), &[cat, proj_params])?;
232            let proj_cat_bias = g.call(Add::new(), &[proj_cat, proj_bias_params])?;
233            let dropped_proj_cat_bias = g.call(Dropout::new(dropout), &[proj_cat_bias])?;
234
235            // Add attention results to input and then normalize
236            let add_atten = g.call(Add::new(), &[norm_inp, dropped_proj_cat_bias])?;
237            let add_atten_norm_coeff = g.alloc(
238                Tensor::<f32>::rand(rng, &[embedding_degree]),
239                true,
240                format!("atten_norm_{}_coeff", l),
241            )?;
242            let add_atten_norm_bias = g.alloc(
243                Tensor::<f32>::zeros(&[embedding_degree]),
244                true,
245                format!("atten_norm_{}_bias", l),
246            )?;
247            let add_atten_norm = g.call(
248                LayerNorm::new(),
249                &[add_atten, add_atten_norm_coeff, add_atten_norm_bias],
250            )?;
251
252            // A feed-forward layer:
253            // Linear embedding_degree -> 4*embedding_degree
254            // Relu
255            // Linear 4*embedding_degree -> embedding_degree
256            let lin1_params = g.alloc(
257                Tensor::<f32>::rand(rng, &[embedding_degree, 4 * embedding_degree]),
258                true,
259                format!("feedforward1_{}_weights", l),
260            )?;
261            let bias1_params = g.alloc(
262                Tensor::<f32>::zeros(&[4 * embedding_degree]),
263                true,
264                format!("feedforward1_{}_bias", l),
265            )?;
266            let lin1_result = g.call(MatMul::new(), &[add_atten_norm, lin1_params])?;
267            let lin1_bias_result = g.call(Add::new(), &[lin1_result, bias1_params])?;
268            let lin1_act = g.call(Gelu::new(), &[lin1_bias_result])?;
269            let lin2_params = g.alloc(
270                Tensor::<f32>::rand(rng, &[4 * embedding_degree, embedding_degree]),
271                true,
272                format!("feedforward2_{}_weights", l),
273            )?;
274            let bias2_params = g.alloc(
275                Tensor::<f32>::zeros(&[embedding_degree]),
276                true,
277                format!("feedforward2_{}_bias", l),
278            )?;
279            let lin2_result = g.call(MatMul::new(), &[lin1_act, lin2_params])?;
280            let lin2_bias_result = g.call(Add::new(), &[lin2_result, bias2_params])?;
281
282            curr_inp = g.call(Add::new(), &[add_atten_norm, lin2_bias_result])?;
283        }
284
285        // Normalize the output after the last layer
286        let norm_out_coeff = g.alloc(
287            Tensor::<f32>::rand(rng, &[embedding_degree]),
288            true,
289            format!("head_norm_coeff"),
290        )?;
291        let norm_out_bias = g.alloc(
292            Tensor::<f32>::zeros(&[embedding_degree]),
293            true,
294            format!("head_norm_bias"),
295        )?;
296        let norm_out = g.call(LayerNorm::new(), &[curr_inp, norm_out_coeff, norm_out_bias])?;
297
298        // Map from embedding_degree to vocab_size through a linear layer
299        let to_vocab = g.alloc(
300            Tensor::<f32>::rand(rng, &[embedding_degree, vocab_size]),
301            true,
302            format!("head_map_weights"),
303        )?;
304        let to_vocab_bias = g.alloc(
305            Tensor::<f32>::zeros(&[vocab_size]),
306            true,
307            format!("head_map_bias"),
308        )?;
309        let result_lin = g.call(MatMul::new(), &[norm_out, to_vocab])?;
310        let output = g.call(Add::new(), &[result_lin, to_vocab_bias])?;
311
312        let loss = g.call(CrossEntropy::new(), &[output, expected_output])?;
313
314        Ok(Self {
315            graph: g,
316            num_tokens,
317            token_input,
318            pos_input,
319            output,
320            expected_output,
321            loss,
322            pos_input_fixed: pos_encode_inter(num_tokens, embedding_degree),
323        })
324    }
325
326    pub fn sync(&mut self) -> Result<(), GraphError> {
327        self.graph
328            .params()
329            .to_vec()
330            .into_iter()
331            .map(|p| self.graph.fetch(p, false))
332            .collect::<Result<Vec<_>, GraphError>>()?;
333        Ok(())
334    }
335
336    pub fn num_params(&self) -> usize {
337        self.graph
338            .params()
339            .to_vec()
340            .into_iter()
341            .map(|p| self.graph.get(p).unwrap().as_float().unwrap().size())
342            .sum::<usize>()
343    }
344
345    pub fn set_training_state(
346        &mut self,
347        training_state: TrainingState,
348        load_optimizer: bool,
349    ) -> Result<(), GraphError> {
350        for p in self.graph.params().to_vec() {
351            let name = self.graph.name_of(p)?;
352            if let Some(t) = training_state.tensors.get(name) {
353                self.graph.load(p, t)?;
354            }
355        }
356        if load_optimizer {
357            self.graph.set_optimizer_state(&training_state.optimizer)?;
358        }
359        Ok(())
360    }
361
362    pub fn get_training_state(&self) -> Result<TrainingState, GraphError> {
363        let mut state = TrainingState {
364            tensors: Default::default(),
365            optimizer: self.graph.get_optimizer_state()?,
366        };
367        for p in self.graph.params().iter() {
368            let k = self.graph.name_of(*p)?.to_string();
369            let v = self.graph.get(*p)?.as_float()?.clone();
370            state.tensors.insert(k, v);
371        }
372        Ok(state)
373    }
374
375    pub fn train_cpu<
376        O: Optimizer,
377        F: Fn(usize) -> f32,
378        C: Fn(&mut Self) -> Result<(), GraphError>,
379    >(
380        &mut self,
381        dataset: &[usize],
382        num_batches: usize,
383        batch_size: usize,
384        limit: Option<usize>,
385        optimizer: &O,
386        learning_rate: F,
387        callback: C,
388    ) -> Result<(), GraphError>
389    where
390        G: Clone + Send + Sync,
391    {
392        self.graph.load(self.pos_input, &self.pos_input_fixed)?;
393
394        for i in 0..num_batches {
395            let timer = Instant::now();
396            let (graphs, errs): (Vec<G>, Vec<f32>) = (0..batch_size)
397                .into_par_iter()
398                .map(|_| {
399                    let mut rng = rand::thread_rng();
400                    let mut graph = self.graph.clone();
401                    let (xs, ys) = sample_dataset(dataset, 1, self.num_tokens, &mut rng);
402
403                    graph.load_usize(self.token_input, &xs)?;
404                    graph.load_usize(self.expected_output, &ys)?;
405                    graph.forward(true)?;
406                    graph.zero_grad()?;
407                    let err = graph.backward_all(self.loss, limit)?;
408                    Ok((graph, err))
409                })
410                .collect::<Result<Vec<(G, f32)>, GraphError>>()?
411                .into_iter()
412                .unzip();
413            for (id, avg) in self
414                .graph
415                .params()
416                .to_vec()
417                .into_par_iter()
418                .map(|id| {
419                    let mut avg = Tensor::<f32>::scalar(0.);
420                    for g in graphs.iter() {
421                        avg = (&avg + g.get_grad(id)?)?;
422                    }
423                    avg = avg.map_values(|f| f / graphs.len() as f32);
424                    Ok((id, avg))
425                })
426                .collect::<Result<Vec<_>, GraphError>>()?
427            {
428                self.graph.load_grad(id, &avg)?;
429            }
430            let avg_loss = errs.iter().sum::<f32>() / errs.len() as f32;
431            let lr = learning_rate(self.graph.optimizer_step());
432            self.graph.optimize(optimizer, lr)?;
433            if i % 10 == 0 {
434                self.sync()?;
435                callback(self)?;
436            }
437            println!(
438                "Step: {} Loss: {} (Elapsed: {}ms)",
439                self.graph.optimizer_step(),
440                avg_loss,
441                timer.elapsed().as_millis()
442            );
443        }
444        Ok(())
445    }
446
447    pub fn train<O: Optimizer, F: Fn(usize) -> f32, C: Fn(&mut Self) -> Result<(), GraphError>>(
448        &mut self,
449        dataset: &[usize],
450        num_batches: usize,
451        batch_size: usize,
452        limit: Option<usize>,
453        optimizer: &O,
454        learning_rate: F,
455        callback: C,
456    ) -> Result<(), GraphError> {
457        self.graph.load(self.pos_input, &self.pos_input_fixed)?;
458
459        for i in 0..num_batches {
460            let timer = Instant::now();
461            let mut rng = rand::thread_rng();
462            let (xs, ys) = sample_dataset(dataset, batch_size, self.num_tokens, &mut rng);
463
464            self.graph.load_usize(self.token_input, &xs)?;
465            self.graph.load_usize(self.expected_output, &ys)?;
466
467            self.graph.forward(true)?;
468            self.graph.zero_grad()?;
469            let err = self.graph.backward_all(self.loss, limit)?;
470            let lr = learning_rate(self.graph.optimizer_step());
471            self.graph.optimize(optimizer, lr)?;
472            if i % 50 == 0 {
473                callback(self)?;
474            }
475            println!(
476                "Step: {} Loss: {} (Elapsed: {}ms)",
477                self.graph.optimizer_step(),
478                err,
479                timer.elapsed().as_millis()
480            );
481        }
482        Ok(())
483    }
484
485    pub fn infer<R: Rng, F: Fn(usize) -> ()>(
486        &mut self,
487        rng: &mut R,
488        prompt: &[usize],
489        count: usize,
490        temperature: f32,
491        callback: F,
492    ) -> Result<Vec<usize>, GraphError> {
493        let mut cnt = prompt.len();
494        let mut context = vec![0; self.num_tokens];
495        context[..prompt.len()].copy_from_slice(prompt);
496
497        self.graph.load(self.pos_input, &self.pos_input_fixed)?;
498
499        for ch in prompt {
500            callback(*ch);
501        }
502        let mut chs = prompt.to_vec();
503        for _ in 0..count {
504            self.graph.load_usize(
505                self.token_input,
506                &Tensor::raw(&[1, self.num_tokens], context.clone())?,
507            )?;
508
509            self.graph.forward(false)?;
510            self.graph.fetch(self.output, false)?;
511            let next_ch = select(
512                rng,
513                &self
514                    .graph
515                    .get(self.output)?
516                    .as_float()?
517                    .get(0)?
518                    .get(cnt - 1)?,
519                temperature,
520            )?;
521
522            chs.push(next_ch);
523            callback(next_ch);
524            if cnt == self.num_tokens {
525                context.remove(0);
526                context.push(0);
527                cnt -= 1;
528            }
529            context[cnt] = next_ch;
530            cnt += 1;
531        }
532        Ok(chs)
533    }
534}