llms_from_scratch_rs/examples/
ch05.rs

1//! Examples from Chapter 5
2
3use crate::Example;
4use anyhow::Result;
5
6/// # Example usage of `text_to_token_ids` and `token_ids_to_text`
7///
8/// #### Id
9/// 05.01
10///
11/// #### Page
12/// This example starts on page 132
13///
14/// #### CLI command
15/// ```sh
16/// # without cuda
17/// cargo run example 05.01
18///
19/// # with cuda
20/// cargo run --features cuda example 05.01
21/// ```
22pub struct EG01;
23
24impl Example for EG01 {
25    fn description(&self) -> String {
26        String::from("Example usage of `text_to_token_ids` and `token_ids_to_text`.")
27    }
28
29    fn page_source(&self) -> usize {
30        132_usize
31    }
32
33    fn main(&self) -> Result<()> {
34        use crate::listings::{
35            ch04::{generate_text_simple, Config, GPTModel},
36            ch05::{text_to_token_ids, token_ids_to_text},
37        };
38        use candle_core::{DType, Device};
39        use candle_nn::{VarBuilder, VarMap};
40        use tiktoken_rs::get_bpe_from_model;
41
42        // construct model
43        let varmap = VarMap::new();
44        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
45        let cfg = Config::gpt2_124m();
46        let model = GPTModel::new(Config::gpt2_124m(), vb.pp("model"))?;
47
48        // sample setup and load tokenizer
49        let start_context = "Every effort moves you";
50        let tokenizer = get_bpe_from_model("gpt2")?;
51
52        // generate next tokens with model
53        let max_new_tokens = 10_usize;
54        let token_ids = generate_text_simple(
55            &model,
56            text_to_token_ids(start_context, &tokenizer, vb.device())?,
57            max_new_tokens,
58            cfg.context_length,
59        )?;
60
61        // decode the token ids to print the output text
62        println!(
63            "Output text:\n{:?}",
64            token_ids_to_text(token_ids, &tokenizer)
65        );
66        Ok(())
67    }
68}
69
70/// # Example computation of cross-entropy and perplexity
71///
72/// #### Id
73/// 05.02
74///
75/// #### Page
76/// This example starts on page 133
77///
78/// #### CLI command
79/// ```sh
80/// # without cuda
81/// cargo run example 05.02
82///
83/// # with cuda
84/// cargo run --features cuda example 05.02
85/// ```
86pub struct EG02;
87
88impl Example for EG02 {
89    fn description(&self) -> String {
90        let desc = "Example computation of cross-entropy and perplexity.";
91        String::from(desc)
92    }
93
94    fn page_source(&self) -> usize {
95        133_usize
96    }
97
98    fn main(&self) -> Result<()> {
99        use crate::listings::{
100            ch04::{Config, GPTModel},
101            ch05::token_ids_to_text,
102        };
103        use candle_core::{DType, Device, IndexOp, ModuleT, Tensor, D};
104        use candle_nn::{loss::cross_entropy, ops::softmax, VarBuilder, VarMap};
105        use tiktoken_rs::get_bpe_from_model;
106
107        // construct model
108        let varmap = VarMap::new();
109        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
110        let cfg = Config::gpt2_124m();
111        let model = GPTModel::new(cfg, vb.pp("model"))?;
112
113        // inputs and target tensors
114        let inputs = Tensor::new(&[[16833_u32, 3626, 6100], [40, 1107, 588]], vb.device())?;
115        let targets = Tensor::new(&[[3626_u32, 6100, 345], [1107, 588, 11311]], vb.device())?;
116
117        // logits and probas
118        let logits = model.forward_t(&inputs, false)?;
119        let probas = softmax(&logits, D::Minus1)?;
120        println!("{:?}", probas);
121
122        // get next token id from probas
123        let token_ids = probas.argmax_keepdim(D::Minus1)?;
124        println!("Token IDs:\n{:?}", token_ids.to_vec3::<u32>());
125
126        // compare predictions to targets
127        let tokenizer = get_bpe_from_model("gpt2")?;
128        println!(
129            "Targets batch 1: {:?}",
130            token_ids_to_text(targets.i(0)?, &tokenizer)
131        );
132        println!(
133            "Outputs batch 1: {:?}",
134            token_ids_to_text(token_ids.i(0)?.flatten_all()?, &tokenizer)
135        );
136
137        // let's see the predicted probas for the target tokens
138        let text_idx = 0_usize;
139        let target_probas_1 =
140            addons::get_target_token_probas_helper(text_idx, &targets, &probas, vb.device())?;
141
142        println!("Text 1: {:?}", target_probas_1);
143
144        let text_idx = 1_usize;
145        let target_probas_2 =
146            addons::get_target_token_probas_helper(text_idx, &targets, &probas, vb.device())?;
147
148        println!("Text 2: {:?}", target_probas_2);
149
150        // compute log probas
151        let log_probas = Tensor::cat(&[&target_probas_1, &target_probas_2], 0)?.log()?;
152        println!("Log probas: {:?}", log_probas);
153
154        // compute average
155        let avg_log_probas = log_probas.mean(0)?;
156        println!("Avg log probbas: {:?}", avg_log_probas);
157
158        // compute negative average log probas or cross-entropy
159        let neg_avg_log_probas = (log_probas.mean(0)? * -1_f64)?;
160        println!("Neg avg log probbas: {:?}", neg_avg_log_probas);
161
162        // compute cross entropy with candle_nn::ops::loss::cross_entropy
163        println!("Logits shape: {:?}", logits);
164        println!("Targets shape: {:?}", targets);
165
166        let logits_flat = logits.flatten(0, 1)?;
167        let targets_flat = targets.flatten_all()?;
168        println!("Flattened logits: {:?}", logits_flat.shape());
169        println!("Flattened targets: {:?}", targets_flat.shape());
170
171        let loss = cross_entropy(&logits_flat, &targets_flat)?;
172        println!("loss: {:?}", loss);
173
174        // perplexity
175        let perplexity = loss.exp()?;
176        println!("perplexity: {:?}", perplexity);
177        Ok(())
178    }
179}
180
181/// # Split text into train and validation datasets and loaders
182///
183/// #### Id
184/// 05.03
185///
186/// #### Page
187/// This example starts on page 141
188///
189/// #### CLI command
190/// ```sh
191/// # without cuda
192/// cargo run example 05.03
193///
194/// # with cuda
195/// cargo run --features cuda example 05.03
196/// ```
197pub struct EG03;
198
199impl Example for EG03 {
200    fn description(&self) -> String {
201        String::from("Split text into train and validation datasets and loaders.")
202    }
203
204    fn page_source(&self) -> usize {
205        141_usize
206    }
207
208    fn main(&self) -> Result<()> {
209        use crate::listings::ch02::DataLoader;
210
211        let (train_loader, val_loader) = addons::get_train_val_data_loaders(true)?;
212
213        let mut train_batcher = train_loader.batcher();
214        let mut val_batcher = val_loader.batcher();
215
216        println!("Train loader:");
217        while let Some(Ok((x, y))) = train_batcher.next() {
218            println!("{:?}, {:?}", x.shape(), y.shape())
219        }
220
221        println!("Valdiation loader:");
222        while let Some(Ok((x, y))) = val_batcher.next() {
223            println!("{:?}, {:?}", x.shape(), y.shape())
224        }
225        Ok(())
226    }
227}
228
229/// # Example usage of `calc_loss_loader
230///
231/// #### Id
232/// 05.04
233///
234/// #### Page
235/// This example starts on page 145
236///
237/// #### CLI command
238/// ```sh
239/// # without cuda
240/// cargo run example 05.04
241///
242/// # with cuda
243/// cargo run --features cuda example 05.04
244/// ```
245pub struct EG04;
246
247impl Example for EG04 {
248    fn description(&self) -> String {
249        String::from("Example usage of `calc_loss_loader`.")
250    }
251
252    fn page_source(&self) -> usize {
253        145_usize
254    }
255
256    fn main(&self) -> Result<()> {
257        use crate::listings::{
258            ch04::{Config, GPTModel},
259            ch05::calc_loss_loader,
260        };
261        use candle_core::{DType, Device};
262        use candle_nn::{VarBuilder, VarMap};
263
264        // construct model
265        let varmap = VarMap::new();
266        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
267        let cfg = Config::gpt2_124m();
268        let model = GPTModel::new(cfg, vb.pp("model"))?;
269
270        // build train and val loaders with utility function from addons module
271        let (train_loader, val_loader) = addons::get_train_val_data_loaders(false)?;
272
273        // compute train and val loss
274        let train_loss = calc_loss_loader(&train_loader, &model, vb.device(), None, None)?;
275        let val_loss = calc_loss_loader(&val_loader, &model, vb.device(), None, None)?;
276
277        println!("Training loss {:?}", train_loss);
278        println!("Validation loss {:?}", val_loss);
279        Ok(())
280    }
281}
282
283/// # Example usage of `train_model_simple` function and plotting loss curves
284///
285/// #### Id
286/// 05.05
287///
288/// #### Page
289/// This example starts on page 149
290///
291/// #### CLI command
292/// ```sh
293/// # without cuda
294/// cargo run example 05.05
295///
296/// # with cuda
297/// cargo run --features cuda example 05.05
298/// ```
299pub struct EG05;
300
301impl Example for EG05 {
302    fn description(&self) -> String {
303        String::from("Example usage of `train_model_simple` function and plotting loss curves.")
304    }
305
306    fn page_source(&self) -> usize {
307        149_usize
308    }
309
310    fn main(&self) -> Result<()> {
311        use crate::listings::{
312            ch04::{generate_text_simple, Config, GPTModel},
313            ch05::{plot_losses, text_to_token_ids, token_ids_to_text, train_model_simple},
314        };
315        use candle_core::{DType, Device};
316        use candle_nn::{AdamW, Optimizer, ParamsAdamW, VarBuilder, VarMap};
317        use ndarray::linspace;
318        use std::path::Path;
319        use tiktoken_rs::get_bpe_from_model;
320
321        let varmap = VarMap::new();
322        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
323        let cfg = Config::gpt2_124m();
324        let model = GPTModel::new(Config::gpt2_124m(), vb.pp("model"))?;
325        let optimizer = AdamW::new(
326            varmap.all_vars(),
327            ParamsAdamW {
328                lr: 0.0004,
329                weight_decay: 0.1,
330                ..Default::default()
331            },
332        )?;
333        let tokenizer = get_bpe_from_model("gpt2")?;
334        let (eval_freq, eval_iter, num_epochs) = (5_usize, 5_usize, 10_usize);
335        let (train_loader, val_loader) = addons::get_train_val_data_loaders(false)?;
336        let start_context = "Every effort moves you";
337        let (train_losses, val_losses, tokens_seen) = train_model_simple(
338            &model,
339            &train_loader,
340            &val_loader,
341            optimizer,
342            vb.device(),
343            num_epochs,
344            eval_freq,
345            eval_iter,
346            start_context,
347            &tokenizer,
348            None,
349        )?;
350
351        // run inference with trained model using deterministic decoding
352        let token_ids = generate_text_simple(
353            &model,
354            text_to_token_ids(start_context, &tokenizer, vb.device())?,
355            25,
356            cfg.context_length,
357        )?;
358
359        // should be the same as the last output generation during training
360        println!(
361            "Output text:\n{:?}",
362            token_ids_to_text(token_ids, &tokenizer)
363        );
364
365        // plot loss curves
366        println!("Saving weights to `./plot_pretraining_loss.html`");
367        let epochs_seen = Vec::from_iter(linspace(0_f32, num_epochs as f32, train_losses.len()));
368        let tokens_seen = tokens_seen
369            .into_iter()
370            .map(|el| el as f32)
371            .collect::<Vec<_>>();
372        let save_path = Path::new("plot_retraining_loss.html").to_path_buf();
373        plot_losses(
374            epochs_seen,
375            tokens_seen,
376            train_losses,
377            val_losses,
378            save_path,
379        )?;
380
381        Ok(())
382    }
383}
384
385/// # Manual multinomial with/without temperature scaling
386///
387/// #### Id
388/// 05.06
389///
390/// #### Page
391/// This example starts on page 152
392///
393/// #### CLI command
394/// ```sh
395/// # without cuda
396/// cargo run example 05.06
397///
398/// # with cuda
399/// cargo run --features cuda example 05.06
400/// ```
401pub struct EG06;
402
403impl Example for EG06 {
404    fn description(&self) -> String {
405        String::from("Manual multinomial with/without temperature scaling decoding example.")
406    }
407
408    fn page_source(&self) -> usize {
409        152_usize
410    }
411
412    #[allow(unused_variables)]
413    fn main(&self) -> Result<()> {
414        use crate::listings::ch05::{print_sampled_tokens, sample_multinomial};
415        use candle_core::D;
416        use candle_nn::ops::softmax;
417        use rand::{rngs::StdRng, SeedableRng};
418
419        let (vocab, inverse_vocab) = addons::get_vocab_and_inversed_vocab();
420        let next_token_logits = addons::get_next_token_logits()?;
421
422        let probas = softmax(&next_token_logits, D::Minus1)?;
423
424        // greedy sampling
425        let next_token_id = probas.argmax(D::Minus1)?;
426        println!(
427            "Greedy sampling next token: {:?}",
428            inverse_vocab.get(&next_token_id.to_scalar::<u32>()?)
429        );
430
431        // multinomial sampling
432        let mut rng = StdRng::seed_from_u64(123_u64);
433        let next_token_id = sample_multinomial(&mut rng, &probas.to_vec1::<f32>()?)?;
434        println!(
435            "Multinomial samping next token: {:?}",
436            inverse_vocab.get(&next_token_id)
437        );
438
439        // temperature scaling
440        let temp = 0.1;
441        let scaled_logits = (next_token_logits / temp)?;
442        let scaled_probas = softmax(&scaled_logits, D::Minus1)?;
443        let next_token_id = sample_multinomial(&mut rng, &scaled_probas.to_vec1::<f32>()?)?;
444        println!(
445            "Temp (temp=0.1) scaled multinomial sampling next token: {:?}",
446            inverse_vocab.get(&next_token_id)
447        );
448
449        // generate multinomial random sample
450        println!("Temp (temp=1.0) scaling sampling conducted 1000 times:");
451        let with_expected_vals = false;
452        print_sampled_tokens(
453            &probas.to_vec1::<f32>()?,
454            &inverse_vocab,
455            with_expected_vals, // this is set in Exercise 5.1
456        )?;
457        Ok(())
458    }
459}
460
461/// # Example of extracting topk probas
462///
463/// #### Id
464/// 05.06
465///
466/// #### Page
467/// This example starts on page 156
468///
469/// #### CLI command
470/// ```sh
471/// # without cuda
472/// cargo run example 05.07
473///
474/// # with cuda
475/// cargo run --features cuda example 05.07
476/// ```
477pub struct EG07;
478
479impl Example for EG07 {
480    fn description(&self) -> String {
481        String::from("Example of extracting topk probas.")
482    }
483
484    fn page_source(&self) -> usize {
485        156_usize
486    }
487
488    #[allow(dead_code, unused_variables)]
489    fn main(&self) -> Result<()> {
490        use crate::candle_addons::TopK;
491        use candle_core::{Tensor, D};
492        use candle_nn::ops::softmax;
493
494        let (vocab, inverse_vocab) = addons::get_vocab_and_inversed_vocab();
495        let next_token_logits = addons::get_next_token_logits()?;
496
497        // top-k logits
498        let top_k = 3_usize;
499        let (top_logits, top_pos) = next_token_logits.topk_last_dim0(top_k)?;
500        println!("Top logits: {:?}", top_logits.to_vec1::<f32>());
501        println!("Top pos: {:?}", top_pos.to_vec1::<u32>());
502
503        // masking to get new logits
504        let mask = next_token_logits.broadcast_lt(&top_logits.min(D::Minus1)?)?;
505        let on_true = next_token_logits
506            .ones_like()?
507            .broadcast_mul(&Tensor::new(f32::NEG_INFINITY, next_token_logits.device())?)?;
508        let new_logits = mask.where_cond(&on_true, &next_token_logits)?;
509        println!("mask: {:?}", mask);
510        println!("new_logits: {:?}", new_logits);
511
512        // get top-k probas
513        let topk_probas = softmax(&new_logits, D::Minus1)?;
514        println!("probas: {:?}", topk_probas);
515        Ok(())
516    }
517}
518
519/// # Example usage of `generate`
520///
521/// #### Id
522/// 05.08
523///
524/// #### Page
525/// This example starts on page 158
526///
527/// #### CLI command
528/// ```sh
529/// # without cuda
530/// cargo run example 05.08
531///
532/// # with cuda
533/// cargo run --features cuda example 05.08
534/// ```
535pub struct EG08;
536
537impl Example for EG08 {
538    fn description(&self) -> String {
539        String::from("Example usage of `generate`.")
540    }
541
542    fn page_source(&self) -> usize {
543        158_usize
544    }
545
546    fn main(&self) -> Result<()> {
547        use crate::listings::{
548            ch04::{Config, GPTModel},
549            ch05::{generate, text_to_token_ids, token_ids_to_text},
550        };
551        use candle_core::{DType, Device};
552        use candle_nn::{VarBuilder, VarMap};
553        use rand::{rngs::StdRng, SeedableRng};
554        use tiktoken_rs::get_bpe_from_model;
555
556        // construct model
557        let varmap = VarMap::new();
558        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
559        let cfg = Config::gpt2_124m();
560        let model = GPTModel::new(Config::gpt2_124m(), vb.pp("model"))?;
561
562        // sample setup and load tokenizer
563        let start_context = "Every effort moves you";
564        let tokenizer = get_bpe_from_model("gpt2")?;
565
566        // generate next tokens with model
567        let mut rng = StdRng::seed_from_u64(42_u64);
568        let token_ids = generate(
569            &model,
570            text_to_token_ids(start_context, &tokenizer, vb.device())?,
571            15_usize,
572            cfg.context_length,
573            Some(1.4_f64),
574            Some(25_usize),
575            None,
576            &mut rng,
577        )?;
578
579        // decode the token ids to print the output text
580        println!(
581            "Output text:\n{:?}",
582            token_ids_to_text(token_ids, &tokenizer)
583        );
584        Ok(())
585    }
586}
587
588/// # Saving and loading a candle model
589///
590/// #### Id
591/// 05.09
592///
593/// #### Page
594/// This example starts on page 159
595///
596/// #### CLI command
597/// ```sh
598/// # without cuda
599/// cargo run example 05.09
600///
601/// # with cuda
602/// cargo run --features cuda example 05.09
603/// ```
604pub struct EG09;
605
606impl Example for EG09 {
607    fn description(&self) -> String {
608        String::from("Saving and loading a candle model.")
609    }
610
611    fn page_source(&self) -> usize {
612        159_usize
613    }
614
615    fn main(&self) -> Result<()> {
616        use crate::listings::{
617            ch04::{Config, GPTModel},
618            ch05::train_model_simple,
619        };
620        use candle_core::{DType, Device, Error, IndexOp};
621        use candle_nn::{AdamW, Optimizer, ParamsAdamW, VarBuilder, VarMap};
622        use tiktoken_rs::get_bpe_from_model;
623
624        // construt model
625        let varmap = VarMap::new();
626        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
627        let cfg = Config::gpt2_124m();
628        let model = GPTModel::new(cfg, vb.pp("model"))?;
629        let optimizer = AdamW::new(
630            varmap.all_vars(),
631            ParamsAdamW {
632                lr: 0.0004,
633                weight_decay: 0.1,
634                ..Default::default()
635            },
636        )?;
637
638        // train model for an epoch
639        let tokenizer = get_bpe_from_model("gpt2")?;
640        let (eval_freq, eval_iter, num_epochs) = (5_usize, 5_usize, 1_usize);
641        let (train_loader, val_loader) = addons::get_train_val_data_loaders(false)?;
642        let start_context = "Every effort moves you";
643        let _ = train_model_simple(
644            &model,
645            &train_loader,
646            &val_loader,
647            optimizer,
648            vb.device(),
649            num_epochs,
650            eval_freq,
651            eval_iter,
652            start_context,
653            &tokenizer,
654            None,
655        );
656
657        // save weights
658        println!(
659            "model.out_head.weight first 10 weights BEFORE save: {:?}",
660            varmap
661                .data()
662                .lock()
663                .unwrap()
664                .get("model.out_head.weight")
665                .ok_or_else(|| {
666                    Error::CannotFindTensor {
667                        path: "model.out_head.weight".to_string(),
668                    }
669                    .bt()
670                })?
671                .i((1, ..10))?
672                .to_vec1::<f32>()
673        );
674
675        println!("Saving weights to `./checkpoint.safetensors`");
676        varmap.save("checkpoint.safetensors")?;
677
678        // construct a new copy of the model and its varmap
679        let mut varmap = VarMap::new();
680        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
681        let _model = GPTModel::new(cfg, vb.pp("model"))?;
682        println!(
683            "model.out_head.weight first 10 weights BEFORE load: {:?}",
684            varmap
685                .data()
686                .lock()
687                .unwrap()
688                .get("model.out_head.weight")
689                .ok_or_else(|| {
690                    Error::CannotFindTensor {
691                        path: "model.out_head.weight".to_string(),
692                    }
693                    .bt()
694                })?
695                .i((1, ..10))?
696                .to_vec1::<f32>()
697        );
698
699        // load the saved weights into the new model copy
700        println!("Loading weights from `./checkpoint.safetensors`");
701        varmap.load("checkpoint.safetensors")?;
702        println!(
703            "model.out_head.weight first 10 weights AFTER load: {:?}",
704            varmap
705                .data()
706                .lock()
707                .unwrap()
708                .get("model.out_head.weight")
709                .ok_or_else(|| {
710                    Error::CannotFindTensor {
711                        path: "model.out_head.weight".to_string(),
712                    }
713                    .bt()
714                })?
715                .i((1, ..10))?
716                .to_vec1::<f32>()
717        );
718        Ok(())
719    }
720}
721
722/// # Example for downloading safetensors from HuggingFace Hub
723///
724/// #### Id
725/// 05.10
726///
727/// #### Page
728/// This example starts on page 161
729///
730/// #### CLI command
731/// ```sh
732/// # without cuda
733/// cargo run example 05.10
734///
735/// # with cuda
736/// cargo run --features cuda example 05.10
737/// ```
738pub struct EG10;
739
740impl Example for EG10 {
741    fn description(&self) -> String {
742        String::from("Example for downloading safetensors from HuggingFace Hub.")
743    }
744
745    fn page_source(&self) -> usize {
746        161_usize
747    }
748
749    fn main(&self) -> Result<()> {
750        use crate::listings::ch04::Config;
751        use candle_core::Device;
752        use hf_hub::api::sync::Api;
753
754        let api = Api::new()?;
755        let repo = api.model("openai-community/gpt2".to_string());
756        let weights = repo.get("model.safetensors")?;
757        let weights = candle_core::safetensors::load(weights, &Device::Cpu)?;
758
759        // update config
760        let mut cfg = Config::gpt2_124m();
761        cfg.qkv_bias = true;
762
763        println!("{:?}", cfg);
764
765        println!("{:?}", weights);
766        Ok(())
767    }
768}
769
770/// # Example usage of `load_weights_into_gpt`
771///
772/// #### Id
773/// 05.11
774///
775/// #### Page
776/// This example starts on page 167
777///
778/// #### CLI command
779/// ```sh
780/// # without cuda
781/// cargo run example 05.11
782///
783/// # with cuda
784/// cargo run --features cuda example 05.11
785/// ```
786pub struct EG11;
787
788impl Example for EG11 {
789    fn description(&self) -> String {
790        String::from("Example usage of `load_weights_into_gpt`.")
791    }
792
793    fn page_source(&self) -> usize {
794        167_usize
795    }
796
797    fn main(&self) -> Result<()> {
798        use crate::listings::{
799            ch04::{Config, GPTModel},
800            ch05::{generate, load_weights_into_gpt, text_to_token_ids, token_ids_to_text},
801        };
802        use candle_core::{DType, Device};
803        use candle_nn::{VarBuilder, VarMap};
804        use hf_hub::api::sync::Api;
805        use rand::{rngs::StdRng, SeedableRng};
806        use tiktoken_rs::get_bpe_from_model;
807
808        let dev = Device::cuda_if_available(0)?;
809        let varmap = VarMap::new();
810        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
811        let mut cfg = Config::gpt2_124m();
812        cfg.qkv_bias = true;
813        let model = GPTModel::new(cfg, vb.pp("model"))?;
814
815        // get weights from HF Hub
816        let api = Api::new()?;
817        let repo = api.model("openai-community/gpt2".to_string());
818        let weights = repo.get("model.safetensors")?;
819        let weights = candle_core::safetensors::load(weights, &dev)?;
820
821        // load weights
822        load_weights_into_gpt(&varmap, weights, Some("model"), cfg.n_layers)?;
823
824        // sample setup and load tokenizer
825        let start_context = "Every effort moves you";
826        let tokenizer = get_bpe_from_model("gpt2")?;
827
828        let mut rng = StdRng::seed_from_u64(42_u64);
829        let token_ids = generate(
830            &model,
831            text_to_token_ids(start_context, &tokenizer, vb.device())?,
832            25_usize,
833            cfg.context_length,
834            Some(0.1_f64),
835            Some(50_usize),
836            None,
837            &mut rng,
838        )?;
839
840        // decode the token ids to print the output text
841        println!(
842            "Output text:\n{:?}",
843            token_ids_to_text(token_ids, &tokenizer)?
844        );
845        Ok(())
846    }
847}
848
849pub mod addons {
850    //! Auxiliary module for examples::ch05
851    use crate::listings::ch02::GPTDataLoader;
852    use candle_core::{Device, IndexOp, Result, Tensor};
853    use std::collections::HashMap;
854
855    /// Helper function to target_tokens of the _i-th_ input sequence
856    pub fn get_target_token_probas_helper(
857        text_idx: usize,
858        targets: &Tensor,
859        probas: &Tensor,
860        dev: &Device,
861    ) -> Result<Tensor> {
862        let target_tokens_1 = targets.i(text_idx)?.to_vec1::<u32>()?;
863        let mut target_probas_1: Vec<f32> = vec![];
864        for (i, target_token) in target_tokens_1.iter().enumerate() {
865            let target_proba = probas
866                .i((text_idx, i, *target_token as usize))?
867                .to_scalar::<f32>()?;
868            target_probas_1.push(target_proba);
869        }
870        Tensor::from_vec(target_probas_1, target_tokens_1.len(), dev)
871    }
872
873    /// Helper function for producing `GPTDataLoader` for train and val splits
874    pub fn get_train_val_data_loaders(
875        verbose: bool,
876    ) -> anyhow::Result<(GPTDataLoader, GPTDataLoader)> {
877        use crate::listings::{ch02::create_dataloader_v1, ch04::Config};
878        use std::fs;
879        use tiktoken_rs::get_bpe_from_model;
880
881        // load the verdict short story and compute stats
882        let text_data =
883            fs::read_to_string("data/the-verdict.txt").expect("Unable to read the file");
884        let total_characters = text_data.len();
885        let tokenizer = get_bpe_from_model("gpt2")?;
886        let total_tokens = tokenizer
887            .encode_with_special_tokens(text_data.as_str())
888            .len();
889        if verbose {
890            println!("Characters: {:?}", total_characters);
891            println!("Tokens: {:?}", total_tokens);
892        }
893
894        // establish train and val data
895        let train_ratio = 0.90_f32;
896        let split_idx = (train_ratio * text_data.len() as f32) as usize;
897        let train_data = &text_data[..split_idx];
898        let val_data = &text_data[split_idx..];
899
900        // build train and val GPTDatasetV1 and batchers
901        let mut cfg = Config::gpt2_124m();
902        cfg.context_length = 256_usize;
903
904        let batch_size = 2_usize;
905        let max_length = cfg.context_length;
906        let stride = cfg.context_length;
907
908        let train_loader =
909            create_dataloader_v1(train_data, batch_size, max_length, stride, true, true);
910        let val_loader =
911            create_dataloader_v1(val_data, batch_size, max_length, stride, false, false);
912
913        Ok((train_loader, val_loader))
914    }
915
916    /// Helper function to get vocab and inversed vocab `HashMap`'s
917    pub fn get_vocab_and_inversed_vocab() -> (HashMap<&'static str, u32>, HashMap<u32, &'static str>)
918    {
919        let vocab = HashMap::from([
920            ("closer", 0_u32),
921            ("every", 1),
922            ("effort", 2),
923            ("forward", 3),
924            ("inches", 4),
925            ("moves", 5),
926            ("pizza", 6),
927            ("toward", 7),
928            ("you", 8),
929        ]);
930        let inverse_vocab = vocab
931            .iter()
932            .map(|(k, v)| (*v, *k))
933            .collect::<HashMap<u32, &str>>();
934        (vocab, inverse_vocab)
935    }
936
937    /// Helper function to get the example next token logits used in the book
938    ///
939    /// ```rust
940    /// use candle_core::{Device, Tensor};
941    ///
942    /// let dev = Device::cuda_if_available(0).unwrap();
943    /// let next_token_logits = Tensor::new(
944    ///     &[4.51_f32, 0.89, -1.90, 6.75, 1.63, -1.62, -1.89, 6.28, 1.79],
945    ///     &dev,
946    /// );
947    /// ```
948    pub fn get_next_token_logits() -> Result<Tensor> {
949        #![allow(clippy::approx_constant)]
950        let dev = Device::cuda_if_available(0)?;
951        Tensor::new(
952            &[4.51_f32, 0.89, -1.90, 6.75, 1.63, -1.62, -1.89, 6.28, 1.79],
953            &dev,
954        )
955    }
956}