1use crate::Exercise;
4use anyhow::Result;
5
6pub struct X1;
20
21impl Exercise for X1 {
22 fn name(&self) -> String {
23 String::from("6.1")
24 }
25
26 fn title(&self) -> String {
27 String::from("Increasing the context length")
28 }
29
30 fn statement(&self) -> String {
31 let stmt = "Pad the inputs to the maximum number of tokens the model \
32 supports and observe how it affects the predictive performance.";
33 stmt.to_string()
34 }
35
36 fn main(&self) -> Result<()> {
37 use crate::listings::{
38 ch04::Config,
39 ch06::{
40 calc_accuracy_loader, download_and_load_gpt2, modify_out_head_for_classification,
41 train_classifier_simple, SpamDataLoader, SpamDatasetBuilder, HF_GPT2_MODEL_ID,
42 },
43 };
44 use anyhow::anyhow;
45 use candle_core::{DType, Device, Var};
46 use candle_nn::{AdamW, Optimizer, ParamsAdamW, VarBuilder, VarMap};
47 use std::ops::Not;
48 use std::path::Path;
49 use tiktoken_rs::get_bpe_from_model;
50
51 println!("Creating train, val, test datasets");
52 let tokenizer = get_bpe_from_model("gpt2")?;
54 let max_length = Some(512_usize);
55
56 let train_path = Path::new("data").join("train.parquet");
57 if train_path.exists().not() {
58 return Err(anyhow!(
59 "Missing 'data/train.parquet' file. Please run EG 06.04."
60 ));
61 }
62 let train_dataset = SpamDatasetBuilder::new(&tokenizer)
63 .load_data_from_parquet(train_path)
64 .max_length(max_length)
65 .build();
66 println!(
67 "...train dataset max length: {}",
68 train_dataset.max_length()
69 );
70
71 let val_path = Path::new("data").join("validation.parquet");
72 if val_path.exists().not() {
73 return Err(anyhow!(
74 "Missing 'data/validation.parquet' file. Please run EG 06.04."
75 ));
76 }
77 let val_dataset = SpamDatasetBuilder::new(&tokenizer)
78 .load_data_from_parquet(val_path)
79 .max_length(max_length)
80 .build();
81 println!("...val dataset max length: {}", val_dataset.max_length());
82
83 let test_path = Path::new("data").join("test.parquet");
84 if test_path.exists().not() {
85 return Err(anyhow!(
86 "Missing 'data/test.parquet' file. Please run EG 06.04."
87 ));
88 }
89 let test_dataset = SpamDatasetBuilder::new(&tokenizer)
90 .load_data_from_parquet(test_path)
91 .max_length(max_length)
92 .build();
93 println!("...test dataset max length: {}", test_dataset.max_length());
94
95 let batch_size = 2_usize;
97 let train_loader = SpamDataLoader::new(train_dataset, batch_size, true, true);
98 let val_loader = SpamDataLoader::new(val_dataset, batch_size, false, false);
99 let test_loader = SpamDataLoader::new(test_dataset, batch_size, false, false);
100
101 println!("...{:?} training batches", train_loader.len());
103 println!("...{:?} validation batches", val_loader.len());
104 println!("...{:?} test batches", test_loader.len());
105
106 println!("Loading pre-trained GPT-2 and modifying prediction head");
108 let mut cfg = Config::gpt2_124m();
109 cfg.qkv_bias = true;
110 let varmap = VarMap::new();
111 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
112 let mut model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, HF_GPT2_MODEL_ID)?;
113 modify_out_head_for_classification(&mut model, cfg, 2_usize, &varmap, vb.pp("model"))?;
114
115 let mut training_vars: Vec<Var> = vec![];
118 let tensor_data = varmap.data().lock().unwrap();
119 let var_names: Vec<&String> = tensor_data
120 .keys()
121 .filter(|k| k.contains("final_norm") || k.contains("out_head") || k.contains("trf.11"))
122 .collect();
123 for var_name in var_names.into_iter() {
124 let var = tensor_data.get(var_name).unwrap();
125 training_vars.push(var.clone());
126 }
127 drop(tensor_data);
128
129 let optimizer = AdamW::new(
130 training_vars,
131 ParamsAdamW {
132 lr: 5e-5,
133 weight_decay: 0.1,
134 ..Default::default()
135 },
136 )?;
137
138 println!("Fine-tuning GPT2 on spam training dataset");
139 let (eval_freq, eval_iter, num_epochs) = (50_usize, 1_usize, 2_usize);
140 let _ = train_classifier_simple(
141 &model,
142 &train_loader,
143 &val_loader,
144 optimizer,
145 vb.device(),
146 num_epochs,
147 eval_freq,
148 eval_iter,
149 None,
150 );
151
152 println!("Computing performance metrics");
153 let num_batches = None;
155 let train_accuracy =
156 calc_accuracy_loader(&train_loader, &model, vb.device(), num_batches, None)?;
157 let val_accuracy =
158 calc_accuracy_loader(&val_loader, &model, vb.device(), num_batches, None)?;
159 let test_accuracy =
160 calc_accuracy_loader(&test_loader, &model, vb.device(), num_batches, None)?;
161
162 println!("Training accuracy: {}", train_accuracy);
163 println!("Validation accuracy: {}", val_accuracy);
164 println!("Test accuracy: {}", test_accuracy);
165
166 Ok(())
167 }
168}
169
170pub struct X2;
184
185impl Exercise for X2 {
186 fn name(&self) -> String {
187 "6.2".to_string()
188 }
189
190 fn title(&self) -> String {
191 "Fine-tuning the whole model".to_string()
192 }
193
194 fn statement(&self) -> String {
195 let stmt = "Instead of fine-tuning just the final transformer \
196 block, fine-tune the entire model and assess the effect on predictive \
197 performance.";
198 stmt.to_string()
199 }
200
201 fn main(&self) -> Result<()> {
202 use crate::listings::{
203 ch04::Config,
204 ch06::{
205 calc_accuracy_loader, download_and_load_gpt2, modify_out_head_for_classification,
206 train_classifier_simple, HF_GPT2_MODEL_ID,
207 },
208 };
209 use candle_core::{DType, Device};
210 use candle_nn::{AdamW, Optimizer, ParamsAdamW, VarBuilder, VarMap};
211
212 let mut cfg = Config::gpt2_124m();
214 cfg.qkv_bias = true;
215 let varmap = VarMap::new();
216 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
217 let mut model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, HF_GPT2_MODEL_ID)?;
218 modify_out_head_for_classification(&mut model, cfg, 2_usize, &varmap, vb.pp("model"))?;
219
220 let eg06 = crate::examples::ch06::EG06; let (train_loader, val_loader, test_loader) = eg06.main_with_return(false)?;
223
224 let optimizer = AdamW::new(
226 varmap.all_vars(), ParamsAdamW {
228 lr: 5e-5,
229 weight_decay: 0.1,
230 ..Default::default()
231 },
232 )?;
233
234 println!("Fine-tuning ENTIRE GPT2 on spam training dataset");
235 let (eval_freq, eval_iter, num_epochs) = (50_usize, 5_usize, 5_usize);
236 let _ = train_classifier_simple(
237 &model,
238 &train_loader,
239 &val_loader,
240 optimizer,
241 vb.device(),
242 num_epochs,
243 eval_freq,
244 eval_iter,
245 None,
246 );
247
248 println!("Computing performance metrics");
249 let num_batches = None;
251 let train_accuracy =
252 calc_accuracy_loader(&train_loader, &model, vb.device(), num_batches, None)?;
253 let val_accuracy =
254 calc_accuracy_loader(&val_loader, &model, vb.device(), num_batches, None)?;
255 let test_accuracy =
256 calc_accuracy_loader(&test_loader, &model, vb.device(), num_batches, None)?;
257
258 println!("Training accuracy: {}", train_accuracy);
259 println!("Validation accuracy: {}", val_accuracy);
260 println!("Test accuracy: {}", test_accuracy);
261
262 Ok(())
263 }
264}
265
266pub struct X3;
280
281impl Exercise for X3 {
282 fn name(&self) -> String {
283 "6.3".to_string()
284 }
285
286 fn title(&self) -> String {
287 "Fine-tuning the first vs. last token".to_string()
288 }
289
290 fn statement(&self) -> String {
291 let stmt = "Try fine-tuning the first output token. Notice the \
292 changes in predictive performance compared to fine-tuning the last \
293 output token.";
294 stmt.to_string()
295 }
296
297 fn main(&self) -> Result<()> {
298 use crate::listings::{
299 ch04::Config,
300 ch06::{
301 calc_accuracy_loader, download_and_load_gpt2, modify_out_head_for_classification,
302 train_classifier_simple, HF_GPT2_MODEL_ID,
303 },
304 };
305 use candle_core::{DType, Device, Var};
306 use candle_nn::{AdamW, Optimizer, ParamsAdamW, VarBuilder, VarMap};
307
308 let mut cfg = Config::gpt2_124m();
310 cfg.qkv_bias = true;
311 let varmap = VarMap::new();
312 let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::cuda_if_available(0)?);
313 let mut model = download_and_load_gpt2(&varmap, vb.pp("model"), cfg, HF_GPT2_MODEL_ID)?;
314 modify_out_head_for_classification(&mut model, cfg, 2_usize, &varmap, vb.pp("model"))?;
315
316 let eg06 = crate::examples::ch06::EG06; let (train_loader, val_loader, test_loader) = eg06.main_with_return(false)?;
319
320 let mut training_vars: Vec<Var> = vec![];
323 let tensor_data = varmap.data().lock().unwrap();
324 let var_names: Vec<&String> = tensor_data
325 .keys()
326 .filter(|k| k.contains("final_norm") || k.contains("out_head") || k.contains("trf.11"))
327 .collect();
328 for var_name in var_names.into_iter() {
329 let var = tensor_data.get(var_name).unwrap();
330 training_vars.push(var.clone());
331 }
332 drop(tensor_data);
333
334 let optimizer = AdamW::new(
335 training_vars,
336 ParamsAdamW {
337 lr: 5e-5,
338 weight_decay: 0.1,
339 ..Default::default()
340 },
341 )?;
342
343 println!("Fine-tuning GPT2 on spam training dataset using first-token");
344 let (eval_freq, eval_iter, num_epochs) = (50_usize, 5_usize, 5_usize);
345 let custom_pred_token_index = Some(0_usize); let _ = train_classifier_simple(
347 &model,
348 &train_loader,
349 &val_loader,
350 optimizer,
351 vb.device(),
352 num_epochs,
353 eval_freq,
354 eval_iter,
355 custom_pred_token_index,
356 );
357
358 println!("Computing performance metrics");
359 let num_batches = None;
361 let train_accuracy = calc_accuracy_loader(
362 &train_loader,
363 &model,
364 vb.device(),
365 num_batches,
366 custom_pred_token_index,
367 )?;
368 let val_accuracy = calc_accuracy_loader(
369 &val_loader,
370 &model,
371 vb.device(),
372 num_batches,
373 custom_pred_token_index,
374 )?;
375 let test_accuracy = calc_accuracy_loader(
376 &test_loader,
377 &model,
378 vb.device(),
379 num_batches,
380 custom_pred_token_index,
381 )?;
382
383 println!("Training accuracy: {}", train_accuracy);
384 println!("Validation accuracy: {}", val_accuracy);
385 println!("Test accuracy: {}", test_accuracy);
386
387 Ok(())
388 }
389}