use super::error::ElevenLabsError;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub const DEFAULT_OUTPUT_FORMAT: &str = "mp3_44100_128";
pub const TTS_ENDPOINT_PATH: &str = "/v1/text-to-speech";
pub fn get_voice_mappings() -> HashMap<&'static str, &'static str> {
let mut mappings = HashMap::new();
mappings.insert("alloy", "21m00Tcm4TlvDq8ikWAM"); mappings.insert("amber", "5Q0t7uMcjvnagumLfvZi"); mappings.insert("ash", "AZnzlk1XvdvUeBnXmlld"); mappings.insert("august", "D38z5RcWu1voky8WS1ja"); mappings.insert("blue", "2EiwWnXFnvU5JabPnv8n"); mappings.insert("coral", "9BWtsMINqrJLrRacOk9x"); mappings.insert("lily", "EXAVITQu4vr4xnSDxMaL"); mappings.insert("onyx", "29vD33N1CtxCmqQRPOHJ"); mappings.insert("sage", "CwhRBWXzGAHq8TQ4Fs17"); mappings.insert("verse", "CYw3kZ02Hs0563khs1Fj"); mappings
}
pub fn get_format_mappings() -> HashMap<&'static str, &'static str> {
let mut mappings = HashMap::new();
mappings.insert("mp3", "mp3_44100_128");
mappings.insert("pcm", "pcm_44100");
mappings.insert("opus", "opus_48000_128");
mappings
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct VoiceSettings {
#[serde(skip_serializing_if = "Option::is_none")]
pub stability: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub similarity_boost: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub style: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub use_speaker_boost: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub speed: Option<f32>,
}
#[derive(Debug, Clone, Serialize)]
pub struct TextToSpeechRequest {
pub text: String,
pub model_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub voice_settings: Option<VoiceSettings>,
#[serde(skip_serializing_if = "Option::is_none")]
pub pronunciation_dictionary_locators: Option<Vec<PronunciationDictionaryLocator>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub previous_text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub next_text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub previous_request_ids: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub next_request_ids: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PronunciationDictionaryLocator {
pub pronunciation_dictionary_id: String,
pub version_id: String,
}
#[derive(Debug)]
pub struct TextToSpeechResponse {
pub audio_data: Vec<u8>,
pub content_type: String,
pub character_cost: Option<u32>,
pub request_id: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum TTSModel {
MonolingualV1,
MultilingualV1,
#[default]
MultilingualV2,
TurboV2,
TurboV2_5,
}
impl TTSModel {
pub fn as_str(&self) -> &'static str {
match self {
TTSModel::MonolingualV1 => "eleven_monolingual_v1",
TTSModel::MultilingualV1 => "eleven_multilingual_v1",
TTSModel::MultilingualV2 => "eleven_multilingual_v2",
TTSModel::TurboV2 => "eleven_turbo_v2",
TTSModel::TurboV2_5 => "eleven_turbo_v2_5",
}
}
pub fn parse(s: &str) -> Option<Self> {
match s {
"eleven_monolingual_v1" => Some(TTSModel::MonolingualV1),
"eleven_multilingual_v1" => Some(TTSModel::MultilingualV1),
"eleven_multilingual_v2" => Some(TTSModel::MultilingualV2),
"eleven_turbo_v2" => Some(TTSModel::TurboV2),
"eleven_turbo_v2_5" => Some(TTSModel::TurboV2_5),
_ => None,
}
}
}
pub fn resolve_voice_id(voice: &str) -> Result<String, ElevenLabsError> {
let voice_mappings = get_voice_mappings();
let normalized = voice.trim().to_lowercase();
if let Some(&voice_id) = voice_mappings.get(normalized.as_str()) {
return Ok(voice_id.to_string());
}
if voice.trim().is_empty() {
return Err(ElevenLabsError::invalid_request(
"elevenlabs",
"Voice ID is required",
));
}
Ok(voice.trim().to_string())
}
pub fn map_output_format(format: Option<&str>) -> &'static str {
let format_mappings = get_format_mappings();
match format {
Some(f) => format_mappings
.get(f)
.copied()
.unwrap_or(DEFAULT_OUTPUT_FORMAT),
None => DEFAULT_OUTPUT_FORMAT,
}
}
pub fn build_tts_url(base_url: &str, voice_id: &str, output_format: Option<&str>) -> String {
let base = base_url.trim_end_matches('/');
let format = map_output_format(output_format);
format!(
"{}{}/{}?output_format={}",
base, TTS_ENDPOINT_PATH, voice_id, format
)
}
pub fn supported_output_formats() -> &'static [&'static str] {
&[
"mp3_22050_32",
"mp3_44100_32",
"mp3_44100_64",
"mp3_44100_96",
"mp3_44100_128",
"mp3_44100_192",
"pcm_16000",
"pcm_22050",
"pcm_24000",
"pcm_44100",
"ulaw_8000",
"opus_32000_24",
"opus_48000_24",
"opus_48000_64",
"opus_48000_128",
]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_voice_mappings() {
let mappings = get_voice_mappings();
assert_eq!(mappings.get("alloy"), Some(&"21m00Tcm4TlvDq8ikWAM"));
assert_eq!(mappings.get("onyx"), Some(&"29vD33N1CtxCmqQRPOHJ"));
}
#[test]
fn test_format_mappings() {
let mappings = get_format_mappings();
assert_eq!(mappings.get("mp3"), Some(&"mp3_44100_128"));
assert_eq!(mappings.get("pcm"), Some(&"pcm_44100"));
assert_eq!(mappings.get("opus"), Some(&"opus_48000_128"));
}
#[test]
fn test_resolve_voice_id_openai_name() {
let voice_id = resolve_voice_id("alloy").unwrap();
assert_eq!(voice_id, "21m00Tcm4TlvDq8ikWAM");
}
#[test]
fn test_resolve_voice_id_direct_id() {
let voice_id = resolve_voice_id("custom-voice-id-123").unwrap();
assert_eq!(voice_id, "custom-voice-id-123");
}
#[test]
fn test_resolve_voice_id_empty() {
let result = resolve_voice_id("");
assert!(result.is_err());
}
#[test]
fn test_map_output_format() {
assert_eq!(map_output_format(Some("mp3")), "mp3_44100_128");
assert_eq!(map_output_format(Some("pcm")), "pcm_44100");
assert_eq!(map_output_format(None), "mp3_44100_128");
assert_eq!(map_output_format(Some("unknown")), "mp3_44100_128");
}
#[test]
fn test_build_tts_url() {
let url = build_tts_url("https://api.elevenlabs.io", "voice-123", Some("mp3"));
assert_eq!(
url,
"https://api.elevenlabs.io/v1/text-to-speech/voice-123?output_format=mp3_44100_128"
);
}
#[test]
fn test_tts_model_as_str() {
assert_eq!(TTSModel::MultilingualV2.as_str(), "eleven_multilingual_v2");
assert_eq!(TTSModel::TurboV2_5.as_str(), "eleven_turbo_v2_5");
}
#[test]
fn test_tts_model_from_str() {
assert_eq!(
TTSModel::parse("eleven_multilingual_v2"),
Some(TTSModel::MultilingualV2)
);
assert_eq!(TTSModel::parse("unknown"), None);
}
#[test]
fn test_voice_settings_serialization() {
let settings = VoiceSettings {
stability: Some(0.5),
similarity_boost: Some(0.75),
speed: Some(1.2),
..Default::default()
};
let json = serde_json::to_value(&settings).unwrap();
assert!((json["stability"].as_f64().unwrap() - 0.5).abs() < 0.01);
assert!((json["similarity_boost"].as_f64().unwrap() - 0.75).abs() < 0.01);
assert!((json["speed"].as_f64().unwrap() - 1.2).abs() < 0.01);
}
}