use crate::stt::SttProvider;
use crate::{Result, VoiceConfig, VoiceError};
use async_trait::async_trait;
const STT_ENDPOINT: &str = "https://api.elevenlabs.io/v1/speech-to-text";
const STT_MODEL: &str = "scribe_v1";
pub struct ElevenLabsSttProvider {
client: reqwest::Client,
api_key: String,
language: String,
}
impl std::fmt::Debug for ElevenLabsSttProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ElevenLabsSttProvider")
.field("api_key", &"<redacted>")
.field("language", &self.language)
.finish()
}
}
impl ElevenLabsSttProvider {
pub fn new(api_key: impl Into<String>, language: impl Into<String>) -> Self {
Self {
client: reqwest::Client::new(),
api_key: api_key.into(),
language: language.into(),
}
}
pub fn from_config(config: &VoiceConfig) -> Result<Self> {
let api_key = config
.elevenlabs_api_key
.clone()
.filter(|k| !k.is_empty())
.or_else(|| car_secrets::resolve_env_or_keychain("ELEVENLABS_API_KEY"))
.ok_or_else(|| {
VoiceError::Config(
"ELEVENLABS_API_KEY not set; set the env var or store it via \
`car secrets put ELEVENLABS_API_KEY`"
.into(),
)
})?;
Ok(Self {
client: reqwest::Client::new(),
api_key,
language: config.language.clone(),
})
}
}
#[async_trait]
impl SttProvider for ElevenLabsSttProvider {
async fn transcribe(&self, samples: &[f32], sample_rate: u32) -> Result<String> {
let wav_data = encode_wav(samples, sample_rate);
let part = reqwest::multipart::Part::bytes(wav_data)
.file_name("recording.wav")
.mime_str("audio/wav")
.map_err(|e| VoiceError::Stt(format!("mime: {e}")))?;
let lang_code = match self.language.as_str() {
"en" => "eng",
other => other,
};
let form = reqwest::multipart::Form::new()
.text("model_id", STT_MODEL)
.text("language_code", lang_code.to_string())
.part("file", part);
let resp = self
.client
.post(STT_ENDPOINT)
.header("xi-api-key", &self.api_key)
.multipart(form)
.send()
.await
.map_err(|e| VoiceError::Stt(format!("http: {e}")))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(VoiceError::Stt(format!("API {status}: {body}")));
}
let json: serde_json::Value = resp
.json()
.await
.map_err(|e| VoiceError::Stt(format!("json: {e}")))?;
Ok(json["text"].as_str().unwrap_or("").trim().to_string())
}
}
pub async fn transcribe(samples: &[f32], sample_rate: u32, api_key: &str) -> Result<String> {
let provider = ElevenLabsSttProvider::new(api_key, "en");
provider.transcribe(samples, sample_rate).await
}
pub fn encode_wav(samples: &[f32], sample_rate: u32) -> Vec<u8> {
let num_samples = samples.len();
let data_size = (num_samples * 2) as u32;
let file_size = 36 + data_size;
let mut buf = Vec::with_capacity(44 + data_size as usize);
buf.extend_from_slice(b"RIFF");
buf.extend_from_slice(&file_size.to_le_bytes());
buf.extend_from_slice(b"WAVE");
buf.extend_from_slice(b"fmt ");
buf.extend_from_slice(&16u32.to_le_bytes()); buf.extend_from_slice(&1u16.to_le_bytes()); buf.extend_from_slice(&1u16.to_le_bytes()); buf.extend_from_slice(&sample_rate.to_le_bytes());
buf.extend_from_slice(&(sample_rate * 2).to_le_bytes()); buf.extend_from_slice(&2u16.to_le_bytes()); buf.extend_from_slice(&16u16.to_le_bytes());
buf.extend_from_slice(b"data");
buf.extend_from_slice(&data_size.to_le_bytes());
for &sample in samples {
let clamped = sample.clamp(-1.0, 1.0);
let int_sample = (clamped * 32767.0) as i16;
buf.extend_from_slice(&int_sample.to_le_bytes());
}
buf
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn encode_wav_has_riff_header_and_data_chunk() {
let samples = vec![0.0_f32; 16_000];
let wav = encode_wav(&samples, 16_000);
assert_eq!(&wav[0..4], b"RIFF");
assert_eq!(&wav[8..12], b"WAVE");
assert_eq!(&wav[12..16], b"fmt ");
assert_eq!(&wav[36..40], b"data");
assert_eq!(wav.len(), 44 + 16_000 * 2);
}
#[test]
fn encode_wav_clamps_out_of_range_samples() {
let samples = vec![2.0_f32, -2.0_f32];
let wav = encode_wav(&samples, 16_000);
let s0 = i16::from_le_bytes([wav[44], wav[45]]);
let s1 = i16::from_le_bytes([wav[46], wav[47]]);
assert_eq!(s0, 32_767);
assert_eq!(s1, -32_767);
}
}