use std::sync::Arc;
use axum::{
extract::State,
http::{HeaderMap, StatusCode},
response::IntoResponse,
routing::{get, post},
Json, Router,
};
use clap::Parser;
use serde::{Deserialize, Serialize};
use tower_http::cors::CorsLayer;
use kittentts::{download, AudioFormat, EncoderFactory, KittenTTS, SAMPLE_RATE};
#[derive(Parser)]
#[command(name = "kittentts-server")]
#[command(about = "OpenAI-compatible TTS server powered by KittenTTS")]
struct Args {
#[arg(long, default_value = "0.0.0.0")]
host: String,
#[arg(long, default_value_t = 8080)]
port: u16,
#[arg(long, default_value = "KittenML/kitten-tts-mini-0.8")]
model: String,
#[arg(long, default_value = "mp3")]
default_format: String,
}
struct AppState {
tts: KittenTTS,
model_id: String,
default_format: String,
}
#[derive(Deserialize)]
struct SpeechRequest {
#[serde(default)]
#[allow(dead_code)]
model: Option<String>,
input: String,
voice: String,
response_format: Option<String>,
speed: Option<f32>,
}
#[derive(Serialize)]
struct ErrorResponse {
error: ErrorDetail,
}
#[derive(Serialize)]
struct ErrorDetail {
message: String,
#[serde(rename = "type")]
error_type: String,
code: Option<String>,
}
#[derive(Serialize)]
struct ModelsResponse {
object: &'static str,
data: Vec<ModelObject>,
}
#[derive(Serialize)]
struct ModelObject {
id: String,
object: &'static str,
owned_by: &'static str,
}
fn openai_voice_to_kittentts(name: &str) -> Option<&'static str> {
match name {
"alloy" => Some("Luna"),
"echo" => Some("Hugo"),
"fable" => Some("Kiki"),
"onyx" => Some("Bruno"),
"nova" => Some("Bella"),
"shimmer" => Some("Rosie"),
"ash" => Some("Jasper"),
"sage" => Some("Leo"),
"coral" => Some("Rosie"),
_ => None,
}
}
fn resolve_voice(name: &str, available: &[String]) -> Option<String> {
if let Some(mapped) = openai_voice_to_kittentts(name) {
if available.iter().any(|v| v == mapped) {
return Some(mapped.to_string());
}
}
for v in available {
if v.eq_ignore_ascii_case(name) {
return Some(v.clone());
}
}
None
}
fn bad_request(msg: impl Into<String>) -> (StatusCode, Json<ErrorResponse>) {
(
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: ErrorDetail {
message: msg.into(),
error_type: "invalid_request_error".to_string(),
code: None,
},
}),
)
}
fn server_error(msg: impl Into<String>) -> (StatusCode, Json<ErrorResponse>) {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: ErrorDetail {
message: msg.into(),
error_type: "server_error".to_string(),
code: None,
},
}),
)
}
async fn speech_handler(
State(state): State<Arc<AppState>>,
Json(req): Json<SpeechRequest>,
) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
if req.input.is_empty() {
return Err(bad_request("Input text must not be empty."));
}
if req.input.chars().count() > 4096 {
return Err(bad_request("Input text must be at most 4096 characters."));
}
let speed = req.speed.unwrap_or(1.0);
if !(0.25..=4.0).contains(&speed) {
return Err(bad_request("Speed must be between 0.25 and 4.0."));
}
let voice = resolve_voice(&req.voice, &state.tts.available_voices)
.ok_or_else(|| {
bad_request(format!(
"Unknown voice '{}'. Available voices: {} (or OpenAI names: alloy, echo, fable, onyx, nova, shimmer, ash, sage, coral).",
req.voice,
state.tts.available_voices.join(", ")
))
})?;
let format_str = req
.response_format
.as_deref()
.unwrap_or(&state.default_format);
let format = AudioFormat::from_str_openai(format_str)
.ok_or_else(|| bad_request(format!("Unsupported format '{format_str}'. Supported: mp3, wav, opus, flac, pcm.")))?;
let encoder = EncoderFactory::create(format).map_err(|e| bad_request(e.to_string()))?;
let display_input: String = req.input.chars().take(80).collect();
let truncated = if display_input.len() < req.input.len() { "..." } else { "" };
eprintln!(
"POST /v1/audio/speech voice={voice} format={format_str} speed={speed} input=\"{display_input}{truncated}\""
);
let state_clone = Arc::clone(&state);
let input = req.input.clone();
let voice_clone = voice.clone();
let audio = tokio::task::spawn_blocking(move || {
state_clone.tts.generate(&input, &voice_clone, speed, true)
})
.await
.map_err(|e| server_error(format!("TTS task panicked: {e}")))?
.map_err(|e| server_error(format!("TTS generation failed: {e}")))?;
let bytes = encoder
.encode(&audio, SAMPLE_RATE)
.map_err(|e| server_error(format!("Encoding failed: {e}")))?;
let mut headers = HeaderMap::new();
headers.insert(
"content-type",
encoder.content_type().parse().unwrap(),
);
Ok((headers, bytes))
}
async fn list_models(State(state): State<Arc<AppState>>) -> Json<ModelsResponse> {
Json(ModelsResponse {
object: "list",
data: vec![ModelObject {
id: state.model_id.clone(),
object: "model",
owned_by: "kittentts",
}],
})
}
async fn list_voices(State(state): State<Arc<AppState>>) -> Json<serde_json::Value> {
Json(serde_json::json!({
"voices": state.tts.available_voices,
"openai_mapping": {
"alloy": "Luna",
"echo": "Hugo",
"fable": "Kiki",
"onyx": "Bruno",
"nova": "Bella",
"shimmer": "Rosie",
"ash": "Jasper",
"sage": "Leo",
"coral": "Rosie",
}
}))
}
async fn health() -> &'static str {
"ok"
}
async fn shutdown_signal() {
tokio::signal::ctrl_c().await.ok();
eprintln!("\nShutting down...");
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let args = Args::parse();
if AudioFormat::from_str_openai(&args.default_format).is_none() {
anyhow::bail!(
"Invalid default format '{}'. Supported: mp3, wav, opus, flac, pcm.",
args.default_format
);
}
eprintln!("Loading model {}...", args.model);
let tts = download::load_from_hub(&args.model)?;
eprintln!(
"Model loaded. Available voices: {:?}",
tts.available_voices
);
let state = Arc::new(AppState {
tts,
model_id: args.model,
default_format: args.default_format,
});
let app = Router::new()
.route("/v1/audio/speech", post(speech_handler))
.route("/v1/models", get(list_models))
.route("/v1/voices", get(list_voices))
.route("/health", get(health))
.layer(CorsLayer::permissive())
.with_state(state);
let addr = format!("{}:{}", args.host, args.port);
let listener = tokio::net::TcpListener::bind(&addr).await?;
eprintln!("Listening on http://{addr}");
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.await?;
Ok(())
}