Skip to main content

ct2rs/sys/
translator.rs

1// translator.rs
2//
3// Copyright (c) 2023-2024 Junpei Kawamoto
4//
5// This software is released under the MIT License.
6//
7// http://opensource.org/licenses/mit-license.php
8
9//! This module provides a Rust binding to the
10//! [`ctranslate2::Translator`](https://opennmt.net/CTranslate2/python/ctranslate2.Translator.html).
11
12use std::ffi::{OsStr, OsString};
13use std::fmt::{Debug, Formatter};
14use std::path::Path;
15
16use anyhow::{anyhow, Error, Result};
17use cxx::UniquePtr;
18
19use super::{config, vec_ffi_vecstr, BatchType, Config, GenerationStepResult, VecStr, VecString};
20
21trait GenerationCallback {
22    fn execute(&mut self, res: GenerationStepResult) -> bool;
23}
24
25impl<F: FnMut(GenerationStepResult) -> bool> GenerationCallback for F {
26    fn execute(&mut self, args: GenerationStepResult) -> bool {
27        self(args)
28    }
29}
30type TranslationCallbackBox<'a> = Box<dyn GenerationCallback + 'a>;
31
32impl<'a> From<Option<&'a mut dyn FnMut(GenerationStepResult) -> bool>>
33    for TranslationCallbackBox<'a>
34{
35    fn from(opt: Option<&'a mut dyn FnMut(GenerationStepResult) -> bool>) -> Self {
36        match opt {
37            None => Box::new(|_| false) as TranslationCallbackBox,
38            Some(c) => Box::new(c) as TranslationCallbackBox,
39        }
40    }
41}
42
43fn execute_translation_callback(f: &mut TranslationCallbackBox, arg: GenerationStepResult) -> bool {
44    f.execute(arg)
45}
46
47#[cxx::bridge]
48mod ffi {
49    struct TranslationOptions<'a> {
50        beam_size: usize,
51        patience: f32,
52        length_penalty: f32,
53        coverage_penalty: f32,
54        repetition_penalty: f32,
55        no_repeat_ngram_size: usize,
56        disable_unk: bool,
57        suppress_sequences: Vec<VecStr<'a>>,
58        prefix_bias_beta: f32,
59        end_token: Vec<&'a str>,
60        return_end_token: bool,
61        max_input_length: usize,
62        max_decoding_length: usize,
63        min_decoding_length: usize,
64        sampling_topk: usize,
65        sampling_topp: f32,
66        sampling_temperature: f32,
67        use_vmap: bool,
68        num_hypotheses: usize,
69        return_scores: bool,
70        return_attention: bool,
71        return_logits_vocab: bool,
72        return_alternatives: bool,
73        min_alternative_expansion_prob: f32,
74        replace_unknowns: bool,
75        max_batch_size: usize,
76        batch_type: BatchType,
77    }
78
79    struct TranslationResult {
80        hypotheses: Vec<VecString>,
81        scores: Vec<f32>,
82        // attention: Vec<Vec<Vec<f32>>>,
83    }
84
85    extern "Rust" {
86        type TranslationCallbackBox<'a>;
87        fn execute_translation_callback(
88            f: &mut TranslationCallbackBox,
89            arg: GenerationStepResult,
90        ) -> bool;
91    }
92
93    unsafe extern "C++" {
94        include!("ct2rs/include/translator.h");
95        include!("ct2rs/src/sys/types.rs.h");
96
97        type VecString = super::VecString;
98        type VecStr<'a> = super::VecStr<'a>;
99
100        type Config = super::config::ffi::Config;
101        type BatchType = super::BatchType;
102        type GenerationStepResult = super::GenerationStepResult;
103
104        type Translator;
105
106        fn translator(model_path: &str, config: UniquePtr<Config>)
107            -> Result<UniquePtr<Translator>>;
108
109        fn translate_batch(
110            self: &Translator,
111            source: &Vec<VecStr>,
112            options: &TranslationOptions,
113            has_callback: bool,
114            callback: &mut TranslationCallbackBox,
115        ) -> Result<Vec<TranslationResult>>;
116
117        fn translate_batch_with_target_prefix(
118            self: &Translator,
119            source: &Vec<VecStr>,
120            target_prefix: &Vec<VecStr>,
121            options: &TranslationOptions,
122            has_callback: bool,
123            callback: &mut TranslationCallbackBox,
124        ) -> Result<Vec<TranslationResult>>;
125
126        fn num_queued_batches(self: &Translator) -> Result<usize>;
127
128        fn num_active_batches(self: &Translator) -> Result<usize>;
129
130        fn num_replicas(self: &Translator) -> Result<usize>;
131    }
132}
133
134unsafe impl Send for ffi::Translator {}
135unsafe impl Sync for ffi::Translator {}
136
137/// Options for translation.
138///
139/// # Examples
140///
141/// Example of creating a default `TranslationOptions`:
142///
143/// ```
144/// # use ct2rs::sys::BatchType;
145/// use ct2rs::sys::TranslationOptions;
146///
147/// let options = TranslationOptions::default();
148/// # assert_eq!(options.beam_size, 2);
149/// # assert_eq!(options.patience, 1.);
150/// # assert_eq!(options.length_penalty, 1.);
151/// # assert_eq!(options.coverage_penalty, 0.);
152/// # assert_eq!(options.repetition_penalty, 1.);
153/// # assert_eq!(options.no_repeat_ngram_size, 0);
154/// # assert!(!options.disable_unk);
155/// # assert!(options.suppress_sequences.is_empty());
156/// # assert_eq!(options.prefix_bias_beta, 0.);
157/// # assert!(options.end_token.is_empty());
158/// # assert!(!options.return_end_token);
159/// # assert_eq!(options.max_input_length, 1024);
160/// # assert_eq!(options.max_decoding_length, 256);
161/// # assert_eq!(options.min_decoding_length, 1);
162/// # assert_eq!(options.sampling_topk, 1);
163/// # assert_eq!(options.sampling_topp, 1.);
164/// # assert_eq!(options.sampling_temperature, 1.);
165/// # assert!(!options.use_vmap);
166/// # assert_eq!(options.num_hypotheses, 1);
167/// # assert!(!options.return_scores);
168/// # assert!(!options.return_attention);
169/// # assert!(!options.return_logits_vocab);
170/// # assert!(!options.return_alternatives);
171/// # assert_eq!(options.min_alternative_expansion_prob, 0.);
172/// # assert!(!options.replace_unknowns);
173/// # assert_eq!(options.max_batch_size, 0);
174/// # assert_eq!(options.batch_type, BatchType::default());
175/// ```
176///
177#[derive(Clone, Debug)]
178pub struct TranslationOptions<T: AsRef<str>, U: AsRef<str>> {
179    /// Beam size to use for beam search (set 1 to run greedy search). (default: 2)
180    pub beam_size: usize,
181    /// Beam search patience factor, as described in <https://arxiv.org/abs/2204.05424>.
182    /// The decoding will continue until beam_size*patience hypotheses are finished.
183    /// (default: 1.0)
184    pub patience: f32,
185    /// Exponential penalty applied to the length during beam search.
186    /// The scores are normalized with:
187    /// ```math
188    /// hypothesis_score /= (hypothesis_length ** length_penalty)
189    /// ```
190    /// (default: 1.0)
191    pub length_penalty: f32,
192    /// Coverage penalty weight applied during beam search. (default: 0)
193    pub coverage_penalty: f32,
194    /// Penalty applied to the score of previously generated tokens, as described in
195    /// <https://arxiv.org/abs/1909.05858> (set > 1 to penalize). (default: 1.0)
196    pub repetition_penalty: f32,
197    /// Prevent repetitions of ngrams with this size (set 0 to disable). (default: 0)
198    pub no_repeat_ngram_size: usize,
199    /// Disable the generation of the unknown token. (default: false)
200    pub disable_unk: bool,
201    /// Disable the generation of some sequences of tokens. (default: empty)
202    pub suppress_sequences: Vec<Vec<T>>,
203    /// Biases decoding towards a given prefix, see <https://arxiv.org/abs/1912.03393> --section 4.2
204    /// Only activates biased-decoding when beta is in range (0, 1) and SearchStrategy is set to
205    /// BeamSearch. The closer beta is to 1, the stronger the bias is towards the given prefix.
206    ///
207    /// If beta <= 0 and a non-empty prefix is given, then the prefix will be used as a
208    /// hard-prefix rather than a soft, biased-prefix. (default: 0)
209    pub prefix_bias_beta: f32,
210    /// Stop the decoding on one of these tokens (defaults to the model EOS token).
211    pub end_token: Vec<U>,
212    /// Include the end token in the result. (default: false)
213    pub return_end_token: bool,
214    /// Truncate the inputs after this many tokens (set 0 to disable truncation). (default: 1024)
215    pub max_input_length: usize,
216    /// Decoding length constraints. (default: 256)
217    pub max_decoding_length: usize,
218    /// Decoding length constraints. (default: 1)
219    pub min_decoding_length: usize,
220    /// Randomly sample from the top K candidates (set 0 to sample from the full output
221    /// distribution). (default: 1)
222    pub sampling_topk: usize,
223    /// Keep the most probable tokens whose cumulative probability exceeds this value.
224    /// (default: 1.0)
225    pub sampling_topp: f32,
226    /// High temperature increase randomness. (default: 1.0)
227    pub sampling_temperature: f32,
228    /// Allow using the vocabulary map included in the model directory, if it exists.
229    /// (default: false)
230    pub use_vmap: bool,
231    /// Number of hypotheses to store in the TranslationResult class. (default: 1)
232    pub num_hypotheses: usize,
233    /// Store scores in the TranslationResult class. (default: false)
234    pub return_scores: bool,
235    /// Store attention vectors in the TranslationResult class. (default: false)
236    pub return_attention: bool,
237    /// Store log probs matrix in the TranslationResult class. (default: false)
238    pub return_logits_vocab: bool,
239    /// Return alternatives at the first unconstrained decoding position. This is typically
240    /// used with a target prefix to provide alternatives at a specific location in the
241    /// translation. (default: false)
242    pub return_alternatives: bool,
243    /// Minimum probability to expand an alternative. (default: 0)
244    pub min_alternative_expansion_prob: f32,
245    /// Replace unknown target tokens by the original source token with the highest attention.
246    /// (default: false)
247    pub replace_unknowns: bool,
248    /// The maximum batch size. If the number of inputs is greater than `max_batch_size`,
249    /// the inputs are sorted by length and split by chunks of `max_batch_size` examples
250    /// so that the number of padding positions is minimized. (default: 0)
251    pub max_batch_size: usize,
252    /// Whether `max_batch_size` is the number of “examples” or “tokens”.
253    pub batch_type: BatchType,
254}
255
256impl Default for TranslationOptions<String, String> {
257    fn default() -> Self {
258        Self {
259            beam_size: 2,
260            patience: 1.,
261            length_penalty: 1.,
262            coverage_penalty: 0.,
263            repetition_penalty: 1.,
264            no_repeat_ngram_size: 0,
265            disable_unk: false,
266            suppress_sequences: vec![],
267            prefix_bias_beta: 0.,
268            end_token: vec![],
269            return_end_token: false,
270            max_input_length: 1024,
271            max_decoding_length: 256,
272            min_decoding_length: 1,
273            sampling_topk: 1,
274            sampling_topp: 1.,
275            sampling_temperature: 1.,
276            use_vmap: false,
277            num_hypotheses: 1,
278            return_scores: false,
279            return_attention: false,
280            return_logits_vocab: false,
281            return_alternatives: false,
282            min_alternative_expansion_prob: 0.,
283            replace_unknowns: false,
284            max_batch_size: 0,
285            batch_type: BatchType::default(),
286        }
287    }
288}
289
290impl<T: AsRef<str>, U: AsRef<str>> TranslationOptions<T, U> {
291    fn to_ffi(&self) -> ffi::TranslationOptions {
292        ffi::TranslationOptions {
293            beam_size: self.beam_size,
294            patience: self.patience,
295            length_penalty: self.length_penalty,
296            coverage_penalty: self.coverage_penalty,
297            repetition_penalty: self.repetition_penalty,
298            no_repeat_ngram_size: self.no_repeat_ngram_size,
299            disable_unk: self.disable_unk,
300            suppress_sequences: vec_ffi_vecstr(self.suppress_sequences.as_ref()),
301            prefix_bias_beta: self.prefix_bias_beta,
302            end_token: self.end_token.iter().map(AsRef::as_ref).collect(),
303            return_end_token: self.return_end_token,
304            max_input_length: self.max_input_length,
305            max_decoding_length: self.max_decoding_length,
306            min_decoding_length: self.min_decoding_length,
307            sampling_topk: self.sampling_topk,
308            sampling_topp: self.sampling_topp,
309            sampling_temperature: self.sampling_temperature,
310            use_vmap: self.use_vmap,
311            num_hypotheses: self.num_hypotheses,
312            return_scores: self.return_scores,
313            return_attention: self.return_attention,
314            return_logits_vocab: self.return_logits_vocab,
315            return_alternatives: self.return_alternatives,
316            min_alternative_expansion_prob: self.min_alternative_expansion_prob,
317            replace_unknowns: self.replace_unknowns,
318            max_batch_size: self.max_batch_size,
319            batch_type: self.batch_type,
320        }
321    }
322}
323
324/// A text translator.
325///
326/// This struct is a Rust binding to the
327/// [`ctranslate2::Translator`](https://opennmt.net/CTranslate2/python/ctranslate2.Translator.html).
328///
329/// # Example
330/// Below is an example where a given list of tokens is translated:
331///
332/// ```no_run
333/// # use anyhow::Result;
334/// use ct2rs::sys::{Config, Device, Translator};
335///
336/// # fn main() -> Result<()> {
337/// let translator = Translator::new("/path/to/model", &Config::default())?;
338/// let res = translator.translate_batch(
339///     &[vec!["▁Hello", "▁world", "!", "</s>", "<unk>"]],
340///     &Default::default(),
341///     None,
342/// )?;
343/// for r in res {
344///     println!("{:?}", r);
345/// }
346/// # Ok(())
347/// # }
348/// ```
349///
350/// If the model requires target prefixes, use [`Translator::translate_batch_with_target_prefix`]
351/// instead:
352///
353/// ```no_run
354/// # use anyhow::Result;
355/// use ct2rs::sys::{Config, Device, Translator};
356///
357/// # fn main() -> Result<()> {
358/// let translator = Translator::new("/path/to/model", &Config::default())?;
359/// let res = translator.translate_batch_with_target_prefix(
360///     &[vec!["▁Hello", "▁world", "!", "</s>", "<unk>"]],
361///     &[vec!["jpn_Jpan"]],
362///     &Default::default(),
363///     None,
364/// )?;
365/// for r in res {
366///     println!("{:?}", r);
367/// }
368/// # Ok(())
369/// # }
370/// ```
371pub struct Translator {
372    model: OsString,
373    ptr: UniquePtr<ffi::Translator>,
374}
375
376impl Translator {
377    /// Creates and initializes an instance of `Translator`.
378    ///
379    /// This function constructs a new `Translator` by loading a language model from the specified
380    /// `model_path` and applying the provided `config` settings.
381    ///
382    /// # Arguments
383    /// * `model_path` - A path to the directory containing the language model to be loaded.
384    /// * `config` - A reference to a `Config` structure that specifies various settings
385    ///   and configurations for the `Translator`.
386    ///
387    /// # Returns
388    /// Returns a `Result` that, if successful, contains the initialized `Translator`. If an error
389    /// occurs during initialization, the function will return an error wrapped in the `Result`.
390    ///
391    /// # Example
392    /// ```no_run
393    /// # use anyhow::Result;
394    /// #
395    /// use ct2rs::sys::{Config, Translator};
396    ///
397    /// # fn main() -> Result<()> {
398    /// let config = Config::default();
399    /// let translator = Translator::new("/path/to/model", &config)?;
400    /// # Ok(())
401    /// # }
402    /// ```
403    pub fn new<T: AsRef<Path>>(model_path: T, config: &Config) -> Result<Translator> {
404        let model_path = model_path.as_ref();
405        Ok(Translator {
406            model: model_path
407                .file_name()
408                .map(OsStr::to_os_string)
409                .unwrap_or_default(),
410            ptr: ffi::translator(
411                model_path
412                    .to_str()
413                    .ok_or_else(|| anyhow!("invalid path: {}", model_path.display()))?,
414                config.to_ffi(),
415            )?,
416        })
417    }
418
419    /// Translates multiple lists of tokens in a batch processing manner.
420    ///
421    /// This function takes a vector of token lists and performs batch translation according to the
422    /// specified settings in `options`. The results of the batch translation are returned as a
423    /// vector. An optional `callback` closure can be provided which is invoked for each new token
424    /// generated during the translation process. This allows for step-by-step reception of the
425    /// batch translation results. If the callback returns `true`, it will stop the translation for
426    /// that batch. Note that if a callback is provided, `options.beam_size` must be set to `1`.
427    ///
428    /// # Arguments
429    /// * `source` - A vector of token lists, each list representing a sequence of tokens to be
430    ///    translated.
431    /// * `options` - Settings applied to the batch translation process.
432    /// * `callback` - An optional mutable reference to a closure that is called for each token
433    ///   generation step. The closure takes a `GenerationStepResult` and returns a `bool`. If it
434    ///   returns `true`, the translation process for the current batch will stop.
435    ///
436    /// # Returns
437    /// Returns a `Result` containing a vector of `TranslationResult` if successful, or an error if
438    /// the translation fails.
439    ///
440    /// # Example
441    /// ```no_run
442    /// # use anyhow::Result;
443    /// #
444    /// use ct2rs::sys::{Config, GenerationStepResult, Translator, TranslationOptions};
445    ///
446    /// # fn main() -> Result<()> {
447    /// let source_tokens = [
448    ///     vec!["▁Hall", "o", "▁World", "!", "</s>"],
449    ///     vec![
450    ///         "▁This", "▁library", "▁is", "▁a", "▁", "Rust", "▁", "binding", "s", "▁of",
451    ///         "▁C", "Trans", "late", "2", ".", "</s>"
452    ///     ],
453    /// ];
454    /// let options = TranslationOptions::default();
455    /// let mut callback = |step_result: GenerationStepResult| -> bool {
456    ///     println!("{:?}", step_result);
457    ///     false // Continue processing
458    /// };
459    /// let translator = Translator::new("/path/to/model", &Config::default())?;
460    /// let results = translator.translate_batch(&source_tokens, &options, Some(&mut callback))?;
461    /// # Ok(())
462    /// # }
463    /// ```
464    pub fn translate_batch<T, U, V>(
465        &self,
466        source: &[Vec<T>],
467        options: &TranslationOptions<U, V>,
468        callback: Option<&mut dyn FnMut(GenerationStepResult) -> bool>,
469    ) -> Result<Vec<TranslationResult>>
470    where
471        T: AsRef<str>,
472        U: AsRef<str>,
473        V: AsRef<str>,
474    {
475        Ok(self
476            .ptr
477            .translate_batch(
478                &vec_ffi_vecstr(source),
479                &options.to_ffi(),
480                callback.is_some(),
481                &mut TranslationCallbackBox::from(callback),
482            )?
483            .into_iter()
484            .map(TranslationResult::from)
485            .collect())
486    }
487
488    /// Translates multiple lists of tokens with target prefixes in a batch processing manner.
489    ///
490    /// This function takes a vector of token lists and corresponding target prefixes, performing
491    /// batch translation according to the specified settings in `options`. An optional `callback`
492    /// closure can be provided which is invoked for each new token generated during the translation
493    /// process.
494    ///
495    /// This function is similar to `translate_batch`, with the addition of handling target prefixes
496    /// that guide the translation process. For more detailed parameter and option descriptions,
497    /// refer to the documentation for [`Translator::translate_batch`].
498    ///
499    /// # Arguments
500    /// * `source` - A vector of token lists, each list representing a sequence of tokens to be
501    ///   translated.
502    /// * `target_prefix` - A vector of token lists, each list representing a sequence of target
503    ///   prefix tokens that provide a starting point for the translation output.
504    /// * `options` - Settings applied to the batch translation process.
505    /// * `callback` - An optional mutable reference to a closure that is called for each token
506    ///   generation step. The closure takes a `GenerationStepResult` and returns a `bool`. If it
507    ///   returns `true`, the translation process for the current batch will stop.
508    ///
509    /// # Returns
510    /// Returns a `Result` containing a vector of `TranslationResult` if successful, or an error if
511    /// the translation fails.
512    pub fn translate_batch_with_target_prefix<T, U, V, W>(
513        &self,
514        source: &[Vec<T>],
515        target_prefix: &[Vec<U>],
516        options: &TranslationOptions<V, W>,
517        callback: Option<&mut dyn FnMut(GenerationStepResult) -> bool>,
518    ) -> Result<Vec<TranslationResult>>
519    where
520        T: AsRef<str>,
521        U: AsRef<str>,
522        V: AsRef<str>,
523        W: AsRef<str>,
524    {
525        Ok(self
526            .ptr
527            .translate_batch_with_target_prefix(
528                &vec_ffi_vecstr(source),
529                &vec_ffi_vecstr(target_prefix),
530                &options.to_ffi(),
531                callback.is_some(),
532                &mut TranslationCallbackBox::from(callback),
533            )?
534            .into_iter()
535            .map(TranslationResult::from)
536            .collect())
537    }
538
539    /// Number of batches in the work queue.
540    #[inline]
541    pub fn num_queued_batches(&self) -> Result<usize> {
542        self.ptr.num_queued_batches().map_err(Error::from)
543    }
544
545    /// Number of batches in the work queue or currently processed by a worker.
546    #[inline]
547    pub fn num_active_batches(&self) -> Result<usize> {
548        self.ptr.num_active_batches().map_err(Error::from)
549    }
550
551    /// Number of parallel replicas.
552    #[inline]
553    pub fn num_replicas(&self) -> Result<usize> {
554        self.ptr.num_replicas().map_err(Error::from)
555    }
556}
557
558impl Debug for Translator {
559    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
560        f.debug_struct("Translator")
561            .field("model", &self.model)
562            .field("queued_batches", &self.num_queued_batches())
563            .field("active_batches", &self.num_active_batches())
564            .field("replicas", &self.num_replicas())
565            .finish()
566    }
567}
568
569// Releasing `UniquePtr<Translator>` invokes joining threads.
570// However, on Windows, this causes a deadlock.
571// As a workaround, it is bypassed here.
572// See also https://github.com/jkawamoto/ctranslate2-rs/issues/64
573#[cfg(target_os = "windows")]
574impl Drop for Translator {
575    fn drop(&mut self) {
576        let ptr = std::mem::replace(&mut self.ptr, UniquePtr::null());
577        unsafe {
578            std::ptr::drop_in_place(ptr.into_raw());
579        }
580    }
581}
582
583/// A translation result.
584///
585/// This struct is a Rust binding to the
586/// [`ctranslate2.TranslationResult`](https://opennmt.net/CTranslate2/python/ctranslate2.TranslationResult.html).
587#[derive(Clone, Debug)]
588pub struct TranslationResult {
589    /// Translation hypotheses.
590    pub hypotheses: Vec<Vec<String>>,
591    /// Score of each translation hypothesis (empty if return_scores was disabled).
592    pub scores: Vec<f32>,
593}
594
595impl From<ffi::TranslationResult> for TranslationResult {
596    fn from(r: ffi::TranslationResult) -> Self {
597        Self {
598            hypotheses: r.hypotheses.into_iter().map(Vec::<String>::from).collect(),
599            scores: r.scores,
600        }
601    }
602}
603
604impl TranslationResult {
605    /// Returns the first translation hypothesis if exists.
606    #[inline]
607    pub fn output(&self) -> Option<&Vec<String>> {
608        self.hypotheses.first()
609    }
610
611    /// Returns the score of the first translation hypothesis if exists.
612    #[inline]
613    pub fn score(&self) -> Option<f32> {
614        self.scores.first().copied()
615    }
616
617    /// Returns the number of translation hypotheses.
618    #[inline]
619    pub fn num_hypotheses(&self) -> usize {
620        self.hypotheses.len()
621    }
622
623    /// Returns true if this result contains scores.
624    #[inline]
625    pub fn has_scores(&self) -> bool {
626        !self.scores.is_empty()
627    }
628}
629
630#[cfg(test)]
631mod tests {
632    use super::ffi::{VecStr, VecString};
633    use super::{ffi, TranslationOptions, TranslationResult};
634
635    #[test]
636    fn options_to_ffi() {
637        let opts = TranslationOptions {
638            suppress_sequences: vec![vec!["a".to_string(), "b".to_string(), "c".to_string()]],
639            end_token: vec!["1".to_string(), "2".to_string()],
640            ..Default::default()
641        };
642        let res = opts.to_ffi();
643
644        assert_eq!(res.beam_size, opts.beam_size);
645        assert_eq!(res.patience, opts.patience);
646        assert_eq!(res.length_penalty, opts.length_penalty);
647        assert_eq!(res.coverage_penalty, opts.coverage_penalty);
648        assert_eq!(res.repetition_penalty, opts.repetition_penalty);
649        assert_eq!(res.no_repeat_ngram_size, opts.no_repeat_ngram_size);
650        assert_eq!(res.disable_unk, opts.disable_unk);
651        assert_eq!(
652            res.suppress_sequences,
653            opts.suppress_sequences
654                .iter()
655                .map(|v| VecStr {
656                    v: v.iter().map(AsRef::as_ref).collect()
657                })
658                .collect::<Vec<VecStr>>()
659        );
660        assert_eq!(res.prefix_bias_beta, opts.prefix_bias_beta);
661        assert_eq!(
662            res.end_token,
663            opts.end_token
664                .iter()
665                .map(AsRef::as_ref)
666                .collect::<Vec<&str>>()
667        );
668        assert_eq!(res.return_end_token, opts.return_end_token);
669        assert_eq!(res.max_input_length, opts.max_input_length);
670        assert_eq!(res.max_decoding_length, opts.max_decoding_length);
671        assert_eq!(res.min_decoding_length, opts.min_decoding_length);
672        assert_eq!(res.sampling_topk, opts.sampling_topk);
673        assert_eq!(res.sampling_topp, opts.sampling_topp);
674        assert_eq!(res.sampling_temperature, opts.sampling_temperature);
675        assert_eq!(res.use_vmap, opts.use_vmap);
676        assert_eq!(res.num_hypotheses, opts.num_hypotheses);
677        assert_eq!(res.return_scores, opts.return_scores);
678        assert_eq!(res.return_attention, opts.return_attention);
679        assert_eq!(res.return_alternatives, opts.return_alternatives);
680        assert_eq!(
681            res.min_alternative_expansion_prob,
682            opts.min_alternative_expansion_prob
683        );
684        assert_eq!(res.replace_unknowns, opts.replace_unknowns);
685        assert_eq!(res.max_batch_size, opts.max_batch_size);
686        assert_eq!(res.batch_type, opts.batch_type);
687    }
688
689    #[test]
690    fn translation_result() {
691        let hypotheses = vec![
692            vec!["a".to_string(), "b".to_string()],
693            vec!["x".to_string(), "y".to_string(), "z".to_string()],
694        ];
695        let scores: Vec<f32> = vec![1., 2., 3.];
696        let res: TranslationResult = ffi::TranslationResult {
697            hypotheses: hypotheses
698                .iter()
699                .map(|v| VecString::from(v.clone()))
700                .collect(),
701            scores: scores.clone(),
702        }
703        .into();
704
705        assert_eq!(res.hypotheses, hypotheses);
706        assert_eq!(res.scores, scores);
707        assert_eq!(res.output(), Some(hypotheses.first().unwrap()));
708        assert_eq!(res.score(), Some(scores[0]));
709        assert_eq!(res.num_hypotheses(), 2);
710        assert!(res.has_scores());
711    }
712
713    #[test]
714    fn translation_empty_result() {
715        let res: TranslationResult = ffi::TranslationResult {
716            hypotheses: vec![],
717            scores: vec![],
718        }
719        .into();
720
721        assert!(res.hypotheses.is_empty());
722        assert!(res.scores.is_empty());
723        assert_eq!(res.output(), None);
724        assert_eq!(res.score(), None);
725        assert_eq!(res.num_hypotheses(), 0);
726        assert!(!res.has_scores());
727    }
728
729    #[cfg(feature = "hub")]
730    mod hub {
731        use crate::sys::Translator;
732        use crate::{download_model, Config, Device};
733
734        const MODEL_ID: &str = "jkawamoto/fugumt-en-ja-ct2";
735        #[test]
736        #[ignore]
737        fn test_translator_debug() {
738            let model_path = download_model(MODEL_ID).unwrap();
739
740            let translator = Translator::new(
741                &model_path,
742                &Config {
743                    device: if cfg!(feature = "cuda") {
744                        Device::CUDA
745                    } else {
746                        Device::CPU
747                    },
748                    ..Default::default()
749                },
750            )
751            .unwrap();
752            assert!(format!("{:?}", translator)
753                .contains(model_path.file_name().unwrap().to_str().unwrap()));
754        }
755    }
756}