elikoga_textsynth/
translate.rs

1//! Provides translate api
2
3use serde::{Deserialize, Serialize};
4use serde_with::skip_serializing_none;
5use thiserror::Error;
6
7use crate::{IsEngine, TextSynthClient};
8
9/// Enum for the different translation engines available for TextSynth
10#[derive(strum::Display)]
11pub enum Engine {
12    /// M2M100 1.2B is a 1.2 billion parameter language model specialized for
13    /// translation. It supports multilingual translation between 100 languages.
14    #[strum(serialize = "m2m100_1_2B")]
15    M2M10012B,
16}
17
18impl IsEngine for Engine {
19    fn is_translation(&self) -> bool {
20        true
21    }
22}
23
24/// Struct for a translation request
25#[skip_serializing_none]
26#[derive(Serialize, Builder)]
27#[builder(setter(into))]
28#[builder(build_fn(validate = "Self::validate"))]
29pub struct Request {
30    /// Each string is an independent text to translate. Batches of at most 64
31    /// texts can be provided.
32    text: Vec<String>,
33    /// Two or three character ISO language code for the source language. The
34    /// special value "auto" indicates to auto-detect the source language. The
35    /// language auto-detection does not support all languages and is based on
36    /// heuristics. Hence if you know the source language you should explicitly
37    /// indicate it.
38    source_lang: String,
39    /// Two or three character ISO language code for the target language.
40    target_lang: String,
41    /// Number of beams used to generate the translated text. The translation is
42    /// usually better with a larger number of beams. Each beam requires
43    /// generating a separate translated text, hence the number of generated
44    /// tokens is multiplied by the number of beams.
45    #[builder(setter(strip_option))]
46    #[builder(default)]
47    num_beams: Option<u32>,
48    /// The translation model only translates one sentence at a time. Hence the
49    /// input must be split into sentences. When split_sentences = true
50    /// (default), each input text is automatically split into sentences using
51    /// source language specific heuristics. If you are sure that each input
52    /// text contains only one sentence, it is better to disable the automatic
53    /// sentence splitting.
54    #[builder(setter(strip_option))]
55    #[builder(default)]
56    split_sentences: Option<bool>,
57}
58
59impl RequestBuilder {
60    fn validate(&self) -> Result<(), String> {
61        // text has length 1 to 64
62        match &self.text {
63            Some(text) if !(1..=64).contains(&text.len()) => {
64                return Err("text has to have 1 to 64 elements".to_string());
65            }
66            _ => {}
67        }
68        // source_lang is 2 or 3 characters long or is "auto"
69        match &self.source_lang {
70            Some(source_lang)
71                if !(source_lang.len() == 2 || source_lang.len() == 3 || source_lang == "auto") =>
72            {
73                return Err(
74                    "source_lang has to be a 2 or 3 characters long iso language code or be \"auto\""
75                        .to_string(),
76                );
77            }
78            _ => {}
79        }
80        // target_lang is 2 or 3 characters long
81        match &self.target_lang {
82            Some(target_lang) if !(target_lang.len() == 2 || target_lang.len() == 3) => {
83                return Err(
84                    "target_lang has to be a 2 or 3 characters long iso language code".to_string(),
85                );
86            }
87            _ => {}
88        }
89        // num_beams has range 1 to 5
90        match self.num_beams {
91            Some(Some(num_beams)) if !(1..=5).contains(&num_beams) => {
92                return Err("num_beams has to be in the range 1 to 5".to_string());
93            }
94            _ => {}
95        }
96        Ok(())
97    }
98}
99
100/// Struct for a translation answer
101#[derive(Deserialize, Debug)]
102pub struct Response {
103    /// Array of translation objects.
104    pub translations: Vec<Translation>,
105    /// Indicate the total number of input tokens. It is useful to estimate the
106    /// number of compute resources used by the request.
107    pub input_tokens: u32,
108    /// Indicate the total number of generated tokens. It is useful to estimate
109    /// the number of compute resources used by the request.
110    pub output_tokens: u32,
111}
112
113/// a single translation result
114#[derive(Deserialize, Debug)]
115pub struct Translation {
116    /// translated text
117    pub text: String,
118    /// ISO language code corresponding to the detected lang (identical to
119    /// source_lang if language auto-detection is not enabled)
120    pub detected_source_lang: String,
121}
122
123#[derive(Error, Debug)]
124/// Error for a completion answer
125pub enum Error {
126    /// Serde error
127    #[error("Serde error: {0}")]
128    SerdeError(#[from] serde_json::Error),
129    /// Error from Reqwest
130    #[error("Reqwest error: {0}")]
131    RequestError(#[from] reqwest::Error),
132}
133
134impl TextSynthClient {
135    /// Perform a completion request
136    pub async fn translate(&self, engine: &Engine, request: &Request) -> Result<Response, Error> {
137        let request_json = serde_json::to_string(&request)?;
138        let url = format!("{}/engines/{}/translate", self.base_url, engine);
139        let response = self.client.post(&url).body(request_json).send().await?;
140        response.json().await.map_err(|e| e.into())
141    }
142}