use anyhow::{Context, Result};
use async_trait::async_trait;
use serde::Deserialize;
use std::collections::HashMap;
use std::sync::{Arc, OnceLock};
mod deepgram;
mod elevenlabs;
mod groq;
#[cfg(feature = "local-whisper")]
mod local_whisper;
mod mistral;
mod openai;
pub const DEFAULT_TIMEOUT_SECS: u64 = 300;
pub use deepgram::DeepgramProvider;
pub use elevenlabs::ElevenLabsProvider;
pub use groq::GroqProvider;
#[cfg(feature = "local-whisper")]
pub use local_whisper::LocalWhisperProvider;
pub use mistral::MistralProvider;
pub use openai::OpenAIProvider;
use crate::config::TranscriptionProvider;
#[derive(Clone)]
pub struct TranscriptionRequest {
pub audio_data: Vec<u8>,
pub language: Option<String>,
pub filename: String,
pub mime_type: String,
}
pub struct TranscriptionResult {
pub text: String,
}
#[derive(Deserialize)]
struct OpenAICompatibleResponse {
text: String,
}
pub(crate) fn openai_compatible_transcribe_sync(
api_url: &str,
model: &str,
api_key: &str,
request: TranscriptionRequest,
) -> Result<TranscriptionResult> {
let client = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(DEFAULT_TIMEOUT_SECS))
.build()
.context("Failed to create HTTP client")?;
let mut form = reqwest::blocking::multipart::Form::new()
.text("model", model.to_string())
.part(
"file",
reqwest::blocking::multipart::Part::bytes(request.audio_data)
.file_name(request.filename)
.mime_str(&request.mime_type)?,
);
if let Some(lang) = request.language {
form = form.text("language", lang);
}
let response = client
.post(api_url)
.header("Authorization", format!("Bearer {api_key}"))
.multipart(form)
.send()
.context("Failed to send request")?;
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.unwrap_or_else(|_| "Unknown error".to_string());
anyhow::bail!("API error ({status}): {error_text}");
}
let text = response.text().context("Failed to get response text")?;
let resp: OpenAICompatibleResponse =
serde_json::from_str(&text).context("Failed to parse API response")?;
Ok(TranscriptionResult { text: resp.text })
}
pub(crate) async fn openai_compatible_transcribe_async(
client: &reqwest::Client,
api_url: &str,
model: &str,
api_key: &str,
request: TranscriptionRequest,
) -> Result<TranscriptionResult> {
let mut form = reqwest::multipart::Form::new()
.text("model", model.to_string())
.part(
"file",
reqwest::multipart::Part::bytes(request.audio_data)
.file_name(request.filename)
.mime_str(&request.mime_type)?,
);
if let Some(lang) = request.language {
form = form.text("language", lang);
}
let response = client
.post(api_url)
.header("Authorization", format!("Bearer {api_key}"))
.multipart(form)
.send()
.await
.context("Failed to send request")?;
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
anyhow::bail!("API error ({status}): {error_text}");
}
let text = response
.text()
.await
.context("Failed to get response text")?;
let resp: OpenAICompatibleResponse =
serde_json::from_str(&text).context("Failed to parse API response")?;
Ok(TranscriptionResult { text: resp.text })
}
#[async_trait]
pub trait TranscriptionBackend: Send + Sync {
fn name(&self) -> &'static str;
fn display_name(&self) -> &'static str;
fn transcribe_sync(
&self,
api_key: &str,
request: TranscriptionRequest,
) -> Result<TranscriptionResult>;
async fn transcribe_async(
&self,
client: &reqwest::Client,
api_key: &str,
request: TranscriptionRequest,
) -> Result<TranscriptionResult>;
}
pub struct ProviderRegistry {
providers: HashMap<&'static str, Arc<dyn TranscriptionBackend>>,
}
impl ProviderRegistry {
pub fn new() -> Self {
let mut providers: HashMap<&'static str, Arc<dyn TranscriptionBackend>> = HashMap::new();
providers.insert("openai", Arc::new(OpenAIProvider));
providers.insert("mistral", Arc::new(MistralProvider));
providers.insert("groq", Arc::new(GroqProvider));
providers.insert("deepgram", Arc::new(DeepgramProvider));
providers.insert("elevenlabs", Arc::new(ElevenLabsProvider));
#[cfg(feature = "local-whisper")]
providers.insert("local-whisper", Arc::new(LocalWhisperProvider));
Self { providers }
}
pub fn get(&self, name: &str) -> Option<Arc<dyn TranscriptionBackend>> {
self.providers.get(name).cloned()
}
pub fn list(&self) -> Vec<&'static str> {
self.providers.keys().copied().collect()
}
pub fn get_by_kind(
&self,
kind: &TranscriptionProvider,
) -> Result<Arc<dyn TranscriptionBackend>> {
self.get(kind.as_str()).ok_or_else(|| {
anyhow::anyhow!(
"Provider '{}' not found in registry. This is a bug - \
all TranscriptionProvider variants must have registered providers.",
kind.as_str()
)
})
}
}
impl Default for ProviderRegistry {
fn default() -> Self {
Self::new()
}
}
pub fn registry() -> &'static ProviderRegistry {
static REGISTRY: OnceLock<ProviderRegistry> = OnceLock::new();
REGISTRY.get_or_init(ProviderRegistry::new)
}