1use crate::Example;
4use anyhow::Result;
5
6pub struct EG01;
23
24impl Example for EG01 {
25 fn description(&self) -> String {
26 String::from("Example usage of `text_to_token_ids` and `token_ids_to_text`.")
27 }
28
29 fn page_source(&self) -> usize {
30 132_usize
31 }
32
33 fn main(&self) -> Result<()> {
34 use crate::listings::{
35 ch04::{generate_text_simple, Config, GPTModel},
36 ch05::{text_to_token_ids, token_ids_to_text},
37 };
38 use candle_core::{DType, Device};
39 use candle_nn::{VarBuilder, VarMap};
40 use tiktoken_rs::get_bpe_from_model;
41
42 let varmap = VarMap::new();
44 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
45 let cfg = Config::gpt2_124m();
46 let model = GPTModel::new(Config::gpt2_124m(), vb.pp("model"))?;
47
48 let start_context = "Every effort moves you";
50 let tokenizer = get_bpe_from_model("gpt2")?;
51
52 let max_new_tokens = 10_usize;
54 let token_ids = generate_text_simple(
55 &model,
56 text_to_token_ids(start_context, &tokenizer, vb.device())?,
57 max_new_tokens,
58 cfg.context_length,
59 )?;
60
61 println!(
63 "Output text:\n{:?}",
64 token_ids_to_text(token_ids, &tokenizer)
65 );
66 Ok(())
67 }
68}
69
70pub struct EG02;
87
88impl Example for EG02 {
89 fn description(&self) -> String {
90 let desc = "Example computation of cross-entropy and perplexity.";
91 String::from(desc)
92 }
93
94 fn page_source(&self) -> usize {
95 133_usize
96 }
97
98 fn main(&self) -> Result<()> {
99 use crate::listings::{
100 ch04::{Config, GPTModel},
101 ch05::token_ids_to_text,
102 };
103 use candle_core::{DType, Device, IndexOp, ModuleT, Tensor, D};
104 use candle_nn::{loss::cross_entropy, ops::softmax, VarBuilder, VarMap};
105 use tiktoken_rs::get_bpe_from_model;
106
107 let varmap = VarMap::new();
109 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
110 let cfg = Config::gpt2_124m();
111 let model = GPTModel::new(cfg, vb.pp("model"))?;
112
113 let inputs = Tensor::new(&[[16833_u32, 3626, 6100], [40, 1107, 588]], vb.device())?;
115 let targets = Tensor::new(&[[3626_u32, 6100, 345], [1107, 588, 11311]], vb.device())?;
116
117 let logits = model.forward_t(&inputs, false)?;
119 let probas = softmax(&logits, D::Minus1)?;
120 println!("{:?}", probas);
121
122 let token_ids = probas.argmax_keepdim(D::Minus1)?;
124 println!("Token IDs:\n{:?}", token_ids.to_vec3::<u32>());
125
126 let tokenizer = get_bpe_from_model("gpt2")?;
128 println!(
129 "Targets batch 1: {:?}",
130 token_ids_to_text(targets.i(0)?, &tokenizer)
131 );
132 println!(
133 "Outputs batch 1: {:?}",
134 token_ids_to_text(token_ids.i(0)?.flatten_all()?, &tokenizer)
135 );
136
137 let text_idx = 0_usize;
139 let target_probas_1 =
140 addons::get_target_token_probas_helper(text_idx, &targets, &probas, vb.device())?;
141
142 println!("Text 1: {:?}", target_probas_1);
143
144 let text_idx = 1_usize;
145 let target_probas_2 =
146 addons::get_target_token_probas_helper(text_idx, &targets, &probas, vb.device())?;
147
148 println!("Text 2: {:?}", target_probas_2);
149
150 let log_probas = Tensor::cat(&[&target_probas_1, &target_probas_2], 0)?.log()?;
152 println!("Log probas: {:?}", log_probas);
153
154 let avg_log_probas = log_probas.mean(0)?;
156 println!("Avg log probbas: {:?}", avg_log_probas);
157
158 let neg_avg_log_probas = (log_probas.mean(0)? * -1_f64)?;
160 println!("Neg avg log probbas: {:?}", neg_avg_log_probas);
161
162 println!("Logits shape: {:?}", logits);
164 println!("Targets shape: {:?}", targets);
165
166 let logits_flat = logits.flatten(0, 1)?;
167 let targets_flat = targets.flatten_all()?;
168 println!("Flattened logits: {:?}", logits_flat.shape());
169 println!("Flattened targets: {:?}", targets_flat.shape());
170
171 let loss = cross_entropy(&logits_flat, &targets_flat)?;
172 println!("loss: {:?}", loss);
173
174 let perplexity = loss.exp()?;
176 println!("perplexity: {:?}", perplexity);
177 Ok(())
178 }
179}
180
181pub struct EG03;
198
199impl Example for EG03 {
200 fn description(&self) -> String {
201 String::from("Split text into train and validation datasets and loaders.")
202 }
203
204 fn page_source(&self) -> usize {
205 141_usize
206 }
207
208 fn main(&self) -> Result<()> {
209 use crate::listings::ch02::DataLoader;
210
211 let (train_loader, val_loader) = addons::get_train_val_data_loaders(true)?;
212
213 let mut train_batcher = train_loader.batcher();
214 let mut val_batcher = val_loader.batcher();
215
216 println!("Train loader:");
217 while let Some(Ok((x, y))) = train_batcher.next() {
218 println!("{:?}, {:?}", x.shape(), y.shape())
219 }
220
221 println!("Valdiation loader:");
222 while let Some(Ok((x, y))) = val_batcher.next() {
223 println!("{:?}, {:?}", x.shape(), y.shape())
224 }
225 Ok(())
226 }
227}
228
229pub struct EG04;
246
247impl Example for EG04 {
248 fn description(&self) -> String {
249 String::from("Example usage of `calc_loss_loader`.")
250 }
251
252 fn page_source(&self) -> usize {
253 145_usize
254 }
255
256 fn main(&self) -> Result<()> {
257 use crate::listings::{
258 ch04::{Config, GPTModel},
259 ch05::calc_loss_loader,
260 };
261 use candle_core::{DType, Device};
262 use candle_nn::{VarBuilder, VarMap};
263
264 let varmap = VarMap::new();
266 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
267 let cfg = Config::gpt2_124m();
268 let model = GPTModel::new(cfg, vb.pp("model"))?;
269
270 let (train_loader, val_loader) = addons::get_train_val_data_loaders(false)?;
272
273 let train_loss = calc_loss_loader(&train_loader, &model, vb.device(), None, None)?;
275 let val_loss = calc_loss_loader(&val_loader, &model, vb.device(), None, None)?;
276
277 println!("Training loss {:?}", train_loss);
278 println!("Validation loss {:?}", val_loss);
279 Ok(())
280 }
281}
282
283pub struct EG05;
300
301impl Example for EG05 {
302 fn description(&self) -> String {
303 String::from("Example usage of `train_model_simple` function and plotting loss curves.")
304 }
305
306 fn page_source(&self) -> usize {
307 149_usize
308 }
309
310 fn main(&self) -> Result<()> {
311 use crate::listings::{
312 ch04::{generate_text_simple, Config, GPTModel},
313 ch05::{plot_losses, text_to_token_ids, token_ids_to_text, train_model_simple},
314 };
315 use candle_core::{DType, Device};
316 use candle_nn::{AdamW, Optimizer, ParamsAdamW, VarBuilder, VarMap};
317 use ndarray::linspace;
318 use std::path::Path;
319 use tiktoken_rs::get_bpe_from_model;
320
321 let varmap = VarMap::new();
322 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
323 let cfg = Config::gpt2_124m();
324 let model = GPTModel::new(Config::gpt2_124m(), vb.pp("model"))?;
325 let optimizer = AdamW::new(
326 varmap.all_vars(),
327 ParamsAdamW {
328 lr: 0.0004,
329 weight_decay: 0.1,
330 ..Default::default()
331 },
332 )?;
333 let tokenizer = get_bpe_from_model("gpt2")?;
334 let (eval_freq, eval_iter, num_epochs) = (5_usize, 5_usize, 10_usize);
335 let (train_loader, val_loader) = addons::get_train_val_data_loaders(false)?;
336 let start_context = "Every effort moves you";
337 let (train_losses, val_losses, tokens_seen) = train_model_simple(
338 &model,
339 &train_loader,
340 &val_loader,
341 optimizer,
342 vb.device(),
343 num_epochs,
344 eval_freq,
345 eval_iter,
346 start_context,
347 &tokenizer,
348 None,
349 )?;
350
351 let token_ids = generate_text_simple(
353 &model,
354 text_to_token_ids(start_context, &tokenizer, vb.device())?,
355 25,
356 cfg.context_length,
357 )?;
358
359 println!(
361 "Output text:\n{:?}",
362 token_ids_to_text(token_ids, &tokenizer)
363 );
364
365 println!("Saving weights to `./plot_pretraining_loss.html`");
367 let epochs_seen = Vec::from_iter(linspace(0_f32, num_epochs as f32, train_losses.len()));
368 let tokens_seen = tokens_seen
369 .into_iter()
370 .map(|el| el as f32)
371 .collect::<Vec<_>>();
372 let save_path = Path::new("plot_retraining_loss.html").to_path_buf();
373 plot_losses(
374 epochs_seen,
375 tokens_seen,
376 train_losses,
377 val_losses,
378 save_path,
379 )?;
380
381 Ok(())
382 }
383}
384
385pub struct EG06;
402
403impl Example for EG06 {
404 fn description(&self) -> String {
405 String::from("Manual multinomial with/without temperature scaling decoding example.")
406 }
407
408 fn page_source(&self) -> usize {
409 152_usize
410 }
411
412 #[allow(unused_variables)]
413 fn main(&self) -> Result<()> {
414 use crate::listings::ch05::{print_sampled_tokens, sample_multinomial};
415 use candle_core::D;
416 use candle_nn::ops::softmax;
417 use rand::{rngs::StdRng, SeedableRng};
418
419 let (vocab, inverse_vocab) = addons::get_vocab_and_inversed_vocab();
420 let next_token_logits = addons::get_next_token_logits()?;
421
422 let probas = softmax(&next_token_logits, D::Minus1)?;
423
424 let next_token_id = probas.argmax(D::Minus1)?;
426 println!(
427 "Greedy sampling next token: {:?}",
428 inverse_vocab.get(&next_token_id.to_scalar::<u32>()?)
429 );
430
431 let mut rng = StdRng::seed_from_u64(123_u64);
433 let next_token_id = sample_multinomial(&mut rng, &probas.to_vec1::<f32>()?)?;
434 println!(
435 "Multinomial samping next token: {:?}",
436 inverse_vocab.get(&next_token_id)
437 );
438
439 let temp = 0.1;
441 let scaled_logits = (next_token_logits / temp)?;
442 let scaled_probas = softmax(&scaled_logits, D::Minus1)?;
443 let next_token_id = sample_multinomial(&mut rng, &scaled_probas.to_vec1::<f32>()?)?;
444 println!(
445 "Temp (temp=0.1) scaled multinomial sampling next token: {:?}",
446 inverse_vocab.get(&next_token_id)
447 );
448
449 println!("Temp (temp=1.0) scaling sampling conducted 1000 times:");
451 let with_expected_vals = false;
452 print_sampled_tokens(
453 &probas.to_vec1::<f32>()?,
454 &inverse_vocab,
455 with_expected_vals, )?;
457 Ok(())
458 }
459}
460
461pub struct EG07;
478
479impl Example for EG07 {
480 fn description(&self) -> String {
481 String::from("Example of extracting topk probas.")
482 }
483
484 fn page_source(&self) -> usize {
485 156_usize
486 }
487
488 #[allow(dead_code, unused_variables)]
489 fn main(&self) -> Result<()> {
490 use crate::candle_addons::TopK;
491 use candle_core::{Tensor, D};
492 use candle_nn::ops::softmax;
493
494 let (vocab, inverse_vocab) = addons::get_vocab_and_inversed_vocab();
495 let next_token_logits = addons::get_next_token_logits()?;
496
497 let top_k = 3_usize;
499 let (top_logits, top_pos) = next_token_logits.topk_last_dim0(top_k)?;
500 println!("Top logits: {:?}", top_logits.to_vec1::<f32>());
501 println!("Top pos: {:?}", top_pos.to_vec1::<u32>());
502
503 let mask = next_token_logits.broadcast_lt(&top_logits.min(D::Minus1)?)?;
505 let on_true = next_token_logits
506 .ones_like()?
507 .broadcast_mul(&Tensor::new(f32::NEG_INFINITY, next_token_logits.device())?)?;
508 let new_logits = mask.where_cond(&on_true, &next_token_logits)?;
509 println!("mask: {:?}", mask);
510 println!("new_logits: {:?}", new_logits);
511
512 let topk_probas = softmax(&new_logits, D::Minus1)?;
514 println!("probas: {:?}", topk_probas);
515 Ok(())
516 }
517}
518
519pub struct EG08;
536
537impl Example for EG08 {
538 fn description(&self) -> String {
539 String::from("Example usage of `generate`.")
540 }
541
542 fn page_source(&self) -> usize {
543 158_usize
544 }
545
546 fn main(&self) -> Result<()> {
547 use crate::listings::{
548 ch04::{Config, GPTModel},
549 ch05::{generate, text_to_token_ids, token_ids_to_text},
550 };
551 use candle_core::{DType, Device};
552 use candle_nn::{VarBuilder, VarMap};
553 use rand::{rngs::StdRng, SeedableRng};
554 use tiktoken_rs::get_bpe_from_model;
555
556 let varmap = VarMap::new();
558 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
559 let cfg = Config::gpt2_124m();
560 let model = GPTModel::new(Config::gpt2_124m(), vb.pp("model"))?;
561
562 let start_context = "Every effort moves you";
564 let tokenizer = get_bpe_from_model("gpt2")?;
565
566 let mut rng = StdRng::seed_from_u64(42_u64);
568 let token_ids = generate(
569 &model,
570 text_to_token_ids(start_context, &tokenizer, vb.device())?,
571 15_usize,
572 cfg.context_length,
573 Some(1.4_f64),
574 Some(25_usize),
575 None,
576 &mut rng,
577 )?;
578
579 println!(
581 "Output text:\n{:?}",
582 token_ids_to_text(token_ids, &tokenizer)
583 );
584 Ok(())
585 }
586}
587
588pub struct EG09;
605
606impl Example for EG09 {
607 fn description(&self) -> String {
608 String::from("Saving and loading a candle model.")
609 }
610
611 fn page_source(&self) -> usize {
612 159_usize
613 }
614
615 fn main(&self) -> Result<()> {
616 use crate::listings::{
617 ch04::{Config, GPTModel},
618 ch05::train_model_simple,
619 };
620 use candle_core::{DType, Device, Error, IndexOp};
621 use candle_nn::{AdamW, Optimizer, ParamsAdamW, VarBuilder, VarMap};
622 use tiktoken_rs::get_bpe_from_model;
623
624 let varmap = VarMap::new();
626 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
627 let cfg = Config::gpt2_124m();
628 let model = GPTModel::new(cfg, vb.pp("model"))?;
629 let optimizer = AdamW::new(
630 varmap.all_vars(),
631 ParamsAdamW {
632 lr: 0.0004,
633 weight_decay: 0.1,
634 ..Default::default()
635 },
636 )?;
637
638 let tokenizer = get_bpe_from_model("gpt2")?;
640 let (eval_freq, eval_iter, num_epochs) = (5_usize, 5_usize, 1_usize);
641 let (train_loader, val_loader) = addons::get_train_val_data_loaders(false)?;
642 let start_context = "Every effort moves you";
643 let _ = train_model_simple(
644 &model,
645 &train_loader,
646 &val_loader,
647 optimizer,
648 vb.device(),
649 num_epochs,
650 eval_freq,
651 eval_iter,
652 start_context,
653 &tokenizer,
654 None,
655 );
656
657 println!(
659 "model.out_head.weight first 10 weights BEFORE save: {:?}",
660 varmap
661 .data()
662 .lock()
663 .unwrap()
664 .get("model.out_head.weight")
665 .ok_or_else(|| {
666 Error::CannotFindTensor {
667 path: "model.out_head.weight".to_string(),
668 }
669 .bt()
670 })?
671 .i((1, ..10))?
672 .to_vec1::<f32>()
673 );
674
675 println!("Saving weights to `./checkpoint.safetensors`");
676 varmap.save("checkpoint.safetensors")?;
677
678 let mut varmap = VarMap::new();
680 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
681 let _model = GPTModel::new(cfg, vb.pp("model"))?;
682 println!(
683 "model.out_head.weight first 10 weights BEFORE load: {:?}",
684 varmap
685 .data()
686 .lock()
687 .unwrap()
688 .get("model.out_head.weight")
689 .ok_or_else(|| {
690 Error::CannotFindTensor {
691 path: "model.out_head.weight".to_string(),
692 }
693 .bt()
694 })?
695 .i((1, ..10))?
696 .to_vec1::<f32>()
697 );
698
699 println!("Loading weights from `./checkpoint.safetensors`");
701 varmap.load("checkpoint.safetensors")?;
702 println!(
703 "model.out_head.weight first 10 weights AFTER load: {:?}",
704 varmap
705 .data()
706 .lock()
707 .unwrap()
708 .get("model.out_head.weight")
709 .ok_or_else(|| {
710 Error::CannotFindTensor {
711 path: "model.out_head.weight".to_string(),
712 }
713 .bt()
714 })?
715 .i((1, ..10))?
716 .to_vec1::<f32>()
717 );
718 Ok(())
719 }
720}
721
722pub struct EG10;
739
740impl Example for EG10 {
741 fn description(&self) -> String {
742 String::from("Example for downloading safetensors from HuggingFace Hub.")
743 }
744
745 fn page_source(&self) -> usize {
746 161_usize
747 }
748
749 fn main(&self) -> Result<()> {
750 use crate::listings::ch04::Config;
751 use candle_core::Device;
752 use hf_hub::api::sync::Api;
753
754 let api = Api::new()?;
755 let repo = api.model("openai-community/gpt2".to_string());
756 let weights = repo.get("model.safetensors")?;
757 let weights = candle_core::safetensors::load(weights, &Device::Cpu)?;
758
759 let mut cfg = Config::gpt2_124m();
761 cfg.qkv_bias = true;
762
763 println!("{:?}", cfg);
764
765 println!("{:?}", weights);
766 Ok(())
767 }
768}
769
770pub struct EG11;
787
788impl Example for EG11 {
789 fn description(&self) -> String {
790 String::from("Example usage of `load_weights_into_gpt`.")
791 }
792
793 fn page_source(&self) -> usize {
794 167_usize
795 }
796
797 fn main(&self) -> Result<()> {
798 use crate::listings::{
799 ch04::{Config, GPTModel},
800 ch05::{generate, load_weights_into_gpt, text_to_token_ids, token_ids_to_text},
801 };
802 use candle_core::{DType, Device};
803 use candle_nn::{VarBuilder, VarMap};
804 use hf_hub::api::sync::Api;
805 use rand::{rngs::StdRng, SeedableRng};
806 use tiktoken_rs::get_bpe_from_model;
807
808 let dev = Device::cuda_if_available(0)?;
809 let varmap = VarMap::new();
810 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
811 let mut cfg = Config::gpt2_124m();
812 cfg.qkv_bias = true;
813 let model = GPTModel::new(cfg, vb.pp("model"))?;
814
815 let api = Api::new()?;
817 let repo = api.model("openai-community/gpt2".to_string());
818 let weights = repo.get("model.safetensors")?;
819 let weights = candle_core::safetensors::load(weights, &dev)?;
820
821 load_weights_into_gpt(&varmap, weights, Some("model"), cfg.n_layers)?;
823
824 let start_context = "Every effort moves you";
826 let tokenizer = get_bpe_from_model("gpt2")?;
827
828 let mut rng = StdRng::seed_from_u64(42_u64);
829 let token_ids = generate(
830 &model,
831 text_to_token_ids(start_context, &tokenizer, vb.device())?,
832 25_usize,
833 cfg.context_length,
834 Some(0.1_f64),
835 Some(50_usize),
836 None,
837 &mut rng,
838 )?;
839
840 println!(
842 "Output text:\n{:?}",
843 token_ids_to_text(token_ids, &tokenizer)?
844 );
845 Ok(())
846 }
847}
848
849pub mod addons {
850 use crate::listings::ch02::GPTDataLoader;
852 use candle_core::{Device, IndexOp, Result, Tensor};
853 use std::collections::HashMap;
854
855 pub fn get_target_token_probas_helper(
857 text_idx: usize,
858 targets: &Tensor,
859 probas: &Tensor,
860 dev: &Device,
861 ) -> Result<Tensor> {
862 let target_tokens_1 = targets.i(text_idx)?.to_vec1::<u32>()?;
863 let mut target_probas_1: Vec<f32> = vec![];
864 for (i, target_token) in target_tokens_1.iter().enumerate() {
865 let target_proba = probas
866 .i((text_idx, i, *target_token as usize))?
867 .to_scalar::<f32>()?;
868 target_probas_1.push(target_proba);
869 }
870 Tensor::from_vec(target_probas_1, target_tokens_1.len(), dev)
871 }
872
873 pub fn get_train_val_data_loaders(
875 verbose: bool,
876 ) -> anyhow::Result<(GPTDataLoader, GPTDataLoader)> {
877 use crate::listings::{ch02::create_dataloader_v1, ch04::Config};
878 use std::fs;
879 use tiktoken_rs::get_bpe_from_model;
880
881 let text_data =
883 fs::read_to_string("data/the-verdict.txt").expect("Unable to read the file");
884 let total_characters = text_data.len();
885 let tokenizer = get_bpe_from_model("gpt2")?;
886 let total_tokens = tokenizer
887 .encode_with_special_tokens(text_data.as_str())
888 .len();
889 if verbose {
890 println!("Characters: {:?}", total_characters);
891 println!("Tokens: {:?}", total_tokens);
892 }
893
894 let train_ratio = 0.90_f32;
896 let split_idx = (train_ratio * text_data.len() as f32) as usize;
897 let train_data = &text_data[..split_idx];
898 let val_data = &text_data[split_idx..];
899
900 let mut cfg = Config::gpt2_124m();
902 cfg.context_length = 256_usize;
903
904 let batch_size = 2_usize;
905 let max_length = cfg.context_length;
906 let stride = cfg.context_length;
907
908 let train_loader =
909 create_dataloader_v1(train_data, batch_size, max_length, stride, true, true);
910 let val_loader =
911 create_dataloader_v1(val_data, batch_size, max_length, stride, false, false);
912
913 Ok((train_loader, val_loader))
914 }
915
916 pub fn get_vocab_and_inversed_vocab() -> (HashMap<&'static str, u32>, HashMap<u32, &'static str>)
918 {
919 let vocab = HashMap::from([
920 ("closer", 0_u32),
921 ("every", 1),
922 ("effort", 2),
923 ("forward", 3),
924 ("inches", 4),
925 ("moves", 5),
926 ("pizza", 6),
927 ("toward", 7),
928 ("you", 8),
929 ]);
930 let inverse_vocab = vocab
931 .iter()
932 .map(|(k, v)| (*v, *k))
933 .collect::<HashMap<u32, &str>>();
934 (vocab, inverse_vocab)
935 }
936
937 pub fn get_next_token_logits() -> Result<Tensor> {
949 #![allow(clippy::approx_constant)]
950 let dev = Device::cuda_if_available(0)?;
951 Tensor::new(
952 &[4.51_f32, 0.89, -1.90, 6.75, 1.63, -1.62, -1.89, 6.28, 1.79],
953 &dev,
954 )
955 }
956}