use super::error::{PocketTTSError, Result};
use super::model::ModelVariant;
use super::voices::PredefinedVoice;
use crate::{AudioData, SpeechRequest, SpeechResponse, VoiceIdentifier};
use pocket_tts::TTSModel;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
pub struct PocketTTSBackend {
pub(crate) model: TTSModel,
#[allow(dead_code)]
temperature: f32,
#[allow(dead_code)]
lsd_decode_steps: usize,
#[allow(dead_code)]
eos_threshold: f32,
#[allow(dead_code)]
noise_clamp: Option<f32>,
voice_cache: Arc<RwLock<HashMap<String, pocket_tts::ModelState>>>,
}
impl PocketTTSBackend {
pub fn new(
model_variant: ModelVariant,
temperature: f32,
lsd_decode_steps: usize,
eos_threshold: f32,
noise_clamp: Option<f32>,
) -> Result<Self> {
let device = candle_core::Device::Cpu;
let model = TTSModel::load_with_params_device(
model_variant.hf_id(),
temperature,
lsd_decode_steps,
eos_threshold,
noise_clamp,
&device,
)?;
Ok(Self {
model,
temperature,
lsd_decode_steps,
eos_threshold,
noise_clamp,
voice_cache: Arc::new(RwLock::new(HashMap::new())),
})
}
fn samples_to_audio_data(samples: Vec<f32>, sample_rate: u32) -> AudioData {
AudioData {
samples,
channels: 1, sample_rate,
}
}
pub async fn generate(&self, request: SpeechRequest) -> Result<SpeechResponse> {
let voice_state = self.resolve_voice(&request.voice)?;
let model = self.model.clone();
let text = request.text.clone();
let sample_rate = self.model.sample_rate as u32;
let result = tokio::task::spawn_blocking(move || {
model.generate(&text, &voice_state).map_err(|e| {
PocketTTSError::generation_error_detailed(
e.to_string(),
"model generation",
format!("text length: {}", text.len()),
)
})
})
.await
.map_err(|e| {
PocketTTSError::generation_error_detailed(
format!("Task join error: {}", e),
"spawn_blocking",
"generation task failed to join",
)
})??;
let tensor = if result.dims().len() > 1 {
result.squeeze(0).map_err(|e| {
PocketTTSError::tensor_error(
format!("Failed to squeeze tensor: {}", e),
"1D or 2D",
format!("{:?}", result.dims()),
)
})?
} else {
result
};
let samples = tensor.to_vec1::<f32>().map_err(|e| {
PocketTTSError::generation_error_detailed(
format!("Failed to extract samples: {}", e),
"tensor conversion",
format!("tensor shape: {:?}", tensor.dims()),
)
})?;
let audio_data = Self::samples_to_audio_data(samples, sample_rate);
let duration_ms =
(audio_data.samples.len() as f64 / audio_data.sample_rate as f64 * 1000.0) as u64;
Ok(SpeechResponse {
text: request.text,
audio: audio_data,
duration_ms,
})
}
pub async fn generate_stream(
&self,
request: SpeechRequest,
) -> Result<impl futures::Stream<Item = Result<SpeechResponse>> + Send> {
let voice_state = self.resolve_voice(&request.voice)?;
let sample_rate = self.model.sample_rate as u32;
let text = request.text.clone();
let model = self.model.clone();
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
tokio::task::spawn_blocking(move || {
let stream_iter = model.generate_stream(&text, &voice_state);
for (idx, result) in stream_iter.enumerate() {
let response = result
.map_err(|e| {
PocketTTSError::generation_error_detailed(
e.to_string(),
"streaming generation",
format!("chunk index: {}", idx),
)
})
.and_then(|tensor| {
let tensor = tensor.squeeze(0).and_then(|t| t.squeeze(0)).map_err(|e| {
PocketTTSError::tensor_error(
format!("Failed to squeeze tensor: {}", e),
"[samples]",
format!("{:?}", tensor.dims()),
)
})?;
let samples = tensor.to_vec1::<f32>().map_err(|e| {
PocketTTSError::generation_error_detailed(
format!("Failed to extract samples: {}", e),
"streaming tensor conversion",
format!("chunk {}, tensor shape: {:?}", idx, tensor.dims()),
)
})?;
let audio_data = Self::samples_to_audio_data(samples, sample_rate);
let duration_ms = (audio_data.samples.len() as f64
/ audio_data.sample_rate as f64
* 1000.0) as u64;
Ok(SpeechResponse {
text: format!("{}_{}", text, idx),
audio: audio_data,
duration_ms,
})
});
if tx.send(response).is_err() {
break;
}
}
});
Ok(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
}
fn resolve_voice(&self, voice_id: &VoiceIdentifier) -> Result<pocket_tts::ModelState> {
self.load_predefined_voice(&voice_id.name)
}
fn load_predefined_voice(&self, name: &str) -> Result<pocket_tts::ModelState> {
{
let cache = self.voice_cache.read().map_err(|e| {
PocketTTSError::cache_error(
format!("Cache lock poisoned: {}", e),
"voice cache read",
)
})?;
if let Some(state) = cache.get(name) {
return Ok(state.clone());
}
}
let voice: PredefinedVoice = name.parse().map_err(|e: String| {
PocketTTSError::voice_error_detailed(e, name.to_string(), "parsing voice name")
})?;
let hf_path = voice.hf_path();
let local_path = pocket_tts::weights::download_if_necessary(&hf_path).map_err(|e| {
PocketTTSError::download_error(
format!("Failed to download voice embeddings: {}", e),
name.to_string(),
hf_path.clone(),
)
})?;
let state = self
.model
.get_voice_state_from_prompt_file(&local_path)
.map_err(|e| {
PocketTTSError::voice_error_detailed(
e.to_string(),
name.to_string(),
"loading voice embeddings from file",
)
})?;
{
let mut cache = self.voice_cache.write().map_err(|e| {
PocketTTSError::cache_error(
format!("Cache lock poisoned: {}", e),
"voice cache write",
)
})?;
cache.insert(name.to_string(), state.clone());
}
Ok(state)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[ignore = "requires HuggingFace model download"]
fn test_backend_creation() {
let result = PocketTTSBackend::new(ModelVariant::default(), 0.7, 1, -4.0, None);
assert!(result.is_ok());
}
}