pocket_tts_cli/server/
handlers.rs

1//! HTTP request handlers
2
3use crate::server::state::AppState;
4use crate::voice::resolve_voice;
5use axum::{
6    Json,
7    body::Body,
8    extract::{Multipart, State},
9    http::{HeaderMap, StatusCode, header},
10    response::{Html, IntoResponse, Response},
11};
12use rust_embed::Embed;
13use serde::{Deserialize, Serialize};
14use tokio_stream::StreamExt as _;
15
16// Embed static files at compile time
17#[derive(Embed)]
18#[folder = "static/"]
19struct StaticAssets;
20
21// ============================================================================
22// Static file serving
23// ============================================================================
24
25/// Serve the main index.html
26pub async fn serve_index() -> impl IntoResponse {
27    match StaticAssets::get("index.html") {
28        Some(content) => Html(content.data.to_vec()).into_response(),
29        None => (StatusCode::NOT_FOUND, "index.html not found").into_response(),
30    }
31}
32
33/// Serve static files (CSS, JS, images)
34pub async fn serve_static(uri: axum::http::Uri) -> impl IntoResponse {
35    let path = uri.path().trim_start_matches('/');
36    let path = percent_encoding::percent_decode_str(path).decode_utf8_lossy();
37
38    match StaticAssets::get(&path) {
39        Some(content) => {
40            let mime = mime_guess::from_path(path.as_ref()).first_or_octet_stream();
41            let mut headers = HeaderMap::new();
42            headers.insert(header::CONTENT_TYPE, mime.as_ref().parse().unwrap());
43            (headers, content.data.to_vec()).into_response()
44        }
45        None => (StatusCode::NOT_FOUND, "File not found").into_response(),
46    }
47}
48
49// ============================================================================
50// Health check
51// ============================================================================
52
53#[derive(Serialize)]
54pub struct HealthResponse {
55    status: String,
56    version: String,
57}
58
59pub async fn health_check() -> impl IntoResponse {
60    Json(HealthResponse {
61        status: "healthy".to_string(),
62        version: env!("CARGO_PKG_VERSION").to_string(),
63    })
64}
65
66// ============================================================================
67// Generation (JSON API)
68// ============================================================================
69
70#[derive(Deserialize)]
71pub struct GenerateRequest {
72    text: String,
73    voice: Option<String>,
74}
75
76#[derive(Serialize)]
77struct ErrorResponse {
78    error: String,
79}
80
81pub async fn generate(
82    State(state): State<AppState>,
83    Json(payload): Json<GenerateRequest>,
84) -> Response {
85    // Acquire lock for sequential processing
86    let _guard = state.lock.lock().await;
87
88    let model = state.model.clone();
89    let default_voice = state.default_voice_state.clone();
90    let text = payload.text.clone();
91    let voice_spec = payload.voice.clone();
92
93    // Run generation in blocking thread
94    let result = tokio::task::spawn_blocking(move || {
95        // Resolve voice (use default if not specified)
96        let voice_state = if voice_spec.is_some() {
97            resolve_voice(&model, voice_spec.as_deref())?
98        } else {
99            (*default_voice).clone()
100        };
101
102        // Generate audio
103        tracing::info!("Starting generation for text length: {} chars", text.len());
104        let mut audio_chunks = Vec::new();
105        for chunk in model.generate_stream_long(&text, &voice_state) {
106            audio_chunks.push(chunk?);
107        }
108        if audio_chunks.is_empty() {
109            anyhow::bail!("No audio generated");
110        }
111        let audio = candle_core::Tensor::cat(&audio_chunks, 2)?;
112        let audio = audio.squeeze(0)?;
113
114        // Encode as WAV
115        let mut buffer = std::io::Cursor::new(Vec::new());
116        pocket_tts::audio::write_wav_to_writer(&mut buffer, &audio, model.sample_rate as u32)?;
117
118        Ok::<Vec<u8>, anyhow::Error>(buffer.into_inner())
119    })
120    .await;
121
122    match result {
123        Ok(Ok(wav_bytes)) => {
124            let mut headers = HeaderMap::new();
125            headers.insert(header::CONTENT_TYPE, "audio/wav".parse().unwrap());
126            headers.insert(
127                header::CONTENT_DISPOSITION,
128                "attachment; filename=\"pocket-tts-output.wav\""
129                    .parse()
130                    .unwrap(),
131            );
132            (StatusCode::OK, headers, Body::from(wav_bytes)).into_response()
133        }
134        Ok(Err(e)) => (
135            StatusCode::INTERNAL_SERVER_ERROR,
136            Json(ErrorResponse {
137                error: e.to_string(),
138            }),
139        )
140            .into_response(),
141        Err(e) => (
142            StatusCode::INTERNAL_SERVER_ERROR,
143            Json(ErrorResponse {
144                error: format!("Task error: {}", e),
145            }),
146        )
147            .into_response(),
148    }
149}
150
151// ============================================================================
152// Streaming generation
153// ============================================================================
154
155pub async fn generate_stream(
156    State(state): State<AppState>,
157    Json(payload): Json<GenerateRequest>,
158) -> Response {
159    let model = state.model.clone();
160    let default_voice = state.default_voice_state.clone();
161    let text = payload.text.clone();
162    let voice_spec = payload.voice.clone();
163    let lock = state.lock.clone();
164
165    // Channel for streaming chunks
166    let (tx, rx) = tokio::sync::mpsc::channel::<Result<Vec<u8>, anyhow::Error>>(10);
167
168    // Spawn generation task
169    tokio::spawn(async move {
170        let _guard = lock.lock().await;
171
172        let tx_inner = tx.clone();
173        let result = tokio::task::spawn_blocking(move || {
174            // Resolve voice
175            let voice_state = if voice_spec.is_some() {
176                resolve_voice(&model, voice_spec.as_deref())?
177            } else {
178                (*default_voice).clone()
179            };
180
181            // Stream audio chunks
182            tracing::info!(
183                "Starting streaming generation for text length: {} chars",
184                text.len()
185            );
186            for (i, chunk_res) in model.generate_stream_long(&text, &voice_state).enumerate() {
187                if i > 0 && i % 20 == 0 {
188                    tracing::info!("Generated chunk {}", i);
189                }
190                match chunk_res {
191                    Ok(chunk) => {
192                        // Convert tensor to 16-bit PCM bytes
193                        let chunk = chunk.squeeze(0).map_err(|e| anyhow::anyhow!(e))?;
194                        let data = chunk.to_vec2::<f32>().map_err(|e| anyhow::anyhow!(e))?;
195
196                        let mut bytes = Vec::new();
197                        for (i, _) in data[0].iter().enumerate() {
198                            for channel_data in &data {
199                                // Hard clamp to [-1, 1] to match Python's behavior
200                                let val = channel_data[i].clamp(-1.0, 1.0);
201                                let val = (val * 32767.0) as i16;
202                                bytes.extend_from_slice(&val.to_le_bytes());
203                            }
204                        }
205
206                        if tx_inner.blocking_send(Ok(bytes)).is_err() {
207                            break; // Receiver dropped
208                        }
209                    }
210                    Err(e) => {
211                        let _ = tx_inner.blocking_send(Err(anyhow::anyhow!(e)));
212                        break;
213                    }
214                }
215            }
216            Ok::<(), anyhow::Error>(())
217        });
218
219        if let Err(e) = result.await {
220            let _ = tx.send(Err(anyhow::anyhow!("Task error: {}", e))).await;
221        }
222    });
223
224    // Convert channel to stream
225    let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
226    let body_stream = stream.map(|res| match res {
227        Ok(bytes) => Ok(axum::body::Bytes::from(bytes)),
228        Err(e) => Err(std::io::Error::other(e.to_string())),
229    });
230
231    Response::builder()
232        .header(header::CONTENT_TYPE, "application/octet-stream")
233        .body(Body::from_stream(body_stream))
234        .unwrap()
235}
236
237// ============================================================================
238// Python API compatibility (/tts with multipart form)
239// ============================================================================
240
241pub async fn tts_form(State(state): State<AppState>, mut multipart: Multipart) -> Response {
242    let mut text: Option<String> = None;
243    let mut voice_url: Option<String> = None;
244    let mut voice_wav_bytes: Option<Vec<u8>> = None;
245
246    // Parse multipart form
247    while let Ok(Some(field)) = multipart.next_field().await {
248        let name = field.name().unwrap_or("").to_string();
249        match name.as_str() {
250            "text" => {
251                text = field.text().await.ok();
252            }
253            "voice_url" => {
254                voice_url = field.text().await.ok();
255            }
256            "voice_wav" => {
257                voice_wav_bytes = field.bytes().await.ok().map(|b| b.to_vec());
258            }
259            _ => {}
260        }
261    }
262
263    let text = match text {
264        Some(t) if !t.trim().is_empty() => t,
265        _ => {
266            return (
267                StatusCode::BAD_REQUEST,
268                Json(ErrorResponse {
269                    error: "Text is required".to_string(),
270                }),
271            )
272                .into_response();
273        }
274    };
275
276    // Determine voice
277    let voice = if let Some(bytes) = voice_wav_bytes {
278        // Use uploaded WAV - encode as base64 for our resolver
279        use base64::{Engine as _, engine::general_purpose};
280        Some(format!(
281            "data:audio/wav;base64,{}",
282            general_purpose::STANDARD.encode(&bytes)
283        ))
284    } else {
285        voice_url
286    };
287
288    // Delegate to JSON generate handler
289    generate(State(state), Json(GenerateRequest { text, voice })).await
290}
291
292// ============================================================================
293// OpenAI compatibility
294// ============================================================================
295
296#[derive(Deserialize)]
297#[allow(dead_code)]
298pub struct OpenAIRequest {
299    model: String,
300    input: String,
301    voice: Option<String>,
302    response_format: Option<String>,
303}
304
305pub async fn openai_speech(state: State<AppState>, Json(payload): Json<OpenAIRequest>) -> Response {
306    // Map OpenAI format to our format
307    let req = GenerateRequest {
308        text: payload.input,
309        voice: payload.voice,
310    };
311    generate(state, Json(req)).await
312}