gpt_model/
model.rs

1//! Runtime wrappers for running inference
2//! on a GPT model saved in ONNX format.
3use anyhow::Result;
4use ndarray::{Array, ArrayD, ArrayViewMut, Axis, Ix1, Ix2, Ix3, Ix6};
5use rand::{distributions::WeightedIndex, prelude::Distribution};
6use tract_onnx::prelude::{
7    tvec, DatumExt, Framework, Graph, InferenceModelExt, SimplePlan, Tensor, TypedFact, TypedOp,
8};
9
10/// Alias for the type returned by Tract
11/// for an optimized and strongly-typed
12/// runnable ML model.
13type OptimizedOnnxModel =
14    SimplePlan<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>;
15
16/// Alias for the shape of the GPT-2 `tokens` input tensor.
17type TokensInput = Array<i32, Ix2>;
18
19/// Alias for the shape of the GPT-2 `token_predictions` output tensor.
20///
21/// The shape axes correspond to:
22/// - `0`: Input batch size
23/// - `1`: Input token sequence length
24/// - `2`: Model vocabulary size
25type InferenceOutput = Array<f32, Ix3>;
26
27/// Alias for the shape of the GPT-2 `token_embeddings` output tensor.
28///
29/// The shape axes correspond to:
30/// - `0`: Input batch size
31/// - `1`: Model layer count
32/// - `2`: Key / value pairs (always `2` "rows")
33/// - `3`: Model head count
34/// - `4`: Input token sequence length
35/// - `5`: Model embeddings per layer / model head count
36type HiddenLayersOutput = Array<f32, Ix6>;
37
38/// Token vocabulary size of the GPT-2 models supported
39/// by this library.
40const GPT2_VOCABULARY_SIZE: usize = 50257;
41
42/// Number of layers used by the GPT-2 models supported
43/// by this library.
44const GPT2_LAYER_COUNT: usize = 12;
45
46/// Number of heads used by the GPT-2 models supported
47/// by this library.
48const GPT2_HEAD_COUNT: usize = 12;
49
50/// Number of embeddings used by each layer of the
51/// GPT-2 models supported by this library.
52pub const GPT2_EMBEDDING_SIZE: usize = 768;
53
54/// Sampling temperatuore gradient which affects
55/// the entropy of inferences.
56///
57/// Temperatures of:
58/// - `0.0` will result in no entropy (deterministic outputs).
59/// - `1.0` will defer to the model's internal entropy.
60/// - `> 1` will exaggerate the model's entropy.
61///
62/// In general, _higher_ temperatures result in more
63/// "creative" samples of the model's inferences.
64const SAMPLE_TEMPERATURE: f32 = 0.9;
65
66/// Sampling filter which restricts samples
67/// of the model's inference for a token to
68/// the `P` most confident inferences.
69///
70/// P-values of:
71/// - `0.0` will select only the most likely inference.
72/// - `1.0` will select all inferences (i.e., the entire
73///   vocabulary of the model).
74///
75/// In general, _lower_ P-values result in
76/// more "creative" samples of the model's inferences.
77const SAMPLE_MIN_P_VALUE: f32 = 0.5;
78
79/// The GPT-2 natural langage ML model.
80///
81/// ## Example Usage
82///
83/// ```rust
84/// # use gpt::tokenizer::Tokenizer;
85/// # use gpt::model::Gpt2Model;
86/// #
87/// # let bpe_path = "./gpt-2-model/saved_models/124M_vocab.bpe";
88/// # let encoder_path = "./gpt-2-model/saved_models/124M_encoder.json";
89/// # let model_path = "./gpt-2-model/saved_models/gpt-2-124M.onnx";
90/// #
91/// # let batch_size = 1;
92/// # let sequence_length = 128;
93/// #
94/// // Load tokenizer and GPT-2 model.
95/// let tokenizer = Tokenizer::new(bpe_path, encoder_path);
96/// let gpt_model = Gpt2Model::new(model_path, batch_size, sequence_length).unwrap();
97///
98/// // Convert input text to a token sequence.
99/// let text_in = "Horses aren't real; they can't hurt you.";
100/// let (tokens_in, padding) = tokenizer.encode_to_length(text_in, sequence_length);
101///
102/// // Convert token sequence to an input tensor, and get
103/// // an inference from the model.
104/// let tensor_in = gpt_model.tensor_from_tokens(&[tokens_in]);
105/// let (inference, hidden_layers) = gpt_model.infer(tensor_in);
106///
107/// // Generate the next tokens based on the inference,
108/// // and convert the tokens to text.
109/// let tokens_out = gpt_model.tokens_from_inference(inference, &[padding]);
110/// let generated_text = tokenizer.decode(tokens_out);
111///
112/// // Bonus: Extract the embedding of the input text from
113/// //        the hidden layers.
114/// let text_embedding = gpt_model.embeddings_from_layers(&hidden_layers, &[padding], 11);
115/// ```
116pub struct Gpt2Model {
117    /// The loaded ONNX model.
118    model: OptimizedOnnxModel,
119
120    /// The index of the model's token inference output.
121    out_inference_index: usize,
122
123    /// The index of the model's
124    out_hidden_layers_index: usize,
125
126    /// The number of token sequences
127    /// (i.e., "sentences") given to
128    /// the model during inference.
129    batch_size: usize,
130
131    /// The length of each token sequence
132    /// (i.e., "sentence") given to the
133    /// model during inference.
134    sequence_length: usize,
135}
136
137impl Gpt2Model {
138    /// Creates a new GPT-2 model from the ONNX
139    /// model saved at `onnx_model_path`, with fixed
140    /// `batch_size` and `sequence_length`.
141    ///
142    /// `batch_size` specifies the maximum number of
143    /// texts ("token sequences") that can be processed
144    /// during each inference request.
145    ///
146    /// `sequence_length` specifies the number of tokens
147    /// that can be processed by the model in a single
148    /// token sequence. Sequences will be truncated and/or
149    /// padded to match this length.
150    pub fn new(onnx_model_path: &str, batch_size: usize, sequence_length: usize) -> Result<Self> {
151        // Load the model into memory.
152        let mut model = tract_onnx::onnx()
153            .with_ignore_output_shapes(true)
154            .with_ignore_output_types(true)
155            .model_for_path(onnx_model_path)?;
156
157        // Configure shape of the input tokens.
158        model.set_input_fact(0, i32::fact([batch_size, sequence_length]).into())?;
159
160        // Configure shape of the output inferences.
161        let out_inference = model
162            .find_outlet_label("next_token_inferences")
163            .expect("missing inference output");
164        model.set_outlet_fact(
165            out_inference,
166            f32::fact([batch_size, sequence_length, GPT2_VOCABULARY_SIZE]).into(),
167        )?;
168        let out_inference_index = model
169            .output_outlets()?
170            .iter()
171            .position(|o| o == &out_inference)
172            .expect("missing inference output");
173
174        // Configure shape of the output hidden layers.
175        let out_hidden_layers = model
176            .find_outlet_label("hidden_layers")
177            .expect("missing hidden layers output");
178        model.set_outlet_fact(
179            out_hidden_layers,
180            f32::fact([
181                batch_size,
182                GPT2_LAYER_COUNT,
183                2,
184                GPT2_HEAD_COUNT,
185                sequence_length,
186                GPT2_EMBEDDING_SIZE / GPT2_HEAD_COUNT,
187            ])
188            .into(),
189        )?;
190        let out_hidden_layers_index = model
191            .output_outlets()?
192            .iter()
193            .position(|o| o == &out_hidden_layers)
194            .expect("missing hidden layers output");
195
196        // Prepare model for execution.
197        let model = model.into_optimized()?;
198        let model = model.into_runnable()?;
199
200        Ok(Gpt2Model {
201            model,
202            out_inference_index,
203            out_hidden_layers_index,
204            batch_size,
205            sequence_length,
206        })
207    }
208
209    /// Converts a slice of one or more token sequences
210    /// into a single tensor which may be passed into
211    /// the GPT-2 model.
212    ///
213    /// ## Panics
214    ///
215    /// If `tokens` contains any token sequences not
216    /// matching this model's `sequence_length`, or if
217    /// the number of token sequences in `tokens` does
218    /// not match this model's `batch_size`.
219    pub fn tensor_from_tokens(&self, tokens: &[Vec<i32>]) -> TokensInput {
220        assert_eq!(self.batch_size, tokens.len());
221
222        TokensInput::from_shape_fn(
223            (self.batch_size, self.sequence_length),
224            |(batch_index, sequence_index)| tokens[batch_index][sequence_index],
225        )
226    }
227
228    /// Runs the model to generate an inference for `tensor`.
229    ///
230    /// The returned tuple will contain `(inference, hidden_layers)`,
231    /// where `inference` is a 3D tensor of shape
232    /// `[batch_size, sequence_length, vocabulary size]`,
233    /// and `hidden_layers` is a 6D tensor of shape
234    /// `[batch_size, layers, 2, head count, sequence_length, embeddings per head].
235    ///
236    /// For most GPT-2 models, the vocabulary size is `50257`.
237    ///
238    /// For the 124M ("small") GPT-2 model, there will be
239    /// `12` layers, `12` heads, and `64` embeddings per head,
240    /// for a total of `768` embeddings per layer.
241    pub fn infer(&self, tensor: TokensInput) -> (InferenceOutput, HiddenLayersOutput) {
242        // Convert input into a concrete Tract tensor.
243        let tensor: Tensor = tensor.into();
244
245        // Run inference.
246        let model_outputs = self.model.run(tvec!(tensor)).expect("inference");
247
248        // Extract inference data.
249        let inference = model_outputs[self.out_inference_index].clone();
250        let hidden_layers = model_outputs[self.out_hidden_layers_index].clone();
251
252        // Convert inference data to f32 arrays.
253        let inference = (*inference).clone();
254        let inference: ArrayD<f32> = inference.into_array().unwrap();
255        let inference: InferenceOutput = inference.into_dimensionality().unwrap();
256        let hidden_layers = (*hidden_layers).clone();
257        let hidden_layers: ArrayD<f32> = hidden_layers.into_array().unwrap();
258        let hidden_layers: HiddenLayersOutput = hidden_layers.into_dimensionality().unwrap();
259
260        (inference, hidden_layers)
261    }
262
263    /// Returns the number of hidden layers within `hidden_layers`.
264    pub fn count_layers(&self, hidden_layers: &HiddenLayersOutput) -> usize {
265        hidden_layers.dim().1
266    }
267
268    /// Samples `inference` for the next
269    /// token for each sequence in the batch.
270    ///
271    /// `tokens_padding` must be a slice of the
272    /// same length as `batch_size`, where each
273    /// element corresponds to the number of padding
274    /// tokens added onto the input token sequence
275    /// for that batch element.
276    ///
277    /// Returns a 1D tensor of shape `[batch_size]`,
278    /// where each batch entry is the next token in a sequence.
279    pub fn tokens_from_inference(
280        &self,
281        mut inference: InferenceOutput,
282        tokens_padding: &[usize],
283    ) -> Vec<i32> {
284        // Extract and check inference dimensions.
285        let batch_size = inference.dim().0;
286        let sequence_length = inference.dim().1;
287        assert_eq!(self.batch_size, batch_size);
288        assert_eq!(self.sequence_length, sequence_length);
289        assert_eq!(batch_size, tokens_padding.len());
290
291        // Iterate over all token sequences in
292        // the batch.
293        let mut token_indexes = Vec::with_capacity(batch_size);
294        let axis = Axis(0);
295        for (index, padding) in tokens_padding.iter().enumerate().take(batch_size) {
296            let mut inference = inference.index_axis_mut(axis, index);
297            let sample = sample_nucleus(
298                &mut inference,
299                Self::last_token_inference_index(sequence_length, *padding),
300            );
301            token_indexes.push(sample as i32);
302        }
303
304        token_indexes
305    }
306
307    /// Post-processes `hidden_layers` to extract
308    /// the embedding of each sequence in the batch.
309    ///
310    /// Returns a 2D tensor of shape `[batch_size, embeddings per layer]`,
311    /// where each batch entry is the embedding of the
312    /// entire _input_ sequence for that entry.
313    ///
314    /// For the 124M ("small") GPT-2 model, there
315    /// are `768` embeddings per layer.
316    ///
317    /// `tokens_padding` must be a slice of the
318    /// same length as `batch_size`, where each
319    /// element corresponds to the number of padding
320    /// tokens added onto the input token sequence
321    /// for that batch element.
322    pub fn embeddings_from_layers(
323        &self,
324        hidden_layers: &HiddenLayersOutput,
325        tokens_padding: &[usize],
326        hidden_layer_index: usize,
327    ) -> Array<f32, Ix2> {
328        // Extract dimensional data from the layers.
329        let batch_size = hidden_layers.dim().0;
330        assert_eq!(2, hidden_layers.dim().2);
331        let head_count = hidden_layers.dim().3;
332        let token_sequence_length = hidden_layers.dim().4;
333        let embeddings_per_head = hidden_layers.dim().5;
334        let embeddings_per_layer = embeddings_per_head * head_count;
335
336        // Iterate over all final hidden layers in the batch.
337        let mut embeddings = Array::zeros((0, embeddings_per_layer));
338        for (index, padding) in tokens_padding.iter().enumerate().take(batch_size) {
339            // Restrict view to the hidden layers for this batch.
340            let hidden_layer = hidden_layers.index_axis(Axis(0), index);
341
342            // TODO: This line restricts the view to the _last_
343            // hidden layer of this batch. However, "lower" (earlier)
344            // layers may perform better in tasks where over-contextualization
345            // of embeddings isn't desirable:
346            // https://kawine.github.io/blog/nlp/2020/02/03/contextual.html
347            let hidden_layer = hidden_layer.index_axis(Axis(0), hidden_layer_index);
348
349            // Restrict view to the "value" axis of the hidden layer.
350            let hidden_layer = hidden_layer.index_axis(Axis(0), 1);
351
352            // Concatenate embeddings across all GPT model "heads."
353            let mut embedding = Vec::with_capacity(embeddings_per_layer);
354            for head in 0..head_count {
355                // Restrict view to the current head.
356                let hidden_layer = hidden_layer.index_axis(Axis(0), head);
357
358                // Restrict view to the last non-padding token.
359                let token_index = Self::last_token_inference_index(token_sequence_length, *padding);
360                let hidden_layer = hidden_layer.index_axis(Axis(0), token_index);
361
362                embedding.extend(hidden_layer.iter());
363            }
364            let embedding: Array<f32, Ix1> = Array::from_vec(embedding);
365
366            // Copy embeddings into output.
367            embeddings.push_row(embedding.view()).expect("row");
368        }
369
370        embeddings
371    }
372
373    /// Returns the last index which should
374    /// contain an inference on non-padding
375    /// token data.
376    ///
377    /// In the case where `token_padding == token_sequence_length`,
378    /// `0` will be returned.
379    pub fn last_token_inference_index(token_sequence_length: usize, token_padding: usize) -> usize {
380        if token_padding >= token_sequence_length {
381            0
382        } else {
383            token_sequence_length - token_padding - 1
384        }
385    }
386}
387
388/// Performs nucleus sampling of an `inference`
389/// of shape `[sequence_length, vocabulary]`
390/// for the token at `token_index` in the sequence.
391fn sample_nucleus(inference: &mut ArrayViewMut<f32, Ix2>, token_index: usize) -> usize {
392    // Restrict our view to the inference of the `token_index`th token.
393    let mut inference = inference.index_axis_mut(Axis(0), token_index);
394
395    // Apply sampling temperature.
396    inference.mapv_inplace(|score| score / SAMPLE_TEMPERATURE);
397
398    // Each value in `inference` is a "score" of how likely
399    // it is a specific token comes _after_ the token
400    // that inferrence ran on.
401    //
402    // Here, we create a clone of the inference and sort it
403    // from the highest to lowest scores.
404    let mut sorted_scores: Vec<f32> = inference.iter().copied().collect();
405    sorted_scores.sort_by(|a, b| a.total_cmp(b).reverse());
406    let mut sorted_scores: Array<f32, Ix1> = sorted_scores.into();
407    assert!(sorted_scores[0] > sorted_scores[sorted_scores.len() - 1]);
408
409    // A clone of the original scores will be needed later,
410    // when performing the final sampling of the scores.
411    let original_sorted_scores = sorted_scores.clone();
412
413    // Softmax the sorted scores.
414    softmax(&mut sorted_scores.view_mut());
415
416    // Cumulative sum the sorted scores.
417    sorted_scores.accumulate_axis_inplace(Axis(0), |&prev, curr| *curr += prev);
418
419    // Find the lowest score in `k`, which
420    // is the set of scores that have a
421    // cumulative probability greater
422    // than the sampling P-value.
423    //
424    // Because the scores are sorted
425    // in descending order, we can use
426    // the count of all scores `<=` the
427    // sampling P-value, minus one,
428    // as the index of the lowest
429    // score in `k`.
430    //
431    // In "Top-K" sampling, we would
432    // stop processing at this stage
433    // and randomly sample from the set
434    // of scores in `k`.
435    let iter = sorted_scores
436        .iter()
437        .filter(|score| score <= &&SAMPLE_MIN_P_VALUE);
438    let k_min_index = iter.count().saturating_sub(1);
439    let k_min_score = original_sorted_scores[k_min_index];
440
441    // "Mask" or "drop out" all scores lower
442    // than `k_min_score` by replacing them
443    // with a tiny number.
444    //
445    // This masking will cause these scores
446    // to be effectively removed from consideration
447    // during sampling when we softmax the scores.
448    inference.mapv_inplace(|score| {
449        if score < k_min_score {
450            return -1e10;
451        }
452
453        score
454    });
455
456    // Calculate the softmax of the scores.
457    softmax(&mut inference.view_mut());
458
459    // Draw a weighted sample from the inference.
460    // Although not _technically_ a multinomial sample,
461    // the resulting inferences are good enough!
462    let inference = inference.mapv(|score| score as f64);
463    let multinomial = WeightedIndex::new(inference.view()).unwrap();
464
465    multinomial.sample(&mut rand::thread_rng())
466}
467
468/// Calculates the `softmax` of a 1-dimensional
469/// `tensor` in-place, replacing its contents
470/// with their softmax'ed equivalents.
471///
472/// ## What's a `softmax`?
473///
474/// The `softmax` function converts a vector
475/// (1-dimensional tensor, or "array") of `n`
476/// values into a vector of `n` values _that
477/// sum to `1.0`_.
478///
479/// Regardless of what values are in the original
480/// inputs, the output will always contain values
481/// in the range of `0.0` to `1.0`. This property
482/// makes `softmax` similar to a normalization
483/// function that can turn arbitrary data into
484/// a `0-1` scale.
485///
486/// _Unlike_ a "typical" normalization function,
487/// which maps values to a `0-1` scale based on
488/// some known lower and upper bound (e.g., mapping
489/// a any byte in the range `0-255` to `0-1`),
490/// `softmax` maps values based on their relative
491/// "weights".
492///
493/// For example, a vector containing
494/// `(-0.3, 1,000,000)` might produce a `softmax`
495/// vector of `(0.1, 0.9)` (fyi, these numbers
496/// are for illustration and not technically correct).
497/// This mapping shows that the first element was
498/// _very_ small compared to the second element
499/// in the input vector.
500fn softmax(tensor: &mut ArrayViewMut<f32, Ix1>) {
501    // Shift all values to handle under/over flow.
502    let max_value = *tensor.iter().max_by(|a, b| a.total_cmp(b)).unwrap();
503    tensor.mapv_inplace(|value| value - max_value);
504
505    // Perform the softmax operation, which:
506    //
507    // 1. Replaces each value `v` with the value of
508    //    Euler's constant raised to that value. We'll
509    //    call each of these new values `e^v`.
510    //
511    // 2. Sums all `e^v`. We'll call this sum `sum(e^v)`.
512    //
513    // 3. Replace each `e^v` with `e^v / sum(e^v)`.
514    //
515    // The final values will be equivalent to their
516    // normalized probabilities on a 0-1 scale that sums to 1.
517    tensor.mapv_inplace(|value| value.exp());
518    let sum_exps = tensor.sum();
519    tensor.mapv_inplace(|value| value / sum_exps);
520
521    // Handle rounding errors to ensure all values sum to 1.
522    let sum_values = tensor.sum();
523    tensor.mapv_inplace(|value| value / sum_values);
524}
525
526#[cfg(test)]
527pub mod test {
528    use crate::tokenizer::{self, Tokenizer};
529
530    use super::*;
531
532    // Paths to OpenAI training data for the 124M (smallest) GPT-2 model.
533    const MODEL_PATH: &str = "./gpt-2-model/saved_models/gpt-2-124M.onnx";
534    const BPE_PATH: &str = "./gpt-2-model/saved_models/124M_vocab.bpe";
535    const ENCODER_PATH: &str = "./gpt-2-model/saved_models/124M_encoder.json";
536
537    // Expected model hyperparameters.
538    const BATCH_SIZE: usize = 1;
539    const SEQUENCE_LENGTH: usize = 128;
540
541    // Sample input text for inference.
542    const INPUT_TEXT_STR: &str =
543        "GPT-2 is a machine learning model for natural language-processing;";
544
545    #[test]
546    fn infers_and_samples_sentence() {
547        // Load model.
548        let model = Gpt2Model::new(MODEL_PATH, BATCH_SIZE, SEQUENCE_LENGTH).expect("load failed");
549
550        // Load tokenizer.
551        let tokenizer = Tokenizer::new(BPE_PATH, ENCODER_PATH);
552
553        // Prepare initial set of tokens.
554        let tokens = tokenizer.encode(INPUT_TEXT_STR);
555        let mut all_tokens = tokens.clone();
556
557        eprintln!("   Prompt: `{}`", INPUT_TEXT_STR);
558        eprint!("Inference: ");
559
560        // Predict the next full sentence from the model.
561        let mut full_sentence = String::from(INPUT_TEXT_STR);
562        for _ in 0..64 {
563            // Prepare input tokens, padding as necessary.
564            let mut inference_tokens = all_tokens.clone();
565            let padding = SEQUENCE_LENGTH - inference_tokens.len();
566            for _ in 0..padding {
567                inference_tokens.push(tokenizer::PAD_TOKEN);
568            }
569
570            // Prepare inference tensor.
571            let tensor = model.tensor_from_tokens(&[inference_tokens]);
572
573            // Run inference.
574            let (inference, hidden_layers) = model.infer(tensor);
575
576            // Sample the next token in the sentence based on inference.
577            let next_token = model.tokens_from_inference(inference, &[padding])[0];
578            all_tokens.push(next_token);
579
580            // Decode the token and add it to the sentence.
581            let next_word = tokenizer.decode(vec![next_token]);
582            full_sentence.push_str(&next_word);
583
584            eprint!("{}", next_word);
585
586            // Quit early if the model emits a full-stop.
587            // In these tests, we always embed from the final
588            // ("highest") hidden layer.
589            let hidden_layer_index = model.count_layers(&hidden_layers) - 1;
590            if full_sentence.ends_with('.') {
591                eprintln!();
592                eprintln!(
593                    "Final inference embedding: {:?}",
594                    model.embeddings_from_layers(&hidden_layers, &[padding], hidden_layer_index)
595                );
596                break;
597            }
598
599            assert_eq!(tokenizer.decode(all_tokens.clone()), full_sentence);
600        }
601    }
602}