Struct edge_transformers::pipelines::conditional_generation_with_pkvs::ConditionalGenerationPipelineWithPKVs
source · pub struct ConditionalGenerationPipelineWithPKVs<'a> { /* private fields */ }Expand description
Wraps Huggingface Optimum pipeline exported to ONNX with causal-lm-with-past task.
!!! Note Does not add any special tokens to the input text. If you want to add special tokens to the input text, just provide them in the prompt.
Export docs https://huggingface.co/docs/optimum/exporters/onnx/usage_guides/export_a_model
Example
use std::fs;
use ort::environment::Environment;
use ort::{GraphOptimizationLevel, LoggingLevel};
use edge_transformers::{ConditionalGenerationPipelineWithPKVs, TopKSampler, Device};
let environment = Environment::builder()
.with_name("test")
.with_log_level(LoggingLevel::Verbose)
.build()
.unwrap();
let sampler = TopKSampler::new(50, 0.9);
let pipeline = ConditionalGenerationPipelineWithPKVs::from_pretrained(
environment.into_arc(),
"optimum/gpt2".to_string(),
Device::CPU,
GraphOptimizationLevel::Level3,
).unwrap();
let input = "This is a test";
println!("{}", pipeline.generate(input, 10, &sampler).unwrap());