pub mod http;
pub mod metrics;
pub mod rate_limit;
use crate::inference::{Engine, SessionTriplet};
use crate::protocol::{ClientMessage, ServerMessage};
use anyhow::{Context, Result};
use axum::Router;
use axum::extract::DefaultBodyLimit;
use axum::extract::State;
use axum::extract::ws::{Message as WsMessage, WebSocket, WebSocketUpgrade};
use axum::response::Response;
use axum::routing::{get, post};
use futures_util::{SinkExt, StreamExt};
use std::net::SocketAddr;
use std::sync::Arc;
fn json_text(msg: &impl serde::Serialize) -> String {
serde_json::to_string(msg).unwrap_or_else(|e| {
tracing::error!("Failed to serialize server message: {e}");
r#"{"type":"error","message":"Internal serialization error","code":"internal"}"#.into()
})
}
pub(crate) const SUPPORTED_RATES: &[u32] = &[8000, 16000, 24000, 44100, 48000];
const DEFAULT_SAMPLE_RATE: u32 = 48000;
pub(crate) const POOL_RETRY_AFTER_MS: u32 = 30_000;
pub(crate) const POOL_RETRY_AFTER_SECS: u64 = 30;
#[derive(Debug, Clone, Default)]
pub struct OriginPolicy {
pub allow_any: bool,
pub allowed_origins: Vec<String>,
}
impl OriginPolicy {
pub fn loopback_only() -> Self {
Self::default()
}
}
#[derive(Debug)]
enum OriginVerdict {
AllowedNoEcho,
Allowed(String),
Denied,
}
fn is_loopback_origin(origin: &str) -> bool {
let lowered = origin.to_ascii_lowercase();
const HOST_PREFIXES: &[&str] = &[
"http://localhost",
"https://localhost",
"http://127.0.0.1",
"https://127.0.0.1",
"http://[::1]",
"https://[::1]",
];
HOST_PREFIXES.iter().any(|p| match lowered.strip_prefix(p) {
None => false,
Some(rest) => rest.is_empty() || rest.starts_with(':') || rest.starts_with('/'),
})
}
impl OriginPolicy {
fn evaluate(&self, origin: Option<&str>) -> OriginVerdict {
let Some(origin) = origin else {
return OriginVerdict::AllowedNoEcho;
};
if origin.eq_ignore_ascii_case("null") {
return OriginVerdict::AllowedNoEcho;
}
if self.allow_any || is_loopback_origin(origin) {
return OriginVerdict::Allowed(origin.to_string());
}
if self
.allowed_origins
.iter()
.any(|a| a.eq_ignore_ascii_case(origin))
{
return OriginVerdict::Allowed(origin.to_string());
}
OriginVerdict::Denied
}
}
#[derive(Debug, Clone)]
pub struct RuntimeLimits {
pub idle_timeout_secs: u64,
pub ws_frame_max_bytes: usize,
pub body_limit_bytes: usize,
pub rate_limit_per_minute: u32,
pub rate_limit_burst: u32,
pub max_session_secs: u64,
pub shutdown_drain_secs: u64,
}
impl Default for RuntimeLimits {
fn default() -> Self {
Self {
idle_timeout_secs: 300,
ws_frame_max_bytes: 512 * 1024,
body_limit_bytes: 50 * 1024 * 1024,
rate_limit_per_minute: 0,
rate_limit_burst: 10,
max_session_secs: 3600,
shutdown_drain_secs: 10,
}
}
}
#[derive(Debug, Clone)]
pub struct ServerConfig {
pub port: u16,
pub host: String,
pub origin_policy: OriginPolicy,
pub limits: RuntimeLimits,
pub metrics_enabled: bool,
}
impl ServerConfig {
pub fn local(port: u16) -> Self {
Self {
port,
host: "127.0.0.1".to_string(),
origin_policy: OriginPolicy::loopback_only(),
limits: RuntimeLimits::default(),
metrics_enabled: false,
}
}
}
pub async fn run(engine: Engine, port: u16, host: &str) -> Result<()> {
run_with_shutdown(engine, port, host, None).await
}
pub async fn run_with_shutdown(
engine: Engine,
port: u16,
host: &str,
shutdown: Option<tokio::sync::oneshot::Receiver<()>>,
) -> Result<()> {
let config = ServerConfig {
port,
host: host.to_string(),
origin_policy: OriginPolicy::loopback_only(),
limits: RuntimeLimits::default(),
metrics_enabled: false,
};
run_with_config(engine, config, shutdown).await
}
pub async fn run_with_config(
engine: Engine,
config: ServerConfig,
shutdown: Option<tokio::sync::oneshot::Receiver<()>>,
) -> Result<()> {
let addr: SocketAddr = format!("{}:{}", config.host, config.port)
.parse()
.context("Invalid host:port")?;
let metrics_registry = if config.metrics_enabled {
let reg = std::sync::Arc::new(self::metrics::MetricsRegistry::new());
reg.register_counter(
"phostt_http_requests_total",
"Total HTTP requests processed",
);
reg.register_histogram(
"phostt_http_request_duration_seconds",
"HTTP request duration in seconds",
self::metrics::DEFAULT_BUCKETS,
);
tracing::info!("Prometheus /metrics endpoint enabled");
Some(reg)
} else {
None
};
if config.limits.max_session_secs != 0
&& config.limits.max_session_secs < config.limits.idle_timeout_secs
{
tracing::warn!(
max_session_secs = config.limits.max_session_secs,
idle_timeout_secs = config.limits.idle_timeout_secs,
"max_session_secs < idle_timeout_secs — sessions will be capped before \
the idle timer can fire; this is probably not what you want"
);
}
let shutdown_root = tokio_util::sync::CancellationToken::new();
let tracker = tokio_util::task::TaskTracker::new();
let state = Arc::new(http::AppState {
engine: Arc::new(engine),
limits: config.limits.clone(),
metrics_registry,
shutdown: shutdown_root.clone(),
tracker: tracker.clone(),
});
let policy = Arc::new(config.origin_policy.clone());
let origin_layer = {
let policy = policy.clone();
axum::middleware::from_fn(move |req, next| {
let policy = policy.clone();
async move { origin_middleware(policy, req, next).await }
})
};
let protected = Router::new()
.route("/v1/models", get(http::models))
.route("/v1/transcribe", post(http::transcribe))
.route("/v1/transcribe/stream", post(http::transcribe_stream))
.route("/v1/ws", get(ws_handler))
.route("/ws", get(ws_handler_legacy))
.route("/metrics", get(http::metrics))
.layer(axum::middleware::from_fn_with_state(
state.clone(),
http_metrics_middleware,
))
.with_state(state.clone());
let protected = if config.limits.rate_limit_per_minute > 0 {
let limiter = Arc::new(rate_limit::RateLimiter::new(
config.limits.rate_limit_per_minute,
config.limits.rate_limit_burst,
));
let interval_ms = limiter.interval_ms();
let evict_limiter = limiter.clone();
let evict_cancel = shutdown_root.clone();
tokio::spawn(async move {
let mut ticker = tokio::time::interval(std::time::Duration::from_secs(60));
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
ticker.tick().await;
loop {
tokio::select! {
biased;
_ = evict_cancel.cancelled() => break,
_ = ticker.tick() => {
evict_limiter.evict_stale(std::time::Duration::from_secs(300));
}
}
}
});
tracing::info!(
rpm = config.limits.rate_limit_per_minute,
interval_ms,
burst = config.limits.rate_limit_burst,
"per-IP rate limiting enabled"
);
let layer_limiter = limiter.clone();
protected.layer(axum::middleware::from_fn(move |req, next| {
let limiter = layer_limiter.clone();
async move { rate_limit::rate_limit_middleware(limiter, req, next).await }
}))
} else {
protected
};
let shutdown_engine = state.engine.clone();
let app = Router::new()
.route("/health", get(http::health))
.merge(protected)
.layer(DefaultBodyLimit::max(config.limits.body_limit_bytes))
.layer(origin_layer)
.with_state(state);
tracing::info!("phostt server listening on http://{addr}");
tracing::info!(" WebSocket: ws://{addr}/v1/ws (legacy alias: ws://{addr}/ws)");
tracing::info!(" REST API: http://{addr}/health, /v1/transcribe, /v1/transcribe/stream");
if config.origin_policy.allow_any {
tracing::warn!(
"CORS allow-any is ON: any cross-origin page can call this server. \
Only use with trusted callers."
);
} else if !config.origin_policy.allowed_origins.is_empty() {
tracing::info!(
"CORS allowlist (in addition to loopback): {:?}",
config.origin_policy.allowed_origins
);
}
let listener = tokio::net::TcpListener::bind(&addr).await?;
let shutdown_drain_secs = config.limits.shutdown_drain_secs.max(1);
let shutdown_fut = {
let shutdown_root = shutdown_root.clone();
async move {
match shutdown {
Some(rx) => {
rx.await.ok();
}
None => {
tokio::signal::ctrl_c().await.ok();
}
}
tracing::info!("Shutting down server");
shutdown_root.cancel();
shutdown_engine.pool.close();
}
};
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.with_graceful_shutdown(shutdown_fut)
.await?;
tracker.close();
match tokio::time::timeout(
std::time::Duration::from_secs(shutdown_drain_secs),
tracker.wait(),
)
.await
{
Ok(()) => tracing::info!("Drain complete: all tracked WS/SSE tasks finished"),
Err(_) => tracing::warn!(
drain_secs = shutdown_drain_secs,
pending = tracker.len(),
"Drain window expired with tracked tasks still running — forcing exit"
),
}
Ok(())
}
async fn http_metrics_middleware(
State(state): State<Arc<http::AppState>>,
req: axum::extract::Request,
next: axum::middleware::Next,
) -> Response {
let Some(registry) = state.metrics_registry.clone() else {
return next.run(req).await;
};
let method = req.method().as_str().to_string();
let path = req.uri().path().to_string();
let start = std::time::Instant::now();
let response = next.run(req).await;
let elapsed = start.elapsed().as_secs_f64();
let status = response.status().as_u16().to_string();
registry.counter_inc(
"phostt_http_requests_total",
vec![
("method".into(), method.clone()),
("path".into(), path.clone()),
("status".into(), status),
],
1,
);
registry.histogram_record(
"phostt_http_request_duration_seconds",
vec![("method".into(), method), ("path".into(), path)],
elapsed,
);
response
}
async fn origin_middleware(
policy: Arc<OriginPolicy>,
req: axum::extract::Request,
next: axum::middleware::Next,
) -> Response {
use axum::http::{StatusCode, header};
use axum::response::IntoResponse;
if req.uri().path() == "/health" {
return next.run(req).await;
}
let origin = req
.headers()
.get("origin")
.and_then(|v| v.to_str().ok())
.map(str::to_string);
match policy.evaluate(origin.as_deref()) {
OriginVerdict::AllowedNoEcho => next.run(req).await,
OriginVerdict::Allowed(echo) => {
let mut response = next.run(req).await;
let headers = response.headers_mut();
let value = if policy.allow_any { "*".into() } else { echo };
if let Ok(v) = axum::http::HeaderValue::from_str(&value) {
headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, v);
}
headers.insert(
header::ACCESS_CONTROL_ALLOW_METHODS,
axum::http::HeaderValue::from_static("GET, POST, OPTIONS"),
);
headers.insert(
header::ACCESS_CONTROL_ALLOW_HEADERS,
axum::http::HeaderValue::from_static("*"),
);
response
}
OriginVerdict::Denied => {
let origin_str = origin.as_deref().unwrap_or("");
let path = req.uri().path().to_string();
tracing::warn!(
origin = %origin_str,
path = %path,
"cross-origin request denied by default policy"
);
(
StatusCode::FORBIDDEN,
axum::response::Json(serde_json::json!({
"error": "Origin not allowed",
"code": "origin_denied",
})),
)
.into_response()
}
}
}
async fn ws_handler(
ws: WebSocketUpgrade,
axum::extract::ConnectInfo(peer): axum::extract::ConnectInfo<SocketAddr>,
State(state): State<Arc<http::AppState>>,
) -> Response {
if state.shutdown.is_cancelled() {
use axum::http::StatusCode;
use axum::response::IntoResponse;
tracing::warn!(peer = %peer, "Rejecting WS upgrade after shutdown");
return (
StatusCode::SERVICE_UNAVAILABLE,
axum::response::Json(serde_json::json!({
"error": "Server shutting down",
"code": "shutting_down",
})),
)
.into_response();
}
let max_bytes = state.limits.ws_frame_max_bytes;
let state_cloned = state.clone();
ws.max_message_size(max_bytes)
.max_frame_size(max_bytes)
.on_upgrade(move |socket| async move {
state_cloned
.tracker
.clone()
.track_future(handle_ws(socket, peer, state_cloned.clone()))
.await
})
}
async fn ws_handler_legacy(
ws: WebSocketUpgrade,
axum::extract::ConnectInfo(peer): axum::extract::ConnectInfo<SocketAddr>,
State(state): State<Arc<http::AppState>>,
) -> Response {
tracing::warn!(
peer = %peer,
"WebSocket client connected to deprecated /ws path — switch to /v1/ws before v1.0"
);
if state.shutdown.is_cancelled() {
use axum::http::StatusCode;
use axum::response::IntoResponse;
return (
StatusCode::SERVICE_UNAVAILABLE,
axum::response::Json(serde_json::json!({
"error": "Server shutting down",
"code": "shutting_down",
})),
)
.into_response();
}
let max_bytes = state.limits.ws_frame_max_bytes;
let state_cloned = state.clone();
let mut response = ws
.max_message_size(max_bytes)
.max_frame_size(max_bytes)
.on_upgrade(move |socket| async move {
state_cloned
.tracker
.clone()
.track_future(handle_ws(socket, peer, state_cloned.clone()))
.await
});
let headers = response.headers_mut();
headers.insert("deprecation", axum::http::HeaderValue::from_static("true"));
headers.insert(
"link",
axum::http::HeaderValue::from_static("</v1/ws>; rel=\"successor-version\""),
);
response
}
async fn handle_ws(socket: WebSocket, peer: SocketAddr, state: Arc<http::AppState>) {
let guard = tokio::select! {
biased;
_ = state.shutdown.cancelled() => {
tracing::info!(peer = %peer, "Shutdown requested before pool checkout");
let (mut sink, _) = socket.split();
let _ = sink
.send(WsMessage::Close(Some(axum::extract::ws::CloseFrame {
code: 1001,
reason: "server shutdown".into(),
})))
.await;
return;
}
res = tokio::time::timeout(
std::time::Duration::from_secs(30),
state.engine.pool.checkout(),
) => match res {
Ok(Ok(guard)) => guard,
Ok(Err(_pool_closed)) => {
tracing::info!("WebSocket pool closed for {peer} — server is shutting down");
let (mut sink, _) = socket.split();
let err = ServerMessage::Error {
message: "Server is shutting down".into(),
code: "pool_closed".into(),
retry_after_ms: None,
};
let _ = sink.send(WsMessage::Text(json_text(&err).into())).await;
return;
}
Err(_) => {
tracing::warn!("WebSocket pool checkout timeout for {peer}");
let (mut sink, _) = socket.split();
let err = ServerMessage::Error {
message: "Server busy, try again later".into(),
code: "timeout".into(),
retry_after_ms: Some(POOL_RETRY_AFTER_MS),
};
let _ = sink.send(WsMessage::Text(json_text(&err).into())).await;
return;
}
}
};
let (triplet, reservation) = guard.into_owned();
let (triplet_opt, result) = handle_ws_inner(
socket,
peer,
&state.engine,
&state.limits,
triplet,
state.shutdown.clone(),
)
.await;
if let Err(e) = result {
tracing::error!("WebSocket error from {peer}: {e}");
}
if let Some(triplet) = triplet_opt {
reservation.checkin(triplet);
}
}
enum FrameOutcome {
Continue,
Break,
}
type WsSink = futures_util::stream::SplitSink<WebSocket, WsMessage>;
async fn send_server_message(sink: &mut WsSink, msg: &ServerMessage) -> Result<()> {
sink.send(WsMessage::Text(json_text(msg).into()))
.await
.map_err(Into::into)
}
#[allow(clippy::too_many_arguments)]
async fn handle_binary_frame(
sink: &mut WsSink,
engine: &Arc<Engine>,
state_opt: &mut Option<crate::inference::StreamingState>,
triplet_opt: &mut Option<SessionTriplet>,
audio_received: &mut bool,
client_sample_rate: u32,
pending_byte: &mut Option<u8>,
peer: SocketAddr,
data: axum::body::Bytes,
) -> Result<FrameOutcome> {
if data.is_empty() {
tracing::debug!("Empty binary frame from {peer}, skipping");
return Ok(FrameOutcome::Continue);
}
*audio_received = true;
let carry_prev = pending_byte.take();
let samples_f32: Vec<f32> = if carry_prev.is_some() || !data.len().is_multiple_of(2) {
let mut combined = Vec::with_capacity(data.len() + 1);
if let Some(prev) = carry_prev {
combined.push(prev);
}
combined.extend_from_slice(&data);
if !combined.len().is_multiple_of(2) {
tracing::warn!(
"Odd-length PCM stream from {peer}: {} bytes incl. carry, deferring 1 byte",
combined.len()
);
*pending_byte = combined.pop();
}
combined
.chunks_exact(2)
.map(|chunk| i16::from_le_bytes([chunk[0], chunk[1]]) as f32 / 32768.0)
.collect()
} else {
data.chunks_exact(2)
.map(|chunk| i16::from_le_bytes([chunk[0], chunk[1]]) as f32 / 32768.0)
.collect()
};
let samples_16k = if client_sample_rate == 16000 {
samples_f32
} else {
crate::inference::audio::resample(&samples_f32, client_sample_rate, 16000)?
};
let state = state_opt
.take()
.ok_or_else(|| anyhow::anyhow!("Streaming state lost"))?;
let triplet = triplet_opt.take().ok_or_else(|| {
tracing::error!("Triplet unexpectedly missing for {peer}");
anyhow::anyhow!("Triplet lost")
})?;
let eng = engine.clone();
let join_result = tokio::task::spawn_blocking(move || {
let mut state = state;
let mut triplet = triplet;
let r = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
eng.process_chunk(&samples_16k, &mut state, &mut triplet)
}));
(r, state, triplet)
})
.await;
match join_result {
Ok((Ok(Ok(segments)), state_back, triplet_back)) => {
*state_opt = Some(state_back);
*triplet_opt = Some(triplet_back);
for seg in segments {
let msg = if seg.is_final {
ServerMessage::Final {
text: seg.text.to_string(),
timestamp: seg.timestamp,
words: seg.words.to_vec(),
}
} else {
ServerMessage::Partial {
text: seg.text.to_string(),
timestamp: seg.timestamp,
words: seg.words.to_vec(),
}
};
send_server_message(sink, &msg).await?;
}
Ok(FrameOutcome::Continue)
}
Ok((Ok(Err(e)), state_back, triplet_back)) => {
*state_opt = Some(state_back);
*triplet_opt = Some(triplet_back);
tracing::error!("Inference error for {peer}: {e:#}");
send_server_message(
sink,
&ServerMessage::Error {
message: "Inference failed. Please check audio format.".into(),
code: "inference_error".into(),
retry_after_ms: None,
},
)
.await?;
Ok(FrameOutcome::Continue)
}
Ok((Err(_panic), _state_back, triplet_back)) => {
tracing::error!(
"Panic in WS inference for {peer} — triplet recovered, streaming state reset"
);
*triplet_opt = Some(triplet_back);
match engine.create_state(false) {
Ok(state) => *state_opt = Some(state),
Err(e) => {
tracing::error!("Failed to create streaming state after panic: {e}");
}
}
send_server_message(
sink,
&ServerMessage::Error {
message: "Inference failed unexpectedly. Session reset.".into(),
code: "inference_panic".into(),
retry_after_ms: None,
},
)
.await?;
Ok(FrameOutcome::Continue)
}
Err(e) => {
tracing::error!("spawn_blocking join error for {peer}: {e}");
Err(anyhow::anyhow!("Blocking task join failed"))
}
}
}
#[allow(clippy::too_many_arguments)]
async fn handle_configure_message(
sink: &mut WsSink,
engine: &Arc<Engine>,
state_opt: &mut Option<crate::inference::StreamingState>,
client_sample_rate: &mut u32,
audio_received: bool,
sample_rate: Option<u32>,
diarization: Option<bool>,
peer: SocketAddr,
) -> Result<FrameOutcome> {
if audio_received {
send_server_message(
sink,
&ServerMessage::Error {
message: "Configure must be sent before first audio frame".into(),
code: "configure_too_late".into(),
retry_after_ms: None,
},
)
.await?;
return Ok(FrameOutcome::Continue);
}
if let Some(rate) = sample_rate {
if SUPPORTED_RATES.contains(&rate) {
*client_sample_rate = rate;
tracing::info!("Client {peer} configured sample rate: {rate}Hz");
} else {
send_server_message(
sink,
&ServerMessage::Error {
message: format!(
"Unsupported sample rate: {rate}Hz. Supported: {SUPPORTED_RATES:?}"
),
code: "invalid_sample_rate".into(),
retry_after_ms: None,
},
)
.await?;
}
}
#[cfg(feature = "diarization")]
if let Some(enable_dia) = diarization {
tracing::info!("Client {peer} configured diarization: {enable_dia}");
*state_opt = Some(
engine
.create_state(enable_dia)
.map_err(|e| anyhow::anyhow!("State init failed: {e}"))?,
);
}
#[cfg(not(feature = "diarization"))]
{
let _ = (engine, state_opt, diarization);
}
Ok(FrameOutcome::Continue)
}
async fn handle_stop_message(
sink: &mut WsSink,
engine: &Arc<Engine>,
state_opt: &mut Option<crate::inference::StreamingState>,
triplet_opt: &mut Option<SessionTriplet>,
peer: SocketAddr,
) -> Result<FrameOutcome> {
tracing::info!("Stop received from {peer}, finalizing");
let Some(mut state) = state_opt.take() else {
return Ok(FrameOutcome::Break);
};
let Some(mut triplet) = triplet_opt.take() else {
return Ok(FrameOutcome::Break);
};
let flush_seg = engine.flush_state(&mut state, &mut triplet);
*triplet_opt = Some(triplet);
let final_msg = if let Some(seg) = flush_seg {
ServerMessage::Final {
text: seg.text.to_string(),
timestamp: seg.timestamp,
words: seg.words.to_vec(),
}
} else {
ServerMessage::Final {
text: String::new(),
timestamp: crate::inference::now_timestamp(),
words: vec![],
}
};
send_server_message(sink, &final_msg).await?;
Ok(FrameOutcome::Break)
}
async fn flush_and_final(
sink: &mut WsSink,
engine: &Arc<Engine>,
state_opt: &mut Option<crate::inference::StreamingState>,
triplet_opt: &mut Option<SessionTriplet>,
) -> Result<()> {
let Some(mut triplet) = triplet_opt.take() else {
let final_msg = ServerMessage::Final {
text: String::new(),
timestamp: crate::inference::now_timestamp(),
words: vec![],
};
return send_server_message(sink, &final_msg).await;
};
let flush_seg = state_opt
.as_mut()
.and_then(|state| engine.flush_state(state, &mut triplet));
*triplet_opt = Some(triplet);
let final_msg = match flush_seg {
Some(seg) => ServerMessage::Final {
text: seg.text.to_string(),
timestamp: seg.timestamp,
words: seg.words.to_vec(),
},
None => ServerMessage::Final {
text: String::new(),
timestamp: crate::inference::now_timestamp(),
words: vec![],
},
};
send_server_message(sink, &final_msg).await
}
async fn handle_ws_inner(
socket: WebSocket,
peer: SocketAddr,
engine: &Arc<Engine>,
limits: &RuntimeLimits,
triplet: SessionTriplet,
cancel: tokio_util::sync::CancellationToken,
) -> (Option<SessionTriplet>, Result<()>) {
let (mut sink, mut source) = socket.split();
tracing::info!("Client connected: {peer}");
#[cfg(feature = "diarization")]
let diarization_available = engine.has_speaker_encoder();
#[cfg(not(feature = "diarization"))]
let diarization_available = false;
let ready = ServerMessage::Ready {
model: "zipformer-vi-rnnt".into(),
sample_rate: DEFAULT_SAMPLE_RATE,
version: crate::protocol::PROTOCOL_VERSION.into(),
supported_rates: SUPPORTED_RATES.to_vec(),
diarization: diarization_available,
};
if let Err(e) = send_server_message(&mut sink, &ready).await {
return (Some(triplet), Err(e));
}
let mut state_opt = match engine.create_state(false) {
Ok(state) => Some(state),
Err(e) => {
tracing::error!("State init failed: {e}");
return (Some(triplet), Err(anyhow::anyhow!("State init failed: {e}")));
}
};
let mut triplet_opt = Some(triplet);
let mut client_sample_rate: u32 = DEFAULT_SAMPLE_RATE;
let mut audio_received = false;
let mut pending_byte: Option<u8> = None;
let idle_timeout = std::time::Duration::from_secs(limits.idle_timeout_secs);
let session_deadline = if limits.max_session_secs == 0 {
tokio::time::Instant::now() + std::time::Duration::from_secs(u64::MAX / 2)
} else {
tokio::time::Instant::now() + std::time::Duration::from_secs(limits.max_session_secs)
};
let result: Result<()> = loop {
if cancel.is_cancelled() {
tracing::info!(peer = %peer, "Shutdown signalled — flushing WS session");
let _ = flush_and_final(&mut sink, engine, &mut state_opt, &mut triplet_opt).await;
let _ = sink
.send(WsMessage::Close(Some(axum::extract::ws::CloseFrame {
code: 1001,
reason: "server shutdown".into(),
})))
.await;
break Ok(());
}
if tokio::time::Instant::now() >= session_deadline {
tracing::warn!(
peer = %peer,
max_session_secs = limits.max_session_secs,
"Session cap reached — closing WS"
);
let _ = send_server_message(
&mut sink,
&ServerMessage::Error {
message: "Maximum session duration exceeded".into(),
code: "max_session_duration_exceeded".into(),
retry_after_ms: None,
},
)
.await;
let _ = flush_and_final(&mut sink, engine, &mut state_opt, &mut triplet_opt).await;
let _ = sink
.send(WsMessage::Close(Some(axum::extract::ws::CloseFrame {
code: 1008,
reason: "max session duration".into(),
})))
.await;
break Ok(());
}
tokio::select! {
biased;
_ = cancel.cancelled() => {
tracing::info!(peer = %peer, "Shutdown signalled — flushing WS session");
let _ = flush_and_final(&mut sink, engine, &mut state_opt, &mut triplet_opt).await;
let _ = sink
.send(WsMessage::Close(Some(axum::extract::ws::CloseFrame {
code: 1001,
reason: "server shutdown".into(),
})))
.await;
break Ok(());
}
_ = tokio::time::sleep_until(session_deadline) => {
tracing::warn!(
peer = %peer,
max_session_secs = limits.max_session_secs,
"Session cap reached — closing WS"
);
let _ = send_server_message(
&mut sink,
&ServerMessage::Error {
message: "Maximum session duration exceeded".into(),
code: "max_session_duration_exceeded".into(),
retry_after_ms: None,
},
)
.await;
let _ = flush_and_final(&mut sink, engine, &mut state_opt, &mut triplet_opt).await;
let _ = sink
.send(WsMessage::Close(Some(axum::extract::ws::CloseFrame {
code: 1008,
reason: "max session duration".into(),
})))
.await;
break Ok(());
}
maybe_msg = tokio::time::timeout(idle_timeout, source.next()) => {
let msg = match maybe_msg {
Ok(Some(Ok(msg))) => msg,
Ok(Some(Err(e))) => break Err(e.into()),
Ok(None) => break Ok(()),
Err(_) => {
tracing::info!(
"Client {peer} idle timeout ({}s)",
limits.idle_timeout_secs
);
break Ok(());
}
};
let outcome = match msg {
WsMessage::Binary(data) => {
handle_binary_frame(
&mut sink,
engine,
&mut state_opt,
&mut triplet_opt,
&mut audio_received,
client_sample_rate,
&mut pending_byte,
peer,
data,
)
.await
}
WsMessage::Text(text) => match serde_json::from_str::<ClientMessage>(&text) {
Ok(ClientMessage::Configure {
sample_rate,
diarization,
}) => {
handle_configure_message(
&mut sink,
engine,
&mut state_opt,
&mut client_sample_rate,
audio_received,
sample_rate,
diarization,
peer,
)
.await
}
Ok(ClientMessage::Stop) => {
handle_stop_message(&mut sink, engine, &mut state_opt, &mut triplet_opt, peer).await
}
Err(_) => {
tracing::debug!(
"Unrecognized text message from {peer}: {}",
&text[..text.len().min(100)]
);
Ok(FrameOutcome::Continue)
}
},
WsMessage::Close(_) => Ok(FrameOutcome::Break),
_ => Ok(FrameOutcome::Continue), };
match outcome {
Ok(FrameOutcome::Continue) => continue,
Ok(FrameOutcome::Break) => break Ok(()),
Err(e) => break Err(e),
}
}
}
};
tracing::info!("Client disconnected: {peer}");
(triplet_opt, result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_runtime_limits_default_rate_limit_disabled() {
let limits = RuntimeLimits::default();
assert_eq!(
limits.rate_limit_per_minute, 0,
"rate limiting must be off by default (privacy-first)"
);
assert_eq!(limits.rate_limit_burst, 10, "default burst size must be 10");
}
#[test]
fn test_runtime_limits_default_session_and_drain() {
let limits = RuntimeLimits::default();
assert_eq!(
limits.max_session_secs, 3600,
"default session cap must be 1 hour to stop silence-streamers from \
holding a triplet forever"
);
assert_eq!(
limits.shutdown_drain_secs, 10,
"default shutdown drain must be 10 s — comfortably inside the usual \
k8s terminationGracePeriodSeconds = 30"
);
}
#[test]
fn test_supported_rates_contains_common() {
assert!(
SUPPORTED_RATES.contains(&8000),
"SUPPORTED_RATES must include 8000 Hz"
);
assert!(
SUPPORTED_RATES.contains(&16000),
"SUPPORTED_RATES must include 16000 Hz"
);
assert!(
SUPPORTED_RATES.contains(&48000),
"SUPPORTED_RATES must include 48000 Hz"
);
}
#[test]
fn test_default_sample_rate_in_supported() {
assert!(
SUPPORTED_RATES.contains(&DEFAULT_SAMPLE_RATE),
"DEFAULT_SAMPLE_RATE ({DEFAULT_SAMPLE_RATE}) must be present in SUPPORTED_RATES"
);
}
#[test]
fn test_loopback_origin_matcher() {
assert!(is_loopback_origin("http://localhost"));
assert!(is_loopback_origin("https://localhost:3000"));
assert!(is_loopback_origin("http://127.0.0.1:9876"));
assert!(is_loopback_origin("HTTPS://127.0.0.1")); assert!(is_loopback_origin("http://[::1]:9876"));
assert!(!is_loopback_origin("https://evil.example.com"));
assert!(!is_loopback_origin("http://192.168.1.10"));
assert!(!is_loopback_origin("http://localhost.evil.example.com"));
}
#[test]
fn test_origin_policy_default_denies_third_party() {
let policy = OriginPolicy::loopback_only();
assert!(matches!(
policy.evaluate(Some("https://evil.example.com")),
OriginVerdict::Denied
));
}
#[test]
fn test_origin_policy_allows_loopback_by_default() {
let policy = OriginPolicy::loopback_only();
assert!(matches!(
policy.evaluate(Some("http://localhost:3000")),
OriginVerdict::Allowed(_)
));
}
#[test]
fn test_origin_policy_allows_listed_origin() {
let policy = OriginPolicy {
allow_any: false,
allowed_origins: vec!["https://app.example.com".into()],
};
assert!(matches!(
policy.evaluate(Some("https://app.example.com")),
OriginVerdict::Allowed(_)
));
assert!(matches!(
policy.evaluate(Some("https://app.example.com.evil.com")),
OriginVerdict::Denied
));
}
#[test]
fn test_origin_policy_allow_any_short_circuits() {
let policy = OriginPolicy {
allow_any: true,
allowed_origins: vec![],
};
assert!(matches!(
policy.evaluate(Some("https://anything.example.com")),
OriginVerdict::Allowed(_)
));
}
#[test]
fn test_origin_policy_no_header_allowed() {
let policy = OriginPolicy::loopback_only();
assert!(matches!(
policy.evaluate(None),
OriginVerdict::AllowedNoEcho
));
assert!(matches!(
policy.evaluate(Some("null")),
OriginVerdict::AllowedNoEcho
));
}
#[tokio::test]
async fn test_origin_middleware_integration() {
use axum::Router;
use axum::routing::get;
let policy = Arc::new(OriginPolicy::loopback_only());
let origin_layer = {
let policy = policy.clone();
axum::middleware::from_fn(move |req, next| {
let policy = policy.clone();
async move { origin_middleware(policy, req, next).await }
})
};
let app = Router::new()
.route("/health", get(|| async { "ok" }))
.route("/v1/transcribe", get(|| async { "ok" }))
.layer(origin_layer);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
tokio::spawn(async move {
let _ = axum::serve(listener, app).await;
});
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let client = reqwest::Client::new();
let base = format!("http://127.0.0.1:{port}");
let r = client
.get(format!("{base}/health"))
.header("Origin", "https://evil.example.com")
.send()
.await
.unwrap();
assert_eq!(r.status(), 200, "/health must skip the Origin guard");
let r = client
.get(format!("{base}/v1/transcribe"))
.header("Origin", "https://evil.example.com")
.send()
.await
.unwrap();
assert_eq!(
r.status(),
403,
"non-loopback Origin must receive 403 Forbidden"
);
let text = r.text().await.unwrap();
let body: serde_json::Value = serde_json::from_str(&text).unwrap();
assert_eq!(body["code"], "origin_denied");
let r = client
.get(format!("{base}/v1/transcribe"))
.header("Origin", "http://localhost:3000")
.send()
.await
.unwrap();
assert_eq!(r.status(), 200, "loopback Origin must be allowed");
assert_eq!(
r.headers()
.get("access-control-allow-origin")
.and_then(|v| v.to_str().ok()),
Some("http://localhost:3000"),
"CORS echo must mirror the incoming Origin (no wildcard by default)",
);
let r = client
.get(format!("{base}/v1/transcribe"))
.send()
.await
.unwrap();
assert_eq!(r.status(), 200, "requests without Origin must pass");
let r = client
.get(format!("{base}/v1/transcribe"))
.header("Origin", "http://localhost.evil.example.com")
.send()
.await
.unwrap();
assert_eq!(
r.status(),
403,
"localhost.* DNS continuation must not impersonate loopback"
);
}
#[test]
fn test_rate_limit_interval_formula() {
const MAX_RPM: u64 = 60_000;
fn interval_ms_for(rpm: u32) -> u64 {
let rpm = (rpm as u64).min(MAX_RPM);
(60_000u64 / rpm).max(1)
}
let cases: &[(u32, u64)] = &[
(1, 60_000),
(10, 6_000),
(30, 2_000),
(59, 1_016), (60, 1_000),
(600, 100),
(60_000, 1),
(120_000, 1), ];
for (rpm, expected) in cases {
assert_eq!(
interval_ms_for(*rpm),
*expected,
"rpm={rpm} should map to interval_ms={expected}"
);
}
}
#[test]
fn test_catch_unwind_preserves_ownership_across_panic() {
use std::panic::{AssertUnwindSafe, catch_unwind};
let mut state = 42u32;
let mut triplet_marker = String::from("pool_slot");
let result = catch_unwind(AssertUnwindSafe(|| {
state = 99;
triplet_marker.push_str("/taken");
panic!("simulated inference panic");
}));
assert!(result.is_err(), "catch_unwind must report the panic");
assert_eq!(state, 99, "state must remain accessible after panic");
assert_eq!(
triplet_marker, "pool_slot/taken",
"triplet marker must survive panic"
);
}
}