llms_from_scratch_rs/exercises/
ch06.rs

1//! Exercises from Chapter 6
2
3use crate::Exercise;
4use anyhow::Result;
5
6/// # Increasing the context length
7///
8/// #### Id
9/// 6.1
10///
11/// #### CLI command
12/// ```sh
13/// # without cuda
14/// cargo run exercise 6.1
15///
16/// # with cuda
17/// cargo run --features cuda exercise 6.1
18/// ```
19pub struct X1;
20
21impl Exercise for X1 {
22    fn name(&self) -> String {
23        String::from("6.1")
24    }
25
26    fn title(&self) -> String {
27        String::from("Increasing the context length")
28    }
29
30    fn statement(&self) -> String {
31        let stmt = "Pad the inputs to the maximum number of tokens the model \
32        supports and observe how it affects the predictive performance.";
33        stmt.to_string()
34    }
35
36    fn main(&self) -> Result<()> {
37        use crate::listings::{
38            ch04::Config,
39            ch06::{
40                calc_accuracy_loader, download_and_load_gpt2, modify_out_head_for_classification,
41                train_classifier_simple, SpamDataLoader, SpamDatasetBuilder, HF_GPT2_MODEL_ID,
42            },
43        };
44        use anyhow::anyhow;
45        use candle_core::{DType, Device, Var};
46        use candle_nn::{AdamW, Optimizer, ParamsAdamW, VarBuilder, VarMap};
47        use std::ops::Not;
48        use std::path::Path;
49        use tiktoken_rs::get_bpe_from_model;
50
51        println!("Creating train, val, test datasets");
52        // create datasets
53        let tokenizer = get_bpe_from_model("gpt2")?;
54        let max_length = Some(512_usize);
55
56        let train_path = Path::new("data").join("train.parquet");
57        if train_path.exists().not() {
58            return Err(anyhow!(
59                "Missing 'data/train.parquet' file. Please run EG 06.04."
60            ));
61        }
62        let train_dataset = SpamDatasetBuilder::new(&tokenizer)
63            .load_data_from_parquet(train_path)
64            .max_length(max_length)
65            .build();
66        println!(
67            "...train dataset max length: {}",
68            train_dataset.max_length()
69        );
70
71        let val_path = Path::new("data").join("validation.parquet");
72        if val_path.exists().not() {
73            return Err(anyhow!(
74                "Missing 'data/validation.parquet' file. Please run EG 06.04."
75            ));
76        }
77        let val_dataset = SpamDatasetBuilder::new(&tokenizer)
78            .load_data_from_parquet(val_path)
79            .max_length(max_length)
80            .build();
81        println!("...val dataset max length: {}", val_dataset.max_length());
82
83        let test_path = Path::new("data").join("test.parquet");
84        if test_path.exists().not() {
85            return Err(anyhow!(
86                "Missing 'data/test.parquet' file. Please run EG 06.04."
87            ));
88        }
89        let test_dataset = SpamDatasetBuilder::new(&tokenizer)
90            .load_data_from_parquet(test_path)
91            .max_length(max_length)
92            .build();
93        println!("...test dataset max length: {}", test_dataset.max_length());
94
95        // create loaders
96        let batch_size = 2_usize;
97        let train_loader = SpamDataLoader::new(train_dataset, batch_size, true, true);
98        let val_loader = SpamDataLoader::new(val_dataset, batch_size, false, false);
99        let test_loader = SpamDataLoader::new(test_dataset, batch_size, false, false);
100
101        // print total number of batches in each data loader
102        println!("...{:?} training batches", train_loader.len());
103        println!("...{:?} validation batches", val_loader.len());
104        println!("...{:?} test batches", test_loader.len());
105
106        // get model
107        println!("Loading pre-trained GPT-2 and modifying prediction head");
108        let mut cfg = Config::gpt2_124m();
109        cfg.qkv_bias = true;
110        let varmap = VarMap::new();
111        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
112        let mut model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, HF_GPT2_MODEL_ID)?;
113        modify_out_head_for_classification(&mut model, cfg, 2_usize, &varmap, vb.pp("model"))?;
114
115        // train model
116        // trainable: last trf block, final layer norm, classification head
117        let mut training_vars: Vec<Var> = vec![];
118        let tensor_data = varmap.data().lock().unwrap();
119        let var_names: Vec<&String> = tensor_data
120            .keys()
121            .filter(|k| k.contains("final_norm") || k.contains("out_head") || k.contains("trf.11"))
122            .collect();
123        for var_name in var_names.into_iter() {
124            let var = tensor_data.get(var_name).unwrap();
125            training_vars.push(var.clone());
126        }
127        drop(tensor_data);
128
129        let optimizer = AdamW::new(
130            training_vars,
131            ParamsAdamW {
132                lr: 5e-5,
133                weight_decay: 0.1,
134                ..Default::default()
135            },
136        )?;
137
138        println!("Fine-tuning GPT2 on spam training dataset");
139        let (eval_freq, eval_iter, num_epochs) = (50_usize, 1_usize, 2_usize);
140        let _ = train_classifier_simple(
141            &model,
142            &train_loader,
143            &val_loader,
144            optimizer,
145            vb.device(),
146            num_epochs,
147            eval_freq,
148            eval_iter,
149            None,
150        );
151
152        println!("Computing performance metrics");
153        // compute accuracies
154        let num_batches = None;
155        let train_accuracy =
156            calc_accuracy_loader(&train_loader, &model, vb.device(), num_batches, None)?;
157        let val_accuracy =
158            calc_accuracy_loader(&val_loader, &model, vb.device(), num_batches, None)?;
159        let test_accuracy =
160            calc_accuracy_loader(&test_loader, &model, vb.device(), num_batches, None)?;
161
162        println!("Training accuracy: {}", train_accuracy);
163        println!("Validation accuracy: {}", val_accuracy);
164        println!("Test accuracy: {}", test_accuracy);
165
166        Ok(())
167    }
168}
169
170/// # Fine-tuning the whole model
171///
172/// #### Id
173/// 6.2
174///
175/// #### CLI command
176/// ```sh
177/// # without cuda
178/// cargo run exercise 6.2
179///
180/// # with cuda
181/// cargo run --features cuda exercise 6.2
182/// ```
183pub struct X2;
184
185impl Exercise for X2 {
186    fn name(&self) -> String {
187        "6.2".to_string()
188    }
189
190    fn title(&self) -> String {
191        "Fine-tuning the whole model".to_string()
192    }
193
194    fn statement(&self) -> String {
195        let stmt = "Instead of fine-tuning just the final transformer \
196        block, fine-tune the entire model and assess the effect on predictive \
197        performance.";
198        stmt.to_string()
199    }
200
201    fn main(&self) -> Result<()> {
202        use crate::listings::{
203            ch04::Config,
204            ch06::{
205                calc_accuracy_loader, download_and_load_gpt2, modify_out_head_for_classification,
206                train_classifier_simple, HF_GPT2_MODEL_ID,
207            },
208        };
209        use candle_core::{DType, Device};
210        use candle_nn::{AdamW, Optimizer, ParamsAdamW, VarBuilder, VarMap};
211
212        // get gpt model with classification head
213        let mut cfg = Config::gpt2_124m();
214        cfg.qkv_bias = true;
215        let varmap = VarMap::new();
216        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
217        let mut model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, HF_GPT2_MODEL_ID)?;
218        modify_out_head_for_classification(&mut model, cfg, 2_usize, &varmap, vb.pp("model"))?;
219
220        // get data loaders
221        let eg06 = crate::examples::ch06::EG06; // re-use
222        let (train_loader, val_loader, test_loader) = eg06.main_with_return(false)?;
223
224        // trainable params and optimizer
225        let optimizer = AdamW::new(
226            varmap.all_vars(), // train on all vars
227            ParamsAdamW {
228                lr: 5e-5,
229                weight_decay: 0.1,
230                ..Default::default()
231            },
232        )?;
233
234        println!("Fine-tuning ENTIRE GPT2 on spam training dataset");
235        let (eval_freq, eval_iter, num_epochs) = (50_usize, 5_usize, 5_usize);
236        let _ = train_classifier_simple(
237            &model,
238            &train_loader,
239            &val_loader,
240            optimizer,
241            vb.device(),
242            num_epochs,
243            eval_freq,
244            eval_iter,
245            None,
246        );
247
248        println!("Computing performance metrics");
249        // compute accuracies
250        let num_batches = None;
251        let train_accuracy =
252            calc_accuracy_loader(&train_loader, &model, vb.device(), num_batches, None)?;
253        let val_accuracy =
254            calc_accuracy_loader(&val_loader, &model, vb.device(), num_batches, None)?;
255        let test_accuracy =
256            calc_accuracy_loader(&test_loader, &model, vb.device(), num_batches, None)?;
257
258        println!("Training accuracy: {}", train_accuracy);
259        println!("Validation accuracy: {}", val_accuracy);
260        println!("Test accuracy: {}", test_accuracy);
261
262        Ok(())
263    }
264}
265
266/// # Fine-tuning the first vs. last token
267///
268/// #### Id
269/// 6.3
270///
271/// #### CLI command
272/// ```sh
273/// # without cuda
274/// cargo run exercise 6.3
275///
276/// # with cuda
277/// cargo run --features cuda exercise 6.3
278/// ```
279pub struct X3;
280
281impl Exercise for X3 {
282    fn name(&self) -> String {
283        "6.3".to_string()
284    }
285
286    fn title(&self) -> String {
287        "Fine-tuning the first vs. last token".to_string()
288    }
289
290    fn statement(&self) -> String {
291        let stmt = "Try fine-tuning the first output token. Notice the \
292        changes in predictive performance compared to fine-tuning the last \
293        output token.";
294        stmt.to_string()
295    }
296
297    fn main(&self) -> Result<()> {
298        use crate::listings::{
299            ch04::Config,
300            ch06::{
301                calc_accuracy_loader, download_and_load_gpt2, modify_out_head_for_classification,
302                train_classifier_simple, HF_GPT2_MODEL_ID,
303            },
304        };
305        use candle_core::{DType, Device, Var};
306        use candle_nn::{AdamW, Optimizer, ParamsAdamW, VarBuilder, VarMap};
307
308        // get gpt model with classification head
309        let mut cfg = Config::gpt2_124m();
310        cfg.qkv_bias = true;
311        let varmap = VarMap::new();
312        let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
313        let mut model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, HF_GPT2_MODEL_ID)?;
314        modify_out_head_for_classification(&mut model, cfg, 2_usize, &varmap, vb.pp("model"))?;
315
316        // get data loaders
317        let eg06 = crate::examples::ch06::EG06; // re-use
318        let (train_loader, val_loader, test_loader) = eg06.main_with_return(false)?;
319
320        // trainable params and optimizer
321        // trainable: last trf block, final layer norm, classification head
322        let mut training_vars: Vec<Var> = vec![];
323        let tensor_data = varmap.data().lock().unwrap();
324        let var_names: Vec<&String> = tensor_data
325            .keys()
326            .filter(|k| k.contains("final_norm") || k.contains("out_head") || k.contains("trf.11"))
327            .collect();
328        for var_name in var_names.into_iter() {
329            let var = tensor_data.get(var_name).unwrap();
330            training_vars.push(var.clone());
331        }
332        drop(tensor_data);
333
334        let optimizer = AdamW::new(
335            training_vars,
336            ParamsAdamW {
337                lr: 5e-5,
338                weight_decay: 0.1,
339                ..Default::default()
340            },
341        )?;
342
343        println!("Fine-tuning GPT2 on spam training dataset using first-token");
344        let (eval_freq, eval_iter, num_epochs) = (50_usize, 5_usize, 5_usize);
345        let custom_pred_token_index = Some(0_usize); // use the first token!
346        let _ = train_classifier_simple(
347            &model,
348            &train_loader,
349            &val_loader,
350            optimizer,
351            vb.device(),
352            num_epochs,
353            eval_freq,
354            eval_iter,
355            custom_pred_token_index,
356        );
357
358        println!("Computing performance metrics");
359        // compute accuracies
360        let num_batches = None;
361        let train_accuracy = calc_accuracy_loader(
362            &train_loader,
363            &model,
364            vb.device(),
365            num_batches,
366            custom_pred_token_index,
367        )?;
368        let val_accuracy = calc_accuracy_loader(
369            &val_loader,
370            &model,
371            vb.device(),
372            num_batches,
373            custom_pred_token_index,
374        )?;
375        let test_accuracy = calc_accuracy_loader(
376            &test_loader,
377            &model,
378            vb.device(),
379            num_batches,
380            custom_pred_token_index,
381        )?;
382
383        println!("Training accuracy: {}", train_accuracy);
384        println!("Validation accuracy: {}", val_accuracy);
385        println!("Test accuracy: {}", test_accuracy);
386
387        Ok(())
388    }
389}