use axum::body::Bytes;
use axum::extract::State;
use axum::http::StatusCode;
use axum::http::header;
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::{IntoResponse, Json, Response};
use futures_util::StreamExt;
use futures_util::stream::Stream;
use serde::Serialize;
use std::sync::Arc;
use super::{POOL_RETRY_AFTER_MS, POOL_RETRY_AFTER_SECS, RuntimeLimits};
use crate::inference::Engine;
use metrics_exporter_prometheus::PrometheusHandle;
pub struct AppState {
pub engine: Arc<Engine>,
pub limits: RuntimeLimits,
pub metrics_handle: Option<PrometheusHandle>,
}
pub async fn metrics(State(state): State<Arc<AppState>>) -> Response {
match &state.metrics_handle {
Some(handle) => (
StatusCode::OK,
[(
header::CONTENT_TYPE,
"text/plain; version=0.0.4; charset=utf-8",
)],
handle.render(),
)
.into_response(),
None => (
StatusCode::NOT_FOUND,
Json(serde_json::json!({
"error": "metrics endpoint disabled",
"code": "metrics_disabled",
})),
)
.into_response(),
}
}
#[derive(Serialize)]
pub struct HealthResponse {
pub status: String,
pub model: String,
pub version: String,
}
#[derive(Serialize)]
pub struct ModelInfo {
pub id: String,
pub name: String,
pub version: String,
pub encoder: String,
pub vocab_size: usize,
pub sample_rate: u32,
pub pool_size: usize,
pub pool_available: usize,
pub supported_formats: Vec<String>,
pub supported_rates: Vec<u32>,
pub diarization: bool,
}
#[derive(Serialize)]
pub struct TranscribeResponse {
pub text: String,
pub words: Vec<crate::inference::WordInfo>,
pub duration: f64,
}
type ApiError = Response;
fn api_error(status: StatusCode, msg: &str, code: &str) -> ApiError {
(
status,
Json(serde_json::json!({"error": msg, "code": code})),
)
.into_response()
}
fn api_timeout_error() -> ApiError {
(
StatusCode::SERVICE_UNAVAILABLE,
[(header::RETRY_AFTER, POOL_RETRY_AFTER_SECS.to_string())],
Json(serde_json::json!({
"error": "Server busy, try again later",
"code": "timeout",
"retry_after_ms": POOL_RETRY_AFTER_MS,
})),
)
.into_response()
}
pub async fn health(State(state): State<Arc<AppState>>) -> Json<HealthResponse> {
let _ = &state.engine;
Json(HealthResponse {
status: "ok".into(),
model: "gigaam-v3-e2e-rnnt".into(),
version: env!("CARGO_PKG_VERSION").into(),
})
}
pub async fn models(State(state): State<Arc<AppState>>) -> Json<ModelInfo> {
let engine = &state.engine;
#[cfg(feature = "diarization")]
let diarization = engine.has_speaker_encoder();
#[cfg(not(feature = "diarization"))]
let diarization = false;
Json(ModelInfo {
id: "gigaam-v3-e2e-rnnt".into(),
name: "GigaAM v3 RNN-T".into(),
version: env!("CARGO_PKG_VERSION").into(),
encoder: if engine.is_int8() {
"int8".into()
} else {
"fp32".into()
},
vocab_size: 1025,
sample_rate: 16000,
pool_size: engine.pool.total(),
pool_available: engine.pool.available(),
supported_formats: vec![
"wav".into(),
"mp3".into(),
"m4a".into(),
"ogg".into(),
"flac".into(),
],
supported_rates: vec![8000, 16000, 24000, 44100, 48000],
diarization,
})
}
pub async fn transcribe(
State(state): State<Arc<AppState>>,
body: Bytes,
) -> Result<Json<TranscribeResponse>, ApiError> {
if body.is_empty() {
return Err(api_error(
StatusCode::BAD_REQUEST,
"Empty request body",
"empty_body",
));
}
let triplet = tokio::time::timeout(
std::time::Duration::from_secs(30),
state.engine.pool.checkout(),
)
.await
.map_err(|_| api_timeout_error())?;
let body_bytes = body.to_vec();
drop(body);
let engine = state.engine.clone();
let result = tokio::task::spawn_blocking(move || {
let mut triplet = triplet;
let r = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
engine.transcribe_bytes(&body_bytes, &mut triplet)
}));
match r {
Ok(inference_result) => (inference_result, triplet),
Err(_) => {
tracing::error!("Panic in REST transcribe — triplet recovered");
(
Err(crate::error::GigasttError::Inference(
"Inference thread panicked".into(),
)),
triplet,
)
}
}
})
.await;
match result {
Ok((Ok(result), triplet)) => {
state.engine.pool.checkin(triplet).await;
Ok(Json(TranscribeResponse {
text: result.text,
words: result.words,
duration: result.duration_s,
}))
}
Ok((Err(e), triplet)) => {
state.engine.pool.checkin(triplet).await;
tracing::error!("Transcription error: {e}");
Err(api_error(
StatusCode::UNPROCESSABLE_ENTITY,
"Transcription failed. Check audio format.",
"transcription_error",
))
}
Err(e) => {
tracing::error!("spawn_blocking join error: {e}");
Err(api_error(
StatusCode::INTERNAL_SERVER_ERROR,
"Internal server error",
"internal",
))
}
}
}
pub async fn transcribe_stream(
State(state): State<Arc<AppState>>,
body: Bytes,
) -> Result<Sse<impl Stream<Item = Result<Event, std::convert::Infallible>>>, ApiError> {
if body.is_empty() {
return Err(api_error(
StatusCode::BAD_REQUEST,
"Empty request body",
"empty_body",
));
}
let body_bytes = body.to_vec();
drop(body);
let samples = {
let bytes = body_bytes;
tokio::task::spawn_blocking(move || {
crate::inference::audio::decode_audio_bytes(&bytes)
})
.await
.map_err(|e| {
tracing::error!("spawn_blocking join error: {e}");
api_error(StatusCode::INTERNAL_SERVER_ERROR, "Internal server error", "internal")
})?
.map_err(|e| {
tracing::error!("Audio decode error: {e:#}");
api_error(StatusCode::UNPROCESSABLE_ENTITY, "Failed to decode audio file. Check format (WAV, MP3, M4A, OGG, FLAC supported).", "invalid_audio")
})?
};
let triplet = tokio::time::timeout(
std::time::Duration::from_secs(30),
state.engine.pool.checkout(),
)
.await
.map_err(|_| api_timeout_error())?;
let (tx, rx) =
tokio::sync::mpsc::channel::<Result<crate::inference::TranscriptSegment, String>>(16);
let engine = state.engine.clone();
tokio::task::spawn_blocking(move || {
let mut triplet = triplet;
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut stream_state = engine.create_state(
#[cfg(feature = "diarization")]
false,
);
let chunk_size = 16000;
for chunk in samples.chunks(chunk_size) {
match engine.process_chunk(chunk, &mut stream_state, &mut triplet) {
Ok(segs) => {
for seg in segs {
if tx.blocking_send(Ok(seg)).is_err() {
return;
}
}
}
Err(e) => {
let _ = tx.blocking_send(Err(format!("{e}")));
return;
}
}
}
if let Some(seg) = engine.flush_state(&mut stream_state) {
let _ = tx.blocking_send(Ok(seg));
}
}));
if result.is_err() {
tracing::error!("Panic in SSE inference task — triplet recovered");
}
engine.pool.blocking_checkin(triplet);
});
let stream = tokio_stream::wrappers::ReceiverStream::new(rx)
.map(|result| {
let event = match result {
Ok(seg) => {
let msg = if seg.is_final {
serde_json::json!({"type": "final", "text": seg.text, "timestamp": seg.timestamp, "words": seg.words})
} else {
serde_json::json!({"type": "partial", "text": seg.text, "timestamp": seg.timestamp, "words": seg.words})
};
Event::default().data(msg.to_string())
}
Err(_) => {
let msg = serde_json::json!({"type": "error", "message": "Transcription failed.", "code": "inference_error"});
Event::default().data(msg.to_string())
}
};
Ok(event)
});
Ok(Sse::new(stream).keep_alive(KeepAlive::default()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_health_response_serialization() {
let resp = HealthResponse {
status: "ok".into(),
model: "test".into(),
version: "0.3.0".into(),
};
let json = serde_json::to_string(&resp).unwrap();
let v: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(v["status"], "ok");
assert_eq!(v["model"], "test");
}
#[test]
fn test_transcribe_response_serialization() {
let resp = TranscribeResponse {
text: "hello".into(),
words: vec![],
duration: 1.5,
};
let json = serde_json::to_string(&resp).unwrap();
let v: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(v["text"], "hello");
assert_eq!(v["duration"], 1.5);
}
}