use std::fmt::{Debug, Formatter};
use std::path::Path;
use anyhow::{anyhow, Result};
pub use sys::GenerationOptions;
use crate::tokenizer::encode_all;
use super::{sys, Config, GenerationStepResult, ScoringOptions, ScoringResult, Tokenizer};
pub struct Generator<T: Tokenizer> {
generator: sys::Generator,
tokenizer: T,
}
impl Generator<crate::tokenizers::auto::Tokenizer> {
pub fn new<T: AsRef<Path>>(path: T, config: &Config) -> anyhow::Result<Self> {
Self::with_tokenizer(
&path,
crate::tokenizers::auto::Tokenizer::new(&path)?,
config,
)
}
}
impl<T: Tokenizer> Generator<T> {
pub fn with_tokenizer<U: AsRef<Path>>(
path: U,
tokenizer: T,
config: &Config,
) -> anyhow::Result<Self> {
Ok(Generator {
generator: sys::Generator::new(path, config)?,
tokenizer,
})
}
pub fn generate_batch<U, V, W, E>(
&self,
prompts: &[U],
options: &GenerationOptions<V, E, W>,
callback: Option<&mut dyn FnMut(GenerationStepResult) -> Result<()>>,
) -> anyhow::Result<Vec<(Vec<String>, Vec<f32>)>>
where
U: AsRef<str>,
V: AsRef<str>,
W: AsRef<str>,
E: AsRef<str>,
{
let output = if let Some(callback) = callback {
let mut callback_result = Ok(());
let mut wrapped_callback = |r: sys::GenerationStepResult| -> bool {
if let Err(e) =
GenerationStepResult::from_ffi(r, &self.tokenizer).and_then(|r| callback(r))
{
callback_result = Err(e);
return true;
}
false
};
let output = self.generator.generate_batch(
&encode_all(&self.tokenizer, prompts)?,
options,
Some(&mut wrapped_callback),
)?;
callback_result?;
output
} else {
self.generator
.generate_batch(&encode_all(&self.tokenizer, prompts)?, options, None)?
};
let mut res = Vec::new();
for r in output.into_iter() {
let sequence = r
.sequences
.into_iter()
.map(|seq| self.tokenizer.decode(seq))
.collect::<anyhow::Result<Vec<_>, _>>()
.map_err(|err| anyhow!("failed to decode: {err}"))?;
let scores = r.scores;
res.push((sequence, scores))
}
Ok(res)
}
pub fn score_batch<U>(
&self,
prompts: &[U],
options: &ScoringOptions,
) -> Result<Vec<ScoringResult>>
where
U: AsRef<str>,
{
self.generator
.score_batch(&encode_all(&self.tokenizer, prompts)?, options)
}
#[inline]
pub fn num_queued_batches(&self) -> anyhow::Result<usize> {
self.generator.num_queued_batches()
}
#[inline]
pub fn num_active_batches(&self) -> anyhow::Result<usize> {
self.generator.num_active_batches()
}
#[inline]
pub fn num_replicas(&self) -> anyhow::Result<usize> {
self.generator.num_replicas()
}
}
impl<T: Tokenizer> Debug for Generator<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.generator)
}
}
#[cfg(test)]
#[cfg(feature = "hub")]
mod tests {
use super::Generator;
use crate::tokenizers::auto::Tokenizer;
use crate::{download_model, Config, Device, GenerationOptions};
use anyhow::Result;
use std::path::PathBuf;
const MODEL_ID: &str = "jkawamoto/gpt2-ct2";
fn new_generator(model_path: &PathBuf) -> Result<Generator<Tokenizer>> {
Generator::new(
model_path,
&Config {
device: if cfg!(feature = "cuda") {
Device::CUDA
} else {
Device::CPU
},
..Default::default()
},
)
}
#[test]
#[ignore]
fn test_generate() {
let model_path = download_model(MODEL_ID).unwrap();
let g = new_generator(&model_path).unwrap();
let prompt = "CTranslate2 is a library";
let res = g
.generate_batch(
&[prompt],
&GenerationOptions {
max_length: 32,
..Default::default()
},
None,
)
.unwrap();
assert!(res[0].0[0].starts_with(prompt));
}
#[test]
#[ignore]
fn test_scoring() {
let model_path = download_model(MODEL_ID).unwrap();
let g = new_generator(&model_path).unwrap();
let prompt = "CTranslate2 is a library";
let res = g.score_batch(&[prompt], &Default::default()).unwrap();
assert_eq!(
res[0].tokens,
vec!["Trans", "late", "2", "Ġis", "Ġa", "Ġlibrary"]
.iter()
.map(|s| s.to_string())
.collect::<Vec<String>>()
);
assert_ne!(res[0].normalized_score(), 0.0);
}
#[test]
#[ignore]
fn test_generator_debug() {
let model_path = download_model(MODEL_ID).unwrap();
let g = new_generator(&model_path).unwrap();
assert!(format!("{:?}", g).contains(model_path.file_name().unwrap().to_str().unwrap()));
}
}