use crate::{ApiResult, Error, FakeYou};
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
use super::{STORAGE_URL, TTS_INFERENCE, TTS_JOB};
fn uuid_idemptency_token_serialize<S>(
maybe_uuid: &Option<String>,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match maybe_uuid {
Some(uuid) => serializer.serialize_str(uuid),
None => serializer.serialize_str(uuid::Uuid::new_v4().to_string().as_str()),
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct InferenceBody {
pub tts_model_token: String,
pub inference_text: String,
#[serde(serialize_with = "uuid_idemptency_token_serialize")]
pub uuid_idempotency_token: Option<String>,
}
impl InferenceBody {
pub fn new(tts_model_token: &str, inference_text: &str) -> Self {
Self {
tts_model_token: tts_model_token.to_string(),
inference_text: inference_text.to_string(),
uuid_idempotency_token: None,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct TtsInferenceResult {
pub success: bool,
pub inference_job_token: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct TtsJobResult {
pub success: bool,
pub state: TtsJobState,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct TtsJobState {
pub job_token: String,
pub status: TtsJobStatus,
pub maybe_extra_status_description: Option<String>,
pub attempt_count: u32,
pub maybe_result_token: Option<String>,
pub maybe_public_bucket_wav_audio_path: Option<String>,
pub model_token: String,
pub tts_model_type: String,
pub title: String,
pub raw_inference_text: String,
pub created_at: String,
pub updated_at: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TtsJobStatus {
Pending,
Started,
CompleteSuccess,
CompleteFailure,
AttemptFailed,
Dead,
}
#[derive(Debug, Clone)]
pub struct TtsOutputResult {
pub bytes: Vec<u8>,
}
pub trait TtsApi {
fn tts_inference(&self, inference_body: &InferenceBody) -> ApiResult<TtsInferenceResult>;
fn tts_job(&self, job_id: &str) -> ApiResult<TtsJobResult>;
fn tts_output(&self, public_bucket_wav_audio_path: &str) -> ApiResult<TtsOutputResult>;
}
impl TtsApi for FakeYou {
fn tts_inference(&self, voice_settings: &InferenceBody) -> ApiResult<TtsInferenceResult> {
let voice_settings = serde_json::to_value(voice_settings).unwrap();
let url = format!("{}/{}", &self.api_url, TTS_INFERENCE);
let response = self
.client
.post(url.as_str())
.header("Accept", "application/json")
.header("Content-Type", "application/json")
.json(&voice_settings)
.send()
.map_err(|e| Error::RequestFailed(e.to_string()))?;
match response.status() {
StatusCode::OK => response
.json::<TtsInferenceResult>()
.map_err(|e| Error::ParseError(e.to_string())),
StatusCode::BAD_REQUEST => Err(Error::BadRequest),
StatusCode::UNAUTHORIZED => Err(Error::Unauthorized),
StatusCode::TOO_MANY_REQUESTS => Err(Error::TooManyRequests),
StatusCode::INTERNAL_SERVER_ERROR => Err(Error::InternalServerError),
code => Err(Error::Unknown(code.as_u16())),
}
}
fn tts_job(&self, job_id: &str) -> ApiResult<TtsJobResult> {
let url = format!("{}/{}/{}", &self.api_url, TTS_JOB, job_id);
let response = self
.client
.get(url.as_str())
.header("Accept", "application/json")
.send()
.map_err(|e| Error::RequestFailed(e.to_string()))?;
match response.status() {
StatusCode::OK => response
.json::<TtsJobResult>()
.map_err(|e| Error::ParseError(e.to_string())),
code => Err(Error::Unknown(code.as_u16())),
}
}
fn tts_output(&self, public_bucket_wav_audio_path: &str) -> ApiResult<TtsOutputResult> {
let url = format!("{}{}", STORAGE_URL, public_bucket_wav_audio_path);
let response = self
.client
.get(url.as_str())
.header("Accept", "audio/wav")
.send()
.map_err(|e| Error::RequestFailed(e.to_string()))?;
let bytes = response
.bytes()
.map_err(|e| Error::ParseError(e.to_string()))?
.to_vec();
Ok(TtsOutputResult { bytes })
}
}