#![cfg(feature = "stt-cloud-wasm")]
#![allow(missing_docs)]
use std::fmt;
use async_trait::async_trait;
use serde::Deserialize;
use super::SttError;
pub mod anthropic;
pub mod groq;
pub mod openai;
#[cfg(feature = "stt-candle")]
pub mod local_candle;
#[cfg(not(target_arch = "wasm32"))]
#[async_trait]
pub trait SttProvider: Send + Sync + fmt::Debug {
async fn transcribe(
&self,
audio_bytes: Vec<u8>,
audio_mime: &str,
lang_hint: Option<&str>,
) -> Result<String, SttError>;
fn name(&self) -> &'static str;
}
#[async_trait(?Send)]
#[cfg(target_arch = "wasm32")]
pub trait SttProvider: fmt::Debug {
async fn transcribe(
&self,
audio_bytes: Vec<u8>,
audio_mime: &str,
lang_hint: Option<&str>,
) -> Result<String, SttError>;
fn name(&self) -> &'static str;
}
#[derive(Deserialize)]
struct OpenAiCompatibleTranscription {
text: String,
}
async fn post_openai_compatible(
endpoint: &str,
api_key: &str,
model: &str,
audio_bytes: Vec<u8>,
audio_mime: &str,
lang_hint: Option<&str>,
) -> Result<String, SttError> {
let boundary = format!("nexo-stt-{}", uuid::Uuid::new_v4().simple());
let body = build_openai_multipart_body(&boundary, model, audio_bytes, audio_mime, lang_hint);
let content_type = format!("multipart/form-data; boundary={boundary}");
let client = reqwest::Client::new();
let resp = client
.post(endpoint)
.bearer_auth(api_key)
.header("content-type", content_type)
.body(body)
.send()
.await
.map_err(|e| SttError::Whisper(format!("cloud HTTP send: {e}")))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(SttError::Whisper(format!(
"cloud STT HTTP {status}: {body}"
)));
}
let parsed: OpenAiCompatibleTranscription = resp
.json()
.await
.map_err(|e| SttError::Whisper(format!("cloud JSON parse: {e}")))?;
let text = parsed.text.trim().to_string();
if text.is_empty() {
return Err(SttError::EmptyTranscript);
}
Ok(text)
}
fn build_openai_multipart_body(
boundary: &str,
model: &str,
audio_bytes: Vec<u8>,
audio_mime: &str,
lang_hint: Option<&str>,
) -> Vec<u8> {
let mut body = Vec::with_capacity(audio_bytes.len() + 512);
let mut push_text = |name: &str, value: &str| {
body.extend_from_slice(format!("--{boundary}\r\n").as_bytes());
body.extend_from_slice(
format!("Content-Disposition: form-data; name=\"{name}\"\r\n\r\n").as_bytes(),
);
body.extend_from_slice(value.as_bytes());
body.extend_from_slice(b"\r\n");
};
push_text("model", model);
push_text("response_format", "json");
if let Some(lang) = lang_hint.filter(|l| !l.is_empty() && *l != "auto") {
let base = lang.split(|c| c == '-' || c == '_').next().unwrap_or(lang);
push_text("language", &base.to_lowercase());
}
body.extend_from_slice(format!("--{boundary}\r\n").as_bytes());
body.extend_from_slice(
b"Content-Disposition: form-data; name=\"file\"; filename=\"audio\"\r\n",
);
body.extend_from_slice(format!("Content-Type: {audio_mime}\r\n\r\n").as_bytes());
body.extend_from_slice(&audio_bytes);
body.extend_from_slice(b"\r\n");
body.extend_from_slice(format!("--{boundary}--\r\n").as_bytes());
body
}
#[derive(Debug)]
pub struct CompositeProvider {
providers: Vec<Box<dyn SttProvider>>,
}
impl CompositeProvider {
pub fn new(providers: Vec<Box<dyn SttProvider>>) -> Self {
Self { providers }
}
pub fn push(&mut self, provider: Box<dyn SttProvider>) {
self.providers.push(provider);
}
}
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
impl SttProvider for CompositeProvider {
async fn transcribe(
&self,
audio_bytes: Vec<u8>,
audio_mime: &str,
lang_hint: Option<&str>,
) -> Result<String, SttError> {
if self.providers.is_empty() {
return Err(SttError::Whisper(
"CompositeProvider has no legs configured — set at least one provider".into(),
));
}
let mut last_err: Option<SttError> = None;
for (idx, provider) in self.providers.iter().enumerate() {
let bytes = audio_bytes.clone();
match provider.transcribe(bytes, audio_mime, lang_hint).await {
Ok(text) => {
tracing::info!(
target: "stt.cloud",
provider = provider.name(),
leg = idx,
transcript_len = text.len(),
"composite STT transcription ok"
);
return Ok(text);
}
Err(err) => {
if matches!(
&err,
SttError::EmptyAudio | SttError::UnsupportedFormat(_) | SttError::Decode(_)
) {
return Err(err);
}
tracing::warn!(
target: "stt.cloud",
provider = provider.name(),
leg = idx,
error = %err,
"composite STT leg failed, trying next"
);
last_err = Some(err);
}
}
}
Err(last_err.unwrap_or_else(|| {
SttError::Whisper("CompositeProvider exhausted every leg without success".into())
}))
}
fn name(&self) -> &'static str {
"composite"
}
}
#[cfg(not(target_arch = "wasm32"))]
pub async fn transcribe_file_with_chain(
path: &std::path::Path,
chain: &dyn SttProvider,
lang_hint: Option<&str>,
) -> Result<String, SttError> {
let bytes = tokio::fs::read(path).await.map_err(SttError::Io)?;
if bytes.is_empty() {
return Err(SttError::EmptyAudio);
}
let mime = mime_from_path(path);
chain.transcribe(bytes, mime, lang_hint).await
}
#[cfg(not(target_arch = "wasm32"))]
pub fn mime_from_path(path: &std::path::Path) -> &'static str {
let ext = path.extension().and_then(|s| s.to_str()).unwrap_or("");
match ext.to_ascii_lowercase().as_str() {
"ogg" | "oga" | "opus" => "audio/ogg",
"wav" => "audio/wav",
"mp3" => "audio/mpeg",
"m4a" | "mp4" | "aac" => "audio/mp4",
"flac" => "audio/flac",
"webm" => "audio/webm",
_ => "application/octet-stream",
}
}
#[cfg(all(feature = "stt-cloud-anthropic", feature = "stt-cloud-local-candle"))]
pub fn anthropic_then_candle(
anthropic_token: impl Into<String>,
candle_cfg: std::sync::Arc<crate::stt::TranscribeConfig>,
) -> CompositeProvider {
CompositeProvider::new(vec![
Box::new(anthropic::AnthropicVoiceStream::new(anthropic_token)),
Box::new(local_candle::LocalCandleProvider::new(candle_cfg)),
])
}
#[cfg(feature = "stt-cloud-local-candle")]
pub fn openai_then_candle(
api_key: impl Into<String>,
candle_cfg: std::sync::Arc<crate::stt::TranscribeConfig>,
) -> CompositeProvider {
CompositeProvider::new(vec![
Box::new(openai::OpenAiProvider::new(api_key)),
Box::new(local_candle::LocalCandleProvider::new(candle_cfg)),
])
}
#[cfg(feature = "stt-cloud-local-candle")]
pub fn groq_then_candle(
api_key: impl Into<String>,
candle_cfg: std::sync::Arc<crate::stt::TranscribeConfig>,
) -> CompositeProvider {
CompositeProvider::new(vec![
Box::new(groq::GroqProvider::new(api_key)),
Box::new(local_candle::LocalCandleProvider::new(candle_cfg)),
])
}
#[cfg(test)]
mod tests {
use super::*;
struct StubProvider {
outcome: Box<dyn Fn() -> Result<String, SttError> + Send + Sync>,
name: &'static str,
}
impl fmt::Debug for StubProvider {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("StubProvider")
.field("name", &self.name)
.field("outcome", &"<dyn Fn>")
.finish()
}
}
#[async_trait]
impl SttProvider for StubProvider {
async fn transcribe(
&self,
_audio: Vec<u8>,
_mime: &str,
_lang: Option<&str>,
) -> Result<String, SttError> {
(self.outcome)()
}
fn name(&self) -> &'static str {
self.name
}
}
fn ok_leg(text: &'static str) -> Box<dyn SttProvider> {
Box::new(StubProvider {
outcome: Box::new(move || Ok(text.to_string())),
name: text,
})
}
fn whisper_err_leg(name: &'static str) -> Box<dyn SttProvider> {
Box::new(StubProvider {
outcome: Box::new(|| Err(SttError::Whisper("transport blow-up".into()))),
name,
})
}
fn decode_err_leg(name: &'static str) -> Box<dyn SttProvider> {
Box::new(StubProvider {
outcome: Box::new(|| Err(SttError::Decode("bad audio".into()))),
name,
})
}
#[tokio::test]
async fn composite_first_leg_ok_returns_immediately() {
let chain = CompositeProvider::new(vec![ok_leg("primary"), ok_leg("backup")]);
let out = chain
.transcribe(vec![1, 2, 3], "audio/ogg", None)
.await
.unwrap();
assert_eq!(out, "primary");
}
#[tokio::test]
async fn composite_transport_failure_falls_through_to_next() {
let chain = CompositeProvider::new(vec![whisper_err_leg("primary"), ok_leg("backup")]);
let out = chain
.transcribe(vec![1, 2, 3], "audio/ogg", None)
.await
.unwrap();
assert_eq!(out, "backup");
}
#[tokio::test]
async fn composite_decode_failure_short_circuits() {
let chain = CompositeProvider::new(vec![decode_err_leg("primary"), ok_leg("backup")]);
let err = match chain.transcribe(vec![1, 2, 3], "audio/ogg", None).await {
Ok(t) => panic!("expected error, got {t:?}"),
Err(e) => e,
};
assert!(matches!(err, SttError::Decode(_)));
}
#[tokio::test]
async fn composite_empty_chain_errors() {
let chain = CompositeProvider::new(vec![]);
let err = match chain.transcribe(vec![], "audio/ogg", None).await {
Ok(_) => panic!("expected error"),
Err(e) => e,
};
assert!(matches!(err, SttError::Whisper(_)));
}
#[tokio::test]
async fn composite_all_legs_fail_returns_last_error() {
let chain =
CompositeProvider::new(vec![whisper_err_leg("primary"), whisper_err_leg("backup")]);
let err = match chain.transcribe(vec![1], "audio/ogg", None).await {
Ok(_) => panic!("expected error"),
Err(e) => e,
};
assert!(matches!(err, SttError::Whisper(_)));
}
#[cfg(not(target_arch = "wasm32"))]
#[test]
fn mime_from_path_matrix() {
use std::path::Path;
assert_eq!(mime_from_path(Path::new("voice.ogg")), "audio/ogg");
assert_eq!(mime_from_path(Path::new("voice.OGG")), "audio/ogg");
assert_eq!(mime_from_path(Path::new("voice.opus")), "audio/ogg");
assert_eq!(mime_from_path(Path::new("voice.oga")), "audio/ogg");
assert_eq!(mime_from_path(Path::new("voice.wav")), "audio/wav");
assert_eq!(mime_from_path(Path::new("voice.mp3")), "audio/mpeg");
assert_eq!(mime_from_path(Path::new("voice.m4a")), "audio/mp4");
assert_eq!(mime_from_path(Path::new("voice.mp4")), "audio/mp4");
assert_eq!(mime_from_path(Path::new("voice.aac")), "audio/mp4");
assert_eq!(mime_from_path(Path::new("voice.flac")), "audio/flac");
assert_eq!(mime_from_path(Path::new("voice.webm")), "audio/webm");
assert_eq!(
mime_from_path(Path::new("voice.unknown")),
"application/octet-stream"
);
assert_eq!(
mime_from_path(Path::new("noext")),
"application/octet-stream"
);
assert_eq!(mime_from_path(Path::new("/tmp/a.b.c.ogg")), "audio/ogg");
}
#[cfg(not(target_arch = "wasm32"))]
#[tokio::test]
async fn transcribe_file_with_chain_reads_bytes_and_dispatches() {
use std::sync::{Arc, Mutex};
struct CapturingProvider {
captured: Arc<Mutex<Option<(Vec<u8>, String, Option<String>)>>>,
}
impl fmt::Debug for CapturingProvider {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "CapturingProvider")
}
}
#[async_trait]
impl SttProvider for CapturingProvider {
async fn transcribe(
&self,
audio: Vec<u8>,
mime: &str,
lang: Option<&str>,
) -> Result<String, SttError> {
*self.captured.lock().unwrap() =
Some((audio, mime.to_string(), lang.map(str::to_string)));
Ok("captured".into())
}
fn name(&self) -> &'static str {
"capturing"
}
}
let tmp = tempfile::Builder::new()
.prefix("stt-chain-test-")
.suffix(".ogg")
.tempfile()
.unwrap();
let payload = b"OggS\0\xdeadbeef".to_vec();
tokio::fs::write(tmp.path(), &payload).await.unwrap();
let captured = Arc::new(Mutex::new(None));
let provider = CapturingProvider {
captured: captured.clone(),
};
let result = transcribe_file_with_chain(tmp.path(), &provider, Some("es"))
.await
.unwrap();
assert_eq!(result, "captured");
let captured = captured.lock().unwrap().clone().unwrap();
assert_eq!(captured.0, payload);
assert_eq!(captured.1, "audio/ogg");
assert_eq!(captured.2.as_deref(), Some("es"));
}
#[cfg(not(target_arch = "wasm32"))]
#[tokio::test]
async fn transcribe_file_with_chain_rejects_empty_file() {
let tmp = tempfile::Builder::new()
.prefix("stt-empty-")
.suffix(".ogg")
.tempfile()
.unwrap();
let chain = CompositeProvider::new(vec![ok_leg("never-called")]);
let err = transcribe_file_with_chain(tmp.path(), &chain, None)
.await
.expect_err("empty file should error");
assert!(matches!(err, SttError::EmptyAudio));
}
#[cfg(not(target_arch = "wasm32"))]
#[tokio::test]
async fn transcribe_file_with_chain_io_error_on_missing_path() {
use std::path::PathBuf;
let nonexistent = PathBuf::from("/nonexistent/path/voice.ogg");
let chain = CompositeProvider::new(vec![ok_leg("never-called")]);
let err = transcribe_file_with_chain(&nonexistent, &chain, None)
.await
.expect_err("missing file should error");
assert!(matches!(err, SttError::Io(_)));
}
#[test]
fn multipart_body_minimum_shape() {
let body = build_openai_multipart_body(
"TEST-BOUNDARY",
"whisper-1",
b"audio-payload".to_vec(),
"audio/ogg",
None,
);
let s = String::from_utf8(body).unwrap();
assert_eq!(s.matches("--TEST-BOUNDARY\r\n").count(), 3);
assert!(s.ends_with("--TEST-BOUNDARY--\r\n"));
assert!(s.contains("name=\"model\"\r\n\r\nwhisper-1\r\n"));
assert!(s.contains("name=\"response_format\"\r\n\r\njson\r\n"));
assert!(s.contains(
"name=\"file\"; filename=\"audio\"\r\n\
Content-Type: audio/ogg\r\n\r\n\
audio-payload\r\n"
));
assert!(!s.contains("name=\"language\""));
}
#[test]
fn multipart_body_includes_language_when_hint_given() {
let body = build_openai_multipart_body("X", "m", b"a".to_vec(), "audio/ogg", Some("es"));
let s = String::from_utf8(body).unwrap();
assert!(s.contains("name=\"language\"\r\n\r\nes\r\n"));
}
#[test]
fn multipart_body_strips_bcp47_region_subtag() {
let body = build_openai_multipart_body("X", "m", b"a".to_vec(), "audio/ogg", Some("es-AR"));
let s = String::from_utf8(body).unwrap();
assert!(s.contains("name=\"language\"\r\n\r\nes\r\n"));
assert!(!s.contains("es-AR"));
assert!(!s.contains("AR"));
}
#[test]
fn multipart_body_omits_language_for_auto() {
let body = build_openai_multipart_body("X", "m", b"a".to_vec(), "audio/ogg", Some("auto"));
let s = String::from_utf8(body).unwrap();
assert!(!s.contains("name=\"language\""));
}
#[test]
fn multipart_body_omits_language_for_empty_hint() {
let body = build_openai_multipart_body("X", "m", b"a".to_vec(), "audio/ogg", Some(""));
let s = String::from_utf8(body).unwrap();
assert!(!s.contains("name=\"language\""));
}
#[test]
fn multipart_body_preserves_binary_audio_verbatim() {
let audio: Vec<u8> = (0u8..=255).collect();
let body = build_openai_multipart_body("X", "m", audio.clone(), "audio/L16", None);
let needle = b"Content-Type: audio/L16\r\n\r\n";
let idx = body
.windows(needle.len())
.position(|w| w == needle)
.expect("file field header must be present");
let audio_start = idx + needle.len();
let audio_end = audio_start + audio.len();
assert_eq!(&body[audio_start..audio_end], audio.as_slice());
assert_eq!(&body[audio_end..audio_end + 2], b"\r\n");
}
}