llms_from_scratch_rs/examples/
ch07.rs

1//! Examples from Chapter 7
2
3use crate::Example;
4use anyhow::{anyhow, Context, Result};
5
6/// # Example usage of `download_and_load_file`
7///
8/// #### Id
9/// 07.01
10///
11/// #### Page
12/// This example starts on page 207
13///
14/// #### CLI command
15/// ```sh
16/// # without cuda
17/// cargo run example 07.01
18///
19/// # with cuda
20/// cargo run --features cuda example 07.01
21/// ```
22pub struct EG01;
23
24impl Example for EG01 {
25    fn description(&self) -> String {
26        "Example usage of `download_and_load_file`.".to_string()
27    }
28
29    fn page_source(&self) -> usize {
30        207_usize
31    }
32
33    fn main(&self) -> Result<()> {
34        use crate::listings::ch07::{
35            download_and_load_file, DATA_DIR, INSTRUCTION_DATA_FILENAME, INSTRUCTION_DATA_URL,
36        };
37        use std::path::Path;
38
39        let file_path = Path::new(DATA_DIR).join(INSTRUCTION_DATA_FILENAME);
40        let data = download_and_load_file(file_path, INSTRUCTION_DATA_URL, false)?;
41        println!("Number of entries: {}", data.len());
42
43        // See example at index 50
44        println!("Example entry:\n{}\n", data[50]);
45
46        // See another example at index 999
47        println!("Another example entry:\n{}", data[999]);
48
49        Ok(())
50    }
51}
52
53/// # Example usage of `format_input`
54///
55/// #### Id
56/// 07.02
57///
58/// #### Page
59/// This example starts on page 209
60///
61/// #### CLI command
62/// ```sh
63/// # without cuda
64/// cargo run example 07.02
65///
66/// # with cuda
67/// cargo run --features cuda example 07.02
68/// ```
69pub struct EG02;
70
71impl Example for EG02 {
72    fn description(&self) -> String {
73        "Example usage of `format_input`.".to_string()
74    }
75
76    fn page_source(&self) -> usize {
77        209_usize
78    }
79
80    fn main(&self) -> Result<()> {
81        use crate::listings::ch07::{
82            download_and_load_file, AlpacaPromptFormatter, InstructionExample, PromptFormatter,
83            DATA_DIR, INSTRUCTION_DATA_FILENAME, INSTRUCTION_DATA_URL,
84        };
85        use std::path::Path;
86
87        // load instruction examples
88        let file_path = Path::new(DATA_DIR).join(INSTRUCTION_DATA_FILENAME);
89        let data = download_and_load_file(file_path, INSTRUCTION_DATA_URL, false)?;
90        let prompt_formatter = AlpacaPromptFormatter;
91
92        // first model input
93        let model_input = prompt_formatter.format_input(&data[50]);
94        let detailed_response = format!("\n\n### Response:\n{}", data[50].output());
95        println!("{}", model_input + &detailed_response);
96
97        // print a separator
98        println!("\n---\n");
99
100        // print another model input
101        let model_input = prompt_formatter.format_input(&data[999]);
102        let detailed_response = format!("\n\n### Response:\n{}", data[999].output());
103        println!("{}", model_input + &detailed_response);
104
105        Ok(())
106    }
107}
108
109/// # Example usage of `partition_data`
110///
111/// #### Id
112/// 07.03
113///
114/// #### Page
115/// This example starts on page 210
116///
117/// #### CLI command
118/// ```sh
119/// # without cuda
120/// cargo run example 07.03
121///
122/// # with cuda
123/// cargo run --features cuda example 07.03
124/// ```
125pub struct EG03;
126
127impl Example for EG03 {
128    fn description(&self) -> String {
129        String::from("Example usage of `partition_data`")
130    }
131
132    fn page_source(&self) -> usize {
133        210_usize
134    }
135
136    fn main(&self) -> Result<()> {
137        use crate::listings::ch07::{
138            download_and_load_file, partition_data, DATA_DIR, INSTRUCTION_DATA_FILENAME,
139            INSTRUCTION_DATA_URL,
140        };
141        use std::path::Path;
142
143        // load instruction examples
144        let file_path = Path::new(DATA_DIR).join(INSTRUCTION_DATA_FILENAME);
145        let data = download_and_load_file(file_path, INSTRUCTION_DATA_URL, false)?;
146
147        // partition data
148        let (train_data, val_data, test_data) = partition_data(data, 0.85_f32, 0.05_f32)?;
149
150        println!("Training set length: {}", train_data.len());
151        println!("Validation set length: {}", val_data.len());
152        println!("Test set length: {}", test_data.len());
153
154        Ok(())
155    }
156}
157
158/// # Example usage of `<|endoftext|>` special token with tiktoken
159///
160/// #### Id
161/// 07.04
162///
163/// #### Page
164/// This example starts on page 214
165///
166/// #### CLI command
167/// ```sh
168/// # without cuda
169/// cargo run example 07.04
170///
171/// # with cuda
172/// cargo run --features cuda example 07.04
173/// ```
174pub struct EG04;
175
176impl Example for EG04 {
177    fn description(&self) -> String {
178        "Example usage of `<|endoftext|>` special token with tiktoken.".to_string()
179    }
180
181    fn page_source(&self) -> usize {
182        214_usize
183    }
184
185    fn main(&self) -> Result<()> {
186        use std::collections::HashSet;
187        use tiktoken_rs::get_bpe_from_model;
188
189        let allowed_special = HashSet::from(["<|endoftext|>"]);
190        let tokenizer = get_bpe_from_model("gpt2")?;
191        println!("{:?}", tokenizer.encode("<|endoftext|>", allowed_special));
192
193        Ok(())
194    }
195}
196
197/// # Example usage of `InstructionDataCollator.custom_collate_fn`
198///
199/// #### Id
200/// 07.05
201///
202/// #### Page
203/// This example starts on page 220
204///
205/// #### CLI command
206/// ```sh
207/// # without cuda
208/// cargo run example 07.05
209///
210/// # with cuda
211/// cargo run --features cuda example 07.05
212/// ```
213pub struct EG05;
214
215impl Example for EG05 {
216    fn description(&self) -> String {
217        String::from("Example usage of `InstructionDataCollator.custom_collate_fn`.")
218    }
219
220    fn page_source(&self) -> usize {
221        220_usize
222    }
223
224    fn main(&self) -> Result<()> {
225        use crate::listings::ch07::InstructionDataCollator;
226        use candle_core::{Device, Tensor};
227
228        let device = Device::cuda_if_available(0)?;
229        let inputs_1 = Tensor::new(&[0_u32, 1, 2, 3, 4], &device)?;
230        let inputs_2 = Tensor::new(&[5_u32, 6], &device)?;
231        let inputs_3 = Tensor::new(&[7_u32, 8, 9], &device)?;
232        let batch = vec![inputs_1, inputs_2, inputs_3];
233
234        let collator = InstructionDataCollator::new();
235        let (inputs, targets) = collator.custom_collate_fn(batch)?;
236
237        println!("inputs:\n{:?}", inputs.to_vec2::<u32>()?);
238        println!("targets:\n{:?}", targets.to_vec2::<i64>()?);
239
240        Ok(())
241    }
242}
243
244/// # An adapted example demonstrating effect of `ignore_index` in `calc_loss_batch`
245///
246/// #### Id
247/// 07.06
248///
249/// #### Page
250/// This example starts on page 221
251///
252/// #### CLI command
253/// ```sh
254/// # without cuda
255/// cargo run example 07.06
256///
257/// # with cuda
258/// cargo run --features cuda example 07.06
259/// ```
260pub struct EG06;
261
262impl Example for EG06 {
263    fn description(&self) -> String {
264        "An adapted example demonstrating effect of `ignore_index` in `calc_loss_batch`."
265            .to_string()
266    }
267
268    fn page_source(&self) -> usize {
269        221_usize
270    }
271
272    /// In this example, we make a slight modification to the one found in the book
273    /// as the `candle_nn::loss::cross_entropy()` method does not allow for `ignore_index`.
274    /// So, we opt to implement such logic within `calc_loss_batch`.
275    fn main(&self) -> Result<()> {
276        use crate::listings::{
277            ch04::{Config, GPTModel},
278            ch05::{calc_loss_batch, DEFAULT_IGNORE_INDEX},
279        };
280        use candle_core::{DType, Device, Tensor};
281        use candle_nn::{VarBuilder, VarMap};
282
283        // create model
284        let varmap = VarMap::new();
285        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
286        let cfg = Config::gpt_sm_test();
287        let model = GPTModel::new(cfg, vb.pp("model"))?;
288
289        // create sample inputs
290        let inputs = Tensor::new(&[[100_u32, 20, 300]], vb.device())?;
291        let targets = Tensor::new(&[[1_u32, 2, 3]], vb.device())?;
292        let loss = calc_loss_batch(&inputs, &targets, &model, vb.device(), false, None)?;
293
294        println!("Inputs: {:?}", inputs.to_vec2::<u32>()?);
295        println!("Targets: {:?}", inputs.to_vec2::<u32>()?);
296        println!("Loss: {:?}", loss);
297
298        // Note targets that use ignore_index will now be a Tensor of Dtype::I64
299        let inputs_2 = Tensor::new(&[[100_u32, 20, 300], [400, 7, 88]], vb.device())?;
300        let targets_2 = Tensor::new(
301            &[
302                [1_i64, 2, 3],
303                [
304                    DEFAULT_IGNORE_INDEX,
305                    DEFAULT_IGNORE_INDEX,
306                    DEFAULT_IGNORE_INDEX,
307                ],
308            ],
309            vb.device(),
310        )?;
311        let loss_2 = calc_loss_batch(
312            &inputs_2,
313            &targets_2,
314            &model,
315            vb.device(),
316            false,
317            Some(DEFAULT_IGNORE_INDEX),
318        )?;
319
320        println!(
321            "---\nSimilar inputs but now a second sequence whose target has the ignore index:\n"
322        );
323
324        println!("Inputs: {:?}", inputs_2.to_vec2::<u32>()?);
325        println!("Targets: {:?}", targets_2.to_vec2::<i64>()?);
326        println!("Loss: {:?}", loss_2);
327
328        Ok(())
329    }
330}
331
332/// # Creating a `InstructionDataLoader` for each of the train, val and test data partitions
333///
334/// #### Id
335/// 07.07
336///
337/// #### Page
338/// This example starts on page 225
339///
340/// #### CLI command
341/// ```sh
342/// # without cuda
343/// cargo run example 07.07
344///
345/// # with cuda
346/// cargo run --features cuda example 07.07
347/// ```
348pub struct EG07;
349
350impl EG07 {
351    pub fn main_with_return(
352        &self,
353        batch_size: usize,
354        verbose: bool,
355    ) -> Result<(
356        crate::listings::ch07::InstructionDataLoader<
357            crate::listings::ch07::InstructionDataCollator,
358        >,
359        crate::listings::ch07::InstructionDataLoader<
360            crate::listings::ch07::InstructionDataCollator,
361        >,
362        crate::listings::ch07::InstructionDataLoader<
363            crate::listings::ch07::InstructionDataCollator,
364        >,
365    )> {
366        use crate::listings::ch07::{
367            download_and_load_file, partition_data, AlpacaPromptFormatter, DataLoader,
368            InstructionDataCollator, InstructionDataLoader, InstructionDataset, DATA_DIR,
369            INSTRUCTION_DATA_FILENAME, INSTRUCTION_DATA_URL,
370        };
371        use candle_core::Device;
372        use std::path::Path;
373        use tiktoken_rs::get_bpe_from_model;
374
375        let tokenizer = get_bpe_from_model("gpt2")?;
376
377        // load instruction examples
378        let file_path = Path::new(DATA_DIR).join(INSTRUCTION_DATA_FILENAME);
379        let data = download_and_load_file(file_path, INSTRUCTION_DATA_URL, false)?;
380
381        // partition data and create train, val, test datasets
382        let (train_data, val_data, test_data) = partition_data(data, 0.85_f32, 0.05_f32)?;
383        let prompt_formatter = AlpacaPromptFormatter;
384        let train_dataset = InstructionDataset::new(train_data, &tokenizer, &prompt_formatter);
385        let val_dataset = InstructionDataset::new(val_data, &tokenizer, &prompt_formatter);
386        let test_dataset = InstructionDataset::new(test_data, &tokenizer, &prompt_formatter);
387
388        // create loaders
389        let collator = InstructionDataCollator::new()
390            .device(Device::cuda_if_available(0)?)
391            .allowed_max_length(Some(1024_usize));
392        let train_loader =
393            InstructionDataLoader::new(train_dataset, batch_size, true, true, collator.clone());
394        let val_loader =
395            InstructionDataLoader::new(val_dataset, batch_size, false, false, collator.clone());
396        let test_loader =
397            InstructionDataLoader::new(test_dataset, batch_size, false, false, collator);
398
399        if verbose {
400            println!("Train loader:");
401            let mut batcher = train_loader.batcher();
402            while let Some(Ok((inputs, targets))) = batcher.next() {
403                println!("inputs: {:?} targets: {:?}", inputs, targets);
404            }
405        }
406
407        Ok((train_loader, val_loader, test_loader))
408    }
409}
410
411impl Example for EG07 {
412    fn description(&self) -> String {
413        let desc = "Creating a `InstructionDataLoader` for each of the train, \
414        val and test data partitions";
415        desc.to_string()
416    }
417
418    fn page_source(&self) -> usize {
419        225_usize
420    }
421
422    fn main(&self) -> Result<()> {
423        let _ = self.main_with_return(8_usize, true);
424        Ok(())
425    }
426}
427
428/// # Example usage of `download_and_load_gpt2` and sample instruction inference
429///
430/// #### Id
431/// 07.08
432///
433/// #### Page
434/// This example starts on page 227
435///
436/// #### CLI command
437/// ```sh
438/// # without cuda
439/// cargo run example 07.08
440///
441/// # with cuda
442/// cargo run --features cuda example 07.08
443/// ```
444pub struct EG08;
445
446impl Example for EG08 {
447    fn description(&self) -> String {
448        "Example usage of `download_and_load_gpt2` and sample instruction inference.".to_string()
449    }
450
451    fn page_source(&self) -> usize {
452        227_usize
453    }
454
455    fn main(&self) -> Result<()> {
456        use crate::listings::{
457            ch04::Config,
458            ch05::{generate, text_to_token_ids, token_ids_to_text},
459            ch07::{
460                download_and_load_file, download_and_load_gpt2, partition_data,
461                AlpacaPromptFormatter, PromptFormatter, DATA_DIR, INSTRUCTION_DATA_FILENAME,
462                INSTRUCTION_DATA_URL,
463            },
464        };
465        use candle_core::{DType, Device, Tensor};
466        use candle_nn::{VarBuilder, VarMap};
467        use rand::{rngs::StdRng, SeedableRng};
468        use std::path::Path;
469        use tiktoken_rs::get_bpe_from_model;
470
471        // partition data and create train, val, test datasets
472        let file_path = Path::new(DATA_DIR).join(INSTRUCTION_DATA_FILENAME);
473        let data = download_and_load_file(file_path, INSTRUCTION_DATA_URL, false)?;
474        let (_train_data, val_data, _test_data) = partition_data(data, 0.85_f32, 0.05_f32)?;
475
476        // use `download_and_load_gpt2` for gpt2-medium
477        let mut cfg = Config::gpt2_medium();
478        cfg.qkv_bias = true;
479        let varmap = VarMap::new();
480        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
481        let model_id = "openai-community/gpt2-medium";
482        let model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, model_id)?;
483
484        // input instructions
485        let prompt_formatter = AlpacaPromptFormatter;
486        let input_text = prompt_formatter.format_input(&val_data[0]);
487        println!("{}", input_text);
488
489        // run inference
490        let tokenizer = get_bpe_from_model("gpt2")?;
491        let mut rng = StdRng::seed_from_u64(42_u64);
492        let token_ids = generate(
493            &model,
494            text_to_token_ids(input_text.as_str(), &tokenizer, vb.device())?,
495            35_usize,
496            cfg.context_length,
497            None,
498            None,
499            Some(Tensor::new(&[50_256_u32], vb.device())?),
500            &mut rng,
501        )?;
502        let generated_text = token_ids_to_text(token_ids, &tokenizer)?;
503        let response_text = &generated_text[input_text.len()..].trim();
504
505        println!("---generated-text-below---\n{}", response_text);
506
507        Ok(())
508    }
509}
510
511/// # Example usage of `calc_loss_loader` to compute cross-entropy loss on train, val, test sets
512///
513/// #### Id
514/// 07.09
515///
516/// #### Page
517/// This example starts on page 230
518///
519/// #### CLI command
520/// ```sh
521/// # without cuda
522/// cargo run example 07.09
523///
524/// # with cuda
525/// cargo run --features cuda example 07.09
526/// ```
527pub struct EG09;
528
529impl Example for EG09 {
530    fn description(&self) -> String {
531        let desc = "Example usage of `calc_loss_loader` to compute accuracy on \
532        train, validation and test instruction datasets.";
533        desc.to_string()
534    }
535
536    fn page_source(&self) -> usize {
537        230_usize
538    }
539
540    fn main(&self) -> Result<()> {
541        use crate::listings::{
542            ch04::Config,
543            ch05::DEFAULT_IGNORE_INDEX,
544            ch07::{calc_loss_loader, download_and_load_gpt2},
545        };
546        use candle_core::{DType, Device};
547        use candle_nn::{VarBuilder, VarMap};
548
549        // use `download_and_load_gpt2` for gpt2-medium
550        let mut cfg = Config::gpt2_medium();
551        cfg.qkv_bias = true;
552        let varmap = VarMap::new();
553        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
554        let model_id = "openai-community/gpt2-medium";
555        let model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, model_id)?;
556
557        // re-use eg 07.07
558        let eg07 = EG07;
559        let (train_loader, val_loader, _test_loader) = eg07.main_with_return(8_usize, false)?;
560
561        // // compute losses
562        let num_batches = Some(5_usize);
563        let train_loss = calc_loss_loader(
564            &train_loader,
565            &model,
566            vb.device(),
567            num_batches,
568            Some(DEFAULT_IGNORE_INDEX),
569        )?;
570        let val_loss = calc_loss_loader(
571            &val_loader,
572            &model,
573            vb.device(),
574            num_batches,
575            Some(DEFAULT_IGNORE_INDEX),
576        )?;
577
578        println!("Training loss: {}", train_loss);
579        println!("Validation loss: {}", val_loss);
580
581        Ok(())
582    }
583}
584
585/// # Example usage of `train_model_simple` and `plot_losses` functions
586///
587/// NOTE: In the book this material is encapsulated within Listing 7.8. Here,
588/// we choose to make it as an example instead.
589///
590/// #### Id
591/// 07.10
592///
593/// #### Page
594/// This example starts on page 231
595///
596/// #### CLI command
597/// ```sh
598/// # without cuda
599/// cargo run example 07.10
600///
601/// # with cuda
602/// cargo run --features cuda example 07.10
603/// ```
604pub struct EG10;
605
606impl Example for EG10 {
607    fn description(&self) -> String {
608        "Example usage of `train_model_simple` and `plot_losses` functions".to_string()
609    }
610
611    fn page_source(&self) -> usize {
612        231_usize
613    }
614
615    // TODO: This fails silently if run into OOM issues.
616    fn main(&self) -> Result<()> {
617        use crate::listings::{
618            ch04::Config,
619            ch05::plot_losses,
620            ch07::{
621                download_and_load_gpt2, train_model_simple, AlpacaPromptFormatter, PromptFormatter,
622                DEFAULT_IGNORE_INDEX,
623            },
624        };
625        use candle_core::{DType, Device};
626        use candle_nn::{AdamW, Optimizer, ParamsAdamW, VarBuilder, VarMap};
627        use ndarray::linspace;
628        use std::path::Path;
629        use tiktoken_rs::get_bpe_from_model;
630
631        // use `download_and_load_gpt2`
632        let model_id = "openai-community/gpt2"; // use `gpt2-medium` for med instead
633        let mut cfg = Config::gpt2_124m(); // use `gpt2_medium()` for med instead
634
635        cfg.qkv_bias = true;
636        let varmap = VarMap::new();
637        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
638        let model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, model_id)?;
639
640        // re-use eg 07.07
641        let eg07 = EG07;
642        let (train_loader, val_loader, _test_loader) = eg07.main_with_return(8_usize, false)?;
643
644        // invoke training
645        let (eval_freq, eval_iter, num_epochs) = (5_usize, 5_usize, 2_usize);
646        let optimizer = AdamW::new(
647            varmap.all_vars(),
648            ParamsAdamW {
649                lr: 0.00005,
650                weight_decay: 0.1,
651                ..Default::default()
652            },
653        )?;
654        let tokenizer = get_bpe_from_model("gpt2")?;
655        let prompt_formatter = AlpacaPromptFormatter;
656        let start_context = prompt_formatter.format_input(&val_loader.dataset().data()[0]);
657        let (train_losses, val_losses, tokens_seen) = train_model_simple(
658            &model,
659            &train_loader,
660            &val_loader,
661            optimizer,
662            vb.device(),
663            num_epochs,
664            eval_freq,
665            eval_iter,
666            start_context.as_str(),
667            &tokenizer,
668            Some(DEFAULT_IGNORE_INDEX),
669        )?;
670
671        // save model
672        println!("Saving weights to `./ift.checkpoint.safetensors`");
673        varmap.save("ift.checkpoint.safetensors")?;
674
675        // plot loss curves
676        println!("Saving plot to `./plot_ift_loss.html`");
677        let epochs_seen = Vec::from_iter(linspace(0_f32, num_epochs as f32, train_losses.len()));
678        let tokens_seen = tokens_seen
679            .into_iter()
680            .map(|el| el as f32)
681            .collect::<Vec<_>>();
682        let save_path = Path::new("plot_ift_loss.html").to_path_buf();
683        plot_losses(
684            epochs_seen,
685            tokens_seen,
686            train_losses,
687            val_losses,
688            save_path,
689        )?;
690
691        Ok(())
692    }
693}
694
695/// # Example of extracting model-generated responses and comparing to correct ones
696///
697/// #### Id
698/// 07.11
699///
700/// #### Page
701/// This example starts on page 234
702///
703/// #### CLI command
704/// ```sh
705/// # without cuda
706/// cargo run example 07.11
707///
708/// # with cuda
709/// cargo run --features cuda example 07.11
710/// ```
711pub struct EG11;
712
713impl Example for EG11 {
714    fn description(&self) -> String {
715        let desc = "Example of extracting model-generated responses and \
716        comparing to correct ones";
717        desc.to_string()
718    }
719
720    fn page_source(&self) -> usize {
721        234_usize
722    }
723
724    fn main(&self) -> Result<()> {
725        use crate::listings::{
726            ch04::{Config, GPTModel},
727            ch05::{generate, text_to_token_ids, token_ids_to_text},
728            ch07::{AlpacaPromptFormatter, InstructionExample, PromptFormatter},
729        };
730        use candle_core::{DType, Device, Tensor};
731        use candle_nn::{VarBuilder, VarMap};
732        use rand::{rngs::StdRng, SeedableRng};
733        use tiktoken_rs::get_bpe_from_model;
734
735        // setup the gpt2 model
736        let mut cfg = Config::gpt2_124m(); // must match model size used in EG10
737        cfg.qkv_bias = true;
738        let mut varmap = VarMap::new();
739        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
740        let model = GPTModel::new(cfg, vb.pp("model"))?;
741
742        // load instructed-finetuned weights
743        varmap
744            .load("ift.checkpoint.safetensors")
745            .with_context(|| "Missing 'ift.checkpoint.safetensors' file. Please run EG 07.10.")?;
746
747        // extract responses
748        let eg07 = EG07;
749        let (_train_loader, _val_loader, test_loader) = eg07.main_with_return(8_usize, false)?;
750        let tokenizer = get_bpe_from_model("gpt2")?;
751        let mut rng = StdRng::seed_from_u64(42_u64);
752        let prompt_formatter = AlpacaPromptFormatter;
753
754        for entry in &test_loader.dataset().data()[..3] {
755            let input_text = prompt_formatter.format_input(entry);
756            let token_ids = generate(
757                &model,
758                text_to_token_ids(&input_text[..], &tokenizer, vb.device())?,
759                256_usize,
760                cfg.context_length,
761                None,
762                None,
763                Some(Tensor::new(&[50_256_u32], vb.device())?),
764                &mut rng,
765            )?;
766            let generated_text = token_ids_to_text(token_ids, &tokenizer)?;
767            let response_text = &generated_text[input_text.len()..].trim();
768
769            // print
770            println!("{}", input_text);
771            println!("\nCorrect response:\n>>{}", entry.output());
772            println!("\nModel response:\n>>{}", response_text.trim());
773            println!("-----------------------------------------");
774        }
775
776        Ok(())
777    }
778}
779
780/// # Example usage of `generate_test_set_responses`
781///
782/// #### Id
783/// 07.12
784///
785/// #### Page
786/// This example starts on page 237
787///
788/// #### CLI command
789/// ```sh
790/// # without cuda
791/// cargo run example 07.12
792///
793/// # with cuda
794/// cargo run --features cuda example 07.12
795/// ```
796pub struct EG12;
797
798impl Example for EG12 {
799    fn description(&self) -> String {
800        "Example usage of `generate_test_set_responses`.".to_string()
801    }
802
803    fn page_source(&self) -> usize {
804        237_usize
805    }
806
807    fn main(&self) -> Result<()> {
808        use crate::listings::{
809            ch04::{Config, GPTModel},
810            ch07::{generate_test_set_responses, AlpacaPromptFormatter, DATA_DIR},
811        };
812        use candle_core::{DType, Device};
813        use candle_nn::{VarBuilder, VarMap};
814        use std::path::Path;
815
816        // setup the gpt2 model
817        let mut cfg = Config::gpt2_124m(); // must match model size used in EG10
818        cfg.qkv_bias = true;
819        let mut varmap = VarMap::new();
820        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
821        let model = GPTModel::new(cfg, vb.pp("model"))?;
822
823        // load instructed-finetuned weights
824        varmap
825            .load("ift.checkpoint.safetensors")
826            .with_context(|| "Missing 'ift.checkpoint.safetensors' file. Please run EG 07.10.")?;
827
828        // get data loaders
829        let eg07 = EG07;
830        let (_train_loader, _val_loader, test_loader) = eg07.main_with_return(8_usize, false)?;
831
832        // generate test set responses
833        let save_path = Path::new(DATA_DIR).join("instruction_data_with_response.json");
834        let prompt_formatter = AlpacaPromptFormatter;
835        let mut test_data = test_loader.dataset().data().clone();
836        generate_test_set_responses(
837            &mut test_data,
838            &model,
839            cfg.context_length,
840            vb.device(),
841            save_path,
842            &prompt_formatter,
843        )?;
844
845        println!("{}", test_data[0]);
846
847        Ok(())
848    }
849}
850
851/// # An example to check if `ollama` process is running
852///
853/// #### Id
854/// 07.13
855///
856/// #### Page
857/// This example starts on page 241
858///
859/// #### CLI command
860/// ```sh
861/// # without cuda
862/// cargo run example 07.13
863///
864/// # with cuda
865/// cargo run --features cuda example 07.13
866/// ```
867pub struct EG13;
868
869impl Example for EG13 {
870    fn description(&self) -> String {
871        "An example to check if `ollama` process is running.".to_string()
872    }
873
874    fn page_source(&self) -> usize {
875        241_usize
876    }
877
878    fn main(&self) -> Result<()> {
879        use anyhow::anyhow;
880        use sysinfo::System;
881
882        let sys = System::new_all();
883        let mut ollama_processes = sys.processes_by_exact_name("ollama".as_ref());
884        let _ = ollama_processes.next().ok_or(anyhow!(
885            "Ollama not running. Launch ollama before proceeding."
886        ))?;
887
888        println!("Ollama running");
889
890        Ok(())
891    }
892}
893
894/// # Example usage of `query_model`
895///
896/// #### Id
897/// 07.14
898///
899/// #### Page
900/// This example starts on page 243
901///
902/// #### CLI command
903/// ```sh
904/// # without cuda
905/// cargo run example 07.14
906///
907/// # with cuda
908/// cargo run --features cuda example 07.14
909/// ```
910pub struct EG14;
911
912impl Example for EG14 {
913    fn description(&self) -> String {
914        "Example usage of `query_model`.".to_string()
915    }
916
917    fn page_source(&self) -> usize {
918        243_usize
919    }
920
921    fn main(&self) -> Result<()> {
922        use crate::listings::ch07::{query_model, DEFAULT_OLLAMA_API_URL};
923
924        let model = "llama3";
925        let result = query_model("What do Llamas eat?", model, DEFAULT_OLLAMA_API_URL)?;
926
927        println!("{}", result);
928        Ok(())
929    }
930}
931
932/// # Using Llama3.2 (via Ollama) as the LLM judge to evaluate model responses
933///
934/// #### Id
935/// 07.15
936///
937/// #### Page
938/// This example starts on page 244
939///
940/// #### CLI command
941/// ```sh
942/// # without cuda
943/// cargo run example 07.15
944///
945/// # with cuda
946/// cargo run --features cuda example 07.15
947/// ```
948pub struct EG15;
949
950impl Example for EG15 {
951    fn description(&self) -> String {
952        let desc = "Using Llama3.2 (via Ollama) as the LLM judge to evaluate model responses.";
953        desc.to_string()
954    }
955
956    fn page_source(&self) -> usize {
957        244_usize
958    }
959
960    fn main(&self) -> Result<()> {
961        use crate::listings::ch07::{
962            load_instruction_data_from_json, query_model, AlpacaPromptFormatter,
963            InstructionExample, InstructionResponseExample, PromptFormatter, DATA_DIR,
964            DEFAULT_OLLAMA_API_URL,
965        };
966        use std::path::Path;
967
968        // load test instruction data with response
969        let file_path = Path::new(DATA_DIR).join("instruction_data_with_response.json");
970        let test_data: Vec<InstructionResponseExample> = load_instruction_data_from_json(file_path)
971            .with_context(|| {
972                "Missing 'instruction_data_with_response.json' file. Please run EG 07.12."
973            })?;
974
975        let model = "llama3";
976        let prompt_formatter = AlpacaPromptFormatter;
977        for (ix, entry) in test_data.iter().enumerate().take(3_usize) {
978            let model_response = entry
979                .model_response()
980                .as_ref()
981                .ok_or_else(|| anyhow!("Entry {ix} is missing a model response."))?;
982            let prompt = format!(
983                "Given the input `{}` and the correct output `{}`, score the \
984                model response `{}` on a scale from 0 to 100, where 100 is the ]
985                best score.",
986                prompt_formatter.format_input(entry),
987                entry.output(),
988                model_response
989            );
990
991            println!("\nDataset response:");
992            println!("\n>>{}", entry.output());
993            println!("\nModel response:");
994            println!("\n>>{}", model_response);
995            println!("\nScore:");
996            println!(
997                "\n>>{}",
998                query_model(prompt.as_str(), model, DEFAULT_OLLAMA_API_URL)?
999            );
1000        }
1001
1002        Ok(())
1003    }
1004}
1005
1006/// # Example usage of `generate_model_scores`
1007///
1008/// #### Id
1009/// 07.16
1010///
1011/// #### Page
1012/// This example starts on page 246
1013///
1014/// #### CLI command
1015/// ```sh
1016/// # without cuda
1017/// cargo run example 07.16
1018///
1019/// # with cuda
1020/// cargo run --features cuda example 07.16
1021/// ```
1022pub struct EG16;
1023
1024impl Example for EG16 {
1025    fn description(&self) -> String {
1026        "Example usage of `generate_model_scores`.".to_string()
1027    }
1028
1029    fn page_source(&self) -> usize {
1030        246_usize
1031    }
1032
1033    fn main(&self) -> Result<()> {
1034        use crate::listings::ch07::{
1035            generate_model_scores, load_instruction_data_from_json, AlpacaPromptFormatter,
1036            DATA_DIR, DEFAULT_OLLAMA_API_URL,
1037        };
1038        use std::path::Path;
1039
1040        // load test instruction data with response
1041        let file_path = Path::new(DATA_DIR).join("instruction_data_with_response.json");
1042        let test_data = load_instruction_data_from_json(file_path).with_context(|| {
1043            "Missing 'instruction_data_with_response.json' file. Please run EG 07.12."
1044        })?;
1045
1046        // invoke generate_model_scores
1047        let model = "llama3";
1048        let prompt_formatter = AlpacaPromptFormatter;
1049        let scores =
1050            generate_model_scores(&test_data, DEFAULT_OLLAMA_API_URL, model, &prompt_formatter)?;
1051
1052        // print stats
1053        println!("Number of scores: {} of {}", scores.len(), test_data.len());
1054        let average_score = scores.iter().sum::<f32>() / scores.len() as f32;
1055        println!("Average score: {}", average_score);
1056
1057        Ok(())
1058    }
1059}
1060
1061/// # [Bonus] Usage of `generate_chosen_and_rejected_response` to create preference example
1062///
1063/// #### Id
1064/// 07.17
1065///
1066/// #### Page
1067/// This example is from `04_preference-tuning-with-dpo/create-preference-data-ollama.ipynb`
1068///
1069/// #### CLI command
1070/// ```sh
1071/// # without cuda
1072/// cargo run example 07.17
1073///
1074/// # with cuda
1075/// cargo run --features cuda example 07.17
1076/// ```
1077pub struct EG17;
1078
1079impl Example for EG17 {
1080    fn description(&self) -> String {
1081        let desc = "[Bonus from DPO notebook] Usage of \
1082        `generate_chosen_and_rejected_response` to create preference example.";
1083        desc.to_string()
1084    }
1085
1086    fn page_source(&self) -> usize {
1087        0_usize
1088    }
1089
1090    fn main(&self) -> Result<()> {
1091        use crate::listings::ch07::{
1092            bonus::generate_chosen_and_rejected_response, download_and_load_file,
1093            AlpacaPromptFormatter, DATA_DIR, DEFAULT_OLLAMA_API_URL, INSTRUCTION_DATA_FILENAME,
1094            INSTRUCTION_DATA_URL,
1095        };
1096        use rand::{rngs::StdRng, SeedableRng};
1097        use std::path::Path;
1098
1099        // load instruction examples
1100        let file_path = Path::new(DATA_DIR).join(INSTRUCTION_DATA_FILENAME);
1101        let data = download_and_load_file(file_path, INSTRUCTION_DATA_URL, false)?;
1102
1103        // invoke generate_chose_and_rejected_response
1104        let model = "llama3";
1105        let prompt_formatter = AlpacaPromptFormatter;
1106        let mut rng = StdRng::seed_from_u64(42_u64);
1107        let preference_example = generate_chosen_and_rejected_response(
1108            &data[42],
1109            DEFAULT_OLLAMA_API_URL,
1110            model,
1111            &prompt_formatter,
1112            &mut rng,
1113        )?;
1114
1115        println!("{:#?}", preference_example);
1116
1117        Ok(())
1118    }
1119}
1120
1121/// # [BONUS] Example usage of `generate_preference_dataset`
1122///
1123/// #### Id
1124/// 07.18
1125///
1126/// #### Page
1127/// This example is adapted from `04_preference-tuning-with-dpo/create-preference-data-ollama.ipynb`
1128///
1129/// #### CLI command
1130/// ```sh
1131/// # without cuda
1132/// cargo run example 07.18
1133///
1134/// # with cuda
1135/// cargo run --features cuda example 07.18
1136/// ```
1137pub struct EG18;
1138
1139impl Example for EG18 {
1140    fn description(&self) -> String {
1141        "[Bonus from DPO notebook] Example usage of `generate_preference_dataset`.".to_string()
1142    }
1143
1144    fn page_source(&self) -> usize {
1145        0_usize
1146    }
1147
1148    fn main(&self) -> Result<()> {
1149        use crate::listings::{
1150            ch07::bonus::generate_preference_dataset,
1151            ch07::{
1152                download_and_load_file, AlpacaPromptFormatter, DATA_DIR, DEFAULT_OLLAMA_API_URL,
1153                INSTRUCTION_DATA_FILENAME, INSTRUCTION_DATA_URL,
1154            },
1155        };
1156        use std::path::Path;
1157
1158        // load instruction examples
1159        let file_path = Path::new(DATA_DIR).join(INSTRUCTION_DATA_FILENAME);
1160        let data = download_and_load_file(file_path, INSTRUCTION_DATA_URL, false)?;
1161
1162        // invoke generate_chose_and_rejected_response
1163        let model = "llama3";
1164        let prompt_formatter = AlpacaPromptFormatter;
1165        let save_path = Path::new(DATA_DIR).join("instruction_data_with_preference.json");
1166        generate_preference_dataset(
1167            &data,
1168            DEFAULT_OLLAMA_API_URL,
1169            model,
1170            &prompt_formatter,
1171            save_path,
1172        )?;
1173
1174        Ok(())
1175    }
1176}
1177
1178/// # [BONUS] Example usage of `PreferenceDataCollator.custom_collate_fn`
1179///
1180/// #### Id
1181/// 07.19
1182///
1183/// #### Page
1184/// This example is adapted from `04_preference-tuning-with-dpo/create-preference-data-ollama.ipynb`
1185///
1186/// #### CLI command
1187/// ```sh
1188/// # without cuda
1189/// cargo run example 07.19
1190///
1191/// # with cuda
1192/// cargo run --features cuda example 07.19
1193/// ```
1194pub struct EG19;
1195
1196impl Example for EG19 {
1197    fn description(&self) -> String {
1198        let desc = "[Bonus from DPO notebook] Example usage of \
1199        `PreferenceDataCollator.custom_collate_fn`.";
1200        desc.to_string()
1201    }
1202
1203    fn page_source(&self) -> usize {
1204        0_usize
1205    }
1206
1207    fn main(&self) -> Result<()> {
1208        use crate::listings::{
1209            ch05::token_ids_to_text,
1210            ch07::bonus::{
1211                CustomCollator, EncodedPreferenceExample, PreferenceDataCollator, PreferenceExample,
1212            },
1213            ch07::{load_instruction_data_from_json, AlpacaPromptFormatter, DATA_DIR},
1214        };
1215        use candle_core::{Device, IndexOp};
1216        use std::path::Path;
1217        use tiktoken_rs::get_bpe_from_model;
1218
1219        let tokenizer = get_bpe_from_model("gpt2")?;
1220        let prompt_formatter = AlpacaPromptFormatter;
1221
1222        // load preference examples
1223        let file_path = Path::new(DATA_DIR).join("instruction_data_with_preference.json");
1224        let preference_data: Vec<PreferenceExample> = load_instruction_data_from_json(file_path)
1225            .with_context(|| {
1226                "Missing 'instruction_data_with_preference.json' file. Please run EG 07.18."
1227            })?;
1228
1229        // take a small sample and encode
1230        let sample = preference_data
1231            .into_iter()
1232            .take(2_usize)
1233            .collect::<Vec<_>>();
1234
1235        println!("Sample Preference Examples:\n{:#?}", sample);
1236
1237        let batch = sample
1238            .into_iter()
1239            .map(|el| EncodedPreferenceExample::from_example(&el, &prompt_formatter, &tokenizer))
1240            .collect::<Vec<_>>();
1241
1242        let collator = PreferenceDataCollator::new().device(Device::cuda_if_available(0)?);
1243        let collated_item = collator.collate(batch)?;
1244
1245        // print prompts
1246        println!(
1247            "\nCollated Batch: Prompt Tokens\n\n{:?}\n",
1248            collated_item
1249                .prompt()
1250                .iter()
1251                .map(|el| el.to_vec1::<u32>())
1252                .collect::<Vec<_>>()
1253        );
1254
1255        // print chosen
1256        println!(
1257            "\nCollated Batch: Chosen Tokens\n\n{:?}\n",
1258            collated_item.chosen().to_vec2::<u32>()?
1259        );
1260
1261        // Decode prompt and print
1262        let prompt_text = token_ids_to_text(collated_item.prompt()[1].clone(), &tokenizer)?;
1263        println!("\nCollated Batch Item 1: Prompt Text\n\n{}\n", prompt_text);
1264
1265        // Decode chosen and print
1266        let chosen = collated_item.chosen().i((1, ..))?;
1267        let chosen_text = token_ids_to_text(chosen.clone(), &tokenizer)?;
1268        println!("\nCollated Batch Item 1: Chosen Text\n\n{}\n", chosen_text);
1269
1270        // Decode chosen and print
1271        let rejected = collated_item.rejected().i((1, ..))?;
1272        let rejected_text = token_ids_to_text(rejected.clone(), &tokenizer)?;
1273        println!(
1274            "\nCollated Batch Item 1: Rejected Text\n\n{}\n",
1275            rejected_text
1276        );
1277
1278        // Print masks and their shapes
1279        let chosen_mask = &collated_item.chosen_mask().i((1, ..))?;
1280        println!("\nCollated Batch: Masks\n");
1281        println!("Chosen inputs: {:?}", chosen);
1282        println!("Chosen mask: {:?}", chosen_mask);
1283
1284        println!(
1285            "\nCollated Batch Item 1: Chosen Mask Indexes\n\n{:?}\n",
1286            chosen_mask.to_vec1::<u32>()?
1287        );
1288
1289        // decode chosen mask
1290        // let chosen_masked_text = token_ids_to_text(
1291        //     chosen.index_select(chosen_mask.clone(), 0)?,
1292        //     &tokenizer,
1293        // )?;
1294        // println!(
1295        //     "\nCollated Batch Item 1: Chosen Mask Text\n\n{}\n",
1296        //     chosen_masked_text
1297        // );
1298
1299        // decode rejected mask
1300        // let rejected_masked_text = token_ids_to_text(
1301        //     rejected.index_select(&collated_item.rejected_mask()[1], 0)?,
1302        //     &tokenizer,
1303        // )?;
1304        // println!(
1305        //     "\nCollated Batch Item 1: Rejected Mask Text\n\n{}\n",
1306        //     rejected_masked_text
1307        // );
1308
1309        Ok(())
1310    }
1311}
1312
1313/// # [BONUS] Creating a `PreferenceDataLoader`
1314///
1315/// #### Id
1316/// 07.20
1317///
1318/// #### Page
1319/// This example is adapted from `04_preference-tuning-with-dpo/create-preference-data-ollama.ipynb`
1320///
1321/// #### CLI command
1322/// ```sh
1323/// # without cuda
1324/// cargo run example 07.20
1325///
1326/// # with cuda
1327/// cargo run --features cuda example 07.20
1328/// ```
1329pub struct EG20;
1330
1331impl Example for EG20 {
1332    fn description(&self) -> String {
1333        "[BONUS] Creating a `PreferenceDataLoader`.".to_string()
1334    }
1335
1336    fn page_source(&self) -> usize {
1337        0_usize
1338    }
1339
1340    fn main(&self) -> Result<()> {
1341        use crate::listings::{
1342            ch07::bonus::{
1343                PreferenceDataCollator, PreferenceDataLoader, PreferenceDataset, PreferenceExample,
1344            },
1345            ch07::{
1346                load_instruction_data_from_json, partition_data, AlpacaPromptFormatter, DataLoader,
1347                DATA_DIR,
1348            },
1349        };
1350        use candle_core::Device;
1351        use std::path::Path;
1352        use tiktoken_rs::get_bpe_from_model;
1353
1354        let tokenizer = get_bpe_from_model("gpt2")?;
1355        let prompt_formatter = AlpacaPromptFormatter;
1356
1357        // load preference examples
1358        let file_path = Path::new(DATA_DIR).join("instruction_data_with_preference.json");
1359        let preference_data: Vec<PreferenceExample> = load_instruction_data_from_json(file_path)
1360            .with_context(|| {
1361                "Missing 'instruction_data_with_preference.json' file. Please run EG 07.18."
1362            })?;
1363
1364        // partition data and create train, val, test datasets
1365        let (train_data, val_data, test_data) =
1366            partition_data(preference_data, 0.85_f32, 0.05_f32)?;
1367        let train_dataset = PreferenceDataset::new(train_data, &tokenizer, &prompt_formatter);
1368        let val_dataset = PreferenceDataset::new(val_data, &tokenizer, &prompt_formatter);
1369        let test_dataset = PreferenceDataset::new(test_data, &tokenizer, &prompt_formatter);
1370
1371        // create loaders
1372        let collator = PreferenceDataCollator::new().device(Device::cuda_if_available(0)?);
1373        let batch_size = 3_usize;
1374        let train_loader =
1375            PreferenceDataLoader::new(train_dataset, batch_size, true, true, collator.clone());
1376        let val_loader =
1377            PreferenceDataLoader::new(val_dataset, batch_size, false, false, collator.clone());
1378        let test_loader =
1379            PreferenceDataLoader::new(test_dataset, batch_size, false, false, collator);
1380
1381        println!("Train loader:");
1382        let mut batcher = train_loader.batcher();
1383        while let Some(Ok(batch)) = batcher.next() {
1384            println!("batch: {:#?}", batch);
1385        }
1386
1387        println!("Val loader:");
1388        let mut batcher = val_loader.batcher();
1389        while let Some(Ok(batch)) = batcher.next() {
1390            println!("batch: {:#?}", batch);
1391        }
1392
1393        println!("Test loader:");
1394        let mut batcher = test_loader.batcher();
1395        while let Some(Ok(batch)) = batcher.next() {
1396            println!("batch: {:#?}", batch);
1397        }
1398
1399        Ok(())
1400    }
1401}
1402
1403/// # [BONUS] Example usage of `compute_dpo_loss_batch`
1404///
1405/// #### Id
1406/// 07.21
1407///
1408/// #### Page
1409/// This example is adapted from `04_preference-tuning-with-dpo/create-preference-data-ollama.ipynb`
1410///
1411/// #### CLI command
1412/// ```sh
1413/// # without cuda
1414/// cargo run example 07.21
1415///
1416/// # with cuda
1417/// cargo run --features cuda example 07.21
1418/// ```
1419pub struct EG21;
1420
1421impl Example for EG21 {
1422    fn description(&self) -> String {
1423        "[BONUS] Example usage of `compute_dpo_loss_batch`.".to_string()
1424    }
1425
1426    fn page_source(&self) -> usize {
1427        0_usize
1428    }
1429
1430    fn main(&self) -> Result<()> {
1431        use crate::listings::{
1432            ch04::{Config, GPTModel},
1433            ch07::bonus::{
1434                compute_dpo_loss_batch, PreferenceDataCollator, PreferenceDataLoader,
1435                PreferenceDataset, PreferenceExample,
1436            },
1437            ch07::{
1438                load_instruction_data_from_json, partition_data, AlpacaPromptFormatter, DataLoader,
1439                DATA_DIR,
1440            },
1441        };
1442        use candle_core::{DType, Device};
1443        use candle_nn::{VarBuilder, VarMap};
1444        use std::path::Path;
1445        use tiktoken_rs::get_bpe_from_model;
1446
1447        let tokenizer = get_bpe_from_model("gpt2")?;
1448        let prompt_formatter = AlpacaPromptFormatter;
1449
1450        // load reference and policy model
1451        let mut cfg = Config::gpt2_124m(); // must match model size used in EG10
1452        cfg.qkv_bias = true;
1453        let mut varmap = VarMap::new();
1454        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
1455        let policy_model = GPTModel::new(cfg, vb.pp("model"))?;
1456
1457        // load instructed-finetuned weights
1458        varmap
1459            .load("ift.checkpoint.safetensors")
1460            .with_context(|| "Missing 'ift.checkpoint.safetensors' file. Please run EG 07.10.")?;
1461
1462        let vb2 = vb.clone();
1463        let reference_model = GPTModel::new(cfg, vb2.pp("model"))?;
1464
1465        // load preference examples
1466        let file_path = Path::new(DATA_DIR).join("instruction_data_with_preference.json");
1467        let preference_data: Vec<PreferenceExample> = load_instruction_data_from_json(file_path)
1468            .with_context(|| {
1469                "Missing 'instruction_data_with_preference.json' file. Please run EG 07.18."
1470            })?;
1471
1472        // partition data and create train, val, test datasets
1473        let (train_data, _val_data, _test_data) =
1474            partition_data(preference_data, 0.85_f32, 0.05_f32)?;
1475        let train_dataset = PreferenceDataset::new(train_data, &tokenizer, &prompt_formatter);
1476
1477        // create loaders
1478        let collator = PreferenceDataCollator::new().device(Device::cuda_if_available(0)?);
1479        let batch_size = 2_usize;
1480        let train_loader =
1481            PreferenceDataLoader::new(train_dataset, batch_size, true, true, collator.clone());
1482
1483        // get batcher
1484        let mut batcher = train_loader.batcher();
1485
1486        if let Some(Ok(batch)) = batcher.next() {
1487            let (loss, chosen_r, rejected_r) =
1488                compute_dpo_loss_batch(&batch, &policy_model, &reference_model, 0.1, false)?;
1489            println!("{:?}", loss);
1490            println!("{:?}", chosen_r);
1491            println!("{:?}", rejected_r);
1492        }
1493
1494        Ok(())
1495    }
1496}
1497
1498/// # [BONUS] Example usage of `compute_dpo_loss_loader`
1499///
1500/// #### Id
1501/// 07.22
1502///
1503/// #### Page
1504/// This example is adapted from `04_preference-tuning-with-dpo/create-preference-data-ollama.ipynb`
1505///
1506/// #### CLI command
1507/// ```sh
1508/// # without cuda
1509/// cargo run example 07.22
1510///
1511/// # with cuda
1512/// cargo run --features cuda example 07.22
1513/// ```
1514pub struct EG22;
1515
1516impl Example for EG22 {
1517    fn description(&self) -> String {
1518        "[BONUS] Example usage of `compute_dpo_loss_loader`.".to_string()
1519    }
1520
1521    fn page_source(&self) -> usize {
1522        0_usize
1523    }
1524
1525    fn main(&self) -> Result<()> {
1526        use crate::listings::{
1527            ch04::{Config, GPTModel},
1528            ch07::bonus::{
1529                compute_dpo_loss_loader, PreferenceDataCollator, PreferenceDataLoader,
1530                PreferenceDataset, PreferenceExample,
1531            },
1532            ch07::{
1533                load_instruction_data_from_json, partition_data, AlpacaPromptFormatter, DATA_DIR,
1534            },
1535        };
1536        use candle_core::{DType, Device};
1537        use candle_nn::{VarBuilder, VarMap};
1538        use std::path::Path;
1539        use tiktoken_rs::get_bpe_from_model;
1540
1541        let tokenizer = get_bpe_from_model("gpt2")?;
1542        let prompt_formatter = AlpacaPromptFormatter;
1543
1544        // load reference and policy model
1545        let mut cfg = Config::gpt2_124m(); // must match model size used in EG10
1546        cfg.qkv_bias = true;
1547        let mut varmap = VarMap::new();
1548        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
1549        let policy_model = GPTModel::new(cfg, vb.pp("model"))?;
1550
1551        // load instructed-finetuned weights
1552        varmap
1553            .load("ift.checkpoint.safetensors")
1554            .with_context(|| "Missing 'ift.checkpoint.safetensors' file. Please run EG 07.10.")?;
1555
1556        let vb2 = vb.clone();
1557        let reference_model = GPTModel::new(cfg, vb2.pp("model"))?;
1558
1559        // load preference examples
1560        let file_path = Path::new(DATA_DIR).join("instruction_data_with_preference.json");
1561        let preference_data: Vec<PreferenceExample> = load_instruction_data_from_json(file_path)
1562            .with_context(|| {
1563                "Missing 'instruction_data_with_preference.json' file. Please run EG 07.18."
1564            })?;
1565
1566        // partition data and create train, val, test datasets
1567        let (train_data, _val_data, _test_data) =
1568            partition_data(preference_data, 0.85_f32, 0.05_f32)?;
1569        let train_dataset = PreferenceDataset::new(train_data, &tokenizer, &prompt_formatter);
1570
1571        // create loaders
1572        let collator = PreferenceDataCollator::new().device(Device::cuda_if_available(0)?);
1573        let batch_size = 2_usize;
1574        let train_loader =
1575            PreferenceDataLoader::new(train_dataset, batch_size, true, true, collator.clone());
1576
1577        // compute dpo loss with loader
1578        let num_batches = Some(5_usize);
1579        let (train_loss, chosen_rewards, rejected_rewards) = compute_dpo_loss_loader(
1580            &train_loader,
1581            &policy_model,
1582            &reference_model,
1583            0.1,
1584            num_batches,
1585            false,
1586        )?;
1587
1588        println!("Training loss: {}", train_loss);
1589        println!("Chosen rewards: {}", chosen_rewards);
1590        println!("Rejected rewards: {}", rejected_rewards);
1591
1592        Ok(())
1593    }
1594}
1595
1596/// # [BONUS] Example usage of `train_model_dpo_simple` and `plot_losses` functions
1597///
1598/// #### Id
1599/// 07.23
1600///
1601/// #### Page
1602/// This example is adapted from `04_preference-tuning-with-dpo/create-preference-data-ollama.ipynb`
1603///
1604/// #### CLI command
1605/// ```sh
1606/// # without cuda
1607/// cargo run example 07.23
1608///
1609/// # with cuda
1610/// cargo run --features cuda example 07.23
1611/// ```
1612pub struct EG23;
1613
1614impl Example for EG23 {
1615    fn description(&self) -> String {
1616        "[BONUS] Example usage of `train_model_dpo_simple` and `plot_losses` functions".to_string()
1617    }
1618
1619    fn page_source(&self) -> usize {
1620        0_usize
1621    }
1622
1623    // TODO: This fails silently if run into OOM issues.
1624    fn main(&self) -> Result<()> {
1625        use crate::listings::{
1626            ch04::Config,
1627            ch07::bonus::{
1628                train_model_dpo_simple, PreferenceDataCollator, PreferenceDataLoader,
1629                PreferenceDataset, PreferenceExample,
1630            },
1631            ch07::{
1632                load_instruction_data_from_json, partition_data, AlpacaPromptFormatter,
1633                PromptFormatter, DATA_DIR,
1634            },
1635        };
1636        use candle_core::{DType, Device};
1637        use candle_nn::{AdamW, Optimizer, ParamsAdamW, VarBuilder, VarMap};
1638        use std::path::{Path, PathBuf};
1639        use std::str::FromStr;
1640        use tiktoken_rs::get_bpe_from_model;
1641
1642        let tokenizer = get_bpe_from_model("gpt2")?;
1643        let prompt_formatter = AlpacaPromptFormatter;
1644
1645        // load preference examples
1646        let file_path = Path::new(DATA_DIR).join("instruction_data_with_preference.json");
1647        let preference_data: Vec<PreferenceExample> = load_instruction_data_from_json(file_path)
1648            .with_context(|| {
1649                "Missing 'instruction_data_with_preference.json' file. Please run EG 07.18."
1650            })?;
1651
1652        // partition data and create train, val, test datasets
1653        let (train_data, val_data, _test_data) =
1654            partition_data(preference_data, 0.85_f32, 0.05_f32)?;
1655        let train_dataset = PreferenceDataset::new(train_data, &tokenizer, &prompt_formatter);
1656        let val_dataset = PreferenceDataset::new(val_data, &tokenizer, &prompt_formatter);
1657
1658        // create loaders
1659        let collator = PreferenceDataCollator::new().device(Device::cuda_if_available(0)?);
1660        let batch_size = 3_usize;
1661        let train_loader =
1662            PreferenceDataLoader::new(train_dataset, batch_size, true, true, collator.clone());
1663        let val_loader =
1664            PreferenceDataLoader::new(val_dataset, batch_size, false, false, collator.clone());
1665
1666        println!("Train loader: {}", train_loader.len());
1667        println!("Val loader: {}", val_loader.len());
1668
1669        // load policy model
1670        let mut cfg = Config::gpt2_124m(); // must match model size used in EG10
1671        cfg.qkv_bias = true;
1672        let mut varmap = VarMap::new();
1673        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
1674        let checkpoint_path = PathBuf::from_str("ift.checkpoint.safetensors")
1675            .with_context(|| "Missing 'ift.checkpoint.safetensors' file. Please run EG 07.10.")?;
1676        let policy_model =
1677            addons::load_ift_gpt_model(&mut varmap, vb.pp("model"), cfg, &checkpoint_path)?;
1678
1679        // load reference model
1680        let mut varmap_reference = VarMap::new();
1681        let vb_reference = VarBuilder::from_varmap(
1682            &varmap_reference,
1683            DType::F32,
1684            &Device::cuda_if_available(0)?,
1685        );
1686        let reference_model = addons::load_ift_gpt_model(
1687            &mut varmap_reference,
1688            vb_reference.pp("model"),
1689            cfg,
1690            &checkpoint_path,
1691        )?;
1692
1693        // invoke training
1694        let (eval_freq, eval_iter, num_epochs) = (5_usize, 5_usize, 2_usize);
1695        let optimizer = AdamW::new(
1696            varmap.all_vars(),
1697            ParamsAdamW {
1698                lr: 0.00005,
1699                weight_decay: 0.1,
1700                ..Default::default()
1701            },
1702        )?;
1703        let tokenizer = get_bpe_from_model("gpt2")?;
1704        let prompt_formatter = AlpacaPromptFormatter;
1705        let start_context = prompt_formatter.format_input(&val_loader.dataset().data()[0]);
1706        let tracking = train_model_dpo_simple(
1707            &policy_model,
1708            &reference_model,
1709            &train_loader,
1710            &val_loader,
1711            0.1,
1712            optimizer,
1713            vb.device(),
1714            num_epochs,
1715            eval_freq,
1716            eval_iter,
1717            start_context.as_str(),
1718            &tokenizer,
1719        )?;
1720
1721        println!("{:#?}", tracking);
1722
1723        // save model
1724        println!("Saving weights to `./dpo.checkpoint.safetensors`");
1725        varmap.save("dpo.checkpoint.safetensors")?;
1726
1727        Ok(())
1728    }
1729}
1730
1731pub mod addons {
1732    //! Auxiliary module for examples::ch07
1733    use crate::listings::ch04::{Config, GPTModel};
1734    use anyhow::Context;
1735    use candle_nn::{VarBuilder, VarMap};
1736
1737    /// Helper function to load reference and policy models
1738    pub fn load_ift_gpt_model<P>(
1739        varmap: &mut VarMap,
1740        vb: VarBuilder<'_>,
1741        cfg: Config,
1742        checkpoint_path: &P,
1743    ) -> anyhow::Result<GPTModel>
1744    where
1745        P: AsRef<std::path::Path> + std::fmt::Debug,
1746    {
1747        let ift_model = GPTModel::new(cfg, vb)?;
1748
1749        // load instructed-finetuned weights
1750        varmap
1751            .load(checkpoint_path)
1752            .with_context(|| format!("Missing '{:?}' file.", checkpoint_path))?;
1753
1754        Ok(ift_model)
1755    }
1756}