use reqwest::multipart::{Form, Part};
use serde::Deserialize;
use std::path::Path;
use thiserror::Error;
use tracing::{error, warn};
#[derive(Error, Debug)]
pub enum TranscriptionError {
#[error("API key not configured")]
NoApiKey,
#[error("File not found: {0}")]
FileNotFound(String),
#[error("Failed to read file: {0}")]
FileReadError(#[from] std::io::Error),
#[error("HTTP error: {0}")]
HttpError(#[from] reqwest::Error),
#[error("API error: {0}")]
ApiError(String),
}
#[derive(Debug, Deserialize)]
struct GroqResponse {
text: String,
}
#[derive(Clone)]
pub struct TranscriptionService {
api_key: Option<String>,
api_url: String,
model: String,
}
impl TranscriptionService {
pub fn new(api_key: Option<String>) -> Self {
let api_key = api_key.or_else(|| std::env::var("GROQ_API_KEY").ok());
Self {
api_key,
api_url: "https://api.groq.com/openai/v1/audio/transcriptions".to_string(),
model: "whisper-large-v3".to_string(),
}
}
pub fn with_settings(api_key: Option<String>, api_url: String, model: String) -> Self {
let api_key = api_key.or_else(|| std::env::var("GROQ_API_KEY").ok());
Self {
api_key,
api_url,
model,
}
}
pub fn is_configured(&self) -> bool {
self.api_key.is_some()
}
pub async fn transcribe<P: AsRef<Path>>(
&self,
file_path: P,
) -> Result<String, TranscriptionError> {
let api_key = self.api_key.as_ref().ok_or(TranscriptionError::NoApiKey)?;
let path = file_path.as_ref();
if !path.exists() {
return Err(TranscriptionError::FileNotFound(
path.to_string_lossy().to_string(),
));
}
let file_bytes = tokio::fs::read(path).await?;
let file_name = path
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("audio.mp3")
.to_string();
let file_part = Part::bytes(file_bytes).file_name(file_name);
let form = Form::new()
.part("file", file_part)
.text("model", self.model.clone());
let client = reqwest::Client::new();
let response = client
.post(&self.api_url)
.header("Authorization", format!("Bearer {}", api_key))
.multipart(form)
.timeout(std::time::Duration::from_secs(60))
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
error!("Groq transcription failed: {} - {}", status, error_text);
return Err(TranscriptionError::ApiError(format!(
"{}: {}",
status, error_text
)));
}
let data: GroqResponse = response.json().await?;
Ok(data.text)
}
pub async fn transcribe_safe<P: AsRef<Path>>(&self, file_path: P) -> String {
match self.transcribe(file_path).await {
Ok(text) => text,
Err(e) => {
warn!("Transcription failed: {}", e);
String::new()
}
}
}
}
impl Default for TranscriptionService {
fn default() -> Self {
Self::new(None)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_service() {
let service = TranscriptionService::new(Some("test_key".to_string()));
assert!(service.is_configured());
assert_eq!(service.model, "whisper-large-v3");
}
#[test]
fn test_new_service_no_key() {
let service = TranscriptionService::new(None);
assert!(!service.is_configured());
}
#[test]
fn test_with_settings() {
let service = TranscriptionService::with_settings(
Some("test_key".to_string()),
"https://custom.api.com/v1/audio/transcriptions".to_string(),
"custom-model".to_string(),
);
assert!(service.is_configured());
assert_eq!(
service.api_url,
"https://custom.api.com/v1/audio/transcriptions"
);
assert_eq!(service.model, "custom-model");
}
#[tokio::test]
async fn test_transcribe_no_api_key() {
let service = TranscriptionService::new(None);
let result = service.transcribe("test.mp3").await;
assert!(matches!(result, Err(TranscriptionError::NoApiKey)));
}
#[tokio::test]
async fn test_transcribe_file_not_found() {
let service = TranscriptionService::new(Some("test_key".to_string()));
let result = service.transcribe("/nonexistent/file.mp3").await;
assert!(matches!(result, Err(TranscriptionError::FileNotFound(_))));
}
#[tokio::test]
async fn test_transcribe_safe_returns_empty_on_error() {
let service = TranscriptionService::new(None);
let result = service.transcribe_safe("test.mp3").await;
assert_eq!(result, "");
}
#[test]
fn test_default() {
let service = TranscriptionService::default();
assert_eq!(service.model, "whisper-large-v3");
}
}