ct2rs/sys/translator.rs
1// translator.rs
2//
3// Copyright (c) 2023-2024 Junpei Kawamoto
4//
5// This software is released under the MIT License.
6//
7// http://opensource.org/licenses/mit-license.php
8
9//! This module provides a Rust binding to the
10//! [`ctranslate2::Translator`](https://opennmt.net/CTranslate2/python/ctranslate2.Translator.html).
11
12use std::ffi::{OsStr, OsString};
13use std::fmt::{Debug, Formatter};
14use std::path::Path;
15
16use anyhow::{anyhow, Error, Result};
17use cxx::UniquePtr;
18
19use super::{config, vec_ffi_vecstr, BatchType, Config, GenerationStepResult, VecStr, VecString};
20
21trait GenerationCallback {
22 fn execute(&mut self, res: GenerationStepResult) -> bool;
23}
24
25impl<F: FnMut(GenerationStepResult) -> bool> GenerationCallback for F {
26 fn execute(&mut self, args: GenerationStepResult) -> bool {
27 self(args)
28 }
29}
30type TranslationCallbackBox<'a> = Box<dyn GenerationCallback + 'a>;
31
32impl<'a> From<Option<&'a mut dyn FnMut(GenerationStepResult) -> bool>>
33 for TranslationCallbackBox<'a>
34{
35 fn from(opt: Option<&'a mut dyn FnMut(GenerationStepResult) -> bool>) -> Self {
36 match opt {
37 None => Box::new(|_| false) as TranslationCallbackBox,
38 Some(c) => Box::new(c) as TranslationCallbackBox,
39 }
40 }
41}
42
43fn execute_translation_callback(f: &mut TranslationCallbackBox, arg: GenerationStepResult) -> bool {
44 f.execute(arg)
45}
46
47#[cxx::bridge]
48mod ffi {
49 struct TranslationOptions<'a> {
50 beam_size: usize,
51 patience: f32,
52 length_penalty: f32,
53 coverage_penalty: f32,
54 repetition_penalty: f32,
55 no_repeat_ngram_size: usize,
56 disable_unk: bool,
57 suppress_sequences: Vec<VecStr<'a>>,
58 prefix_bias_beta: f32,
59 end_token: Vec<&'a str>,
60 return_end_token: bool,
61 max_input_length: usize,
62 max_decoding_length: usize,
63 min_decoding_length: usize,
64 sampling_topk: usize,
65 sampling_topp: f32,
66 sampling_temperature: f32,
67 use_vmap: bool,
68 num_hypotheses: usize,
69 return_scores: bool,
70 return_attention: bool,
71 return_logits_vocab: bool,
72 return_alternatives: bool,
73 min_alternative_expansion_prob: f32,
74 replace_unknowns: bool,
75 max_batch_size: usize,
76 batch_type: BatchType,
77 }
78
79 struct TranslationResult {
80 hypotheses: Vec<VecString>,
81 scores: Vec<f32>,
82 // attention: Vec<Vec<Vec<f32>>>,
83 }
84
85 extern "Rust" {
86 type TranslationCallbackBox<'a>;
87 fn execute_translation_callback(
88 f: &mut TranslationCallbackBox,
89 arg: GenerationStepResult,
90 ) -> bool;
91 }
92
93 unsafe extern "C++" {
94 include!("ct2rs/include/translator.h");
95 include!("ct2rs/src/sys/types.rs.h");
96
97 type VecString = super::VecString;
98 type VecStr<'a> = super::VecStr<'a>;
99
100 type Config = super::config::ffi::Config;
101 type BatchType = super::BatchType;
102 type GenerationStepResult = super::GenerationStepResult;
103
104 type Translator;
105
106 fn translator(model_path: &str, config: UniquePtr<Config>)
107 -> Result<UniquePtr<Translator>>;
108
109 fn translate_batch(
110 self: &Translator,
111 source: &Vec<VecStr>,
112 options: &TranslationOptions,
113 has_callback: bool,
114 callback: &mut TranslationCallbackBox,
115 ) -> Result<Vec<TranslationResult>>;
116
117 fn translate_batch_with_target_prefix(
118 self: &Translator,
119 source: &Vec<VecStr>,
120 target_prefix: &Vec<VecStr>,
121 options: &TranslationOptions,
122 has_callback: bool,
123 callback: &mut TranslationCallbackBox,
124 ) -> Result<Vec<TranslationResult>>;
125
126 fn num_queued_batches(self: &Translator) -> Result<usize>;
127
128 fn num_active_batches(self: &Translator) -> Result<usize>;
129
130 fn num_replicas(self: &Translator) -> Result<usize>;
131 }
132}
133
134unsafe impl Send for ffi::Translator {}
135unsafe impl Sync for ffi::Translator {}
136
137/// Options for translation.
138///
139/// # Examples
140///
141/// Example of creating a default `TranslationOptions`:
142///
143/// ```
144/// # use ct2rs::sys::BatchType;
145/// use ct2rs::sys::TranslationOptions;
146///
147/// let options = TranslationOptions::default();
148/// # assert_eq!(options.beam_size, 2);
149/// # assert_eq!(options.patience, 1.);
150/// # assert_eq!(options.length_penalty, 1.);
151/// # assert_eq!(options.coverage_penalty, 0.);
152/// # assert_eq!(options.repetition_penalty, 1.);
153/// # assert_eq!(options.no_repeat_ngram_size, 0);
154/// # assert!(!options.disable_unk);
155/// # assert!(options.suppress_sequences.is_empty());
156/// # assert_eq!(options.prefix_bias_beta, 0.);
157/// # assert!(options.end_token.is_empty());
158/// # assert!(!options.return_end_token);
159/// # assert_eq!(options.max_input_length, 1024);
160/// # assert_eq!(options.max_decoding_length, 256);
161/// # assert_eq!(options.min_decoding_length, 1);
162/// # assert_eq!(options.sampling_topk, 1);
163/// # assert_eq!(options.sampling_topp, 1.);
164/// # assert_eq!(options.sampling_temperature, 1.);
165/// # assert!(!options.use_vmap);
166/// # assert_eq!(options.num_hypotheses, 1);
167/// # assert!(!options.return_scores);
168/// # assert!(!options.return_attention);
169/// # assert!(!options.return_logits_vocab);
170/// # assert!(!options.return_alternatives);
171/// # assert_eq!(options.min_alternative_expansion_prob, 0.);
172/// # assert!(!options.replace_unknowns);
173/// # assert_eq!(options.max_batch_size, 0);
174/// # assert_eq!(options.batch_type, BatchType::default());
175/// ```
176///
177#[derive(Clone, Debug)]
178pub struct TranslationOptions<T: AsRef<str>, U: AsRef<str>> {
179 /// Beam size to use for beam search (set 1 to run greedy search). (default: 2)
180 pub beam_size: usize,
181 /// Beam search patience factor, as described in <https://arxiv.org/abs/2204.05424>.
182 /// The decoding will continue until beam_size*patience hypotheses are finished.
183 /// (default: 1.0)
184 pub patience: f32,
185 /// Exponential penalty applied to the length during beam search.
186 /// The scores are normalized with:
187 /// ```math
188 /// hypothesis_score /= (hypothesis_length ** length_penalty)
189 /// ```
190 /// (default: 1.0)
191 pub length_penalty: f32,
192 /// Coverage penalty weight applied during beam search. (default: 0)
193 pub coverage_penalty: f32,
194 /// Penalty applied to the score of previously generated tokens, as described in
195 /// <https://arxiv.org/abs/1909.05858> (set > 1 to penalize). (default: 1.0)
196 pub repetition_penalty: f32,
197 /// Prevent repetitions of ngrams with this size (set 0 to disable). (default: 0)
198 pub no_repeat_ngram_size: usize,
199 /// Disable the generation of the unknown token. (default: false)
200 pub disable_unk: bool,
201 /// Disable the generation of some sequences of tokens. (default: empty)
202 pub suppress_sequences: Vec<Vec<T>>,
203 /// Biases decoding towards a given prefix, see <https://arxiv.org/abs/1912.03393> --section 4.2
204 /// Only activates biased-decoding when beta is in range (0, 1) and SearchStrategy is set to
205 /// BeamSearch. The closer beta is to 1, the stronger the bias is towards the given prefix.
206 ///
207 /// If beta <= 0 and a non-empty prefix is given, then the prefix will be used as a
208 /// hard-prefix rather than a soft, biased-prefix. (default: 0)
209 pub prefix_bias_beta: f32,
210 /// Stop the decoding on one of these tokens (defaults to the model EOS token).
211 pub end_token: Vec<U>,
212 /// Include the end token in the result. (default: false)
213 pub return_end_token: bool,
214 /// Truncate the inputs after this many tokens (set 0 to disable truncation). (default: 1024)
215 pub max_input_length: usize,
216 /// Decoding length constraints. (default: 256)
217 pub max_decoding_length: usize,
218 /// Decoding length constraints. (default: 1)
219 pub min_decoding_length: usize,
220 /// Randomly sample from the top K candidates (set 0 to sample from the full output
221 /// distribution). (default: 1)
222 pub sampling_topk: usize,
223 /// Keep the most probable tokens whose cumulative probability exceeds this value.
224 /// (default: 1.0)
225 pub sampling_topp: f32,
226 /// High temperature increase randomness. (default: 1.0)
227 pub sampling_temperature: f32,
228 /// Allow using the vocabulary map included in the model directory, if it exists.
229 /// (default: false)
230 pub use_vmap: bool,
231 /// Number of hypotheses to store in the TranslationResult class. (default: 1)
232 pub num_hypotheses: usize,
233 /// Store scores in the TranslationResult class. (default: false)
234 pub return_scores: bool,
235 /// Store attention vectors in the TranslationResult class. (default: false)
236 pub return_attention: bool,
237 /// Store log probs matrix in the TranslationResult class. (default: false)
238 pub return_logits_vocab: bool,
239 /// Return alternatives at the first unconstrained decoding position. This is typically
240 /// used with a target prefix to provide alternatives at a specific location in the
241 /// translation. (default: false)
242 pub return_alternatives: bool,
243 /// Minimum probability to expand an alternative. (default: 0)
244 pub min_alternative_expansion_prob: f32,
245 /// Replace unknown target tokens by the original source token with the highest attention.
246 /// (default: false)
247 pub replace_unknowns: bool,
248 /// The maximum batch size. If the number of inputs is greater than `max_batch_size`,
249 /// the inputs are sorted by length and split by chunks of `max_batch_size` examples
250 /// so that the number of padding positions is minimized. (default: 0)
251 pub max_batch_size: usize,
252 /// Whether `max_batch_size` is the number of “examples” or “tokens”.
253 pub batch_type: BatchType,
254}
255
256impl Default for TranslationOptions<String, String> {
257 fn default() -> Self {
258 Self {
259 beam_size: 2,
260 patience: 1.,
261 length_penalty: 1.,
262 coverage_penalty: 0.,
263 repetition_penalty: 1.,
264 no_repeat_ngram_size: 0,
265 disable_unk: false,
266 suppress_sequences: vec![],
267 prefix_bias_beta: 0.,
268 end_token: vec![],
269 return_end_token: false,
270 max_input_length: 1024,
271 max_decoding_length: 256,
272 min_decoding_length: 1,
273 sampling_topk: 1,
274 sampling_topp: 1.,
275 sampling_temperature: 1.,
276 use_vmap: false,
277 num_hypotheses: 1,
278 return_scores: false,
279 return_attention: false,
280 return_logits_vocab: false,
281 return_alternatives: false,
282 min_alternative_expansion_prob: 0.,
283 replace_unknowns: false,
284 max_batch_size: 0,
285 batch_type: BatchType::default(),
286 }
287 }
288}
289
290impl<T: AsRef<str>, U: AsRef<str>> TranslationOptions<T, U> {
291 fn to_ffi(&self) -> ffi::TranslationOptions {
292 ffi::TranslationOptions {
293 beam_size: self.beam_size,
294 patience: self.patience,
295 length_penalty: self.length_penalty,
296 coverage_penalty: self.coverage_penalty,
297 repetition_penalty: self.repetition_penalty,
298 no_repeat_ngram_size: self.no_repeat_ngram_size,
299 disable_unk: self.disable_unk,
300 suppress_sequences: vec_ffi_vecstr(self.suppress_sequences.as_ref()),
301 prefix_bias_beta: self.prefix_bias_beta,
302 end_token: self.end_token.iter().map(AsRef::as_ref).collect(),
303 return_end_token: self.return_end_token,
304 max_input_length: self.max_input_length,
305 max_decoding_length: self.max_decoding_length,
306 min_decoding_length: self.min_decoding_length,
307 sampling_topk: self.sampling_topk,
308 sampling_topp: self.sampling_topp,
309 sampling_temperature: self.sampling_temperature,
310 use_vmap: self.use_vmap,
311 num_hypotheses: self.num_hypotheses,
312 return_scores: self.return_scores,
313 return_attention: self.return_attention,
314 return_logits_vocab: self.return_logits_vocab,
315 return_alternatives: self.return_alternatives,
316 min_alternative_expansion_prob: self.min_alternative_expansion_prob,
317 replace_unknowns: self.replace_unknowns,
318 max_batch_size: self.max_batch_size,
319 batch_type: self.batch_type,
320 }
321 }
322}
323
324/// A text translator.
325///
326/// This struct is a Rust binding to the
327/// [`ctranslate2::Translator`](https://opennmt.net/CTranslate2/python/ctranslate2.Translator.html).
328///
329/// # Example
330/// Below is an example where a given list of tokens is translated:
331///
332/// ```no_run
333/// # use anyhow::Result;
334/// use ct2rs::sys::{Config, Device, Translator};
335///
336/// # fn main() -> Result<()> {
337/// let translator = Translator::new("/path/to/model", &Config::default())?;
338/// let res = translator.translate_batch(
339/// &[vec!["▁Hello", "▁world", "!", "</s>", "<unk>"]],
340/// &Default::default(),
341/// None,
342/// )?;
343/// for r in res {
344/// println!("{:?}", r);
345/// }
346/// # Ok(())
347/// # }
348/// ```
349///
350/// If the model requires target prefixes, use [`Translator::translate_batch_with_target_prefix`]
351/// instead:
352///
353/// ```no_run
354/// # use anyhow::Result;
355/// use ct2rs::sys::{Config, Device, Translator};
356///
357/// # fn main() -> Result<()> {
358/// let translator = Translator::new("/path/to/model", &Config::default())?;
359/// let res = translator.translate_batch_with_target_prefix(
360/// &[vec!["▁Hello", "▁world", "!", "</s>", "<unk>"]],
361/// &[vec!["jpn_Jpan"]],
362/// &Default::default(),
363/// None,
364/// )?;
365/// for r in res {
366/// println!("{:?}", r);
367/// }
368/// # Ok(())
369/// # }
370/// ```
371pub struct Translator {
372 model: OsString,
373 ptr: UniquePtr<ffi::Translator>,
374}
375
376impl Translator {
377 /// Creates and initializes an instance of `Translator`.
378 ///
379 /// This function constructs a new `Translator` by loading a language model from the specified
380 /// `model_path` and applying the provided `config` settings.
381 ///
382 /// # Arguments
383 /// * `model_path` - A path to the directory containing the language model to be loaded.
384 /// * `config` - A reference to a `Config` structure that specifies various settings
385 /// and configurations for the `Translator`.
386 ///
387 /// # Returns
388 /// Returns a `Result` that, if successful, contains the initialized `Translator`. If an error
389 /// occurs during initialization, the function will return an error wrapped in the `Result`.
390 ///
391 /// # Example
392 /// ```no_run
393 /// # use anyhow::Result;
394 /// #
395 /// use ct2rs::sys::{Config, Translator};
396 ///
397 /// # fn main() -> Result<()> {
398 /// let config = Config::default();
399 /// let translator = Translator::new("/path/to/model", &config)?;
400 /// # Ok(())
401 /// # }
402 /// ```
403 pub fn new<T: AsRef<Path>>(model_path: T, config: &Config) -> Result<Translator> {
404 let model_path = model_path.as_ref();
405 Ok(Translator {
406 model: model_path
407 .file_name()
408 .map(OsStr::to_os_string)
409 .unwrap_or_default(),
410 ptr: ffi::translator(
411 model_path
412 .to_str()
413 .ok_or_else(|| anyhow!("invalid path: {}", model_path.display()))?,
414 config.to_ffi(),
415 )?,
416 })
417 }
418
419 /// Translates multiple lists of tokens in a batch processing manner.
420 ///
421 /// This function takes a vector of token lists and performs batch translation according to the
422 /// specified settings in `options`. The results of the batch translation are returned as a
423 /// vector. An optional `callback` closure can be provided which is invoked for each new token
424 /// generated during the translation process. This allows for step-by-step reception of the
425 /// batch translation results. If the callback returns `true`, it will stop the translation for
426 /// that batch. Note that if a callback is provided, `options.beam_size` must be set to `1`.
427 ///
428 /// # Arguments
429 /// * `source` - A vector of token lists, each list representing a sequence of tokens to be
430 /// translated.
431 /// * `options` - Settings applied to the batch translation process.
432 /// * `callback` - An optional mutable reference to a closure that is called for each token
433 /// generation step. The closure takes a `GenerationStepResult` and returns a `bool`. If it
434 /// returns `true`, the translation process for the current batch will stop.
435 ///
436 /// # Returns
437 /// Returns a `Result` containing a vector of `TranslationResult` if successful, or an error if
438 /// the translation fails.
439 ///
440 /// # Example
441 /// ```no_run
442 /// # use anyhow::Result;
443 /// #
444 /// use ct2rs::sys::{Config, GenerationStepResult, Translator, TranslationOptions};
445 ///
446 /// # fn main() -> Result<()> {
447 /// let source_tokens = [
448 /// vec!["▁Hall", "o", "▁World", "!", "</s>"],
449 /// vec![
450 /// "▁This", "▁library", "▁is", "▁a", "▁", "Rust", "▁", "binding", "s", "▁of",
451 /// "▁C", "Trans", "late", "2", ".", "</s>"
452 /// ],
453 /// ];
454 /// let options = TranslationOptions::default();
455 /// let mut callback = |step_result: GenerationStepResult| -> bool {
456 /// println!("{:?}", step_result);
457 /// false // Continue processing
458 /// };
459 /// let translator = Translator::new("/path/to/model", &Config::default())?;
460 /// let results = translator.translate_batch(&source_tokens, &options, Some(&mut callback))?;
461 /// # Ok(())
462 /// # }
463 /// ```
464 pub fn translate_batch<T, U, V>(
465 &self,
466 source: &[Vec<T>],
467 options: &TranslationOptions<U, V>,
468 callback: Option<&mut dyn FnMut(GenerationStepResult) -> bool>,
469 ) -> Result<Vec<TranslationResult>>
470 where
471 T: AsRef<str>,
472 U: AsRef<str>,
473 V: AsRef<str>,
474 {
475 Ok(self
476 .ptr
477 .translate_batch(
478 &vec_ffi_vecstr(source),
479 &options.to_ffi(),
480 callback.is_some(),
481 &mut TranslationCallbackBox::from(callback),
482 )?
483 .into_iter()
484 .map(TranslationResult::from)
485 .collect())
486 }
487
488 /// Translates multiple lists of tokens with target prefixes in a batch processing manner.
489 ///
490 /// This function takes a vector of token lists and corresponding target prefixes, performing
491 /// batch translation according to the specified settings in `options`. An optional `callback`
492 /// closure can be provided which is invoked for each new token generated during the translation
493 /// process.
494 ///
495 /// This function is similar to `translate_batch`, with the addition of handling target prefixes
496 /// that guide the translation process. For more detailed parameter and option descriptions,
497 /// refer to the documentation for [`Translator::translate_batch`].
498 ///
499 /// # Arguments
500 /// * `source` - A vector of token lists, each list representing a sequence of tokens to be
501 /// translated.
502 /// * `target_prefix` - A vector of token lists, each list representing a sequence of target
503 /// prefix tokens that provide a starting point for the translation output.
504 /// * `options` - Settings applied to the batch translation process.
505 /// * `callback` - An optional mutable reference to a closure that is called for each token
506 /// generation step. The closure takes a `GenerationStepResult` and returns a `bool`. If it
507 /// returns `true`, the translation process for the current batch will stop.
508 ///
509 /// # Returns
510 /// Returns a `Result` containing a vector of `TranslationResult` if successful, or an error if
511 /// the translation fails.
512 pub fn translate_batch_with_target_prefix<T, U, V, W>(
513 &self,
514 source: &[Vec<T>],
515 target_prefix: &[Vec<U>],
516 options: &TranslationOptions<V, W>,
517 callback: Option<&mut dyn FnMut(GenerationStepResult) -> bool>,
518 ) -> Result<Vec<TranslationResult>>
519 where
520 T: AsRef<str>,
521 U: AsRef<str>,
522 V: AsRef<str>,
523 W: AsRef<str>,
524 {
525 Ok(self
526 .ptr
527 .translate_batch_with_target_prefix(
528 &vec_ffi_vecstr(source),
529 &vec_ffi_vecstr(target_prefix),
530 &options.to_ffi(),
531 callback.is_some(),
532 &mut TranslationCallbackBox::from(callback),
533 )?
534 .into_iter()
535 .map(TranslationResult::from)
536 .collect())
537 }
538
539 /// Number of batches in the work queue.
540 #[inline]
541 pub fn num_queued_batches(&self) -> Result<usize> {
542 self.ptr.num_queued_batches().map_err(Error::from)
543 }
544
545 /// Number of batches in the work queue or currently processed by a worker.
546 #[inline]
547 pub fn num_active_batches(&self) -> Result<usize> {
548 self.ptr.num_active_batches().map_err(Error::from)
549 }
550
551 /// Number of parallel replicas.
552 #[inline]
553 pub fn num_replicas(&self) -> Result<usize> {
554 self.ptr.num_replicas().map_err(Error::from)
555 }
556}
557
558impl Debug for Translator {
559 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
560 f.debug_struct("Translator")
561 .field("model", &self.model)
562 .field("queued_batches", &self.num_queued_batches())
563 .field("active_batches", &self.num_active_batches())
564 .field("replicas", &self.num_replicas())
565 .finish()
566 }
567}
568
569// Releasing `UniquePtr<Translator>` invokes joining threads.
570// However, on Windows, this causes a deadlock.
571// As a workaround, it is bypassed here.
572// See also https://github.com/jkawamoto/ctranslate2-rs/issues/64
573#[cfg(target_os = "windows")]
574impl Drop for Translator {
575 fn drop(&mut self) {
576 let ptr = std::mem::replace(&mut self.ptr, UniquePtr::null());
577 unsafe {
578 std::ptr::drop_in_place(ptr.into_raw());
579 }
580 }
581}
582
583/// A translation result.
584///
585/// This struct is a Rust binding to the
586/// [`ctranslate2.TranslationResult`](https://opennmt.net/CTranslate2/python/ctranslate2.TranslationResult.html).
587#[derive(Clone, Debug)]
588pub struct TranslationResult {
589 /// Translation hypotheses.
590 pub hypotheses: Vec<Vec<String>>,
591 /// Score of each translation hypothesis (empty if return_scores was disabled).
592 pub scores: Vec<f32>,
593}
594
595impl From<ffi::TranslationResult> for TranslationResult {
596 fn from(r: ffi::TranslationResult) -> Self {
597 Self {
598 hypotheses: r.hypotheses.into_iter().map(Vec::<String>::from).collect(),
599 scores: r.scores,
600 }
601 }
602}
603
604impl TranslationResult {
605 /// Returns the first translation hypothesis if exists.
606 #[inline]
607 pub fn output(&self) -> Option<&Vec<String>> {
608 self.hypotheses.first()
609 }
610
611 /// Returns the score of the first translation hypothesis if exists.
612 #[inline]
613 pub fn score(&self) -> Option<f32> {
614 self.scores.first().copied()
615 }
616
617 /// Returns the number of translation hypotheses.
618 #[inline]
619 pub fn num_hypotheses(&self) -> usize {
620 self.hypotheses.len()
621 }
622
623 /// Returns true if this result contains scores.
624 #[inline]
625 pub fn has_scores(&self) -> bool {
626 !self.scores.is_empty()
627 }
628}
629
630#[cfg(test)]
631mod tests {
632 use super::ffi::{VecStr, VecString};
633 use super::{ffi, TranslationOptions, TranslationResult};
634
635 #[test]
636 fn options_to_ffi() {
637 let opts = TranslationOptions {
638 suppress_sequences: vec![vec!["a".to_string(), "b".to_string(), "c".to_string()]],
639 end_token: vec!["1".to_string(), "2".to_string()],
640 ..Default::default()
641 };
642 let res = opts.to_ffi();
643
644 assert_eq!(res.beam_size, opts.beam_size);
645 assert_eq!(res.patience, opts.patience);
646 assert_eq!(res.length_penalty, opts.length_penalty);
647 assert_eq!(res.coverage_penalty, opts.coverage_penalty);
648 assert_eq!(res.repetition_penalty, opts.repetition_penalty);
649 assert_eq!(res.no_repeat_ngram_size, opts.no_repeat_ngram_size);
650 assert_eq!(res.disable_unk, opts.disable_unk);
651 assert_eq!(
652 res.suppress_sequences,
653 opts.suppress_sequences
654 .iter()
655 .map(|v| VecStr {
656 v: v.iter().map(AsRef::as_ref).collect()
657 })
658 .collect::<Vec<VecStr>>()
659 );
660 assert_eq!(res.prefix_bias_beta, opts.prefix_bias_beta);
661 assert_eq!(
662 res.end_token,
663 opts.end_token
664 .iter()
665 .map(AsRef::as_ref)
666 .collect::<Vec<&str>>()
667 );
668 assert_eq!(res.return_end_token, opts.return_end_token);
669 assert_eq!(res.max_input_length, opts.max_input_length);
670 assert_eq!(res.max_decoding_length, opts.max_decoding_length);
671 assert_eq!(res.min_decoding_length, opts.min_decoding_length);
672 assert_eq!(res.sampling_topk, opts.sampling_topk);
673 assert_eq!(res.sampling_topp, opts.sampling_topp);
674 assert_eq!(res.sampling_temperature, opts.sampling_temperature);
675 assert_eq!(res.use_vmap, opts.use_vmap);
676 assert_eq!(res.num_hypotheses, opts.num_hypotheses);
677 assert_eq!(res.return_scores, opts.return_scores);
678 assert_eq!(res.return_attention, opts.return_attention);
679 assert_eq!(res.return_alternatives, opts.return_alternatives);
680 assert_eq!(
681 res.min_alternative_expansion_prob,
682 opts.min_alternative_expansion_prob
683 );
684 assert_eq!(res.replace_unknowns, opts.replace_unknowns);
685 assert_eq!(res.max_batch_size, opts.max_batch_size);
686 assert_eq!(res.batch_type, opts.batch_type);
687 }
688
689 #[test]
690 fn translation_result() {
691 let hypotheses = vec![
692 vec!["a".to_string(), "b".to_string()],
693 vec!["x".to_string(), "y".to_string(), "z".to_string()],
694 ];
695 let scores: Vec<f32> = vec![1., 2., 3.];
696 let res: TranslationResult = ffi::TranslationResult {
697 hypotheses: hypotheses
698 .iter()
699 .map(|v| VecString::from(v.clone()))
700 .collect(),
701 scores: scores.clone(),
702 }
703 .into();
704
705 assert_eq!(res.hypotheses, hypotheses);
706 assert_eq!(res.scores, scores);
707 assert_eq!(res.output(), Some(hypotheses.first().unwrap()));
708 assert_eq!(res.score(), Some(scores[0]));
709 assert_eq!(res.num_hypotheses(), 2);
710 assert!(res.has_scores());
711 }
712
713 #[test]
714 fn translation_empty_result() {
715 let res: TranslationResult = ffi::TranslationResult {
716 hypotheses: vec![],
717 scores: vec![],
718 }
719 .into();
720
721 assert!(res.hypotheses.is_empty());
722 assert!(res.scores.is_empty());
723 assert_eq!(res.output(), None);
724 assert_eq!(res.score(), None);
725 assert_eq!(res.num_hypotheses(), 0);
726 assert!(!res.has_scores());
727 }
728
729 #[cfg(feature = "hub")]
730 mod hub {
731 use crate::sys::Translator;
732 use crate::{download_model, Config, Device};
733
734 const MODEL_ID: &str = "jkawamoto/fugumt-en-ja-ct2";
735 #[test]
736 #[ignore]
737 fn test_translator_debug() {
738 let model_path = download_model(MODEL_ID).unwrap();
739
740 let translator = Translator::new(
741 &model_path,
742 &Config {
743 device: if cfg!(feature = "cuda") {
744 Device::CUDA
745 } else {
746 Device::CPU
747 },
748 ..Default::default()
749 },
750 )
751 .unwrap();
752 assert!(format!("{:?}", translator)
753 .contains(model_path.file_name().unwrap().to_str().unwrap()));
754 }
755 }
756}