use anyhow::{Context, Result};
use reqwest::Client;
use serde::Deserialize;
use crate::config::VoiceConfig;
const GROQ_TRANSCRIPTION_URL: &str = "https://api.groq.com/openai/v1/audio/transcriptions";
const OPENAI_SPEECH_URL: &str = "https://api.openai.com/v1/audio/speech";
pub async fn transcribe_audio(audio_bytes: Vec<u8>, groq_api_key: &str) -> Result<String> {
transcribe_audio_with_url(audio_bytes, groq_api_key, GROQ_TRANSCRIPTION_URL).await
}
pub async fn transcribe(audio_bytes: Vec<u8>, voice_config: &VoiceConfig) -> Result<String> {
let primary = resolve_primary_stt(voice_config);
let chain = resolve_fallback_chain(voice_config, primary);
let candidates: Vec<SttProviderKind> = std::iter::once(primary).chain(chain).collect();
let mut attempts: Vec<String> = Vec::with_capacity(candidates.len());
let mut audio = Some(audio_bytes);
for kind in candidates {
let bytes = match audio.take() {
Some(b) => {
if !attempts.is_empty() {
}
b
}
None => break,
};
let bytes_for_next = bytes.clone();
attempts.push(kind.label());
match try_stt(kind, bytes, voice_config).await {
Ok(text) => {
if attempts.len() > 1 {
tracing::info!(
"STT recovered via fallback chain: {} (attempts: {})",
kind.label(),
attempts.join(" → "),
);
}
return Ok(text);
}
Err(e) => {
tracing::warn!(
"STT provider '{}' failed: {} — trying next in chain",
kind.label(),
e,
);
if let Some(s) = attempts.last_mut() {
s.push_str(&format!(": {e}"));
}
audio = Some(bytes_for_next);
}
}
}
anyhow::bail!(
"All STT providers failed. Attempts:\n - {}",
attempts.join("\n - "),
)
}
fn resolve_primary_stt(cfg: &VoiceConfig) -> SttProviderKind {
if cfg.voicebox_stt_enabled {
SttProviderKind::Voicebox
} else if cfg.stt_base_url.is_some() && cfg.stt_model.is_some() && cfg.stt_api_key.is_some() {
SttProviderKind::OpenAiCompatible
} else if cfg
.stt_provider
.as_ref()
.and_then(|p| p.api_key.as_ref())
.is_some()
{
SttProviderKind::Groq
} else {
SttProviderKind::Local
}
}
pub(crate) fn resolve_fallback_chain(
cfg: &VoiceConfig,
primary: SttProviderKind,
) -> Vec<SttProviderKind> {
let user_chain = &cfg.stt_fallback_chain;
let raw: Vec<SttProviderKind> = if user_chain.is_empty() {
vec![
SttProviderKind::Voicebox,
SttProviderKind::OpenAiCompatible,
SttProviderKind::Groq,
SttProviderKind::Local,
]
} else {
user_chain
.iter()
.filter_map(|s| SttProviderKind::from_label(s))
.collect()
};
raw.into_iter()
.filter(|k| *k != primary && provider_is_configured(*k, cfg))
.collect()
}
fn provider_is_configured(kind: SttProviderKind, cfg: &VoiceConfig) -> bool {
match kind {
SttProviderKind::Voicebox => cfg.voicebox_stt_enabled,
SttProviderKind::OpenAiCompatible => {
cfg.stt_base_url.is_some() && cfg.stt_model.is_some() && cfg.stt_api_key.is_some()
}
SttProviderKind::Groq => cfg
.stt_provider
.as_ref()
.and_then(|p| p.api_key.as_ref())
.is_some(),
SttProviderKind::Local => cfg!(feature = "local-stt"),
}
}
async fn try_stt(kind: SttProviderKind, audio_bytes: Vec<u8>, cfg: &VoiceConfig) -> Result<String> {
match kind {
SttProviderKind::Voicebox => {
super::voicebox_stt::transcribe(audio_bytes, &cfg.voicebox_stt_base_url).await
}
SttProviderKind::OpenAiCompatible => {
let base = cfg
.stt_base_url
.as_ref()
.ok_or_else(|| anyhow::anyhow!("openai-compatible STT missing base_url"))?;
let model = cfg
.stt_model
.as_ref()
.ok_or_else(|| anyhow::anyhow!("openai-compatible STT missing model"))?;
let key = cfg
.stt_api_key
.as_ref()
.ok_or_else(|| anyhow::anyhow!("openai-compatible STT missing api_key"))?;
super::openai_stt::transcribe_audio(audio_bytes, key, model, base).await
}
SttProviderKind::Groq => {
let key = cfg
.stt_provider
.as_ref()
.and_then(|p| p.api_key.as_ref())
.ok_or_else(|| anyhow::anyhow!("Groq STT missing api_key"))?;
transcribe_audio(audio_bytes, key).await
}
SttProviderKind::Local => {
#[cfg(feature = "local-stt")]
{
if !super::local_stt_available() {
anyhow::bail!("Local STT is not supported on this CPU — AVX2 is required.");
}
let model_id = &cfg.local_stt_model;
let preset = super::local_whisper::find_local_model(model_id)
.ok_or_else(|| anyhow::anyhow!("Unknown local STT model: {}", model_id))?;
if !super::local_whisper::is_model_downloaded(preset) {
tracing::info!(
"Local STT model '{}' not in candle format — downloading automatically",
model_id
);
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
let preset_id = model_id.to_string();
let download_preset = super::local_whisper::find_local_model(&preset_id)
.ok_or_else(|| anyhow::anyhow!("Unknown model: {}", preset_id))?;
let download_handle = tokio::spawn(async move {
super::local_whisper::download_model(download_preset, tx).await
});
while let Some(progress) = rx.recv().await {
if progress.done {
break;
}
if let Some(err) = progress.error {
anyhow::bail!("Model download failed: {}", err);
}
}
download_handle
.await
.map_err(|e| anyhow::anyhow!("Download task failed: {}", e))??;
tracing::info!("Local STT model '{}' downloaded successfully", model_id);
}
tracing::info!("Local STT: transcribing with model {}", model_id);
transcribe_audio_local(audio_bytes, model_id.clone()).await
}
#[cfg(not(feature = "local-stt"))]
{
let _ = audio_bytes;
anyhow::bail!("Local STT feature not compiled in")
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum SttProviderKind {
Voicebox,
OpenAiCompatible,
Groq,
Local,
}
impl SttProviderKind {
pub(crate) fn label(self) -> String {
match self {
SttProviderKind::Voicebox => "voicebox".to_string(),
SttProviderKind::OpenAiCompatible => "openai_compatible".to_string(),
SttProviderKind::Groq => "groq".to_string(),
SttProviderKind::Local => "local".to_string(),
}
}
pub(crate) fn from_label(s: &str) -> Option<Self> {
match s.trim().to_lowercase().as_str() {
"voicebox" => Some(SttProviderKind::Voicebox),
"openai_compatible" | "openai-compatible" | "openai" => {
Some(SttProviderKind::OpenAiCompatible)
}
"groq" => Some(SttProviderKind::Groq),
"local" | "whisper" | "local_whisper" => Some(SttProviderKind::Local),
_ => None,
}
}
}
pub(crate) async fn transcribe_audio_with_url(
audio_bytes: Vec<u8>,
api_key: &str,
url: &str,
) -> Result<String> {
let client = Client::new();
let file_part = reqwest::multipart::Part::bytes(audio_bytes)
.file_name("voice.ogg")
.mime_str("audio/ogg")?;
let form = reqwest::multipart::Form::new()
.part("file", file_part)
.text("model", "whisper-large-v3-turbo")
.text("response_format", "json");
let response = client
.post(url)
.header("Authorization", format!("Bearer {}", api_key))
.multipart(form)
.send()
.await
.context("Failed to send audio to Groq Whisper")?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
anyhow::bail!("Groq STT error ({}): {}", status, error_text);
}
let result: TranscriptionResponse = response
.json()
.await
.context("Failed to parse Groq transcription response")?;
tracing::info!("Groq STT: transcribed {} chars", result.text.len());
Ok(result.text)
}
pub async fn synthesize(text: &str, voice_config: &VoiceConfig) -> Result<Vec<u8>> {
if text.is_empty() {
anyhow::bail!("Cannot synthesize empty text");
}
let primary = resolve_primary_tts(voice_config);
let chain = resolve_tts_fallback_chain(voice_config, primary);
let candidates: Vec<TtsProviderKind> = std::iter::once(primary).chain(chain).collect();
let mut attempts: Vec<String> = Vec::with_capacity(candidates.len());
let mut audio: Option<Vec<u8>> = None;
for kind in &candidates {
attempts.push(kind.label());
match try_tts(*kind, text, voice_config).await {
Ok(bytes) => {
if attempts.len() > 1 {
tracing::info!(
"TTS recovered via fallback chain: {} (attempts: {})",
kind.label(),
attempts.join(" → "),
);
}
audio = Some(bytes);
break;
}
Err(e) => {
tracing::warn!(
"TTS provider '{}' failed: {} — trying next in chain",
kind.label(),
e,
);
if let Some(s) = attempts.last_mut() {
s.push_str(&format!(": {e}"));
}
}
}
}
let audio = audio.ok_or_else(|| {
anyhow::anyhow!(
"All TTS providers failed. Attempts:\n - {}",
attempts.join("\n - "),
)
})?;
Ok(ensure_opus(audio).await)
}
fn resolve_primary_tts(cfg: &VoiceConfig) -> TtsProviderKind {
if cfg.voicebox_tts_enabled {
TtsProviderKind::Voicebox
} else if cfg.tts_base_url.is_some() && cfg.tts_api_key.is_some() {
TtsProviderKind::OpenAiCompatible
} else if cfg
.tts_provider
.as_ref()
.and_then(|p| p.api_key.as_ref())
.is_some()
{
TtsProviderKind::OpenAi
} else {
TtsProviderKind::Local
}
}
pub(crate) fn resolve_tts_fallback_chain(
cfg: &VoiceConfig,
primary: TtsProviderKind,
) -> Vec<TtsProviderKind> {
let user_chain = &cfg.tts_fallback_chain;
let raw: Vec<TtsProviderKind> = if user_chain.is_empty() {
vec![
TtsProviderKind::Voicebox,
TtsProviderKind::OpenAiCompatible,
TtsProviderKind::OpenAi,
TtsProviderKind::Local,
]
} else {
user_chain
.iter()
.filter_map(|s| TtsProviderKind::from_label(s))
.collect()
};
raw.into_iter()
.filter(|k| *k != primary && tts_provider_is_configured(*k, cfg))
.collect()
}
fn tts_provider_is_configured(kind: TtsProviderKind, cfg: &VoiceConfig) -> bool {
match kind {
TtsProviderKind::Voicebox => cfg.voicebox_tts_enabled,
TtsProviderKind::OpenAiCompatible => {
cfg.tts_base_url.is_some() && cfg.tts_api_key.is_some()
}
TtsProviderKind::OpenAi => cfg
.tts_provider
.as_ref()
.and_then(|p| p.api_key.as_ref())
.is_some(),
TtsProviderKind::Local => cfg!(feature = "local-tts"),
}
}
async fn try_tts(kind: TtsProviderKind, text: &str, cfg: &VoiceConfig) -> Result<Vec<u8>> {
match kind {
TtsProviderKind::Voicebox => {
tracing::info!(
"TTS dispatch → Voicebox (base_url={}, profile_id={}, engine={})",
cfg.voicebox_tts_base_url,
cfg.voicebox_tts_profile_id,
cfg.voicebox_tts_engine
);
let client = super::voicebox_tts::VoiceboxTts::new(
&cfg.voicebox_tts_base_url,
&cfg.voicebox_tts_profile_id,
&cfg.voicebox_tts_engine,
);
client.synthesize(text).await
}
TtsProviderKind::OpenAiCompatible => {
let base_url = cfg
.tts_base_url
.as_ref()
.ok_or_else(|| anyhow::anyhow!("openai-compatible TTS missing base_url"))?;
let api_key = cfg
.tts_api_key
.as_ref()
.ok_or_else(|| anyhow::anyhow!("openai-compatible TTS missing api_key"))?;
tracing::info!(
"TTS dispatch → OpenAI-compatible (base_url={}, model={}, voice={})",
base_url,
cfg.tts_model,
cfg.tts_voice
);
super::openai_tts::synthesize_speech(
text,
api_key,
&cfg.tts_voice,
&cfg.tts_model,
base_url,
)
.await
}
TtsProviderKind::OpenAi => {
let api_key = cfg
.tts_provider
.as_ref()
.and_then(|p| p.api_key.as_ref())
.ok_or_else(|| anyhow::anyhow!("OpenAI TTS missing api_key"))?;
tracing::info!(
"TTS dispatch → OpenAI (model={}, voice={})",
cfg.tts_model,
cfg.tts_voice
);
synthesize_speech(text, api_key, &cfg.tts_voice, &cfg.tts_model).await
}
TtsProviderKind::Local => {
#[cfg(feature = "local-tts")]
{
let voice_id = cfg.local_tts_voice.clone();
tracing::info!("TTS dispatch → Local Piper (voice={})", voice_id);
synthesize_speech_local(text, &voice_id).await
}
#[cfg(not(feature = "local-tts"))]
{
let _ = text;
let _ = cfg;
anyhow::bail!("Local TTS feature not compiled in")
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum TtsProviderKind {
Voicebox,
OpenAiCompatible,
OpenAi,
Local,
}
impl TtsProviderKind {
pub(crate) fn label(self) -> String {
match self {
TtsProviderKind::Voicebox => "voicebox".to_string(),
TtsProviderKind::OpenAiCompatible => "openai_compatible".to_string(),
TtsProviderKind::OpenAi => "openai".to_string(),
TtsProviderKind::Local => "local".to_string(),
}
}
pub(crate) fn from_label(s: &str) -> Option<Self> {
match s.trim().to_lowercase().as_str() {
"voicebox" => Some(TtsProviderKind::Voicebox),
"openai_compatible" | "openai-compatible" => Some(TtsProviderKind::OpenAiCompatible),
"openai" => Some(TtsProviderKind::OpenAi),
"local" | "piper" | "local_piper" => Some(TtsProviderKind::Local),
_ => None,
}
}
}
fn is_opus(audio: &[u8]) -> bool {
audio.starts_with(b"OggS")
}
async fn ensure_opus(audio_bytes: Vec<u8>) -> Vec<u8> {
if is_opus(&audio_bytes) {
return audio_bytes;
}
let audio_len = audio_bytes.len();
let audio_fallback = audio_bytes.clone();
let handle = tokio::task::spawn_blocking(move || {
let ffmpeg_path = std::env::var("FFMPEG_PATH").unwrap_or_else(|_| "ffmpeg".to_string());
let mut cmd = std::process::Command::new(&ffmpeg_path);
cmd.args([
"-i",
"pipe:0",
"-f",
"ogg",
"-c:a",
"libopus",
"-b:a",
"48k",
"-application",
"voip",
"-y",
"pipe:1",
])
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped());
if let Ok(mut child) = cmd.spawn() {
use std::io::Write;
let write_ok = child
.stdin
.take()
.map(|mut stdin| stdin.write_all(&audio_bytes).is_ok())
.unwrap_or(false);
if write_ok {
if let Ok(output) = child.wait_with_output() {
if output.status.success() && !output.stdout.is_empty() {
tracing::info!(
"TTS: voip opus conversion ok ({} → {} bytes)",
audio_bytes.len(),
output.stdout.len()
);
return output.stdout;
}
let stderr = String::from_utf8_lossy(&output.stderr);
tracing::warn!(
"TTS: voip opus failed ({}): {}",
output.status,
stderr.lines().next().unwrap_or("")
);
}
} else {
tracing::warn!("TTS: failed to write audio to ffmpeg stage 1");
}
}
tracing::info!("TTS: trying basic opus fallback");
let mut cmd2 = std::process::Command::new(&ffmpeg_path);
cmd2.args([
"-i", "pipe:0", "-f", "ogg", "-c:a", "libopus", "-b:a", "64k", "-y", "pipe:1",
])
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped());
if let Ok(mut child) = cmd2.spawn() {
use std::io::Write;
let write_ok = child
.stdin
.take()
.map(|mut stdin| stdin.write_all(&audio_bytes).is_ok())
.unwrap_or(false);
if write_ok {
if let Ok(output) = child.wait_with_output() {
if output.status.success() && !output.stdout.is_empty() {
tracing::info!(
"TTS: basic opus fallback ok ({} → {} bytes)",
audio_bytes.len(),
output.stdout.len()
);
return output.stdout;
}
let stderr = String::from_utf8_lossy(&output.stderr);
tracing::warn!(
"TTS: basic opus fallback failed ({}): {}",
output.status,
stderr.lines().next().unwrap_or("")
);
}
} else {
tracing::warn!("TTS: failed to write audio to ffmpeg stage 2");
}
} else {
tracing::warn!("TTS: ffmpeg not found at '{}'", ffmpeg_path);
}
audio_bytes
});
match tokio::time::timeout(std::time::Duration::from_secs(300), handle).await {
Ok(result) => result.unwrap_or(audio_fallback),
Err(_) => {
tracing::warn!(
"TTS: ffmpeg conversion timed out after 5m ({} bytes), sending original audio",
audio_len
);
audio_fallback
}
}
}
#[cfg(feature = "local-tts")]
async fn synthesize_speech_local(text: &str, voice_id: &str) -> Result<Vec<u8>> {
let text = text.to_string();
let voice_id = voice_id.to_string();
let handle = tokio::task::spawn_blocking(move || {
let engine = super::local_tts::PiperTts::new(&voice_id)?;
engine.synthesize_opus(&text)
});
match tokio::time::timeout(std::time::Duration::from_secs(120), handle).await {
Ok(Ok(result)) => result,
Ok(Err(e)) => anyhow::bail!("Local TTS task panicked: {}", e),
Err(_) => anyhow::bail!("Local TTS timed out after 120s"),
}
}
pub async fn synthesize_speech(
text: &str,
openai_api_key: &str,
voice: &str,
model: &str,
) -> Result<Vec<u8>> {
synthesize_speech_with_url(text, openai_api_key, voice, model, OPENAI_SPEECH_URL).await
}
pub(crate) async fn synthesize_speech_with_url(
text: &str,
api_key: &str,
voice: &str,
model: &str,
url: &str,
) -> Result<Vec<u8>> {
let client = Client::new();
let body = serde_json::json!({
"model": model,
"input": text,
"voice": voice,
"response_format": "opus",
});
let response = client
.post(url)
.header("Authorization", format!("Bearer {}", api_key))
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.context("Failed to send TTS request to OpenAI")?;
let status = response.status();
if !status.is_success() {
let error_text = response.text().await.unwrap_or_default();
anyhow::bail!("OpenAI TTS error ({}): {}", status, error_text);
}
let audio_bytes = response
.bytes()
.await
.context("Failed to read TTS audio bytes")?
.to_vec();
tracing::info!(
"OpenAI TTS: generated {} bytes of audio (voice={}, model={})",
audio_bytes.len(),
voice,
model,
);
Ok(audio_bytes)
}
#[derive(Deserialize)]
pub(crate) struct TranscriptionResponse {
pub(crate) text: String,
}
#[cfg(feature = "local-stt")]
static CACHED_WHISPER: tokio::sync::OnceCell<super::local_whisper::LocalWhisper> =
tokio::sync::OnceCell::const_new();
#[cfg(feature = "local-stt")]
pub async fn preload_local_whisper(model_id: &str) -> Result<()> {
CACHED_WHISPER
.get_or_try_init(|| async {
let preset = super::local_whisper::find_local_model(model_id)
.ok_or_else(|| anyhow::anyhow!("Unknown local STT model: {}", model_id))?;
super::local_whisper::LocalWhisper::new(preset).await
})
.await?;
Ok(())
}
#[cfg(feature = "local-stt")]
pub async fn transcribe_audio_local(audio_bytes: Vec<u8>, model_id: String) -> Result<String> {
let len = audio_bytes.len();
tracing::info!(
"Local STT: starting transcription ({} bytes audio, model {})",
len,
model_id
);
let whisper = CACHED_WHISPER
.get_or_try_init(|| async {
let preset = super::local_whisper::find_local_model(&model_id)
.ok_or_else(|| anyhow::anyhow!("Unknown local STT model: {}", model_id))?;
super::local_whisper::LocalWhisper::new(preset).await
})
.await?;
match tokio::time::timeout(
std::time::Duration::from_secs(300),
whisper.transcribe(&audio_bytes),
)
.await
{
Ok(Ok(text)) => {
tracing::info!("Local STT: transcription complete");
Ok(text)
}
Ok(Err(e)) => {
tracing::error!("Local STT: transcription failed: {}", e);
Err(e)
}
Err(_) => {
tracing::error!("Local STT: transcription timed out after 300s");
anyhow::bail!("Local STT transcription timed out (300s)")
}
}
}