gpt_model/model.rs
1//! Runtime wrappers for running inference
2//! on a GPT model saved in ONNX format.
3use anyhow::Result;
4use ndarray::{Array, ArrayD, ArrayViewMut, Axis, Ix1, Ix2, Ix3, Ix6};
5use rand::{distributions::WeightedIndex, prelude::Distribution};
6use tract_onnx::prelude::{
7 tvec, DatumExt, Framework, Graph, InferenceModelExt, SimplePlan, Tensor, TypedFact, TypedOp,
8};
9
10/// Alias for the type returned by Tract
11/// for an optimized and strongly-typed
12/// runnable ML model.
13type OptimizedOnnxModel =
14 SimplePlan<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>;
15
16/// Alias for the shape of the GPT-2 `tokens` input tensor.
17type TokensInput = Array<i32, Ix2>;
18
19/// Alias for the shape of the GPT-2 `token_predictions` output tensor.
20///
21/// The shape axes correspond to:
22/// - `0`: Input batch size
23/// - `1`: Input token sequence length
24/// - `2`: Model vocabulary size
25type InferenceOutput = Array<f32, Ix3>;
26
27/// Alias for the shape of the GPT-2 `token_embeddings` output tensor.
28///
29/// The shape axes correspond to:
30/// - `0`: Input batch size
31/// - `1`: Model layer count
32/// - `2`: Key / value pairs (always `2` "rows")
33/// - `3`: Model head count
34/// - `4`: Input token sequence length
35/// - `5`: Model embeddings per layer / model head count
36type HiddenLayersOutput = Array<f32, Ix6>;
37
38/// Token vocabulary size of the GPT-2 models supported
39/// by this library.
40const GPT2_VOCABULARY_SIZE: usize = 50257;
41
42/// Number of layers used by the GPT-2 models supported
43/// by this library.
44const GPT2_LAYER_COUNT: usize = 12;
45
46/// Number of heads used by the GPT-2 models supported
47/// by this library.
48const GPT2_HEAD_COUNT: usize = 12;
49
50/// Number of embeddings used by each layer of the
51/// GPT-2 models supported by this library.
52pub const GPT2_EMBEDDING_SIZE: usize = 768;
53
54/// Sampling temperatuore gradient which affects
55/// the entropy of inferences.
56///
57/// Temperatures of:
58/// - `0.0` will result in no entropy (deterministic outputs).
59/// - `1.0` will defer to the model's internal entropy.
60/// - `> 1` will exaggerate the model's entropy.
61///
62/// In general, _higher_ temperatures result in more
63/// "creative" samples of the model's inferences.
64const SAMPLE_TEMPERATURE: f32 = 0.9;
65
66/// Sampling filter which restricts samples
67/// of the model's inference for a token to
68/// the `P` most confident inferences.
69///
70/// P-values of:
71/// - `0.0` will select only the most likely inference.
72/// - `1.0` will select all inferences (i.e., the entire
73/// vocabulary of the model).
74///
75/// In general, _lower_ P-values result in
76/// more "creative" samples of the model's inferences.
77const SAMPLE_MIN_P_VALUE: f32 = 0.5;
78
79/// The GPT-2 natural langage ML model.
80///
81/// ## Example Usage
82///
83/// ```rust
84/// # use gpt::tokenizer::Tokenizer;
85/// # use gpt::model::Gpt2Model;
86/// #
87/// # let bpe_path = "./gpt-2-model/saved_models/124M_vocab.bpe";
88/// # let encoder_path = "./gpt-2-model/saved_models/124M_encoder.json";
89/// # let model_path = "./gpt-2-model/saved_models/gpt-2-124M.onnx";
90/// #
91/// # let batch_size = 1;
92/// # let sequence_length = 128;
93/// #
94/// // Load tokenizer and GPT-2 model.
95/// let tokenizer = Tokenizer::new(bpe_path, encoder_path);
96/// let gpt_model = Gpt2Model::new(model_path, batch_size, sequence_length).unwrap();
97///
98/// // Convert input text to a token sequence.
99/// let text_in = "Horses aren't real; they can't hurt you.";
100/// let (tokens_in, padding) = tokenizer.encode_to_length(text_in, sequence_length);
101///
102/// // Convert token sequence to an input tensor, and get
103/// // an inference from the model.
104/// let tensor_in = gpt_model.tensor_from_tokens(&[tokens_in]);
105/// let (inference, hidden_layers) = gpt_model.infer(tensor_in);
106///
107/// // Generate the next tokens based on the inference,
108/// // and convert the tokens to text.
109/// let tokens_out = gpt_model.tokens_from_inference(inference, &[padding]);
110/// let generated_text = tokenizer.decode(tokens_out);
111///
112/// // Bonus: Extract the embedding of the input text from
113/// // the hidden layers.
114/// let text_embedding = gpt_model.embeddings_from_layers(&hidden_layers, &[padding], 11);
115/// ```
116pub struct Gpt2Model {
117 /// The loaded ONNX model.
118 model: OptimizedOnnxModel,
119
120 /// The index of the model's token inference output.
121 out_inference_index: usize,
122
123 /// The index of the model's
124 out_hidden_layers_index: usize,
125
126 /// The number of token sequences
127 /// (i.e., "sentences") given to
128 /// the model during inference.
129 batch_size: usize,
130
131 /// The length of each token sequence
132 /// (i.e., "sentence") given to the
133 /// model during inference.
134 sequence_length: usize,
135}
136
137impl Gpt2Model {
138 /// Creates a new GPT-2 model from the ONNX
139 /// model saved at `onnx_model_path`, with fixed
140 /// `batch_size` and `sequence_length`.
141 ///
142 /// `batch_size` specifies the maximum number of
143 /// texts ("token sequences") that can be processed
144 /// during each inference request.
145 ///
146 /// `sequence_length` specifies the number of tokens
147 /// that can be processed by the model in a single
148 /// token sequence. Sequences will be truncated and/or
149 /// padded to match this length.
150 pub fn new(onnx_model_path: &str, batch_size: usize, sequence_length: usize) -> Result<Self> {
151 // Load the model into memory.
152 let mut model = tract_onnx::onnx()
153 .with_ignore_output_shapes(true)
154 .with_ignore_output_types(true)
155 .model_for_path(onnx_model_path)?;
156
157 // Configure shape of the input tokens.
158 model.set_input_fact(0, i32::fact([batch_size, sequence_length]).into())?;
159
160 // Configure shape of the output inferences.
161 let out_inference = model
162 .find_outlet_label("next_token_inferences")
163 .expect("missing inference output");
164 model.set_outlet_fact(
165 out_inference,
166 f32::fact([batch_size, sequence_length, GPT2_VOCABULARY_SIZE]).into(),
167 )?;
168 let out_inference_index = model
169 .output_outlets()?
170 .iter()
171 .position(|o| o == &out_inference)
172 .expect("missing inference output");
173
174 // Configure shape of the output hidden layers.
175 let out_hidden_layers = model
176 .find_outlet_label("hidden_layers")
177 .expect("missing hidden layers output");
178 model.set_outlet_fact(
179 out_hidden_layers,
180 f32::fact([
181 batch_size,
182 GPT2_LAYER_COUNT,
183 2,
184 GPT2_HEAD_COUNT,
185 sequence_length,
186 GPT2_EMBEDDING_SIZE / GPT2_HEAD_COUNT,
187 ])
188 .into(),
189 )?;
190 let out_hidden_layers_index = model
191 .output_outlets()?
192 .iter()
193 .position(|o| o == &out_hidden_layers)
194 .expect("missing hidden layers output");
195
196 // Prepare model for execution.
197 let model = model.into_optimized()?;
198 let model = model.into_runnable()?;
199
200 Ok(Gpt2Model {
201 model,
202 out_inference_index,
203 out_hidden_layers_index,
204 batch_size,
205 sequence_length,
206 })
207 }
208
209 /// Converts a slice of one or more token sequences
210 /// into a single tensor which may be passed into
211 /// the GPT-2 model.
212 ///
213 /// ## Panics
214 ///
215 /// If `tokens` contains any token sequences not
216 /// matching this model's `sequence_length`, or if
217 /// the number of token sequences in `tokens` does
218 /// not match this model's `batch_size`.
219 pub fn tensor_from_tokens(&self, tokens: &[Vec<i32>]) -> TokensInput {
220 assert_eq!(self.batch_size, tokens.len());
221
222 TokensInput::from_shape_fn(
223 (self.batch_size, self.sequence_length),
224 |(batch_index, sequence_index)| tokens[batch_index][sequence_index],
225 )
226 }
227
228 /// Runs the model to generate an inference for `tensor`.
229 ///
230 /// The returned tuple will contain `(inference, hidden_layers)`,
231 /// where `inference` is a 3D tensor of shape
232 /// `[batch_size, sequence_length, vocabulary size]`,
233 /// and `hidden_layers` is a 6D tensor of shape
234 /// `[batch_size, layers, 2, head count, sequence_length, embeddings per head].
235 ///
236 /// For most GPT-2 models, the vocabulary size is `50257`.
237 ///
238 /// For the 124M ("small") GPT-2 model, there will be
239 /// `12` layers, `12` heads, and `64` embeddings per head,
240 /// for a total of `768` embeddings per layer.
241 pub fn infer(&self, tensor: TokensInput) -> (InferenceOutput, HiddenLayersOutput) {
242 // Convert input into a concrete Tract tensor.
243 let tensor: Tensor = tensor.into();
244
245 // Run inference.
246 let model_outputs = self.model.run(tvec!(tensor)).expect("inference");
247
248 // Extract inference data.
249 let inference = model_outputs[self.out_inference_index].clone();
250 let hidden_layers = model_outputs[self.out_hidden_layers_index].clone();
251
252 // Convert inference data to f32 arrays.
253 let inference = (*inference).clone();
254 let inference: ArrayD<f32> = inference.into_array().unwrap();
255 let inference: InferenceOutput = inference.into_dimensionality().unwrap();
256 let hidden_layers = (*hidden_layers).clone();
257 let hidden_layers: ArrayD<f32> = hidden_layers.into_array().unwrap();
258 let hidden_layers: HiddenLayersOutput = hidden_layers.into_dimensionality().unwrap();
259
260 (inference, hidden_layers)
261 }
262
263 /// Returns the number of hidden layers within `hidden_layers`.
264 pub fn count_layers(&self, hidden_layers: &HiddenLayersOutput) -> usize {
265 hidden_layers.dim().1
266 }
267
268 /// Samples `inference` for the next
269 /// token for each sequence in the batch.
270 ///
271 /// `tokens_padding` must be a slice of the
272 /// same length as `batch_size`, where each
273 /// element corresponds to the number of padding
274 /// tokens added onto the input token sequence
275 /// for that batch element.
276 ///
277 /// Returns a 1D tensor of shape `[batch_size]`,
278 /// where each batch entry is the next token in a sequence.
279 pub fn tokens_from_inference(
280 &self,
281 mut inference: InferenceOutput,
282 tokens_padding: &[usize],
283 ) -> Vec<i32> {
284 // Extract and check inference dimensions.
285 let batch_size = inference.dim().0;
286 let sequence_length = inference.dim().1;
287 assert_eq!(self.batch_size, batch_size);
288 assert_eq!(self.sequence_length, sequence_length);
289 assert_eq!(batch_size, tokens_padding.len());
290
291 // Iterate over all token sequences in
292 // the batch.
293 let mut token_indexes = Vec::with_capacity(batch_size);
294 let axis = Axis(0);
295 for (index, padding) in tokens_padding.iter().enumerate().take(batch_size) {
296 let mut inference = inference.index_axis_mut(axis, index);
297 let sample = sample_nucleus(
298 &mut inference,
299 Self::last_token_inference_index(sequence_length, *padding),
300 );
301 token_indexes.push(sample as i32);
302 }
303
304 token_indexes
305 }
306
307 /// Post-processes `hidden_layers` to extract
308 /// the embedding of each sequence in the batch.
309 ///
310 /// Returns a 2D tensor of shape `[batch_size, embeddings per layer]`,
311 /// where each batch entry is the embedding of the
312 /// entire _input_ sequence for that entry.
313 ///
314 /// For the 124M ("small") GPT-2 model, there
315 /// are `768` embeddings per layer.
316 ///
317 /// `tokens_padding` must be a slice of the
318 /// same length as `batch_size`, where each
319 /// element corresponds to the number of padding
320 /// tokens added onto the input token sequence
321 /// for that batch element.
322 pub fn embeddings_from_layers(
323 &self,
324 hidden_layers: &HiddenLayersOutput,
325 tokens_padding: &[usize],
326 hidden_layer_index: usize,
327 ) -> Array<f32, Ix2> {
328 // Extract dimensional data from the layers.
329 let batch_size = hidden_layers.dim().0;
330 assert_eq!(2, hidden_layers.dim().2);
331 let head_count = hidden_layers.dim().3;
332 let token_sequence_length = hidden_layers.dim().4;
333 let embeddings_per_head = hidden_layers.dim().5;
334 let embeddings_per_layer = embeddings_per_head * head_count;
335
336 // Iterate over all final hidden layers in the batch.
337 let mut embeddings = Array::zeros((0, embeddings_per_layer));
338 for (index, padding) in tokens_padding.iter().enumerate().take(batch_size) {
339 // Restrict view to the hidden layers for this batch.
340 let hidden_layer = hidden_layers.index_axis(Axis(0), index);
341
342 // TODO: This line restricts the view to the _last_
343 // hidden layer of this batch. However, "lower" (earlier)
344 // layers may perform better in tasks where over-contextualization
345 // of embeddings isn't desirable:
346 // https://kawine.github.io/blog/nlp/2020/02/03/contextual.html
347 let hidden_layer = hidden_layer.index_axis(Axis(0), hidden_layer_index);
348
349 // Restrict view to the "value" axis of the hidden layer.
350 let hidden_layer = hidden_layer.index_axis(Axis(0), 1);
351
352 // Concatenate embeddings across all GPT model "heads."
353 let mut embedding = Vec::with_capacity(embeddings_per_layer);
354 for head in 0..head_count {
355 // Restrict view to the current head.
356 let hidden_layer = hidden_layer.index_axis(Axis(0), head);
357
358 // Restrict view to the last non-padding token.
359 let token_index = Self::last_token_inference_index(token_sequence_length, *padding);
360 let hidden_layer = hidden_layer.index_axis(Axis(0), token_index);
361
362 embedding.extend(hidden_layer.iter());
363 }
364 let embedding: Array<f32, Ix1> = Array::from_vec(embedding);
365
366 // Copy embeddings into output.
367 embeddings.push_row(embedding.view()).expect("row");
368 }
369
370 embeddings
371 }
372
373 /// Returns the last index which should
374 /// contain an inference on non-padding
375 /// token data.
376 ///
377 /// In the case where `token_padding == token_sequence_length`,
378 /// `0` will be returned.
379 pub fn last_token_inference_index(token_sequence_length: usize, token_padding: usize) -> usize {
380 if token_padding >= token_sequence_length {
381 0
382 } else {
383 token_sequence_length - token_padding - 1
384 }
385 }
386}
387
388/// Performs nucleus sampling of an `inference`
389/// of shape `[sequence_length, vocabulary]`
390/// for the token at `token_index` in the sequence.
391fn sample_nucleus(inference: &mut ArrayViewMut<f32, Ix2>, token_index: usize) -> usize {
392 // Restrict our view to the inference of the `token_index`th token.
393 let mut inference = inference.index_axis_mut(Axis(0), token_index);
394
395 // Apply sampling temperature.
396 inference.mapv_inplace(|score| score / SAMPLE_TEMPERATURE);
397
398 // Each value in `inference` is a "score" of how likely
399 // it is a specific token comes _after_ the token
400 // that inferrence ran on.
401 //
402 // Here, we create a clone of the inference and sort it
403 // from the highest to lowest scores.
404 let mut sorted_scores: Vec<f32> = inference.iter().copied().collect();
405 sorted_scores.sort_by(|a, b| a.total_cmp(b).reverse());
406 let mut sorted_scores: Array<f32, Ix1> = sorted_scores.into();
407 assert!(sorted_scores[0] > sorted_scores[sorted_scores.len() - 1]);
408
409 // A clone of the original scores will be needed later,
410 // when performing the final sampling of the scores.
411 let original_sorted_scores = sorted_scores.clone();
412
413 // Softmax the sorted scores.
414 softmax(&mut sorted_scores.view_mut());
415
416 // Cumulative sum the sorted scores.
417 sorted_scores.accumulate_axis_inplace(Axis(0), |&prev, curr| *curr += prev);
418
419 // Find the lowest score in `k`, which
420 // is the set of scores that have a
421 // cumulative probability greater
422 // than the sampling P-value.
423 //
424 // Because the scores are sorted
425 // in descending order, we can use
426 // the count of all scores `<=` the
427 // sampling P-value, minus one,
428 // as the index of the lowest
429 // score in `k`.
430 //
431 // In "Top-K" sampling, we would
432 // stop processing at this stage
433 // and randomly sample from the set
434 // of scores in `k`.
435 let iter = sorted_scores
436 .iter()
437 .filter(|score| score <= &&SAMPLE_MIN_P_VALUE);
438 let k_min_index = iter.count().saturating_sub(1);
439 let k_min_score = original_sorted_scores[k_min_index];
440
441 // "Mask" or "drop out" all scores lower
442 // than `k_min_score` by replacing them
443 // with a tiny number.
444 //
445 // This masking will cause these scores
446 // to be effectively removed from consideration
447 // during sampling when we softmax the scores.
448 inference.mapv_inplace(|score| {
449 if score < k_min_score {
450 return -1e10;
451 }
452
453 score
454 });
455
456 // Calculate the softmax of the scores.
457 softmax(&mut inference.view_mut());
458
459 // Draw a weighted sample from the inference.
460 // Although not _technically_ a multinomial sample,
461 // the resulting inferences are good enough!
462 let inference = inference.mapv(|score| score as f64);
463 let multinomial = WeightedIndex::new(inference.view()).unwrap();
464
465 multinomial.sample(&mut rand::thread_rng())
466}
467
468/// Calculates the `softmax` of a 1-dimensional
469/// `tensor` in-place, replacing its contents
470/// with their softmax'ed equivalents.
471///
472/// ## What's a `softmax`?
473///
474/// The `softmax` function converts a vector
475/// (1-dimensional tensor, or "array") of `n`
476/// values into a vector of `n` values _that
477/// sum to `1.0`_.
478///
479/// Regardless of what values are in the original
480/// inputs, the output will always contain values
481/// in the range of `0.0` to `1.0`. This property
482/// makes `softmax` similar to a normalization
483/// function that can turn arbitrary data into
484/// a `0-1` scale.
485///
486/// _Unlike_ a "typical" normalization function,
487/// which maps values to a `0-1` scale based on
488/// some known lower and upper bound (e.g., mapping
489/// a any byte in the range `0-255` to `0-1`),
490/// `softmax` maps values based on their relative
491/// "weights".
492///
493/// For example, a vector containing
494/// `(-0.3, 1,000,000)` might produce a `softmax`
495/// vector of `(0.1, 0.9)` (fyi, these numbers
496/// are for illustration and not technically correct).
497/// This mapping shows that the first element was
498/// _very_ small compared to the second element
499/// in the input vector.
500fn softmax(tensor: &mut ArrayViewMut<f32, Ix1>) {
501 // Shift all values to handle under/over flow.
502 let max_value = *tensor.iter().max_by(|a, b| a.total_cmp(b)).unwrap();
503 tensor.mapv_inplace(|value| value - max_value);
504
505 // Perform the softmax operation, which:
506 //
507 // 1. Replaces each value `v` with the value of
508 // Euler's constant raised to that value. We'll
509 // call each of these new values `e^v`.
510 //
511 // 2. Sums all `e^v`. We'll call this sum `sum(e^v)`.
512 //
513 // 3. Replace each `e^v` with `e^v / sum(e^v)`.
514 //
515 // The final values will be equivalent to their
516 // normalized probabilities on a 0-1 scale that sums to 1.
517 tensor.mapv_inplace(|value| value.exp());
518 let sum_exps = tensor.sum();
519 tensor.mapv_inplace(|value| value / sum_exps);
520
521 // Handle rounding errors to ensure all values sum to 1.
522 let sum_values = tensor.sum();
523 tensor.mapv_inplace(|value| value / sum_values);
524}
525
526#[cfg(test)]
527pub mod test {
528 use crate::tokenizer::{self, Tokenizer};
529
530 use super::*;
531
532 // Paths to OpenAI training data for the 124M (smallest) GPT-2 model.
533 const MODEL_PATH: &str = "./gpt-2-model/saved_models/gpt-2-124M.onnx";
534 const BPE_PATH: &str = "./gpt-2-model/saved_models/124M_vocab.bpe";
535 const ENCODER_PATH: &str = "./gpt-2-model/saved_models/124M_encoder.json";
536
537 // Expected model hyperparameters.
538 const BATCH_SIZE: usize = 1;
539 const SEQUENCE_LENGTH: usize = 128;
540
541 // Sample input text for inference.
542 const INPUT_TEXT_STR: &str =
543 "GPT-2 is a machine learning model for natural language-processing;";
544
545 #[test]
546 fn infers_and_samples_sentence() {
547 // Load model.
548 let model = Gpt2Model::new(MODEL_PATH, BATCH_SIZE, SEQUENCE_LENGTH).expect("load failed");
549
550 // Load tokenizer.
551 let tokenizer = Tokenizer::new(BPE_PATH, ENCODER_PATH);
552
553 // Prepare initial set of tokens.
554 let tokens = tokenizer.encode(INPUT_TEXT_STR);
555 let mut all_tokens = tokens.clone();
556
557 eprintln!(" Prompt: `{}`", INPUT_TEXT_STR);
558 eprint!("Inference: ");
559
560 // Predict the next full sentence from the model.
561 let mut full_sentence = String::from(INPUT_TEXT_STR);
562 for _ in 0..64 {
563 // Prepare input tokens, padding as necessary.
564 let mut inference_tokens = all_tokens.clone();
565 let padding = SEQUENCE_LENGTH - inference_tokens.len();
566 for _ in 0..padding {
567 inference_tokens.push(tokenizer::PAD_TOKEN);
568 }
569
570 // Prepare inference tensor.
571 let tensor = model.tensor_from_tokens(&[inference_tokens]);
572
573 // Run inference.
574 let (inference, hidden_layers) = model.infer(tensor);
575
576 // Sample the next token in the sentence based on inference.
577 let next_token = model.tokens_from_inference(inference, &[padding])[0];
578 all_tokens.push(next_token);
579
580 // Decode the token and add it to the sentence.
581 let next_word = tokenizer.decode(vec![next_token]);
582 full_sentence.push_str(&next_word);
583
584 eprint!("{}", next_word);
585
586 // Quit early if the model emits a full-stop.
587 // In these tests, we always embed from the final
588 // ("highest") hidden layer.
589 let hidden_layer_index = model.count_layers(&hidden_layers) - 1;
590 if full_sentence.ends_with('.') {
591 eprintln!();
592 eprintln!(
593 "Final inference embedding: {:?}",
594 model.embeddings_from_layers(&hidden_layers, &[padding], hidden_layer_index)
595 );
596 break;
597 }
598
599 assert_eq!(tokenizer.decode(all_tokens.clone()), full_sentence);
600 }
601 }
602}