use anyhow::{Context, Result};
use async_trait::async_trait;
use serde::Deserialize;
use std::collections::HashMap;
use std::sync::{Arc, OnceLock};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TranscriptionStage {
Recording,
Encoding,
Uploading,
Transcribing,
PostProcessing,
Complete,
}
impl TranscriptionStage {
pub fn message(&self) -> &'static str {
match self {
Self::Recording => "Recording...",
Self::Encoding => "Encoding...",
Self::Uploading => "Uploading...",
Self::Transcribing => "Transcribing...",
Self::PostProcessing => "Post-processing...",
Self::Complete => "Done!",
}
}
}
pub type ProgressCallback = Arc<dyn Fn(TranscriptionStage) + Send + Sync>;
mod deepgram;
mod elevenlabs;
mod groq;
#[cfg(feature = "local-transcription")]
mod local_parakeet;
#[cfg(feature = "local-transcription")]
mod local_whisper;
mod mistral;
mod openai;
#[cfg(feature = "realtime")]
mod openai_realtime;
pub const DEFAULT_TIMEOUT_SECS: u64 = 300;
pub use deepgram::DeepgramProvider;
pub use elevenlabs::ElevenLabsProvider;
pub use groq::GroqProvider;
#[cfg(feature = "local-transcription")]
pub use local_parakeet::LocalParakeetProvider;
#[cfg(feature = "local-transcription")]
pub use local_parakeet::transcribe_raw as transcribe_raw_parakeet;
#[cfg(feature = "local-transcription")]
pub use local_whisper::LocalWhisperProvider;
#[cfg(feature = "local-transcription")]
pub use local_whisper::transcribe_raw;
pub use mistral::MistralProvider;
pub use openai::OpenAIProvider;
#[cfg(feature = "realtime")]
pub use openai_realtime::OpenAIRealtimeProvider;
use crate::config::TranscriptionProvider;
#[derive(Clone)]
pub struct TranscriptionRequest {
pub audio_data: Vec<u8>,
pub language: Option<String>,
pub filename: String,
pub mime_type: String,
pub progress: Option<ProgressCallback>,
}
impl TranscriptionRequest {
pub fn new(audio_data: Vec<u8>, language: Option<String>) -> Self {
Self {
audio_data,
language,
filename: "audio.mp3".to_string(),
mime_type: "audio/mpeg".to_string(),
progress: None,
}
}
pub fn with_progress(mut self, callback: ProgressCallback) -> Self {
self.progress = Some(callback);
self
}
pub fn report(&self, stage: TranscriptionStage) {
if let Some(cb) = &self.progress {
cb(stage);
}
}
}
pub struct TranscriptionResult {
pub text: String,
}
#[derive(Deserialize)]
struct OpenAICompatibleResponse {
text: String,
}
pub(crate) fn openai_compatible_transcribe_sync(
api_url: &str,
model: &str,
api_key: &str,
request: TranscriptionRequest,
) -> Result<TranscriptionResult> {
request.report(TranscriptionStage::Uploading);
let client = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(DEFAULT_TIMEOUT_SECS))
.build()
.context("Failed to create HTTP client")?;
let mut form = reqwest::blocking::multipart::Form::new()
.text("model", model.to_string())
.part(
"file",
reqwest::blocking::multipart::Part::bytes(request.audio_data.clone())
.file_name(request.filename.clone())
.mime_str(&request.mime_type)?,
);
if let Some(lang) = request.language.clone() {
form = form.text("language", lang);
}
request.report(TranscriptionStage::Transcribing);
let response = client
.post(api_url)
.header("Authorization", format!("Bearer {api_key}"))
.multipart(form)
.send()
.context("Failed to send request")?;
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.unwrap_or_else(|_| "Unknown error".to_string());
anyhow::bail!("API error ({status}): {error_text}");
}
let text = response.text().context("Failed to get response text")?;
let resp: OpenAICompatibleResponse =
serde_json::from_str(&text).context("Failed to parse API response")?;
Ok(TranscriptionResult { text: resp.text })
}
pub(crate) async fn openai_compatible_transcribe_async(
client: &reqwest::Client,
api_url: &str,
model: &str,
api_key: &str,
request: TranscriptionRequest,
) -> Result<TranscriptionResult> {
request.report(TranscriptionStage::Uploading);
let mut form = reqwest::multipart::Form::new()
.text("model", model.to_string())
.part(
"file",
reqwest::multipart::Part::bytes(request.audio_data.clone())
.file_name(request.filename.clone())
.mime_str(&request.mime_type)?,
);
if let Some(lang) = request.language.clone() {
form = form.text("language", lang);
}
request.report(TranscriptionStage::Transcribing);
let response = client
.post(api_url)
.header("Authorization", format!("Bearer {api_key}"))
.multipart(form)
.send()
.await
.context("Failed to send request")?;
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
anyhow::bail!("API error ({status}): {error_text}");
}
let text = response
.text()
.await
.context("Failed to get response text")?;
let resp: OpenAICompatibleResponse =
serde_json::from_str(&text).context("Failed to parse API response")?;
Ok(TranscriptionResult { text: resp.text })
}
#[async_trait]
pub trait TranscriptionBackend: Send + Sync {
fn name(&self) -> &'static str;
fn display_name(&self) -> &'static str;
fn transcribe_sync(
&self,
api_key: &str,
request: TranscriptionRequest,
) -> Result<TranscriptionResult>;
async fn transcribe_async(
&self,
client: &reqwest::Client,
api_key: &str,
request: TranscriptionRequest,
) -> Result<TranscriptionResult>;
}
pub struct ProviderRegistry {
providers: HashMap<&'static str, Arc<dyn TranscriptionBackend>>,
}
impl ProviderRegistry {
pub fn new() -> Self {
let mut providers: HashMap<&'static str, Arc<dyn TranscriptionBackend>> = HashMap::new();
providers.insert("openai", Arc::new(OpenAIProvider));
#[cfg(feature = "realtime")]
providers.insert("openai-realtime", Arc::new(OpenAIRealtimeProvider));
providers.insert("mistral", Arc::new(MistralProvider));
providers.insert("groq", Arc::new(GroqProvider));
providers.insert("deepgram", Arc::new(DeepgramProvider));
providers.insert("elevenlabs", Arc::new(ElevenLabsProvider));
#[cfg(feature = "local-transcription")]
providers.insert("local-whisper", Arc::new(LocalWhisperProvider));
#[cfg(feature = "local-transcription")]
providers.insert("local-parakeet", Arc::new(LocalParakeetProvider));
Self { providers }
}
pub fn get(&self, name: &str) -> Option<Arc<dyn TranscriptionBackend>> {
self.providers.get(name).cloned()
}
pub fn list(&self) -> Vec<&'static str> {
self.providers.keys().copied().collect()
}
pub fn get_by_kind(
&self,
kind: &TranscriptionProvider,
) -> Result<Arc<dyn TranscriptionBackend>> {
self.get(kind.as_str()).ok_or_else(|| {
anyhow::anyhow!(
"Provider '{}' not found in registry. This is a bug - \
all TranscriptionProvider variants must have registered providers.",
kind.as_str()
)
})
}
}
impl Default for ProviderRegistry {
fn default() -> Self {
Self::new()
}
}
pub fn registry() -> &'static ProviderRegistry {
static REGISTRY: OnceLock<ProviderRegistry> = OnceLock::new();
REGISTRY.get_or_init(ProviderRegistry::new)
}