use crate::ai_context::AiContext;
use crate::ai_types::{AiError, TranscriptionResult};
use crate::config::AiCapability;
use std::io::Cursor;
use bytes::Bytes;
use reqwest::blocking::{RequestBuilder, multipart};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TranscriptionTask {
Transcribe,
Translate,
}
impl TranscriptionTask {
pub fn as_str(self) -> &'static str {
match self {
Self::Transcribe => "transcribe",
Self::Translate => "translate",
}
}
fn capability(self) -> AiCapability {
match self {
Self::Transcribe => AiCapability::AudioTranscribe,
Self::Translate => AiCapability::AudioTranslate,
}
}
fn endpoint_path(self) -> &'static str {
match self {
Self::Transcribe => "/v1/audio/transcriptions",
Self::Translate => "/v1/audio/translations",
}
}
}
pub fn transcribe(
cfg: &AiContext,
bytes: Vec<u8>,
file_name: &str,
mime: &str,
task: TranscriptionTask,
language: Option<&str>,
) -> Result<TranscriptionResult, AiError> {
let transport = super::AiTransport::new(cfg)?;
let capability = task.capability();
let url = endpoint_url(cfg, task)?;
let bytes = Bytes::from(bytes);
let file_name = file_name.to_string();
let mime = mime.to_string();
let language = language.map(str::to_string);
let _permit = cfg.limiter.acquire();
let value = super::retry_with_backoff(
|| {
let request = build_request(
&transport,
capability,
&url,
bytes.clone(),
&file_name,
&mime,
language.as_deref(),
)?;
super::parse_json_response(request.send().map_err(super::reqwest_error)?)
},
std::thread::sleep,
)?;
transport.parse_transcription(value)
}
fn endpoint_url(cfg: &AiContext, task: TranscriptionTask) -> Result<String, AiError> {
let capability = task.capability();
let binding = cfg.binding(capability);
let api_base = binding
.api_base
.as_deref()
.map(str::trim)
.filter(|value| !value.is_empty())
.ok_or_else(|| {
AiError::not_configured(
Some(capability.as_str().to_string()),
format!(
"{}.api_base is required for direct audio {}",
capability.namespace(),
task.as_str()
),
)
})?;
Ok(format!(
"{}{}",
api_base.trim_end_matches('/'),
task.endpoint_path()
))
}
fn build_request(
transport: &super::AiTransport<'_>,
capability: AiCapability,
url: &str,
bytes: Bytes,
file_name: &str,
mime: &str,
language: Option<&str>,
) -> Result<RequestBuilder, AiError> {
let binding = transport.context.binding(capability);
let file_len = u64::try_from(bytes.len()).map_err(|_| {
AiError::parse_failure("transcription payload is too large to send".to_string())
})?;
let file_part = multipart::Part::reader_with_length(Cursor::new(bytes), file_len)
.file_name(file_name.to_string())
.mime_str(mime)
.map_err(|error| {
AiError::parse_failure(format!("invalid transcription MIME type {mime}: {error}"))
})?;
let mut form = multipart::Form::new()
.part("file", file_part)
.text("response_format", "verbose_json");
if let Some(model) = binding.model.as_deref().filter(|value| !value.is_empty()) {
form = form.text("model", model.to_string());
}
if let Some(language) = language
.filter(|value| !value.is_empty())
.or(binding.language.as_deref())
{
form = form.text("language", language.to_string());
}
Ok(super::apply_api_key(
transport
.client
.post(url)
.timeout(super::timeout_for(capability))
.multipart(form),
binding,
))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ai_context::{AiBindings, AiLimiter};
use crate::config::{AiRouting, AiTuning, CapabilityBinding};
use crate::test_http::{RequestHandle, spawn_json_response};
#[test]
fn builds_multipart_and_parses_segments() {
let response = r#"{"text":"hello world","language":"es","model":"whisper-1","task":"translate","segments":[{"start":0.125,"end":1.5,"text":"hello"},{"start":1.5,"end":2.0,"text":" world"}]}"#;
let (api_base, request) = spawn_server(response);
let cfg = test_context(&api_base, None);
let result = transcribe(
&cfg,
b"hola mundo".to_vec(),
"clip.webm",
"audio/webm",
TranscriptionTask::Translate,
Some("es"),
)
.unwrap();
let request = request.join().unwrap().unwrap();
assert!(request.starts_with("POST /v1/audio/translations HTTP/1.1"));
assert!(request.contains("filename=\"clip.webm\""));
assert!(request.contains("name=\"response_format\"\r\n\r\nverbose_json"));
assert!(request.contains("name=\"language\"\r\n\r\nes"));
assert!(request.contains("name=\"model\"\r\n\r\nwhisper-1"));
assert_eq!(result.language.as_deref(), Some("es"));
assert_eq!(result.segments[0].start_ms, 125);
assert_eq!(result.segments[0].end_ms, 1500);
assert_eq!(result.segments[1].start_ms, 1500);
assert_eq!(result.segments[1].end_ms, 2000);
}
#[test]
fn wire_multipart_filename_and_auth() {
let response = r#"{"text":"hello","language":"en","segments":[]}"#;
let (api_base, request) = spawn_server(response);
let cfg = test_context(&api_base, Some("test-token"));
transcribe(
&cfg,
b"audio bytes".to_vec(),
"meeting.m4a",
"audio/mp4",
TranscriptionTask::Transcribe,
None,
)
.unwrap();
let request = request.join().unwrap().unwrap();
assert!(request.starts_with("POST /v1/audio/transcriptions HTTP/1.1"));
assert!(has_header(&request, "authorization", "Bearer test-token"));
assert!(request.contains("name=\"file\"; filename=\"meeting.m4a\""));
assert!(request.contains("Content-Type: audio/mp4"));
}
fn spawn_server(response: &'static str) -> (String, RequestHandle) {
spawn_json_response(response).expect("spawn test server")
}
fn has_header(request: &str, name: &str, value: &str) -> bool {
request.lines().any(|line| {
let Some((header_name, header_value)) = line.split_once(':') else {
return false;
};
header_name.eq_ignore_ascii_case(name) && header_value.trim() == value
})
}
fn test_context(api_base: &str, api_key: Option<&str>) -> AiContext {
let binding = binding(api_base, api_key);
AiContext {
bindings: AiBindings {
embed: binding.clone(),
audio_transcribe: binding.clone(),
audio_translate: binding.clone(),
vision_extract: binding.clone(),
text_generate: binding,
},
tuning: AiTuning {
max_concurrency: 1,
keep_alive: None,
},
limiter: AiLimiter::new(1),
project_id: None,
}
}
fn binding(api_base: &str, api_key: Option<&str>) -> CapabilityBinding {
CapabilityBinding {
routing: AiRouting::Direct,
transport: None,
api_base: Some(api_base.to_string()),
api_key: api_key.map(str::to_string),
model: Some("whisper-1".to_string()),
provider: None,
task: None,
language: None,
target_lang: None,
}
}
}