use crate::tts::{AudioFormat, Speaker, SynthesizedAudio};
use crate::{Result, VoiceConfig, VoiceError};
use async_trait::async_trait;
use futures::StreamExt;
pub struct LocalTtsSpeaker {
client: reqwest::Client,
base_url: String,
model: String,
voice: String,
speed: f32,
temperature: f32,
ref_audio: Option<String>,
ref_text: Option<String>,
instruct: Option<String>,
}
impl std::fmt::Debug for LocalTtsSpeaker {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LocalTtsSpeaker")
.field("base_url", &self.base_url)
.field("model", &self.model)
.field("voice", &self.voice)
.field("speed", &self.speed)
.finish()
}
}
impl LocalTtsSpeaker {
pub fn from_config(config: &VoiceConfig) -> Self {
Self {
client: reqwest::Client::new(),
base_url: config.local_tts_url.clone(),
model: config.local_tts_model.clone(),
voice: config.local_tts_voice.clone(),
speed: config.local_tts_speed,
temperature: config.local_tts_temperature,
ref_audio: config.local_tts_ref_audio.clone(),
ref_text: config.local_tts_ref_text.clone(),
instruct: config.local_tts_instruct.clone(),
}
}
fn request_body(&self, text: &str) -> serde_json::Value {
let mut body = serde_json::json!({
"model": self.model,
"input": text,
"voice": self.voice,
"speed": self.speed,
"response_format": "mp3",
});
let obj = body.as_object_mut().unwrap();
if let Some(ref ref_audio) = self.ref_audio {
obj.insert("ref_audio".into(), serde_json::json!(ref_audio));
}
if let Some(ref ref_text) = self.ref_text {
obj.insert("ref_text".into(), serde_json::json!(ref_text));
}
if let Some(ref instruct) = self.instruct {
obj.insert("instruct".into(), serde_json::json!(instruct));
}
if self.temperature != 0.7 {
obj.insert("temperature".into(), serde_json::json!(self.temperature));
}
body
}
}
#[async_trait]
impl Speaker for LocalTtsSpeaker {
async fn synth(&self, text: &str) -> Result<SynthesizedAudio> {
if text.trim().is_empty() {
return Err(VoiceError::Tts("empty text".into()));
}
let url = format!("{}/audio/speech", self.base_url);
let resp = self
.client
.post(&url)
.json(&self.request_body(text))
.send()
.await
.map_err(|e| VoiceError::Tts(format!("local TTS at {url}: {e}")))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(VoiceError::Tts(format!("local TTS {status}: {body}")));
}
let mut bytes = Vec::with_capacity(64 * 1024);
let mut stream = resp.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(|e| VoiceError::Tts(format!("stream: {e}")))?;
bytes.extend_from_slice(&chunk);
}
if bytes.is_empty() {
return Err(VoiceError::Tts(
"empty audio response from local TTS".into(),
));
}
Ok(SynthesizedAudio {
bytes,
format: AudioFormat::Mp3,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn request_body_has_required_fields() {
let speaker = LocalTtsSpeaker::from_config(&VoiceConfig::default());
let body = speaker.request_body("hello");
assert_eq!(body["input"], "hello");
assert_eq!(body["voice"], "af_heart");
assert_eq!(body["response_format"], "mp3");
assert!(body.get("model").is_some());
}
#[test]
fn request_body_includes_voice_cloning_params() {
let config = VoiceConfig {
local_tts_ref_audio: Some("/tmp/ref.wav".into()),
local_tts_ref_text: Some("reference text".into()),
..VoiceConfig::default()
};
let speaker = LocalTtsSpeaker::from_config(&config);
let body = speaker.request_body("test");
assert_eq!(body["ref_audio"], "/tmp/ref.wav");
assert_eq!(body["ref_text"], "reference text");
}
#[test]
fn request_body_includes_instruct_param() {
let config = VoiceConfig {
local_tts_instruct: Some("warm female voice with Southern accent".into()),
..VoiceConfig::default()
};
let speaker = LocalTtsSpeaker::from_config(&config);
let body = speaker.request_body("test");
assert_eq!(body["instruct"], "warm female voice with Southern accent");
}
#[tokio::test]
async fn synth_rejects_empty_text() {
let speaker = LocalTtsSpeaker::from_config(&VoiceConfig::default());
let err = speaker.synth(" ").await.unwrap_err();
assert!(matches!(err, VoiceError::Tts(_)));
}
}