use std::ffi::{OsStr, OsString};
use std::fmt::{Debug, Formatter};
use std::path::Path;
use anyhow::{anyhow, Error, Result};
use cxx::UniquePtr;
use super::{
config, vec_ffi_vecstr, BatchType, Config, GenerationStepResult, ScoringOptions, ScoringResult,
VecStr, VecString, VecUSize,
};
trait GenerationCallback {
fn execute(&mut self, res: GenerationStepResult) -> bool;
}
impl<F: FnMut(GenerationStepResult) -> bool> GenerationCallback for F {
fn execute(&mut self, args: GenerationStepResult) -> bool {
self(args)
}
}
type GenerationCallbackBox<'a> = Box<dyn GenerationCallback + 'a>;
impl<'a> From<Option<&'a mut dyn FnMut(GenerationStepResult) -> bool>>
for GenerationCallbackBox<'a>
{
fn from(opt: Option<&'a mut dyn FnMut(GenerationStepResult) -> bool>) -> Self {
match opt {
None => Box::new(|_| false) as GenerationCallbackBox,
Some(c) => Box::new(c) as GenerationCallbackBox,
}
}
}
fn execute_generation_callback(f: &mut GenerationCallbackBox, arg: GenerationStepResult) -> bool {
f.execute(arg)
}
#[cxx::bridge]
mod ffi {
struct GenerationOptions<'a> {
beam_size: usize,
patience: f32,
length_penalty: f32,
repetition_penalty: f32,
no_repeat_ngram_size: usize,
disable_unk: bool,
suppress_sequences: Vec<VecStr<'a>>,
end_token: Vec<&'a str>,
return_end_token: bool,
max_length: usize,
min_length: usize,
sampling_topk: usize,
sampling_topp: f32,
sampling_temperature: f32,
num_hypotheses: usize,
return_scores: bool,
return_logits_vocab: bool,
return_alternatives: bool,
min_alternative_expansion_prob: f32,
static_prompt: Vec<&'a str>,
cache_static_prompt: bool,
include_prompt_in_result: bool,
max_batch_size: usize,
batch_type: BatchType,
}
struct GenerationResult {
sequences: Vec<VecString>,
sequences_ids: Vec<VecUSize>,
scores: Vec<f32>,
}
extern "Rust" {
type GenerationCallbackBox<'a>;
fn execute_generation_callback(
f: &mut GenerationCallbackBox,
arg: GenerationStepResult,
) -> bool;
}
unsafe extern "C++" {
include!("ct2rs/include/generator.h");
include!("ct2rs/src/sys/types.rs.h");
include!("ct2rs/src/sys/scoring.rs.h");
type VecString = super::VecString;
type VecStr<'a> = super::VecStr<'a>;
type VecUSize = super::VecUSize;
type Config = super::config::ffi::Config;
type BatchType = super::BatchType;
type GenerationStepResult = super::GenerationStepResult;
type ScoringOptions = super::ScoringOptions;
type ScoringResult = super::ScoringResult;
type Generator;
fn generator(model_path: &str, config: UniquePtr<Config>) -> Result<UniquePtr<Generator>>;
fn generate_batch(
self: &Generator,
start_tokens: &Vec<VecStr>,
options: &GenerationOptions,
has_callback: bool,
callback: &mut GenerationCallbackBox,
) -> Result<Vec<GenerationResult>>;
fn score_batch(
self: &Generator,
tokens: &Vec<VecStr>,
options: &ScoringOptions,
) -> Result<Vec<ScoringResult>>;
fn num_queued_batches(self: &Generator) -> Result<usize>;
fn num_active_batches(self: &Generator) -> Result<usize>;
fn num_replicas(self: &Generator) -> Result<usize>;
}
}
unsafe impl Send for ffi::Generator {}
unsafe impl Sync for ffi::Generator {}
pub struct Generator {
model: OsString,
ptr: UniquePtr<ffi::Generator>,
}
impl Generator {
pub fn new<T: AsRef<Path>>(model_path: T, config: &Config) -> Result<Generator> {
let model_path = model_path.as_ref();
Ok(Generator {
model: model_path
.file_name()
.map(OsStr::to_os_string)
.unwrap_or_default(),
ptr: ffi::generator(
model_path
.to_str()
.ok_or_else(|| anyhow!("invalid path: {}", model_path.display()))?,
config.to_ffi(),
)?,
})
}
pub fn generate_batch<T: AsRef<str>, U: AsRef<str>, V: AsRef<str>, W: AsRef<str>>(
&self,
start_tokens: &[Vec<T>],
options: &GenerationOptions<U, V, W>,
callback: Option<&mut dyn FnMut(GenerationStepResult) -> bool>,
) -> Result<Vec<GenerationResult>> {
Ok(self
.ptr
.generate_batch(
&vec_ffi_vecstr(start_tokens),
&options.to_ffi(),
callback.is_some(),
&mut GenerationCallbackBox::from(callback),
)?
.into_iter()
.map(GenerationResult::from)
.collect())
}
pub fn score_batch<T: AsRef<str>>(
&self,
tokens: &[Vec<T>],
options: &ScoringOptions,
) -> Result<Vec<ScoringResult>> {
self.ptr
.score_batch(&vec_ffi_vecstr(tokens), options)
.map_err(Error::from)
}
#[inline]
pub fn num_queued_batches(&self) -> Result<usize> {
self.ptr.num_queued_batches().map_err(Error::from)
}
#[inline]
pub fn num_active_batches(&self) -> Result<usize> {
self.ptr.num_active_batches().map_err(Error::from)
}
#[inline]
pub fn num_replicas(&self) -> Result<usize> {
self.ptr.num_replicas().map_err(Error::from)
}
}
impl Debug for Generator {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Generator")
.field("model", &self.model)
.field("queued_batches", &self.num_queued_batches())
.field("active_batches", &self.num_active_batches())
.field("replicas", &self.num_replicas())
.finish()
}
}
#[cfg(target_os = "windows")]
impl Drop for Generator {
fn drop(&mut self) {
let ptr = std::mem::replace(&mut self.ptr, UniquePtr::null());
unsafe {
std::ptr::drop_in_place(ptr.into_raw());
}
}
}
#[derive(Clone, Debug)]
pub struct GenerationOptions<T: AsRef<str>, U: AsRef<str>, V: AsRef<str>> {
pub beam_size: usize,
pub patience: f32,
pub length_penalty: f32,
pub repetition_penalty: f32,
pub no_repeat_ngram_size: usize,
pub disable_unk: bool,
pub suppress_sequences: Vec<Vec<T>>,
pub end_token: Vec<U>,
pub return_end_token: bool,
pub max_length: usize,
pub min_length: usize,
pub sampling_topk: usize,
pub sampling_topp: f32,
pub sampling_temperature: f32,
pub num_hypotheses: usize,
pub return_scores: bool,
pub return_logits_vocab: bool,
pub return_alternatives: bool,
pub min_alternative_expansion_prob: f32,
pub static_prompt: Vec<V>,
pub cache_static_prompt: bool,
pub include_prompt_in_result: bool,
pub max_batch_size: usize,
pub batch_type: BatchType,
}
impl Default for GenerationOptions<String, String, String> {
fn default() -> Self {
Self {
beam_size: 1,
patience: 1.,
length_penalty: 1.,
repetition_penalty: 1.,
no_repeat_ngram_size: 0,
disable_unk: false,
suppress_sequences: vec![],
end_token: vec![],
return_end_token: false,
max_length: 512,
min_length: 0,
sampling_topk: 1,
sampling_topp: 1.,
sampling_temperature: 1.,
num_hypotheses: 1,
return_scores: false,
return_logits_vocab: false,
return_alternatives: false,
min_alternative_expansion_prob: 0.,
static_prompt: vec![],
cache_static_prompt: true,
include_prompt_in_result: true,
max_batch_size: 0,
batch_type: Default::default(),
}
}
}
impl<T: AsRef<str>, U: AsRef<str>, V: AsRef<str>> GenerationOptions<T, U, V> {
#[inline]
fn to_ffi(&self) -> ffi::GenerationOptions {
ffi::GenerationOptions {
beam_size: self.beam_size,
patience: self.patience,
length_penalty: self.length_penalty,
repetition_penalty: self.repetition_penalty,
no_repeat_ngram_size: self.no_repeat_ngram_size,
disable_unk: self.disable_unk,
suppress_sequences: vec_ffi_vecstr(self.suppress_sequences.as_ref()),
end_token: self.end_token.iter().map(AsRef::as_ref).collect(),
return_end_token: self.return_end_token,
max_length: self.max_length,
min_length: self.min_length,
sampling_topk: self.sampling_topk,
sampling_topp: self.sampling_topp,
sampling_temperature: self.sampling_temperature,
num_hypotheses: self.num_hypotheses,
return_scores: self.return_scores,
return_logits_vocab: self.return_logits_vocab,
return_alternatives: self.return_alternatives,
min_alternative_expansion_prob: self.min_alternative_expansion_prob,
static_prompt: self.static_prompt.iter().map(AsRef::as_ref).collect(),
cache_static_prompt: self.cache_static_prompt,
include_prompt_in_result: self.include_prompt_in_result,
max_batch_size: self.max_batch_size,
batch_type: self.batch_type,
}
}
}
#[derive(Clone, Debug)]
pub struct GenerationResult {
pub sequences: Vec<Vec<String>>,
pub sequences_ids: Vec<Vec<usize>>,
pub scores: Vec<f32>,
}
impl From<ffi::GenerationResult> for GenerationResult {
fn from(res: ffi::GenerationResult) -> Self {
Self {
sequences: res.sequences.into_iter().map(Vec::<String>::from).collect(),
sequences_ids: res
.sequences_ids
.into_iter()
.map(Vec::<usize>::from)
.collect(),
scores: res.scores,
}
}
}
impl GenerationResult {
#[inline]
pub fn num_sequences(&self) -> usize {
self.sequences.len()
}
#[inline]
pub fn has_scores(&self) -> bool {
!self.scores.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::ffi::{VecStr, VecString, VecUSize};
use super::{ffi, GenerationOptions, GenerationResult};
#[test]
fn options_to_ffi() {
let opts = GenerationOptions {
suppress_sequences: vec![vec!["x".to_string(), "y".to_string(), "z".to_string()]],
end_token: vec!["1".to_string(), "2".to_string()],
static_prompt: vec!["one".to_string(), "two".to_string()],
..Default::default()
};
let res = opts.to_ffi();
assert_eq!(res.beam_size, opts.beam_size);
assert_eq!(res.patience, opts.patience);
assert_eq!(res.length_penalty, opts.length_penalty);
assert_eq!(res.repetition_penalty, opts.repetition_penalty);
assert_eq!(res.no_repeat_ngram_size, opts.no_repeat_ngram_size);
assert_eq!(res.disable_unk, opts.disable_unk);
assert_eq!(
res.suppress_sequences,
opts.suppress_sequences
.iter()
.map(|v| VecStr {
v: v.iter().map(AsRef::as_ref).collect()
})
.collect::<Vec<VecStr>>()
);
assert_eq!(
res.end_token,
opts.end_token
.iter()
.map(AsRef::as_ref)
.collect::<Vec<&str>>()
);
assert_eq!(res.return_end_token, opts.return_end_token);
assert_eq!(res.max_length, opts.max_length);
assert_eq!(res.min_length, opts.min_length);
assert_eq!(res.sampling_topk, opts.sampling_topk);
assert_eq!(res.sampling_topp, opts.sampling_topp);
assert_eq!(res.sampling_temperature, opts.sampling_temperature);
assert_eq!(res.num_hypotheses, opts.num_hypotheses);
assert_eq!(res.return_scores, opts.return_scores);
assert_eq!(res.return_alternatives, opts.return_alternatives);
assert_eq!(
res.min_alternative_expansion_prob,
opts.min_alternative_expansion_prob
);
assert_eq!(
res.static_prompt,
opts.static_prompt
.iter()
.map(|s| s.as_str())
.collect::<Vec<&str>>()
);
assert_eq!(res.cache_static_prompt, opts.cache_static_prompt);
assert_eq!(res.include_prompt_in_result, opts.include_prompt_in_result);
assert_eq!(res.max_batch_size, opts.max_batch_size);
assert_eq!(res.batch_type, opts.batch_type);
}
#[test]
fn generation_result() {
let sequences = vec![
vec!["a".to_string(), "b".to_string()],
vec!["x".to_string(), "y".to_string(), "z".to_string()],
];
let sequences_ids: Vec<Vec<usize>> = vec![vec![1, 2], vec![10, 20, 30]];
let scores: Vec<f32> = vec![1., 2., 3.];
let res: GenerationResult = ffi::GenerationResult {
sequences: sequences
.iter()
.map(|v| VecString::from(v.clone()))
.collect(),
sequences_ids: sequences_ids
.iter()
.map(|v| VecUSize::from(v.clone()))
.collect(),
scores: scores.clone(),
}
.into();
assert_eq!(res.sequences, sequences);
assert_eq!(res.sequences_ids, sequences_ids);
assert_eq!(res.scores, scores);
assert_eq!(res.num_sequences(), sequences.len());
assert!(res.has_scores());
}
#[test]
fn generation_empty_result() {
let res: GenerationResult = ffi::GenerationResult {
sequences: vec![],
sequences_ids: vec![],
scores: vec![],
}
.into();
assert!(res.sequences.is_empty());
assert!(res.sequences_ids.is_empty());
assert!(res.scores.is_empty());
assert_eq!(res.num_sequences(), 0);
assert!(!res.has_scores());
}
#[cfg(feature = "hub")]
mod hub {
use crate::sys::Generator;
use crate::{download_model, Config, Device};
const MODEL_ID: &str = "jkawamoto/gpt2-ct2";
#[test]
#[ignore]
fn test_generator_debug() {
let model_path = download_model(MODEL_ID).unwrap();
let generator = Generator::new(
&model_path,
&Config {
device: if cfg!(feature = "cuda") {
Device::CUDA
} else {
Device::CPU
},
..Default::default()
},
)
.unwrap();
assert!(format!("{:?}", generator)
.contains(model_path.file_name().unwrap().to_str().unwrap()));
}
}
}