pub mod config;
pub mod http;
pub mod metrics;
pub(crate) mod middleware;
pub mod rate_limit;
mod ws;
pub use config::{OriginPolicy, RuntimeLimits, ServerConfig};
use anyhow::{Context, Result};
use arc_swap::ArcSwap;
use axum::Router;
use axum::extract::DefaultBodyLimit;
use axum::http::StatusCode;
use axum::routing::{get, options, post};
use std::net::SocketAddr;
use std::sync::Arc;
pub(crate) fn json_text(msg: &impl serde::Serialize) -> String {
serde_json::to_string(msg).unwrap_or_else(|e| {
tracing::error!("Failed to serialize server message: {e}");
r#"{"type":"error","message":"Internal serialization error","code":"internal"}"#.into()
})
}
pub async fn run(engine: gigastt_core::inference::Engine, port: u16, host: &str) -> Result<()> {
run_with_shutdown(engine, port, host, None).await
}
pub async fn run_with_shutdown(
engine: gigastt_core::inference::Engine,
port: u16,
host: &str,
shutdown: Option<tokio::sync::oneshot::Receiver<()>>,
) -> Result<()> {
let config = ServerConfig {
port,
host: host.to_string(),
origin_policy: OriginPolicy::loopback_only(),
limits: RuntimeLimits::default(),
metrics_enabled: false,
trust_proxy: false,
config_path: None,
};
run_with_config(engine, config, shutdown).await
}
pub async fn run_with_config(
engine: gigastt_core::inference::Engine,
config: ServerConfig,
shutdown: Option<tokio::sync::oneshot::Receiver<()>>,
) -> Result<()> {
let addr: SocketAddr = format!("{}:{}", config.host, config.port)
.parse()
.context("Invalid host:port")?;
let listener = tokio::net::TcpListener::bind(&addr).await?;
run_with_config_listener(engine, config, shutdown, listener).await
}
pub async fn run_with_config_listener(
engine: gigastt_core::inference::Engine,
mut config: ServerConfig,
shutdown: Option<tokio::sync::oneshot::Receiver<()>>,
listener: tokio::net::TcpListener,
) -> Result<()> {
if config.limits.pool_checkout_timeout_secs == 0 {
tracing::warn!("pool_checkout_timeout_secs=0 would make the pool unusable; clamping to 1");
config.limits.pool_checkout_timeout_secs = 1;
}
let addr: SocketAddr = format!("{}:{}", config.host, config.port)
.parse()
.context("Invalid host:port")?;
let metrics_registry = if config.metrics_enabled {
let reg = std::sync::Arc::new(self::metrics::MetricsRegistry::new());
reg.register_counter(
"gigastt_http_requests_total",
"Total HTTP requests processed",
);
reg.register_histogram(
"gigastt_http_request_duration_seconds",
"HTTP request duration in seconds",
self::metrics::DEFAULT_BUCKETS,
);
reg.register_gauge(
"gigastt_pool_available",
"Number of session triplets currently available in the pool",
);
reg.register_histogram(
"gigastt_pool_checkout_duration_seconds",
"Time spent waiting for a pool checkout",
self::metrics::DEFAULT_BUCKETS,
);
reg.register_counter(
"gigastt_pool_timeouts_total",
"Total pool checkout timeouts",
);
reg.register_gauge(
"gigastt_ws_active_connections",
"Number of active WebSocket connections",
);
reg.register_histogram(
"gigastt_inference_duration_seconds",
"Inference duration in seconds",
self::metrics::DEFAULT_BUCKETS,
);
reg.register_counter(
"gigastt_rate_limit_rejections_total",
"Total requests rejected by rate limiter",
);
tracing::info!("Prometheus /metrics endpoint enabled");
Some(reg)
} else {
None
};
if config.limits.max_session_secs != 0
&& config.limits.max_session_secs < config.limits.idle_timeout_secs
{
tracing::warn!(
max_session_secs = config.limits.max_session_secs,
idle_timeout_secs = config.limits.idle_timeout_secs,
"max_session_secs < idle_timeout_secs — sessions will be capped before \
the idle timer can fire; this is probably not what you want"
);
}
let shutdown_root = tokio_util::sync::CancellationToken::new();
let tracker = tokio_util::task::TaskTracker::new();
let state = Arc::new(http::AppState {
engine: Arc::new(engine),
limits: Arc::new(ArcSwap::from_pointee(config.limits.clone())),
metrics_registry: metrics_registry.clone(),
shutdown: shutdown_root.clone(),
tracker: tracker.clone(),
});
let rate_limiter_swap = if config.limits.rate_limit_per_minute > 0 {
Some(Arc::new(ArcSwap::from(Arc::new(
rate_limit::RateLimiter::new(
config.limits.rate_limit_per_minute,
config.limits.rate_limit_burst,
),
))))
} else {
None
};
#[cfg(unix)]
{
let reload_state = state.clone();
let reload_path = config.config_path.clone();
let reload_shutdown = shutdown_root.clone();
let reload_limiter = rate_limiter_swap.clone();
tokio::spawn(async move {
use tokio::signal::unix::{SignalKind, signal};
let mut sig = signal(SignalKind::hangup()).expect("failed to register SIGHUP handler");
loop {
tokio::select! {
biased;
_ = reload_shutdown.cancelled() => break,
_ = sig.recv() => {
let Some(ref path) = reload_path else {
tracing::info!("No config file specified, ignoring SIGHUP");
continue;
};
match config::load_config_file(path) {
Ok(new_limits) => {
let old = reload_state.limits.load();
tracing::info!(
"RuntimeLimits reloaded from {}: idle_timeout_secs {} → {}, rate_limit_per_minute {} → {}",
path.display(),
old.idle_timeout_secs, new_limits.idle_timeout_secs,
old.rate_limit_per_minute, new_limits.rate_limit_per_minute,
);
if let Some(ref rl) = reload_limiter
&& (old.rate_limit_per_minute
!= new_limits.rate_limit_per_minute
|| old.rate_limit_burst != new_limits.rate_limit_burst)
&& new_limits.rate_limit_per_minute > 0
{
rl.store(Arc::new(rate_limit::RateLimiter::new(
new_limits.rate_limit_per_minute,
new_limits.rate_limit_burst,
)));
tracing::info!(
"Rate limiter recreated: rpm {} → {}, burst {} → {}",
old.rate_limit_per_minute,
new_limits.rate_limit_per_minute,
old.rate_limit_burst,
new_limits.rate_limit_burst,
);
}
reload_state.limits.store(Arc::new(new_limits));
}
Err(e) => {
tracing::error!("Failed to reload config on SIGHUP: {e:#}");
}
}
}
}
}
});
}
let policy = Arc::new(config.origin_policy.clone());
let origin_layer = {
let policy = policy.clone();
axum::middleware::from_fn(move |req, next| {
let policy = policy.clone();
async move { middleware::origin_middleware(policy, req, next).await }
})
};
let protected = Router::new()
.route("/v1/models", get(http::models))
.route("/v1/transcribe", post(http::transcribe))
.route(
"/v1/transcribe",
options(|| async { StatusCode::NO_CONTENT }),
)
.route("/v1/transcribe/stream", post(http::transcribe_stream))
.route(
"/v1/transcribe/stream",
options(|| async { StatusCode::NO_CONTENT }),
)
.route("/v1/ws", get(ws::ws_handler))
.route("/metrics", get(http::metrics))
.layer(axum::middleware::from_fn_with_state(
state.clone(),
middleware::http_metrics_middleware,
))
.with_state(state.clone());
let protected = if let Some(ref limiter_swap) = rate_limiter_swap {
let interval_ms = limiter_swap.load().interval_ms();
let evict_limiter = limiter_swap.clone();
let evict_cancel = shutdown_root.clone();
tokio::spawn(async move {
let mut ticker = tokio::time::interval(std::time::Duration::from_secs(60));
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
ticker.tick().await;
loop {
tokio::select! {
biased;
_ = evict_cancel.cancelled() => break,
_ = ticker.tick() => {
evict_limiter.load().evict_stale(std::time::Duration::from_secs(300));
}
}
}
});
tracing::info!(
rpm = config.limits.rate_limit_per_minute,
interval_ms,
burst = config.limits.rate_limit_burst,
"per-IP rate limiting enabled"
);
let layer_limiter = limiter_swap.clone();
let layer_trust_proxy = config.trust_proxy;
let layer_metrics = metrics_registry.clone();
protected.layer(axum::middleware::from_fn(move |req, next| {
let limiter = layer_limiter.load_full();
let metrics = layer_metrics.clone();
async move {
rate_limit::rate_limit_middleware(limiter, layer_trust_proxy, metrics, req, next)
.await
}
}))
} else {
protected
};
let shutdown_engine = state.engine.clone();
let request_id_layer = axum::middleware::from_fn(middleware::request_id_middleware);
let app = Router::new()
.route("/health", get(http::health))
.route("/ready", get(http::readiness))
.merge(protected)
.layer(DefaultBodyLimit::max(config.limits.body_limit_bytes))
.layer(origin_layer)
.layer(request_id_layer)
.with_state(state);
tracing::info!("gigastt server listening on http://{addr}");
tracing::info!(" WebSocket: ws://{addr}/v1/ws");
tracing::info!(
" REST API: http://{addr}/health, /ready, /v1/transcribe, /v1/transcribe/stream"
);
if config.origin_policy.allow_any {
tracing::warn!(
"CORS allow-any is ON: any cross-origin page can call this server. \
Only use with trusted callers."
);
} else if !config.origin_policy.allowed_origins.is_empty() {
tracing::info!(
"CORS allowlist (in addition to loopback): {:?}",
config.origin_policy.allowed_origins
);
}
let shutdown_drain_secs = config.limits.shutdown_drain_secs.max(1);
let shutdown_fut = {
let shutdown_root = shutdown_root.clone();
async move {
match shutdown {
Some(rx) => {
rx.await.ok();
}
None => {
tokio::signal::ctrl_c().await.ok();
}
}
tracing::info!("Shutting down server");
shutdown_root.cancel();
shutdown_engine.pool.close();
}
};
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.with_graceful_shutdown(shutdown_fut)
.await?;
tracker.close();
match tokio::time::timeout(
std::time::Duration::from_secs(shutdown_drain_secs),
tracker.wait(),
)
.await
{
Ok(()) => tracing::info!("Drain complete: all tracked WS/SSE tasks finished"),
Err(_) => tracing::warn!(
drain_secs = shutdown_drain_secs,
pending = tracker.len(),
"Drain window expired with tracked tasks still running — forcing exit"
),
}
Ok(())
}
#[cfg(test)]
mod tests {
#[test]
fn test_rate_limit_interval_formula() {
const MAX_RPM: u64 = 60_000;
fn interval_ms_for(rpm: u32) -> u64 {
let rpm = (rpm as u64).min(MAX_RPM);
(60_000u64 / rpm).max(1)
}
let cases: &[(u32, u64)] = &[
(1, 60_000),
(10, 6_000),
(30, 2_000),
(59, 1_016), (60, 1_000),
(600, 100),
(60_000, 1),
(120_000, 1), ];
for (rpm, expected) in cases {
assert_eq!(
interval_ms_for(*rpm),
*expected,
"rpm={rpm} should map to interval_ms={expected}"
);
}
}
}