use std::fmt::{Debug, Formatter};
use std::path::Path;
use anyhow::{anyhow, Result};
pub use sys::TranslationOptions;
use super::tokenizer::encode_all;
use super::{sys, Config, GenerationStepResult, Tokenizer};
pub struct Translator<T: Tokenizer> {
translator: sys::Translator,
tokenizer: T,
}
impl Translator<crate::tokenizers::auto::Tokenizer> {
pub fn new<U: AsRef<Path>>(path: U, config: &Config) -> anyhow::Result<Self> {
Self::with_tokenizer(
&path,
crate::tokenizers::auto::Tokenizer::new(&path)?,
config,
)
}
}
impl<T: Tokenizer> Translator<T> {
pub fn with_tokenizer<U: AsRef<Path>>(
path: U,
tokenizer: T,
config: &Config,
) -> anyhow::Result<Self> {
Ok(Translator {
translator: sys::Translator::new(path, config)?,
tokenizer,
})
}
pub fn translate_batch<U, V, W>(
&self,
sources: &[U],
options: &TranslationOptions<V, W>,
callback: Option<&mut dyn FnMut(GenerationStepResult) -> Result<()>>,
) -> anyhow::Result<Vec<(String, Option<f32>)>>
where
U: AsRef<str>,
V: AsRef<str>,
W: AsRef<str>,
{
let output = if let Some(callback) = callback {
let mut callback_result = Ok(());
let mut wrapped_callback = |r: sys::GenerationStepResult| -> bool {
if let Err(e) =
GenerationStepResult::from_ffi(r, &self.tokenizer).and_then(|r| callback(r))
{
callback_result = Err(e);
return true;
}
false
};
let output = self.translator.translate_batch(
&encode_all(&self.tokenizer, sources)?,
options,
Some(&mut wrapped_callback),
)?;
callback_result?;
output
} else {
self.translator.translate_batch(
&encode_all(&self.tokenizer, sources)?,
options,
None,
)?
};
let mut res = Vec::new();
for r in output.into_iter() {
let score = r.score();
let hypotheses = r
.hypotheses
.into_iter()
.next()
.ok_or_else(|| anyhow!("no results are returned"))?;
res.push((
self.tokenizer
.decode(hypotheses)
.map_err(|err| anyhow!("failed to decode: {err}"))?,
score,
));
}
Ok(res)
}
pub fn translate_batch_with_target_prefix<U, V, W, E>(
&self,
sources: &[U],
target_prefixes: &Vec<Vec<V>>,
options: &TranslationOptions<W, E>,
callback: Option<&mut dyn FnMut(GenerationStepResult) -> Result<()>>,
) -> anyhow::Result<Vec<(String, Option<f32>)>>
where
U: AsRef<str>,
V: AsRef<str>,
W: AsRef<str>,
E: AsRef<str>,
{
let output = if let Some(callback) = callback {
let mut callback_result = Ok(());
let mut wrapped_callback = |r: sys::GenerationStepResult| -> bool {
if let Err(e) =
GenerationStepResult::from_ffi(r, &self.tokenizer).and_then(|r| callback(r))
{
callback_result = Err(e);
return true;
}
false
};
let output = self.translator.translate_batch_with_target_prefix(
&encode_all(&self.tokenizer, sources)?,
target_prefixes,
options,
Some(&mut wrapped_callback),
)?;
callback_result?;
output
} else {
self.translator.translate_batch_with_target_prefix(
&encode_all(&self.tokenizer, sources)?,
target_prefixes,
options,
None,
)?
};
let mut res = Vec::new();
for (r, prefix) in output.into_iter().zip(target_prefixes) {
let score = r.score();
let mut hypotheses = r
.hypotheses
.into_iter()
.next()
.ok_or_else(|| anyhow!("no results are returned"))?;
hypotheses.drain(0..prefix.len());
res.push((
self.tokenizer
.decode(hypotheses)
.map_err(|err| anyhow!("failed to decode: {err}"))?,
score,
));
}
Ok(res)
}
#[inline]
pub fn num_queued_batches(&self) -> anyhow::Result<usize> {
self.translator.num_queued_batches()
}
#[inline]
pub fn num_active_batches(&self) -> anyhow::Result<usize> {
self.translator.num_active_batches()
}
#[inline]
pub fn num_replicas(&self) -> anyhow::Result<usize> {
self.translator.num_replicas()
}
}
impl<T: Tokenizer> Debug for Translator<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self.translator)
}
}
#[cfg(test)]
#[cfg(feature = "hub")]
mod tests {
use crate::{download_model, Config, Device, TranslationOptions, Translator};
const MODEL_ID: &str = "jkawamoto/fugumt-en-ja-ct2";
#[test]
#[ignore]
fn test_translate() {
let model_path = download_model(MODEL_ID).unwrap();
let t = Translator::new(
&model_path,
&Config {
device: if cfg!(feature = "cuda") {
Device::CUDA
} else {
Device::CPU
},
..Default::default()
},
)
.unwrap();
let res = t
.translate_batch(
&["Hellow"],
&TranslationOptions {
beam_size: 1,
sampling_temperature: 0.,
..Default::default()
},
None,
)
.unwrap();
assert_eq!(res[0].0, "こんにちは");
}
#[test]
#[ignore]
fn test_translator_debug() {
let model_path = download_model(MODEL_ID).unwrap();
let t = Translator::new(
&model_path,
&Config {
device: if cfg!(feature = "cuda") {
Device::CUDA
} else {
Device::CPU
},
..Default::default()
},
)
.unwrap();
assert!(format!("{:?}", t).contains(model_path.file_name().unwrap().to_str().unwrap()));
}
}