use std::ffi::{OsStr, OsString};
use std::fmt::{Debug, Formatter};
use std::path::Path;
use anyhow::{anyhow, Error, Result};
use cxx::UniquePtr;
use super::{config, vec_ffi_vecstr, BatchType, Config, GenerationStepResult, VecStr, VecString};
trait GenerationCallback {
fn execute(&mut self, res: GenerationStepResult) -> bool;
}
impl<F: FnMut(GenerationStepResult) -> bool> GenerationCallback for F {
fn execute(&mut self, args: GenerationStepResult) -> bool {
self(args)
}
}
type TranslationCallbackBox<'a> = Box<dyn GenerationCallback + 'a>;
impl<'a> From<Option<&'a mut dyn FnMut(GenerationStepResult) -> bool>>
for TranslationCallbackBox<'a>
{
fn from(opt: Option<&'a mut dyn FnMut(GenerationStepResult) -> bool>) -> Self {
match opt {
None => Box::new(|_| false) as TranslationCallbackBox,
Some(c) => Box::new(c) as TranslationCallbackBox,
}
}
}
fn execute_translation_callback(f: &mut TranslationCallbackBox, arg: GenerationStepResult) -> bool {
f.execute(arg)
}
#[cxx::bridge]
mod ffi {
struct TranslationOptions<'a> {
beam_size: usize,
patience: f32,
length_penalty: f32,
coverage_penalty: f32,
repetition_penalty: f32,
no_repeat_ngram_size: usize,
disable_unk: bool,
suppress_sequences: Vec<VecStr<'a>>,
prefix_bias_beta: f32,
end_token: Vec<&'a str>,
return_end_token: bool,
max_input_length: usize,
max_decoding_length: usize,
min_decoding_length: usize,
sampling_topk: usize,
sampling_topp: f32,
sampling_temperature: f32,
use_vmap: bool,
num_hypotheses: usize,
return_scores: bool,
return_attention: bool,
return_logits_vocab: bool,
return_alternatives: bool,
min_alternative_expansion_prob: f32,
replace_unknowns: bool,
max_batch_size: usize,
batch_type: BatchType,
}
struct TranslationResult {
hypotheses: Vec<VecString>,
scores: Vec<f32>,
}
extern "Rust" {
type TranslationCallbackBox<'a>;
fn execute_translation_callback(
f: &mut TranslationCallbackBox,
arg: GenerationStepResult,
) -> bool;
}
unsafe extern "C++" {
include!("ct2rs/include/translator.h");
include!("ct2rs/src/sys/types.rs.h");
type VecString = super::VecString;
type VecStr<'a> = super::VecStr<'a>;
type Config = super::config::ffi::Config;
type BatchType = super::BatchType;
type GenerationStepResult = super::GenerationStepResult;
type Translator;
fn translator(model_path: &str, config: UniquePtr<Config>)
-> Result<UniquePtr<Translator>>;
fn translate_batch(
self: &Translator,
source: &Vec<VecStr>,
options: &TranslationOptions,
has_callback: bool,
callback: &mut TranslationCallbackBox,
) -> Result<Vec<TranslationResult>>;
fn translate_batch_with_target_prefix(
self: &Translator,
source: &Vec<VecStr>,
target_prefix: &Vec<VecStr>,
options: &TranslationOptions,
has_callback: bool,
callback: &mut TranslationCallbackBox,
) -> Result<Vec<TranslationResult>>;
fn num_queued_batches(self: &Translator) -> Result<usize>;
fn num_active_batches(self: &Translator) -> Result<usize>;
fn num_replicas(self: &Translator) -> Result<usize>;
}
}
unsafe impl Send for ffi::Translator {}
unsafe impl Sync for ffi::Translator {}
#[derive(Clone, Debug)]
pub struct TranslationOptions<T: AsRef<str>, U: AsRef<str>> {
pub beam_size: usize,
pub patience: f32,
pub length_penalty: f32,
pub coverage_penalty: f32,
pub repetition_penalty: f32,
pub no_repeat_ngram_size: usize,
pub disable_unk: bool,
pub suppress_sequences: Vec<Vec<T>>,
pub prefix_bias_beta: f32,
pub end_token: Vec<U>,
pub return_end_token: bool,
pub max_input_length: usize,
pub max_decoding_length: usize,
pub min_decoding_length: usize,
pub sampling_topk: usize,
pub sampling_topp: f32,
pub sampling_temperature: f32,
pub use_vmap: bool,
pub num_hypotheses: usize,
pub return_scores: bool,
pub return_attention: bool,
pub return_logits_vocab: bool,
pub return_alternatives: bool,
pub min_alternative_expansion_prob: f32,
pub replace_unknowns: bool,
pub max_batch_size: usize,
pub batch_type: BatchType,
}
impl Default for TranslationOptions<String, String> {
fn default() -> Self {
Self {
beam_size: 2,
patience: 1.,
length_penalty: 1.,
coverage_penalty: 0.,
repetition_penalty: 1.,
no_repeat_ngram_size: 0,
disable_unk: false,
suppress_sequences: vec![],
prefix_bias_beta: 0.,
end_token: vec![],
return_end_token: false,
max_input_length: 1024,
max_decoding_length: 256,
min_decoding_length: 1,
sampling_topk: 1,
sampling_topp: 1.,
sampling_temperature: 1.,
use_vmap: false,
num_hypotheses: 1,
return_scores: false,
return_attention: false,
return_logits_vocab: false,
return_alternatives: false,
min_alternative_expansion_prob: 0.,
replace_unknowns: false,
max_batch_size: 0,
batch_type: BatchType::default(),
}
}
}
impl<T: AsRef<str>, U: AsRef<str>> TranslationOptions<T, U> {
fn to_ffi(&self) -> ffi::TranslationOptions {
ffi::TranslationOptions {
beam_size: self.beam_size,
patience: self.patience,
length_penalty: self.length_penalty,
coverage_penalty: self.coverage_penalty,
repetition_penalty: self.repetition_penalty,
no_repeat_ngram_size: self.no_repeat_ngram_size,
disable_unk: self.disable_unk,
suppress_sequences: vec_ffi_vecstr(self.suppress_sequences.as_ref()),
prefix_bias_beta: self.prefix_bias_beta,
end_token: self.end_token.iter().map(AsRef::as_ref).collect(),
return_end_token: self.return_end_token,
max_input_length: self.max_input_length,
max_decoding_length: self.max_decoding_length,
min_decoding_length: self.min_decoding_length,
sampling_topk: self.sampling_topk,
sampling_topp: self.sampling_topp,
sampling_temperature: self.sampling_temperature,
use_vmap: self.use_vmap,
num_hypotheses: self.num_hypotheses,
return_scores: self.return_scores,
return_attention: self.return_attention,
return_logits_vocab: self.return_logits_vocab,
return_alternatives: self.return_alternatives,
min_alternative_expansion_prob: self.min_alternative_expansion_prob,
replace_unknowns: self.replace_unknowns,
max_batch_size: self.max_batch_size,
batch_type: self.batch_type,
}
}
}
pub struct Translator {
model: OsString,
ptr: UniquePtr<ffi::Translator>,
}
impl Translator {
pub fn new<T: AsRef<Path>>(model_path: T, config: &Config) -> Result<Translator> {
let model_path = model_path.as_ref();
Ok(Translator {
model: model_path
.file_name()
.map(OsStr::to_os_string)
.unwrap_or_default(),
ptr: ffi::translator(
model_path
.to_str()
.ok_or_else(|| anyhow!("invalid path: {}", model_path.display()))?,
config.to_ffi(),
)?,
})
}
pub fn translate_batch<T, U, V>(
&self,
source: &[Vec<T>],
options: &TranslationOptions<U, V>,
callback: Option<&mut dyn FnMut(GenerationStepResult) -> bool>,
) -> Result<Vec<TranslationResult>>
where
T: AsRef<str>,
U: AsRef<str>,
V: AsRef<str>,
{
Ok(self
.ptr
.translate_batch(
&vec_ffi_vecstr(source),
&options.to_ffi(),
callback.is_some(),
&mut TranslationCallbackBox::from(callback),
)?
.into_iter()
.map(TranslationResult::from)
.collect())
}
pub fn translate_batch_with_target_prefix<T, U, V, W>(
&self,
source: &[Vec<T>],
target_prefix: &[Vec<U>],
options: &TranslationOptions<V, W>,
callback: Option<&mut dyn FnMut(GenerationStepResult) -> bool>,
) -> Result<Vec<TranslationResult>>
where
T: AsRef<str>,
U: AsRef<str>,
V: AsRef<str>,
W: AsRef<str>,
{
Ok(self
.ptr
.translate_batch_with_target_prefix(
&vec_ffi_vecstr(source),
&vec_ffi_vecstr(target_prefix),
&options.to_ffi(),
callback.is_some(),
&mut TranslationCallbackBox::from(callback),
)?
.into_iter()
.map(TranslationResult::from)
.collect())
}
#[inline]
pub fn num_queued_batches(&self) -> Result<usize> {
self.ptr.num_queued_batches().map_err(Error::from)
}
#[inline]
pub fn num_active_batches(&self) -> Result<usize> {
self.ptr.num_active_batches().map_err(Error::from)
}
#[inline]
pub fn num_replicas(&self) -> Result<usize> {
self.ptr.num_replicas().map_err(Error::from)
}
}
impl Debug for Translator {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Translator")
.field("model", &self.model)
.field("queued_batches", &self.num_queued_batches())
.field("active_batches", &self.num_active_batches())
.field("replicas", &self.num_replicas())
.finish()
}
}
#[cfg(target_os = "windows")]
impl Drop for Translator {
fn drop(&mut self) {
let ptr = std::mem::replace(&mut self.ptr, UniquePtr::null());
unsafe {
std::ptr::drop_in_place(ptr.into_raw());
}
}
}
#[derive(Clone, Debug)]
pub struct TranslationResult {
pub hypotheses: Vec<Vec<String>>,
pub scores: Vec<f32>,
}
impl From<ffi::TranslationResult> for TranslationResult {
fn from(r: ffi::TranslationResult) -> Self {
Self {
hypotheses: r.hypotheses.into_iter().map(Vec::<String>::from).collect(),
scores: r.scores,
}
}
}
impl TranslationResult {
#[inline]
pub fn output(&self) -> Option<&Vec<String>> {
self.hypotheses.first()
}
#[inline]
pub fn score(&self) -> Option<f32> {
self.scores.first().copied()
}
#[inline]
pub fn num_hypotheses(&self) -> usize {
self.hypotheses.len()
}
#[inline]
pub fn has_scores(&self) -> bool {
!self.scores.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::ffi::{VecStr, VecString};
use super::{ffi, TranslationOptions, TranslationResult};
#[test]
fn options_to_ffi() {
let opts = TranslationOptions {
suppress_sequences: vec![vec!["a".to_string(), "b".to_string(), "c".to_string()]],
end_token: vec!["1".to_string(), "2".to_string()],
..Default::default()
};
let res = opts.to_ffi();
assert_eq!(res.beam_size, opts.beam_size);
assert_eq!(res.patience, opts.patience);
assert_eq!(res.length_penalty, opts.length_penalty);
assert_eq!(res.coverage_penalty, opts.coverage_penalty);
assert_eq!(res.repetition_penalty, opts.repetition_penalty);
assert_eq!(res.no_repeat_ngram_size, opts.no_repeat_ngram_size);
assert_eq!(res.disable_unk, opts.disable_unk);
assert_eq!(
res.suppress_sequences,
opts.suppress_sequences
.iter()
.map(|v| VecStr {
v: v.iter().map(AsRef::as_ref).collect()
})
.collect::<Vec<VecStr>>()
);
assert_eq!(res.prefix_bias_beta, opts.prefix_bias_beta);
assert_eq!(
res.end_token,
opts.end_token
.iter()
.map(AsRef::as_ref)
.collect::<Vec<&str>>()
);
assert_eq!(res.return_end_token, opts.return_end_token);
assert_eq!(res.max_input_length, opts.max_input_length);
assert_eq!(res.max_decoding_length, opts.max_decoding_length);
assert_eq!(res.min_decoding_length, opts.min_decoding_length);
assert_eq!(res.sampling_topk, opts.sampling_topk);
assert_eq!(res.sampling_topp, opts.sampling_topp);
assert_eq!(res.sampling_temperature, opts.sampling_temperature);
assert_eq!(res.use_vmap, opts.use_vmap);
assert_eq!(res.num_hypotheses, opts.num_hypotheses);
assert_eq!(res.return_scores, opts.return_scores);
assert_eq!(res.return_attention, opts.return_attention);
assert_eq!(res.return_alternatives, opts.return_alternatives);
assert_eq!(
res.min_alternative_expansion_prob,
opts.min_alternative_expansion_prob
);
assert_eq!(res.replace_unknowns, opts.replace_unknowns);
assert_eq!(res.max_batch_size, opts.max_batch_size);
assert_eq!(res.batch_type, opts.batch_type);
}
#[test]
fn translation_result() {
let hypotheses = vec![
vec!["a".to_string(), "b".to_string()],
vec!["x".to_string(), "y".to_string(), "z".to_string()],
];
let scores: Vec<f32> = vec![1., 2., 3.];
let res: TranslationResult = ffi::TranslationResult {
hypotheses: hypotheses
.iter()
.map(|v| VecString::from(v.clone()))
.collect(),
scores: scores.clone(),
}
.into();
assert_eq!(res.hypotheses, hypotheses);
assert_eq!(res.scores, scores);
assert_eq!(res.output(), Some(hypotheses.first().unwrap()));
assert_eq!(res.score(), Some(scores[0]));
assert_eq!(res.num_hypotheses(), 2);
assert!(res.has_scores());
}
#[test]
fn translation_empty_result() {
let res: TranslationResult = ffi::TranslationResult {
hypotheses: vec![],
scores: vec![],
}
.into();
assert!(res.hypotheses.is_empty());
assert!(res.scores.is_empty());
assert_eq!(res.output(), None);
assert_eq!(res.score(), None);
assert_eq!(res.num_hypotheses(), 0);
assert!(!res.has_scores());
}
#[cfg(feature = "hub")]
mod hub {
use crate::sys::Translator;
use crate::{download_model, Config, Device};
const MODEL_ID: &str = "jkawamoto/fugumt-en-ja-ct2";
#[test]
#[ignore]
fn test_translator_debug() {
let model_path = download_model(MODEL_ID).unwrap();
let translator = Translator::new(
&model_path,
&Config {
device: if cfg!(feature = "cuda") {
Device::CUDA
} else {
Device::CPU
},
..Default::default()
},
)
.unwrap();
assert!(format!("{:?}", translator)
.contains(model_path.file_name().unwrap().to_str().unwrap()));
}
}
}