1use crate::Example;
4use anyhow::Result;
5
6pub struct EG01;
23
24impl Example for EG01 {
25 fn description(&self) -> String {
26 String::from("Computing attention scores as a dot product.")
27 }
28
29 fn page_source(&self) -> usize {
30 57_usize
31 }
32
33 fn main(&self) -> Result<()> {
34 use candle_core::{IndexOp, Tensor};
35 use candle_nn::ops::softmax;
36
37 let inputs = addons::get_inputs();
38 let dev = inputs.device().to_owned();
39
40 let query = inputs.index_select(&Tensor::new(&[1u32], &dev)?, 0)?;
41
42 let mut optional_attn_scores_2: Option<Tensor> = None;
44 for i in 0..inputs.dims()[0] {
45 let x_i = inputs.index_select(&Tensor::new(&[i as u32], &dev)?, 0)?;
46 let a_i = x_i.matmul(&query.t()?)?.flatten_all()?;
47 optional_attn_scores_2 = match optional_attn_scores_2 {
48 Some(attn_scores_2) => Some(Tensor::cat(&[&attn_scores_2, &a_i], 0)?),
49 None => Some(a_i),
50 }
51 }
52
53 if let Some(attn_scores_2) = optional_attn_scores_2 {
54 println!("Raw attention scores: {:?}", attn_scores_2);
56
57 let sum = attn_scores_2.sum_all()?;
59 let normalized_attn_scores = (attn_scores_2.broadcast_div(&sum))?.to_vec1::<f32>();
60 println!("Normalized attention scores: {:?}", normalized_attn_scores);
61
62 let exponentiator = attn_scores_2.exp()?;
64 let exponentiator_sum = exponentiator.sum_all()?;
65 let naive_softmax_attn_scores = exponentiator.broadcast_div(&exponentiator_sum)?;
66 println!(
67 "Naive Softmax-normalized attention scores: {:?}",
68 naive_softmax_attn_scores
69 );
70
71 let softmax_attn_scores = softmax(&attn_scores_2, 0)?;
73 println!(
74 "Softmax-normalized attention scores: {:?}",
75 softmax_attn_scores
76 );
77
78 let mut context_vec_2 = Tensor::zeros_like(&query)?;
80 for i in 0..inputs.dims()[0] {
81 let x_i = inputs.index_select(&Tensor::new(&[i as u32], &dev)?, 0)?;
82 context_vec_2 =
83 context_vec_2.add(&x_i.broadcast_mul(&softmax_attn_scores.i(i)?)?)?;
84 }
85 println!("Context vector 2: {:?}", context_vec_2.to_vec2::<f32>());
86 }
87 Ok(())
88 }
89}
90
91pub struct EG02;
108
109impl Example for EG02 {
110 fn description(&self) -> String {
111 String::from("Manual computation of multiple context vectors simultaneously.")
112 }
113
114 fn page_source(&self) -> usize {
115 62_usize
116 }
117
118 fn main(&self) -> Result<()> {
119 use candle_nn::ops::softmax;
120
121 let inputs = addons::get_inputs();
122
123 let attn_scores = inputs.matmul(&inputs.t()?)?;
125
126 let attn_weights = softmax(&attn_scores, 1)?;
128
129 let sum = attn_weights.sum(1)?;
131
132 let all_context_vectors = attn_weights.matmul(&inputs)?;
134
135 println!("Attention Weights: {:?}\n", attn_weights.to_vec2::<f32>());
136 println!("All Rows Sum: {:?}\n\n", sum.flatten_all());
137 println!(
138 "Context Vectors: {:?}",
139 all_context_vectors.to_vec2::<f32>()
140 );
141 Ok(())
142 }
143}
144
145pub struct EG03;
162
163impl Example for EG03 {
164 fn description(&self) -> String {
165 let desc = "Implementing the self-attention mechanism with \
166 trainable weights to compute single context vector.";
167 String::from(desc)
168 }
169
170 fn page_source(&self) -> usize {
171 66_usize
172 }
173
174 fn main(&self) -> Result<()> {
175 use candle_core::{DType, Tensor};
176 use candle_nn::init::DEFAULT_KAIMING_NORMAL;
177 use candle_nn::ops::softmax;
178 use candle_nn::{VarBuilder, VarMap};
179
180 let inputs = addons::get_inputs();
181 let dev = inputs.device().to_owned();
182 let varmap = VarMap::new();
183 let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
184
185 let x_2 = inputs.index_select(&Tensor::new(&[1u32], &dev)?, 0)?;
186 let d_in = x_2.dims()[1]; let d_out = 2_usize;
188
189 let init = DEFAULT_KAIMING_NORMAL;
191 let w_query = vs.get_with_hints((d_in, d_out), "query", init)?;
192 let w_key = vs.get_with_hints((d_in, d_out), "key", init)?;
193 let w_value = vs.get_with_hints((d_in, d_out), "value", init)?;
194
195 let query_2 = x_2.matmul(&w_query)?;
197 let key_2 = x_2.matmul(&w_key)?;
198 let value_2 = x_2.matmul(&w_value)?;
199
200 println!("Query 2: {:?}", query_2.to_vec2::<f32>());
201 println!("Key 2: {:?}", key_2.to_vec2::<f32>());
202 println!("Value 2: {:?}", value_2.to_vec2::<f32>());
203
204 let keys = inputs.matmul(&w_key)?;
206 let values = inputs.matmul(&w_value)?;
207
208 println!("Keys shape: {:?}", keys);
209 println!("Values shape: {:?}", values);
210
211 let attn_scores = query_2.matmul(&keys.t()?)?;
213 println!("Attn scores: {:?}", attn_scores.to_vec2::<f32>());
214
215 let d_k = Tensor::new(&[f32::powf(keys.dims()[1] as f32, 0.5_f32)], &dev)?;
217 let attn_weights = softmax(&attn_scores.broadcast_div(&d_k)?, 1)?;
218 println!("Attn weights: {:?}", attn_weights.to_vec2::<f32>());
219
220 let context_vec_2 = attn_weights.matmul(&values)?;
222 println!("Context vector 2: {:?}", context_vec_2.to_vec2::<f32>());
223 Ok(())
224 }
225}
226
227pub struct EG04;
244
245impl Example for EG04 {
246 fn description(&self) -> String {
247 String::from(
248 "Implement self-attention mechanism to compute context vectors in the input sequence.",
249 )
250 }
251
252 fn page_source(&self) -> usize {
253 71_usize
254 }
255
256 fn main(&self) -> Result<()> {
257 use crate::listings::ch03::SelfAttentionV1;
258 use candle_core::{DType, Module};
259 use candle_nn::{VarBuilder, VarMap};
260
261 let inputs = addons::get_inputs();
262 let d_in = inputs.dims()[1]; let d_out = 2_usize;
264
265 let varmap = VarMap::new();
267 let vb = VarBuilder::from_varmap(&varmap, DType::F32, inputs.device());
268 let attn_v1_layer = SelfAttentionV1::new(d_in, d_out, vb.pp("attn"))?;
269
270 let context_vectors = attn_v1_layer.forward(&inputs)?;
272
273 println!("context vectors: {:?}", context_vectors.to_vec2::<f32>());
274 Ok(())
275 }
276}
277
278pub struct EG05;
295
296impl Example for EG05 {
297 fn description(&self) -> String {
298 let desc = "Implement self-attention mechanism to compute \
299 contextualized vectors, using candle_nn::Linear.";
300 String::from(desc)
301 }
302
303 fn page_source(&self) -> usize {
304 73_usize
305 }
306
307 fn main(&self) -> Result<()> {
308 use crate::listings::ch03::SelfAttentionV2;
309 use candle_core::{DType, Module};
310 use candle_nn::{VarBuilder, VarMap};
311
312 let inputs = addons::get_inputs();
313 let d_in = inputs.dims()[1]; let d_out = 2_usize;
315
316 let varmap = VarMap::new();
318 let vb = VarBuilder::from_varmap(&varmap, DType::F32, inputs.device());
319 let attn_v2_layer = SelfAttentionV2::new(d_in, d_out, false, vb.pp("attn"))?;
320
321 let context_vectors = attn_v2_layer.forward(&inputs)?;
323
324 println!("context vectors: {:?}", context_vectors.to_vec2::<f32>());
325 Ok(())
326 }
327}
328
329pub struct EG06;
346
347impl Example for EG06 {
348 fn description(&self) -> String {
349 String::from("Compute causal attention weights.")
350 }
351
352 fn page_source(&self) -> usize {
353 75_usize
354 }
355
356 fn main(&self) -> Result<()> {
357 let _ = self.main_with_return()?;
358 Ok(())
359 }
360}
361
362impl EG06 {
363 fn main_with_return(&self) -> Result<candle_core::Tensor> {
364 use crate::listings::ch03::SelfAttentionV2;
365 use candle_core::{DType, Module, D};
366 use candle_nn::ops::softmax;
367 use candle_nn::{VarBuilder, VarMap};
368
369 let inputs = addons::get_inputs();
370 let d_in = inputs.dims()[1]; let d_out = 2_usize;
372
373 let varmap = VarMap::new();
375 let vb = VarBuilder::from_varmap(&varmap, DType::F32, inputs.device());
376 let attn_v2_layer = SelfAttentionV2::new(d_in, d_out, false, vb.pp("attn"))?;
377
378 let queries = attn_v2_layer.w_query().forward(&inputs)?;
380 let keys = attn_v2_layer.w_key().forward(&inputs)?;
381 let attn_scores = queries.matmul(&keys.t()?)?;
382 let scaling = 1. / (keys.dims()[1] as f64).sqrt();
383 let attn_weights = softmax(&(attn_scores * scaling)?, 1)?;
384
385 let context_length = inputs.dims()[0];
387 let mask_simple: Vec<_> = (0..context_length as u32)
388 .flat_map(|i| (0..context_length as u32).map(move |j| f32::from(j <= i)))
389 .collect();
390 let mask_simple = candle_core::Tensor::from_slice(
391 &mask_simple,
392 (context_length, context_length),
393 inputs.device(),
394 )?;
395 let masked_simple = (attn_weights * mask_simple)?;
396 println!("masked_simple: {:?}", masked_simple.to_vec2::<f32>());
397
398 let row_sums = masked_simple.sum_keepdim(D::Minus1)?;
400 let attn_weights = masked_simple.broadcast_div(&row_sums)?;
401 println!("masked_simple_norm: {:?}", attn_weights.to_vec2::<f32>());
402 Ok(attn_weights)
403 }
404}
405
406pub struct EG07;
423
424impl Example for EG07 {
425 fn description(&self) -> String {
426 let desc = "Compute causal attention weights more efficiently \
427 using `f32::NEGATIVE_INFINITY` and `masked_fill()`.";
428 String::from(desc)
429 }
430
431 fn page_source(&self) -> usize {
432 77_usize
433 }
434
435 fn main(&self) -> Result<()> {
436 let _ = self.main_with_return()?;
437 Ok(())
438 }
439}
440
441impl EG07 {
442 fn main_with_return(&self) -> Result<candle_core::Tensor> {
443 use crate::listings::ch03::SelfAttentionV2;
444 use candle_core::{DType, Module};
445 use candle_nn::ops::softmax;
446 use candle_nn::{VarBuilder, VarMap};
447
448 let inputs = addons::get_inputs();
449 let d_in = inputs.dims()[1]; let d_out = 2_usize;
451
452 let varmap = VarMap::new();
454 let vb = VarBuilder::from_varmap(&varmap, DType::F32, inputs.device());
455 let attn_v2_layer = SelfAttentionV2::new(d_in, d_out, false, vb.pp("attn"))?;
456
457 let queries = attn_v2_layer.w_query().forward(&inputs)?;
459 let keys = attn_v2_layer.w_key().forward(&inputs)?;
460 let attn_scores = queries.matmul(&keys.t()?)?;
461
462 let context_length = attn_scores.dims()[0];
464 let mask: Vec<_> = (0..context_length as u32)
465 .flat_map(|i| (0..context_length as u32).map(move |j| u32::from(j > i)))
466 .collect();
467 let mask = candle_core::Tensor::from_slice(
468 &mask,
469 (context_length, context_length),
470 inputs.device(),
471 )?;
472 let masked = addons::masked_fill(&attn_scores, &mask, f32::NEG_INFINITY)?;
473 println!("masked: {:?}", masked.to_vec2::<f32>());
474
475 let scaling = 1. / (keys.dims()[1] as f64).sqrt();
477 let attn_weights = softmax(&(masked * scaling)?, 1)?;
478 println!("attn_weights: {:?}", attn_weights.to_vec2::<f32>());
479 Ok(attn_weights)
480 }
481}
482
483pub struct EG08;
500
501impl Example for EG08 {
502 fn description(&self) -> String {
503 String::from("Dropout on attention weights.")
504 }
505
506 fn page_source(&self) -> usize {
507 80_usize
508 }
509
510 fn main(&self) -> Result<()> {
511 use candle_nn::Dropout;
512
513 let eg07 = EG07;
514 let attn_weights = eg07.main_with_return()?;
515 let dropout = Dropout::new(0.5);
516
517 let dropped_out = dropout.forward(&attn_weights, true)?;
519 println!("dropped_out: {:?}", dropped_out.to_vec2::<f32>());
520 Ok(())
521 }
522}
523
524pub struct EG09;
541
542impl Example for EG09 {
543 fn description(&self) -> String {
544 String::from("Example usage of `CausalAttention`.")
545 }
546
547 fn page_source(&self) -> usize {
548 81_usize
549 }
550
551 fn main(&self) -> Result<()> {
552 use crate::listings::ch03::CausalAttention;
553 use candle_core::{DType, Module, Tensor};
554 use candle_nn::{VarBuilder, VarMap};
555
556 let inputs = addons::get_inputs();
558 let d_in = inputs.dims()[1]; let d_out = 2_usize;
560 let batch = Tensor::stack(&[&inputs, &inputs], 0usize)?;
561 println!("batch shape: {:?}", batch);
562
563 let varmap = VarMap::new();
565 let vb = VarBuilder::from_varmap(&varmap, DType::F32, inputs.device());
566 let causal_attn = CausalAttention::new(d_in, d_out, 0.0_f32, false, vb.pp("casual_attn"))?;
567
568 let context_vectors = causal_attn.forward(&batch)?;
570 println!("context_vectors.shape: {:?}", context_vectors);
571 Ok(())
572 }
573}
574
575pub struct EG10;
592
593impl Example for EG10 {
594 fn description(&self) -> String {
595 String::from("Example usage of `MultiHeadAttentionWrapper`.")
596 }
597
598 fn page_source(&self) -> usize {
599 85_usize
600 }
601
602 fn main(&self) -> Result<()> {
603 use crate::listings::ch03::MultiHeadAttentionWrapper;
604 use candle_core::{DType, Module, Tensor};
605 use candle_nn::{VarBuilder, VarMap};
606
607 let inputs = addons::get_inputs();
609 let d_in = inputs.dims()[1]; let d_out = 2_usize;
611 let batch = Tensor::stack(&[&inputs, &inputs], 0usize)?;
612 println!("batch shape: {:?}", batch);
613
614 let varmap = VarMap::new();
616 let vb = VarBuilder::from_varmap(&varmap, DType::F32, inputs.device());
617 let num_heads = 2_usize;
618 let mha =
619 MultiHeadAttentionWrapper::new(num_heads, d_in, d_out, 0.0_f32, false, vb.pp("mha"))?;
620
621 let context_vectors = mha.forward(&batch)?;
623 println!("context_vectors.shape: {:?}", context_vectors);
624 println!("context_vectors: {:?}", context_vectors.to_vec3::<f32>());
625 Ok(())
626 }
627}
628
629pub struct EG11;
646
647impl Example for EG11 {
648 fn description(&self) -> String {
649 String::from("Example usage of `MultiHeadAttention`.")
650 }
651
652 fn page_source(&self) -> usize {
653 90_usize
654 }
655
656 fn main(&self) -> Result<()> {
657 use crate::listings::ch03::MultiHeadAttention;
658 use candle_core::{DType, Tensor};
659 use candle_nn::{VarBuilder, VarMap};
660
661 let inputs = addons::get_inputs();
663 let d_in = inputs.dims()[1]; let d_out = 2_usize;
665 let batch = Tensor::stack(&[&inputs, &inputs], 0usize)?;
666 println!("batch shape: {:?}", batch);
667
668 let varmap = VarMap::new();
670 let vb = VarBuilder::from_varmap(&varmap, DType::F32, inputs.device());
671 let num_heads = 2_usize;
672 let mha = MultiHeadAttention::new(d_in, d_out, 0.0_f32, num_heads, false, vb.pp("mha"))?;
673
674 let context_vectors = mha.forward(&batch)?;
676 println!("mha.head_dim: {:?}", mha.head_dim());
677 println!("context_vectors.shape: {:?}", context_vectors);
678 println!("context_vectors: {:?}", context_vectors.to_vec3::<f32>());
679 Ok(())
680 }
681}
682
683pub mod addons {
684 use candle_core::{Device, Result, Tensor};
686
687 pub fn get_inputs() -> Tensor {
689 let dev = Device::cuda_if_available(0).unwrap();
690 Tensor::new(
691 &[
692 [0.43_f32, 0.15, 0.89], [0.55, 0.87, 0.66], [0.57, 0.85, 0.64], [0.22, 0.58, 0.33], [0.77, 0.25, 0.10], [0.05, 0.80, 0.55], ],
699 &dev,
700 )
701 .unwrap()
702 }
703
704 pub fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
706 let shape = mask.shape();
707 let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
708 let m = mask.where_cond(&on_true, on_false)?;
709 Ok(m)
710 }
711}