use anyhow::Result;
use ndarray::{Array, ArrayD, ArrayViewMut, Axis, Ix1, Ix2, Ix3, Ix6};
use rand::{distributions::WeightedIndex, prelude::Distribution};
use tract_onnx::prelude::{
tvec, DatumExt, Framework, Graph, InferenceModelExt, SimplePlan, Tensor, TypedFact, TypedOp,
};
type OptimizedOnnxModel =
SimplePlan<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>;
type TokensInput = Array<i32, Ix2>;
type InferenceOutput = Array<f32, Ix3>;
type HiddenLayersOutput = Array<f32, Ix6>;
const GPT2_VOCABULARY_SIZE: usize = 50257;
const GPT2_LAYER_COUNT: usize = 12;
const GPT2_HEAD_COUNT: usize = 12;
pub const GPT2_EMBEDDING_SIZE: usize = 768;
const SAMPLE_TEMPERATURE: f32 = 0.9;
const SAMPLE_MIN_P_VALUE: f32 = 0.5;
pub struct Gpt2Model {
model: OptimizedOnnxModel,
out_inference_index: usize,
out_hidden_layers_index: usize,
batch_size: usize,
sequence_length: usize,
}
impl Gpt2Model {
pub fn new(onnx_model_path: &str, batch_size: usize, sequence_length: usize) -> Result<Self> {
let mut model = tract_onnx::onnx()
.with_ignore_output_shapes(true)
.with_ignore_output_types(true)
.model_for_path(onnx_model_path)?;
model.set_input_fact(0, i32::fact([batch_size, sequence_length]).into())?;
let out_inference = model
.find_outlet_label("next_token_inferences")
.expect("missing inference output");
model.set_outlet_fact(
out_inference,
f32::fact([batch_size, sequence_length, GPT2_VOCABULARY_SIZE]).into(),
)?;
let out_inference_index = model
.output_outlets()?
.iter()
.position(|o| o == &out_inference)
.expect("missing inference output");
let out_hidden_layers = model
.find_outlet_label("hidden_layers")
.expect("missing hidden layers output");
model.set_outlet_fact(
out_hidden_layers,
f32::fact([
batch_size,
GPT2_LAYER_COUNT,
2,
GPT2_HEAD_COUNT,
sequence_length,
GPT2_EMBEDDING_SIZE / GPT2_HEAD_COUNT,
])
.into(),
)?;
let out_hidden_layers_index = model
.output_outlets()?
.iter()
.position(|o| o == &out_hidden_layers)
.expect("missing hidden layers output");
let model = model.into_optimized()?;
let model = model.into_runnable()?;
Ok(Gpt2Model {
model,
out_inference_index,
out_hidden_layers_index,
batch_size,
sequence_length,
})
}
pub fn tensor_from_tokens(&self, tokens: &[Vec<i32>]) -> TokensInput {
assert_eq!(self.batch_size, tokens.len());
TokensInput::from_shape_fn(
(self.batch_size, self.sequence_length),
|(batch_index, sequence_index)| tokens[batch_index][sequence_index],
)
}
pub fn infer(&self, tensor: TokensInput) -> (InferenceOutput, HiddenLayersOutput) {
let tensor: Tensor = tensor.into();
let model_outputs = self.model.run(tvec!(tensor)).expect("inference");
let inference = model_outputs[self.out_inference_index].clone();
let hidden_layers = model_outputs[self.out_hidden_layers_index].clone();
let inference = (*inference).clone();
let inference: ArrayD<f32> = inference.into_array().unwrap();
let inference: InferenceOutput = inference.into_dimensionality().unwrap();
let hidden_layers = (*hidden_layers).clone();
let hidden_layers: ArrayD<f32> = hidden_layers.into_array().unwrap();
let hidden_layers: HiddenLayersOutput = hidden_layers.into_dimensionality().unwrap();
(inference, hidden_layers)
}
pub fn count_layers(&self, hidden_layers: &HiddenLayersOutput) -> usize {
hidden_layers.dim().1
}
pub fn tokens_from_inference(
&self,
mut inference: InferenceOutput,
tokens_padding: &[usize],
) -> Vec<i32> {
let batch_size = inference.dim().0;
let sequence_length = inference.dim().1;
assert_eq!(self.batch_size, batch_size);
assert_eq!(self.sequence_length, sequence_length);
assert_eq!(batch_size, tokens_padding.len());
let mut token_indexes = Vec::with_capacity(batch_size);
let axis = Axis(0);
for (index, padding) in tokens_padding.iter().enumerate().take(batch_size) {
let mut inference = inference.index_axis_mut(axis, index);
let sample = sample_nucleus(
&mut inference,
Self::last_token_inference_index(sequence_length, *padding),
);
token_indexes.push(sample as i32);
}
token_indexes
}
pub fn embeddings_from_layers(
&self,
hidden_layers: &HiddenLayersOutput,
tokens_padding: &[usize],
hidden_layer_index: usize,
) -> Array<f32, Ix2> {
let batch_size = hidden_layers.dim().0;
assert_eq!(2, hidden_layers.dim().2);
let head_count = hidden_layers.dim().3;
let token_sequence_length = hidden_layers.dim().4;
let embeddings_per_head = hidden_layers.dim().5;
let embeddings_per_layer = embeddings_per_head * head_count;
let mut embeddings = Array::zeros((0, embeddings_per_layer));
for (index, padding) in tokens_padding.iter().enumerate().take(batch_size) {
let hidden_layer = hidden_layers.index_axis(Axis(0), index);
let hidden_layer = hidden_layer.index_axis(Axis(0), hidden_layer_index);
let hidden_layer = hidden_layer.index_axis(Axis(0), 1);
let mut embedding = Vec::with_capacity(embeddings_per_layer);
for head in 0..head_count {
let hidden_layer = hidden_layer.index_axis(Axis(0), head);
let token_index = Self::last_token_inference_index(token_sequence_length, *padding);
let hidden_layer = hidden_layer.index_axis(Axis(0), token_index);
embedding.extend(hidden_layer.iter());
}
let embedding: Array<f32, Ix1> = Array::from_vec(embedding);
embeddings.push_row(embedding.view()).expect("row");
}
embeddings
}
pub fn last_token_inference_index(token_sequence_length: usize, token_padding: usize) -> usize {
if token_padding >= token_sequence_length {
0
} else {
token_sequence_length - token_padding - 1
}
}
}
fn sample_nucleus(inference: &mut ArrayViewMut<f32, Ix2>, token_index: usize) -> usize {
let mut inference = inference.index_axis_mut(Axis(0), token_index);
inference.mapv_inplace(|score| score / SAMPLE_TEMPERATURE);
let mut sorted_scores: Vec<f32> = inference.iter().copied().collect();
sorted_scores.sort_by(|a, b| a.total_cmp(b).reverse());
let mut sorted_scores: Array<f32, Ix1> = sorted_scores.into();
assert!(sorted_scores[0] > sorted_scores[sorted_scores.len() - 1]);
let original_sorted_scores = sorted_scores.clone();
softmax(&mut sorted_scores.view_mut());
sorted_scores.accumulate_axis_inplace(Axis(0), |&prev, curr| *curr += prev);
let iter = sorted_scores
.iter()
.filter(|score| score <= &&SAMPLE_MIN_P_VALUE);
let k_min_index = iter.count().saturating_sub(1);
let k_min_score = original_sorted_scores[k_min_index];
inference.mapv_inplace(|score| {
if score < k_min_score {
return -1e10;
}
score
});
softmax(&mut inference.view_mut());
let inference = inference.mapv(|score| score as f64);
let multinomial = WeightedIndex::new(inference.view()).unwrap();
multinomial.sample(&mut rand::thread_rng())
}
fn softmax(tensor: &mut ArrayViewMut<f32, Ix1>) {
let max_value = *tensor.iter().max_by(|a, b| a.total_cmp(b)).unwrap();
tensor.mapv_inplace(|value| value - max_value);
tensor.mapv_inplace(|value| value.exp());
let sum_exps = tensor.sum();
tensor.mapv_inplace(|value| value / sum_exps);
let sum_values = tensor.sum();
tensor.mapv_inplace(|value| value / sum_values);
}
#[cfg(test)]
pub mod test {
use crate::tokenizer::{self, Tokenizer};
use super::*;
const MODEL_PATH: &str = "./gpt-2-model/saved_models/gpt-2-124M.onnx";
const BPE_PATH: &str = "./gpt-2-model/saved_models/124M_vocab.bpe";
const ENCODER_PATH: &str = "./gpt-2-model/saved_models/124M_encoder.json";
const BATCH_SIZE: usize = 1;
const SEQUENCE_LENGTH: usize = 128;
const INPUT_TEXT_STR: &str =
"GPT-2 is a machine learning model for natural language-processing;";
#[test]
fn infers_and_samples_sentence() {
let model = Gpt2Model::new(MODEL_PATH, BATCH_SIZE, SEQUENCE_LENGTH).expect("load failed");
let tokenizer = Tokenizer::new(BPE_PATH, ENCODER_PATH);
let tokens = tokenizer.encode(INPUT_TEXT_STR);
let mut all_tokens = tokens.clone();
eprintln!(" Prompt: `{}`", INPUT_TEXT_STR);
eprint!("Inference: ");
let mut full_sentence = String::from(INPUT_TEXT_STR);
for _ in 0..64 {
let mut inference_tokens = all_tokens.clone();
let padding = SEQUENCE_LENGTH - inference_tokens.len();
for _ in 0..padding {
inference_tokens.push(tokenizer::PAD_TOKEN);
}
let tensor = model.tensor_from_tokens(&[inference_tokens]);
let (inference, hidden_layers) = model.infer(tensor);
let next_token = model.tokens_from_inference(inference, &[padding])[0];
all_tokens.push(next_token);
let next_word = tokenizer.decode(vec![next_token]);
full_sentence.push_str(&next_word);
eprint!("{}", next_word);
let hidden_layer_index = model.count_layers(&hidden_layers) - 1;
if full_sentence.ends_with('.') {
eprintln!();
eprintln!(
"Final inference embedding: {:?}",
model.embeddings_from_layers(&hidden_layers, &[padding], hidden_layer_index)
);
break;
}
assert_eq!(tokenizer.decode(all_tokens.clone()), full_sentence);
}
}
}