use axum::body::Bytes;
use axum::extract::State;
use axum::http::StatusCode;
use axum::http::header;
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::{IntoResponse, Json, Response};
use futures_util::StreamExt;
use futures_util::stream::Stream;
use serde::Serialize;
use std::sync::Arc;
use super::metrics::MetricsRegistry;
use super::{POOL_RETRY_AFTER_MS, POOL_RETRY_AFTER_SECS, RuntimeLimits};
use crate::inference::Engine;
pub struct AppState {
pub engine: Arc<Engine>,
pub limits: RuntimeLimits,
pub metrics_registry: Option<Arc<MetricsRegistry>>,
pub shutdown: tokio_util::sync::CancellationToken,
pub tracker: tokio_util::task::TaskTracker,
}
pub async fn metrics(State(state): State<Arc<AppState>>) -> Response {
match &state.metrics_registry {
Some(registry) => (
StatusCode::OK,
[(
header::CONTENT_TYPE,
"text/plain; version=0.0.4; charset=utf-8",
)],
registry.render_prometheus(),
)
.into_response(),
None => (
StatusCode::NOT_FOUND,
Json(serde_json::json!({
"error": "metrics endpoint disabled",
"code": "metrics_disabled",
})),
)
.into_response(),
}
}
#[derive(Debug, Serialize)]
#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
pub struct HealthResponse {
pub status: String,
pub model: String,
pub version: String,
}
#[derive(Debug, Serialize)]
#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
pub struct ModelInfo {
pub id: String,
pub name: String,
pub version: String,
pub encoder: String,
pub vocab_size: usize,
pub sample_rate: u32,
pub pool_size: usize,
pub pool_available: usize,
pub supported_formats: Vec<String>,
pub supported_rates: Vec<u32>,
pub diarization: bool,
}
#[derive(Debug, Serialize)]
#[cfg_attr(feature = "openapi", derive(utoipa::ToSchema))]
pub struct TranscribeResponse {
pub text: String,
pub words: Vec<crate::inference::WordInfo>,
pub duration: f64,
}
type ApiError = Response;
fn api_error(status: StatusCode, msg: &str, code: &str) -> ApiError {
(
status,
Json(serde_json::json!({"error": msg, "code": code})),
)
.into_response()
}
fn api_timeout_error() -> ApiError {
(
StatusCode::SERVICE_UNAVAILABLE,
[(header::RETRY_AFTER, POOL_RETRY_AFTER_SECS.to_string())],
Json(serde_json::json!({
"error": "Server busy, try again later",
"code": "timeout",
"retry_after_ms": POOL_RETRY_AFTER_MS,
})),
)
.into_response()
}
fn api_pool_closed_error() -> ApiError {
(
StatusCode::SERVICE_UNAVAILABLE,
Json(serde_json::json!({
"error": "Server is shutting down",
"code": "pool_closed",
})),
)
.into_response()
}
async fn checkout_triplet(
engine: &std::sync::Arc<Engine>,
) -> Result<
(
crate::inference::SessionTriplet,
crate::inference::OwnedReservation<crate::inference::SessionTriplet>,
),
ApiError,
> {
match tokio::time::timeout(std::time::Duration::from_secs(30), engine.pool.checkout()).await {
Ok(Ok(guard)) => {
let (triplet, reservation) = guard.into_owned();
Ok((triplet, reservation))
}
Ok(Err(_pool_closed)) => Err(api_pool_closed_error()),
Err(_timeout) => Err(api_timeout_error()),
}
}
#[cfg_attr(feature = "openapi", utoipa::path(
get,
path = "/health",
responses((status = 200, body = HealthResponse))
))]
pub async fn health(State(state): State<Arc<AppState>>) -> Json<HealthResponse> {
let engine = &state.engine;
let status = if engine.pool.available() > 0 || engine.pool.total() == 0 {
"ok"
} else {
"degraded"
};
Json(HealthResponse {
status: status.into(),
model: "zipformer-vi-rnnt".into(),
version: env!("CARGO_PKG_VERSION").into(),
})
}
#[cfg_attr(feature = "openapi", utoipa::path(
get,
path = "/v1/models",
responses((status = 200, body = ModelInfo))
))]
pub async fn models(State(state): State<Arc<AppState>>) -> Json<ModelInfo> {
let engine = &state.engine;
#[cfg(feature = "diarization")]
let diarization = engine.has_speaker_encoder();
#[cfg(not(feature = "diarization"))]
let diarization = false;
Json(ModelInfo {
id: "zipformer-vi-rnnt".into(),
name: "Zipformer-vi RNN-T".into(),
version: env!("CARGO_PKG_VERSION").into(),
encoder: "int8".into(),
vocab_size: engine.vocab_size(),
sample_rate: crate::inference::TARGET_SAMPLE_RATE,
pool_size: engine.pool.total(),
pool_available: engine.pool.available(),
supported_formats: vec![
"wav".into(),
"mp3".into(),
"m4a".into(),
"ogg".into(),
"flac".into(),
],
supported_rates: super::SUPPORTED_RATES.to_vec(),
diarization,
})
}
#[cfg_attr(feature = "openapi", utoipa::path(
post,
path = "/v1/transcribe",
request_body = Vec<u8>,
responses((status = 200, body = TranscribeResponse))
))]
pub async fn transcribe(
State(state): State<Arc<AppState>>,
body: Bytes,
) -> Result<Json<TranscribeResponse>, ApiError> {
if body.is_empty() {
return Err(api_error(
StatusCode::BAD_REQUEST,
"Empty request body",
"empty_body",
));
}
if body.len() > state.limits.body_limit_bytes {
return Err(api_error(
StatusCode::PAYLOAD_TOO_LARGE,
"Request body exceeds the configured size limit",
"payload_too_large",
));
}
let (triplet, reservation) = checkout_triplet(&state.engine).await?;
let engine = state.engine.clone();
let result = tokio::task::spawn_blocking(move || {
let mut triplet = triplet;
let r = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
engine.transcribe_bytes_shared(body, &mut triplet)
}));
match r {
Ok(inference_result) => (inference_result, triplet),
Err(_) => {
tracing::error!("Panic in REST transcribe — triplet recovered");
(
Err(crate::error::PhosttError::Inference(
"Inference thread panicked".into(),
)),
triplet,
)
}
}
})
.await;
match result {
Ok((Ok(result), triplet)) => {
reservation.checkin(triplet);
Ok(Json(TranscribeResponse {
text: result.text,
words: result.words,
duration: result.duration_s,
}))
}
Ok((Err(e), triplet)) => {
reservation.checkin(triplet);
tracing::error!("Transcription error: {e}");
Err(api_error(
StatusCode::UNPROCESSABLE_ENTITY,
"Transcription failed. Check audio format.",
"transcription_error",
))
}
Err(e) => {
tracing::error!("spawn_blocking join error: {e}");
Err(api_error(
StatusCode::INTERNAL_SERVER_ERROR,
"Internal server error",
"internal",
))
}
}
}
#[cfg_attr(feature = "openapi", utoipa::path(
post,
path = "/v1/transcribe/stream",
request_body = Vec<u8>,
responses((status = 200, description = "SSE stream of transcript segments"))
))]
pub async fn transcribe_stream(
State(state): State<Arc<AppState>>,
body: Bytes,
) -> Result<Sse<impl Stream<Item = Result<Event, std::convert::Infallible>>>, ApiError> {
if body.is_empty() {
return Err(api_error(
StatusCode::BAD_REQUEST,
"Empty request body",
"empty_body",
));
}
if body.len() > state.limits.body_limit_bytes {
return Err(api_error(
StatusCode::PAYLOAD_TOO_LARGE,
"Request body exceeds the configured size limit",
"payload_too_large",
));
}
let (triplet, reservation) = checkout_triplet(&state.engine).await?;
let (tx, rx) =
tokio::sync::mpsc::channel::<Result<crate::inference::TranscriptSegment, String>>(16);
let engine = state.engine.clone();
let cancel = state.shutdown.clone();
let tracker = state.tracker.clone();
tracker.spawn_blocking(move || {
let mut triplet = triplet;
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut stream_state = match engine.create_state(false) {
Ok(s) => s,
Err(e) => {
let _ = tx.blocking_send(Err(format!("{e}")));
return;
}
};
let chunk_size = crate::inference::TARGET_SAMPLE_RATE as usize; let mut chunk_buf: Vec<f32> = Vec::with_capacity(chunk_size);
let decode_result = crate::inference::audio::decode_audio_streaming(body, |samples| {
if cancel.is_cancelled() {
return Ok(());
}
chunk_buf.extend_from_slice(samples);
while chunk_buf.len() >= chunk_size {
let chunk: Vec<f32> = chunk_buf.drain(..chunk_size).collect();
match engine.process_chunk(&chunk, &mut stream_state, &mut triplet) {
Ok(segs) => {
for seg in segs {
if tx.blocking_send(Ok(seg)).is_err() {
return Err(anyhow::anyhow!("receiver dropped"));
}
}
}
Err(e) => {
let _ = tx.blocking_send(Err(format!("{e}")));
return Err(anyhow::anyhow!("inference failed"));
}
}
}
Ok(())
});
if let Err(e) = decode_result {
tracing::error!("Streaming decode error: {e:#}");
let _ = tx.blocking_send(Err(format!("{e}")));
return;
}
if !chunk_buf.is_empty() && !cancel.is_cancelled() {
match engine.process_chunk(&chunk_buf, &mut stream_state, &mut triplet) {
Ok(segs) => {
for seg in segs {
if tx.blocking_send(Ok(seg)).is_err() {
return;
}
}
}
Err(e) => {
let _ = tx.blocking_send(Err(format!("{e}")));
return;
}
}
}
if !cancel.is_cancelled()
&& let Some(seg) = engine.flush_state(&mut stream_state, &mut triplet)
{
let _ = tx.blocking_send(Ok(seg));
}
}));
if result.is_err() {
tracing::error!("Panic in SSE inference task — triplet recovered");
}
reservation.checkin(triplet);
});
let stream =
tokio_stream::wrappers::ReceiverStream::new(rx).map(|result| Ok(segment_to_event(result)));
Ok(Sse::new(stream).keep_alive(KeepAlive::default()))
}
fn segment_to_event(result: Result<crate::inference::TranscriptSegment, String>) -> Event {
Event::default().data(segment_to_json_value(result).to_string())
}
fn segment_to_json_value(
result: Result<crate::inference::TranscriptSegment, String>,
) -> serde_json::Value {
match result {
Ok(seg) => {
if seg.is_final {
serde_json::json!({"type": "final", "text": seg.text.as_ref(), "timestamp": seg.timestamp, "words": seg.words.as_ref()})
} else {
serde_json::json!({"type": "partial", "text": seg.text.as_ref(), "timestamp": seg.timestamp, "words": seg.words.as_ref()})
}
}
Err(_) => {
serde_json::json!({"type": "error", "message": "Transcription failed.", "code": "inference_error"})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::inference::{Engine, TranscriptSegment, WordInfo};
use axum::body::to_bytes;
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
use tokio_util::task::TaskTracker;
fn test_state(limits: RuntimeLimits, metrics: Option<Arc<MetricsRegistry>>) -> Arc<AppState> {
Arc::new(AppState {
engine: Arc::new(Engine::test_stub()),
limits,
metrics_registry: metrics,
shutdown: CancellationToken::new(),
tracker: TaskTracker::new(),
})
}
#[tokio::test]
async fn test_transcribe_empty_body() {
let state = test_state(RuntimeLimits::default(), None);
let result = transcribe(State(state), Bytes::new()).await;
assert!(result.is_err());
let resp = result.unwrap_err();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let body = to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["code"], "empty_body");
}
#[tokio::test]
async fn test_stream_empty_body() {
let state = test_state(RuntimeLimits::default(), None);
let result = transcribe_stream(State(state), Bytes::new()).await;
assert!(result.is_err());
let resp = result.unwrap_err();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let body = to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["code"], "empty_body");
}
#[tokio::test]
async fn test_transcribe_payload_too_large() {
let limits = RuntimeLimits {
body_limit_bytes: 10,
..RuntimeLimits::default()
};
let state = test_state(limits, None);
let result = transcribe(State(state), Bytes::from(vec![0u8; 100])).await;
assert!(result.is_err());
let resp = result.unwrap_err();
assert_eq!(resp.status(), StatusCode::PAYLOAD_TOO_LARGE);
let body = to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["code"], "payload_too_large");
}
#[tokio::test]
async fn test_stream_payload_too_large() {
let limits = RuntimeLimits {
body_limit_bytes: 10,
..RuntimeLimits::default()
};
let state = test_state(limits, None);
let result = transcribe_stream(State(state), Bytes::from(vec![0u8; 100])).await;
assert!(result.is_err());
let resp = result.unwrap_err();
assert_eq!(resp.status(), StatusCode::PAYLOAD_TOO_LARGE);
let body = to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["code"], "payload_too_large");
}
#[tokio::test]
async fn test_transcribe_pool_timeout() {
tokio::time::pause();
let state = test_state(RuntimeLimits::default(), None);
let handle =
tokio::spawn(async move { transcribe(State(state), Bytes::from(vec![1u8])).await });
tokio::time::advance(std::time::Duration::from_secs(31)).await;
let result = handle.await.unwrap();
assert!(result.is_err());
let resp = result.unwrap_err();
assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
assert_eq!(
resp.headers().get(header::RETRY_AFTER).unwrap(),
POOL_RETRY_AFTER_SECS.to_string().as_str()
);
let body = to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["code"], "timeout");
assert_eq!(json["retry_after_ms"], POOL_RETRY_AFTER_MS);
}
#[tokio::test]
async fn test_stream_pool_timeout() {
tokio::time::pause();
let state = test_state(RuntimeLimits::default(), None);
let handle =
tokio::spawn(
async move { transcribe_stream(State(state), Bytes::from(vec![1u8])).await },
);
tokio::time::advance(std::time::Duration::from_secs(31)).await;
let result = handle.await.unwrap();
assert!(result.is_err());
let resp = result.unwrap_err();
assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
assert_eq!(
resp.headers().get(header::RETRY_AFTER).unwrap(),
POOL_RETRY_AFTER_SECS.to_string().as_str()
);
let body = to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["code"], "timeout");
assert_eq!(json["retry_after_ms"], POOL_RETRY_AFTER_MS);
}
#[tokio::test]
async fn test_transcribe_pool_closed() {
let state = test_state(RuntimeLimits::default(), None);
state.engine.pool.close();
let result = transcribe(State(state), Bytes::from(vec![1u8])).await;
assert!(result.is_err());
let resp = result.unwrap_err();
assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
let body = to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["code"], "pool_closed");
assert!(json.get("retry_after_ms").is_none());
}
#[tokio::test]
async fn test_stream_pool_closed() {
let state = test_state(RuntimeLimits::default(), None);
state.engine.pool.close();
let result = transcribe_stream(State(state), Bytes::from(vec![1u8])).await;
assert!(result.is_err());
let resp = result.unwrap_err();
assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
let body = to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["code"], "pool_closed");
assert!(json.get("retry_after_ms").is_none());
}
#[tokio::test]
async fn test_metrics_disabled() {
let state = test_state(RuntimeLimits::default(), None);
let resp = metrics(State(state)).await;
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
let body = to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert_eq!(json["code"], "metrics_disabled");
}
#[tokio::test]
async fn test_metrics_enabled() {
let registry = Arc::new(MetricsRegistry::new());
registry.register_counter("requests_total", "Total requests");
registry.counter_inc("requests_total", vec![], 1);
let state = test_state(RuntimeLimits::default(), Some(registry));
let resp = metrics(State(state)).await;
assert_eq!(resp.status(), StatusCode::OK);
let ct = resp.headers().get(header::CONTENT_TYPE).unwrap();
assert!(ct.to_str().unwrap().contains("text/plain"));
let body = to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let text = String::from_utf8(body.to_vec()).unwrap();
assert!(text.contains("requests_total"));
}
#[test]
fn test_sse_partial_event() {
let seg = TranscriptSegment {
text: Arc::new("hello".into()),
words: Arc::new(vec![]),
is_final: false,
timestamp: 1.5,
};
let json = segment_to_json_value(Ok(seg));
assert_eq!(json["type"], "partial");
assert_eq!(json["text"], "hello");
assert_eq!(json["timestamp"], 1.5);
assert!(json["words"].is_array());
}
#[test]
fn test_sse_final_event() {
let word = WordInfo {
word: "world".into(),
start: 0.0,
end: 1.0,
confidence: 0.95,
speaker: None,
};
let seg = TranscriptSegment {
text: Arc::new("world".into()),
words: Arc::new(vec![word]),
is_final: true,
timestamp: 2.0,
};
let json = segment_to_json_value(Ok(seg));
assert_eq!(json["type"], "final");
assert_eq!(json["text"], "world");
let words = json["words"].as_array().unwrap();
assert_eq!(words.len(), 1);
assert_eq!(words[0]["word"], "world");
}
#[test]
fn test_sse_error_event() {
let json = segment_to_json_value(Err("boom".into()));
assert_eq!(json["type"], "error");
assert_eq!(json["code"], "inference_error");
}
}