batuta/serve/banco/
handlers_audio.rs1use axum::{extract::State, http::StatusCode, response::Json};
7use serde::{Deserialize, Serialize};
8
9use super::state::BancoState;
10use super::types::ErrorResponse;
11
12pub async fn transcribe_handler(
14 State(_state): State<BancoState>,
15 Json(request): Json<TranscribeRequest>,
16) -> Result<Json<TranscribeResponse>, (StatusCode, Json<ErrorResponse>)> {
17 transcribe_audio(&request)
18}
19
20pub async fn audio_formats_handler() -> Json<AudioFormatsResponse> {
22 Json(AudioFormatsResponse {
23 formats: vec![
24 AudioFormat { extension: "wav".to_string(), mime: "audio/wav".to_string() },
25 AudioFormat { extension: "mp3".to_string(), mime: "audio/mpeg".to_string() },
26 AudioFormat { extension: "flac".to_string(), mime: "audio/flac".to_string() },
27 AudioFormat { extension: "ogg".to_string(), mime: "audio/ogg".to_string() },
28 ],
29 sample_rate: 16000,
30 engine: if cfg!(feature = "speech") { "whisper-apr" } else { "dry-run" }.to_string(),
31 })
32}
33
34#[cfg(feature = "speech")]
39fn transcribe_audio(
40 request: &TranscribeRequest,
41) -> Result<Json<TranscribeResponse>, (StatusCode, Json<ErrorResponse>)> {
42 let audio_bytes = base64_decode(&request.audio_data).map_err(|e| {
44 (
45 StatusCode::BAD_REQUEST,
46 Json(ErrorResponse::new(format!("Invalid base64 audio: {e}"), "invalid_audio", 400)),
47 )
48 })?;
49
50 let ext = request.format.as_deref().unwrap_or("wav");
51
52 let samples = whisper_apr::audio::load_audio_samples(&audio_bytes, ext).map_err(|e| {
54 (
55 StatusCode::BAD_REQUEST,
56 Json(ErrorResponse::new(format!("Audio decode failed: {e}"), "audio_error", 400)),
57 )
58 })?;
59
60 let options = whisper_apr::TranscribeOptions {
62 language: request.language.clone(),
63 task: if request.translate.unwrap_or(false) {
64 whisper_apr::Task::Translate
65 } else {
66 whisper_apr::Task::Transcribe
67 },
68 ..Default::default()
69 };
70
71 let model = whisper_apr::WhisperApr::tiny();
73 let result = model.transcribe(&samples, options).map_err(|e| {
74 (
75 StatusCode::INTERNAL_SERVER_ERROR,
76 Json(ErrorResponse::new(
77 format!("Transcription failed: {e}"),
78 "transcription_error",
79 500,
80 )),
81 )
82 })?;
83
84 Ok(Json(TranscribeResponse {
85 text: result.text,
86 language: result.language,
87 duration_secs: samples.len() as f32 / 16000.0,
88 segments: result
89 .segments
90 .into_iter()
91 .map(|s| TranscribeSegment { start: s.start, end: s.end, text: s.text })
92 .collect(),
93 }))
94}
95
96#[cfg(not(feature = "speech"))]
101fn transcribe_audio(
102 request: &TranscribeRequest,
103) -> Result<Json<TranscribeResponse>, (StatusCode, Json<ErrorResponse>)> {
104 let audio_len = request.audio_data.len();
105 let estimated_bytes = audio_len * 3 / 4; let estimated_duration = estimated_bytes as f32 / 32000.0;
108
109 Ok(Json(TranscribeResponse {
110 text: format!(
111 "[dry-run] Would transcribe {} bytes of {} audio (~{:.1}s). Enable --features speech for real transcription.",
112 audio_len,
113 request.format.as_deref().unwrap_or("wav"),
114 estimated_duration
115 ),
116 language: request.language.clone().unwrap_or_else(|| "en".to_string()),
117 duration_secs: estimated_duration,
118 segments: vec![],
119 }))
120}
121
122pub(crate) fn base64_decode(input: &str) -> Result<Vec<u8>, String> {
124 let table: Vec<u8> =
126 b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/".to_vec();
127
128 let input = input.trim().replace(['\n', '\r', ' '], "");
129 let mut output = Vec::with_capacity(input.len() * 3 / 4);
130 let mut buf: u32 = 0;
131 let mut bits: u32 = 0;
132
133 for c in input.bytes() {
134 if c == b'=' {
135 break;
136 }
137 let val = table.iter().position(|&b| b == c).ok_or("Invalid base64 character")?;
138 buf = (buf << 6) | val as u32;
139 bits += 6;
140 if bits >= 8 {
141 bits -= 8;
142 output.push((buf >> bits) as u8);
143 buf &= (1 << bits) - 1;
144 }
145 }
146 Ok(output)
147}
148
149#[derive(Debug, Clone, Deserialize)]
155pub struct TranscribeRequest {
156 pub audio_data: String,
158 #[serde(default)]
160 pub format: Option<String>,
161 #[serde(default)]
163 pub language: Option<String>,
164 #[serde(default)]
166 pub translate: Option<bool>,
167}
168
169#[derive(Debug, Clone, Serialize)]
171pub struct TranscribeResponse {
172 pub text: String,
173 pub language: String,
174 pub duration_secs: f32,
175 pub segments: Vec<TranscribeSegment>,
176}
177
178#[derive(Debug, Clone, Serialize)]
180pub struct TranscribeSegment {
181 pub start: f32,
182 pub end: f32,
183 pub text: String,
184}
185
186#[derive(Debug, Serialize)]
188pub struct AudioFormatsResponse {
189 pub formats: Vec<AudioFormat>,
190 pub sample_rate: u32,
191 pub engine: String,
192}
193
194#[derive(Debug, Serialize)]
196pub struct AudioFormat {
197 pub extension: String,
198 pub mime: String,
199}