pub mod http;
use crate::inference::{Engine, SessionTriplet};
use crate::protocol::{ClientMessage, ServerMessage};
use anyhow::{Context, Result};
use axum::Router;
use axum::extract::DefaultBodyLimit;
use axum::extract::State;
use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
use axum::response::Response;
use axum::routing::{get, post};
use futures_util::{SinkExt, StreamExt};
use std::net::SocketAddr;
use std::sync::Arc;
fn json_text(msg: &impl serde::Serialize) -> String {
serde_json::to_string(msg).unwrap_or_else(|e| {
tracing::error!("Failed to serialize server message: {e}");
r#"{"type":"error","message":"Internal serialization error","code":"internal"}"#.into()
})
}
const SUPPORTED_RATES: &[u32] = &[8000, 16000, 24000, 44100, 48000];
const DEFAULT_SAMPLE_RATE: u32 = 48000;
pub(crate) const POOL_RETRY_AFTER_MS: u32 = 30_000;
pub(crate) const POOL_RETRY_AFTER_SECS: u64 = 30;
#[derive(Debug, Clone, Default)]
pub struct OriginPolicy {
pub allow_any: bool,
pub allowed_origins: Vec<String>,
}
impl OriginPolicy {
pub fn loopback_only() -> Self {
Self::default()
}
}
#[derive(Debug)]
enum OriginVerdict {
AllowedNoEcho,
Allowed(String),
Denied,
}
fn is_loopback_origin(origin: &str) -> bool {
let lowered = origin.to_ascii_lowercase();
const HOST_PREFIXES: &[&str] = &[
"http://localhost",
"https://localhost",
"http://127.0.0.1",
"https://127.0.0.1",
"http://[::1]",
"https://[::1]",
];
HOST_PREFIXES.iter().any(|p| match lowered.strip_prefix(p) {
None => false,
Some(rest) => rest.is_empty() || rest.starts_with(':') || rest.starts_with('/'),
})
}
impl OriginPolicy {
fn evaluate(&self, origin: Option<&str>) -> OriginVerdict {
let Some(origin) = origin else {
return OriginVerdict::AllowedNoEcho;
};
if origin.eq_ignore_ascii_case("null") {
return OriginVerdict::AllowedNoEcho;
}
if self.allow_any || is_loopback_origin(origin) {
return OriginVerdict::Allowed(origin.to_string());
}
if self
.allowed_origins
.iter()
.any(|a| a.eq_ignore_ascii_case(origin))
{
return OriginVerdict::Allowed(origin.to_string());
}
OriginVerdict::Denied
}
}
#[derive(Debug, Clone)]
pub struct ServerConfig {
pub port: u16,
pub host: String,
pub origin_policy: OriginPolicy,
}
impl ServerConfig {
pub fn local(port: u16) -> Self {
Self {
port,
host: "127.0.0.1".to_string(),
origin_policy: OriginPolicy::loopback_only(),
}
}
}
pub async fn run(engine: Engine, port: u16, host: &str) -> Result<()> {
run_with_shutdown(engine, port, host, None).await
}
pub async fn run_with_shutdown(
engine: Engine,
port: u16,
host: &str,
shutdown: Option<tokio::sync::oneshot::Receiver<()>>,
) -> Result<()> {
let config = ServerConfig {
port,
host: host.to_string(),
origin_policy: OriginPolicy::loopback_only(),
};
run_with_config(engine, config, shutdown).await
}
pub async fn run_with_config(
engine: Engine,
config: ServerConfig,
shutdown: Option<tokio::sync::oneshot::Receiver<()>>,
) -> Result<()> {
let addr: SocketAddr = format!("{}:{}", config.host, config.port)
.parse()
.context("Invalid host:port")?;
let state = Arc::new(http::AppState {
engine: Arc::new(engine),
});
let policy = Arc::new(config.origin_policy.clone());
let origin_layer = {
let policy = policy.clone();
axum::middleware::from_fn(move |req, next| {
let policy = policy.clone();
async move { origin_middleware(policy, req, next).await }
})
};
let app = Router::new()
.route("/health", get(http::health))
.route("/v1/models", get(http::models))
.route("/v1/transcribe", post(http::transcribe))
.route("/v1/transcribe/stream", post(http::transcribe_stream))
.route("/ws", get(ws_handler))
.layer(DefaultBodyLimit::max(50 * 1024 * 1024)) .layer(origin_layer)
.with_state(state);
tracing::info!("gigastt server listening on http://{addr}");
tracing::info!(" WebSocket: ws://{addr}/ws");
tracing::info!(" REST API: http://{addr}/health, /v1/transcribe, /v1/transcribe/stream");
if config.origin_policy.allow_any {
tracing::warn!(
"CORS allow-any is ON: any cross-origin page can call this server. \
Only use with trusted callers."
);
} else if !config.origin_policy.allowed_origins.is_empty() {
tracing::info!(
"CORS allowlist (in addition to loopback): {:?}",
config.origin_policy.allowed_origins
);
}
let listener = tokio::net::TcpListener::bind(&addr).await?;
let shutdown_fut = async {
match shutdown {
Some(rx) => {
rx.await.ok();
}
None => {
tokio::signal::ctrl_c().await.ok();
}
}
tracing::info!("Shutting down server");
};
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.with_graceful_shutdown(shutdown_fut)
.await?;
Ok(())
}
async fn origin_middleware(
policy: Arc<OriginPolicy>,
req: axum::extract::Request,
next: axum::middleware::Next,
) -> Response {
use axum::http::{StatusCode, header};
use axum::response::IntoResponse;
if req.uri().path() == "/health" {
return next.run(req).await;
}
let origin = req
.headers()
.get("origin")
.and_then(|v| v.to_str().ok())
.map(str::to_string);
match policy.evaluate(origin.as_deref()) {
OriginVerdict::AllowedNoEcho => next.run(req).await,
OriginVerdict::Allowed(echo) => {
let mut response = next.run(req).await;
let headers = response.headers_mut();
let value = if policy.allow_any { "*".into() } else { echo };
if let Ok(v) = axum::http::HeaderValue::from_str(&value) {
headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, v);
}
headers.insert(
header::ACCESS_CONTROL_ALLOW_METHODS,
axum::http::HeaderValue::from_static("GET, POST, OPTIONS"),
);
headers.insert(
header::ACCESS_CONTROL_ALLOW_HEADERS,
axum::http::HeaderValue::from_static("*"),
);
response
}
OriginVerdict::Denied => {
let origin_str = origin.as_deref().unwrap_or("");
let path = req.uri().path().to_string();
tracing::warn!(
origin = %origin_str,
path = %path,
"cross-origin request denied by default policy"
);
(
StatusCode::FORBIDDEN,
axum::response::Json(serde_json::json!({
"error": "Origin not allowed",
"code": "origin_denied",
})),
)
.into_response()
}
}
}
async fn ws_handler(
ws: WebSocketUpgrade,
axum::extract::ConnectInfo(peer): axum::extract::ConnectInfo<SocketAddr>,
State(state): State<Arc<http::AppState>>,
) -> Response {
ws.max_message_size(512 * 1024)
.max_frame_size(512 * 1024)
.on_upgrade(move |socket| handle_ws(socket, peer, state))
}
async fn handle_ws(socket: WebSocket, peer: SocketAddr, state: Arc<http::AppState>) {
let triplet = match tokio::time::timeout(
std::time::Duration::from_secs(30),
state.engine.pool.checkout(),
)
.await
{
Ok(triplet) => triplet,
Err(_) => {
tracing::warn!("WebSocket pool checkout timeout for {peer}");
let (mut sink, _) = socket.split();
let err = ServerMessage::Error {
message: "Server busy, try again later".into(),
code: "timeout".into(),
retry_after_ms: Some(POOL_RETRY_AFTER_MS),
};
let _ = sink.send(WsMessage::Text(json_text(&err).into())).await;
return;
}
};
let (triplet_opt, result) = handle_ws_inner(socket, peer, &state.engine, triplet).await;
if let Err(e) = result {
tracing::error!("WebSocket error from {peer}: {e}");
}
if let Some(triplet) = triplet_opt {
state.engine.pool.checkin(triplet).await;
}
}
enum FrameOutcome {
Continue,
Break,
}
type WsSink = futures_util::stream::SplitSink<WebSocket, WsMessage>;
async fn send_server_message(sink: &mut WsSink, msg: &ServerMessage) -> Result<()> {
sink.send(WsMessage::Text(json_text(msg).into()))
.await
.map_err(Into::into)
}
#[allow(clippy::too_many_arguments)]
async fn handle_binary_frame(
sink: &mut WsSink,
engine: &Arc<Engine>,
state_opt: &mut Option<crate::inference::StreamingState>,
triplet_opt: &mut Option<SessionTriplet>,
audio_received: &mut bool,
client_sample_rate: u32,
peer: SocketAddr,
data: axum::body::Bytes,
) -> Result<FrameOutcome> {
if data.is_empty() {
tracing::debug!("Empty binary frame from {peer}, skipping");
return Ok(FrameOutcome::Continue);
}
*audio_received = true;
if !data.len().is_multiple_of(2) {
tracing::warn!(
"Odd-length PCM frame ({} bytes) from {peer}, dropping last byte",
data.len()
);
}
let samples_f32: Vec<f32> = data
.chunks_exact(2)
.map(|chunk| i16::from_le_bytes([chunk[0], chunk[1]]) as f32 / 32768.0)
.collect();
let samples_16k = if client_sample_rate == 16000 {
samples_f32
} else {
crate::inference::audio::resample(&samples_f32, client_sample_rate, 16000)?
};
let state = state_opt
.take()
.ok_or_else(|| anyhow::anyhow!("Streaming state lost"))?;
let triplet = triplet_opt.take().ok_or_else(|| {
tracing::error!("Triplet unexpectedly missing for {peer}");
anyhow::anyhow!("Triplet lost")
})?;
let eng = engine.clone();
let join_result = tokio::task::spawn_blocking(move || {
let mut state = state;
let mut triplet = triplet;
let r = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
eng.process_chunk(&samples_16k, &mut state, &mut triplet)
}));
(r, state, triplet)
})
.await;
match join_result {
Ok((Ok(Ok(segments)), state_back, triplet_back)) => {
*state_opt = Some(state_back);
*triplet_opt = Some(triplet_back);
for seg in segments {
let msg = if seg.is_final {
ServerMessage::Final {
text: seg.text,
timestamp: seg.timestamp,
words: seg.words,
}
} else {
ServerMessage::Partial {
text: seg.text,
timestamp: seg.timestamp,
words: seg.words,
}
};
send_server_message(sink, &msg).await?;
}
Ok(FrameOutcome::Continue)
}
Ok((Ok(Err(e)), state_back, triplet_back)) => {
*state_opt = Some(state_back);
*triplet_opt = Some(triplet_back);
tracing::error!("Inference error for {peer}: {e:#}");
send_server_message(
sink,
&ServerMessage::Error {
message: "Inference failed. Please check audio format.".into(),
code: "inference_error".into(),
retry_after_ms: None,
},
)
.await?;
Ok(FrameOutcome::Continue)
}
Ok((Err(_panic), _state_back, triplet_back)) => {
tracing::error!(
"Panic in WS inference for {peer} — triplet recovered, streaming state reset"
);
*triplet_opt = Some(triplet_back);
*state_opt = Some(engine.create_state(
#[cfg(feature = "diarization")]
false,
));
send_server_message(
sink,
&ServerMessage::Error {
message: "Inference failed unexpectedly. Session reset.".into(),
code: "inference_panic".into(),
retry_after_ms: None,
},
)
.await?;
Ok(FrameOutcome::Continue)
}
Err(e) => {
tracing::error!("spawn_blocking join error for {peer}: {e}");
Err(anyhow::anyhow!("Blocking task join failed"))
}
}
}
#[allow(clippy::too_many_arguments)]
async fn handle_configure_message(
sink: &mut WsSink,
engine: &Arc<Engine>,
state_opt: &mut Option<crate::inference::StreamingState>,
client_sample_rate: &mut u32,
audio_received: bool,
sample_rate: Option<u32>,
diarization: Option<bool>,
peer: SocketAddr,
) -> Result<FrameOutcome> {
if audio_received {
send_server_message(
sink,
&ServerMessage::Error {
message: "Configure must be sent before first audio frame".into(),
code: "configure_too_late".into(),
retry_after_ms: None,
},
)
.await?;
return Ok(FrameOutcome::Continue);
}
if let Some(rate) = sample_rate {
if SUPPORTED_RATES.contains(&rate) {
*client_sample_rate = rate;
tracing::info!("Client {peer} configured sample rate: {rate}Hz");
} else {
send_server_message(
sink,
&ServerMessage::Error {
message: format!(
"Unsupported sample rate: {rate}Hz. Supported: {SUPPORTED_RATES:?}"
),
code: "invalid_sample_rate".into(),
retry_after_ms: None,
},
)
.await?;
}
}
#[cfg(feature = "diarization")]
if let Some(enable_dia) = diarization {
tracing::info!("Client {peer} configured diarization: {enable_dia}");
*state_opt = Some(engine.create_state(enable_dia));
}
#[cfg(not(feature = "diarization"))]
{
let _ = (engine, state_opt, diarization);
}
Ok(FrameOutcome::Continue)
}
async fn handle_stop_message(
sink: &mut WsSink,
engine: &Arc<Engine>,
state_opt: &mut Option<crate::inference::StreamingState>,
peer: SocketAddr,
) -> Result<FrameOutcome> {
tracing::info!("Stop received from {peer}, finalizing");
let Some(mut state) = state_opt.take() else {
return Ok(FrameOutcome::Break);
};
let flush_seg = engine.flush_state(&mut state);
drop(state);
let final_msg = if let Some(seg) = flush_seg {
ServerMessage::Final {
text: seg.text,
timestamp: seg.timestamp,
words: seg.words,
}
} else {
ServerMessage::Final {
text: String::new(),
timestamp: crate::inference::now_timestamp(),
words: vec![],
}
};
send_server_message(sink, &final_msg).await?;
Ok(FrameOutcome::Break)
}
async fn handle_ws_inner(
socket: WebSocket,
peer: SocketAddr,
engine: &Arc<Engine>,
triplet: SessionTriplet,
) -> (Option<SessionTriplet>, Result<()>) {
let (mut sink, mut source) = socket.split();
tracing::info!("Client connected: {peer}");
#[cfg(feature = "diarization")]
let diarization_available = engine.has_speaker_encoder();
#[cfg(not(feature = "diarization"))]
let diarization_available = false;
let ready = ServerMessage::Ready {
model: "gigaam-v3-e2e-rnnt".into(),
sample_rate: DEFAULT_SAMPLE_RATE,
version: crate::protocol::PROTOCOL_VERSION.into(),
supported_rates: SUPPORTED_RATES.to_vec(),
diarization: diarization_available,
};
if let Err(e) = send_server_message(&mut sink, &ready).await {
return (Some(triplet), Err(e));
}
let mut state_opt = Some(engine.create_state(
#[cfg(feature = "diarization")]
false,
));
let mut triplet_opt = Some(triplet);
let mut client_sample_rate: u32 = DEFAULT_SAMPLE_RATE;
let mut audio_received = false;
let result: Result<()> = loop {
let msg =
match tokio::time::timeout(std::time::Duration::from_secs(300), source.next()).await {
Ok(Some(Ok(msg))) => msg,
Ok(Some(Err(e))) => break Err(e.into()),
Ok(None) => break Ok(()),
Err(_) => {
tracing::info!("Client {peer} idle timeout (300s)");
break Ok(());
}
};
let outcome = match msg {
WsMessage::Binary(data) => {
handle_binary_frame(
&mut sink,
engine,
&mut state_opt,
&mut triplet_opt,
&mut audio_received,
client_sample_rate,
peer,
data,
)
.await
}
WsMessage::Text(text) => match serde_json::from_str::<ClientMessage>(&text) {
Ok(ClientMessage::Configure {
sample_rate,
diarization,
}) => {
handle_configure_message(
&mut sink,
engine,
&mut state_opt,
&mut client_sample_rate,
audio_received,
sample_rate,
diarization,
peer,
)
.await
}
Ok(ClientMessage::Stop) => {
handle_stop_message(&mut sink, engine, &mut state_opt, peer).await
}
Err(_) => {
tracing::debug!(
"Unrecognized text message from {peer}: {}",
&text[..text.len().min(100)]
);
Ok(FrameOutcome::Continue)
}
},
WsMessage::Close(_) => Ok(FrameOutcome::Break),
_ => Ok(FrameOutcome::Continue), };
match outcome {
Ok(FrameOutcome::Continue) => continue,
Ok(FrameOutcome::Break) => break Ok(()),
Err(e) => break Err(e),
}
};
tracing::info!("Client disconnected: {peer}");
(triplet_opt, result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_supported_rates_contains_common() {
assert!(
SUPPORTED_RATES.contains(&8000),
"SUPPORTED_RATES must include 8000 Hz"
);
assert!(
SUPPORTED_RATES.contains(&16000),
"SUPPORTED_RATES must include 16000 Hz"
);
assert!(
SUPPORTED_RATES.contains(&48000),
"SUPPORTED_RATES must include 48000 Hz"
);
}
#[test]
fn test_default_sample_rate_in_supported() {
assert!(
SUPPORTED_RATES.contains(&DEFAULT_SAMPLE_RATE),
"DEFAULT_SAMPLE_RATE ({DEFAULT_SAMPLE_RATE}) must be present in SUPPORTED_RATES"
);
}
#[test]
fn test_loopback_origin_matcher() {
assert!(is_loopback_origin("http://localhost"));
assert!(is_loopback_origin("https://localhost:3000"));
assert!(is_loopback_origin("http://127.0.0.1:9876"));
assert!(is_loopback_origin("HTTPS://127.0.0.1")); assert!(is_loopback_origin("http://[::1]:9876"));
assert!(!is_loopback_origin("https://evil.example.com"));
assert!(!is_loopback_origin("http://192.168.1.10"));
assert!(!is_loopback_origin("http://localhost.evil.example.com"));
}
#[test]
fn test_origin_policy_default_denies_third_party() {
let policy = OriginPolicy::loopback_only();
assert!(matches!(
policy.evaluate(Some("https://evil.example.com")),
OriginVerdict::Denied
));
}
#[test]
fn test_origin_policy_allows_loopback_by_default() {
let policy = OriginPolicy::loopback_only();
assert!(matches!(
policy.evaluate(Some("http://localhost:3000")),
OriginVerdict::Allowed(_)
));
}
#[test]
fn test_origin_policy_allows_listed_origin() {
let policy = OriginPolicy {
allow_any: false,
allowed_origins: vec!["https://app.example.com".into()],
};
assert!(matches!(
policy.evaluate(Some("https://app.example.com")),
OriginVerdict::Allowed(_)
));
assert!(matches!(
policy.evaluate(Some("https://app.example.com.evil.com")),
OriginVerdict::Denied
));
}
#[test]
fn test_origin_policy_allow_any_short_circuits() {
let policy = OriginPolicy {
allow_any: true,
allowed_origins: vec![],
};
assert!(matches!(
policy.evaluate(Some("https://anything.example.com")),
OriginVerdict::Allowed(_)
));
}
#[test]
fn test_origin_policy_no_header_allowed() {
let policy = OriginPolicy::loopback_only();
assert!(matches!(
policy.evaluate(None),
OriginVerdict::AllowedNoEcho
));
assert!(matches!(
policy.evaluate(Some("null")),
OriginVerdict::AllowedNoEcho
));
}
#[tokio::test]
async fn test_origin_middleware_integration() {
use axum::Router;
use axum::routing::get;
let policy = Arc::new(OriginPolicy::loopback_only());
let origin_layer = {
let policy = policy.clone();
axum::middleware::from_fn(move |req, next| {
let policy = policy.clone();
async move { origin_middleware(policy, req, next).await }
})
};
let app = Router::new()
.route("/health", get(|| async { "ok" }))
.route("/v1/transcribe", get(|| async { "ok" }))
.layer(origin_layer);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
tokio::spawn(async move {
let _ = axum::serve(listener, app).await;
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let client = reqwest::Client::new();
let base = format!("http://127.0.0.1:{port}");
let r = client
.get(format!("{base}/health"))
.header("Origin", "https://evil.example.com")
.send()
.await
.unwrap();
assert_eq!(r.status(), 200, "/health must skip the Origin guard");
let r = client
.get(format!("{base}/v1/transcribe"))
.header("Origin", "https://evil.example.com")
.send()
.await
.unwrap();
assert_eq!(
r.status(),
403,
"non-loopback Origin must receive 403 Forbidden"
);
let text = r.text().await.unwrap();
let body: serde_json::Value = serde_json::from_str(&text).unwrap();
assert_eq!(body["code"], "origin_denied");
let r = client
.get(format!("{base}/v1/transcribe"))
.header("Origin", "http://localhost:3000")
.send()
.await
.unwrap();
assert_eq!(r.status(), 200, "loopback Origin must be allowed");
assert_eq!(
r.headers()
.get("access-control-allow-origin")
.and_then(|v| v.to_str().ok()),
Some("http://localhost:3000"),
"CORS echo must mirror the incoming Origin (no wildcard by default)",
);
let r = client
.get(format!("{base}/v1/transcribe"))
.send()
.await
.unwrap();
assert_eq!(r.status(), 200, "requests without Origin must pass");
let r = client
.get(format!("{base}/v1/transcribe"))
.header("Origin", "http://localhost.evil.example.com")
.send()
.await
.unwrap();
assert_eq!(
r.status(),
403,
"localhost.* DNS continuation must not impersonate loopback"
);
}
#[test]
fn test_catch_unwind_preserves_ownership_across_panic() {
use std::panic::{AssertUnwindSafe, catch_unwind};
let mut state = 42u32;
let mut triplet_marker = String::from("pool_slot");
let result = catch_unwind(AssertUnwindSafe(|| {
state = 99;
triplet_marker.push_str("/taken");
panic!("simulated inference panic");
}));
assert!(result.is_err(), "catch_unwind must report the panic");
assert_eq!(state, 99, "state must remain accessible after panic");
assert_eq!(
triplet_marker, "pool_slot/taken",
"triplet marker must survive panic"
);
}
}