#![cfg(all(feature = "stt-cloud-wasm", feature = "stt-candle"))]
use std::borrow::Cow;
use std::path::PathBuf;
use std::sync::Arc;
use async_trait::async_trait;
use super::SttProvider;
use crate::stt::{transcribe_candle, SttError, TranscribeConfig};
pub struct LocalCandleProvider {
cfg: Arc<TranscribeConfig>,
}
impl std::fmt::Debug for LocalCandleProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LocalCandleProvider")
.field("model_path", &self.cfg.model_path)
.field("model_id", &self.cfg.model_id)
.field("lang_hint", &self.cfg.lang_hint)
.finish()
}
}
impl LocalCandleProvider {
pub fn new(cfg: Arc<TranscribeConfig>) -> Self {
Self { cfg }
}
}
fn extension_for_mime(mime: &str) -> Option<&'static str> {
let lower = mime.to_lowercase();
if lower.starts_with("audio/ogg") || lower.starts_with("audio/opus") {
Some("ogg")
} else {
None
}
}
#[async_trait]
impl SttProvider for LocalCandleProvider {
async fn transcribe(
&self,
audio_bytes: Vec<u8>,
audio_mime: &str,
lang_hint: Option<&str>,
) -> Result<String, SttError> {
if audio_bytes.is_empty() {
return Err(SttError::EmptyAudio);
}
let ext = extension_for_mime(audio_mime).ok_or_else(|| {
SttError::UnsupportedFormat(format!(
"LocalCandleProvider only decodes ogg-opus (got {audio_mime:?}); \
transcode upstream or route through a cloud leg that accepts \
the source format"
))
})?;
let tempfile = tempfile::Builder::new()
.prefix("nexo-stt-candle-")
.suffix(&format!(".{ext}"))
.tempfile()
.map_err(SttError::Io)?;
let path: PathBuf = tempfile.path().to_path_buf();
tokio::fs::write(&path, &audio_bytes).await?;
let cfg = if lang_hint.is_some() && lang_hint.map(str::to_string) != self.cfg.lang_hint {
let mut c = (*self.cfg).clone();
c.lang_hint = lang_hint.map(str::to_string);
Cow::Owned(c)
} else {
Cow::Borrowed(self.cfg.as_ref())
};
let out = transcribe_candle::transcribe_file(&path, cfg.as_ref()).await?;
Ok(out)
}
fn name(&self) -> &'static str {
"local-candle"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn name_is_stable() {
let p = LocalCandleProvider::new(Arc::new(TranscribeConfig::default()));
assert_eq!(p.name(), "local-candle");
}
#[test]
fn debug_surfaces_cfg_state() {
let cfg = Arc::new(TranscribeConfig {
lang_hint: Some("es".into()),
model_id: Some("openai/whisper-tiny".into()),
..Default::default()
});
let p = LocalCandleProvider::new(cfg);
let dbg = format!("{p:?}");
assert!(dbg.contains("LocalCandleProvider"));
assert!(dbg.contains("\"es\""));
assert!(dbg.contains("openai/whisper-tiny"));
}
#[tokio::test]
async fn empty_audio_rejected_before_tempfile() {
let p = LocalCandleProvider::new(Arc::new(TranscribeConfig::default()));
let err = match p.transcribe(vec![], "audio/ogg", None).await {
Ok(t) => panic!("expected error, got {t:?}"),
Err(e) => e,
};
assert!(matches!(err, SttError::EmptyAudio));
}
#[tokio::test]
async fn unsupported_mime_rejected() {
let p = LocalCandleProvider::new(Arc::new(TranscribeConfig::default()));
let err = match p.transcribe(vec![1, 2, 3], "audio/mp3", None).await {
Ok(t) => panic!("expected error, got {t:?}"),
Err(e) => e,
};
assert!(matches!(err, SttError::UnsupportedFormat(_)));
}
#[test]
fn extension_picker_matrix() {
assert_eq!(extension_for_mime("audio/ogg"), Some("ogg"));
assert_eq!(extension_for_mime("audio/ogg; codecs=opus"), Some("ogg"));
assert_eq!(extension_for_mime("AUDIO/OGG"), Some("ogg"));
assert_eq!(extension_for_mime("audio/opus"), Some("ogg"));
assert_eq!(extension_for_mime("audio/mp3"), None);
assert_eq!(extension_for_mime("audio/L16; rate=16000"), None);
assert_eq!(extension_for_mime(""), None);
}
}