use axum::body::Bytes;
use axum::extract::State;
use axum::http::StatusCode;
use axum::http::header;
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::{IntoResponse, Json, Response};
use futures_util::StreamExt;
use futures_util::stream::Stream;
use serde::Serialize;
use std::sync::Arc;
use super::metrics::MetricsRegistry;
use super::{POOL_RETRY_AFTER_MS, POOL_RETRY_AFTER_SECS, RuntimeLimits};
use crate::inference::Engine;
pub struct AppState {
pub engine: Arc<Engine>,
pub limits: RuntimeLimits,
pub metrics_registry: Option<Arc<MetricsRegistry>>,
pub shutdown: tokio_util::sync::CancellationToken,
pub tracker: tokio_util::task::TaskTracker,
}
pub async fn metrics(State(state): State<Arc<AppState>>) -> Response {
match &state.metrics_registry {
Some(registry) => (
StatusCode::OK,
[(
header::CONTENT_TYPE,
"text/plain; version=0.0.4; charset=utf-8",
)],
registry.render_prometheus(),
)
.into_response(),
None => (
StatusCode::NOT_FOUND,
Json(serde_json::json!({
"error": "metrics endpoint disabled",
"code": "metrics_disabled",
})),
)
.into_response(),
}
}
#[derive(Serialize)]
pub struct HealthResponse {
pub status: String,
pub model: String,
pub version: String,
}
#[derive(Serialize)]
pub struct ModelInfo {
pub id: String,
pub name: String,
pub version: String,
pub encoder: String,
pub vocab_size: usize,
pub sample_rate: u32,
pub pool_size: usize,
pub pool_available: usize,
pub supported_formats: Vec<String>,
pub supported_rates: Vec<u32>,
pub diarization: bool,
}
#[derive(Serialize)]
pub struct TranscribeResponse {
pub text: String,
pub words: Vec<crate::inference::WordInfo>,
pub duration: f64,
}
type ApiError = Response;
fn api_error(status: StatusCode, msg: &str, code: &str) -> ApiError {
(
status,
Json(serde_json::json!({"error": msg, "code": code})),
)
.into_response()
}
fn api_timeout_error() -> ApiError {
(
StatusCode::SERVICE_UNAVAILABLE,
[(header::RETRY_AFTER, POOL_RETRY_AFTER_SECS.to_string())],
Json(serde_json::json!({
"error": "Server busy, try again later",
"code": "timeout",
"retry_after_ms": POOL_RETRY_AFTER_MS,
})),
)
.into_response()
}
fn api_pool_closed_error() -> ApiError {
(
StatusCode::SERVICE_UNAVAILABLE,
Json(serde_json::json!({
"error": "Server is shutting down",
"code": "pool_closed",
})),
)
.into_response()
}
pub async fn health(State(state): State<Arc<AppState>>) -> Json<HealthResponse> {
let _ = &state.engine;
Json(HealthResponse {
status: "ok".into(),
model: "zipformer-vi-rnnt".into(),
version: env!("CARGO_PKG_VERSION").into(),
})
}
pub async fn models(State(state): State<Arc<AppState>>) -> Json<ModelInfo> {
let engine = &state.engine;
#[cfg(feature = "diarization")]
let diarization = engine.has_speaker_encoder();
#[cfg(not(feature = "diarization"))]
let diarization = false;
Json(ModelInfo {
id: "zipformer-vi-rnnt".into(),
name: "Zipformer-vi RNN-T".into(),
version: env!("CARGO_PKG_VERSION").into(),
encoder: "int8".into(),
vocab_size: engine.vocab_size(),
sample_rate: 16000,
pool_size: engine.pool.total(),
pool_available: engine.pool.available(),
supported_formats: vec![
"wav".into(),
"mp3".into(),
"m4a".into(),
"ogg".into(),
"flac".into(),
],
supported_rates: super::SUPPORTED_RATES.to_vec(),
diarization,
})
}
pub async fn transcribe(
State(state): State<Arc<AppState>>,
body: Bytes,
) -> Result<Json<TranscribeResponse>, ApiError> {
if body.is_empty() {
return Err(api_error(
StatusCode::BAD_REQUEST,
"Empty request body",
"empty_body",
));
}
if body.len() > state.limits.body_limit_bytes {
return Err(api_error(
StatusCode::PAYLOAD_TOO_LARGE,
"Request body exceeds the configured size limit",
"payload_too_large",
));
}
let guard = match tokio::time::timeout(
std::time::Duration::from_secs(30),
state.engine.pool.checkout(),
)
.await
{
Ok(Ok(guard)) => guard,
Ok(Err(_pool_closed)) => return Err(api_pool_closed_error()),
Err(_timeout) => return Err(api_timeout_error()),
};
let (triplet, reservation) = guard.into_owned();
let engine = state.engine.clone();
let result = tokio::task::spawn_blocking(move || {
let mut triplet = triplet;
let r = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
engine.transcribe_bytes_shared(body, &mut triplet)
}));
match r {
Ok(inference_result) => (inference_result, triplet),
Err(_) => {
tracing::error!("Panic in REST transcribe — triplet recovered");
(
Err(crate::error::GigasttError::Inference(
"Inference thread panicked".into(),
)),
triplet,
)
}
}
})
.await;
match result {
Ok((Ok(result), triplet)) => {
reservation.checkin(triplet);
Ok(Json(TranscribeResponse {
text: result.text,
words: result.words,
duration: result.duration_s,
}))
}
Ok((Err(e), triplet)) => {
reservation.checkin(triplet);
tracing::error!("Transcription error: {e}");
Err(api_error(
StatusCode::UNPROCESSABLE_ENTITY,
"Transcription failed. Check audio format.",
"transcription_error",
))
}
Err(e) => {
tracing::error!("spawn_blocking join error: {e}");
Err(api_error(
StatusCode::INTERNAL_SERVER_ERROR,
"Internal server error",
"internal",
))
}
}
}
pub async fn transcribe_stream(
State(state): State<Arc<AppState>>,
body: Bytes,
) -> Result<Sse<impl Stream<Item = Result<Event, std::convert::Infallible>>>, ApiError> {
if body.is_empty() {
return Err(api_error(
StatusCode::BAD_REQUEST,
"Empty request body",
"empty_body",
));
}
if body.len() > state.limits.body_limit_bytes {
return Err(api_error(
StatusCode::PAYLOAD_TOO_LARGE,
"Request body exceeds the configured size limit",
"payload_too_large",
));
}
let samples = tokio::task::spawn_blocking(move || {
crate::inference::audio::decode_audio_bytes_shared(body)
})
.await
.map_err(|e| {
tracing::error!("spawn_blocking join error: {e}");
api_error(
StatusCode::INTERNAL_SERVER_ERROR,
"Internal server error",
"internal",
)
})?
.map_err(|e| {
tracing::error!("Audio decode error: {e:#}");
api_error(
StatusCode::UNPROCESSABLE_ENTITY,
"Failed to decode audio file. Check format (WAV, MP3, M4A, OGG, FLAC supported).",
"invalid_audio",
)
})?;
let guard = match tokio::time::timeout(
std::time::Duration::from_secs(30),
state.engine.pool.checkout(),
)
.await
{
Ok(Ok(guard)) => guard,
Ok(Err(_pool_closed)) => return Err(api_pool_closed_error()),
Err(_timeout) => return Err(api_timeout_error()),
};
let (triplet, reservation) = guard.into_owned();
let (tx, rx) =
tokio::sync::mpsc::channel::<Result<crate::inference::TranscriptSegment, String>>(16);
let engine = state.engine.clone();
let cancel = state.shutdown.clone();
let tracker = state.tracker.clone();
tracker.spawn_blocking(move || {
let mut triplet = triplet;
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut stream_state = match engine.create_state(false) {
Ok(s) => s,
Err(e) => {
let _ = tx.blocking_send(Err(format!("{e}")));
return;
}
};
let chunk_size = 16000;
for chunk in samples.chunks(chunk_size) {
if cancel.is_cancelled() {
tracing::info!("SSE transcription cancelled by shutdown");
return;
}
match engine.process_chunk(chunk, &mut stream_state, &mut triplet) {
Ok(segs) => {
for seg in segs {
if tx.blocking_send(Ok(seg)).is_err() {
return;
}
}
}
Err(e) => {
let _ = tx.blocking_send(Err(format!("{e}")));
return;
}
}
}
if !cancel.is_cancelled()
&& let Some(seg) = engine.flush_state(&mut stream_state, &mut triplet)
{
let _ = tx.blocking_send(Ok(seg));
}
}));
if result.is_err() {
tracing::error!("Panic in SSE inference task — triplet recovered");
}
reservation.checkin(triplet);
});
let stream = tokio_stream::wrappers::ReceiverStream::new(rx)
.map(|result| {
let event = match result {
Ok(seg) => {
let msg = if seg.is_final {
serde_json::json!({"type": "final", "text": seg.text.as_ref(), "timestamp": seg.timestamp, "words": seg.words.as_ref()})
} else {
serde_json::json!({"type": "partial", "text": seg.text.as_ref(), "timestamp": seg.timestamp, "words": seg.words.as_ref()})
};
Event::default().data(msg.to_string())
}
Err(_) => {
let msg = serde_json::json!({"type": "error", "message": "Transcription failed.", "code": "inference_error"});
Event::default().data(msg.to_string())
}
};
Ok(event)
});
Ok(Sse::new(stream).keep_alive(KeepAlive::default()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_health_response_serialization() {
let resp = HealthResponse {
status: "ok".into(),
model: "test".into(),
version: "0.3.0".into(),
};
let json = serde_json::to_string(&resp).unwrap();
let v: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(v["status"], "ok");
assert_eq!(v["model"], "test");
}
#[test]
fn test_transcribe_response_serialization() {
let resp = TranscribeResponse {
text: "hello".into(),
words: vec![],
duration: 1.5,
};
let json = serde_json::to_string(&resp).unwrap();
let v: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(v["text"], "hello");
assert_eq!(v["duration"], 1.5);
}
}