[][src]Trait rust_bert::pipelines::generation_utils::LanguageGenerator

pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>: PrivateLanguageGenerator<T, V, U> {
    fn generate<'a, S>(
        &self,
        prompt_texts: Option<S>,
        attention_mask: Option<Tensor>,
        min_length: impl Into<Option<i64>>,
        max_length: impl Into<Option<i64>>,
        decoder_start_token_id: impl Into<Option<i64>>
    ) -> Vec<String>
    where
        S: AsRef<[&'a str]>
, { ... }
fn generate_indices<'a, S>(
        &self,
        prompt_texts: Option<S>,
        attention_mask: Option<Tensor>,
        min_length: impl Into<Option<i64>>,
        max_length: impl Into<Option<i64>>,
        decoder_start_token_id: impl Into<Option<i64>>
    ) -> Vec<Vec<i64>>
    where
        S: AsRef<[&'a str]>
, { ... }
fn generate_from_ids_and_past(
        &self,
        input_ids: Tensor,
        attention_mask: Option<Tensor>,
        min_length: impl Into<Option<i64>>,
        max_length: impl Into<Option<i64>>,
        decoder_start_token_id: impl Into<Option<i64>>
    ) -> Vec<Vec<i64>> { ... } }

Common trait for text generation models.

Main API for text generation

Provided methods

fn generate<'a, S>(
    &self,
    prompt_texts: Option<S>,
    attention_mask: Option<Tensor>,
    min_length: impl Into<Option<i64>>,
    max_length: impl Into<Option<i64>>,
    decoder_start_token_id: impl Into<Option<i64>>
) -> Vec<String> where
    S: AsRef<[&'a str]>, 

Generate text based on a vector of promp texts.

Arguments

  • prompt_texts - Option<Vec<&str>> Optional vector of text prompts. An empty prompt to the model may be passed if the model implement a bos_id.
  • attention_mask - Option<Tensor> Optional attention mask to hide portions of the prompt.

Returns

  • Vec<String> Vector of generated strings based on the prompts of length number_of_prompts x num_return_sequences.

Example

use rust_bert::pipelines::generation_utils::{
    GPT2Generator, GenerateConfig, LanguageGenerator,
};
let device = Device::cuda_if_available();
let generate_config = GenerateConfig {
    max_length: 30,
    do_sample: true,
    num_beams: 5,
    temperature: 1.1,
    num_return_sequences: 3,
    ..Default::default()
};
let mut gpt2_generator = GPT2Generator::new(generate_config)?;
let input_context = "The dog";
let second_input_context = "The cat was";

let attention_mask = None;
let min_length = 32;
let max_length = 128;
let decoder_start_token_id = None;

let output = gpt2_generator.generate(
    Some(vec![input_context, second_input_context]),
    attention_mask,
    min_length,
    max_length,
    decoder_start_token_id,
);

Example output:

[
    "The dog's owners, however, did not want to be named. According to the lawsuit, the animal's owner, a 29-year",
    "The dog has always been part of the family. \"He was always going to be my dog and he was always looking out for me",
    "The dog has been able to stay in the home for more than three months now. \"It's a very good dog. She's",
    "The cat was discovered earlier this month in the home of a relative of the deceased. The cat\'s owner, who wished to remain anonymous,",
    "The cat was pulled from the street by two-year-old Jazmine.\"I didn't know what to do,\" she said",
    "The cat was attacked by two stray dogs and was taken to a hospital. Two other cats were also injured in the attack and are being treated."
]

fn generate_indices<'a, S>(
    &self,
    prompt_texts: Option<S>,
    attention_mask: Option<Tensor>,
    min_length: impl Into<Option<i64>>,
    max_length: impl Into<Option<i64>>,
    decoder_start_token_id: impl Into<Option<i64>>
) -> Vec<Vec<i64>> where
    S: AsRef<[&'a str]>, 

Generate token indices without decoding (useful for token-level operations before returning final text or as validation step during training).

Arguments

  • prompt_texts - Option<Vec<&str>> Optional vector of text prompts. An empty prompt to the model may be passed if the model implement a bos_id.
  • attention_mask - Option<Tensor> Optional attention mask to hide portions of the prompt.

Returns

  • Vec<Vec<i64>> Vector of Vector of generated token indices based on the prompts of length number_of_prompts x num_return_sequences.

Example

use rust_bert::pipelines::generation_utils::{
    GPT2Generator, GenerateConfig, LanguageGenerator,
};
let device = Device::cuda_if_available();
let generate_config = GenerateConfig {
    max_length: 30,
    do_sample: true,
    num_beams: 5,
    temperature: 1.1,
    num_return_sequences: 3,
    ..Default::default()
};
let mut gpt2_generator = GPT2Generator::new(generate_config)?;
let input_context = "The dog";
let second_input_context = "The cat was";
let attention_mask = None;
let min_length = 32;
let max_length = 128;
let decoder_start_token_id = None;

let output = gpt2_generator.generate_indices(
    Some(vec![input_context, second_input_context]),
    attention_mask,
    min_length,
    max_length,
    decoder_start_token_id,
);

fn generate_from_ids_and_past(
    &self,
    input_ids: Tensor,
    attention_mask: Option<Tensor>,
    min_length: impl Into<Option<i64>>,
    max_length: impl Into<Option<i64>>,
    decoder_start_token_id: impl Into<Option<i64>>
) -> Vec<Vec<i64>>

Loading content...

Implementors

Loading content...