1use crate::Example;
4use anyhow::{anyhow, Context, Result};
5
6pub struct EG01;
23
24impl Example for EG01 {
25 fn description(&self) -> String {
26 "Example usage of `download_and_load_file`.".to_string()
27 }
28
29 fn page_source(&self) -> usize {
30 207_usize
31 }
32
33 fn main(&self) -> Result<()> {
34 use crate::listings::ch07::{
35 download_and_load_file, DATA_DIR, INSTRUCTION_DATA_FILENAME, INSTRUCTION_DATA_URL,
36 };
37 use std::path::Path;
38
39 let file_path = Path::new(DATA_DIR).join(INSTRUCTION_DATA_FILENAME);
40 let data = download_and_load_file(file_path, INSTRUCTION_DATA_URL, false)?;
41 println!("Number of entries: {}", data.len());
42
43 println!("Example entry:\n{}\n", data[50]);
45
46 println!("Another example entry:\n{}", data[999]);
48
49 Ok(())
50 }
51}
52
53pub struct EG02;
70
71impl Example for EG02 {
72 fn description(&self) -> String {
73 "Example usage of `format_input`.".to_string()
74 }
75
76 fn page_source(&self) -> usize {
77 209_usize
78 }
79
80 fn main(&self) -> Result<()> {
81 use crate::listings::ch07::{
82 download_and_load_file, AlpacaPromptFormatter, InstructionExample, PromptFormatter,
83 DATA_DIR, INSTRUCTION_DATA_FILENAME, INSTRUCTION_DATA_URL,
84 };
85 use std::path::Path;
86
87 let file_path = Path::new(DATA_DIR).join(INSTRUCTION_DATA_FILENAME);
89 let data = download_and_load_file(file_path, INSTRUCTION_DATA_URL, false)?;
90 let prompt_formatter = AlpacaPromptFormatter;
91
92 let model_input = prompt_formatter.format_input(&data[50]);
94 let detailed_response = format!("\n\n### Response:\n{}", data[50].output());
95 println!("{}", model_input + &detailed_response);
96
97 println!("\n---\n");
99
100 let model_input = prompt_formatter.format_input(&data[999]);
102 let detailed_response = format!("\n\n### Response:\n{}", data[999].output());
103 println!("{}", model_input + &detailed_response);
104
105 Ok(())
106 }
107}
108
109pub struct EG03;
126
127impl Example for EG03 {
128 fn description(&self) -> String {
129 String::from("Example usage of `partition_data`")
130 }
131
132 fn page_source(&self) -> usize {
133 210_usize
134 }
135
136 fn main(&self) -> Result<()> {
137 use crate::listings::ch07::{
138 download_and_load_file, partition_data, DATA_DIR, INSTRUCTION_DATA_FILENAME,
139 INSTRUCTION_DATA_URL,
140 };
141 use std::path::Path;
142
143 let file_path = Path::new(DATA_DIR).join(INSTRUCTION_DATA_FILENAME);
145 let data = download_and_load_file(file_path, INSTRUCTION_DATA_URL, false)?;
146
147 let (train_data, val_data, test_data) = partition_data(data, 0.85_f32, 0.05_f32)?;
149
150 println!("Training set length: {}", train_data.len());
151 println!("Validation set length: {}", val_data.len());
152 println!("Test set length: {}", test_data.len());
153
154 Ok(())
155 }
156}
157
158pub struct EG04;
175
176impl Example for EG04 {
177 fn description(&self) -> String {
178 "Example usage of `<|endoftext|>` special token with tiktoken.".to_string()
179 }
180
181 fn page_source(&self) -> usize {
182 214_usize
183 }
184
185 fn main(&self) -> Result<()> {
186 use std::collections::HashSet;
187 use tiktoken_rs::get_bpe_from_model;
188
189 let allowed_special = HashSet::from(["<|endoftext|>"]);
190 let tokenizer = get_bpe_from_model("gpt2")?;
191 println!("{:?}", tokenizer.encode("<|endoftext|>", allowed_special));
192
193 Ok(())
194 }
195}
196
197pub struct EG05;
214
215impl Example for EG05 {
216 fn description(&self) -> String {
217 String::from("Example usage of `InstructionDataCollator.custom_collate_fn`.")
218 }
219
220 fn page_source(&self) -> usize {
221 220_usize
222 }
223
224 fn main(&self) -> Result<()> {
225 use crate::listings::ch07::InstructionDataCollator;
226 use candle_core::{Device, Tensor};
227
228 let device = Device::cuda_if_available(0)?;
229 let inputs_1 = Tensor::new(&[0_u32, 1, 2, 3, 4], &device)?;
230 let inputs_2 = Tensor::new(&[5_u32, 6], &device)?;
231 let inputs_3 = Tensor::new(&[7_u32, 8, 9], &device)?;
232 let batch = vec![inputs_1, inputs_2, inputs_3];
233
234 let collator = InstructionDataCollator::new();
235 let (inputs, targets) = collator.custom_collate_fn(batch)?;
236
237 println!("inputs:\n{:?}", inputs.to_vec2::<u32>()?);
238 println!("targets:\n{:?}", targets.to_vec2::<i64>()?);
239
240 Ok(())
241 }
242}
243
244pub struct EG06;
261
262impl Example for EG06 {
263 fn description(&self) -> String {
264 "An adapted example demonstrating effect of `ignore_index` in `calc_loss_batch`."
265 .to_string()
266 }
267
268 fn page_source(&self) -> usize {
269 221_usize
270 }
271
272 fn main(&self) -> Result<()> {
276 use crate::listings::{
277 ch04::{Config, GPTModel},
278 ch05::{calc_loss_batch, DEFAULT_IGNORE_INDEX},
279 };
280 use candle_core::{DType, Device, Tensor};
281 use candle_nn::{VarBuilder, VarMap};
282
283 let varmap = VarMap::new();
285 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
286 let cfg = Config::gpt_sm_test();
287 let model = GPTModel::new(cfg, vb.pp("model"))?;
288
289 let inputs = Tensor::new(&[[100_u32, 20, 300]], vb.device())?;
291 let targets = Tensor::new(&[[1_u32, 2, 3]], vb.device())?;
292 let loss = calc_loss_batch(&inputs, &targets, &model, vb.device(), false, None)?;
293
294 println!("Inputs: {:?}", inputs.to_vec2::<u32>()?);
295 println!("Targets: {:?}", inputs.to_vec2::<u32>()?);
296 println!("Loss: {:?}", loss);
297
298 let inputs_2 = Tensor::new(&[[100_u32, 20, 300], [400, 7, 88]], vb.device())?;
300 let targets_2 = Tensor::new(
301 &[
302 [1_i64, 2, 3],
303 [
304 DEFAULT_IGNORE_INDEX,
305 DEFAULT_IGNORE_INDEX,
306 DEFAULT_IGNORE_INDEX,
307 ],
308 ],
309 vb.device(),
310 )?;
311 let loss_2 = calc_loss_batch(
312 &inputs_2,
313 &targets_2,
314 &model,
315 vb.device(),
316 false,
317 Some(DEFAULT_IGNORE_INDEX),
318 )?;
319
320 println!(
321 "---\nSimilar inputs but now a second sequence whose target has the ignore index:\n"
322 );
323
324 println!("Inputs: {:?}", inputs_2.to_vec2::<u32>()?);
325 println!("Targets: {:?}", targets_2.to_vec2::<i64>()?);
326 println!("Loss: {:?}", loss_2);
327
328 Ok(())
329 }
330}
331
332pub struct EG07;
349
350impl EG07 {
351 pub fn main_with_return(
352 &self,
353 batch_size: usize,
354 verbose: bool,
355 ) -> Result<(
356 crate::listings::ch07::InstructionDataLoader<
357 crate::listings::ch07::InstructionDataCollator,
358 >,
359 crate::listings::ch07::InstructionDataLoader<
360 crate::listings::ch07::InstructionDataCollator,
361 >,
362 crate::listings::ch07::InstructionDataLoader<
363 crate::listings::ch07::InstructionDataCollator,
364 >,
365 )> {
366 use crate::listings::ch07::{
367 download_and_load_file, partition_data, AlpacaPromptFormatter, DataLoader,
368 InstructionDataCollator, InstructionDataLoader, InstructionDataset, DATA_DIR,
369 INSTRUCTION_DATA_FILENAME, INSTRUCTION_DATA_URL,
370 };
371 use candle_core::Device;
372 use std::path::Path;
373 use tiktoken_rs::get_bpe_from_model;
374
375 let tokenizer = get_bpe_from_model("gpt2")?;
376
377 let file_path = Path::new(DATA_DIR).join(INSTRUCTION_DATA_FILENAME);
379 let data = download_and_load_file(file_path, INSTRUCTION_DATA_URL, false)?;
380
381 let (train_data, val_data, test_data) = partition_data(data, 0.85_f32, 0.05_f32)?;
383 let prompt_formatter = AlpacaPromptFormatter;
384 let train_dataset = InstructionDataset::new(train_data, &tokenizer, &prompt_formatter);
385 let val_dataset = InstructionDataset::new(val_data, &tokenizer, &prompt_formatter);
386 let test_dataset = InstructionDataset::new(test_data, &tokenizer, &prompt_formatter);
387
388 let collator = InstructionDataCollator::new()
390 .device(Device::cuda_if_available(0)?)
391 .allowed_max_length(Some(1024_usize));
392 let train_loader =
393 InstructionDataLoader::new(train_dataset, batch_size, true, true, collator.clone());
394 let val_loader =
395 InstructionDataLoader::new(val_dataset, batch_size, false, false, collator.clone());
396 let test_loader =
397 InstructionDataLoader::new(test_dataset, batch_size, false, false, collator);
398
399 if verbose {
400 println!("Train loader:");
401 let mut batcher = train_loader.batcher();
402 while let Some(Ok((inputs, targets))) = batcher.next() {
403 println!("inputs: {:?} targets: {:?}", inputs, targets);
404 }
405 }
406
407 Ok((train_loader, val_loader, test_loader))
408 }
409}
410
411impl Example for EG07 {
412 fn description(&self) -> String {
413 let desc = "Creating a `InstructionDataLoader` for each of the train, \
414 val and test data partitions";
415 desc.to_string()
416 }
417
418 fn page_source(&self) -> usize {
419 225_usize
420 }
421
422 fn main(&self) -> Result<()> {
423 let _ = self.main_with_return(8_usize, true);
424 Ok(())
425 }
426}
427
428pub struct EG08;
445
446impl Example for EG08 {
447 fn description(&self) -> String {
448 "Example usage of `download_and_load_gpt2` and sample instruction inference.".to_string()
449 }
450
451 fn page_source(&self) -> usize {
452 227_usize
453 }
454
455 fn main(&self) -> Result<()> {
456 use crate::listings::{
457 ch04::Config,
458 ch05::{generate, text_to_token_ids, token_ids_to_text},
459 ch07::{
460 download_and_load_file, download_and_load_gpt2, partition_data,
461 AlpacaPromptFormatter, PromptFormatter, DATA_DIR, INSTRUCTION_DATA_FILENAME,
462 INSTRUCTION_DATA_URL,
463 },
464 };
465 use candle_core::{DType, Device, Tensor};
466 use candle_nn::{VarBuilder, VarMap};
467 use rand::{rngs::StdRng, SeedableRng};
468 use std::path::Path;
469 use tiktoken_rs::get_bpe_from_model;
470
471 let file_path = Path::new(DATA_DIR).join(INSTRUCTION_DATA_FILENAME);
473 let data = download_and_load_file(file_path, INSTRUCTION_DATA_URL, false)?;
474 let (_train_data, val_data, _test_data) = partition_data(data, 0.85_f32, 0.05_f32)?;
475
476 let mut cfg = Config::gpt2_medium();
478 cfg.qkv_bias = true;
479 let varmap = VarMap::new();
480 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
481 let model_id = "openai-community/gpt2-medium";
482 let model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, model_id)?;
483
484 let prompt_formatter = AlpacaPromptFormatter;
486 let input_text = prompt_formatter.format_input(&val_data[0]);
487 println!("{}", input_text);
488
489 let tokenizer = get_bpe_from_model("gpt2")?;
491 let mut rng = StdRng::seed_from_u64(42_u64);
492 let token_ids = generate(
493 &model,
494 text_to_token_ids(input_text.as_str(), &tokenizer, vb.device())?,
495 35_usize,
496 cfg.context_length,
497 None,
498 None,
499 Some(Tensor::new(&[50_256_u32], vb.device())?),
500 &mut rng,
501 )?;
502 let generated_text = token_ids_to_text(token_ids, &tokenizer)?;
503 let response_text = &generated_text[input_text.len()..].trim();
504
505 println!("---generated-text-below---\n{}", response_text);
506
507 Ok(())
508 }
509}
510
511pub struct EG09;
528
529impl Example for EG09 {
530 fn description(&self) -> String {
531 let desc = "Example usage of `calc_loss_loader` to compute accuracy on \
532 train, validation and test instruction datasets.";
533 desc.to_string()
534 }
535
536 fn page_source(&self) -> usize {
537 230_usize
538 }
539
540 fn main(&self) -> Result<()> {
541 use crate::listings::{
542 ch04::Config,
543 ch05::DEFAULT_IGNORE_INDEX,
544 ch07::{calc_loss_loader, download_and_load_gpt2},
545 };
546 use candle_core::{DType, Device};
547 use candle_nn::{VarBuilder, VarMap};
548
549 let mut cfg = Config::gpt2_medium();
551 cfg.qkv_bias = true;
552 let varmap = VarMap::new();
553 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
554 let model_id = "openai-community/gpt2-medium";
555 let model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, model_id)?;
556
557 let eg07 = EG07;
559 let (train_loader, val_loader, _test_loader) = eg07.main_with_return(8_usize, false)?;
560
561 let num_batches = Some(5_usize);
563 let train_loss = calc_loss_loader(
564 &train_loader,
565 &model,
566 vb.device(),
567 num_batches,
568 Some(DEFAULT_IGNORE_INDEX),
569 )?;
570 let val_loss = calc_loss_loader(
571 &val_loader,
572 &model,
573 vb.device(),
574 num_batches,
575 Some(DEFAULT_IGNORE_INDEX),
576 )?;
577
578 println!("Training loss: {}", train_loss);
579 println!("Validation loss: {}", val_loss);
580
581 Ok(())
582 }
583}
584
585pub struct EG10;
605
606impl Example for EG10 {
607 fn description(&self) -> String {
608 "Example usage of `train_model_simple` and `plot_losses` functions".to_string()
609 }
610
611 fn page_source(&self) -> usize {
612 231_usize
613 }
614
615 fn main(&self) -> Result<()> {
617 use crate::listings::{
618 ch04::Config,
619 ch05::plot_losses,
620 ch07::{
621 download_and_load_gpt2, train_model_simple, AlpacaPromptFormatter, PromptFormatter,
622 DEFAULT_IGNORE_INDEX,
623 },
624 };
625 use candle_core::{DType, Device};
626 use candle_nn::{AdamW, Optimizer, ParamsAdamW, VarBuilder, VarMap};
627 use ndarray::linspace;
628 use std::path::Path;
629 use tiktoken_rs::get_bpe_from_model;
630
631 let model_id = "openai-community/gpt2"; let mut cfg = Config::gpt2_124m(); cfg.qkv_bias = true;
636 let varmap = VarMap::new();
637 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
638 let model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, model_id)?;
639
640 let eg07 = EG07;
642 let (train_loader, val_loader, _test_loader) = eg07.main_with_return(8_usize, false)?;
643
644 let (eval_freq, eval_iter, num_epochs) = (5_usize, 5_usize, 2_usize);
646 let optimizer = AdamW::new(
647 varmap.all_vars(),
648 ParamsAdamW {
649 lr: 0.00005,
650 weight_decay: 0.1,
651 ..Default::default()
652 },
653 )?;
654 let tokenizer = get_bpe_from_model("gpt2")?;
655 let prompt_formatter = AlpacaPromptFormatter;
656 let start_context = prompt_formatter.format_input(&val_loader.dataset().data()[0]);
657 let (train_losses, val_losses, tokens_seen) = train_model_simple(
658 &model,
659 &train_loader,
660 &val_loader,
661 optimizer,
662 vb.device(),
663 num_epochs,
664 eval_freq,
665 eval_iter,
666 start_context.as_str(),
667 &tokenizer,
668 Some(DEFAULT_IGNORE_INDEX),
669 )?;
670
671 println!("Saving weights to `./ift.checkpoint.safetensors`");
673 varmap.save("ift.checkpoint.safetensors")?;
674
675 println!("Saving plot to `./plot_ift_loss.html`");
677 let epochs_seen = Vec::from_iter(linspace(0_f32, num_epochs as f32, train_losses.len()));
678 let tokens_seen = tokens_seen
679 .into_iter()
680 .map(|el| el as f32)
681 .collect::<Vec<_>>();
682 let save_path = Path::new("plot_ift_loss.html").to_path_buf();
683 plot_losses(
684 epochs_seen,
685 tokens_seen,
686 train_losses,
687 val_losses,
688 save_path,
689 )?;
690
691 Ok(())
692 }
693}
694
695pub struct EG11;
712
713impl Example for EG11 {
714 fn description(&self) -> String {
715 let desc = "Example of extracting model-generated responses and \
716 comparing to correct ones";
717 desc.to_string()
718 }
719
720 fn page_source(&self) -> usize {
721 234_usize
722 }
723
724 fn main(&self) -> Result<()> {
725 use crate::listings::{
726 ch04::{Config, GPTModel},
727 ch05::{generate, text_to_token_ids, token_ids_to_text},
728 ch07::{AlpacaPromptFormatter, InstructionExample, PromptFormatter},
729 };
730 use candle_core::{DType, Device, Tensor};
731 use candle_nn::{VarBuilder, VarMap};
732 use rand::{rngs::StdRng, SeedableRng};
733 use tiktoken_rs::get_bpe_from_model;
734
735 let mut cfg = Config::gpt2_124m(); cfg.qkv_bias = true;
738 let mut varmap = VarMap::new();
739 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
740 let model = GPTModel::new(cfg, vb.pp("model"))?;
741
742 varmap
744 .load("ift.checkpoint.safetensors")
745 .with_context(|| "Missing 'ift.checkpoint.safetensors' file. Please run EG 07.10.")?;
746
747 let eg07 = EG07;
749 let (_train_loader, _val_loader, test_loader) = eg07.main_with_return(8_usize, false)?;
750 let tokenizer = get_bpe_from_model("gpt2")?;
751 let mut rng = StdRng::seed_from_u64(42_u64);
752 let prompt_formatter = AlpacaPromptFormatter;
753
754 for entry in &test_loader.dataset().data()[..3] {
755 let input_text = prompt_formatter.format_input(entry);
756 let token_ids = generate(
757 &model,
758 text_to_token_ids(&input_text[..], &tokenizer, vb.device())?,
759 256_usize,
760 cfg.context_length,
761 None,
762 None,
763 Some(Tensor::new(&[50_256_u32], vb.device())?),
764 &mut rng,
765 )?;
766 let generated_text = token_ids_to_text(token_ids, &tokenizer)?;
767 let response_text = &generated_text[input_text.len()..].trim();
768
769 println!("{}", input_text);
771 println!("\nCorrect response:\n>>{}", entry.output());
772 println!("\nModel response:\n>>{}", response_text.trim());
773 println!("-----------------------------------------");
774 }
775
776 Ok(())
777 }
778}
779
780pub struct EG12;
797
798impl Example for EG12 {
799 fn description(&self) -> String {
800 "Example usage of `generate_test_set_responses`.".to_string()
801 }
802
803 fn page_source(&self) -> usize {
804 237_usize
805 }
806
807 fn main(&self) -> Result<()> {
808 use crate::listings::{
809 ch04::{Config, GPTModel},
810 ch07::{generate_test_set_responses, AlpacaPromptFormatter, DATA_DIR},
811 };
812 use candle_core::{DType, Device};
813 use candle_nn::{VarBuilder, VarMap};
814 use std::path::Path;
815
816 let mut cfg = Config::gpt2_124m(); cfg.qkv_bias = true;
819 let mut varmap = VarMap::new();
820 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
821 let model = GPTModel::new(cfg, vb.pp("model"))?;
822
823 varmap
825 .load("ift.checkpoint.safetensors")
826 .with_context(|| "Missing 'ift.checkpoint.safetensors' file. Please run EG 07.10.")?;
827
828 let eg07 = EG07;
830 let (_train_loader, _val_loader, test_loader) = eg07.main_with_return(8_usize, false)?;
831
832 let save_path = Path::new(DATA_DIR).join("instruction_data_with_response.json");
834 let prompt_formatter = AlpacaPromptFormatter;
835 let mut test_data = test_loader.dataset().data().clone();
836 generate_test_set_responses(
837 &mut test_data,
838 &model,
839 cfg.context_length,
840 vb.device(),
841 save_path,
842 &prompt_formatter,
843 )?;
844
845 println!("{}", test_data[0]);
846
847 Ok(())
848 }
849}
850
851pub struct EG13;
868
869impl Example for EG13 {
870 fn description(&self) -> String {
871 "An example to check if `ollama` process is running.".to_string()
872 }
873
874 fn page_source(&self) -> usize {
875 241_usize
876 }
877
878 fn main(&self) -> Result<()> {
879 use anyhow::anyhow;
880 use sysinfo::System;
881
882 let sys = System::new_all();
883 let mut ollama_processes = sys.processes_by_exact_name("ollama".as_ref());
884 let _ = ollama_processes.next().ok_or(anyhow!(
885 "Ollama not running. Launch ollama before proceeding."
886 ))?;
887
888 println!("Ollama running");
889
890 Ok(())
891 }
892}
893
894pub struct EG14;
911
912impl Example for EG14 {
913 fn description(&self) -> String {
914 "Example usage of `query_model`.".to_string()
915 }
916
917 fn page_source(&self) -> usize {
918 243_usize
919 }
920
921 fn main(&self) -> Result<()> {
922 use crate::listings::ch07::{query_model, DEFAULT_OLLAMA_API_URL};
923
924 let model = "llama3";
925 let result = query_model("What do Llamas eat?", model, DEFAULT_OLLAMA_API_URL)?;
926
927 println!("{}", result);
928 Ok(())
929 }
930}
931
932pub struct EG15;
949
950impl Example for EG15 {
951 fn description(&self) -> String {
952 let desc = "Using Llama3.2 (via Ollama) as the LLM judge to evaluate model responses.";
953 desc.to_string()
954 }
955
956 fn page_source(&self) -> usize {
957 244_usize
958 }
959
960 fn main(&self) -> Result<()> {
961 use crate::listings::ch07::{
962 load_instruction_data_from_json, query_model, AlpacaPromptFormatter,
963 InstructionExample, InstructionResponseExample, PromptFormatter, DATA_DIR,
964 DEFAULT_OLLAMA_API_URL,
965 };
966 use std::path::Path;
967
968 let file_path = Path::new(DATA_DIR).join("instruction_data_with_response.json");
970 let test_data: Vec<InstructionResponseExample> = load_instruction_data_from_json(file_path)
971 .with_context(|| {
972 "Missing 'instruction_data_with_response.json' file. Please run EG 07.12."
973 })?;
974
975 let model = "llama3";
976 let prompt_formatter = AlpacaPromptFormatter;
977 for (ix, entry) in test_data.iter().enumerate().take(3_usize) {
978 let model_response = entry
979 .model_response()
980 .as_ref()
981 .ok_or_else(|| anyhow!("Entry {ix} is missing a model response."))?;
982 let prompt = format!(
983 "Given the input `{}` and the correct output `{}`, score the \
984 model response `{}` on a scale from 0 to 100, where 100 is the ]
985 best score.",
986 prompt_formatter.format_input(entry),
987 entry.output(),
988 model_response
989 );
990
991 println!("\nDataset response:");
992 println!("\n>>{}", entry.output());
993 println!("\nModel response:");
994 println!("\n>>{}", model_response);
995 println!("\nScore:");
996 println!(
997 "\n>>{}",
998 query_model(prompt.as_str(), model, DEFAULT_OLLAMA_API_URL)?
999 );
1000 }
1001
1002 Ok(())
1003 }
1004}
1005
1006pub struct EG16;
1023
1024impl Example for EG16 {
1025 fn description(&self) -> String {
1026 "Example usage of `generate_model_scores`.".to_string()
1027 }
1028
1029 fn page_source(&self) -> usize {
1030 246_usize
1031 }
1032
1033 fn main(&self) -> Result<()> {
1034 use crate::listings::ch07::{
1035 generate_model_scores, load_instruction_data_from_json, AlpacaPromptFormatter,
1036 DATA_DIR, DEFAULT_OLLAMA_API_URL,
1037 };
1038 use std::path::Path;
1039
1040 let file_path = Path::new(DATA_DIR).join("instruction_data_with_response.json");
1042 let test_data = load_instruction_data_from_json(file_path).with_context(|| {
1043 "Missing 'instruction_data_with_response.json' file. Please run EG 07.12."
1044 })?;
1045
1046 let model = "llama3";
1048 let prompt_formatter = AlpacaPromptFormatter;
1049 let scores =
1050 generate_model_scores(&test_data, DEFAULT_OLLAMA_API_URL, model, &prompt_formatter)?;
1051
1052 println!("Number of scores: {} of {}", scores.len(), test_data.len());
1054 let average_score = scores.iter().sum::<f32>() / scores.len() as f32;
1055 println!("Average score: {}", average_score);
1056
1057 Ok(())
1058 }
1059}
1060
1061pub struct EG17;
1078
1079impl Example for EG17 {
1080 fn description(&self) -> String {
1081 let desc = "[Bonus from DPO notebook] Usage of \
1082 `generate_chosen_and_rejected_response` to create preference example.";
1083 desc.to_string()
1084 }
1085
1086 fn page_source(&self) -> usize {
1087 0_usize
1088 }
1089
1090 fn main(&self) -> Result<()> {
1091 use crate::listings::ch07::{
1092 bonus::generate_chosen_and_rejected_response, download_and_load_file,
1093 AlpacaPromptFormatter, DATA_DIR, DEFAULT_OLLAMA_API_URL, INSTRUCTION_DATA_FILENAME,
1094 INSTRUCTION_DATA_URL,
1095 };
1096 use rand::{rngs::StdRng, SeedableRng};
1097 use std::path::Path;
1098
1099 let file_path = Path::new(DATA_DIR).join(INSTRUCTION_DATA_FILENAME);
1101 let data = download_and_load_file(file_path, INSTRUCTION_DATA_URL, false)?;
1102
1103 let model = "llama3";
1105 let prompt_formatter = AlpacaPromptFormatter;
1106 let mut rng = StdRng::seed_from_u64(42_u64);
1107 let preference_example = generate_chosen_and_rejected_response(
1108 &data[42],
1109 DEFAULT_OLLAMA_API_URL,
1110 model,
1111 &prompt_formatter,
1112 &mut rng,
1113 )?;
1114
1115 println!("{:#?}", preference_example);
1116
1117 Ok(())
1118 }
1119}
1120
1121pub struct EG18;
1138
1139impl Example for EG18 {
1140 fn description(&self) -> String {
1141 "[Bonus from DPO notebook] Example usage of `generate_preference_dataset`.".to_string()
1142 }
1143
1144 fn page_source(&self) -> usize {
1145 0_usize
1146 }
1147
1148 fn main(&self) -> Result<()> {
1149 use crate::listings::{
1150 ch07::bonus::generate_preference_dataset,
1151 ch07::{
1152 download_and_load_file, AlpacaPromptFormatter, DATA_DIR, DEFAULT_OLLAMA_API_URL,
1153 INSTRUCTION_DATA_FILENAME, INSTRUCTION_DATA_URL,
1154 },
1155 };
1156 use std::path::Path;
1157
1158 let file_path = Path::new(DATA_DIR).join(INSTRUCTION_DATA_FILENAME);
1160 let data = download_and_load_file(file_path, INSTRUCTION_DATA_URL, false)?;
1161
1162 let model = "llama3";
1164 let prompt_formatter = AlpacaPromptFormatter;
1165 let save_path = Path::new(DATA_DIR).join("instruction_data_with_preference.json");
1166 generate_preference_dataset(
1167 &data,
1168 DEFAULT_OLLAMA_API_URL,
1169 model,
1170 &prompt_formatter,
1171 save_path,
1172 )?;
1173
1174 Ok(())
1175 }
1176}
1177
1178pub struct EG19;
1195
1196impl Example for EG19 {
1197 fn description(&self) -> String {
1198 let desc = "[Bonus from DPO notebook] Example usage of \
1199 `PreferenceDataCollator.custom_collate_fn`.";
1200 desc.to_string()
1201 }
1202
1203 fn page_source(&self) -> usize {
1204 0_usize
1205 }
1206
1207 fn main(&self) -> Result<()> {
1208 use crate::listings::{
1209 ch05::token_ids_to_text,
1210 ch07::bonus::{
1211 CustomCollator, EncodedPreferenceExample, PreferenceDataCollator, PreferenceExample,
1212 },
1213 ch07::{load_instruction_data_from_json, AlpacaPromptFormatter, DATA_DIR},
1214 };
1215 use candle_core::{Device, IndexOp};
1216 use std::path::Path;
1217 use tiktoken_rs::get_bpe_from_model;
1218
1219 let tokenizer = get_bpe_from_model("gpt2")?;
1220 let prompt_formatter = AlpacaPromptFormatter;
1221
1222 let file_path = Path::new(DATA_DIR).join("instruction_data_with_preference.json");
1224 let preference_data: Vec<PreferenceExample> = load_instruction_data_from_json(file_path)
1225 .with_context(|| {
1226 "Missing 'instruction_data_with_preference.json' file. Please run EG 07.18."
1227 })?;
1228
1229 let sample = preference_data
1231 .into_iter()
1232 .take(2_usize)
1233 .collect::<Vec<_>>();
1234
1235 println!("Sample Preference Examples:\n{:#?}", sample);
1236
1237 let batch = sample
1238 .into_iter()
1239 .map(|el| EncodedPreferenceExample::from_example(&el, &prompt_formatter, &tokenizer))
1240 .collect::<Vec<_>>();
1241
1242 let collator = PreferenceDataCollator::new().device(Device::cuda_if_available(0)?);
1243 let collated_item = collator.collate(batch)?;
1244
1245 println!(
1247 "\nCollated Batch: Prompt Tokens\n\n{:?}\n",
1248 collated_item
1249 .prompt()
1250 .iter()
1251 .map(|el| el.to_vec1::<u32>())
1252 .collect::<Vec<_>>()
1253 );
1254
1255 println!(
1257 "\nCollated Batch: Chosen Tokens\n\n{:?}\n",
1258 collated_item.chosen().to_vec2::<u32>()?
1259 );
1260
1261 let prompt_text = token_ids_to_text(collated_item.prompt()[1].clone(), &tokenizer)?;
1263 println!("\nCollated Batch Item 1: Prompt Text\n\n{}\n", prompt_text);
1264
1265 let chosen = collated_item.chosen().i((1, ..))?;
1267 let chosen_text = token_ids_to_text(chosen.clone(), &tokenizer)?;
1268 println!("\nCollated Batch Item 1: Chosen Text\n\n{}\n", chosen_text);
1269
1270 let rejected = collated_item.rejected().i((1, ..))?;
1272 let rejected_text = token_ids_to_text(rejected.clone(), &tokenizer)?;
1273 println!(
1274 "\nCollated Batch Item 1: Rejected Text\n\n{}\n",
1275 rejected_text
1276 );
1277
1278 let chosen_mask = &collated_item.chosen_mask().i((1, ..))?;
1280 println!("\nCollated Batch: Masks\n");
1281 println!("Chosen inputs: {:?}", chosen);
1282 println!("Chosen mask: {:?}", chosen_mask);
1283
1284 println!(
1285 "\nCollated Batch Item 1: Chosen Mask Indexes\n\n{:?}\n",
1286 chosen_mask.to_vec1::<u32>()?
1287 );
1288
1289 Ok(())
1310 }
1311}
1312
1313pub struct EG20;
1330
1331impl Example for EG20 {
1332 fn description(&self) -> String {
1333 "[BONUS] Creating a `PreferenceDataLoader`.".to_string()
1334 }
1335
1336 fn page_source(&self) -> usize {
1337 0_usize
1338 }
1339
1340 fn main(&self) -> Result<()> {
1341 use crate::listings::{
1342 ch07::bonus::{
1343 PreferenceDataCollator, PreferenceDataLoader, PreferenceDataset, PreferenceExample,
1344 },
1345 ch07::{
1346 load_instruction_data_from_json, partition_data, AlpacaPromptFormatter, DataLoader,
1347 DATA_DIR,
1348 },
1349 };
1350 use candle_core::Device;
1351 use std::path::Path;
1352 use tiktoken_rs::get_bpe_from_model;
1353
1354 let tokenizer = get_bpe_from_model("gpt2")?;
1355 let prompt_formatter = AlpacaPromptFormatter;
1356
1357 let file_path = Path::new(DATA_DIR).join("instruction_data_with_preference.json");
1359 let preference_data: Vec<PreferenceExample> = load_instruction_data_from_json(file_path)
1360 .with_context(|| {
1361 "Missing 'instruction_data_with_preference.json' file. Please run EG 07.18."
1362 })?;
1363
1364 let (train_data, val_data, test_data) =
1366 partition_data(preference_data, 0.85_f32, 0.05_f32)?;
1367 let train_dataset = PreferenceDataset::new(train_data, &tokenizer, &prompt_formatter);
1368 let val_dataset = PreferenceDataset::new(val_data, &tokenizer, &prompt_formatter);
1369 let test_dataset = PreferenceDataset::new(test_data, &tokenizer, &prompt_formatter);
1370
1371 let collator = PreferenceDataCollator::new().device(Device::cuda_if_available(0)?);
1373 let batch_size = 3_usize;
1374 let train_loader =
1375 PreferenceDataLoader::new(train_dataset, batch_size, true, true, collator.clone());
1376 let val_loader =
1377 PreferenceDataLoader::new(val_dataset, batch_size, false, false, collator.clone());
1378 let test_loader =
1379 PreferenceDataLoader::new(test_dataset, batch_size, false, false, collator);
1380
1381 println!("Train loader:");
1382 let mut batcher = train_loader.batcher();
1383 while let Some(Ok(batch)) = batcher.next() {
1384 println!("batch: {:#?}", batch);
1385 }
1386
1387 println!("Val loader:");
1388 let mut batcher = val_loader.batcher();
1389 while let Some(Ok(batch)) = batcher.next() {
1390 println!("batch: {:#?}", batch);
1391 }
1392
1393 println!("Test loader:");
1394 let mut batcher = test_loader.batcher();
1395 while let Some(Ok(batch)) = batcher.next() {
1396 println!("batch: {:#?}", batch);
1397 }
1398
1399 Ok(())
1400 }
1401}
1402
1403pub struct EG21;
1420
1421impl Example for EG21 {
1422 fn description(&self) -> String {
1423 "[BONUS] Example usage of `compute_dpo_loss_batch`.".to_string()
1424 }
1425
1426 fn page_source(&self) -> usize {
1427 0_usize
1428 }
1429
1430 fn main(&self) -> Result<()> {
1431 use crate::listings::{
1432 ch04::{Config, GPTModel},
1433 ch07::bonus::{
1434 compute_dpo_loss_batch, PreferenceDataCollator, PreferenceDataLoader,
1435 PreferenceDataset, PreferenceExample,
1436 },
1437 ch07::{
1438 load_instruction_data_from_json, partition_data, AlpacaPromptFormatter, DataLoader,
1439 DATA_DIR,
1440 },
1441 };
1442 use candle_core::{DType, Device};
1443 use candle_nn::{VarBuilder, VarMap};
1444 use std::path::Path;
1445 use tiktoken_rs::get_bpe_from_model;
1446
1447 let tokenizer = get_bpe_from_model("gpt2")?;
1448 let prompt_formatter = AlpacaPromptFormatter;
1449
1450 let mut cfg = Config::gpt2_124m(); cfg.qkv_bias = true;
1453 let mut varmap = VarMap::new();
1454 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
1455 let policy_model = GPTModel::new(cfg, vb.pp("model"))?;
1456
1457 varmap
1459 .load("ift.checkpoint.safetensors")
1460 .with_context(|| "Missing 'ift.checkpoint.safetensors' file. Please run EG 07.10.")?;
1461
1462 let vb2 = vb.clone();
1463 let reference_model = GPTModel::new(cfg, vb2.pp("model"))?;
1464
1465 let file_path = Path::new(DATA_DIR).join("instruction_data_with_preference.json");
1467 let preference_data: Vec<PreferenceExample> = load_instruction_data_from_json(file_path)
1468 .with_context(|| {
1469 "Missing 'instruction_data_with_preference.json' file. Please run EG 07.18."
1470 })?;
1471
1472 let (train_data, _val_data, _test_data) =
1474 partition_data(preference_data, 0.85_f32, 0.05_f32)?;
1475 let train_dataset = PreferenceDataset::new(train_data, &tokenizer, &prompt_formatter);
1476
1477 let collator = PreferenceDataCollator::new().device(Device::cuda_if_available(0)?);
1479 let batch_size = 2_usize;
1480 let train_loader =
1481 PreferenceDataLoader::new(train_dataset, batch_size, true, true, collator.clone());
1482
1483 let mut batcher = train_loader.batcher();
1485
1486 if let Some(Ok(batch)) = batcher.next() {
1487 let (loss, chosen_r, rejected_r) =
1488 compute_dpo_loss_batch(&batch, &policy_model, &reference_model, 0.1, false)?;
1489 println!("{:?}", loss);
1490 println!("{:?}", chosen_r);
1491 println!("{:?}", rejected_r);
1492 }
1493
1494 Ok(())
1495 }
1496}
1497
1498pub struct EG22;
1515
1516impl Example for EG22 {
1517 fn description(&self) -> String {
1518 "[BONUS] Example usage of `compute_dpo_loss_loader`.".to_string()
1519 }
1520
1521 fn page_source(&self) -> usize {
1522 0_usize
1523 }
1524
1525 fn main(&self) -> Result<()> {
1526 use crate::listings::{
1527 ch04::{Config, GPTModel},
1528 ch07::bonus::{
1529 compute_dpo_loss_loader, PreferenceDataCollator, PreferenceDataLoader,
1530 PreferenceDataset, PreferenceExample,
1531 },
1532 ch07::{
1533 load_instruction_data_from_json, partition_data, AlpacaPromptFormatter, DATA_DIR,
1534 },
1535 };
1536 use candle_core::{DType, Device};
1537 use candle_nn::{VarBuilder, VarMap};
1538 use std::path::Path;
1539 use tiktoken_rs::get_bpe_from_model;
1540
1541 let tokenizer = get_bpe_from_model("gpt2")?;
1542 let prompt_formatter = AlpacaPromptFormatter;
1543
1544 let mut cfg = Config::gpt2_124m(); cfg.qkv_bias = true;
1547 let mut varmap = VarMap::new();
1548 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
1549 let policy_model = GPTModel::new(cfg, vb.pp("model"))?;
1550
1551 varmap
1553 .load("ift.checkpoint.safetensors")
1554 .with_context(|| "Missing 'ift.checkpoint.safetensors' file. Please run EG 07.10.")?;
1555
1556 let vb2 = vb.clone();
1557 let reference_model = GPTModel::new(cfg, vb2.pp("model"))?;
1558
1559 let file_path = Path::new(DATA_DIR).join("instruction_data_with_preference.json");
1561 let preference_data: Vec<PreferenceExample> = load_instruction_data_from_json(file_path)
1562 .with_context(|| {
1563 "Missing 'instruction_data_with_preference.json' file. Please run EG 07.18."
1564 })?;
1565
1566 let (train_data, _val_data, _test_data) =
1568 partition_data(preference_data, 0.85_f32, 0.05_f32)?;
1569 let train_dataset = PreferenceDataset::new(train_data, &tokenizer, &prompt_formatter);
1570
1571 let collator = PreferenceDataCollator::new().device(Device::cuda_if_available(0)?);
1573 let batch_size = 2_usize;
1574 let train_loader =
1575 PreferenceDataLoader::new(train_dataset, batch_size, true, true, collator.clone());
1576
1577 let num_batches = Some(5_usize);
1579 let (train_loss, chosen_rewards, rejected_rewards) = compute_dpo_loss_loader(
1580 &train_loader,
1581 &policy_model,
1582 &reference_model,
1583 0.1,
1584 num_batches,
1585 false,
1586 )?;
1587
1588 println!("Training loss: {}", train_loss);
1589 println!("Chosen rewards: {}", chosen_rewards);
1590 println!("Rejected rewards: {}", rejected_rewards);
1591
1592 Ok(())
1593 }
1594}
1595
1596pub struct EG23;
1613
1614impl Example for EG23 {
1615 fn description(&self) -> String {
1616 "[BONUS] Example usage of `train_model_dpo_simple` and `plot_losses` functions".to_string()
1617 }
1618
1619 fn page_source(&self) -> usize {
1620 0_usize
1621 }
1622
1623 fn main(&self) -> Result<()> {
1625 use crate::listings::{
1626 ch04::Config,
1627 ch07::bonus::{
1628 train_model_dpo_simple, PreferenceDataCollator, PreferenceDataLoader,
1629 PreferenceDataset, PreferenceExample,
1630 },
1631 ch07::{
1632 load_instruction_data_from_json, partition_data, AlpacaPromptFormatter,
1633 PromptFormatter, DATA_DIR,
1634 },
1635 };
1636 use candle_core::{DType, Device};
1637 use candle_nn::{AdamW, Optimizer, ParamsAdamW, VarBuilder, VarMap};
1638 use std::path::{Path, PathBuf};
1639 use std::str::FromStr;
1640 use tiktoken_rs::get_bpe_from_model;
1641
1642 let tokenizer = get_bpe_from_model("gpt2")?;
1643 let prompt_formatter = AlpacaPromptFormatter;
1644
1645 let file_path = Path::new(DATA_DIR).join("instruction_data_with_preference.json");
1647 let preference_data: Vec<PreferenceExample> = load_instruction_data_from_json(file_path)
1648 .with_context(|| {
1649 "Missing 'instruction_data_with_preference.json' file. Please run EG 07.18."
1650 })?;
1651
1652 let (train_data, val_data, _test_data) =
1654 partition_data(preference_data, 0.85_f32, 0.05_f32)?;
1655 let train_dataset = PreferenceDataset::new(train_data, &tokenizer, &prompt_formatter);
1656 let val_dataset = PreferenceDataset::new(val_data, &tokenizer, &prompt_formatter);
1657
1658 let collator = PreferenceDataCollator::new().device(Device::cuda_if_available(0)?);
1660 let batch_size = 3_usize;
1661 let train_loader =
1662 PreferenceDataLoader::new(train_dataset, batch_size, true, true, collator.clone());
1663 let val_loader =
1664 PreferenceDataLoader::new(val_dataset, batch_size, false, false, collator.clone());
1665
1666 println!("Train loader: {}", train_loader.len());
1667 println!("Val loader: {}", val_loader.len());
1668
1669 let mut cfg = Config::gpt2_124m(); cfg.qkv_bias = true;
1672 let mut varmap = VarMap::new();
1673 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
1674 let checkpoint_path = PathBuf::from_str("ift.checkpoint.safetensors")
1675 .with_context(|| "Missing 'ift.checkpoint.safetensors' file. Please run EG 07.10.")?;
1676 let policy_model =
1677 addons::load_ift_gpt_model(&mut varmap, vb.pp("model"), cfg, &checkpoint_path)?;
1678
1679 let mut varmap_reference = VarMap::new();
1681 let vb_reference = VarBuilder::from_varmap(
1682 &varmap_reference,
1683 DType::F32,
1684 &Device::cuda_if_available(0)?,
1685 );
1686 let reference_model = addons::load_ift_gpt_model(
1687 &mut varmap_reference,
1688 vb_reference.pp("model"),
1689 cfg,
1690 &checkpoint_path,
1691 )?;
1692
1693 let (eval_freq, eval_iter, num_epochs) = (5_usize, 5_usize, 2_usize);
1695 let optimizer = AdamW::new(
1696 varmap.all_vars(),
1697 ParamsAdamW {
1698 lr: 0.00005,
1699 weight_decay: 0.1,
1700 ..Default::default()
1701 },
1702 )?;
1703 let tokenizer = get_bpe_from_model("gpt2")?;
1704 let prompt_formatter = AlpacaPromptFormatter;
1705 let start_context = prompt_formatter.format_input(&val_loader.dataset().data()[0]);
1706 let tracking = train_model_dpo_simple(
1707 &policy_model,
1708 &reference_model,
1709 &train_loader,
1710 &val_loader,
1711 0.1,
1712 optimizer,
1713 vb.device(),
1714 num_epochs,
1715 eval_freq,
1716 eval_iter,
1717 start_context.as_str(),
1718 &tokenizer,
1719 )?;
1720
1721 println!("{:#?}", tracking);
1722
1723 println!("Saving weights to `./dpo.checkpoint.safetensors`");
1725 varmap.save("dpo.checkpoint.safetensors")?;
1726
1727 Ok(())
1728 }
1729}
1730
1731pub mod addons {
1732 use crate::listings::ch04::{Config, GPTModel};
1734 use anyhow::Context;
1735 use candle_nn::{VarBuilder, VarMap};
1736
1737 pub fn load_ift_gpt_model<P>(
1739 varmap: &mut VarMap,
1740 vb: VarBuilder<'_>,
1741 cfg: Config,
1742 checkpoint_path: &P,
1743 ) -> anyhow::Result<GPTModel>
1744 where
1745 P: AsRef<std::path::Path> + std::fmt::Debug,
1746 {
1747 let ift_model = GPTModel::new(cfg, vb)?;
1748
1749 varmap
1751 .load(checkpoint_path)
1752 .with_context(|| format!("Missing '{:?}' file.", checkpoint_path))?;
1753
1754 Ok(ift_model)
1755 }
1756}