rust-bert 0.20.0

Ready-to-use NLP pipelines and language models
Documentation
// Copyright 2018-2020 The HuggingFace Inc. team.
// Copyright 2020 Marian Team Authors
// Copyright 2019-2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//     http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use tch::Device;

use crate::common::error::RustBertError;
use crate::m2m_100::M2M100Generator;
use crate::marian::MarianGenerator;
use crate::mbart::MBartGenerator;
use crate::pipelines::common::ModelType;
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
use crate::pipelines::generation_utils::{GenerateConfig, GenerateOptions, LanguageGenerator};
use crate::resources::ResourceProvider;
use crate::t5::T5Generator;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::fmt;
use std::fmt::{Debug, Display};

/// Language
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub enum Language {
    Afrikaans,
    Danish,
    Dutch,
    German,
    English,
    Icelandic,
    Luxembourgish,
    Norwegian,
    Swedish,
    WesternFrisian,
    Yiddish,
    Asturian,
    Catalan,
    French,
    Galician,
    Italian,
    Occitan,
    Portuguese,
    Romanian,
    Spanish,
    Belarusian,
    Bosnian,
    Bulgarian,
    Croatian,
    Czech,
    Macedonian,
    Polish,
    Russian,
    Serbian,
    Slovak,
    Slovenian,
    Ukrainian,
    Estonian,
    Finnish,
    Hungarian,
    Latvian,
    Lithuanian,
    Albanian,
    Armenian,
    Georgian,
    Greek,
    Breton,
    Irish,
    ScottishGaelic,
    Welsh,
    Azerbaijani,
    Bashkir,
    Kazakh,
    Turkish,
    Uzbek,
    Japanese,
    Korean,
    Vietnamese,
    ChineseMandarin,
    Bengali,
    Gujarati,
    Hindi,
    Kannada,
    Marathi,
    Nepali,
    Oriya,
    Panjabi,
    Sindhi,
    Sinhala,
    Urdu,
    Tamil,
    Cebuano,
    Iloko,
    Indonesian,
    Javanese,
    Malagasy,
    Malay,
    Malayalam,
    Sundanese,
    Tagalog,
    Burmese,
    CentralKhmer,
    Lao,
    Thai,
    Mongolian,
    Arabic,
    Hebrew,
    Pashto,
    Farsi,
    Amharic,
    Fulah,
    Hausa,
    Igbo,
    Lingala,
    Luganda,
    NorthernSotho,
    Somali,
    Swahili,
    Swati,
    Tswana,
    Wolof,
    Xhosa,
    Yoruba,
    Zulu,
    HaitianCreole,
}

impl Display for Language {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "{}", {
            let input_string = format!("{:?}", self);
            let mut output: Vec<&str> = Vec::new();
            let mut start: usize = 0;

            for (c_pos, c) in input_string.char_indices() {
                if c.is_uppercase() {
                    if start < c_pos {
                        output.push(&input_string[start..c_pos]);
                    }
                    start = c_pos;
                }
            }
            if start < input_string.len() {
                output.push(&input_string[start..]);
            }
            output.join(" ")
        })
    }
}

impl Language {
    pub fn get_iso_639_1_code(&self) -> &'static str {
        match self {
            Language::Afrikaans => "af",
            Language::Danish => "da",
            Language::Dutch => "nl",
            Language::German => "de",
            Language::English => "en",
            Language::Icelandic => "is",
            Language::Luxembourgish => "lb",
            Language::Norwegian => "no",
            Language::Swedish => "sv",
            Language::WesternFrisian => "fy",
            Language::Yiddish => "yi",
            Language::Asturian => "ast",
            Language::Catalan => "ca",
            Language::French => "fr",
            Language::Galician => "gl",
            Language::Italian => "it",
            Language::Occitan => "oc",
            Language::Portuguese => "pt",
            Language::Romanian => "ro",
            Language::Spanish => "es",
            Language::Belarusian => "be",
            Language::Bosnian => "bs",
            Language::Bulgarian => "bg",
            Language::Croatian => "hr",
            Language::Czech => "cs",
            Language::Macedonian => "mk",
            Language::Polish => "pl",
            Language::Russian => "ru",
            Language::Serbian => "sr",
            Language::Slovak => "sk",
            Language::Slovenian => "sl",
            Language::Ukrainian => "uk",
            Language::Estonian => "et",
            Language::Finnish => "fi",
            Language::Hungarian => "hu",
            Language::Latvian => "lv",
            Language::Lithuanian => "lt",
            Language::Albanian => "sq",
            Language::Armenian => "hy",
            Language::Georgian => "ka",
            Language::Greek => "el",
            Language::Breton => "br",
            Language::Irish => "ga",
            Language::ScottishGaelic => "gd",
            Language::Welsh => "cy",
            Language::Azerbaijani => "az",
            Language::Bashkir => "ba",
            Language::Kazakh => "kk",
            Language::Turkish => "tr",
            Language::Uzbek => "uz",
            Language::Japanese => "ja",
            Language::Korean => "ko",
            Language::Vietnamese => "vi",
            Language::ChineseMandarin => "zh",
            Language::Bengali => "bn",
            Language::Gujarati => "gu",
            Language::Hindi => "hi",
            Language::Kannada => "kn",
            Language::Marathi => "mr",
            Language::Nepali => "ne",
            Language::Oriya => "or",
            Language::Panjabi => "pa",
            Language::Sindhi => "sd",
            Language::Sinhala => "si",
            Language::Urdu => "ur",
            Language::Tamil => "ta",
            Language::Cebuano => "ceb",
            Language::Iloko => "ilo",
            Language::Indonesian => "id",
            Language::Javanese => "jv",
            Language::Malagasy => "mg",
            Language::Malay => "ms",
            Language::Malayalam => "ml",
            Language::Sundanese => "su",
            Language::Tagalog => "tl",
            Language::Burmese => "my",
            Language::CentralKhmer => "km",
            Language::Lao => "lo",
            Language::Thai => "th",
            Language::Mongolian => "mn",
            Language::Arabic => "ar",
            Language::Hebrew => "he",
            Language::Pashto => "ps",
            Language::Farsi => "fa",
            Language::Amharic => "am",
            Language::Fulah => "ff",
            Language::Hausa => "ha",
            Language::Igbo => "ig",
            Language::Lingala => "ln",
            Language::Luganda => "lg",
            Language::NorthernSotho => "nso",
            Language::Somali => "so",
            Language::Swahili => "sw",
            Language::Swati => "ss",
            Language::Tswana => "tn",
            Language::Wolof => "wo",
            Language::Xhosa => "xh",
            Language::Yoruba => "yo",
            Language::Zulu => "zu",
            Language::HaitianCreole => "ht",
        }
    }

    pub fn get_iso_639_3_code(&self) -> &'static str {
        match self {
            Language::Afrikaans => "afr",
            Language::Danish => "dan",
            Language::Dutch => "nld",
            Language::German => "deu",
            Language::English => "eng",
            Language::Icelandic => "isl",
            Language::Luxembourgish => "ltz",
            Language::Norwegian => "nor",
            Language::Swedish => "swe",
            Language::WesternFrisian => "fry",
            Language::Yiddish => "yid",
            Language::Asturian => "ast",
            Language::Catalan => "cat",
            Language::French => "fra",
            Language::Galician => "glg",
            Language::Italian => "ita",
            Language::Occitan => "oci",
            Language::Portuguese => "por",
            Language::Romanian => "ron",
            Language::Spanish => "spa",
            Language::Belarusian => "bel",
            Language::Bosnian => "bos",
            Language::Bulgarian => "bul",
            Language::Croatian => "hrv",
            Language::Czech => "ces",
            Language::Macedonian => "mkd",
            Language::Polish => "pol",
            Language::Russian => "rus",
            Language::Serbian => "srp",
            Language::Slovak => "slk",
            Language::Slovenian => "slv",
            Language::Ukrainian => "ukr",
            Language::Estonian => "est",
            Language::Finnish => "fin",
            Language::Hungarian => "hun",
            Language::Latvian => "lav",
            Language::Lithuanian => "lit",
            Language::Albanian => "sqi",
            Language::Armenian => "hye",
            Language::Georgian => "kat",
            Language::Greek => "ell",
            Language::Breton => "bre",
            Language::Irish => "gle",
            Language::ScottishGaelic => "gla",
            Language::Welsh => "cym",
            Language::Azerbaijani => "aze",
            Language::Bashkir => "bak",
            Language::Kazakh => "kaz",
            Language::Turkish => "tur",
            Language::Uzbek => "uzb",
            Language::Japanese => "jpn",
            Language::Korean => "kor",
            Language::Vietnamese => "vie",
            Language::ChineseMandarin => "cmn",
            Language::Bengali => "ben",
            Language::Gujarati => "guj",
            Language::Hindi => "hin",
            Language::Kannada => "kan",
            Language::Marathi => "mar",
            Language::Nepali => "nep",
            Language::Oriya => "ori",
            Language::Panjabi => "pan",
            Language::Sindhi => "snd",
            Language::Sinhala => "sin",
            Language::Urdu => "urd",
            Language::Tamil => "tam",
            Language::Cebuano => "ceb",
            Language::Iloko => "ilo",
            Language::Indonesian => "ind",
            Language::Javanese => "jav",
            Language::Malagasy => "mlg",
            Language::Malay => "msa",
            Language::Malayalam => "mal",
            Language::Sundanese => "sun",
            Language::Tagalog => "tgl",
            Language::Burmese => "mya",
            Language::CentralKhmer => "khm",
            Language::Lao => "lao",
            Language::Thai => "tha",
            Language::Mongolian => "mon",
            Language::Arabic => "ara",
            Language::Hebrew => "heb",
            Language::Pashto => "pus",
            Language::Farsi => "fas",
            Language::Amharic => "amh",
            Language::Fulah => "ful",
            Language::Hausa => "hau",
            Language::Igbo => "ibo",
            Language::Lingala => "lin",
            Language::Luganda => "lug",
            Language::NorthernSotho => "nso",
            Language::Somali => "som",
            Language::Swahili => "swa",
            Language::Swati => "ssw",
            Language::Tswana => "tsn",
            Language::Wolof => "wol",
            Language::Xhosa => "xho",
            Language::Yoruba => "yor",
            Language::Zulu => "zul",
            Language::HaitianCreole => "hat",
        }
    }
}

/// # Configuration for text translation
/// Contains information regarding the model to load, mirrors the GenerationConfig, with a
/// different set of default parameters and sets the device to place the model on.
pub struct TranslationConfig {
    /// Model type used for translation
    pub model_type: ModelType,
    /// Model weights resource
    pub model_resource: Box<dyn ResourceProvider + Send>,
    /// Config resource
    pub config_resource: Box<dyn ResourceProvider + Send>,
    /// Vocab resource
    pub vocab_resource: Box<dyn ResourceProvider + Send>,
    /// Merges resource
    pub merges_resource: Option<Box<dyn ResourceProvider + Send>>,
    /// Supported source languages
    pub source_languages: HashSet<Language>,
    /// Supported target languages
    pub target_languages: HashSet<Language>,
    /// Minimum sequence length (default: 0)
    pub min_length: i64,
    /// Maximum sequence length (default: 512)
    pub max_length: Option<i64>,
    /// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
    pub do_sample: bool,
    /// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false)
    pub early_stopping: bool,
    /// Number of beams for beam search (default: 5)
    pub num_beams: i64,
    /// Temperature setting. Values higher than 1 will improve originality at the risk of reducing relevance (default: 1.0)
    pub temperature: f64,
    /// Top_k values for sampling tokens. Value higher than 0 will enable the feature (default: 0)
    pub top_k: i64,
    /// Top_p value for [Nucleus sampling, Holtzman et al.](http://arxiv.org/abs/1904.09751). Keep top tokens until cumulative probability reaches top_p (default: 0.9)
    pub top_p: f64,
    /// Repetition penalty (mostly useful for CTRL decoders). Values higher than 1 will penalize tokens that have been already generated. (default: 1.0)
    pub repetition_penalty: f64,
    /// Exponential penalty based on the length of the hypotheses generated (default: 1.0)
    pub length_penalty: f64,
    /// Number of allowed repetitions of n-grams. Values higher than 0 turn on this feature (default: 3)
    pub no_repeat_ngram_size: i64,
    /// Number of sequences to return for each prompt text (default: 1)
    pub num_return_sequences: i64,
    /// Device to place the model on (default: CUDA/GPU when available)
    pub device: Device,
    /// Number of beam groups for diverse beam generation. If provided and higher than 1, will split the beams into beam subgroups leading to more diverse generation.
    pub num_beam_groups: Option<i64>,
    /// Diversity penalty for diverse beam search. High values will enforce more difference between beam groups (default: 5.5)
    pub diversity_penalty: Option<f64>,
}

impl TranslationConfig {
    /// Create a new `TranslationConfiguration` from an available language.
    ///
    /// # Arguments
    ///
    /// * `language` - `Language` enum value (e.g. `Language::EnglishToFrench`)
    /// * `device` - `Device` to place the model on (CPU/GPU)
    ///
    /// # Example
    ///
    /// ```no_run
    /// # fn main() -> anyhow::Result<()> {     ///
    /// use rust_bert::marian::{
    ///     MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
    ///     MarianTargetLanguages, MarianVocabResources,
    /// };
    /// use rust_bert::pipelines::common::ModelType;
    /// use rust_bert::pipelines::translation::TranslationConfig;
    /// use rust_bert::resources::RemoteResource;
    /// use tch::Device;
    ///
    /// let model_resource = RemoteResource::from_pretrained(MarianModelResources::ROMANCE2ENGLISH);
    /// let config_resource = RemoteResource::from_pretrained(MarianConfigResources::ROMANCE2ENGLISH);
    /// let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ROMANCE2ENGLISH);
    /// let spm_resource = RemoteResource::from_pretrained(MarianSpmResources::ROMANCE2ENGLISH);
    ///
    /// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH;
    /// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH;
    ///
    /// let translation_config = TranslationConfig::new(
    ///     ModelType::Marian,
    ///     model_resource,
    ///     config_resource,
    ///     vocab_resource,
    ///     Some(spm_resource),
    ///     source_languages,
    ///     target_languages,
    ///     Device::cuda_if_available(),
    /// );
    /// # Ok(())
    /// # }
    /// ```
    pub fn new<RM, RC, RV, S, T>(
        model_type: ModelType,
        model_resource: RM,
        config_resource: RC,
        vocab_resource: RV,
        merges_resource: Option<RV>,
        source_languages: S,
        target_languages: T,
        device: impl Into<Option<Device>>,
    ) -> TranslationConfig
    where
        RM: ResourceProvider + Send + 'static,
        RC: ResourceProvider + Send + 'static,
        RV: ResourceProvider + Send + 'static,
        S: AsRef<[Language]>,
        T: AsRef<[Language]>,
    {
        let device = device.into().unwrap_or_else(Device::cuda_if_available);

        TranslationConfig {
            model_type,
            model_resource: Box::new(model_resource),
            config_resource: Box::new(config_resource),
            vocab_resource: Box::new(vocab_resource),
            merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>),
            source_languages: source_languages.as_ref().iter().cloned().collect(),
            target_languages: target_languages.as_ref().iter().cloned().collect(),
            device,
            min_length: 0,
            max_length: Some(512),
            do_sample: false,
            early_stopping: true,
            num_beams: 3,
            temperature: 1.0,
            top_k: 50,
            top_p: 1.0,
            repetition_penalty: 1.0,
            length_penalty: 1.0,
            no_repeat_ngram_size: 0,
            num_return_sequences: 1,
            num_beam_groups: None,
            diversity_penalty: None,
        }
    }
}

impl From<TranslationConfig> for GenerateConfig {
    fn from(config: TranslationConfig) -> GenerateConfig {
        GenerateConfig {
            model_resource: config.model_resource,
            config_resource: config.config_resource,
            merges_resource: config.merges_resource,
            vocab_resource: config.vocab_resource,
            min_length: config.min_length,
            max_length: config.max_length,
            do_sample: config.do_sample,
            early_stopping: config.early_stopping,
            num_beams: config.num_beams,
            temperature: config.temperature,
            top_k: config.top_k,
            top_p: config.top_p,
            repetition_penalty: config.repetition_penalty,
            length_penalty: config.length_penalty,
            no_repeat_ngram_size: config.no_repeat_ngram_size,
            num_return_sequences: config.num_return_sequences,
            num_beam_groups: config.num_beam_groups,
            diversity_penalty: config.diversity_penalty,
            device: config.device,
        }
    }
}

#[allow(clippy::large_enum_variant)]
/// # Abstraction that holds one particular translation model, for any of the supported models
pub enum TranslationOption {
    /// Translator based on Marian model
    Marian(MarianGenerator),
    /// Translator based on T5 model
    T5(T5Generator),
    /// Translator based on MBart50 model
    MBart(MBartGenerator),
    /// Translator based on M2M100 model
    M2M100(M2M100Generator),
}

impl TranslationOption {
    pub fn new(config: TranslationConfig) -> Result<Self, RustBertError> {
        match config.model_type {
            ModelType::Marian => Ok(TranslationOption::Marian(MarianGenerator::new(
                config.into(),
            )?)),
            ModelType::T5 => Ok(TranslationOption::T5(T5Generator::new(config.into())?)),
            ModelType::MBart => Ok(TranslationOption::MBart(MBartGenerator::new(
                config.into(),
            )?)),
            ModelType::M2M100 => Ok(TranslationOption::M2M100(M2M100Generator::new(
                config.into(),
            )?)),
            _ => Err(RustBertError::InvalidConfigurationError(format!(
                "Translation not implemented for {:?}!",
                config.model_type
            ))),
        }
    }

    /// Returns the `ModelType` for this TranslationOption
    pub fn model_type(&self) -> ModelType {
        match *self {
            Self::Marian(_) => ModelType::Marian,
            Self::T5(_) => ModelType::T5,
            Self::MBart(_) => ModelType::MBart,
            Self::M2M100(_) => ModelType::M2M100,
        }
    }

    fn validate_and_get_prefix_and_forced_bos_id(
        &self,
        source_language: Option<&Language>,
        target_language: Option<&Language>,
        supported_source_languages: &HashSet<Language>,
        supported_target_languages: &HashSet<Language>,
    ) -> Result<(Option<String>, Option<i64>), RustBertError> {
        if let Some(source_language) = source_language {
            if !supported_source_languages.contains(source_language) {
                return Err(RustBertError::ValueError(format!(
                    "{} not in list of supported languages: {:?}",
                    source_language, supported_source_languages
                )));
            }
        }

        if let Some(target_language) = target_language {
            if !supported_target_languages.contains(target_language) {
                return Err(RustBertError::ValueError(format!(
                    "{} not in list of supported languages: {:?}",
                    target_language, supported_target_languages
                )));
            }
        }

        Ok(match *self {
            Self::Marian(_) => {
                if supported_target_languages.len() > 1 {
                    (
                        Some(format!(
                            ">>{}<< ",
                            match target_language {
                                Some(value) => value.get_iso_639_1_code(),
                                None => {
                                    return Err(RustBertError::ValueError(format!(
                                        "Missing target language for Marian \
                                        (multiple languages supported by model: {:?}, \
                                        need to specify target language)",
                                        supported_target_languages
                                    )));
                                }
                            }
                        )),
                        None,
                    )
                } else {
                    (None, None)
                }
            }
            Self::T5(_) => (
                Some(format!(
                    "translate {} to {}:",
                    match source_language {
                        Some(value) => value,
                        None => {
                            return Err(RustBertError::ValueError(
                                "Missing source language for T5".to_string(),
                            ));
                        }
                    },
                    match target_language {
                        Some(value) => value,
                        None => {
                            return Err(RustBertError::ValueError(
                                "Missing target language for T5".to_string(),
                            ));
                        }
                    }
                )),
                None,
            ),
            Self::MBart(ref model) => (
                Some(format!(
                    ">>{}<< ",
                    match source_language {
                        Some(value) => value.get_iso_639_1_code(),
                        None => {
                            return Err(RustBertError::ValueError(format!(
                                "Missing source language for MBart\
                                (multiple languages supported by model: {:?}, \
                                need to specify target language)",
                                supported_source_languages
                            )));
                        }
                    }
                )),
                if let Some(target_language) = target_language {
                    Some(
                        model._get_tokenizer().convert_tokens_to_ids(&[format!(
                            ">>{}<<",
                            target_language.get_iso_639_1_code()
                        )])[0],
                    )
                } else {
                    return Err(RustBertError::ValueError(format!(
                        "Missing target language for MBart\
                        (multiple languages supported by model: {:?}, \
                        need to specify target language)",
                        supported_target_languages
                    )));
                },
            ),
            Self::M2M100(ref model) => (
                Some(match source_language {
                    Some(value) => {
                        let language_code = value.get_iso_639_1_code();
                        match language_code.len() {
                            2 => format!(">>{}.<< ", language_code),
                            3 => format!(">>{}<< ", language_code),
                            _ => {
                                return Err(RustBertError::ValueError(
                                    "Invalid ISO 639-3 code".to_string(),
                                ));
                            }
                        }
                    }
                    None => {
                        return Err(RustBertError::ValueError(format!(
                            "Missing source language for M2M100 \
                            (multiple languages supported by model: {:?}, \
                            need to specify target language)",
                            supported_source_languages
                        )));
                    }
                }),
                if let Some(target_language) = target_language {
                    let language_code = target_language.get_iso_639_1_code();
                    Some(
                        model._get_tokenizer().convert_tokens_to_ids(&[
                            match language_code.len() {
                                2 => format!(">>{}.<<", language_code),
                                3 => format!(">>{}<<", language_code),
                                _ => {
                                    return Err(RustBertError::ValueError(
                                        "Invalid ISO 639-3 code".to_string(),
                                    ));
                                }
                            },
                        ])[0],
                    )
                } else {
                    return Err(RustBertError::ValueError(format!(
                        "Missing target language for M2M100 \
                        (multiple languages supported by model: {:?}, \
                        need to specify target language)",
                        supported_target_languages
                    )));
                },
            ),
        })
    }

    /// Interface method to generate() of the particular models.
    pub fn generate<S>(
        &self,
        prompt_texts: Option<&[S]>,
        forced_bos_token_id: Option<i64>,
    ) -> Vec<String>
    where
        S: AsRef<str> + Sync,
    {
        match *self {
            Self::Marian(ref model) => model
                .generate(prompt_texts, None)
                .into_iter()
                .map(|output| output.text)
                .collect(),
            Self::T5(ref model) => model
                .generate(prompt_texts, None)
                .into_iter()
                .map(|output| output.text)
                .collect(),
            Self::MBart(ref model) => {
                let generate_options = GenerateOptions {
                    forced_bos_token_id,
                    ..Default::default()
                };
                model
                    .generate(prompt_texts, Some(generate_options))
                    .into_iter()
                    .map(|output| output.text)
                    .collect()
            }
            Self::M2M100(ref model) => {
                let generate_options = GenerateOptions {
                    forced_bos_token_id,
                    ..Default::default()
                };
                model
                    .generate(prompt_texts, Some(generate_options))
                    .into_iter()
                    .map(|output| output.text)
                    .collect()
            }
        }
    }
}

/// # TranslationModel to perform translation
pub struct TranslationModel {
    model: TranslationOption,
    supported_source_languages: HashSet<Language>,
    supported_target_languages: HashSet<Language>,
}

impl TranslationModel {
    /// Build a new `TranslationModel`
    ///
    /// # Arguments
    ///
    /// * `translation_config` - `TranslationConfig` object containing the resource references (model, vocabulary, configuration), translation options and device placement (CPU/GPU)
    ///
    /// # Example
    ///
    /// ```no_run
    /// # fn main() -> anyhow::Result<()> {     ///
    /// use rust_bert::marian::{
    ///     MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
    ///     MarianTargetLanguages, MarianVocabResources,
    /// };
    /// use rust_bert::pipelines::common::ModelType;
    /// use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel};
    /// use rust_bert::resources::RemoteResource;
    /// use tch::Device;
    ///
    /// let model_resource = RemoteResource::from_pretrained(MarianModelResources::ROMANCE2ENGLISH);
    /// let config_resource = RemoteResource::from_pretrained(MarianConfigResources::ROMANCE2ENGLISH);
    /// let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ROMANCE2ENGLISH);
    /// let spm_resource = RemoteResource::from_pretrained(MarianSpmResources::ROMANCE2ENGLISH);
    ///
    /// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH;
    /// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH;
    ///
    /// let translation_config = TranslationConfig::new(
    ///     ModelType::Marian,
    ///     model_resource,
    ///     config_resource,
    ///     vocab_resource,
    ///     Some(spm_resource),
    ///     source_languages,
    ///     target_languages,
    ///     Device::cuda_if_available(),
    /// );
    /// let mut summarization_model = TranslationModel::new(translation_config)?;
    /// # Ok(())
    /// # }
    /// ```
    pub fn new(translation_config: TranslationConfig) -> Result<TranslationModel, RustBertError> {
        let supported_source_languages = translation_config.source_languages.clone();
        let supported_target_languages = translation_config.target_languages.clone();

        let model = TranslationOption::new(translation_config)?;

        Ok(TranslationModel {
            model,
            supported_source_languages,
            supported_target_languages,
        })
    }

    /// Translates texts provided
    ///
    /// # Arguments
    /// * `input` - `&[&str]` Array of texts to summarize.
    ///
    /// # Returns
    /// * `Vec<String>` Translated texts
    ///
    /// # Example
    ///
    /// ```no_run
    /// # fn main() -> anyhow::Result<()> {
    /// use rust_bert::marian::{
    ///     MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
    ///     MarianTargetLanguages, MarianVocabResources,
    /// };
    /// use rust_bert::pipelines::common::ModelType;
    /// use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
    /// use rust_bert::resources::RemoteResource;
    /// use tch::Device;
    ///
    /// let model_resource = RemoteResource::from_pretrained(MarianModelResources::ENGLISH2ROMANCE);
    /// let config_resource = RemoteResource::from_pretrained(MarianConfigResources::ENGLISH2ROMANCE);
    /// let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ENGLISH2ROMANCE);
    /// let merges_resource = RemoteResource::from_pretrained(MarianSpmResources::ENGLISH2ROMANCE);
    /// let source_languages = MarianSourceLanguages::ENGLISH2ROMANCE;
    /// let target_languages = MarianTargetLanguages::ENGLISH2ROMANCE;
    ///
    /// let translation_config = TranslationConfig::new(
    ///     ModelType::Marian,
    ///     model_resource,
    ///     config_resource,
    ///     vocab_resource,
    ///     Some(merges_resource),
    ///     source_languages,
    ///     target_languages,
    ///     Device::cuda_if_available(),
    /// );
    /// let model = TranslationModel::new(translation_config)?;
    ///
    /// let input = ["This is a sentence to be translated"];
    /// let source_language = None;
    /// let target_language = Language::French;
    ///
    /// let output = model.translate(&input, source_language, target_language);
    /// # Ok(())
    /// # }
    /// ```
    pub fn translate<S>(
        &self,
        texts: &[S],
        source_language: impl Into<Option<Language>>,
        target_language: impl Into<Option<Language>>,
    ) -> Result<Vec<String>, RustBertError>
    where
        S: AsRef<str> + Sync,
    {
        let (prefix, forced_bos_token_id) = self.model.validate_and_get_prefix_and_forced_bos_id(
            source_language.into().as_ref(),
            target_language.into().as_ref(),
            &self.supported_source_languages,
            &self.supported_target_languages,
        )?;

        Ok(match prefix {
            Some(value) => {
                let texts = texts
                    .iter()
                    .map(|v| format!("{}{}", value, v.as_ref()))
                    .collect::<Vec<String>>();
                self.model.generate(Some(&texts), forced_bos_token_id)
            }
            None => self.model.generate(Some(texts), forced_bos_token_id),
        })
    }
}

#[cfg(test)]
mod test {
    use super::*;
    use crate::marian::{
        MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources,
        MarianTargetLanguages, MarianVocabResources,
    };
    use crate::resources::RemoteResource;

    #[test]
    #[ignore] // no need to run, compilation is enough to verify it is Send
    fn test() {
        let model_resource = RemoteResource::from_pretrained(MarianModelResources::ROMANCE2ENGLISH);
        let config_resource =
            RemoteResource::from_pretrained(MarianConfigResources::ROMANCE2ENGLISH);
        let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ROMANCE2ENGLISH);
        let merges_resource = RemoteResource::from_pretrained(MarianSpmResources::ROMANCE2ENGLISH);

        let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH;
        let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH;

        let translation_config = TranslationConfig::new(
            ModelType::Marian,
            model_resource,
            config_resource,
            vocab_resource,
            Some(merges_resource),
            source_languages,
            target_languages,
            Device::cuda_if_available(),
        );
        let _: Box<dyn Send> = Box::new(TranslationModel::new(translation_config));
    }
}