use crate::core::providers::unified_provider::ProviderError;
use serde::{Deserialize, Serialize};
const PROVIDER_NAME: &str = "elevenlabs";
pub const STT_ENDPOINT_PATH: &str = "/v1/speech-to-text";
#[derive(Debug, Clone)]
pub struct TranscriptionRequest {
pub file: Vec<u8>,
pub model_id: String,
pub language_code: Option<String>,
pub temperature: Option<f32>,
pub filename: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct TranscriptionResponse {
pub text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub language_code: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub words: Option<Vec<WordInfo>>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct WordInfo {
pub text: String,
pub start: f32,
pub end: f32,
#[serde(rename = "type")]
pub word_type: String,
}
#[derive(Debug, Clone, Serialize)]
pub struct OpenAITranscriptionResponse {
pub text: String,
pub task: String,
pub language: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub words: Option<Vec<OpenAIWordInfo>>,
}
#[derive(Debug, Clone, Serialize)]
pub struct OpenAIWordInfo {
pub word: String,
pub start: f32,
pub end: f32,
}
impl From<TranscriptionResponse> for OpenAITranscriptionResponse {
fn from(response: TranscriptionResponse) -> Self {
let words = response.words.map(|words| {
words
.into_iter()
.filter(|w| w.word_type == "word")
.map(|w| OpenAIWordInfo {
word: w.text,
start: w.start,
end: w.end,
})
.collect()
});
OpenAITranscriptionResponse {
text: response.text,
task: "transcribe".to_string(),
language: response
.language_code
.unwrap_or_else(|| "unknown".to_string()),
words,
}
}
}
pub fn create_multipart_form(
request: &TranscriptionRequest,
) -> Result<reqwest::multipart::Form, ProviderError> {
use reqwest::multipart;
let mut form = multipart::Form::new();
let (filename, mime_type) = match &request.filename {
Some(name) => {
let mime = detect_audio_mime_type(name);
(name.clone(), mime)
}
None => ("audio.mp3".to_string(), "audio/mpeg"),
};
let file_part = multipart::Part::bytes(request.file.clone())
.file_name(filename)
.mime_str(mime_type)
.map_err(|e| {
ProviderError::invalid_request(PROVIDER_NAME, format!("Invalid MIME type: {}", e))
})?;
form = form.part("file", file_part);
form = form.text("model_id", request.model_id.clone());
if let Some(language_code) = &request.language_code {
form = form.text("language_code", language_code.clone());
}
if let Some(temperature) = request.temperature {
form = form.text("temperature", temperature.to_string());
}
Ok(form)
}
fn detect_audio_mime_type(filename: &str) -> &'static str {
let extension = filename
.rsplit('.')
.next()
.map(|s| s.to_lowercase())
.unwrap_or_default();
match extension.as_str() {
"mp3" => "audio/mpeg",
"mp4" | "m4a" => "audio/mp4",
"wav" => "audio/wav",
"webm" => "audio/webm",
"ogg" | "oga" => "audio/ogg",
"flac" => "audio/flac",
_ => "audio/mpeg",
}
}
pub fn supported_audio_formats() -> &'static [&'static str] {
&["mp3", "mp4", "m4a", "wav", "webm", "ogg", "flac"]
}
pub const MAX_FILE_SIZE: usize = 100 * 1024 * 1024;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum STTModel {
#[default]
ScribeV1,
}
impl STTModel {
pub fn as_str(&self) -> &'static str {
match self {
STTModel::ScribeV1 => "scribe_v1",
}
}
pub fn parse(s: &str) -> Option<Self> {
match s {
"scribe_v1" => Some(STTModel::ScribeV1),
_ => None,
}
}
}
pub fn build_stt_url(base_url: &str) -> String {
let base = base_url.trim_end_matches('/');
format!("{}{}", base, STT_ENDPOINT_PATH)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detect_audio_mime_type() {
assert_eq!(detect_audio_mime_type("audio.mp3"), "audio/mpeg");
assert_eq!(detect_audio_mime_type("audio.wav"), "audio/wav");
assert_eq!(detect_audio_mime_type("audio.m4a"), "audio/mp4");
assert_eq!(detect_audio_mime_type("audio.webm"), "audio/webm");
assert_eq!(detect_audio_mime_type("audio.unknown"), "audio/mpeg");
}
#[test]
fn test_stt_model_as_str() {
assert_eq!(STTModel::ScribeV1.as_str(), "scribe_v1");
}
#[test]
fn test_stt_model_from_str() {
assert_eq!(STTModel::parse("scribe_v1"), Some(STTModel::ScribeV1));
assert_eq!(STTModel::parse("unknown"), None);
}
#[test]
fn test_build_stt_url() {
let url = build_stt_url("https://api.elevenlabs.io");
assert_eq!(url, "https://api.elevenlabs.io/v1/speech-to-text");
}
#[test]
fn test_transcription_response_to_openai() {
let response = TranscriptionResponse {
text: "Hello world".to_string(),
language_code: Some("en".to_string()),
words: Some(vec![
WordInfo {
text: "Hello".to_string(),
start: 0.0,
end: 0.5,
word_type: "word".to_string(),
},
WordInfo {
text: " ".to_string(),
start: 0.5,
end: 0.6,
word_type: "spacing".to_string(),
},
WordInfo {
text: "world".to_string(),
start: 0.6,
end: 1.0,
word_type: "word".to_string(),
},
]),
};
let openai_response: OpenAITranscriptionResponse = response.into();
assert_eq!(openai_response.text, "Hello world");
assert_eq!(openai_response.task, "transcribe");
assert_eq!(openai_response.language, "en");
let words = openai_response.words.unwrap();
assert_eq!(words.len(), 2); assert_eq!(words[0].word, "Hello");
assert_eq!(words[1].word, "world");
}
#[test]
fn test_supported_audio_formats() {
let formats = supported_audio_formats();
assert!(formats.contains(&"mp3"));
assert!(formats.contains(&"wav"));
assert!(formats.contains(&"webm"));
}
}