pocket_tts_cli/server/
handlers.rs1use crate::server::state::AppState;
4use crate::voice::resolve_voice;
5use axum::{
6 Json,
7 body::Body,
8 extract::{Multipart, State},
9 http::{HeaderMap, StatusCode, header},
10 response::{Html, IntoResponse, Response},
11};
12use rust_embed::Embed;
13use serde::{Deserialize, Serialize};
14use tokio_stream::StreamExt as _;
15
16#[derive(Embed)]
18#[folder = "static/"]
19struct StaticAssets;
20
21pub async fn serve_index() -> impl IntoResponse {
27 match StaticAssets::get("index.html") {
28 Some(content) => Html(content.data.to_vec()).into_response(),
29 None => (StatusCode::NOT_FOUND, "index.html not found").into_response(),
30 }
31}
32
33pub async fn serve_static(uri: axum::http::Uri) -> impl IntoResponse {
35 let path = uri.path().trim_start_matches('/');
36 let path = percent_encoding::percent_decode_str(path).decode_utf8_lossy();
37
38 match StaticAssets::get(&path) {
39 Some(content) => {
40 let mime = mime_guess::from_path(path.as_ref()).first_or_octet_stream();
41 let mut headers = HeaderMap::new();
42 headers.insert(header::CONTENT_TYPE, mime.as_ref().parse().unwrap());
43 (headers, content.data.to_vec()).into_response()
44 }
45 None => (StatusCode::NOT_FOUND, "File not found").into_response(),
46 }
47}
48
49#[derive(Serialize)]
54pub struct HealthResponse {
55 status: String,
56 version: String,
57}
58
59pub async fn health_check() -> impl IntoResponse {
60 Json(HealthResponse {
61 status: "healthy".to_string(),
62 version: env!("CARGO_PKG_VERSION").to_string(),
63 })
64}
65
66#[derive(Deserialize)]
71pub struct GenerateRequest {
72 text: String,
73 voice: Option<String>,
74}
75
76#[derive(Serialize)]
77struct ErrorResponse {
78 error: String,
79}
80
81pub async fn generate(
82 State(state): State<AppState>,
83 Json(payload): Json<GenerateRequest>,
84) -> Response {
85 let _guard = state.lock.lock().await;
87
88 let model = state.model.clone();
89 let default_voice = state.default_voice_state.clone();
90 let text = payload.text.clone();
91 let voice_spec = payload.voice.clone();
92
93 let result = tokio::task::spawn_blocking(move || {
95 let voice_state = if voice_spec.is_some() {
97 resolve_voice(&model, voice_spec.as_deref())?
98 } else {
99 (*default_voice).clone()
100 };
101
102 tracing::info!("Starting generation for text length: {} chars", text.len());
104 let mut audio_chunks = Vec::new();
105 for chunk in model.generate_stream_long(&text, &voice_state) {
106 audio_chunks.push(chunk?);
107 }
108 if audio_chunks.is_empty() {
109 anyhow::bail!("No audio generated");
110 }
111 let audio = candle_core::Tensor::cat(&audio_chunks, 2)?;
112 let audio = audio.squeeze(0)?;
113
114 let mut buffer = std::io::Cursor::new(Vec::new());
116 pocket_tts::audio::write_wav_to_writer(&mut buffer, &audio, model.sample_rate as u32)?;
117
118 Ok::<Vec<u8>, anyhow::Error>(buffer.into_inner())
119 })
120 .await;
121
122 match result {
123 Ok(Ok(wav_bytes)) => {
124 let mut headers = HeaderMap::new();
125 headers.insert(header::CONTENT_TYPE, "audio/wav".parse().unwrap());
126 headers.insert(
127 header::CONTENT_DISPOSITION,
128 "attachment; filename=\"pocket-tts-output.wav\""
129 .parse()
130 .unwrap(),
131 );
132 (StatusCode::OK, headers, Body::from(wav_bytes)).into_response()
133 }
134 Ok(Err(e)) => (
135 StatusCode::INTERNAL_SERVER_ERROR,
136 Json(ErrorResponse {
137 error: e.to_string(),
138 }),
139 )
140 .into_response(),
141 Err(e) => (
142 StatusCode::INTERNAL_SERVER_ERROR,
143 Json(ErrorResponse {
144 error: format!("Task error: {}", e),
145 }),
146 )
147 .into_response(),
148 }
149}
150
151pub async fn generate_stream(
156 State(state): State<AppState>,
157 Json(payload): Json<GenerateRequest>,
158) -> Response {
159 let model = state.model.clone();
160 let default_voice = state.default_voice_state.clone();
161 let text = payload.text.clone();
162 let voice_spec = payload.voice.clone();
163 let lock = state.lock.clone();
164
165 let (tx, rx) = tokio::sync::mpsc::channel::<Result<Vec<u8>, anyhow::Error>>(10);
167
168 tokio::spawn(async move {
170 let _guard = lock.lock().await;
171
172 let tx_inner = tx.clone();
173 let result = tokio::task::spawn_blocking(move || {
174 let voice_state = if voice_spec.is_some() {
176 resolve_voice(&model, voice_spec.as_deref())?
177 } else {
178 (*default_voice).clone()
179 };
180
181 tracing::info!(
183 "Starting streaming generation for text length: {} chars",
184 text.len()
185 );
186 for (i, chunk_res) in model.generate_stream_long(&text, &voice_state).enumerate() {
187 if i > 0 && i % 20 == 0 {
188 tracing::info!("Generated chunk {}", i);
189 }
190 match chunk_res {
191 Ok(chunk) => {
192 let chunk = chunk.squeeze(0).map_err(|e| anyhow::anyhow!(e))?;
194 let data = chunk.to_vec2::<f32>().map_err(|e| anyhow::anyhow!(e))?;
195
196 let mut bytes = Vec::new();
197 for (i, _) in data[0].iter().enumerate() {
198 for channel_data in &data {
199 let val = channel_data[i].clamp(-1.0, 1.0);
201 let val = (val * 32767.0) as i16;
202 bytes.extend_from_slice(&val.to_le_bytes());
203 }
204 }
205
206 if tx_inner.blocking_send(Ok(bytes)).is_err() {
207 break; }
209 }
210 Err(e) => {
211 let _ = tx_inner.blocking_send(Err(anyhow::anyhow!(e)));
212 break;
213 }
214 }
215 }
216 Ok::<(), anyhow::Error>(())
217 });
218
219 if let Err(e) = result.await {
220 let _ = tx.send(Err(anyhow::anyhow!("Task error: {}", e))).await;
221 }
222 });
223
224 let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
226 let body_stream = stream.map(|res| match res {
227 Ok(bytes) => Ok(axum::body::Bytes::from(bytes)),
228 Err(e) => Err(std::io::Error::new(
229 std::io::ErrorKind::Other,
230 e.to_string(),
231 )),
232 });
233
234 Response::builder()
235 .header(header::CONTENT_TYPE, "application/octet-stream")
236 .body(Body::from_stream(body_stream))
237 .unwrap()
238}
239
240pub async fn tts_form(State(state): State<AppState>, mut multipart: Multipart) -> Response {
245 let mut text: Option<String> = None;
246 let mut voice_url: Option<String> = None;
247 let mut voice_wav_bytes: Option<Vec<u8>> = None;
248
249 while let Ok(Some(field)) = multipart.next_field().await {
251 let name = field.name().unwrap_or("").to_string();
252 match name.as_str() {
253 "text" => {
254 text = field.text().await.ok();
255 }
256 "voice_url" => {
257 voice_url = field.text().await.ok();
258 }
259 "voice_wav" => {
260 voice_wav_bytes = field.bytes().await.ok().map(|b| b.to_vec());
261 }
262 _ => {}
263 }
264 }
265
266 let text = match text {
267 Some(t) if !t.trim().is_empty() => t,
268 _ => {
269 return (
270 StatusCode::BAD_REQUEST,
271 Json(ErrorResponse {
272 error: "Text is required".to_string(),
273 }),
274 )
275 .into_response();
276 }
277 };
278
279 let voice = if voice_wav_bytes.is_some() {
281 let bytes = voice_wav_bytes.unwrap();
283 use base64::{Engine as _, engine::general_purpose};
284 Some(format!(
285 "data:audio/wav;base64,{}",
286 general_purpose::STANDARD.encode(&bytes)
287 ))
288 } else {
289 voice_url
290 };
291
292 generate(State(state), Json(GenerateRequest { text, voice })).await
294}
295
296#[derive(Deserialize)]
301#[allow(dead_code)]
302pub struct OpenAIRequest {
303 model: String,
304 input: String,
305 voice: Option<String>,
306 response_format: Option<String>,
307}
308
309pub async fn openai_speech(state: State<AppState>, Json(payload): Json<OpenAIRequest>) -> Response {
310 let req = GenerateRequest {
312 text: payload.input,
313 voice: payload.voice,
314 };
315 generate(state, Json(req)).await
316}