Skip to main content

batuta/serve/banco/
handlers_audio.rs

1//! Audio transcription handler — speech-to-text via whisper-apr.
2//!
3//! With `speech` feature: real transcription using whisper-apr.
4//! Without: dry-run response for API testing.
5
6use axum::{extract::State, http::StatusCode, response::Json};
7use serde::{Deserialize, Serialize};
8
9use super::state::BancoState;
10use super::types::ErrorResponse;
11
12/// POST /api/v1/audio/transcriptions — transcribe audio to text.
13pub async fn transcribe_handler(
14    State(_state): State<BancoState>,
15    Json(request): Json<TranscribeRequest>,
16) -> Result<Json<TranscribeResponse>, (StatusCode, Json<ErrorResponse>)> {
17    transcribe_audio(&request)
18}
19
20/// GET /api/v1/audio/formats — list supported audio formats.
21pub async fn audio_formats_handler() -> Json<AudioFormatsResponse> {
22    Json(AudioFormatsResponse {
23        formats: vec![
24            AudioFormat { extension: "wav".to_string(), mime: "audio/wav".to_string() },
25            AudioFormat { extension: "mp3".to_string(), mime: "audio/mpeg".to_string() },
26            AudioFormat { extension: "flac".to_string(), mime: "audio/flac".to_string() },
27            AudioFormat { extension: "ogg".to_string(), mime: "audio/ogg".to_string() },
28        ],
29        sample_rate: 16000,
30        engine: if cfg!(feature = "speech") { "whisper-apr" } else { "dry-run" }.to_string(),
31    })
32}
33
34// ============================================================================
35// whisper-apr transcription (speech feature)
36// ============================================================================
37
38#[cfg(feature = "speech")]
39fn transcribe_audio(
40    request: &TranscribeRequest,
41) -> Result<Json<TranscribeResponse>, (StatusCode, Json<ErrorResponse>)> {
42    // Decode base64 audio data
43    let audio_bytes = base64_decode(&request.audio_data).map_err(|e| {
44        (
45            StatusCode::BAD_REQUEST,
46            Json(ErrorResponse::new(format!("Invalid base64 audio: {e}"), "invalid_audio", 400)),
47        )
48    })?;
49
50    let ext = request.format.as_deref().unwrap_or("wav");
51
52    // Load audio samples
53    let samples = whisper_apr::audio::load_audio_samples(&audio_bytes, ext).map_err(|e| {
54        (
55            StatusCode::BAD_REQUEST,
56            Json(ErrorResponse::new(format!("Audio decode failed: {e}"), "audio_error", 400)),
57        )
58    })?;
59
60    // Create transcription options
61    let options = whisper_apr::TranscribeOptions {
62        language: request.language.clone(),
63        task: if request.translate.unwrap_or(false) {
64            whisper_apr::Task::Translate
65        } else {
66            whisper_apr::Task::Transcribe
67        },
68        ..Default::default()
69    };
70
71    // Create a tiny whisper model for transcription
72    let model = whisper_apr::WhisperApr::tiny();
73    let result = model.transcribe(&samples, options).map_err(|e| {
74        (
75            StatusCode::INTERNAL_SERVER_ERROR,
76            Json(ErrorResponse::new(
77                format!("Transcription failed: {e}"),
78                "transcription_error",
79                500,
80            )),
81        )
82    })?;
83
84    Ok(Json(TranscribeResponse {
85        text: result.text,
86        language: result.language,
87        duration_secs: samples.len() as f32 / 16000.0,
88        segments: result
89            .segments
90            .into_iter()
91            .map(|s| TranscribeSegment { start: s.start, end: s.end, text: s.text })
92            .collect(),
93    }))
94}
95
96// ============================================================================
97// Dry-run transcription (no speech feature)
98// ============================================================================
99
100#[cfg(not(feature = "speech"))]
101fn transcribe_audio(
102    request: &TranscribeRequest,
103) -> Result<Json<TranscribeResponse>, (StatusCode, Json<ErrorResponse>)> {
104    let audio_len = request.audio_data.len();
105    // Estimate duration from base64 size (rough: 16kHz mono 16-bit = 32KB/sec)
106    let estimated_bytes = audio_len * 3 / 4; // base64 → raw
107    let estimated_duration = estimated_bytes as f32 / 32000.0;
108
109    Ok(Json(TranscribeResponse {
110        text: format!(
111            "[dry-run] Would transcribe {} bytes of {} audio (~{:.1}s). Enable --features speech for real transcription.",
112            audio_len,
113            request.format.as_deref().unwrap_or("wav"),
114            estimated_duration
115        ),
116        language: request.language.clone().unwrap_or_else(|| "en".to_string()),
117        duration_secs: estimated_duration,
118        segments: vec![],
119    }))
120}
121
122/// Simple base64 decoder (no external dependency).
123pub(crate) fn base64_decode(input: &str) -> Result<Vec<u8>, String> {
124    // Use the standard alphabet
125    let table: Vec<u8> =
126        b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/".to_vec();
127
128    let input = input.trim().replace(['\n', '\r', ' '], "");
129    let mut output = Vec::with_capacity(input.len() * 3 / 4);
130    let mut buf: u32 = 0;
131    let mut bits: u32 = 0;
132
133    for c in input.bytes() {
134        if c == b'=' {
135            break;
136        }
137        let val = table.iter().position(|&b| b == c).ok_or("Invalid base64 character")?;
138        buf = (buf << 6) | val as u32;
139        bits += 6;
140        if bits >= 8 {
141            bits -= 8;
142            output.push((buf >> bits) as u8);
143            buf &= (1 << bits) - 1;
144        }
145    }
146    Ok(output)
147}
148
149// ============================================================================
150// Types
151// ============================================================================
152
153/// Transcription request.
154#[derive(Debug, Clone, Deserialize)]
155pub struct TranscribeRequest {
156    /// Base64-encoded audio data.
157    pub audio_data: String,
158    /// Audio format: "wav", "mp3", "flac", "ogg".
159    #[serde(default)]
160    pub format: Option<String>,
161    /// Language code (e.g., "en", "es"). Auto-detected if not specified.
162    #[serde(default)]
163    pub language: Option<String>,
164    /// Translate to English instead of transcribing.
165    #[serde(default)]
166    pub translate: Option<bool>,
167}
168
169/// Transcription response.
170#[derive(Debug, Clone, Serialize)]
171pub struct TranscribeResponse {
172    pub text: String,
173    pub language: String,
174    pub duration_secs: f32,
175    pub segments: Vec<TranscribeSegment>,
176}
177
178/// A timestamped segment.
179#[derive(Debug, Clone, Serialize)]
180pub struct TranscribeSegment {
181    pub start: f32,
182    pub end: f32,
183    pub text: String,
184}
185
186/// Supported audio formats.
187#[derive(Debug, Serialize)]
188pub struct AudioFormatsResponse {
189    pub formats: Vec<AudioFormat>,
190    pub sample_rate: u32,
191    pub engine: String,
192}
193
194/// Audio format info.
195#[derive(Debug, Serialize)]
196pub struct AudioFormat {
197    pub extension: String,
198    pub mime: String,
199}