llms_from_scratch_rs/examples/
ch03.rs

1//! Examples from Chapter 3
2
3use crate::Example;
4use anyhow::Result;
5
6/// # Computing attention scores as a dot product
7///
8/// #### Id
9/// 03.01
10///
11/// #### Page
12/// This example starts on page 57
13///
14/// #### CLI command
15/// ```sh
16/// # without cuda
17/// cargo run example 03.01
18///
19/// # with cuda
20/// cargo run --features cuda example 03.01
21/// ```
22pub struct EG01;
23
24impl Example for EG01 {
25    fn description(&self) -> String {
26        String::from("Computing attention scores as a dot product.")
27    }
28
29    fn page_source(&self) -> usize {
30        57_usize
31    }
32
33    fn main(&self) -> Result<()> {
34        use candle_core::{IndexOp, Tensor};
35        use candle_nn::ops::softmax;
36
37        let inputs = addons::get_inputs();
38        let dev = inputs.device().to_owned();
39
40        let query = inputs.index_select(&Tensor::new(&[1u32], &dev)?, 0)?;
41
42        // compute attention scores
43        let mut optional_attn_scores_2: Option<Tensor> = None;
44        for i in 0..inputs.dims()[0] {
45            let x_i = inputs.index_select(&Tensor::new(&[i as u32], &dev)?, 0)?;
46            let a_i = x_i.matmul(&query.t()?)?.flatten_all()?;
47            optional_attn_scores_2 = match optional_attn_scores_2 {
48                Some(attn_scores_2) => Some(Tensor::cat(&[&attn_scores_2, &a_i], 0)?),
49                None => Some(a_i),
50            }
51        }
52
53        if let Some(attn_scores_2) = optional_attn_scores_2 {
54            // raw attention scores
55            println!("Raw attention scores: {:?}", attn_scores_2);
56
57            // basic normalization
58            let sum = attn_scores_2.sum_all()?;
59            let normalized_attn_scores = (attn_scores_2.broadcast_div(&sum))?.to_vec1::<f32>();
60            println!("Normalized attention scores: {:?}", normalized_attn_scores);
61
62            // naive softmax normalization
63            let exponentiator = attn_scores_2.exp()?;
64            let exponentiator_sum = exponentiator.sum_all()?;
65            let naive_softmax_attn_scores = exponentiator.broadcast_div(&exponentiator_sum)?;
66            println!(
67                "Naive Softmax-normalized attention scores: {:?}",
68                naive_softmax_attn_scores
69            );
70
71            // candle softmax
72            let softmax_attn_scores = softmax(&attn_scores_2, 0)?;
73            println!(
74                "Softmax-normalized attention scores: {:?}",
75                softmax_attn_scores
76            );
77
78            // compute second context vector
79            let mut context_vec_2 = Tensor::zeros_like(&query)?;
80            for i in 0..inputs.dims()[0] {
81                let x_i = inputs.index_select(&Tensor::new(&[i as u32], &dev)?, 0)?;
82                context_vec_2 =
83                    context_vec_2.add(&x_i.broadcast_mul(&softmax_attn_scores.i(i)?)?)?;
84            }
85            println!("Context vector 2: {:?}", context_vec_2.to_vec2::<f32>());
86        }
87        Ok(())
88    }
89}
90
91/// # Manual computation of multiple context vectors simultaneously
92///
93/// #### Id
94/// 03.02
95///
96/// #### Page
97/// This example starts on page 62
98///
99/// #### CLI command
100/// ```sh
101/// # without cuda
102/// cargo run example 03.02
103///
104/// # with cuda
105/// cargo run --features cuda example 03.02
106/// ```
107pub struct EG02;
108
109impl Example for EG02 {
110    fn description(&self) -> String {
111        String::from("Manual computation of multiple context vectors simultaneously.")
112    }
113
114    fn page_source(&self) -> usize {
115        62_usize
116    }
117
118    fn main(&self) -> Result<()> {
119        use candle_nn::ops::softmax;
120
121        let inputs = addons::get_inputs();
122
123        // matmul to get attn scores
124        let attn_scores = inputs.matmul(&inputs.t()?)?;
125
126        // apply softmax
127        let attn_weights = softmax(&attn_scores, 1)?;
128
129        // check sums along rows equal to 1
130        let sum = attn_weights.sum(1)?;
131
132        // context vectors
133        let all_context_vectors = attn_weights.matmul(&inputs)?;
134
135        println!("Attention Weights: {:?}\n", attn_weights.to_vec2::<f32>());
136        println!("All Rows Sum: {:?}\n\n", sum.flatten_all());
137        println!(
138            "Context Vectors: {:?}",
139            all_context_vectors.to_vec2::<f32>()
140        );
141        Ok(())
142    }
143}
144
145/// # Implementing the self-attention mechanism with trainable weights
146///
147/// #### Id
148/// 03.03
149///
150/// #### Page
151/// This example starts on page 66
152///
153/// #### CLI command
154/// ```sh
155/// # without cuda
156/// cargo run example 03.03
157///
158/// # with cuda
159/// cargo run --features cuda example 03.03
160/// ```
161pub struct EG03;
162
163impl Example for EG03 {
164    fn description(&self) -> String {
165        let desc = "Implementing the self-attention mechanism with \
166        trainable weights to compute single context vector.";
167        String::from(desc)
168    }
169
170    fn page_source(&self) -> usize {
171        66_usize
172    }
173
174    fn main(&self) -> Result<()> {
175        use candle_core::{DType, Tensor};
176        use candle_nn::init::DEFAULT_KAIMING_NORMAL;
177        use candle_nn::ops::softmax;
178        use candle_nn::{VarBuilder, VarMap};
179
180        let inputs = addons::get_inputs();
181        let dev = inputs.device().to_owned();
182        let varmap = VarMap::new();
183        let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
184
185        let x_2 = inputs.index_select(&Tensor::new(&[1u32], &dev)?, 0)?;
186        let d_in = x_2.dims()[1]; // input embedding dim
187        let d_out = 2_usize;
188
189        // projections
190        let init = DEFAULT_KAIMING_NORMAL;
191        let w_query = vs.get_with_hints((d_in, d_out), "query", init)?;
192        let w_key = vs.get_with_hints((d_in, d_out), "key", init)?;
193        let w_value = vs.get_with_hints((d_in, d_out), "value", init)?;
194
195        // query, key, value vectors
196        let query_2 = x_2.matmul(&w_query)?;
197        let key_2 = x_2.matmul(&w_key)?;
198        let value_2 = x_2.matmul(&w_value)?;
199
200        println!("Query 2: {:?}", query_2.to_vec2::<f32>());
201        println!("Key 2: {:?}", key_2.to_vec2::<f32>());
202        println!("Value 2: {:?}", value_2.to_vec2::<f32>());
203
204        // key and value vectors all input elements
205        let keys = inputs.matmul(&w_key)?;
206        let values = inputs.matmul(&w_value)?;
207
208        println!("Keys shape: {:?}", keys);
209        println!("Values shape: {:?}", values);
210
211        // compute attn scores
212        let attn_scores = query_2.matmul(&keys.t()?)?;
213        println!("Attn scores: {:?}", attn_scores.to_vec2::<f32>());
214
215        // compute attns weights by first scaling then softmax
216        let d_k = Tensor::new(&[f32::powf(keys.dims()[1] as f32, 0.5_f32)], &dev)?;
217        let attn_weights = softmax(&attn_scores.broadcast_div(&d_k)?, 1)?;
218        println!("Attn weights: {:?}", attn_weights.to_vec2::<f32>());
219
220        // compute context vector
221        let context_vec_2 = attn_weights.matmul(&values)?;
222        println!("Context vector 2: {:?}", context_vec_2.to_vec2::<f32>());
223        Ok(())
224    }
225}
226
227/// # Example usage of `SelfAttentionV1` to compute context vectors
228///
229/// #### Id
230/// 03.04
231///
232/// #### Page
233/// This example starts on page 71
234///
235/// #### CLI command
236/// ```sh
237/// # without cuda
238/// cargo run example 03.04
239///
240/// # with cuda
241/// cargo run --features cuda example 03.04
242/// ```
243pub struct EG04;
244
245impl Example for EG04 {
246    fn description(&self) -> String {
247        String::from(
248            "Implement self-attention mechanism to compute context vectors in the input sequence.",
249        )
250    }
251
252    fn page_source(&self) -> usize {
253        71_usize
254    }
255
256    fn main(&self) -> Result<()> {
257        use crate::listings::ch03::SelfAttentionV1;
258        use candle_core::{DType, Module};
259        use candle_nn::{VarBuilder, VarMap};
260
261        let inputs = addons::get_inputs();
262        let d_in = inputs.dims()[1]; // input embedding dim
263        let d_out = 2_usize;
264
265        // construct self attention layer
266        let varmap = VarMap::new();
267        let vb = VarBuilder::from_varmap(&varmap, DType::F32, inputs.device());
268        let attn_v1_layer = SelfAttentionV1::new(d_in, d_out, vb.pp("attn"))?;
269
270        // run a random, embedded input sequence through self-attention
271        let context_vectors = attn_v1_layer.forward(&inputs)?;
272
273        println!("context vectors: {:?}", context_vectors.to_vec2::<f32>());
274        Ok(())
275    }
276}
277
278/// # Example usage of `SelfAttentionV2` to compute context vectors
279///
280/// #### Id
281/// 03.05
282///
283/// #### Page
284/// This example starts on page 73
285///
286/// #### CLI command
287/// ```sh
288/// # without cuda
289/// cargo run example 03.05
290///
291/// # with cuda
292/// cargo run --features cuda example 03.05
293/// ```
294pub struct EG05;
295
296impl Example for EG05 {
297    fn description(&self) -> String {
298        let desc = "Implement self-attention mechanism to compute \
299        contextualized vectors, using candle_nn::Linear.";
300        String::from(desc)
301    }
302
303    fn page_source(&self) -> usize {
304        73_usize
305    }
306
307    fn main(&self) -> Result<()> {
308        use crate::listings::ch03::SelfAttentionV2;
309        use candle_core::{DType, Module};
310        use candle_nn::{VarBuilder, VarMap};
311
312        let inputs = addons::get_inputs();
313        let d_in = inputs.dims()[1]; // input embedding dim
314        let d_out = 2_usize;
315
316        // construct self attention layer
317        let varmap = VarMap::new();
318        let vb = VarBuilder::from_varmap(&varmap, DType::F32, inputs.device());
319        let attn_v2_layer = SelfAttentionV2::new(d_in, d_out, false, vb.pp("attn"))?;
320
321        // run a random, embedded input sequence through self-attention
322        let context_vectors = attn_v2_layer.forward(&inputs)?;
323
324        println!("context vectors: {:?}", context_vectors.to_vec2::<f32>());
325        Ok(())
326    }
327}
328
329/// # Compute causal attention weights
330///
331/// #### Id
332/// 03.06
333///
334/// #### Page
335/// This example starts on page 75
336///
337/// #### CLI command
338/// ```sh
339/// # without cuda
340/// cargo run example 03.06
341///
342/// # with cuda
343/// cargo run --features cuda example 03.06
344/// ```
345pub struct EG06;
346
347impl Example for EG06 {
348    fn description(&self) -> String {
349        String::from("Compute causal attention weights.")
350    }
351
352    fn page_source(&self) -> usize {
353        75_usize
354    }
355
356    fn main(&self) -> Result<()> {
357        let _ = self.main_with_return()?;
358        Ok(())
359    }
360}
361
362impl EG06 {
363    fn main_with_return(&self) -> Result<candle_core::Tensor> {
364        use crate::listings::ch03::SelfAttentionV2;
365        use candle_core::{DType, Module, D};
366        use candle_nn::ops::softmax;
367        use candle_nn::{VarBuilder, VarMap};
368
369        let inputs = addons::get_inputs();
370        let d_in = inputs.dims()[1]; // input embedding dim
371        let d_out = 2_usize;
372
373        // construct self attention layer
374        let varmap = VarMap::new();
375        let vb = VarBuilder::from_varmap(&varmap, DType::F32, inputs.device());
376        let attn_v2_layer = SelfAttentionV2::new(d_in, d_out, false, vb.pp("attn"))?;
377
378        // attn scores
379        let queries = attn_v2_layer.w_query().forward(&inputs)?;
380        let keys = attn_v2_layer.w_key().forward(&inputs)?;
381        let attn_scores = queries.matmul(&keys.t()?)?;
382        let scaling = 1. / (keys.dims()[1] as f64).sqrt();
383        let attn_weights = softmax(&(attn_scores * scaling)?, 1)?;
384
385        // causal mask
386        let context_length = inputs.dims()[0];
387        let mask_simple: Vec<_> = (0..context_length as u32)
388            .flat_map(|i| (0..context_length as u32).map(move |j| f32::from(j <= i)))
389            .collect();
390        let mask_simple = candle_core::Tensor::from_slice(
391            &mask_simple,
392            (context_length, context_length),
393            inputs.device(),
394        )?;
395        let masked_simple = (attn_weights * mask_simple)?;
396        println!("masked_simple: {:?}", masked_simple.to_vec2::<f32>());
397
398        // normalize
399        let row_sums = masked_simple.sum_keepdim(D::Minus1)?;
400        let attn_weights = masked_simple.broadcast_div(&row_sums)?;
401        println!("masked_simple_norm: {:?}", attn_weights.to_vec2::<f32>());
402        Ok(attn_weights)
403    }
404}
405
406/// # Compute causal attention weights more efficiently with `f32::NEGATIVE_INFINITY`
407///
408/// #### Id
409/// 03.07
410///
411/// #### Page
412/// This example starts on page 77
413///
414/// #### CLI command
415/// ```sh
416/// # without cuda
417/// cargo run example 03.07
418///
419/// # with cuda
420/// cargo run --features cuda example 03.07
421/// ```
422pub struct EG07;
423
424impl Example for EG07 {
425    fn description(&self) -> String {
426        let desc = "Compute causal attention weights more efficiently \
427        using `f32::NEGATIVE_INFINITY` and `masked_fill()`.";
428        String::from(desc)
429    }
430
431    fn page_source(&self) -> usize {
432        77_usize
433    }
434
435    fn main(&self) -> Result<()> {
436        let _ = self.main_with_return()?;
437        Ok(())
438    }
439}
440
441impl EG07 {
442    fn main_with_return(&self) -> Result<candle_core::Tensor> {
443        use crate::listings::ch03::SelfAttentionV2;
444        use candle_core::{DType, Module};
445        use candle_nn::ops::softmax;
446        use candle_nn::{VarBuilder, VarMap};
447
448        let inputs = addons::get_inputs();
449        let d_in = inputs.dims()[1]; // input embedding dim
450        let d_out = 2_usize;
451
452        // construct self attention layer
453        let varmap = VarMap::new();
454        let vb = VarBuilder::from_varmap(&varmap, DType::F32, inputs.device());
455        let attn_v2_layer = SelfAttentionV2::new(d_in, d_out, false, vb.pp("attn"))?;
456
457        // attn scores
458        let queries = attn_v2_layer.w_query().forward(&inputs)?;
459        let keys = attn_v2_layer.w_key().forward(&inputs)?;
460        let attn_scores = queries.matmul(&keys.t()?)?;
461
462        // efficient computation of causal mask
463        let context_length = attn_scores.dims()[0];
464        let mask: Vec<_> = (0..context_length as u32)
465            .flat_map(|i| (0..context_length as u32).map(move |j| u32::from(j > i)))
466            .collect();
467        let mask = candle_core::Tensor::from_slice(
468            &mask,
469            (context_length, context_length),
470            inputs.device(),
471        )?;
472        let masked = addons::masked_fill(&attn_scores, &mask, f32::NEG_INFINITY)?;
473        println!("masked: {:?}", masked.to_vec2::<f32>());
474
475        // masked attn weights
476        let scaling = 1. / (keys.dims()[1] as f64).sqrt();
477        let attn_weights = softmax(&(masked * scaling)?, 1)?;
478        println!("attn_weights: {:?}", attn_weights.to_vec2::<f32>());
479        Ok(attn_weights)
480    }
481}
482
483/// # Dropout on attention weights
484///
485/// #### Id
486/// 03.08
487///
488/// #### Page
489/// This example starts on page 80
490///
491/// #### CLI command
492/// ```sh
493/// # without cuda
494/// cargo run example 03.08
495///
496/// # with cuda
497/// cargo run --features cuda example 03.08
498/// ```
499pub struct EG08;
500
501impl Example for EG08 {
502    fn description(&self) -> String {
503        String::from("Dropout on attention weights.")
504    }
505
506    fn page_source(&self) -> usize {
507        80_usize
508    }
509
510    fn main(&self) -> Result<()> {
511        use candle_nn::Dropout;
512
513        let eg07 = EG07;
514        let attn_weights = eg07.main_with_return()?;
515        let dropout = Dropout::new(0.5);
516
517        // could have also just used the candle_nn::ops::dropout directly
518        let dropped_out = dropout.forward(&attn_weights, true)?;
519        println!("dropped_out: {:?}", dropped_out.to_vec2::<f32>());
520        Ok(())
521    }
522}
523
524/// # Example usage of `CausalAttention`
525///
526/// #### Id
527/// 03.09
528///
529/// #### Page
530/// This example starts on page 81
531///
532/// #### CLI command
533/// ```sh
534/// # without cuda
535/// cargo run example 03.09
536///
537/// # with cuda
538/// cargo run --features cuda example 03.09
539/// ```
540pub struct EG09;
541
542impl Example for EG09 {
543    fn description(&self) -> String {
544        String::from("Example usage of `CausalAttention`.")
545    }
546
547    fn page_source(&self) -> usize {
548        81_usize
549    }
550
551    fn main(&self) -> Result<()> {
552        use crate::listings::ch03::CausalAttention;
553        use candle_core::{DType, Module, Tensor};
554        use candle_nn::{VarBuilder, VarMap};
555
556        // create batch
557        let inputs = addons::get_inputs();
558        let d_in = inputs.dims()[1]; // input embedding dim
559        let d_out = 2_usize;
560        let batch = Tensor::stack(&[&inputs, &inputs], 0usize)?;
561        println!("batch shape: {:?}", batch);
562
563        // build causal attn layer
564        let varmap = VarMap::new();
565        let vb = VarBuilder::from_varmap(&varmap, DType::F32, inputs.device());
566        let causal_attn = CausalAttention::new(d_in, d_out, 0.0_f32, false, vb.pp("casual_attn"))?;
567
568        // context vectors
569        let context_vectors = causal_attn.forward(&batch)?;
570        println!("context_vectors.shape: {:?}", context_vectors);
571        Ok(())
572    }
573}
574
575/// # Example usage of `MultiHeadAttentionWrapper`
576///
577/// #### Id
578/// 03.10
579///
580/// #### Page
581/// This example starts on page 85
582///
583/// #### CLI command
584/// ```sh
585/// # without cuda
586/// cargo run example 03.10
587///
588/// # with cuda
589/// cargo run --features cuda example 03.10
590/// ```
591pub struct EG10;
592
593impl Example for EG10 {
594    fn description(&self) -> String {
595        String::from("Example usage of `MultiHeadAttentionWrapper`.")
596    }
597
598    fn page_source(&self) -> usize {
599        85_usize
600    }
601
602    fn main(&self) -> Result<()> {
603        use crate::listings::ch03::MultiHeadAttentionWrapper;
604        use candle_core::{DType, Module, Tensor};
605        use candle_nn::{VarBuilder, VarMap};
606
607        // create batch
608        let inputs = addons::get_inputs();
609        let d_in = inputs.dims()[1]; // input embedding dim
610        let d_out = 2_usize;
611        let batch = Tensor::stack(&[&inputs, &inputs], 0usize)?;
612        println!("batch shape: {:?}", batch);
613
614        // build causal attn layer
615        let varmap = VarMap::new();
616        let vb = VarBuilder::from_varmap(&varmap, DType::F32, inputs.device());
617        let num_heads = 2_usize;
618        let mha =
619            MultiHeadAttentionWrapper::new(num_heads, d_in, d_out, 0.0_f32, false, vb.pp("mha"))?;
620
621        // context vectors
622        let context_vectors = mha.forward(&batch)?;
623        println!("context_vectors.shape: {:?}", context_vectors);
624        println!("context_vectors: {:?}", context_vectors.to_vec3::<f32>());
625        Ok(())
626    }
627}
628
629/// # Example usage of `MultiHeadAttention`
630///
631/// #### Id
632/// 03.11
633///
634/// #### Page
635/// This example starts on page 90
636///
637/// #### CLI command
638/// ```sh
639/// # without cuda
640/// cargo run example 03.11
641///
642/// # with cuda
643/// cargo run --features cuda example 03.11
644/// ```
645pub struct EG11;
646
647impl Example for EG11 {
648    fn description(&self) -> String {
649        String::from("Example usage of `MultiHeadAttention`.")
650    }
651
652    fn page_source(&self) -> usize {
653        90_usize
654    }
655
656    fn main(&self) -> Result<()> {
657        use crate::listings::ch03::MultiHeadAttention;
658        use candle_core::{DType, Tensor};
659        use candle_nn::{VarBuilder, VarMap};
660
661        // create batch
662        let inputs = addons::get_inputs();
663        let d_in = inputs.dims()[1]; // input embedding dim
664        let d_out = 2_usize;
665        let batch = Tensor::stack(&[&inputs, &inputs], 0usize)?;
666        println!("batch shape: {:?}", batch);
667
668        // build causal attn layer
669        let varmap = VarMap::new();
670        let vb = VarBuilder::from_varmap(&varmap, DType::F32, inputs.device());
671        let num_heads = 2_usize;
672        let mha = MultiHeadAttention::new(d_in, d_out, 0.0_f32, num_heads, false, vb.pp("mha"))?;
673
674        // context vectors
675        let context_vectors = mha.forward(&batch)?;
676        println!("mha.head_dim: {:?}", mha.head_dim());
677        println!("context_vectors.shape: {:?}", context_vectors);
678        println!("context_vectors: {:?}", context_vectors.to_vec3::<f32>());
679        Ok(())
680    }
681}
682
683pub mod addons {
684    //! Auxiliary module for examples::ch03
685    use candle_core::{Device, Result, Tensor};
686
687    /// Helper function for getting the sample input token ids
688    pub fn get_inputs() -> Tensor {
689        let dev = Device::cuda_if_available(0).unwrap();
690        Tensor::new(
691            &[
692                [0.43_f32, 0.15, 0.89], // Your
693                [0.55, 0.87, 0.66],     // journey
694                [0.57, 0.85, 0.64],     // starts
695                [0.22, 0.58, 0.33],     // with
696                [0.77, 0.25, 0.10],     // one
697                [0.05, 0.80, 0.55],     // step
698            ],
699            &dev,
700        )
701        .unwrap()
702    }
703
704    /// Helper function for providing a masked `Tensor` specifying `on_false` and `on_true`
705    pub fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
706        let shape = mask.shape();
707        let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
708        let m = mask.where_cond(&on_true, on_false)?;
709        Ok(m)
710    }
711}