elikoga_textsynth/translate.rs
1//! Provides translate api
2
3use serde::{Deserialize, Serialize};
4use serde_with::skip_serializing_none;
5use thiserror::Error;
6
7use crate::{IsEngine, TextSynthClient};
8
9/// Enum for the different translation engines available for TextSynth
10#[derive(strum::Display)]
11pub enum Engine {
12 /// M2M100 1.2B is a 1.2 billion parameter language model specialized for
13 /// translation. It supports multilingual translation between 100 languages.
14 #[strum(serialize = "m2m100_1_2B")]
15 M2M10012B,
16}
17
18impl IsEngine for Engine {
19 fn is_translation(&self) -> bool {
20 true
21 }
22}
23
24/// Struct for a translation request
25#[skip_serializing_none]
26#[derive(Serialize, Builder)]
27#[builder(setter(into))]
28#[builder(build_fn(validate = "Self::validate"))]
29pub struct Request {
30 /// Each string is an independent text to translate. Batches of at most 64
31 /// texts can be provided.
32 text: Vec<String>,
33 /// Two or three character ISO language code for the source language. The
34 /// special value "auto" indicates to auto-detect the source language. The
35 /// language auto-detection does not support all languages and is based on
36 /// heuristics. Hence if you know the source language you should explicitly
37 /// indicate it.
38 source_lang: String,
39 /// Two or three character ISO language code for the target language.
40 target_lang: String,
41 /// Number of beams used to generate the translated text. The translation is
42 /// usually better with a larger number of beams. Each beam requires
43 /// generating a separate translated text, hence the number of generated
44 /// tokens is multiplied by the number of beams.
45 #[builder(setter(strip_option))]
46 #[builder(default)]
47 num_beams: Option<u32>,
48 /// The translation model only translates one sentence at a time. Hence the
49 /// input must be split into sentences. When split_sentences = true
50 /// (default), each input text is automatically split into sentences using
51 /// source language specific heuristics. If you are sure that each input
52 /// text contains only one sentence, it is better to disable the automatic
53 /// sentence splitting.
54 #[builder(setter(strip_option))]
55 #[builder(default)]
56 split_sentences: Option<bool>,
57}
58
59impl RequestBuilder {
60 fn validate(&self) -> Result<(), String> {
61 // text has length 1 to 64
62 match &self.text {
63 Some(text) if !(1..=64).contains(&text.len()) => {
64 return Err("text has to have 1 to 64 elements".to_string());
65 }
66 _ => {}
67 }
68 // source_lang is 2 or 3 characters long or is "auto"
69 match &self.source_lang {
70 Some(source_lang)
71 if !(source_lang.len() == 2 || source_lang.len() == 3 || source_lang == "auto") =>
72 {
73 return Err(
74 "source_lang has to be a 2 or 3 characters long iso language code or be \"auto\""
75 .to_string(),
76 );
77 }
78 _ => {}
79 }
80 // target_lang is 2 or 3 characters long
81 match &self.target_lang {
82 Some(target_lang) if !(target_lang.len() == 2 || target_lang.len() == 3) => {
83 return Err(
84 "target_lang has to be a 2 or 3 characters long iso language code".to_string(),
85 );
86 }
87 _ => {}
88 }
89 // num_beams has range 1 to 5
90 match self.num_beams {
91 Some(Some(num_beams)) if !(1..=5).contains(&num_beams) => {
92 return Err("num_beams has to be in the range 1 to 5".to_string());
93 }
94 _ => {}
95 }
96 Ok(())
97 }
98}
99
100/// Struct for a translation answer
101#[derive(Deserialize, Debug)]
102pub struct Response {
103 /// Array of translation objects.
104 pub translations: Vec<Translation>,
105 /// Indicate the total number of input tokens. It is useful to estimate the
106 /// number of compute resources used by the request.
107 pub input_tokens: u32,
108 /// Indicate the total number of generated tokens. It is useful to estimate
109 /// the number of compute resources used by the request.
110 pub output_tokens: u32,
111}
112
113/// a single translation result
114#[derive(Deserialize, Debug)]
115pub struct Translation {
116 /// translated text
117 pub text: String,
118 /// ISO language code corresponding to the detected lang (identical to
119 /// source_lang if language auto-detection is not enabled)
120 pub detected_source_lang: String,
121}
122
123#[derive(Error, Debug)]
124/// Error for a completion answer
125pub enum Error {
126 /// Serde error
127 #[error("Serde error: {0}")]
128 SerdeError(#[from] serde_json::Error),
129 /// Error from Reqwest
130 #[error("Reqwest error: {0}")]
131 RequestError(#[from] reqwest::Error),
132}
133
134impl TextSynthClient {
135 /// Perform a completion request
136 pub async fn translate(&self, engine: &Engine, request: &Request) -> Result<Response, Error> {
137 let request_json = serde_json::to_string(&request)?;
138 let url = format!("{}/engines/{}/translate", self.base_url, engine);
139 let response = self.client.post(&url).body(request_json).send().await?;
140 response.json().await.map_err(|e| e.into())
141 }
142}