1use crate::{ApiResult, Error, FakeYou};
7use reqwest::StatusCode;
8use serde::{Deserialize, Serialize};
9
10use super::{STORAGE_URL, TTS_INFERENCE, TTS_JOB};
11
12fn uuid_idemptency_token_serialize<S>(
13 maybe_uuid: &Option<String>,
14 serializer: S,
15) -> Result<S::Ok, S::Error>
16where
17 S: serde::Serializer,
18{
19 match maybe_uuid {
20 Some(uuid) => serializer.serialize_str(uuid),
21 None => serializer.serialize_str(uuid::Uuid::new_v4().to_string().as_str()),
22 }
23}
24
25#[derive(Debug, Serialize, Deserialize)]
26pub struct InferenceBody {
27 pub tts_model_token: String,
29 pub inference_text: String,
31 #[serde(serialize_with = "uuid_idemptency_token_serialize")]
34 pub uuid_idempotency_token: Option<String>,
35}
36
37impl InferenceBody {
38 pub fn new(tts_model_token: &str, inference_text: &str) -> Self {
39 Self {
40 tts_model_token: tts_model_token.to_string(),
41 inference_text: inference_text.to_string(),
42 uuid_idempotency_token: None,
43 }
44 }
45}
46
47#[derive(Debug, Serialize, Deserialize)]
48pub struct TtsInferenceResult {
49 pub success: bool,
51 pub inference_job_token: String,
53}
54
55#[derive(Debug, Serialize, Deserialize)]
56pub struct TtsJobResult {
57 pub success: bool,
59 pub state: TtsJobState,
61}
62
63#[derive(Debug, Serialize, Deserialize)]
64pub struct TtsJobState {
65 pub job_token: String,
67 pub status: TtsJobStatus,
70 pub maybe_extra_status_description: Option<String>,
73 pub attempt_count: u32,
75 pub maybe_result_token: Option<String>,
79 pub maybe_public_bucket_wav_audio_path: Option<String>,
83 pub model_token: String,
85 pub tts_model_type: String,
87 pub title: String,
91 pub raw_inference_text: String,
93 pub created_at: String,
95 pub updated_at: String,
97}
98
99#[derive(Debug, Serialize, Deserialize)]
100#[serde(rename_all = "snake_case")]
101pub enum TtsJobStatus {
102 Pending,
103 Started,
104 CompleteSuccess,
105 CompleteFailure,
106 AttemptFailed,
107 Dead,
108}
109
110#[derive(Debug, Clone)]
111pub struct TtsOutputResult {
112 pub bytes: Vec<u8>,
114}
115
116pub trait TtsApi {
117 fn tts_inference(&self, inference_body: &InferenceBody) -> ApiResult<TtsInferenceResult>;
118 fn tts_job(&self, job_id: &str) -> ApiResult<TtsJobResult>;
119 fn tts_output(&self, public_bucket_wav_audio_path: &str) -> ApiResult<TtsOutputResult>;
120}
121
122impl TtsApi for FakeYou {
123 fn tts_inference(&self, voice_settings: &InferenceBody) -> ApiResult<TtsInferenceResult> {
124 let voice_settings = serde_json::to_value(voice_settings).unwrap();
125
126 let url = format!("{}/{}", &self.api_url, TTS_INFERENCE);
127
128 let response = self
129 .client
130 .post(url.as_str())
131 .header("Accept", "application/json")
132 .header("Content-Type", "application/json")
133 .json(&voice_settings)
134 .send()
135 .map_err(|e| Error::RequestFailed(e.to_string()))?;
136
137 match response.status() {
138 StatusCode::OK => response
139 .json::<TtsInferenceResult>()
140 .map_err(|e| Error::ParseError(e.to_string())),
141 StatusCode::BAD_REQUEST => Err(Error::BadRequest),
142 StatusCode::UNAUTHORIZED => Err(Error::Unauthorized),
143 StatusCode::TOO_MANY_REQUESTS => Err(Error::TooManyRequests),
144 StatusCode::INTERNAL_SERVER_ERROR => Err(Error::InternalServerError),
145 code => Err(Error::Unknown(code.as_u16())),
146 }
147 }
148
149 fn tts_job(&self, job_id: &str) -> ApiResult<TtsJobResult> {
150 let url = format!("{}/{}/{}", &self.api_url, TTS_JOB, job_id);
151
152 let response = self
153 .client
154 .get(url.as_str())
155 .header("Accept", "application/json")
156 .send()
157 .map_err(|e| Error::RequestFailed(e.to_string()))?;
158
159 match response.status() {
160 StatusCode::OK => response
161 .json::<TtsJobResult>()
162 .map_err(|e| Error::ParseError(e.to_string())),
163 code => Err(Error::Unknown(code.as_u16())),
164 }
165 }
166
167 fn tts_output(&self, public_bucket_wav_audio_path: &str) -> ApiResult<TtsOutputResult> {
168 let url = format!("{}{}", STORAGE_URL, public_bucket_wav_audio_path);
169
170 let response = self
171 .client
172 .get(url.as_str())
173 .header("Accept", "audio/wav")
174 .send()
175 .map_err(|e| Error::RequestFailed(e.to_string()))?;
176
177 let bytes = response
178 .bytes()
179 .map_err(|e| Error::ParseError(e.to_string()))?
180 .to_vec();
181
182 Ok(TtsOutputResult { bytes })
183 }
184}