llms_from_scratch_rs/examples/
ch06.rs

1//! Examples from Chapter 6
2
3use crate::Example;
4use anyhow::{Context, Result};
5
6/// # Example usage of `download_and_unzip_spam_data`
7///
8/// #### Id
9/// 06.01
10///
11/// #### Page
12/// This example starts on page 173
13///
14/// #### CLI command
15/// ```sh
16/// # without cuda
17/// cargo run example 06.01
18///
19/// # with cuda
20/// cargo run --features cuda example 06.01
21/// ```
22pub struct EG01;
23
24impl Example for EG01 {
25    fn description(&self) -> String {
26        String::from("Sample usage of `download_and_unzip_spam_data`.")
27    }
28
29    fn page_source(&self) -> usize {
30        173_usize
31    }
32
33    fn main(&self) -> Result<()> {
34        use crate::listings::ch06::{download_and_unzip_spam_data, EXTRACTED_PATH, URL, ZIP_PATH};
35        use polars::prelude::*;
36        use std::sync::Arc;
37
38        // download sms spam .tsv file
39        download_and_unzip_spam_data(URL, ZIP_PATH, EXTRACTED_PATH)?;
40
41        // load in .tsv as a DataFrame
42        let f1 = Field::new("Label".into(), DataType::String);
43        let f2 = Field::new("Text".into(), DataType::String);
44        let sc = Arc::new(Schema::from_iter(vec![f1, f2]));
45        let parse_options = CsvParseOptions::default()
46            .with_separator(b'\t')
47            .with_quote_char(None);
48        let df = CsvReadOptions::default()
49            .with_parse_options(parse_options)
50            .with_schema(Some(sc))
51            .with_has_header(false)
52            .try_into_reader_with_file_path(Some("data/SMSSpamCollection.tsv".into()))
53            .unwrap()
54            .finish()?;
55        println!("{}", df);
56
57        // get value counts for label
58        let value_counts = addons::get_value_counts(&df, "Label")?;
59        println!("{}", value_counts);
60
61        Ok(())
62    }
63}
64
65/// # Example usage of `download_smsspam_parquet`
66///
67/// #### Id
68/// 06.02
69///
70/// #### Page
71/// This example starts on page 173
72///
73/// #### CLI command
74/// ```sh
75/// # without cuda
76/// cargo run example 06.02
77///
78/// # with cuda
79/// cargo run --features cuda example 06.02
80/// ```
81pub struct EG02;
82
83impl Example for EG02 {
84    fn description(&self) -> String {
85        String::from("Sample usage of `download_smsspam_parquet`.")
86    }
87
88    fn page_source(&self) -> usize {
89        173_usize
90    }
91
92    fn main(&self) -> Result<()> {
93        use crate::listings::ch06::{download_smsspam_parquet, PARQUET_FILENAME, PARQUET_URL};
94        use polars::prelude::*;
95        use std::path::PathBuf;
96
97        // download parquet file
98        download_smsspam_parquet(PARQUET_URL)?;
99
100        // load parquet
101        let mut file_path = PathBuf::from("data");
102        file_path.push(PARQUET_FILENAME);
103        let mut file = std::fs::File::open(file_path)?;
104        let df = ParquetReader::new(&mut file).finish()?;
105        let df = df
106            .clone()
107            .lazy()
108            .with_column(
109                when(col("label").eq(0))
110                    .then(lit("ham"))
111                    .otherwise(lit("spam"))
112                    .alias("label_text"),
113            )
114            .collect()?;
115        println!("{}", df);
116
117        // get value counts for label
118        let value_counts = addons::get_value_counts(&df, "label_text")?;
119        println!("{}", value_counts);
120
121        Ok(())
122    }
123}
124
125/// # Example usage of `create_balanced_dataset`
126///
127/// #### Id
128/// 06.03
129///
130/// #### Page
131/// This example starts on page 174
132///
133/// #### CLI command
134/// ```sh
135/// # without cuda
136/// cargo run example 06.03
137///
138/// # with cuda
139/// cargo run --features cuda example 06.03
140/// ```
141pub struct EG03;
142
143impl Example for EG03 {
144    fn description(&self) -> String {
145        String::from("Example usage of `create_balanced_dataset`.")
146    }
147
148    fn page_source(&self) -> usize {
149        174_usize
150    }
151
152    fn main(&self) -> Result<()> {
153        use crate::listings::ch06::{
154            create_balanced_dataset, download_smsspam_parquet, PARQUET_FILENAME, PARQUET_URL,
155        };
156        use polars::prelude::*;
157        use std::path::PathBuf;
158
159        // download parquet
160        download_smsspam_parquet(PARQUET_URL)?;
161
162        // load parquet
163        let mut file_path = PathBuf::from("data");
164        file_path.push(PARQUET_FILENAME);
165        let mut file = std::fs::File::open(file_path).unwrap();
166        let df = ParquetReader::new(&mut file).finish().unwrap();
167
168        // balance dataset
169        let balanced_df = create_balanced_dataset(df)?;
170        println!("{}", balanced_df);
171
172        // get value counts for label
173        let value_counts = addons::get_value_counts(&balanced_df, "label")?;
174        println!("{}", value_counts);
175
176        Ok(())
177    }
178}
179
180/// # Example usage of `random_split`
181///
182/// #### Id
183/// 06.04
184///
185/// #### Page
186/// This example starts on page 175
187///
188/// #### CLI command
189/// ```sh
190/// # without cuda
191/// cargo run example 06.04
192///
193/// # with cuda
194/// cargo run --features cuda example 06.04
195/// ```
196pub struct EG04;
197
198impl Example for EG04 {
199    fn description(&self) -> String {
200        String::from("Example usage of `random_split` to create our train, test, val splits.")
201    }
202
203    fn page_source(&self) -> usize {
204        174_usize
205    }
206
207    fn main(&self) -> Result<()> {
208        use crate::listings::ch06::{
209            create_balanced_dataset, download_smsspam_parquet, random_split, PARQUET_FILENAME,
210            PARQUET_URL,
211        };
212        use polars::prelude::*;
213        use std::{path::PathBuf, str::FromStr};
214
215        // download parquet
216        download_smsspam_parquet(PARQUET_URL)?;
217
218        // load parquet
219        let mut file_path = PathBuf::from("data");
220        file_path.push(PARQUET_FILENAME);
221        let mut file = std::fs::File::open(file_path).unwrap();
222        let df = ParquetReader::new(&mut file).finish().unwrap();
223
224        // balance dataset
225        let balanced_df = create_balanced_dataset(df)?;
226
227        // create train, test, val splits
228        let (mut train_df, mut validation_df, mut test_df) =
229            random_split(&balanced_df, 0.7_f32, 0.1_f32)?;
230        println!("{}", train_df);
231        println!("{}", validation_df);
232        println!("{}", test_df);
233
234        // save dfs to csv
235        let train_path = PathBuf::from_str("data/train.parquet")?;
236        let validation_path = PathBuf::from_str("data/validation.parquet")?;
237        let test_path = PathBuf::from_str("data/test.parquet")?;
238
239        addons::write_parquet(&mut train_df, train_path)?;
240        addons::write_parquet(&mut validation_df, validation_path)?;
241        addons::write_parquet(&mut test_df, test_path)?;
242
243        Ok(())
244    }
245}
246
247/// # Creating `SpamDataset` for train, test, and validation via `SpamDatasetBuilder`
248///
249/// #### Id
250/// 06.05
251///
252/// #### Page
253/// This example starts on page 178
254///
255/// #### CLI command
256/// ```sh
257/// # without cuda
258/// cargo run example 06.05
259///
260/// # with cuda
261/// cargo run --features cuda example 06.05
262/// ```
263pub struct EG05;
264
265impl Example for EG05 {
266    fn description(&self) -> String {
267        String::from("Creating `SpamDataset` for train, test, and validation.")
268    }
269
270    fn page_source(&self) -> usize {
271        178_usize
272    }
273
274    fn main(&self) -> Result<()> {
275        use crate::listings::ch06::SpamDatasetBuilder;
276        use anyhow::anyhow;
277        use std::ops::Not;
278        use std::path::Path;
279        use tiktoken_rs::get_bpe_from_model;
280
281        let tokenizer = get_bpe_from_model("gpt2")?;
282
283        let train_path = Path::new("data").join("train.parquet");
284        if train_path.exists().not() {
285            return Err(anyhow!(
286                "Missing 'data/train.parquet' file. Please run EG 06.04."
287            ));
288        }
289        let train_dataset = SpamDatasetBuilder::new(&tokenizer)
290            .load_data_from_parquet(train_path)
291            .build();
292        println!("train dataset max length: {}", train_dataset.max_length());
293
294        let val_path = Path::new("data").join("validation.parquet");
295        if val_path.exists().not() {
296            return Err(anyhow!(
297                "Missing 'data/validation.parquet' file. Please run EG 06.04."
298            ));
299        }
300        let val_dataset = SpamDatasetBuilder::new(&tokenizer)
301            .load_data_from_parquet(val_path)
302            .max_length(Some(train_dataset.max_length()))
303            .build();
304        println!("val dataset max length: {}", val_dataset.max_length());
305
306        let test_path = Path::new("data").join("test.parquet");
307        if test_path.exists().not() {
308            return Err(anyhow!(
309                "Missing 'data/test.parquet' file. Please run EG 06.04."
310            ));
311        }
312        let test_dataset = SpamDatasetBuilder::new(&tokenizer)
313            .load_data_from_parquet(test_path)
314            .max_length(Some(train_dataset.max_length()))
315            .build();
316        println!("test dataset max length: {}", test_dataset.max_length());
317        Ok(())
318    }
319}
320
321/// # Creating a `SpamDataLoader` for each of the train, val and test datasets.
322///
323/// #### Id
324/// 06.06
325///
326/// #### Page
327/// This example starts on page 180
328///
329/// #### CLI command
330/// ```sh
331/// # without cuda
332/// cargo run example 06.06
333///
334/// # with cuda
335/// cargo run --features cuda example 06.06
336/// ```
337pub struct EG06;
338
339impl EG06 {
340    pub fn main_with_return(
341        &self,
342        verbose: bool,
343    ) -> Result<(
344        crate::listings::ch06::SpamDataLoader,
345        crate::listings::ch06::SpamDataLoader,
346        crate::listings::ch06::SpamDataLoader,
347    )> {
348        use crate::listings::ch06::{SpamDataLoader, SpamDatasetBuilder};
349        use anyhow::anyhow;
350        use std::ops::Not;
351        use std::path::Path;
352        use tiktoken_rs::get_bpe_from_model;
353
354        // create datasets
355        let tokenizer = get_bpe_from_model("gpt2")?;
356
357        let train_path = Path::new("data").join("train.parquet");
358        if train_path.exists().not() {
359            return Err(anyhow!(
360                "Missing 'data/train.parquet' file. Please run EG 06.04."
361            ));
362        }
363        let train_dataset = SpamDatasetBuilder::new(&tokenizer)
364            .load_data_from_parquet(train_path)
365            .build();
366
367        let val_path = Path::new("data").join("validation.parquet");
368        if val_path.exists().not() {
369            return Err(anyhow!(
370                "Missing 'data/validation.parquet' file. Please run EG 06.04."
371            ));
372        }
373        let val_dataset = SpamDatasetBuilder::new(&tokenizer)
374            .load_data_from_parquet(val_path)
375            .build();
376
377        let test_path = Path::new("data").join("test.parquet");
378        if test_path.exists().not() {
379            return Err(anyhow!(
380                "Missing 'data/test.parquet' file. Please run EG 06.04."
381            ));
382        }
383        let test_dataset = SpamDatasetBuilder::new(&tokenizer)
384            .load_data_from_parquet(test_path)
385            .build();
386
387        // create loaders
388        let batch_size = 8_usize;
389        let train_loader = SpamDataLoader::new(train_dataset, batch_size, true, true);
390        let val_loader = SpamDataLoader::new(val_dataset, batch_size, false, false);
391        let test_loader = SpamDataLoader::new(test_dataset, batch_size, false, false);
392
393        if verbose {
394            // see last batch of train loader
395            let (input_batch, target_batch) = train_loader.batcher().last().unwrap()?;
396            println!("Input batch dimensions: {:?}", input_batch.shape());
397            println!("Label batch dimensions: {:?}", target_batch.shape());
398
399            // print total number of batches in each data loader
400            println!("{:?} training batches", train_loader.len());
401            println!("{:?} validation batches", val_loader.len());
402            println!("{:?} test batches", test_loader.len());
403        }
404
405        Ok((train_loader, val_loader, test_loader))
406    }
407}
408
409impl Example for EG06 {
410    fn description(&self) -> String {
411        "Creating a `SpamDataLoader` for each of the train, val and test datasets.".to_string()
412    }
413
414    fn page_source(&self) -> usize {
415        180_usize
416    }
417
418    fn main(&self) -> Result<()> {
419        let _ = self.main_with_return(true)?;
420        Ok(())
421    }
422}
423
424/// # Example usage of `download_and_load_gpt2`.
425///
426/// #### Id
427/// 06.07
428///
429/// #### Page
430/// This example starts on page 182
431///
432/// #### CLI command
433/// ```sh
434/// # without cuda
435/// cargo run example 06.07
436///
437/// # with cuda
438/// cargo run --features cuda example 06.07
439/// ```
440pub struct EG07;
441
442impl Example for EG07 {
443    fn description(&self) -> String {
444        String::from("Example usage of `download_and_load_gpt2`.")
445    }
446
447    fn page_source(&self) -> usize {
448        182_usize
449    }
450
451    fn main(&self) -> Result<()> {
452        use crate::listings::{
453            ch04::Config,
454            ch05::{generate, text_to_token_ids, token_ids_to_text},
455            ch06::{download_and_load_gpt2, HF_GPT2_MODEL_ID},
456        };
457        use candle_core::{DType, Device};
458        use candle_nn::{VarBuilder, VarMap};
459        use rand::{rngs::StdRng, SeedableRng};
460        use tiktoken_rs::get_bpe_from_model;
461
462        // use `download_and_load_gpt2`
463        let mut cfg = Config::gpt2_124m();
464        cfg.qkv_bias = true;
465        let varmap = VarMap::new();
466        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
467        let model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, HF_GPT2_MODEL_ID)?;
468
469        // sample setup and load tokenizer
470        let tokenizer = get_bpe_from_model("gpt2")?;
471        let mut rng = StdRng::seed_from_u64(42_u64);
472
473        // generate next tokens with model
474        let text_1 = "Every effort moves you";
475        let token_ids = generate(
476            &model,
477            text_to_token_ids(text_1, &tokenizer, vb.device())?,
478            15_usize,
479            cfg.context_length,
480            None,
481            None,
482            None,
483            &mut rng,
484        )?;
485
486        // decode the token ids to print the output text
487        println!(
488            "Output text:\n{:?}",
489            token_ids_to_text(token_ids, &tokenizer)
490        );
491
492        // test inherent classification abilities
493        let text_2 = "Is the following text 'spam'? Answer with 'yes' or \
494        'no': 'You are a winner you have been specially selected to receive $1000 \
495        cash or a $2000 award.'";
496        let token_ids = generate(
497            &model,
498            text_to_token_ids(text_2, &tokenizer, vb.device())?,
499            23_usize,
500            cfg.context_length,
501            None,
502            None,
503            None,
504            &mut rng,
505        )?;
506
507        // decode the token ids to print the classification
508        println!(
509            "Output text:\n{:?}",
510            token_ids_to_text(token_ids, &tokenizer)
511        );
512
513        Ok(())
514    }
515}
516
517/// # Printing the model variables via `varmap.data()`
518///
519/// #### Id
520/// 06.08
521///
522/// #### Page
523/// This example starts on page 185
524///
525/// #### CLI command
526/// ```sh
527/// # without cuda
528/// cargo run example 06.08
529///
530/// # with cuda
531/// cargo run --features cuda example 06.08
532/// ```
533pub struct EG08;
534
535impl Example for EG08 {
536    fn description(&self) -> String {
537        String::from("Printing the model architecture via `varmap.data()`.")
538    }
539
540    fn page_source(&self) -> usize {
541        185_usize
542    }
543
544    fn main(&self) -> Result<()> {
545        use crate::listings::{
546            ch04::Config,
547            ch06::{download_and_load_gpt2, HF_GPT2_MODEL_ID},
548        };
549        use candle_core::{DType, Device};
550        use candle_nn::{VarBuilder, VarMap};
551
552        // use `download_and_load_gpt2`
553        let mut cfg = Config::gpt2_124m();
554        cfg.qkv_bias = true;
555        let varmap = VarMap::new();
556        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
557        let model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, HF_GPT2_MODEL_ID)?;
558
559        // print model architecture
560        println!("{:#?}", model);
561
562        Ok(())
563    }
564}
565
566/// # Modifying the `out_head` of a GPT2Model and running inference
567///
568/// #### Id
569/// 06.09
570///
571/// #### Page
572/// This example starts on page 186
573///
574/// #### CLI command
575/// ```sh
576/// # without cuda
577/// cargo run example 06.09
578///
579/// # with cuda
580/// cargo run --features cuda example 06.09
581/// ```
582pub struct EG09;
583
584impl Example for EG09 {
585    fn description(&self) -> String {
586        String::from("Modifying the `out_head` of a GPT2Model and running inference.")
587    }
588
589    fn page_source(&self) -> usize {
590        186_usize
591    }
592
593    fn main(&self) -> Result<()> {
594        use crate::listings::{
595            ch04::Config,
596            ch06::{download_and_load_gpt2, modify_out_head_for_classification, HF_GPT2_MODEL_ID},
597        };
598        use candle_core::{DType, Device, IndexOp, ModuleT, Tensor};
599        use candle_nn::{VarBuilder, VarMap};
600        use tiktoken_rs::get_bpe_from_model;
601
602        // use `download_and_load_gpt2`
603        let mut cfg = Config::gpt2_124m();
604        cfg.qkv_bias = true;
605        let varmap = VarMap::new();
606        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
607        let mut model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, HF_GPT2_MODEL_ID)?;
608
609        // print old head
610        let tensor_data = varmap.data().lock().unwrap();
611        let out_head = tensor_data.get("model.out_head.weight");
612        println!("old classification head: {:?}", out_head);
613        drop(tensor_data);
614
615        // modify classification head
616        let num_classes = 2_usize;
617        modify_out_head_for_classification(&mut model, cfg, num_classes, &varmap, vb.pp("model"))?;
618
619        // get out head
620        let tensor_data = varmap.data().lock().unwrap();
621        let out_head = tensor_data.get("model.out_head.weight");
622        println!("new classification head: {:?}", out_head);
623
624        // run sample inference
625        let tokenizer = get_bpe_from_model("gpt2")?;
626        let inputs = tokenizer.encode_with_special_tokens("Do you have time");
627        let num_tokens = inputs.len();
628        let inputs = Tensor::from_vec(inputs, num_tokens, vb.device())?.unsqueeze(0)?;
629        println!("Inputs: {:?}", inputs.to_vec2::<u32>());
630        println!("Inputs dimensions: {:?}", inputs);
631
632        let outputs = model.forward_t(&inputs, false)?;
633        println!("Outputs: {:?}", outputs.to_vec3::<f32>());
634        println!("Outputs dimensions: {:?}", outputs);
635
636        // get last output token to use for making predictions of spam/ham
637        let (_b, c, _vocab_size) = outputs.dims3()?;
638        println!(
639            "Last output token: {:?}",
640            outputs.i((.., c - 1, ..))?.to_vec2::<f32>()
641        );
642
643        Ok(())
644    }
645}
646
647/// # Toy example of using `candle_nn::softmax` on output values to classify spam/ham
648///
649/// #### Id
650/// 06.10
651///
652/// #### Page
653/// This example starts on page 192
654///
655/// #### CLI command
656/// ```sh
657/// # without cuda
658/// cargo run example 06.10
659///
660/// # with cuda
661/// cargo run --features cuda example 06.10
662/// ```
663pub struct EG10;
664
665impl Example for EG10 {
666    fn description(&self) -> String {
667        "Toy example of predicting spam/ham from logits.".to_string()
668    }
669
670    fn page_source(&self) -> usize {
671        192_usize
672    }
673
674    fn main(&self) -> Result<()> {
675        use candle_core::{Device, Tensor, D};
676
677        let dev = Device::cuda_if_available(0)?;
678        let logits = Tensor::new(&[[-3.5983_f32, 3.9902]], &dev)?;
679        println!(
680            "Last output token (i.e. logits): {:?}",
681            logits.to_vec2::<f32>()?
682        );
683
684        let label = logits.argmax(D::Minus1)?;
685        println!("Class label: {:?}", label.squeeze(0)?.to_scalar::<u32>()?);
686
687        Ok(())
688    }
689}
690
691/// # Example usage of `calc_accuracy_loader` to compute accuracy on test, train, val sets
692///
693/// #### Id
694/// 06.11
695///
696/// #### Page
697/// This example starts on page 193
698///
699/// #### CLI command
700/// ```sh
701/// # without cuda
702/// cargo run example 06.11
703///
704/// # with cuda
705/// cargo run --features cuda example 06.11
706/// ```
707pub struct EG11;
708
709impl Example for EG11 {
710    fn description(&self) -> String {
711        let desc = "Example usage of `calc_accuracy_loader` to compute accuracy on \
712        test, train, val sets.";
713        desc.to_string()
714    }
715
716    fn page_source(&self) -> usize {
717        192_usize
718    }
719
720    fn main(&self) -> Result<()> {
721        use crate::listings::{
722            ch04::Config,
723            ch06::{
724                calc_accuracy_loader, download_and_load_gpt2, modify_out_head_for_classification,
725                HF_GPT2_MODEL_ID,
726            },
727        };
728        use candle_core::{DType, Device};
729        use candle_nn::{VarBuilder, VarMap};
730
731        // get gpt model with classification head
732        let mut cfg = Config::gpt2_124m();
733        cfg.qkv_bias = true;
734        let varmap = VarMap::new();
735        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
736        let mut model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, HF_GPT2_MODEL_ID)?;
737        modify_out_head_for_classification(&mut model, cfg, 2_usize, &varmap, vb.pp("model"))?;
738
739        // get data loaders
740        let eg06 = EG06; // re-use
741        let (train_loader, val_loader, test_loader) = eg06.main_with_return(false)?;
742
743        // compute accuracies
744        let num_batches = Some(10_usize);
745        let train_accuracy =
746            calc_accuracy_loader(&train_loader, &model, vb.device(), num_batches, None)?;
747        let val_accuracy =
748            calc_accuracy_loader(&val_loader, &model, vb.device(), num_batches, None)?;
749        let test_accuracy =
750            calc_accuracy_loader(&test_loader, &model, vb.device(), num_batches, None)?;
751
752        println!("Training accuracy: {}", train_accuracy);
753        println!("Validation accuracy: {}", val_accuracy);
754        println!("Test accuracy: {}", test_accuracy);
755
756        Ok(())
757    }
758}
759
760/// # Example usage of `calc_loss_loader` to compute cross-entropy loss on train, val, test sets
761///
762/// #### Id
763/// 06.12
764///
765/// #### Page
766/// This example starts on page 194
767///
768/// #### CLI command
769/// ```sh
770/// # without cuda
771/// cargo run example 06.12
772///
773/// # with cuda
774/// cargo run --features cuda example 06.12
775/// ```
776pub struct EG12;
777
778impl Example for EG12 {
779    fn description(&self) -> String {
780        let desc = "Example usage of `calc_loss_loader` to compute accuracy on \
781        test, train, val sets.";
782        desc.to_string()
783    }
784
785    fn page_source(&self) -> usize {
786        194_usize
787    }
788
789    fn main(&self) -> Result<()> {
790        use crate::listings::{
791            ch04::Config,
792            ch06::{
793                calc_loss_loader, download_and_load_gpt2, modify_out_head_for_classification,
794                HF_GPT2_MODEL_ID,
795            },
796        };
797        use candle_core::{DType, Device};
798        use candle_nn::{VarBuilder, VarMap};
799
800        // get gpt model with classification head
801        let mut cfg = Config::gpt2_124m();
802        cfg.qkv_bias = true;
803        let varmap = VarMap::new();
804        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
805        let mut model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, HF_GPT2_MODEL_ID)?;
806        modify_out_head_for_classification(&mut model, cfg, 2_usize, &varmap, vb.pp("model"))?;
807
808        // get data loaders
809        let eg06 = EG06; // re-use
810        let (train_loader, val_loader, test_loader) = eg06.main_with_return(false)?;
811
812        // compute accuracies
813        let num_batches = Some(5_usize);
814        let train_loss = calc_loss_loader(&train_loader, &model, vb.device(), num_batches, None)?;
815        let val_loss = calc_loss_loader(&val_loader, &model, vb.device(), num_batches, None)?;
816        let test_loss = calc_loss_loader(&test_loader, &model, vb.device(), num_batches, None)?;
817
818        println!("Training loss: {}", train_loss);
819        println!("Validation loss: {}", val_loss);
820        println!("Test loss: {}", test_loss);
821
822        Ok(())
823    }
824}
825
826/// # Example usage of `train_classifier_simple` and `plot_values` functions
827///
828/// #### Id
829/// 06.13
830///
831/// #### Page
832/// This example starts on page 149
833///
834/// #### CLI command
835/// ```sh
836/// # without cuda
837/// cargo run example 06.13
838///
839/// # with cuda
840/// cargo run --features cuda example 06.13
841/// ```
842pub struct EG13;
843
844impl Example for EG13 {
845    fn description(&self) -> String {
846        String::from("Example usage of `train_classifier_simple` and `plot_values` function.")
847    }
848
849    fn page_source(&self) -> usize {
850        197_usize
851    }
852
853    fn main(&self) -> Result<()> {
854        use crate::listings::{
855            ch04::Config,
856            ch06::{
857                download_and_load_gpt2, modify_out_head_for_classification, plot_values,
858                train_classifier_simple, HF_GPT2_MODEL_ID,
859            },
860        };
861        use candle_core::{DType, Device, Var};
862        use candle_nn::{AdamW, Optimizer, ParamsAdamW, VarBuilder, VarMap};
863        use ndarray::linspace;
864        use std::path::Path;
865
866        // get gpt model with classification head
867        let mut cfg = Config::gpt2_124m();
868        cfg.qkv_bias = true;
869        let varmap = VarMap::new();
870        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
871        let mut model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, HF_GPT2_MODEL_ID)?;
872        modify_out_head_for_classification(&mut model, cfg, 2_usize, &varmap, vb.pp("model"))?;
873
874        // get data loaders
875        let eg06 = EG06; // re-use
876        let (train_loader, val_loader, _test_loader) = eg06.main_with_return(false)?;
877
878        // trainable params and optimizer
879        // trainable: last trf block, final layer norm, classification head
880        let mut training_vars: Vec<Var> = vec![];
881        let tensor_data = varmap.data().lock().unwrap();
882        let var_names: Vec<&String> = tensor_data
883            .keys()
884            .filter(|k| k.contains("final_norm") || k.contains("out_head") || k.contains("trf.11"))
885            .collect();
886
887        println!("Training variables: {:?}\n", var_names);
888
889        for var_name in var_names.into_iter() {
890            let var = tensor_data.get(var_name).unwrap();
891            training_vars.push(var.clone());
892        }
893        drop(tensor_data);
894
895        let optimizer = AdamW::new(
896            training_vars,
897            ParamsAdamW {
898                lr: 5e-5,
899                weight_decay: 0.1,
900                ..Default::default()
901            },
902        )?;
903
904        let (eval_freq, eval_iter, num_epochs) = (50_usize, 5_usize, 5_usize);
905        let (train_loss, val_loss, train_accs, val_accs, num_examples) = train_classifier_simple(
906            &model,
907            &train_loader,
908            &val_loader,
909            optimizer,
910            vb.device(),
911            num_epochs,
912            eval_freq,
913            eval_iter,
914            None,
915        )?;
916
917        // save model
918        println!("Saving weights to `./clf.checkpoint.safetensors`");
919        varmap.save("clf.checkpoint.safetensors")?;
920
921        // prepare and save plots
922        let epochs_seen = Vec::from_iter(linspace(0_f32, num_epochs as f32, train_loss.len()));
923        let examples_seen = Vec::from_iter(linspace(0_f32, num_examples as f32, train_loss.len()));
924        let label = "loss";
925        let save_path =
926            Path::new(format!("plot_classification_{label}.html").as_str()).to_path_buf();
927        plot_values(
928            epochs_seen,
929            examples_seen,
930            train_loss,
931            val_loss,
932            label,
933            save_path,
934        )?;
935
936        let epochs_seen = Vec::from_iter(linspace(0_f32, num_epochs as f32, train_accs.len()));
937        let examples_seen = Vec::from_iter(linspace(0_f32, num_examples as f32, train_accs.len()));
938        let label = "accuracy";
939        let save_path =
940            Path::new(format!("plot_classification_{label}.html").as_str()).to_path_buf();
941        plot_values(
942            epochs_seen,
943            examples_seen,
944            train_accs,
945            val_accs,
946            label,
947            save_path,
948        )?;
949
950        Ok(())
951    }
952}
953
954/// # Loading fine-tuned model and calculate performance on whole train, val and test sets.
955///
956/// #### Id
957/// 06.14
958///
959/// #### Page
960/// This example starts on page 200
961///
962/// #### CLI command
963/// ```sh
964/// # without cuda
965/// cargo run example 06.14
966///
967/// # with cuda
968/// cargo run --features cuda example 06.14
969/// ```
970pub struct EG14;
971
972impl Example for EG14 {
973    fn description(&self) -> String {
974        String::from(
975            "Loading fine-tuned model and calculate performance on whole train, val and test sets.",
976        )
977    }
978
979    fn page_source(&self) -> usize {
980        200_usize
981    }
982
983    fn main(&self) -> Result<()> {
984        use crate::listings::{
985            ch04::{Config, GPTModel},
986            ch06::{calc_accuracy_loader, modify_out_head_for_classification},
987        };
988        use candle_core::{DType, Device};
989        use candle_nn::{VarBuilder, VarMap};
990
991        // get gpt model with classification head
992        let mut cfg = Config::gpt2_124m();
993        cfg.qkv_bias = true;
994        let mut varmap = VarMap::new();
995        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
996        let mut model = GPTModel::new(cfg, vb.pp("model"))?;
997        modify_out_head_for_classification(&mut model, cfg, 2_usize, &varmap, vb.pp("model"))?;
998
999        // load safetensors
1000        varmap
1001            .load("clf.checkpoint.safetensors")
1002            .with_context(|| "Missing 'clf.checkpoint.safetensors' file. Please run EG 06.13.")?;
1003
1004        // get data loaders
1005        let eg06 = EG06; // re-use
1006        let (train_loader, val_loader, test_loader) = eg06.main_with_return(false)?;
1007
1008        // compute accuracies
1009        let num_batches = None;
1010        let train_accuracy =
1011            calc_accuracy_loader(&train_loader, &model, vb.device(), num_batches, None)?;
1012        let val_accuracy =
1013            calc_accuracy_loader(&val_loader, &model, vb.device(), num_batches, None)?;
1014        let test_accuracy =
1015            calc_accuracy_loader(&test_loader, &model, vb.device(), num_batches, None)?;
1016
1017        println!("Training accuracy: {}", train_accuracy);
1018        println!("Validation accuracy: {}", val_accuracy);
1019        println!("Test accuracy: {}", test_accuracy);
1020
1021        Ok(())
1022    }
1023}
1024
1025/// # Example usage of `classify_review`
1026///
1027/// #### Id
1028/// 06.15
1029///
1030/// #### Page
1031/// This example starts on page 202
1032///
1033/// #### CLI command
1034/// ```sh
1035/// # without cuda
1036/// cargo run example 06.15
1037///
1038/// # with cuda
1039/// cargo run --features cuda example 06.15
1040/// ```
1041pub struct EG15;
1042
1043impl Example for EG15 {
1044    fn description(&self) -> String {
1045        String::from("Example usage of `classify_review`.")
1046    }
1047
1048    fn page_source(&self) -> usize {
1049        202_usize
1050    }
1051
1052    fn main(&self) -> Result<()> {
1053        use crate::listings::{
1054            ch04::{Config, GPTModel},
1055            ch06::{
1056                classify_review, modify_out_head_for_classification, SpamDatasetBuilder,
1057                PAD_TOKEN_ID,
1058            },
1059        };
1060        use anyhow::anyhow;
1061        use candle_core::{DType, Device};
1062        use candle_nn::{VarBuilder, VarMap};
1063        use std::ops::Not;
1064        use std::path::Path;
1065        use tiktoken_rs::get_bpe_from_model;
1066
1067        // tokenizer and train_dataset
1068        let tokenizer = get_bpe_from_model("gpt2")?;
1069        let train_path = Path::new("data").join("train.parquet");
1070        if train_path.exists().not() {
1071            return Err(anyhow!(
1072                "Missing 'data/train.parquet' file. Please run EG 06.04."
1073            ));
1074        }
1075        let train_dataset = SpamDatasetBuilder::new(&tokenizer)
1076            .load_data_from_parquet(train_path)
1077            .build();
1078
1079        // get gpt model with classification head
1080        let mut cfg = Config::gpt2_124m();
1081        cfg.qkv_bias = true;
1082        let mut varmap = VarMap::new();
1083        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
1084        let mut model = GPTModel::new(cfg, vb.pp("model"))?;
1085        modify_out_head_for_classification(&mut model, cfg, 2_usize, &varmap, vb.pp("model"))?;
1086
1087        // load safetensors from finetuning
1088        varmap
1089            .load("clf.checkpoint.safetensors")
1090            .with_context(|| "Missing 'clf.checkpoint.safetensors' file. Please run EG 06.13.")?;
1091
1092        // classify texts
1093        let text_1 = "You are a winner you have been specially selected to receive \
1094        $1000 cash or a $2000 award.";
1095        println!(
1096            "{}",
1097            classify_review(
1098                text_1,
1099                &model,
1100                &tokenizer,
1101                vb.device(),
1102                Some(train_dataset.max_length()),
1103                PAD_TOKEN_ID,
1104            )
1105            .with_context(|| "Failed to classify text_1.")?,
1106        );
1107
1108        let text_2 = "Hey, just wanted to check if we're still on for \"
1109        dinner tonight? Let me know!";
1110        println!(
1111            "{}",
1112            classify_review(
1113                text_2,
1114                &model,
1115                &tokenizer,
1116                vb.device(),
1117                Some(train_dataset.max_length()),
1118                PAD_TOKEN_ID,
1119            )
1120            .with_context(|| "Failed to classify text_2.")?,
1121        );
1122
1123        Ok(())
1124    }
1125}
1126
1127pub mod addons {
1128    //! Auxiliary module for examples::ch06
1129    use polars::prelude::*;
1130    use std::path::Path;
1131
1132    /// Helper function to get value counts for a polars::DataFrame for a specified column
1133    pub fn get_value_counts(df: &DataFrame, cname: &str) -> anyhow::Result<DataFrame> {
1134        let result = df
1135            .clone()
1136            .lazy()
1137            .select([col(cname)
1138                .value_counts(false, false, "count", false)
1139                .alias("value_counts")])
1140            .collect()?;
1141        Ok(result)
1142    }
1143
1144    pub fn write_csv<P: AsRef<Path>>(df: &mut DataFrame, fname: P) -> anyhow::Result<()> {
1145        let mut file = std::fs::File::create(fname)?;
1146        CsvWriter::new(&mut file).finish(df)?;
1147        Ok(())
1148    }
1149
1150    pub fn write_parquet<P: AsRef<Path>>(df: &mut DataFrame, fname: P) -> anyhow::Result<()> {
1151        let mut file = std::fs::File::create(fname)?;
1152        ParquetWriter::new(&mut file).finish(df)?;
1153        Ok(())
1154    }
1155}