1use crate::server::state::AppState;
4use crate::voice::resolve_voice;
5#[cfg(feature = "web-ui")]
6use axum::response::Html;
7use axum::{
8 Json,
9 body::Body,
10 extract::{Multipart, State},
11 http::{HeaderMap, StatusCode, header},
12 response::{IntoResponse, Response},
13};
14#[cfg(feature = "web-ui")]
15use rust_embed::Embed;
16use serde::{Deserialize, Serialize};
17use tokio_stream::StreamExt as _;
18
19#[cfg(feature = "web-ui")]
21#[derive(Embed)]
22#[folder = "web/dist"]
23struct StaticAssets;
24
25#[cfg(feature = "web-ui")]
31pub async fn serve_static(uri: axum::http::Uri) -> impl IntoResponse {
32 let path = uri.path().trim_start_matches('/');
33
34 let path = if path.is_empty() { "index.html" } else { path };
36 let path = percent_encoding::percent_decode_str(path).decode_utf8_lossy();
37
38 match StaticAssets::get(&path) {
39 Some(content) => {
40 let mime = mime_guess::from_path(path.as_ref()).first_or_octet_stream();
41 let mut headers = HeaderMap::new();
42 headers.insert(header::CONTENT_TYPE, mime.as_ref().parse().unwrap());
43 (headers, content.data.to_vec()).into_response()
44 }
45 None => {
46 if !path.contains('.')
49 && let Some(content) = StaticAssets::get("index.html")
50 {
51 return Html(content.data.to_vec()).into_response();
52 }
53 (StatusCode::NOT_FOUND, "File not found").into_response()
54 }
55 }
56}
57
58#[derive(Serialize)]
63pub struct HealthResponse {
64 status: String,
65 version: String,
66}
67
68pub async fn health_check() -> impl IntoResponse {
69 Json(HealthResponse {
70 status: "healthy".to_string(),
71 version: env!("CARGO_PKG_VERSION").to_string(),
72 })
73}
74
75#[derive(Deserialize)]
80pub struct GenerateRequest {
81 text: String,
82 voice: Option<String>,
83 temperature: Option<f32>,
84 lsd_steps: Option<usize>,
85 eos_threshold: Option<f32>,
86 noise_clamp: Option<f32>,
87}
88
89#[derive(Serialize)]
90struct ErrorResponse {
91 error: String,
92}
93
94pub async fn generate(
95 State(state): State<AppState>,
96 Json(payload): Json<GenerateRequest>,
97) -> Response {
98 let _guard = state.lock.lock().await;
100
101 let model = state.model.clone();
102 let default_voice = state.default_voice_state.clone();
103 let text = payload.text.clone();
104 let voice_spec = payload.voice.clone();
105
106 let result = tokio::task::spawn_blocking(move || {
108 let voice_state = if voice_spec.is_some() {
110 resolve_voice(&model, voice_spec.as_deref())?
111 } else {
112 (*default_voice).clone()
113 };
114
115 let mut model_cloned = (*model).clone();
117 if let Some(t) = payload.temperature {
118 model_cloned.temp = t;
119 }
120 if let Some(s) = payload.lsd_steps {
121 model_cloned.lsd_decode_steps = s;
122 }
123 if let Some(e) = payload.eos_threshold {
124 model_cloned.eos_threshold = e;
125 }
126 if let Some(nc) = payload.noise_clamp {
127 model_cloned.noise_clamp = Some(nc);
128 }
129
130 tracing::info!("Starting generation for text length: {} chars", text.len());
132 let mut audio_chunks = Vec::new();
133 for chunk in model_cloned.generate_stream_long(&text, &voice_state) {
134 audio_chunks.push(chunk?);
135 }
136 if audio_chunks.is_empty() {
137 anyhow::bail!("No audio generated");
138 }
139 let audio = candle_core::Tensor::cat(&audio_chunks, 2)?;
140 let audio = audio.squeeze(0)?;
141
142 let mut buffer = std::io::Cursor::new(Vec::new());
144 pocket_tts::audio::write_wav_to_writer(&mut buffer, &audio, model.sample_rate as u32)?;
145
146 Ok::<Vec<u8>, anyhow::Error>(buffer.into_inner())
147 })
148 .await;
149
150 match result {
151 Ok(Ok(wav_bytes)) => {
152 let mut headers = HeaderMap::new();
153 headers.insert(header::CONTENT_TYPE, "audio/wav".parse().unwrap());
154 headers.insert(
155 header::CONTENT_DISPOSITION,
156 "attachment; filename=\"pocket-tts-output.wav\""
157 .parse()
158 .unwrap(),
159 );
160 (StatusCode::OK, headers, Body::from(wav_bytes)).into_response()
161 }
162 Ok(Err(e)) => (
163 StatusCode::INTERNAL_SERVER_ERROR,
164 Json(ErrorResponse {
165 error: e.to_string(),
166 }),
167 )
168 .into_response(),
169 Err(e) => (
170 StatusCode::INTERNAL_SERVER_ERROR,
171 Json(ErrorResponse {
172 error: format!("Task error: {}", e),
173 }),
174 )
175 .into_response(),
176 }
177}
178
179pub async fn generate_stream(
184 State(state): State<AppState>,
185 Json(payload): Json<GenerateRequest>,
186) -> Response {
187 let model = state.model.clone();
188 let default_voice = state.default_voice_state.clone();
189 let text = payload.text.clone();
190 let voice_spec = payload.voice.clone();
191 let lock = state.lock.clone();
192
193 let (tx, rx) = tokio::sync::mpsc::channel::<Result<Vec<u8>, anyhow::Error>>(10);
195
196 tokio::spawn(async move {
198 let _guard = lock.lock().await;
199
200 let tx_inner = tx.clone();
201 let result = tokio::task::spawn_blocking(move || {
202 let voice_state = if voice_spec.is_some() {
204 resolve_voice(&model, voice_spec.as_deref())?
205 } else {
206 (*default_voice).clone()
207 };
208
209 let mut model_cloned = (*model).clone();
211 if let Some(t) = payload.temperature {
212 model_cloned.temp = t;
213 }
214 if let Some(s) = payload.lsd_steps {
215 model_cloned.lsd_decode_steps = s;
216 }
217 if let Some(e) = payload.eos_threshold {
218 model_cloned.eos_threshold = e;
219 }
220 if let Some(nc) = payload.noise_clamp {
221 model_cloned.noise_clamp = Some(nc);
222 }
223
224 tracing::info!(
226 "Starting streaming generation for text length: {} chars",
227 text.len()
228 );
229 for (i, chunk_res) in model_cloned
230 .generate_stream_long(&text, &voice_state)
231 .enumerate()
232 {
233 if i > 0 && i % 20 == 0 {
234 tracing::info!("Generated chunk {}", i);
235 }
236 match chunk_res {
237 Ok(chunk) => {
238 let chunk = chunk.squeeze(0).map_err(|e| anyhow::anyhow!(e))?;
240 let data = chunk.to_vec2::<f32>().map_err(|e| anyhow::anyhow!(e))?;
241
242 let mut bytes = Vec::new();
243 for (i, _) in data[0].iter().enumerate() {
244 for channel_data in &data {
245 let val = channel_data[i].clamp(-1.0, 1.0);
247 let val = (val * 32767.0) as i16;
248 bytes.extend_from_slice(&val.to_le_bytes());
249 }
250 }
251
252 if tx_inner.blocking_send(Ok(bytes)).is_err() {
253 break; }
255 }
256 Err(e) => {
257 let _ = tx_inner.blocking_send(Err(anyhow::anyhow!(e)));
258 break;
259 }
260 }
261 }
262 Ok::<(), anyhow::Error>(())
263 });
264
265 if let Err(e) = result.await {
266 let _ = tx.send(Err(anyhow::anyhow!("Task error: {}", e))).await;
267 }
268 });
269
270 let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
272 let body_stream = stream.map(
273 |res| -> std::result::Result<axum::body::Bytes, std::io::Error> {
274 match res {
275 Ok(bytes) => Ok(axum::body::Bytes::from(bytes)),
276 Err(e) => Err(std::io::Error::other(e.to_string())),
277 }
278 },
279 );
280
281 Response::builder()
282 .header(header::CONTENT_TYPE, "application/octet-stream")
283 .body(Body::from_stream(body_stream))
284 .unwrap()
285}
286
287pub async fn tts_form(State(state): State<AppState>, mut multipart: Multipart) -> Response {
292 let mut text: Option<String> = None;
293 let mut voice_url: Option<String> = None;
294 let mut voice_wav_bytes: Option<Vec<u8>> = None;
295
296 while let Ok(Some(field)) = multipart.next_field().await {
298 let name = field.name().unwrap_or("").to_string();
299 match name.as_str() {
300 "text" => {
301 text = field.text().await.ok();
302 }
303 "voice_url" => {
304 voice_url = field.text().await.ok();
305 }
306 "voice_wav" => {
307 voice_wav_bytes = field.bytes().await.ok().map(|b| b.to_vec());
308 }
309 _ => {}
310 }
311 }
312
313 let text = match text {
314 Some(t) if !t.trim().is_empty() => t,
315 _ => {
316 return (
317 StatusCode::BAD_REQUEST,
318 Json(ErrorResponse {
319 error: "Text is required".to_string(),
320 }),
321 )
322 .into_response();
323 }
324 };
325
326 let voice = if let Some(bytes) = voice_wav_bytes {
328 use base64::{Engine as _, engine::general_purpose};
330 Some(format!(
331 "data:audio/wav;base64,{}",
332 general_purpose::STANDARD.encode(&bytes)
333 ))
334 } else {
335 voice_url
336 };
337
338 generate(
340 State(state),
341 Json(GenerateRequest {
342 text,
343 voice,
344 temperature: None,
345 lsd_steps: None,
346 eos_threshold: None,
347 noise_clamp: None,
348 }),
349 )
350 .await
351}
352
353#[derive(Deserialize)]
358#[allow(dead_code)]
359pub struct OpenAIRequest {
360 model: String,
361 input: String,
362 voice: Option<String>,
363 response_format: Option<String>,
364}
365
366pub async fn openai_speech(state: State<AppState>, Json(payload): Json<OpenAIRequest>) -> Response {
367 let req = GenerateRequest {
369 text: payload.input,
370 voice: payload.voice,
371 temperature: None,
372 lsd_steps: None,
373 eos_threshold: None,
374 noise_clamp: None,
375 };
376 generate(state, Json(req)).await
377}