use std::time::Duration;
use anyhow::Result;
use reqwest::multipart;
#[derive(Debug, Clone)]
pub struct ProviderConfig {
pub url: String,
pub model: String,
pub api_key: Option<String>,
}
impl ProviderConfig {
#[cfg(test)]
pub fn default_local() -> Self {
Self {
url: "http://127.0.0.1:5200/v1/audio/transcriptions".to_string(),
model: "Systran/faster-whisper-small".to_string(),
api_key: None,
}
}
}
#[derive(Debug)]
pub enum TranscribeError {
ProviderUnavailable(String),
ProviderError(String),
NetworkError(String),
}
impl std::fmt::Display for TranscribeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ProviderUnavailable(s) => write!(f, "provider unavailable: {s}"),
Self::ProviderError(s) => write!(f, "provider error: {s}"),
Self::NetworkError(s) => write!(f, "network error: {s}"),
}
}
}
pub async fn transcribe_with_provider(
config: &ProviderConfig,
audio_bytes: Vec<u8>,
filename: &str,
) -> Result<String, TranscribeError> {
if audio_bytes.is_empty() {
return Ok(String::new());
}
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(60))
.build()
.map_err(|e| TranscribeError::NetworkError(e.to_string()))?;
let file_part = multipart::Part::bytes(audio_bytes)
.file_name(filename.to_string())
.mime_str("audio/wav")
.map_err(|e| TranscribeError::NetworkError(e.to_string()))?;
let form = multipart::Form::new()
.part("file", file_part)
.text("model", config.model.clone());
let mut req = client.post(&config.url).multipart(form);
if let Some(key) = &config.api_key {
req = req.bearer_auth(key);
}
let resp = req.send().await.map_err(|e| {
if e.is_connect() || e.is_timeout() {
TranscribeError::ProviderUnavailable(e.to_string())
} else {
TranscribeError::NetworkError(e.to_string())
}
})?;
let status = resp.status();
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
if status.as_u16() == 503 || status.as_u16() == 502 || status.as_u16() == 504 {
return Err(TranscribeError::ProviderUnavailable(format!(
"{} {body}",
status.as_u16()
)));
}
return Err(TranscribeError::ProviderError(format!(
"{} {body}",
status.as_u16()
)));
}
#[derive(serde::Deserialize)]
struct TranscribeResponse {
text: String,
}
let body: TranscribeResponse = resp
.json()
.await
.map_err(|e| TranscribeError::ProviderError(format!("invalid json: {e}")))?;
Ok(body.text.trim().to_string())
}
pub async fn probe_provider(url: &str) -> bool {
let health_url = match reqwest::Url::parse(url) {
Ok(mut u) => {
u.set_path("/health");
u.set_query(None);
u.to_string()
}
Err(_) => format!("{}/health", url.trim_end_matches('/')),
};
let client = match reqwest::Client::builder()
.timeout(Duration::from_secs(3))
.build()
{
Ok(c) => c,
Err(_) => return false,
};
if client
.get(&health_url)
.send()
.await
.map(|r| r.status().as_u16() < 500)
.unwrap_or(false)
{
return true;
}
client
.head(url)
.send()
.await
.map(|r| r.status().as_u16() < 500)
.unwrap_or(false)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_local_config_has_expected_url() {
let cfg = ProviderConfig::default_local();
assert!(cfg.url.contains("5200"));
assert!(cfg.api_key.is_none());
}
#[tokio::test]
async fn transcribe_with_empty_bytes_returns_empty_string() {
let cfg = ProviderConfig::default_local();
let result = transcribe_with_provider(&cfg, vec![], "speech.wav").await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "");
}
#[tokio::test]
async fn transcribe_unreachable_provider_returns_unavailable() {
let mut cfg = ProviderConfig::default_local();
cfg.url = "http://127.0.0.1:19999/v1/audio/transcriptions".to_string();
let result = transcribe_with_provider(&cfg, vec![0u8; 100], "speech.wav").await;
assert!(matches!(
result,
Err(TranscribeError::ProviderUnavailable(_) | TranscribeError::NetworkError(_))
));
}
#[tokio::test]
async fn transcribe_mock_server() {
use axum::{routing::post, Json as AxumJson, Router};
use std::net::SocketAddr;
async fn mock_handler() -> AxumJson<serde_json::Value> {
AxumJson(serde_json::json!({ "text": "hello world" }))
}
let app = Router::new().route("/v1/audio/transcriptions", post(mock_handler));
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr: SocketAddr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
let mut cfg = ProviderConfig::default_local();
cfg.url = format!("http://{addr}/v1/audio/transcriptions");
let audio = vec![0u8; 100];
let result = transcribe_with_provider(&cfg, audio, "speech.wav").await;
assert!(result.is_ok(), "mock server should return ok: {result:?}");
assert_eq!(result.unwrap(), "hello world");
}
}