use super::{AnalysisError, Result};
use serde::{Deserialize, Serialize};
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TranscriptSegment {
pub start: f64,
pub end: f64,
pub text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub words: Option<Vec<WordTiming>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub language: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub confidence: Option<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WordTiming {
pub word: String,
pub start: f64,
pub end: f64,
#[serde(skip_serializing_if = "Option::is_none")]
pub confidence: Option<f32>,
}
const PARAKEET_MODEL_SEARCH_PATHS: &[&str] = &[
"~/.cache/nab/models",
"~/.cache/parakeet",
"/usr/local/share/parakeet/models",
"/opt/parakeet/models",
];
const PARAKEET_BINARY_SEARCH_PATHS: &[&str] =
&["~/.local/bin", "/usr/local/bin", "/opt/parakeet/bin"];
fn detect_parakeet_binary() -> Option<std::path::PathBuf> {
for name in ["parakeet", "parakeet-cli"] {
if let Ok(output) = std::process::Command::new("which").arg(name).output()
&& output.status.success()
{
let p = std::path::PathBuf::from(String::from_utf8_lossy(&output.stdout).trim());
if p.exists() {
return Some(p);
}
}
}
let home = std::env::var("HOME").unwrap_or_default();
for dir in PARAKEET_BINARY_SEARCH_PATHS {
let expanded = dir.replace('~', &home);
for name in ["parakeet", "parakeet-cli"] {
let candidate = std::path::PathBuf::from(&expanded).join(name);
if candidate.exists() {
return Some(candidate);
}
}
}
None
}
fn detect_parakeet_model() -> Option<std::path::PathBuf> {
let home = std::env::var("HOME").unwrap_or_default();
for dir in PARAKEET_MODEL_SEARCH_PATHS {
let expanded = dir.replace('~', &home);
let dir_path = std::path::PathBuf::from(&expanded);
if !dir_path.is_dir() {
continue;
}
let Ok(entries) = std::fs::read_dir(&dir_path) else {
continue;
};
for entry in entries.flatten() {
let name = entry.file_name();
let name_str = name.to_string_lossy();
if name_str.contains("parakeet") && name_str.ends_with(".gguf") {
return Some(entry.path());
}
}
}
None
}
pub const DEFAULT_VLLM_MODEL: &str = "Qwen/Qwen3-ASR-1.7B";
pub const DEFAULT_VLLM_BASE_URL: &str = "http://localhost:8000";
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum TranscriptionBackend {
Parakeet,
ParakeetRemote,
Whisper,
WhisperRemote,
VllmApi {
base_url: String,
model: String,
},
}
impl TranscriptionBackend {
#[must_use]
pub fn auto_detect() -> Self {
if Self::parakeet_available() {
return Self::Parakeet;
}
if Self::python3_available() {
return Self::Whisper;
}
Self::VllmApi {
base_url: DEFAULT_VLLM_BASE_URL.to_string(),
model: DEFAULT_VLLM_MODEL.to_string(),
}
}
fn parakeet_available() -> bool {
detect_parakeet_binary().is_some() && detect_parakeet_model().is_some()
}
fn python3_available() -> bool {
std::process::Command::new("which")
.arg("python3")
.output()
.map(|o| o.status.success())
.unwrap_or(false)
}
}
#[derive(Debug, Deserialize)]
struct TranscriptionApiResponse {
text: String,
}
#[derive(Debug, Deserialize)]
struct ApiErrorBody {
#[serde(default)]
message: String,
}
#[derive(Debug, Deserialize)]
struct ApiErrorEnvelope {
#[serde(default)]
error: Option<ApiErrorBody>,
#[serde(default)]
message: Option<String>,
}
impl ApiErrorEnvelope {
fn into_message(self) -> String {
self.error
.map(|e| e.message)
.or(self.message)
.filter(|s| !s.is_empty())
.unwrap_or_else(|| "unknown API error".to_string())
}
}
#[derive(Debug, Clone)]
pub struct VllmTranscriber {
base_url: String,
model: String,
}
impl VllmTranscriber {
#[must_use]
pub fn new(base_url: &str, model: &str) -> Self {
Self {
base_url: base_url.trim_end_matches('/').to_string(),
model: model.to_string(),
}
}
#[must_use]
pub fn default_local() -> Self {
Self::new(DEFAULT_VLLM_BASE_URL, DEFAULT_VLLM_MODEL)
}
#[must_use]
pub fn transcription_url(&self) -> String {
format!("{}/v1/audio/transcriptions", self.base_url)
}
#[must_use]
pub fn model(&self) -> &str {
&self.model
}
pub async fn transcribe(&self, audio_path: &Path) -> Result<Vec<TranscriptSegment>> {
let client = reqwest::Client::new();
let text = self.post_audio(&client, audio_path).await?;
Ok(vec![TranscriptSegment {
start: 0.0,
end: 0.0,
text,
words: None,
language: None,
confidence: None,
}])
}
pub async fn transcribe_with_language(
&self,
audio_path: &Path,
language: &str,
) -> Result<Vec<TranscriptSegment>> {
let client = reqwest::Client::new();
let text = self
.post_audio_with_language(&client, audio_path, language)
.await?;
Ok(vec![TranscriptSegment {
start: 0.0,
end: 0.0,
text,
words: None,
language: Some(language.to_string()),
confidence: None,
}])
}
async fn post_audio(&self, client: &reqwest::Client, audio_path: &Path) -> Result<String> {
let form = self.build_multipart(audio_path, None).await?;
let resp = client
.post(self.transcription_url())
.multipart(form)
.send()
.await?;
self.extract_text(resp).await
}
async fn post_audio_with_language(
&self,
client: &reqwest::Client,
audio_path: &Path,
language: &str,
) -> Result<String> {
let form = self.build_multipart(audio_path, Some(language)).await?;
let resp = client
.post(self.transcription_url())
.multipart(form)
.send()
.await?;
self.extract_text(resp).await
}
async fn build_multipart(
&self,
audio_path: &Path,
language: Option<&str>,
) -> Result<reqwest::multipart::Form> {
let file_bytes = tokio::fs::read(audio_path).await?;
let filename = audio_path
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("audio.wav")
.to_string();
let file_part = reqwest::multipart::Part::bytes(file_bytes)
.file_name(filename)
.mime_str("audio/wav")
.map_err(|e| AnalysisError::TranscriptionApi(e.to_string()))?;
let mut form = reqwest::multipart::Form::new()
.part("file", file_part)
.text("model", self.model.clone());
if let Some(lang) = language {
form = form.text("language", lang.to_string());
}
Ok(form)
}
async fn extract_text(&self, resp: reqwest::Response) -> Result<String> {
let status = resp.status();
let body = resp.text().await?;
if status.is_success() {
return self.parse_response(&body);
}
let msg = serde_json::from_str::<ApiErrorEnvelope>(&body)
.map_or_else(|_| body.clone(), ApiErrorEnvelope::into_message);
Err(AnalysisError::TranscriptionApi(format!(
"HTTP {status}: {msg}"
)))
}
pub fn parse_response(&self, json: &str) -> Result<String> {
let parsed: TranscriptionApiResponse = serde_json::from_str(json)
.map_err(|e| AnalysisError::TranscriptionApi(format!("malformed response: {e}")))?;
if parsed.text.is_empty() {
return Err(AnalysisError::TranscriptionApi(
"empty transcription text in response".to_string(),
));
}
Ok(parsed.text)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_word_timing_serialization() {
let word = WordTiming {
word: "hello".to_string(),
start: 0.0,
end: 0.5,
confidence: Some(0.95),
};
let json = serde_json::to_string(&word).unwrap();
assert!(json.contains("hello"));
assert!(json.contains("0.95"));
}
#[test]
fn test_segment_serialization() {
let segment = TranscriptSegment {
start: 0.0,
end: 2.5,
text: "Hello world".to_string(),
words: Some(vec![
WordTiming {
word: "Hello".to_string(),
start: 0.0,
end: 0.5,
confidence: Some(0.9),
},
WordTiming {
word: "world".to_string(),
start: 0.6,
end: 1.2,
confidence: Some(0.85),
},
]),
language: Some("en".to_string()),
confidence: None,
};
let json = serde_json::to_string_pretty(&segment).unwrap();
assert!(json.contains("Hello world"));
assert!(json.contains("\"en\""));
}
#[test]
fn vllm_new_uses_supplied_base_url_and_model() {
let t = VllmTranscriber::new("http://spark:8000", "Qwen/Qwen3-ASR-8B");
assert_eq!(t.model(), "Qwen/Qwen3-ASR-8B");
assert_eq!(
t.transcription_url(),
"http://spark:8000/v1/audio/transcriptions"
);
}
#[test]
fn vllm_new_strips_trailing_slash_from_base_url() {
let t = VllmTranscriber::new("http://localhost:8000/", "model");
assert_eq!(
t.transcription_url(),
"http://localhost:8000/v1/audio/transcriptions"
);
}
#[test]
fn vllm_default_local_uses_qwen3_asr_1_7b() {
let t = VllmTranscriber::default_local();
assert_eq!(t.model(), DEFAULT_VLLM_MODEL);
assert_eq!(
t.transcription_url(),
"http://localhost:8000/v1/audio/transcriptions"
);
}
#[test]
fn vllm_default_model_constant_is_qwen3_asr_1_7b() {
assert_eq!(DEFAULT_VLLM_MODEL, "Qwen/Qwen3-ASR-1.7B");
}
#[test]
fn vllm_parse_response_returns_text_for_valid_json() {
let t = VllmTranscriber::default_local();
let json = r#"{"text": "Hello, world."}"#;
let result = t.parse_response(json).unwrap();
assert_eq!(result, "Hello, world.");
}
#[test]
fn vllm_parse_response_preserves_whitespace_in_text() {
let t = VllmTranscriber::default_local();
let json = r#"{"text": " spaced out "}"#;
let result = t.parse_response(json).unwrap();
assert_eq!(result, " spaced out ");
}
#[test]
fn vllm_parse_response_handles_extra_fields_in_response() {
let t = VllmTranscriber::default_local();
let json =
r#"{"text": "Bonjour.", "language": "fr", "duration": 1.2, "task": "transcribe"}"#;
let result = t.parse_response(json).unwrap();
assert_eq!(result, "Bonjour.");
}
#[test]
fn vllm_parse_response_errors_on_malformed_json() {
let t = VllmTranscriber::default_local();
let err = t.parse_response("not json at all").unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("malformed response"),
"expected 'malformed response' in: {msg}"
);
}
#[test]
fn vllm_parse_response_errors_on_empty_text_field() {
let t = VllmTranscriber::default_local();
let err = t.parse_response(r#"{"text": ""}"#).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("empty transcription"),
"expected 'empty transcription' in: {msg}"
);
}
#[test]
fn vllm_parse_response_errors_on_missing_text_field() {
let t = VllmTranscriber::default_local();
let err = t.parse_response(r#"{"result": "something"}"#).unwrap_err();
let msg = err.to_string();
assert!(
msg.contains("malformed response"),
"expected 'malformed response' in: {msg}"
);
}
#[test]
fn backend_enum_vllm_api_round_trips_through_json() {
let backend = TranscriptionBackend::VllmApi {
base_url: "http://localhost:8000".to_string(),
model: "Qwen/Qwen3-ASR-1.7B".to_string(),
};
let json = serde_json::to_string(&backend).unwrap();
assert!(json.contains("vllm_api"), "tag missing: {json}");
assert!(json.contains("Qwen3-ASR"), "model missing: {json}");
let back: TranscriptionBackend = serde_json::from_str(&json).unwrap();
let TranscriptionBackend::VllmApi { base_url, model } = back else {
panic!("wrong variant after round-trip");
};
assert_eq!(base_url, "http://localhost:8000");
assert_eq!(model, "Qwen/Qwen3-ASR-1.7B");
}
#[test]
fn backend_enum_whisper_variant_serialises_correctly() {
let backend = TranscriptionBackend::Whisper;
let json = serde_json::to_string(&backend).unwrap();
assert!(json.contains("\"whisper\""), "unexpected json: {json}");
}
#[test]
fn backend_enum_whisper_remote_variant_serialises_correctly() {
let backend = TranscriptionBackend::WhisperRemote;
let json = serde_json::to_string(&backend).unwrap();
assert!(json.contains("whisper_remote"), "unexpected json: {json}");
}
#[test]
fn backend_enum_parakeet_variant_serialises_correctly() {
let backend = TranscriptionBackend::Parakeet;
let json = serde_json::to_string(&backend).unwrap();
assert!(json.contains("\"parakeet\""), "unexpected json: {json}");
}
#[test]
fn backend_enum_parakeet_remote_variant_serialises_correctly() {
let backend = TranscriptionBackend::ParakeetRemote;
let json = serde_json::to_string(&backend).unwrap();
assert!(json.contains("parakeet_remote"), "unexpected json: {json}");
}
#[test]
fn backend_enum_parakeet_round_trips_through_json() {
let backend = TranscriptionBackend::Parakeet;
let json = serde_json::to_string(&backend).unwrap();
let back: TranscriptionBackend = serde_json::from_str(&json).unwrap();
assert!(
matches!(back, TranscriptionBackend::Parakeet),
"wrong variant after round-trip: {json}"
);
}
#[test]
fn auto_detect_falls_back_to_vllm_when_neither_parakeet_nor_python3_present() {
let backend = TranscriptionBackend::auto_detect();
let _json = serde_json::to_string(&backend).unwrap(); match backend {
TranscriptionBackend::Parakeet
| TranscriptionBackend::ParakeetRemote
| TranscriptionBackend::Whisper
| TranscriptionBackend::WhisperRemote
| TranscriptionBackend::VllmApi { .. } => {}
}
}
#[test]
fn auto_detect_returns_vllm_default_url_and_model_as_fallback() {
if let TranscriptionBackend::VllmApi { base_url, model } =
TranscriptionBackend::auto_detect()
{
assert_eq!(base_url, DEFAULT_VLLM_BASE_URL);
assert_eq!(model, DEFAULT_VLLM_MODEL);
}
}
}