use crate::request_headers;
use crate::Key;
use crate::Provider;
use base64::prelude::*;
use bytes::Bytes;
use reqwest;
use reqwest::header::HeaderMap;
use reqwest::header::HeaderValue;
use serde::Deserialize;
use serde::Serialize;
use serde_json::json;
use serde_json::Value;
use std::collections::HashMap;
use std::error::Error;
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct TTSConfig {
pub output_format: Option<String>,
pub voice: Option<String>,
pub speed: Option<f64>,
pub language_code: Option<String>,
pub seed: Option<u64>,
pub other: Option<HashMap<String, Value>>,
}
impl PartialEq for TTSConfig {
fn eq(&self, other: &Self) -> bool {
fn compare_hashmap(
a: &Option<HashMap<String, Value>>,
b: &Option<HashMap<String, Value>>,
) -> bool {
match (a, b) {
(Some(a), Some(b)) => {
let mut a_keys: Vec<_> = a.keys().collect();
let mut b_keys: Vec<_> = b.keys().collect();
a_keys.sort();
b_keys.sort();
if a_keys != b_keys {
return false;
}
a_keys.iter().all(|key| a.get(*key) == b.get(*key))
}
(None, None) => true,
_ => false,
}
}
self.output_format == other.output_format
&& self.voice == other.voice
&& self.speed == other.speed
&& self.language_code == other.language_code
&& self.seed == other.seed
&& compare_hashmap(&self.other, &other.other)
}
}
#[test]
fn test_ttsconfig_eq() {
for i in 0..100 {
let mut a = TTSConfig::default();
let mut a_map = HashMap::new();
a_map.insert("foo".to_string(), json!("bar"));
a_map.insert("baz".to_string(), json!(42));
a.other = Some(a_map);
let mut b = TTSConfig::default();
let mut b_map = HashMap::new();
b_map.insert("baz".to_string(), json!(42));
b_map.insert("foo".to_string(), json!("bar"));
b.other = Some(b_map);
assert_eq!(a, b, "during iteration {i}");
}
}
fn is_openai_compatible(provider: &Provider) -> bool {
matches!(provider, Provider::OpenAICompatible(_))
}
fn address(provider: &Provider, key: &Key, model: Option<&str>, config: &TTSConfig) -> String {
if provider == &Provider::ElevenLabs {
let voice = config
.voice
.as_ref()
.expect("voice is required for ElevenLabs");
if let Some(output_format) = &config.output_format {
format!(
"{}/v1/text-to-speech/{voice}?{output_format}",
provider.domain()
)
} else {
format!("{}/v1/text-to-speech/{voice}", provider.domain())
}
} else if provider == &Provider::DeepInfra {
let model = model.unwrap_or("hexgrad/Kokoro-82M");
format!("{}/v1/inference/{}", provider.domain(), model)
} else if provider == &Provider::Hyperbolic {
format!("{}/v1/audio/generation", provider.domain())
} else if provider == &Provider::OpenAI {
format!("{}/v1/audio/speech", provider.domain())
} else if let Provider::OpenAICompatible(domain) = &provider {
format!("{domain}/v1/audio/speech")
} else if provider == &Provider::Google {
let domain = "https://texttospeech.googleapis.com";
let path = "/v1beta1/text:synthesize";
format!("{domain}{path}?key={}", key.key)
} else {
panic!("Unsupported TTS provider: {}", provider);
}
}
#[derive(Debug)]
pub struct Speech {
pub request_id: Option<String>,
pub file_format: String,
pub audio: Bytes,
}
impl Speech {
pub fn decode_speech(
audio: &str,
provider: &Provider,
output_format: Option<&str>,
) -> Result<Bytes, Box<dyn Error + Send + Sync>> {
let stripped = if provider == &Provider::DeepInfra {
let output_format = output_format.expect("no output format");
tracing::debug!("Decoding DeepInfra speech with output format: {output_format}");
let deepinfra_prefix = match output_format {
"mp3" => "data:audio/mp3;base64,",
"opus" => "data:audio/ogg; codec=\"opus\";base64,",
_ => panic!("Unsupported output format: {}", output_format),
};
match audio.strip_prefix(deepinfra_prefix) {
Some(stripped) => stripped,
None => panic!("prefix '{deepinfra_prefix}' not found"),
}
} else {
audio
};
let bytes = BASE64_STANDARD.decode(stripped).expect("no decode");
Ok(Bytes::from(bytes))
}
}
pub struct SpeechResponse {
provider: Provider,
resp: Bytes,
}
impl SpeechResponse {
pub fn bytes(&self) -> &Bytes {
&self.resp
}
pub fn raw_value(&self) -> Result<Value, Box<dyn Error + Send + Sync>> {
Ok(serde_json::from_slice::<Value>(&self.resp)?)
}
pub fn structured(&self) -> Result<Speech, Box<dyn Error + Send + Sync>> {
if self.provider == Provider::ElevenLabs {
Ok(Speech {
request_id: None,
file_format: "mp3".to_string(),
audio: self.resp.clone(),
})
} else if self.provider == Provider::DeepInfra {
let resp = self.raw_value()?;
tracing::debug!("Response: {resp}");
if resp.get("detail").is_some() {
return Err(format!("DeepInfra returned an error: {}", resp["detail"]).into());
}
let audio = resp["audio"].as_str().expect("no audio in resp");
let output_format = resp["output_format"].as_str().unwrap().to_string();
let out = Speech {
request_id: Some(resp["request_id"].as_str().unwrap().to_string()),
file_format: output_format.to_string(),
audio: Speech::decode_speech(audio, &self.provider, Some(&output_format))?,
};
Ok(out)
} else if self.provider == Provider::Hyperbolic {
let resp = self.raw_value()?;
tracing::debug!("Response: {resp}");
let audio = &resp["audio"].as_str().unwrap();
let out = Speech {
request_id: None,
file_format: "mp3".to_string(),
audio: Speech::decode_speech(audio, &self.provider, None)?,
};
Ok(out)
} else if self.provider == Provider::OpenAI || is_openai_compatible(&self.provider) {
let audio = self.resp.clone();
if let Ok(resp) = serde_json::from_slice::<Value>(&self.resp) {
tracing::debug!("Response: {resp}");
if resp.get("error").is_some() {
return Err(resp["error"].to_string().into());
}
}
let out = Speech {
request_id: None,
file_format: "mp3".to_string(),
audio,
};
Ok(out)
} else if self.provider == Provider::Google {
let resp = self.raw_value()?;
tracing::debug!("Response: {resp}");
if resp.get("error").is_some() {
return Err(resp["error"].to_string().into());
}
let audio = &resp["audioContent"].as_str().expect("audioContent");
let _timepoints = &resp["timepoints"].as_array().unwrap();
let out = Speech {
request_id: None,
file_format: "mp3".to_string(),
audio: Speech::decode_speech(audio, &self.provider, None)?,
};
Ok(out)
} else {
panic!("Unsupported TTS provider: {}", self.provider);
}
}
}
fn tts_headers(provider: &Provider, key: &Key) -> Result<HeaderMap, Box<dyn Error + Send + Sync>> {
let headers = if provider == &Provider::Google {
let mut headers = request_headers(key)?;
headers.remove("Authorization");
headers
} else if provider == &Provider::ElevenLabs {
let mut headers = request_headers(key)?;
headers.insert("xi-api-key", HeaderValue::from_str(&key.key)?);
headers.remove("Authorization");
headers
} else {
request_headers(key)?
};
Ok(headers)
}
fn tts_body(config: &TTSConfig, provider: &Provider, model: Option<&str>, text: &str) -> Value {
if provider == &Provider::ElevenLabs {
let mut body = json!({});
body["text"] = Value::String(text.to_string());
if let Some(model) = &model {
body["model_id"] = Value::String(model.to_string());
}
if let Some(language_code) = &config.language_code {
body["language_code"] = Value::String(language_code.clone());
}
if let Some(_speed) = &config.speed {
panic!("Set speed for ElevenLabs via stored settings for voice.");
}
if let Some(seed) = &config.seed {
body["seed"] = Value::String(seed.to_string());
}
return body;
}
let mut body = json!({});
if provider == &Provider::OpenAI || is_openai_compatible(provider) {
body["input"] = Value::String(text.to_string());
} else if provider == &Provider::Google {
body["input"] = json!({
"text": text.to_string()
});
} else {
body["text"] = Value::String(text.to_string());
}
if let Some(model) = &model {
body["model"] = Value::String(model.to_string());
}
if let Some(voice) = &config.voice {
if provider == &Provider::OpenAI || is_openai_compatible(provider) {
body["voice"] = Value::String(voice.clone());
} else if provider == &Provider::Google {
body["voice"] = json!({
"name": voice.clone()
});
if let Some(language_code) = &config.language_code {
body["voice"]["languageCode"] = Value::String(language_code.clone());
}
body["audioConfig"] = json!({
"audioEncoding": "LINEAR16",
"pitch": 0,
"speakingRate": 1
});
} else if provider == &Provider::DeepInfra {
body["preset_voice"] = Value::String(voice.clone());
} else {
panic!("Unsupported TTS provider: {}", provider);
}
}
if let Some(speed) = config.speed {
body["speed"] = Value::from(speed);
}
if let Some(output_format) = &config.output_format {
body["output_format"] = Value::String(output_format.clone());
}
if let Some(other) = &config.other {
for (key, value) in other {
body[key] = value.clone();
}
}
body
}
pub async fn tts(
provider: &Provider,
key: &Key,
config: &TTSConfig,
model: Option<&str>,
text: &str,
) -> Result<SpeechResponse, Box<dyn Error + Send + Sync>> {
let address = address(provider, key, model, config);
let headers = tts_headers(provider, key)?;
let body = tts_body(config, provider, model, text);
tracing::debug!("Requesting {address} for text-to-speech with {body}");
let client = reqwest::Client::new();
let resp = client
.post(address)
.headers(headers)
.json(&body)
.send()
.await?;
let speech_response = SpeechResponse {
provider: provider.clone(),
resp: resp.bytes().await?,
};
Ok(speech_response)
}