pocket_tts_cli/server/
handlers.rs

1//! HTTP request handlers
2
3use crate::server::state::AppState;
4use crate::voice::resolve_voice;
5#[cfg(feature = "web-ui")]
6use axum::response::Html;
7use axum::{
8    Json,
9    body::Body,
10    extract::{Multipart, State},
11    http::{HeaderMap, StatusCode, header},
12    response::{IntoResponse, Response},
13};
14#[cfg(feature = "web-ui")]
15use rust_embed::Embed;
16use serde::{Deserialize, Serialize};
17use tokio_stream::StreamExt as _;
18
19// Embed static files at compile time
20#[cfg(feature = "web-ui")]
21#[derive(Embed)]
22#[folder = "web/dist"]
23struct StaticAssets;
24
25// ============================================================================
26// Static file serving
27// ============================================================================
28
29/// Serve static files (CSS, JS, images)
30#[cfg(feature = "web-ui")]
31pub async fn serve_static(uri: axum::http::Uri) -> impl IntoResponse {
32    let path = uri.path().trim_start_matches('/');
33
34    // If path is empty, serve index.html
35    let path = if path.is_empty() { "index.html" } else { path };
36    let path = percent_encoding::percent_decode_str(path).decode_utf8_lossy();
37
38    match StaticAssets::get(&path) {
39        Some(content) => {
40            let mime = mime_guess::from_path(path.as_ref()).first_or_octet_stream();
41            let mut headers = HeaderMap::new();
42            headers.insert(header::CONTENT_TYPE, mime.as_ref().parse().unwrap());
43            (headers, content.data.to_vec()).into_response()
44        }
45        None => {
46            // If the request doesn't match a static file, return index.html (for SPA routing)
47            // But only if it doesn't look like a file request (to avoid infinite loops on 404s)
48            if !path.contains('.')
49                && let Some(content) = StaticAssets::get("index.html")
50            {
51                return Html(content.data.to_vec()).into_response();
52            }
53            (StatusCode::NOT_FOUND, "File not found").into_response()
54        }
55    }
56}
57
58// ============================================================================
59// Health check
60// ============================================================================
61
62#[derive(Serialize)]
63pub struct HealthResponse {
64    status: String,
65    version: String,
66}
67
68pub async fn health_check() -> impl IntoResponse {
69    Json(HealthResponse {
70        status: "healthy".to_string(),
71        version: env!("CARGO_PKG_VERSION").to_string(),
72    })
73}
74
75// ============================================================================
76// Generation (JSON API)
77// ============================================================================
78
79#[derive(Deserialize)]
80pub struct GenerateRequest {
81    text: String,
82    voice: Option<String>,
83    temperature: Option<f32>,
84    lsd_steps: Option<usize>,
85    eos_threshold: Option<f32>,
86    noise_clamp: Option<f32>,
87}
88
89#[derive(Serialize)]
90struct ErrorResponse {
91    error: String,
92}
93
94pub async fn generate(
95    State(state): State<AppState>,
96    Json(payload): Json<GenerateRequest>,
97) -> Response {
98    // Acquire lock for sequential processing
99    let _guard = state.lock.lock().await;
100
101    let model = state.model.clone();
102    let default_voice = state.default_voice_state.clone();
103    let text = payload.text.clone();
104    let voice_spec = payload.voice.clone();
105
106    // Run generation in blocking thread
107    let result = tokio::task::spawn_blocking(move || {
108        // Resolve voice (use default if not specified)
109        let voice_state = if voice_spec.is_some() {
110            resolve_voice(&model, voice_spec.as_deref())?
111        } else {
112            (*default_voice).clone()
113        };
114
115        // Override model params if provided in request
116        let mut model_cloned = (*model).clone();
117        if let Some(t) = payload.temperature {
118            model_cloned.temp = t;
119        }
120        if let Some(s) = payload.lsd_steps {
121            model_cloned.lsd_decode_steps = s;
122        }
123        if let Some(e) = payload.eos_threshold {
124            model_cloned.eos_threshold = e;
125        }
126        if let Some(nc) = payload.noise_clamp {
127            model_cloned.noise_clamp = Some(nc);
128        }
129
130        // Generate audio
131        tracing::info!("Starting generation for text length: {} chars", text.len());
132        let mut audio_chunks = Vec::new();
133        for chunk in model_cloned.generate_stream_long(&text, &voice_state) {
134            audio_chunks.push(chunk?);
135        }
136        if audio_chunks.is_empty() {
137            anyhow::bail!("No audio generated");
138        }
139        let audio = candle_core::Tensor::cat(&audio_chunks, 2)?;
140        let audio = audio.squeeze(0)?;
141
142        // Encode as WAV
143        let mut buffer = std::io::Cursor::new(Vec::new());
144        pocket_tts::audio::write_wav_to_writer(&mut buffer, &audio, model.sample_rate as u32)?;
145
146        Ok::<Vec<u8>, anyhow::Error>(buffer.into_inner())
147    })
148    .await;
149
150    match result {
151        Ok(Ok(wav_bytes)) => {
152            let mut headers = HeaderMap::new();
153            headers.insert(header::CONTENT_TYPE, "audio/wav".parse().unwrap());
154            headers.insert(
155                header::CONTENT_DISPOSITION,
156                "attachment; filename=\"pocket-tts-output.wav\""
157                    .parse()
158                    .unwrap(),
159            );
160            (StatusCode::OK, headers, Body::from(wav_bytes)).into_response()
161        }
162        Ok(Err(e)) => (
163            StatusCode::INTERNAL_SERVER_ERROR,
164            Json(ErrorResponse {
165                error: e.to_string(),
166            }),
167        )
168            .into_response(),
169        Err(e) => (
170            StatusCode::INTERNAL_SERVER_ERROR,
171            Json(ErrorResponse {
172                error: format!("Task error: {}", e),
173            }),
174        )
175            .into_response(),
176    }
177}
178
179// ============================================================================
180// Streaming generation
181// ============================================================================
182
183pub async fn generate_stream(
184    State(state): State<AppState>,
185    Json(payload): Json<GenerateRequest>,
186) -> Response {
187    let model = state.model.clone();
188    let default_voice = state.default_voice_state.clone();
189    let text = payload.text.clone();
190    let voice_spec = payload.voice.clone();
191    let lock = state.lock.clone();
192
193    // Channel for streaming chunks
194    let (tx, rx) = tokio::sync::mpsc::channel::<Result<Vec<u8>, anyhow::Error>>(10);
195
196    // Spawn generation task
197    tokio::spawn(async move {
198        let _guard = lock.lock().await;
199
200        let tx_inner = tx.clone();
201        let result = tokio::task::spawn_blocking(move || {
202            // Resolve voice
203            let voice_state = if voice_spec.is_some() {
204                resolve_voice(&model, voice_spec.as_deref())?
205            } else {
206                (*default_voice).clone()
207            };
208
209            // Override model params if provided in request
210            let mut model_cloned = (*model).clone();
211            if let Some(t) = payload.temperature {
212                model_cloned.temp = t;
213            }
214            if let Some(s) = payload.lsd_steps {
215                model_cloned.lsd_decode_steps = s;
216            }
217            if let Some(e) = payload.eos_threshold {
218                model_cloned.eos_threshold = e;
219            }
220            if let Some(nc) = payload.noise_clamp {
221                model_cloned.noise_clamp = Some(nc);
222            }
223
224            // Stream audio chunks
225            tracing::info!(
226                "Starting streaming generation for text length: {} chars",
227                text.len()
228            );
229            for (i, chunk_res) in model_cloned
230                .generate_stream_long(&text, &voice_state)
231                .enumerate()
232            {
233                if i > 0 && i % 20 == 0 {
234                    tracing::info!("Generated chunk {}", i);
235                }
236                match chunk_res {
237                    Ok(chunk) => {
238                        // Convert tensor to 16-bit PCM bytes
239                        let chunk = chunk.squeeze(0).map_err(|e| anyhow::anyhow!(e))?;
240                        let data = chunk.to_vec2::<f32>().map_err(|e| anyhow::anyhow!(e))?;
241
242                        let mut bytes = Vec::new();
243                        for (i, _) in data[0].iter().enumerate() {
244                            for channel_data in &data {
245                                // Hard clamp to [-1, 1] to match Python's behavior
246                                let val = channel_data[i].clamp(-1.0, 1.0);
247                                let val = (val * 32767.0) as i16;
248                                bytes.extend_from_slice(&val.to_le_bytes());
249                            }
250                        }
251
252                        if tx_inner.blocking_send(Ok(bytes)).is_err() {
253                            break; // Receiver dropped
254                        }
255                    }
256                    Err(e) => {
257                        let _ = tx_inner.blocking_send(Err(anyhow::anyhow!(e)));
258                        break;
259                    }
260                }
261            }
262            Ok::<(), anyhow::Error>(())
263        });
264
265        if let Err(e) = result.await {
266            let _ = tx.send(Err(anyhow::anyhow!("Task error: {}", e))).await;
267        }
268    });
269
270    // Convert channel to stream
271    let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
272    let body_stream = stream.map(
273        |res| -> std::result::Result<axum::body::Bytes, std::io::Error> {
274            match res {
275                Ok(bytes) => Ok(axum::body::Bytes::from(bytes)),
276                Err(e) => Err(std::io::Error::other(e.to_string())),
277            }
278        },
279    );
280
281    Response::builder()
282        .header(header::CONTENT_TYPE, "application/octet-stream")
283        .body(Body::from_stream(body_stream))
284        .unwrap()
285}
286
287// ============================================================================
288// Python API compatibility (/tts with multipart form)
289// ============================================================================
290
291pub async fn tts_form(State(state): State<AppState>, mut multipart: Multipart) -> Response {
292    let mut text: Option<String> = None;
293    let mut voice_url: Option<String> = None;
294    let mut voice_wav_bytes: Option<Vec<u8>> = None;
295
296    // Parse multipart form
297    while let Ok(Some(field)) = multipart.next_field().await {
298        let name = field.name().unwrap_or("").to_string();
299        match name.as_str() {
300            "text" => {
301                text = field.text().await.ok();
302            }
303            "voice_url" => {
304                voice_url = field.text().await.ok();
305            }
306            "voice_wav" => {
307                voice_wav_bytes = field.bytes().await.ok().map(|b| b.to_vec());
308            }
309            _ => {}
310        }
311    }
312
313    let text = match text {
314        Some(t) if !t.trim().is_empty() => t,
315        _ => {
316            return (
317                StatusCode::BAD_REQUEST,
318                Json(ErrorResponse {
319                    error: "Text is required".to_string(),
320                }),
321            )
322                .into_response();
323        }
324    };
325
326    // Determine voice
327    let voice = if let Some(bytes) = voice_wav_bytes {
328        // Use uploaded WAV - encode as base64 for our resolver
329        use base64::{Engine as _, engine::general_purpose};
330        Some(format!(
331            "data:audio/wav;base64,{}",
332            general_purpose::STANDARD.encode(&bytes)
333        ))
334    } else {
335        voice_url
336    };
337
338    // Delegate to JSON generate handler
339    generate(
340        State(state),
341        Json(GenerateRequest {
342            text,
343            voice,
344            temperature: None,
345            lsd_steps: None,
346            eos_threshold: None,
347            noise_clamp: None,
348        }),
349    )
350    .await
351}
352
353// ============================================================================
354// OpenAI compatibility
355// ============================================================================
356
357#[derive(Deserialize)]
358#[allow(dead_code)]
359pub struct OpenAIRequest {
360    model: String,
361    input: String,
362    voice: Option<String>,
363    response_format: Option<String>,
364}
365
366pub async fn openai_speech(state: State<AppState>, Json(payload): Json<OpenAIRequest>) -> Response {
367    // Map OpenAI format to our format
368    let req = GenerateRequest {
369        text: payload.input,
370        voice: payload.voice,
371        temperature: None,
372        lsd_steps: None,
373        eos_threshold: None,
374        noise_clamp: None,
375    };
376    generate(state, Json(req)).await
377}