1use crate::Example;
4use anyhow::{Context, Result};
5
6pub struct EG01;
23
24impl Example for EG01 {
25 fn description(&self) -> String {
26 String::from("Sample usage of `download_and_unzip_spam_data`.")
27 }
28
29 fn page_source(&self) -> usize {
30 173_usize
31 }
32
33 fn main(&self) -> Result<()> {
34 use crate::listings::ch06::{download_and_unzip_spam_data, EXTRACTED_PATH, URL, ZIP_PATH};
35 use polars::prelude::*;
36 use std::sync::Arc;
37
38 download_and_unzip_spam_data(URL, ZIP_PATH, EXTRACTED_PATH)?;
40
41 let f1 = Field::new("Label".into(), DataType::String);
43 let f2 = Field::new("Text".into(), DataType::String);
44 let sc = Arc::new(Schema::from_iter(vec![f1, f2]));
45 let parse_options = CsvParseOptions::default()
46 .with_separator(b'\t')
47 .with_quote_char(None);
48 let df = CsvReadOptions::default()
49 .with_parse_options(parse_options)
50 .with_schema(Some(sc))
51 .with_has_header(false)
52 .try_into_reader_with_file_path(Some("data/SMSSpamCollection.tsv".into()))
53 .unwrap()
54 .finish()?;
55 println!("{}", df);
56
57 let value_counts = addons::get_value_counts(&df, "Label")?;
59 println!("{}", value_counts);
60
61 Ok(())
62 }
63}
64
65pub struct EG02;
82
83impl Example for EG02 {
84 fn description(&self) -> String {
85 String::from("Sample usage of `download_smsspam_parquet`.")
86 }
87
88 fn page_source(&self) -> usize {
89 173_usize
90 }
91
92 fn main(&self) -> Result<()> {
93 use crate::listings::ch06::{download_smsspam_parquet, PARQUET_FILENAME, PARQUET_URL};
94 use polars::prelude::*;
95 use std::path::PathBuf;
96
97 download_smsspam_parquet(PARQUET_URL)?;
99
100 let mut file_path = PathBuf::from("data");
102 file_path.push(PARQUET_FILENAME);
103 let mut file = std::fs::File::open(file_path)?;
104 let df = ParquetReader::new(&mut file).finish()?;
105 let df = df
106 .clone()
107 .lazy()
108 .with_column(
109 when(col("label").eq(0))
110 .then(lit("ham"))
111 .otherwise(lit("spam"))
112 .alias("label_text"),
113 )
114 .collect()?;
115 println!("{}", df);
116
117 let value_counts = addons::get_value_counts(&df, "label_text")?;
119 println!("{}", value_counts);
120
121 Ok(())
122 }
123}
124
125pub struct EG03;
142
143impl Example for EG03 {
144 fn description(&self) -> String {
145 String::from("Example usage of `create_balanced_dataset`.")
146 }
147
148 fn page_source(&self) -> usize {
149 174_usize
150 }
151
152 fn main(&self) -> Result<()> {
153 use crate::listings::ch06::{
154 create_balanced_dataset, download_smsspam_parquet, PARQUET_FILENAME, PARQUET_URL,
155 };
156 use polars::prelude::*;
157 use std::path::PathBuf;
158
159 download_smsspam_parquet(PARQUET_URL)?;
161
162 let mut file_path = PathBuf::from("data");
164 file_path.push(PARQUET_FILENAME);
165 let mut file = std::fs::File::open(file_path).unwrap();
166 let df = ParquetReader::new(&mut file).finish().unwrap();
167
168 let balanced_df = create_balanced_dataset(df)?;
170 println!("{}", balanced_df);
171
172 let value_counts = addons::get_value_counts(&balanced_df, "label")?;
174 println!("{}", value_counts);
175
176 Ok(())
177 }
178}
179
180pub struct EG04;
197
198impl Example for EG04 {
199 fn description(&self) -> String {
200 String::from("Example usage of `random_split` to create our train, test, val splits.")
201 }
202
203 fn page_source(&self) -> usize {
204 174_usize
205 }
206
207 fn main(&self) -> Result<()> {
208 use crate::listings::ch06::{
209 create_balanced_dataset, download_smsspam_parquet, random_split, PARQUET_FILENAME,
210 PARQUET_URL,
211 };
212 use polars::prelude::*;
213 use std::{path::PathBuf, str::FromStr};
214
215 download_smsspam_parquet(PARQUET_URL)?;
217
218 let mut file_path = PathBuf::from("data");
220 file_path.push(PARQUET_FILENAME);
221 let mut file = std::fs::File::open(file_path).unwrap();
222 let df = ParquetReader::new(&mut file).finish().unwrap();
223
224 let balanced_df = create_balanced_dataset(df)?;
226
227 let (mut train_df, mut validation_df, mut test_df) =
229 random_split(&balanced_df, 0.7_f32, 0.1_f32)?;
230 println!("{}", train_df);
231 println!("{}", validation_df);
232 println!("{}", test_df);
233
234 let train_path = PathBuf::from_str("data/train.parquet")?;
236 let validation_path = PathBuf::from_str("data/validation.parquet")?;
237 let test_path = PathBuf::from_str("data/test.parquet")?;
238
239 addons::write_parquet(&mut train_df, train_path)?;
240 addons::write_parquet(&mut validation_df, validation_path)?;
241 addons::write_parquet(&mut test_df, test_path)?;
242
243 Ok(())
244 }
245}
246
247pub struct EG05;
264
265impl Example for EG05 {
266 fn description(&self) -> String {
267 String::from("Creating `SpamDataset` for train, test, and validation.")
268 }
269
270 fn page_source(&self) -> usize {
271 178_usize
272 }
273
274 fn main(&self) -> Result<()> {
275 use crate::listings::ch06::SpamDatasetBuilder;
276 use anyhow::anyhow;
277 use std::ops::Not;
278 use std::path::Path;
279 use tiktoken_rs::get_bpe_from_model;
280
281 let tokenizer = get_bpe_from_model("gpt2")?;
282
283 let train_path = Path::new("data").join("train.parquet");
284 if train_path.exists().not() {
285 return Err(anyhow!(
286 "Missing 'data/train.parquet' file. Please run EG 06.04."
287 ));
288 }
289 let train_dataset = SpamDatasetBuilder::new(&tokenizer)
290 .load_data_from_parquet(train_path)
291 .build();
292 println!("train dataset max length: {}", train_dataset.max_length());
293
294 let val_path = Path::new("data").join("validation.parquet");
295 if val_path.exists().not() {
296 return Err(anyhow!(
297 "Missing 'data/validation.parquet' file. Please run EG 06.04."
298 ));
299 }
300 let val_dataset = SpamDatasetBuilder::new(&tokenizer)
301 .load_data_from_parquet(val_path)
302 .max_length(Some(train_dataset.max_length()))
303 .build();
304 println!("val dataset max length: {}", val_dataset.max_length());
305
306 let test_path = Path::new("data").join("test.parquet");
307 if test_path.exists().not() {
308 return Err(anyhow!(
309 "Missing 'data/test.parquet' file. Please run EG 06.04."
310 ));
311 }
312 let test_dataset = SpamDatasetBuilder::new(&tokenizer)
313 .load_data_from_parquet(test_path)
314 .max_length(Some(train_dataset.max_length()))
315 .build();
316 println!("test dataset max length: {}", test_dataset.max_length());
317 Ok(())
318 }
319}
320
321pub struct EG06;
338
339impl EG06 {
340 pub fn main_with_return(
341 &self,
342 verbose: bool,
343 ) -> Result<(
344 crate::listings::ch06::SpamDataLoader,
345 crate::listings::ch06::SpamDataLoader,
346 crate::listings::ch06::SpamDataLoader,
347 )> {
348 use crate::listings::ch06::{SpamDataLoader, SpamDatasetBuilder};
349 use anyhow::anyhow;
350 use std::ops::Not;
351 use std::path::Path;
352 use tiktoken_rs::get_bpe_from_model;
353
354 let tokenizer = get_bpe_from_model("gpt2")?;
356
357 let train_path = Path::new("data").join("train.parquet");
358 if train_path.exists().not() {
359 return Err(anyhow!(
360 "Missing 'data/train.parquet' file. Please run EG 06.04."
361 ));
362 }
363 let train_dataset = SpamDatasetBuilder::new(&tokenizer)
364 .load_data_from_parquet(train_path)
365 .build();
366
367 let val_path = Path::new("data").join("validation.parquet");
368 if val_path.exists().not() {
369 return Err(anyhow!(
370 "Missing 'data/validation.parquet' file. Please run EG 06.04."
371 ));
372 }
373 let val_dataset = SpamDatasetBuilder::new(&tokenizer)
374 .load_data_from_parquet(val_path)
375 .build();
376
377 let test_path = Path::new("data").join("test.parquet");
378 if test_path.exists().not() {
379 return Err(anyhow!(
380 "Missing 'data/test.parquet' file. Please run EG 06.04."
381 ));
382 }
383 let test_dataset = SpamDatasetBuilder::new(&tokenizer)
384 .load_data_from_parquet(test_path)
385 .build();
386
387 let batch_size = 8_usize;
389 let train_loader = SpamDataLoader::new(train_dataset, batch_size, true, true);
390 let val_loader = SpamDataLoader::new(val_dataset, batch_size, false, false);
391 let test_loader = SpamDataLoader::new(test_dataset, batch_size, false, false);
392
393 if verbose {
394 let (input_batch, target_batch) = train_loader.batcher().last().unwrap()?;
396 println!("Input batch dimensions: {:?}", input_batch.shape());
397 println!("Label batch dimensions: {:?}", target_batch.shape());
398
399 println!("{:?} training batches", train_loader.len());
401 println!("{:?} validation batches", val_loader.len());
402 println!("{:?} test batches", test_loader.len());
403 }
404
405 Ok((train_loader, val_loader, test_loader))
406 }
407}
408
409impl Example for EG06 {
410 fn description(&self) -> String {
411 "Creating a `SpamDataLoader` for each of the train, val and test datasets.".to_string()
412 }
413
414 fn page_source(&self) -> usize {
415 180_usize
416 }
417
418 fn main(&self) -> Result<()> {
419 let _ = self.main_with_return(true)?;
420 Ok(())
421 }
422}
423
424pub struct EG07;
441
442impl Example for EG07 {
443 fn description(&self) -> String {
444 String::from("Example usage of `download_and_load_gpt2`.")
445 }
446
447 fn page_source(&self) -> usize {
448 182_usize
449 }
450
451 fn main(&self) -> Result<()> {
452 use crate::listings::{
453 ch04::Config,
454 ch05::{generate, text_to_token_ids, token_ids_to_text},
455 ch06::{download_and_load_gpt2, HF_GPT2_MODEL_ID},
456 };
457 use candle_core::{DType, Device};
458 use candle_nn::{VarBuilder, VarMap};
459 use rand::{rngs::StdRng, SeedableRng};
460 use tiktoken_rs::get_bpe_from_model;
461
462 let mut cfg = Config::gpt2_124m();
464 cfg.qkv_bias = true;
465 let varmap = VarMap::new();
466 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
467 let model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, HF_GPT2_MODEL_ID)?;
468
469 let tokenizer = get_bpe_from_model("gpt2")?;
471 let mut rng = StdRng::seed_from_u64(42_u64);
472
473 let text_1 = "Every effort moves you";
475 let token_ids = generate(
476 &model,
477 text_to_token_ids(text_1, &tokenizer, vb.device())?,
478 15_usize,
479 cfg.context_length,
480 None,
481 None,
482 None,
483 &mut rng,
484 )?;
485
486 println!(
488 "Output text:\n{:?}",
489 token_ids_to_text(token_ids, &tokenizer)
490 );
491
492 let text_2 = "Is the following text 'spam'? Answer with 'yes' or \
494 'no': 'You are a winner you have been specially selected to receive $1000 \
495 cash or a $2000 award.'";
496 let token_ids = generate(
497 &model,
498 text_to_token_ids(text_2, &tokenizer, vb.device())?,
499 23_usize,
500 cfg.context_length,
501 None,
502 None,
503 None,
504 &mut rng,
505 )?;
506
507 println!(
509 "Output text:\n{:?}",
510 token_ids_to_text(token_ids, &tokenizer)
511 );
512
513 Ok(())
514 }
515}
516
517pub struct EG08;
534
535impl Example for EG08 {
536 fn description(&self) -> String {
537 String::from("Printing the model architecture via `varmap.data()`.")
538 }
539
540 fn page_source(&self) -> usize {
541 185_usize
542 }
543
544 fn main(&self) -> Result<()> {
545 use crate::listings::{
546 ch04::Config,
547 ch06::{download_and_load_gpt2, HF_GPT2_MODEL_ID},
548 };
549 use candle_core::{DType, Device};
550 use candle_nn::{VarBuilder, VarMap};
551
552 let mut cfg = Config::gpt2_124m();
554 cfg.qkv_bias = true;
555 let varmap = VarMap::new();
556 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
557 let model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, HF_GPT2_MODEL_ID)?;
558
559 println!("{:#?}", model);
561
562 Ok(())
563 }
564}
565
566pub struct EG09;
583
584impl Example for EG09 {
585 fn description(&self) -> String {
586 String::from("Modifying the `out_head` of a GPT2Model and running inference.")
587 }
588
589 fn page_source(&self) -> usize {
590 186_usize
591 }
592
593 fn main(&self) -> Result<()> {
594 use crate::listings::{
595 ch04::Config,
596 ch06::{download_and_load_gpt2, modify_out_head_for_classification, HF_GPT2_MODEL_ID},
597 };
598 use candle_core::{DType, Device, IndexOp, ModuleT, Tensor};
599 use candle_nn::{VarBuilder, VarMap};
600 use tiktoken_rs::get_bpe_from_model;
601
602 let mut cfg = Config::gpt2_124m();
604 cfg.qkv_bias = true;
605 let varmap = VarMap::new();
606 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
607 let mut model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, HF_GPT2_MODEL_ID)?;
608
609 let tensor_data = varmap.data().lock().unwrap();
611 let out_head = tensor_data.get("model.out_head.weight");
612 println!("old classification head: {:?}", out_head);
613 drop(tensor_data);
614
615 let num_classes = 2_usize;
617 modify_out_head_for_classification(&mut model, cfg, num_classes, &varmap, vb.pp("model"))?;
618
619 let tensor_data = varmap.data().lock().unwrap();
621 let out_head = tensor_data.get("model.out_head.weight");
622 println!("new classification head: {:?}", out_head);
623
624 let tokenizer = get_bpe_from_model("gpt2")?;
626 let inputs = tokenizer.encode_with_special_tokens("Do you have time");
627 let num_tokens = inputs.len();
628 let inputs = Tensor::from_vec(inputs, num_tokens, vb.device())?.unsqueeze(0)?;
629 println!("Inputs: {:?}", inputs.to_vec2::<u32>());
630 println!("Inputs dimensions: {:?}", inputs);
631
632 let outputs = model.forward_t(&inputs, false)?;
633 println!("Outputs: {:?}", outputs.to_vec3::<f32>());
634 println!("Outputs dimensions: {:?}", outputs);
635
636 let (_b, c, _vocab_size) = outputs.dims3()?;
638 println!(
639 "Last output token: {:?}",
640 outputs.i((.., c - 1, ..))?.to_vec2::<f32>()
641 );
642
643 Ok(())
644 }
645}
646
647pub struct EG10;
664
665impl Example for EG10 {
666 fn description(&self) -> String {
667 "Toy example of predicting spam/ham from logits.".to_string()
668 }
669
670 fn page_source(&self) -> usize {
671 192_usize
672 }
673
674 fn main(&self) -> Result<()> {
675 use candle_core::{Device, Tensor, D};
676
677 let dev = Device::cuda_if_available(0)?;
678 let logits = Tensor::new(&[[-3.5983_f32, 3.9902]], &dev)?;
679 println!(
680 "Last output token (i.e. logits): {:?}",
681 logits.to_vec2::<f32>()?
682 );
683
684 let label = logits.argmax(D::Minus1)?;
685 println!("Class label: {:?}", label.squeeze(0)?.to_scalar::<u32>()?);
686
687 Ok(())
688 }
689}
690
691pub struct EG11;
708
709impl Example for EG11 {
710 fn description(&self) -> String {
711 let desc = "Example usage of `calc_accuracy_loader` to compute accuracy on \
712 test, train, val sets.";
713 desc.to_string()
714 }
715
716 fn page_source(&self) -> usize {
717 192_usize
718 }
719
720 fn main(&self) -> Result<()> {
721 use crate::listings::{
722 ch04::Config,
723 ch06::{
724 calc_accuracy_loader, download_and_load_gpt2, modify_out_head_for_classification,
725 HF_GPT2_MODEL_ID,
726 },
727 };
728 use candle_core::{DType, Device};
729 use candle_nn::{VarBuilder, VarMap};
730
731 let mut cfg = Config::gpt2_124m();
733 cfg.qkv_bias = true;
734 let varmap = VarMap::new();
735 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
736 let mut model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, HF_GPT2_MODEL_ID)?;
737 modify_out_head_for_classification(&mut model, cfg, 2_usize, &varmap, vb.pp("model"))?;
738
739 let eg06 = EG06; let (train_loader, val_loader, test_loader) = eg06.main_with_return(false)?;
742
743 let num_batches = Some(10_usize);
745 let train_accuracy =
746 calc_accuracy_loader(&train_loader, &model, vb.device(), num_batches, None)?;
747 let val_accuracy =
748 calc_accuracy_loader(&val_loader, &model, vb.device(), num_batches, None)?;
749 let test_accuracy =
750 calc_accuracy_loader(&test_loader, &model, vb.device(), num_batches, None)?;
751
752 println!("Training accuracy: {}", train_accuracy);
753 println!("Validation accuracy: {}", val_accuracy);
754 println!("Test accuracy: {}", test_accuracy);
755
756 Ok(())
757 }
758}
759
760pub struct EG12;
777
778impl Example for EG12 {
779 fn description(&self) -> String {
780 let desc = "Example usage of `calc_loss_loader` to compute accuracy on \
781 test, train, val sets.";
782 desc.to_string()
783 }
784
785 fn page_source(&self) -> usize {
786 194_usize
787 }
788
789 fn main(&self) -> Result<()> {
790 use crate::listings::{
791 ch04::Config,
792 ch06::{
793 calc_loss_loader, download_and_load_gpt2, modify_out_head_for_classification,
794 HF_GPT2_MODEL_ID,
795 },
796 };
797 use candle_core::{DType, Device};
798 use candle_nn::{VarBuilder, VarMap};
799
800 let mut cfg = Config::gpt2_124m();
802 cfg.qkv_bias = true;
803 let varmap = VarMap::new();
804 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
805 let mut model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, HF_GPT2_MODEL_ID)?;
806 modify_out_head_for_classification(&mut model, cfg, 2_usize, &varmap, vb.pp("model"))?;
807
808 let eg06 = EG06; let (train_loader, val_loader, test_loader) = eg06.main_with_return(false)?;
811
812 let num_batches = Some(5_usize);
814 let train_loss = calc_loss_loader(&train_loader, &model, vb.device(), num_batches, None)?;
815 let val_loss = calc_loss_loader(&val_loader, &model, vb.device(), num_batches, None)?;
816 let test_loss = calc_loss_loader(&test_loader, &model, vb.device(), num_batches, None)?;
817
818 println!("Training loss: {}", train_loss);
819 println!("Validation loss: {}", val_loss);
820 println!("Test loss: {}", test_loss);
821
822 Ok(())
823 }
824}
825
826pub struct EG13;
843
844impl Example for EG13 {
845 fn description(&self) -> String {
846 String::from("Example usage of `train_classifier_simple` and `plot_values` function.")
847 }
848
849 fn page_source(&self) -> usize {
850 197_usize
851 }
852
853 fn main(&self) -> Result<()> {
854 use crate::listings::{
855 ch04::Config,
856 ch06::{
857 download_and_load_gpt2, modify_out_head_for_classification, plot_values,
858 train_classifier_simple, HF_GPT2_MODEL_ID,
859 },
860 };
861 use candle_core::{DType, Device, Var};
862 use candle_nn::{AdamW, Optimizer, ParamsAdamW, VarBuilder, VarMap};
863 use ndarray::linspace;
864 use std::path::Path;
865
866 let mut cfg = Config::gpt2_124m();
868 cfg.qkv_bias = true;
869 let varmap = VarMap::new();
870 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
871 let mut model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, HF_GPT2_MODEL_ID)?;
872 modify_out_head_for_classification(&mut model, cfg, 2_usize, &varmap, vb.pp("model"))?;
873
874 let eg06 = EG06; let (train_loader, val_loader, _test_loader) = eg06.main_with_return(false)?;
877
878 let mut training_vars: Vec<Var> = vec![];
881 let tensor_data = varmap.data().lock().unwrap();
882 let var_names: Vec<&String> = tensor_data
883 .keys()
884 .filter(|k| k.contains("final_norm") || k.contains("out_head") || k.contains("trf.11"))
885 .collect();
886
887 println!("Training variables: {:?}\n", var_names);
888
889 for var_name in var_names.into_iter() {
890 let var = tensor_data.get(var_name).unwrap();
891 training_vars.push(var.clone());
892 }
893 drop(tensor_data);
894
895 let optimizer = AdamW::new(
896 training_vars,
897 ParamsAdamW {
898 lr: 5e-5,
899 weight_decay: 0.1,
900 ..Default::default()
901 },
902 )?;
903
904 let (eval_freq, eval_iter, num_epochs) = (50_usize, 5_usize, 5_usize);
905 let (train_loss, val_loss, train_accs, val_accs, num_examples) = train_classifier_simple(
906 &model,
907 &train_loader,
908 &val_loader,
909 optimizer,
910 vb.device(),
911 num_epochs,
912 eval_freq,
913 eval_iter,
914 None,
915 )?;
916
917 println!("Saving weights to `./clf.checkpoint.safetensors`");
919 varmap.save("clf.checkpoint.safetensors")?;
920
921 let epochs_seen = Vec::from_iter(linspace(0_f32, num_epochs as f32, train_loss.len()));
923 let examples_seen = Vec::from_iter(linspace(0_f32, num_examples as f32, train_loss.len()));
924 let label = "loss";
925 let save_path =
926 Path::new(format!("plot_classification_{label}.html").as_str()).to_path_buf();
927 plot_values(
928 epochs_seen,
929 examples_seen,
930 train_loss,
931 val_loss,
932 label,
933 save_path,
934 )?;
935
936 let epochs_seen = Vec::from_iter(linspace(0_f32, num_epochs as f32, train_accs.len()));
937 let examples_seen = Vec::from_iter(linspace(0_f32, num_examples as f32, train_accs.len()));
938 let label = "accuracy";
939 let save_path =
940 Path::new(format!("plot_classification_{label}.html").as_str()).to_path_buf();
941 plot_values(
942 epochs_seen,
943 examples_seen,
944 train_accs,
945 val_accs,
946 label,
947 save_path,
948 )?;
949
950 Ok(())
951 }
952}
953
954pub struct EG14;
971
972impl Example for EG14 {
973 fn description(&self) -> String {
974 String::from(
975 "Loading fine-tuned model and calculate performance on whole train, val and test sets.",
976 )
977 }
978
979 fn page_source(&self) -> usize {
980 200_usize
981 }
982
983 fn main(&self) -> Result<()> {
984 use crate::listings::{
985 ch04::{Config, GPTModel},
986 ch06::{calc_accuracy_loader, modify_out_head_for_classification},
987 };
988 use candle_core::{DType, Device};
989 use candle_nn::{VarBuilder, VarMap};
990
991 let mut cfg = Config::gpt2_124m();
993 cfg.qkv_bias = true;
994 let mut varmap = VarMap::new();
995 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
996 let mut model = GPTModel::new(cfg, vb.pp("model"))?;
997 modify_out_head_for_classification(&mut model, cfg, 2_usize, &varmap, vb.pp("model"))?;
998
999 varmap
1001 .load("clf.checkpoint.safetensors")
1002 .with_context(|| "Missing 'clf.checkpoint.safetensors' file. Please run EG 06.13.")?;
1003
1004 let eg06 = EG06; let (train_loader, val_loader, test_loader) = eg06.main_with_return(false)?;
1007
1008 let num_batches = None;
1010 let train_accuracy =
1011 calc_accuracy_loader(&train_loader, &model, vb.device(), num_batches, None)?;
1012 let val_accuracy =
1013 calc_accuracy_loader(&val_loader, &model, vb.device(), num_batches, None)?;
1014 let test_accuracy =
1015 calc_accuracy_loader(&test_loader, &model, vb.device(), num_batches, None)?;
1016
1017 println!("Training accuracy: {}", train_accuracy);
1018 println!("Validation accuracy: {}", val_accuracy);
1019 println!("Test accuracy: {}", test_accuracy);
1020
1021 Ok(())
1022 }
1023}
1024
1025pub struct EG15;
1042
1043impl Example for EG15 {
1044 fn description(&self) -> String {
1045 String::from("Example usage of `classify_review`.")
1046 }
1047
1048 fn page_source(&self) -> usize {
1049 202_usize
1050 }
1051
1052 fn main(&self) -> Result<()> {
1053 use crate::listings::{
1054 ch04::{Config, GPTModel},
1055 ch06::{
1056 classify_review, modify_out_head_for_classification, SpamDatasetBuilder,
1057 PAD_TOKEN_ID,
1058 },
1059 };
1060 use anyhow::anyhow;
1061 use candle_core::{DType, Device};
1062 use candle_nn::{VarBuilder, VarMap};
1063 use std::ops::Not;
1064 use std::path::Path;
1065 use tiktoken_rs::get_bpe_from_model;
1066
1067 let tokenizer = get_bpe_from_model("gpt2")?;
1069 let train_path = Path::new("data").join("train.parquet");
1070 if train_path.exists().not() {
1071 return Err(anyhow!(
1072 "Missing 'data/train.parquet' file. Please run EG 06.04."
1073 ));
1074 }
1075 let train_dataset = SpamDatasetBuilder::new(&tokenizer)
1076 .load_data_from_parquet(train_path)
1077 .build();
1078
1079 let mut cfg = Config::gpt2_124m();
1081 cfg.qkv_bias = true;
1082 let mut varmap = VarMap::new();
1083 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
1084 let mut model = GPTModel::new(cfg, vb.pp("model"))?;
1085 modify_out_head_for_classification(&mut model, cfg, 2_usize, &varmap, vb.pp("model"))?;
1086
1087 varmap
1089 .load("clf.checkpoint.safetensors")
1090 .with_context(|| "Missing 'clf.checkpoint.safetensors' file. Please run EG 06.13.")?;
1091
1092 let text_1 = "You are a winner you have been specially selected to receive \
1094 $1000 cash or a $2000 award.";
1095 println!(
1096 "{}",
1097 classify_review(
1098 text_1,
1099 &model,
1100 &tokenizer,
1101 vb.device(),
1102 Some(train_dataset.max_length()),
1103 PAD_TOKEN_ID,
1104 )
1105 .with_context(|| "Failed to classify text_1.")?,
1106 );
1107
1108 let text_2 = "Hey, just wanted to check if we're still on for \"
1109 dinner tonight? Let me know!";
1110 println!(
1111 "{}",
1112 classify_review(
1113 text_2,
1114 &model,
1115 &tokenizer,
1116 vb.device(),
1117 Some(train_dataset.max_length()),
1118 PAD_TOKEN_ID,
1119 )
1120 .with_context(|| "Failed to classify text_2.")?,
1121 );
1122
1123 Ok(())
1124 }
1125}
1126
1127pub mod addons {
1128 use polars::prelude::*;
1130 use std::path::Path;
1131
1132 pub fn get_value_counts(df: &DataFrame, cname: &str) -> anyhow::Result<DataFrame> {
1134 let result = df
1135 .clone()
1136 .lazy()
1137 .select([col(cname)
1138 .value_counts(false, false, "count", false)
1139 .alias("value_counts")])
1140 .collect()?;
1141 Ok(result)
1142 }
1143
1144 pub fn write_csv<P: AsRef<Path>>(df: &mut DataFrame, fname: P) -> anyhow::Result<()> {
1145 let mut file = std::fs::File::create(fname)?;
1146 CsvWriter::new(&mut file).finish(df)?;
1147 Ok(())
1148 }
1149
1150 pub fn write_parquet<P: AsRef<Path>>(df: &mut DataFrame, fname: P) -> anyhow::Result<()> {
1151 let mut file = std::fs::File::create(fname)?;
1152 ParquetWriter::new(&mut file).finish(df)?;
1153 Ok(())
1154 }
1155}