use crate::engine::{estimate_word_boundaries, preprocess_speech_markdown, TtsEngine};
use crate::types::{normalize_gender, LanguageCode, TtsError, TtsResult, Voice, WordBoundary};
use std::collections::HashMap;
#[cfg(feature = "cloud")]
use {
tungstenite::{connect, Message},
url::Url,
uuid::Uuid,
};
#[derive(Debug, Clone, Default)]
struct CloudConfig {
synth_url: String,
auth_header: String,
auth_prefix: String,
voice_param: String,
model_param: Option<String>,
model_default: Option<String>,
default_voice: Option<String>,
text_field: String,
extra_body: HashMap<String, serde_json::Value>,
body_is_ssml: bool,
content_type: Option<String>,
extra_headers: HashMap<String, String>,
voices_url: Option<String>,
provider_id: String,
}
#[derive(Debug)]
pub struct CloudEngine {
config: CloudConfig,
api_key: String,
credentials: HashMap<String, String>,
client: reqwest::blocking::Client,
}
impl CloudEngine {
pub fn new(id: &str, credentials: &HashMap<String, String>) -> Option<Self> {
let config = build_config(id, credentials)?;
let api_key = credentials
.get("apiKey")
.or_else(|| credentials.get("subscriptionKey"))
.or_else(|| credentials.get("token"))
.cloned()
.unwrap_or_default();
Some(CloudEngine {
config,
api_key,
credentials: credentials.clone(),
client: reqwest::blocking::Client::new(),
})
}
}
#[allow(clippy::too_many_lines)]
fn build_config(id: &str, creds: &HashMap<String, String>) -> Option<CloudConfig> {
match id {
"openai" => Some(CloudConfig {
synth_url: "https://api.openai.com/v1/audio/speech".into(),
auth_header: "Authorization".into(),
auth_prefix: "Bearer ".into(),
voice_param: "voice".into(),
model_param: Some("model".into()),
model_default: Some("gpt-4o-mini-tts".into()),
default_voice: Some("alloy".into()),
text_field: "input".into(),
provider_id: "openai".into(),
..Default::default()
}),
"elevenlabs" => {
let voice_id = creds
.get("voiceId")
.cloned()
.unwrap_or_else(|| "21m00Tcm4TlvDq8ikWAM".into());
Some(CloudConfig {
synth_url: format!("https://api.elevenlabs.io/v1/text-to-speech/{voice_id}"),
auth_header: "xi-api-key".into(),
model_param: Some("model_id".into()),
model_default: Some("eleven_multilingual_v2".into()),
text_field: "text".into(),
voices_url: Some("https://api.elevenlabs.io/v1/voices".into()),
provider_id: "elevenlabs".into(),
..Default::default()
})
}
"azure" => {
let region = creds
.get("region")
.cloned()
.unwrap_or_else(|| "eastus".into());
let mut extra = HashMap::new();
extra.insert(
"X-Microsoft-OutputFormat".into(),
"audio-24khz-96kbitrate-mono-mp3".into(),
);
extra.insert("User-Agent".into(), "rust-tts-wrapper".into());
Some(CloudConfig {
synth_url: format!(
"https://{region}.tts.speech.microsoft.com/cognitiveservices/v1"
),
auth_header: "Ocp-Apim-Subscription-Key".into(),
default_voice: Some("en-US-AriaNeural".into()),
body_is_ssml: true,
content_type: Some("application/ssml+xml".into()),
extra_headers: extra,
voices_url: Some(format!(
"https://{region}.tts.speech.microsoft.com/cognitiveservices/voices/list"
)),
provider_id: "azure".into(),
..Default::default()
})
}
"google" => {
let api_key = creds.get("apiKey").cloned().unwrap_or_default();
Some(CloudConfig {
synth_url: format!(
"https://texttospeech.googleapis.com/v1/text:synthesize?key={api_key}"
),
text_field: "text".into(),
voices_url: Some(format!(
"https://texttospeech.googleapis.com/v1/voices?key={api_key}"
)),
provider_id: "google".into(),
..Default::default()
})
}
"cartesia" => Some(CloudConfig {
synth_url: "https://api.cartesia.ai/tts/bytes".into(),
auth_header: "X-API-Key".into(),
voice_param: "voice_id".into(),
model_param: Some("model_id".into()),
model_default: Some("sonic-2".into()),
text_field: "text".into(),
voices_url: Some("https://api.cartesia.ai/voices".into()),
provider_id: "cartesia".into(),
..Default::default()
}),
"deepgram" => Some(CloudConfig {
synth_url: "https://api.deepgram.com/v1/speak".into(),
auth_header: "Authorization".into(),
auth_prefix: "Token ".into(),
voice_param: "voice".into(),
default_voice: Some("aura-asteria-en".into()),
text_field: "text".into(),
provider_id: "deepgram".into(),
..Default::default()
}),
"playht" => {
let user_id = creds.get("userId").cloned().unwrap_or_default();
let mut extra = HashMap::new();
extra.insert("user_id".into(), serde_json::Value::String(user_id));
Some(CloudConfig {
synth_url: "https://api.play.ht/api/v2/tts".into(),
auth_header: "Authorization".into(),
auth_prefix: "Bearer ".into(),
voice_param: "voice".into(),
text_field: "text".into(),
extra_body: extra,
provider_id: "playht".into(),
..Default::default()
})
}
"fishaudio" => Some(CloudConfig {
synth_url: "https://api.fish.audio/v1/tts".into(),
auth_header: "Authorization".into(),
auth_prefix: "Bearer ".into(),
voice_param: "reference_id".into(),
text_field: "text".into(),
provider_id: "fishaudio".into(),
..Default::default()
}),
"hume" => Some(CloudConfig {
synth_url: "https://api.hume.ai/v0/tts".into(),
auth_header: "Authorization".into(),
auth_prefix: "Bearer ".into(),
voice_param: "voice".into(),
text_field: "text".into(),
provider_id: "hume".into(),
..Default::default()
}),
"mistral" => Some(CloudConfig {
synth_url: "https://api.mistral.ai/v1/tts".into(),
auth_header: "Authorization".into(),
auth_prefix: "Bearer ".into(),
voice_param: "voice".into(),
text_field: "text".into(),
provider_id: "mistral".into(),
..Default::default()
}),
"murf" => Some(CloudConfig {
synth_url: "https://api.murf.ai/v1/speech/generate".into(),
auth_header: "api-key".into(),
voice_param: "voice_id".into(),
text_field: "text".into(),
provider_id: "murf".into(),
..Default::default()
}),
"resemble" => Some(CloudConfig {
synth_url: "https://app.resemble.ai/api/v2/synthesize".into(),
auth_header: "Authorization".into(),
auth_prefix: "Token ".into(),
voice_param: "voice_uuid".into(),
text_field: "text".into(),
provider_id: "resemble".into(),
..Default::default()
}),
"unrealspeech" => Some(CloudConfig {
synth_url: "https://api.v7.unrealspeech.com/speech".into(),
auth_header: "Authorization".into(),
auth_prefix: "Bearer ".into(),
voice_param: "voice_id".into(),
default_voice: Some("Scarlett".into()),
text_field: "text".into(),
provider_id: "unrealspeech".into(),
..Default::default()
}),
"upliftai" => Some(CloudConfig {
synth_url: "https://api.upliftai.org/v1/tts".into(),
auth_header: "Authorization".into(),
auth_prefix: "Bearer ".into(),
voice_param: "voice".into(),
text_field: "text".into(),
provider_id: "upliftai".into(),
..Default::default()
}),
"watson" => {
let region = creds
.get("region")
.cloned()
.unwrap_or_else(|| "us-east".into());
let instance_id = creds.get("instanceId").cloned().unwrap_or_default();
Some(CloudConfig {
synth_url: format!(
"https://{region}.text-to-speech.watson.cloud.ibm.com/instances/{instance_id}/v1/synthesize"
),
auth_header: "Authorization".into(),
auth_prefix: format!(
"Basic {}:",
base64_encode(&creds.get("apiKey").cloned().unwrap_or_default())
),
voice_param: "voice".into(),
text_field: "text".into(),
provider_id: "watson".into(),
..Default::default()
})
}
"witai" => Some(CloudConfig {
synth_url: "https://api.wit.ai/synthesize?v=20240304".into(),
auth_header: "Authorization".into(),
auth_prefix: "Bearer ".into(),
provider_id: "witai".into(),
..Default::default()
}),
"xai" => Some(CloudConfig {
synth_url: "https://api.x.ai/v1/audio/speech".into(),
auth_header: "Authorization".into(),
auth_prefix: "Bearer ".into(),
voice_param: "voice".into(),
text_field: "input".into(),
provider_id: "xai".into(),
..Default::default()
}),
"modelslab" => Some(CloudConfig {
synth_url: "https://modelslab.com/api/v1/text_to_speech".into(),
voice_param: "voice".into(),
text_field: "text".into(),
provider_id: "modelslab".into(),
..Default::default()
}),
"polly" => Some(CloudConfig {
synth_url: "https://polly.us-east-1.amazonaws.com/v1/speech".into(),
voice_param: "VoiceId".into(),
default_voice: Some("Joanna".into()),
text_field: "Text".into(),
provider_id: "polly".into(),
..Default::default()
}),
_ => None,
}
}
fn base64_encode(data: &str) -> String {
use base64::Engine;
base64::engine::general_purpose::STANDARD.encode(data.as_bytes())
}
fn build_azure_ssml(text: &str, voice: &str, rate: f32, pitch: f32) -> String {
let lang = voice.chars().take(5).collect::<String>();
let escaped = text
.replace('&', "&")
.replace('<', "<")
.replace('>', ">");
let mut prosody_attrs = Vec::new();
let rate_str = match rate {
r if r < 0.7 => "x-slow",
r if r < 0.85 => "slow",
r if r < 1.15 => "medium",
r if r < 1.4 => "fast",
_ => "x-fast",
};
let pitch_str = match pitch {
p if p < 0.7 => "x-low",
p if p < 0.85 => "low",
p if p < 1.15 => "medium",
p if p < 1.4 => "high",
_ => "x-high",
};
if (rate - 1.0).abs() > f32::EPSILON {
prosody_attrs.push(format!("rate=\"{rate_str}\""));
}
if (pitch - 1.0).abs() > f32::EPSILON {
prosody_attrs.push(format!("pitch=\"{pitch_str}\""));
}
let inner = if prosody_attrs.is_empty() {
escaped
} else {
format!("<prosody {}>{escaped}</prosody>", prosody_attrs.join(" "))
};
format!(
"<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis' xml:lang='{lang}'>\
<voice name='{voice}'>{inner}</voice></speak>"
)
}
fn build_google_request(
text: &str,
voice: &str,
add_marks: bool,
) -> (serde_json::Value, Vec<String>) {
let lang = voice.chars().take(5).collect::<String>();
let mut words_list = Vec::new();
let input = if add_marks {
let words: Vec<&str> = text.split_whitespace().filter(|w| !w.is_empty()).collect();
let mut ssml = String::from("<speak>");
for (i, w) in words.iter().enumerate() {
if i > 0 {
ssml.push(' ');
}
let _ = std::fmt::Write::write_fmt(&mut ssml, format_args!("<mark name=\"{i}\"/>{w}"));
words_list = words.iter().map(|w| (*w).to_string()).collect();
}
ssml.push_str("</speak>");
serde_json::json!({ "ssml": ssml })
} else {
serde_json::json!({ "text": text })
};
let mut body = serde_json::json!({
"input": input,
"voice": { "languageCode": lang, "name": voice },
"audioConfig": { "audioEncoding": "MP3" }
});
if add_marks {
body["enableTimePointing"] = serde_json::json!(["SSML_MARK"]);
}
(body, words_list)
}
fn parse_google_timepoints(
timepoints: &[serde_json::Value],
words: &[String],
) -> Vec<WordBoundary> {
#[derive(Clone)]
struct RawTp {
index: usize,
time_ms: u64,
}
let mut raw: Vec<RawTp> = Vec::new();
for tp in timepoints {
let mark = tp.get("markName").and_then(|v| v.as_str()).unwrap_or("");
let idx: usize = mark.parse().unwrap_or(usize::MAX);
let secs = tp
.get("timeSeconds")
.and_then(serde_json::Value::as_f64)
.unwrap_or(0.0);
if idx < words.len() {
raw.push(RawTp {
index: idx,
time_ms: (secs * 1000.0) as u64,
});
}
}
raw.sort_by_key(|r| r.time_ms);
let mut boundaries = Vec::with_capacity(raw.len());
for (i, tp) in raw.iter().enumerate() {
let word = &words[tp.index];
let duration = if i + 1 < raw.len() {
raw[i + 1].time_ms.saturating_sub(tp.time_ms)
} else {
((word.len() as u64) * 80).max(50)
};
boundaries.push(WordBoundary {
text: word.clone(),
offset: tp.time_ms,
duration,
});
}
boundaries
}
fn map_azure_voices(json: &[serde_json::Value]) -> Vec<Voice> {
let mut voices = Vec::new();
for v in json {
let Some(short_name) = v.get("ShortName").and_then(|v| v.as_str()) else {
continue;
};
let name = v
.get("DisplayName")
.and_then(|v| v.as_str())
.unwrap_or(short_name)
.to_string();
let gender_raw = v.get("Gender").and_then(|v| v.as_str()).unwrap_or("");
let locale = v.get("Locale").and_then(|v| v.as_str()).unwrap_or("en-US");
voices.push(Voice {
id: short_name.to_string(),
name,
gender: normalize_gender(gender_raw),
provider: "azure".to_string(),
language_codes: vec![LanguageCode {
bcp47: locale.to_string(),
iso639_3: locale.split('-').next().unwrap_or("en").to_string(),
display: v
.get("LocaleName")
.and_then(|v| v.as_str())
.unwrap_or(locale)
.to_string(),
}],
});
}
voices
}
fn map_google_voices(json: &[serde_json::Value]) -> Vec<Voice> {
let mut voices = Vec::new();
for v in json {
let Some(name) = v.get("name").and_then(|v| v.as_str()) else {
continue;
};
let gender_raw = v.get("ssmlGender").and_then(|v| v.as_str()).unwrap_or("");
let lang_codes = v
.get("languageCodes")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|c| {
let code = c.as_str()?;
Some(LanguageCode {
iso639_3: code.split('-').next()?.to_string(),
bcp47: code.to_string(),
display: code.to_string(),
})
})
.collect::<Vec<_>>()
})
.unwrap_or_default();
voices.push(Voice {
id: name.to_string(),
name: name.to_string(),
gender: normalize_gender(gender_raw),
provider: "google".to_string(),
language_codes: lang_codes,
});
}
voices
}
#[allow(
clippy::too_many_lines,
clippy::cast_precision_loss,
clippy::map_unwrap_or
)]
#[cfg(feature = "cloud")]
fn compute_durations(boundaries: &mut [WordBoundary]) {
if boundaries.is_empty() {
return;
}
if boundaries.len() == 1 {
boundaries[0].duration = boundaries[0].duration.max(500);
return;
}
let len = boundaries.len();
for i in 0..(len - 1) {
if boundaries[i].duration == 0 {
boundaries[i].duration = boundaries[i + 1]
.offset
.saturating_sub(boundaries[i].offset);
}
}
if boundaries[len - 1].duration == 0 {
boundaries[len - 1].duration = 500;
}
}
impl TtsEngine for CloudEngine {
#[allow(clippy::too_many_lines, clippy::cast_precision_loss)]
fn speak(
&self,
text: &str,
voice: Option<&str>,
rate: f32,
pitch: f32,
_volume: f32,
mut on_audio: Option<crate::engine::OnAudioCallback>,
mut on_boundary: Option<crate::engine::OnBoundaryCallback>,
) -> TtsResult<()> {
let voice_to_use = voice
.map(std::string::ToString::to_string)
.or_else(|| self.config.default_voice.clone())
.unwrap_or_default();
let (text, _is_ssml) = preprocess_speech_markdown(text, &self.config.provider_id);
#[cfg(feature = "cloud")]
if self.config.provider_id == "azure" && on_boundary.is_some() {
let Some(on_boundary) = on_boundary.as_mut() else {
return Ok(());
};
let region = self
.credentials
.get("region")
.cloned()
.unwrap_or_else(|| "eastus".into());
let request_id = Uuid::new_v4().to_string().to_lowercase();
let ws_url_str = format!(
"wss://{}.tts.speech.microsoft.com/cognitiveservices/websocket/v1?Ocp-Apim-Subscription-Key={}",
region, self.api_key
);
let ws_url = Url::parse(&ws_url_str)
.map_err(|e| TtsError(format!("Invalid Azure WS URL: {e}")))?;
let (mut socket, _) = connect(ws_url.as_str())
.map_err(|e| TtsError(format!("Azure WS Connect error: {e}")))?;
let output_format = "audio-24khz-96kbitrate-mono-mp3";
let config_headers = format!("X-RequestId:{request_id}\r\nContent-Type:application/json; charset=utf-8\r\nPath:speech.config\r\n\r\n");
let config_body = format!(
r#"{{"context":{{"synthesis":{{"audio":{{"metadataOptions":{{"sentenceBoundaryEnabled":false,"wordBoundaryEnabled":true}},"outputFormat":"{output_format}"}}}}}}}}"#
);
let config_msg = format!("{config_headers}{config_body}");
socket
.send(Message::Text(config_msg.into()))
.map_err(|e| TtsError(format!("WS config send error: {e}")))?;
let ssml = build_azure_ssml(&text, &voice_to_use, rate, pitch);
let ssml_msg = format!(
"X-RequestId:{request_id}\r\nContent-Type:application/ssml+xml\r\nX-StreamId:{request_id}\r\nPath:ssml\r\n\r\n{ssml}"
);
socket
.send(Message::Text(ssml_msg.into()))
.map_err(|e| TtsError(format!("WS ssml send error: {e}")))?;
let mut collected_boundaries = Vec::new();
loop {
let msg = match socket.read() {
Ok(m) => m,
Err(
tungstenite::error::Error::ConnectionClosed
| tungstenite::error::Error::AlreadyClosed,
) => break,
Err(e) => return Err(TtsError(format!("WS receive error: {e}"))),
};
match msg {
Message::Text(t) => {
let text_msg = t.as_str();
let path_line = text_msg.lines().find(|l| l.starts_with("Path:"));
let path = path_line.map_or("", |l| l[5..].trim());
if path == "turn.end" {
let _ = socket.close(None);
break;
}
if path == "audio.metadata" || path == "word-boundary" {
let body = if let Some(idx) = text_msg.find("\r\n\r\n") {
&text_msg[idx + 4..]
} else if let Some(idx) = text_msg.find("\n\n") {
&text_msg[idx + 2..]
} else {
text_msg
};
if let Ok(json) = serde_json::from_str::<serde_json::Value>(body) {
if let Some(metadata) =
json.get("Metadata").and_then(|v| v.as_array())
{
for item in metadata {
if item.get("Type").and_then(|v| v.as_str())
== Some("WordBoundary")
{
if let Some(data) = item.get("Data") {
let offset_ticks = data
.get("Offset")
.and_then(serde_json::Value::as_i64)
.unwrap_or(0);
let duration_ticks = data
.get("Duration")
.and_then(serde_json::Value::as_i64)
.unwrap_or(0);
let word = if let Some(text_obj) =
data.get("text").and_then(|v| v.as_object())
{
text_obj
.get("Text")
.and_then(|v| v.as_str())
.unwrap_or("")
} else if let Some(text_obj) =
data.get("Text").and_then(|v| v.as_object())
{
text_obj
.get("Text")
.and_then(|v| v.as_str())
.unwrap_or("")
} else {
data.get("text")
.and_then(|v| v.as_str())
.unwrap_or("")
};
if !word.is_empty() {
collected_boundaries.push(WordBoundary {
text: word.to_string(),
offset: (offset_ticks / 10_000) as u64,
duration: (duration_ticks / 10_000) as u64,
});
}
}
}
}
}
}
}
}
Message::Binary(b) if b.len() > 2 => {
let header_length = ((b[0] as usize) << 8) | (b[1] as usize);
if b.len() > 2 + header_length {
let audio_start = 2 + header_length;
if let Some(cb) = on_audio.as_mut() {
cb(&b[audio_start..]);
}
}
}
_ => {}
}
}
compute_durations(&mut collected_boundaries);
for b in &collected_boundaries {
on_boundary(
&b.text,
b.offset as f32 / 1000.0,
(b.offset + b.duration) as f32 / 1000.0,
);
}
return Ok(());
}
let mut synth_url = self.config.synth_url.clone();
if self.config.provider_id == "elevenlabs" && on_boundary.is_some() {
synth_url.push_str("/with-timestamps");
}
let mut req = self.client.post(&synth_url);
if !self.config.auth_header.is_empty() {
let val = format!("{}{}", self.config.auth_prefix, self.api_key);
req = req.header(&self.config.auth_header, val);
}
for (k, v) in &self.config.extra_headers {
req = req.header(k.as_str(), v.as_str());
}
let resp = if self.config.body_is_ssml {
let ssml = build_azure_ssml(&text, &voice_to_use, rate, pitch);
let ct = self
.config
.content_type
.as_deref()
.unwrap_or("application/ssml+xml");
req = req.header("Content-Type", ct);
req.body(ssml).send()
} else if self.config.provider_id == "google" {
let (body, _words) = build_google_request(&text, &voice_to_use, on_boundary.is_some());
req = req.json(&body);
req.send()
} else {
let mut body = serde_json::Map::new();
if !self.config.text_field.is_empty() {
body.insert(
self.config.text_field.clone(),
serde_json::Value::String(text.clone()),
);
}
if !self.config.voice_param.is_empty() && !voice_to_use.is_empty() {
body.insert(
self.config.voice_param.clone(),
serde_json::Value::String(voice_to_use.clone()),
);
}
if let Some(ref model_param) = self.config.model_param {
if let Some(ref model) = self.config.model_default {
body.insert(
model_param.clone(),
serde_json::Value::String(model.clone()),
);
}
}
for (k, v) in &self.config.extra_body {
body.insert(k.clone(), v.clone());
}
req = req.json(&serde_json::Value::Object(body));
req.send()
};
let resp = resp.map_err(|e| TtsError(format!("HTTP error: {e}")))?;
if !resp.status().is_success() {
let status = resp.status();
let body_text = resp.text().unwrap_or_default();
return Err(TtsError(format!("API error {status}: {body_text}")));
}
if self.config.provider_id == "elevenlabs" && on_boundary.is_some() {
let resp_text = resp
.text()
.map_err(|e| TtsError(format!("Read error: {e}")))?;
let json: serde_json::Value = serde_json::from_str(&resp_text)
.map_err(|e| TtsError(format!("JSON parse: {e}")))?;
if let Some(b64) = json.get("audio_base64").and_then(|v| v.as_str()) {
use base64::Engine;
let audio_bytes = base64::engine::general_purpose::STANDARD
.decode(b64)
.map_err(|e| TtsError(format!("Base64 decode: {e}")))?;
if let Some(cb) = on_audio.as_mut() {
for chunk in audio_bytes.chunks(8192) {
cb(chunk);
}
}
}
if let Some(cb) = on_boundary.as_mut() {
if let Some(alignment) = json.get("alignment").and_then(|v| v.as_object()) {
let chars = alignment.get("characters").and_then(|v| v.as_array());
let starts = alignment
.get("character_start_times_seconds")
.and_then(|v| v.as_array());
let ends = alignment
.get("character_end_times_seconds")
.and_then(|v| v.as_array());
if let (Some(chars), Some(starts), Some(ends)) = (chars, starts, ends) {
let mut current_word = String::new();
let mut word_start: f32 = 0.0;
let mut has_started = false;
for i in 0..chars.len() {
let char_str = chars[i].as_str().unwrap_or("");
let start_time = starts[i].as_f64().unwrap_or(0.0) as f32;
let end_time = ends[i].as_f64().unwrap_or(0.0) as f32;
if char_str.trim().is_empty() {
if has_started {
cb(¤t_word, word_start, end_time);
current_word.clear();
has_started = false;
}
} else {
if !has_started {
word_start = start_time;
has_started = true;
}
current_word.push_str(char_str);
}
}
if has_started {
let end_time = ends
.last()
.and_then(serde_json::Value::as_f64)
.unwrap_or(0.0) as f32;
cb(¤t_word, word_start, end_time);
}
}
}
}
} else if self.config.provider_id == "google" && on_boundary.is_some() {
let resp_text = resp
.text()
.map_err(|e| TtsError(format!("Read error: {e}")))?;
let json: serde_json::Value = serde_json::from_str(&resp_text)
.map_err(|e| TtsError(format!("JSON parse: {e}")))?;
if let Some(b64) = json.get("audioContent").and_then(|v| v.as_str()) {
use base64::Engine;
let audio_bytes = base64::engine::general_purpose::STANDARD
.decode(b64)
.map_err(|e| TtsError(format!("Base64 decode: {e}")))?;
if let Some(cb) = on_audio.as_mut() {
for chunk in audio_bytes.chunks(8192) {
cb(chunk);
}
}
}
if let Some(cb) = on_boundary.as_mut() {
let (_, words) = build_google_request(&text, &voice_to_use, true);
if let Some(tps) = json.get("timepoints").and_then(|v| v.as_array()) {
let boundaries = parse_google_timepoints(tps, &words);
for b in &boundaries {
cb(
&b.text,
b.offset as f32 / 1000.0,
(b.offset + b.duration) as f32 / 1000.0,
);
}
} else {
let estimated = estimate_word_boundaries(&text);
for b in &estimated {
cb(
&b.text,
b.offset as f32 / 1000.0,
(b.offset + b.duration) as f32 / 1000.0,
);
}
}
}
} else if let Some(cb) = on_audio.as_mut() {
use std::io::Read;
let mut resp = resp;
let mut buffer = [0u8; 8192];
loop {
let n = resp
.read(&mut buffer)
.map_err(|e| TtsError(format!("Read error: {e}")))?;
if n == 0 {
break;
}
cb(&buffer[..n]);
}
if let Some(cb) = on_boundary.as_mut() {
let estimated = estimate_word_boundaries(&text);
for b in &estimated {
cb(
&b.text,
b.offset as f32 / 1000.0,
(b.offset + b.duration) as f32 / 1000.0,
);
}
}
} else {
let _audio_bytes = resp
.bytes()
.map_err(|e| TtsError(format!("Read error: {e}")))?;
}
Ok(())
}
fn speak_sync(
&self,
text: &str,
voice: Option<&str>,
rate: f32,
pitch: f32,
volume: f32,
on_audio: Option<crate::engine::OnAudioCallback>,
on_boundary: Option<crate::engine::OnBoundaryCallback>,
) -> TtsResult<()> {
self.speak(text, voice, rate, pitch, volume, on_audio, on_boundary)
}
fn stop(&self) -> TtsResult<()> {
Ok(())
}
fn get_voices(&self) -> TtsResult<Vec<Voice>> {
let Some(ref voices_url) = self.config.voices_url else {
return Ok(vec![]);
};
let mut req = self.client.get(voices_url.as_str());
if !self.config.auth_header.is_empty() {
let val = format!("{}{}", self.config.auth_prefix, self.api_key);
req = req.header(&self.config.auth_header, val);
}
let resp = req
.send()
.map_err(|e| TtsError(format!("Voice list HTTP error: {e}")))?;
if !resp.status().is_success() {
return Ok(vec![]);
}
let json: serde_json::Value = resp
.json()
.map_err(|e| TtsError(format!("Voice list parse error: {e}")))?;
match self.config.provider_id.as_str() {
"azure" => json
.as_array()
.map_or_else(|| Ok(vec![]), |arr| Ok(map_azure_voices(arr))),
"google" => json
.get("voices")
.and_then(|v| v.as_array())
.map_or_else(|| Ok(vec![]), |arr| Ok(map_google_voices(arr))),
_ => {
json.as_array().map_or_else(
|| Ok(vec![]),
|arr| {
Ok(arr
.iter()
.filter_map(|v| {
let id = v
.get("id")
.or(v.get("voice_id"))
.or(v.get("name"))?
.as_str()?;
Some(Voice {
id: id.to_string(),
name: v
.get("name")
.and_then(|v| v.as_str())
.unwrap_or(id)
.to_string(),
gender: normalize_gender(
v.get("gender")
.or(v.get("labels"))
.and_then(|v| v.as_str())
.unwrap_or(""),
),
provider: self.config.provider_id.clone(),
language_codes: vec![],
})
})
.collect())
},
)
}
}
}
fn engine_id(&self) -> &'static str {
match self.config.provider_id.as_str() {
"openai" => "openai",
"elevenlabs" => "elevenlabs",
"azure" => "azure",
"google" => "google",
"cartesia" => "cartesia",
"deepgram" => "deepgram",
"playht" => "playht",
"fishaudio" => "fishaudio",
"hume" => "hume",
"mistral" => "mistral",
"murf" => "murf",
"resemble" => "resemble",
"unrealspeech" => "unrealspeech",
"upliftai" => "upliftai",
"watson" => "watson",
"witai" => "witai",
"xai" => "xai",
"modelslab" => "modelslab",
"polly" => "polly",
_ => "cloud",
}
}
}
pub fn create_cloud_engine(id: &str, credentials_json: &str) -> Option<Box<dyn TtsEngine>> {
let creds: HashMap<String, String> = if credentials_json.is_empty() {
HashMap::new()
} else {
serde_json::from_str(credentials_json).unwrap_or_default()
};
CloudEngine::new(id, &creds).map(|e| Box::new(e) as Box<dyn TtsEngine>)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_build_azure_ssml() {
let ssml = build_azure_ssml("Hello world", "en-US-AriaNeural", 1.0, 1.0);
assert!(ssml.contains("<speak"));
assert!(ssml.contains("en-US-AriaNeural"));
assert!(ssml.contains("Hello world"));
assert!(!ssml.contains("<prosody"));
}
#[test]
fn test_build_azure_ssml_with_prosody() {
let ssml = build_azure_ssml("Hello world", "en-US-AriaNeural", 1.5, 0.8);
assert!(ssml.contains("<prosody"));
assert!(ssml.contains("rate="));
assert!(ssml.contains("pitch="));
}
#[test]
fn test_build_google_request_basic() {
let (body, words) = build_google_request("Hello world", "en-US-Wavenet-D", false);
assert!(body["input"]["text"].as_str().unwrap() == "Hello world");
assert!(words.is_empty());
}
#[test]
fn test_build_google_request_with_marks() {
let (body, words) = build_google_request("Hello world", "en-US-Wavenet-D", true);
let ssml = body["input"]["ssml"].as_str().unwrap();
assert!(ssml.contains("<mark name=\"0\"/>"));
assert!(ssml.contains("<mark name=\"1\"/>"));
assert_eq!(words.len(), 2);
assert_eq!(words[0], "Hello");
assert_eq!(words[1], "world");
assert!(body.get("enableTimePointing").is_some());
}
#[test]
fn test_parse_google_timepoints() {
let tps = vec![
serde_json::json!({"markName": "0", "timeSeconds": 0.125}),
serde_json::json!({"markName": "1", "timeSeconds": 0.450}),
];
let words = vec!["Hello".to_string(), "world".to_string()];
let boundaries = parse_google_timepoints(&tps, &words);
assert_eq!(boundaries.len(), 2);
assert_eq!(boundaries[0].text, "Hello");
assert_eq!(boundaries[0].offset, 125);
assert_eq!(boundaries[0].duration, 325);
assert_eq!(boundaries[1].text, "world");
assert_eq!(boundaries[1].offset, 450);
}
#[test]
fn test_estimate_word_boundaries() {
let boundaries = estimate_word_boundaries("Hello world this is a test");
assert_eq!(boundaries.len(), 6);
assert_eq!(boundaries[0].text, "Hello");
assert_eq!(boundaries[0].offset, 0);
assert!(boundaries[0].duration > 0);
}
#[test]
fn test_normalize_gender() {
assert_eq!(
super::super::types::normalize_gender("Female"),
super::super::types::Gender::Female
);
assert_eq!(
super::super::types::normalize_gender("male"),
super::super::types::Gender::Male
);
assert_eq!(
super::super::types::normalize_gender(""),
super::super::types::Gender::Unknown
);
}
#[test]
fn test_build_config_all_engines() {
let engines = [
"openai",
"elevenlabs",
"azure",
"google",
"cartesia",
"deepgram",
"playht",
"fishaudio",
"hume",
"mistral",
"murf",
"resemble",
"unrealspeech",
"upliftai",
"watson",
"witai",
"xai",
"modelslab",
"polly",
];
let creds = HashMap::new();
for id in &engines {
assert!(
build_config(id, &creds).is_some(),
"Failed for engine: {id}"
);
}
}
#[test]
fn test_build_config_unknown() {
let creds = HashMap::new();
assert!(build_config("nonexistent", &creds).is_none());
}
#[test]
fn test_azure_ssml_escapes_special_chars() {
let ssml = build_azure_ssml("A & B < C > D", "en-US-AriaNeural", 1.0, 1.0);
assert!(ssml.contains("&"));
assert!(ssml.contains("<"));
assert!(ssml.contains(">"));
}
#[test]
fn test_speech_markdown_preprocessing() {
use crate::engine::preprocess_speech_markdown;
let (result, is_ssml) =
preprocess_speech_markdown("Hello (world)[emphasis:\"strong\"]", "azure");
assert!(is_ssml);
assert!(result.contains("<speak>"));
}
}