pub struct ConditionalGenerationPipeline<'a> { /* private fields */ }
Expand description

Wraps Huggingface Optimum pipeline exported to ONNX with causal-lm task.

!!! Note Does not add any special tokens to the input text. If you want to add special tokens to the input text, just provide them in the prompt.

Export docs https://huggingface.co/docs/optimum/exporters/onnx/usage_guides/export_a_model

Example

use std::fs;
use ort::{GraphOptimizationLevel, LoggingLevel};
use ort::environment::Environment;
use edge_transformers::{ConditionalGenerationPipeline, TopKSampler, Device};

let environment = Environment::builder()
  .with_name("test")
 .with_log_level(LoggingLevel::Verbose)
.build()
.unwrap();

let sampler = TopKSampler::new(50, 0.9);
let pipeline = ConditionalGenerationPipeline::from_pretrained(
 environment.into_arc(),
"optimum/gpt2".to_string(),
Device::CPU,
GraphOptimizationLevel::Level3,
).unwrap();

let input = "This is a test";

println!("{}", pipeline.generate(input, 10, &sampler).unwrap());

Implementations§

source§

impl<'a> ConditionalGenerationPipeline<'a>

source

pub fn from_pretrained( env: Arc<Environment>, model_id: String, device: Device, optimization_level: GraphOptimizationLevel ) -> Result<Self, Error>

source

pub fn new_from_memory( environment: Arc<Environment>, model: &'a [u8], tokenizer_config: String, special_tokens_map: String, device: Device, optimization_level: GraphOptimizationLevel ) -> Result<Self, Error>

Creates new pipeline from ONNX model bytes and tokenizer configuration.

source

pub fn new_from_files( environment: Arc<Environment>, model: PathBuf, tokenizer_config: PathBuf, special_tokens_map: PathBuf, device: Device, optimization_level: GraphOptimizationLevel ) -> Result<Self, Error>

Creates new pipeline from ONNX model file and tokenizer configuration.

source

pub fn generate<'sampler>( &self, prompt: &str, max_length: i32, sampler: &'sampler dyn Sampler ) -> Result<String, Error>

source

pub fn generate_batch<'sampler>( &self, prompt: Vec<String>, max_length: i32, sampler: &'sampler dyn Sampler ) -> Result<Vec<String>, Error>

Generates text from input batch text.

Auto Trait Implementations§

Blanket Implementations§

source§

impl<T> Any for Twhere T: 'static + ?Sized,

source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
source§

impl<T> Borrow<T> for Twhere T: ?Sized,

const: unstable · source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
source§

impl<T> BorrowMut<T> for Twhere T: ?Sized,

const: unstable · source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
source§

impl<T> From<T> for T

const: unstable · source§

fn from(t: T) -> T

Returns the argument unchanged.

source§

impl<T> Instrument for T

source§

fn instrument(self, span: Span) -> Instrumented<Self>

Instruments this type with the provided Span, returning an Instrumented wrapper. Read more
source§

fn in_current_span(self) -> Instrumented<Self>

Instruments this type with the current Span, returning an Instrumented wrapper. Read more
source§

impl<T, U> Into<U> for Twhere U: From<T>,

const: unstable · source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

§

impl<T> Pointable for T

§

const ALIGN: usize = mem::align_of::<T>()

The alignment of pointer.
§

type Init = T

The type for initializers.
§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
source§

impl<T> Same<T> for T

§

type Output = T

Should always be Self
source§

impl<T, U> TryFrom<U> for Twhere U: Into<T>,

§

type Error = Infallible

The type returned in the event of a conversion error.
const: unstable · source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
source§

impl<T, U> TryInto<U> for Twhere U: TryFrom<T>,

§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
const: unstable · source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
§

impl<V, T> VZip<V> for Twhere V: MultiLane<T>,

§

fn vzip(self) -> V

source§

impl<T> WithSubscriber for T

source§

fn with_subscriber<S>(self, subscriber: S) -> WithDispatch<Self>where S: Into<Dispatch>,

Attaches the provided Subscriber to this type, returning a WithDispatch wrapper. Read more
source§

fn with_current_subscriber(self) -> WithDispatch<Self>

Attaches the current default Subscriber to this type, returning a WithDispatch wrapper. Read more