llms_from_scratch_rs/examples/
ch02.rs

1//! Examples from Chapter 2
2
3use crate::Example;
4use anyhow::Result;
5
6/// # Example of reading text files into Rust
7///
8/// #### Id
9/// 02.01
10///
11/// #### Page
12/// This example starts on page 22
13///
14/// #### CLI command
15/// ```sh
16/// # without cuda
17/// cargo run example 02.01
18///
19/// # with cuda
20/// cargo run --features cuda example 02.01
21/// ```
22pub struct EG01;
23
24impl Example for EG01 {
25    fn description(&self) -> String {
26        String::from("Example usage of `listings::ch02::sample_read_text`")
27    }
28
29    fn page_source(&self) -> usize {
30        22_usize
31    }
32
33    fn main(&self) -> Result<()> {
34        use crate::listings::ch02::sample_read_text;
35        let _raw_text = sample_read_text(true)?;
36        Ok(())
37    }
38}
39
40/// # Example of building a vocabulary
41///
42/// #### Id
43/// 02.02
44///
45/// #### Page
46/// This example starts on page 25
47///
48/// #### CLI command
49/// ```sh
50/// # without cuda
51/// cargo run example 02.02
52///
53/// # with cuda
54/// cargo run --features cuda example 02.02
55/// ```
56pub struct EG02;
57
58impl Example for EG02 {
59    fn description(&self) -> String {
60        String::from("Example usage of `listings::ch02::sample_create_vocab`")
61    }
62
63    fn page_source(&self) -> usize {
64        25_usize
65    }
66
67    fn main(&self) -> Result<()> {
68        use crate::listings::ch02::sample_create_vocab;
69
70        let vocab = sample_create_vocab()?;
71        // Note: this iter is not sorted
72        for (i, item) in vocab.iter().enumerate() {
73            println!("{:?}", item);
74            if i >= 50 {
75                break;
76            }
77        }
78        Ok(())
79    }
80}
81
82/// # Use candle to generate an Embedding Layer
83///
84/// #### Id
85/// 02.03
86///
87/// #### Page
88/// This example starts on page 42
89///
90/// #### CLI command
91/// ```sh
92/// # without cuda
93/// cargo run example 02.03
94///
95/// # with cuda
96/// cargo run --features cuda example 02.03
97/// ```
98pub struct EG03;
99
100impl Example for EG03 {
101    fn description(&self) -> String {
102        String::from("Use candle to generate an Embedding Layer.")
103    }
104
105    fn page_source(&self) -> usize {
106        42_usize
107    }
108
109    fn main(&self) -> Result<()> {
110        use candle_core::{DType, Device, Tensor};
111        use candle_nn::{embedding, VarBuilder, VarMap};
112
113        let vocab_size = 6_usize;
114        let output_dim = 3_usize;
115        let varmap = VarMap::new();
116        let dev = Device::cuda_if_available(0)?;
117        let vs = VarBuilder::from_varmap(&varmap, DType::F32, &dev);
118        let emb = embedding(vocab_size, output_dim, vs)?;
119
120        println!("{:?}", emb.embeddings().to_vec2::<f32>());
121        // print specific embedding of a given token id
122        let token_ids = Tensor::new(&[3u32], &dev)?;
123        println!(
124            "{:?}",
125            emb.embeddings()
126                .index_select(&token_ids, 0)?
127                .to_vec2::<f32>()
128        );
129        Ok(())
130    }
131}
132
133/// # Create absolute positional embeddings
134///
135/// #### Id
136/// 02.04
137///
138/// #### Page
139/// This example starts on page 47
140///
141/// #### CLI command
142/// ```sh
143/// # without cuda
144/// cargo run example 02.04
145///
146/// # with cuda
147/// cargo run --features cuda example 02.04
148/// ```
149pub struct EG04;
150
151impl Example for EG04 {
152    fn description(&self) -> String {
153        String::from("Create absolute positional embeddings.")
154    }
155
156    fn page_source(&self) -> usize {
157        47_usize
158    }
159
160    fn main(&self) -> Result<()> {
161        use crate::listings::ch02::{create_dataloader_v1, DataLoader};
162        use candle_core::{DType, Tensor};
163        use candle_nn::{embedding, VarBuilder, VarMap};
164        use std::fs;
165
166        // create data batcher
167        let raw_text = fs::read_to_string("data/the-verdict.txt").expect("Unable to read the file");
168        let max_length = 4_usize;
169        let stride = max_length;
170        let shuffle = false;
171        let drop_last = false;
172        let batch_size = 8_usize;
173        let data_loader = create_dataloader_v1(
174            &raw_text[..],
175            batch_size,
176            max_length,
177            stride,
178            shuffle,
179            drop_last,
180        );
181
182        let mut batch_iter = data_loader.batcher();
183
184        // get embeddings of first batch inputs
185        match batch_iter.next() {
186            Some(Ok((inputs, _targets))) => {
187                let varmap = VarMap::new();
188                let vs = VarBuilder::from_varmap(&varmap, DType::F32, inputs.device());
189
190                let vocab_size = 50_257_usize;
191                let output_dim = 256_usize;
192                let mut final_dims = inputs.dims().to_vec();
193                final_dims.push(output_dim);
194
195                // token embeddings of the current batch inputs
196                let token_embedding_layer = embedding(vocab_size, output_dim, vs.pp("tok_emb"))?;
197                let token_embeddings = token_embedding_layer
198                    .embeddings()
199                    .index_select(&inputs.flatten_all()?, 0)?;
200                let token_embeddings = token_embeddings.reshape(final_dims)?;
201                println!("token embeddings dims: {:?}", token_embeddings.dims());
202
203                // position embeddings
204                let context_length = max_length;
205                let pos_embedding_layer = embedding(context_length, output_dim, vs.pp("pos_emb"))?;
206                let pos_ids = Tensor::arange(0u32, context_length as u32, inputs.device())?;
207                let pos_embeddings = pos_embedding_layer.embeddings().index_select(&pos_ids, 0)?;
208                println!("pos embeddings dims: {:?}", pos_embeddings.dims());
209
210                // incorporate positional embeddings
211                let input_embeddings = token_embeddings.broadcast_add(&pos_embeddings)?;
212                println!("input embeddings dims: {:?}", input_embeddings.dims());
213            }
214            Some(Err(err)) => panic!("{}", err),
215            None => panic!("None"),
216        }
217        Ok(())
218    }
219}