use anyhow::{Context, Result};
use reqwest::Client;
use serde::Deserialize;
use tokio::time::{Duration, sleep};
use super::openai_tts::build_endpoint_url;
const POLL_TIMEOUT: Duration = Duration::from_secs(120);
const POLL_INTERVAL: Duration = Duration::from_millis(500);
pub struct VoiceboxTts {
client: Client,
base_url: String,
profile_id: String,
}
#[derive(Deserialize)]
struct GenerateResponse {
#[serde(default)]
id: String,
#[serde(default)]
status: String,
#[serde(default)]
audio_path: String,
#[serde(default)]
#[allow(dead_code)]
duration: f64,
#[serde(default)]
#[allow(dead_code)]
error: Option<String>,
}
#[derive(Deserialize)]
struct StatusResponse {
#[serde(default)]
status: String,
#[serde(default)]
#[allow(dead_code)]
duration: f64,
#[serde(default)]
error: Option<String>,
}
impl VoiceboxTts {
pub fn new(base_url: &str, profile_id: &str) -> Self {
Self {
client: Client::new(),
base_url: base_url.to_string(),
profile_id: profile_id.to_string(),
}
}
pub async fn synthesize(&self, text: &str) -> Result<Vec<u8>> {
if text.is_empty() {
anyhow::bail!("Cannot synthesize empty text");
}
let generate_url = build_endpoint_url(&self.base_url, "generate")?;
let body = serde_json::json!({
"profile_id": self.profile_id,
"text": text,
});
let response = self
.client
.post(&generate_url)
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.context("Failed to send Voicebox TTS request")?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
anyhow::bail!("Voicebox TTS error ({}): {}", status, error_text);
}
let result: GenerateResponse = response
.json()
.await
.context("Failed to parse Voicebox TTS response")?;
let generation_id = result.id;
if result.status == "completed" && !result.audio_path.is_empty() {
return self.fetch_audio(&result.audio_path, result.duration).await;
}
self.poll_until_completed(&generation_id).await?;
self.fetch_audio_by_id(&generation_id).await
}
async fn poll_until_completed(&self, generation_id: &str) -> Result<StatusResponse> {
let status_url = build_endpoint_url(
&self.base_url,
&format!("generate/{}/status", generation_id),
)?;
let deadline = tokio::time::Instant::now() + POLL_TIMEOUT;
loop {
if tokio::time::Instant::now() > deadline {
anyhow::bail!(
"Voicebox TTS timed out after {}s waiting for generation {}",
POLL_TIMEOUT.as_secs(),
generation_id
);
}
sleep(POLL_INTERVAL).await;
let response = self
.client
.get(&status_url)
.send()
.await
.context("Failed to poll Voicebox generation status")?;
if response.status() == reqwest::StatusCode::NO_CONTENT {
continue;
}
if !response.status().is_success() {
continue;
}
let text = response.text().await.unwrap_or_default();
if text.is_empty() || text.trim().is_empty() {
continue;
}
let json_text = text
.lines()
.find(|l| l.starts_with("data: "))
.map(|l| l.trim_start_matches("data: ").trim())
.unwrap_or(text.trim());
let result: StatusResponse = serde_json::from_str(json_text).with_context(|| {
format!(
"Failed to parse Voicebox status response: '{}'",
json_text.chars().take(200).collect::<String>()
)
})?;
match result.status.as_str() {
"completed" => return Ok(result),
"failed" | "error" => {
anyhow::bail!(
"Voicebox TTS generation failed: {}",
result.error.unwrap_or_else(|| "Unknown error".to_string())
);
}
"generating" | "queued" | "pending" => continue,
other => {
tracing::warn!(
"Voicebox TTS: unexpected status '{}', continuing to poll",
other
);
continue;
}
}
}
}
async fn fetch_audio_by_id(&self, generation_id: &str) -> Result<Vec<u8>> {
let audio_url = build_endpoint_url(&self.base_url, &format!("audio/{}", generation_id))?;
let resp = self
.client
.get(&audio_url)
.send()
.await
.with_context(|| format!("Failed to fetch audio from {}", audio_url))?;
let content_type = resp
.headers()
.get(reqwest::header::CONTENT_TYPE)
.map(|v| v.to_str().unwrap_or("unknown"))
.unwrap_or("unknown")
.to_string();
let resp = resp
.error_for_status()
.with_context(|| format!("Voicebox audio fetch returned error for {}", audio_url))?;
let audio_bytes = resp
.bytes()
.await
.with_context(|| "Failed to read audio bytes from response")?
.to_vec();
let detected_format = if audio_bytes.starts_with(b"RIFF") {
"WAV"
} else if audio_bytes.starts_with(b"ID3") || audio_bytes.starts_with(b"\xff\xfb") {
"MP3"
} else if audio_bytes.starts_with(b"OggS") {
"OGG"
} else if audio_bytes.starts_with(&[0x00, 0x00, 0x00, 0x1c, 0x66, 0x74, 0x79, 0x70]) {
"M4A"
} else {
"unknown"
};
tracing::info!(
"Voicebox TTS: fetched {} bytes (content_type={}, detected={}, profile={}, generation={})",
audio_bytes.len(),
content_type,
detected_format,
self.profile_id,
generation_id,
);
Ok(audio_bytes)
}
async fn fetch_audio(&self, audio_path: &str, _duration: f64) -> Result<Vec<u8>> {
let audio_url = if audio_path.starts_with("http://") || audio_path.starts_with("https://") {
audio_path.to_string()
} else {
build_endpoint_url(&self.base_url, audio_path.trim_start_matches('/'))?
};
let resp = self
.client
.get(&audio_url)
.send()
.await
.with_context(|| format!("Failed to fetch audio from {}", audio_url))?;
let content_type = resp
.headers()
.get(reqwest::header::CONTENT_TYPE)
.map(|v| v.to_str().unwrap_or("unknown"))
.unwrap_or("unknown")
.to_string();
let resp = resp
.error_for_status()
.with_context(|| format!("Voicebox audio fetch returned error for {}", audio_url))?;
let audio_bytes = resp
.bytes()
.await
.with_context(|| "Failed to read audio bytes from response")?
.to_vec();
let detected_format = if audio_bytes.starts_with(b"RIFF") {
"WAV"
} else if audio_bytes.starts_with(b"ID3") || audio_bytes.starts_with(b"\xff\xfb") {
"MP3"
} else if audio_bytes.starts_with(b"OggS") {
"OGG"
} else if audio_bytes.starts_with(&[0x00, 0x00, 0x00, 0x1c, 0x66, 0x74, 0x79, 0x70]) {
"M4A"
} else {
"unknown"
};
tracing::info!(
"Voicebox TTS: fetched {} bytes (content_type={}, detected={}, profile={}, path={})",
audio_bytes.len(),
content_type,
detected_format,
self.profile_id,
audio_path,
);
Ok(audio_bytes)
}
}