use clap::{Parser, Subcommand};
use gigastt::server;
use gigastt::server::{OriginPolicy, RuntimeLimits, ServerConfig};
use gigastt_core::{inference, model};
use std::net::IpAddr;
use tracing_subscriber::EnvFilter;
#[derive(Parser)]
#[command(
name = "gigastt",
version,
about = "Local STT server powered by GigaAM v3"
)]
struct Cli {
#[arg(long, global = true, default_value = "info")]
log_level: String,
#[command(subcommand)]
command: Commands,
}
#[derive(Subcommand)]
enum Commands {
Serve {
#[arg(short, long, default_value_t = 9876)]
port: u16,
#[arg(long, default_value = "127.0.0.1")]
host: String,
#[arg(long, default_value_t = model::default_model_dir())]
model_dir: String,
#[arg(long, default_value_t = 4)]
pool_size: usize,
#[arg(long, default_value_t = false)]
bind_all: bool,
#[arg(long = "allow-origin", value_name = "URL")]
allow_origin: Vec<String>,
#[arg(long, default_value_t = false)]
cors_allow_any: bool,
#[arg(long, env = "GIGASTT_IDLE_TIMEOUT_SECS")]
idle_timeout_secs: Option<u64>,
#[arg(long, env = "GIGASTT_WS_FRAME_MAX_BYTES")]
ws_frame_max_bytes: Option<usize>,
#[arg(long, env = "GIGASTT_BODY_LIMIT_BYTES")]
body_limit_bytes: Option<usize>,
#[arg(long, env = "GIGASTT_RATE_LIMIT_PER_MINUTE")]
rate_limit_per_minute: Option<u32>,
#[arg(long, env = "GIGASTT_RATE_LIMIT_BURST")]
rate_limit_burst: Option<u32>,
#[arg(long, env = "GIGASTT_METRICS", default_value_t = false)]
metrics: bool,
#[arg(long, env = "GIGASTT_MAX_SESSION_SECS")]
max_session_secs: Option<u64>,
#[arg(long, env = "GIGASTT_SHUTDOWN_DRAIN_SECS")]
shutdown_drain_secs: Option<u64>,
#[arg(long, env = "GIGASTT_POOL_CHECKOUT_TIMEOUT_SECS")]
pool_checkout_timeout_secs: Option<u64>,
#[arg(long, env = "GIGASTT_SKIP_QUANTIZE", default_value_t = false)]
skip_quantize: bool,
#[arg(long, env = "GIGASTT_TRUST_PROXY", default_value_t = false)]
trust_proxy: bool,
#[arg(long)]
config: Option<String>,
},
Download {
#[arg(long, default_value_t = model::default_model_dir())]
model_dir: String,
#[cfg(feature = "diarization")]
#[arg(long, default_value_t = false)]
skip_diarization: bool,
#[arg(long, env = "GIGASTT_SKIP_QUANTIZE", default_value_t = false)]
skip_quantize: bool,
},
Quantize {
#[arg(long, default_value_t = model::default_model_dir())]
model_dir: String,
#[arg(long)]
force: bool,
},
Transcribe {
file: String,
#[arg(long, default_value_t = model::default_model_dir())]
model_dir: String,
},
}
fn log_rss() {
#[cfg(target_os = "linux")]
{
if let Ok(status) = std::fs::read_to_string("/proc/self/status")
&& let Some(line) = status.lines().find(|l| l.starts_with("VmRSS:"))
{
tracing::info!("{}", line.trim());
}
}
#[cfg(not(target_os = "linux"))]
{
if let Ok(output) = std::process::Command::new("ps")
.args(["-o", "rss=", "-p", &std::process::id().to_string()])
.output()
&& let Ok(rss) = String::from_utf8_lossy(&output.stdout)
.trim()
.parse::<u64>()
{
tracing::info!(rss_mb = rss / 1024, "memory_after_load");
}
}
}
fn ensure_bind_allowed(host: &str, bind_all_flag: bool) -> anyhow::Result<()> {
if is_loopback_host(host) {
return Ok(());
}
let env_opt_in = std::env::var("GIGASTT_ALLOW_BIND_ANY")
.map(|v| matches!(v.trim(), "1" | "true" | "TRUE" | "yes" | "YES"))
.unwrap_or(false);
if bind_all_flag || env_opt_in {
tracing::warn!(
host = %host,
"binding to non-loopback address — anyone on the network can reach this server"
);
return Ok(());
}
anyhow::bail!(
"refusing to bind to '{host}': non-loopback addresses require \
`--bind-all` (or env GIGASTT_ALLOW_BIND_ANY=1) to prevent accidental \
public exposure of local transcription"
)
}
fn is_loopback_host(host: &str) -> bool {
let lowered = host.trim().to_ascii_lowercase();
if lowered == "localhost" || lowered == "::1" {
return true;
}
let stripped = lowered.trim_start_matches('[').trim_end_matches(']');
if let Ok(ip) = stripped.parse::<IpAddr>() {
return ip.is_loopback();
}
false
}
fn ensure_int8_encoder(model_dir: &str, skip: bool) -> anyhow::Result<()> {
let int8_path = std::path::Path::new(model_dir).join("v3_e2e_rnnt_encoder_int8.onnx");
if int8_path.exists() {
return Ok(());
}
if skip {
tracing::info!(
"Skipping INT8 quantization (--skip-quantize). Engine will load the FP32 encoder."
);
return Ok(());
}
let input = std::path::Path::new(model_dir).join("v3_e2e_rnnt_encoder.onnx");
if !input.exists() {
anyhow::bail!(
"Cannot quantize: FP32 encoder not found at {}",
input.display()
);
}
tracing::info!("Quantizing encoder to INT8 (~2 min, one-time)…");
gigastt_core::quantize::quantize_model(&input, &int8_path)?;
tracing::info!("INT8 encoder saved to {}", int8_path.display());
Ok(())
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let cli = Cli::parse();
let directive = format!("gigastt={}", cli.log_level);
tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env().add_directive(directive.parse()?))
.init();
match cli.command {
Commands::Serve {
port,
host,
model_dir,
pool_size,
bind_all,
allow_origin,
cors_allow_any,
idle_timeout_secs,
ws_frame_max_bytes,
body_limit_bytes,
rate_limit_per_minute,
rate_limit_burst,
metrics,
max_session_secs,
shutdown_drain_secs,
pool_checkout_timeout_secs,
skip_quantize,
trust_proxy,
config,
} => {
ensure_bind_allowed(&host, bind_all)?;
model::ensure_model(&model_dir).await?;
ensure_int8_encoder(&model_dir, skip_quantize)?;
let engine = inference::Engine::load_with_pool_size(&model_dir, pool_size)?;
log_rss();
let mut limits = if let Some(ref path) = config {
server::config::load_config_file(std::path::Path::new(path))?
} else {
RuntimeLimits::default()
};
if let Some(v) = idle_timeout_secs {
limits.idle_timeout_secs = v;
}
if let Some(v) = ws_frame_max_bytes {
limits.ws_frame_max_bytes = v;
}
if let Some(v) = body_limit_bytes {
limits.body_limit_bytes = v;
}
if let Some(v) = rate_limit_per_minute {
limits.rate_limit_per_minute = v;
}
if let Some(v) = rate_limit_burst {
limits.rate_limit_burst = v;
}
if limits.rate_limit_per_minute > 0 && limits.rate_limit_burst == 0 {
anyhow::bail!(
"--rate-limit-burst must be > 0 when --rate-limit-per-minute is enabled"
);
}
if let Some(v) = max_session_secs {
limits.max_session_secs = v;
}
if let Some(v) = shutdown_drain_secs {
limits.shutdown_drain_secs = v;
}
if let Some(v) = pool_checkout_timeout_secs {
limits.pool_checkout_timeout_secs = v;
}
let config = ServerConfig {
port,
host,
origin_policy: OriginPolicy {
allow_any: cors_allow_any,
allowed_origins: allow_origin,
},
limits,
metrics_enabled: metrics,
trust_proxy,
config_path: config.map(std::path::PathBuf::from),
};
server::run_with_config(engine, config, None).await?;
}
Commands::Download {
model_dir,
#[cfg(feature = "diarization")]
skip_diarization,
skip_quantize,
} => {
model::ensure_model(&model_dir).await?;
#[cfg(feature = "diarization")]
{
if !skip_diarization {
model::ensure_speaker_model(&model_dir).await?;
}
}
ensure_int8_encoder(&model_dir, skip_quantize)?;
tracing::info!("Model ready at {model_dir}");
}
Commands::Quantize { model_dir, force } => {
model::ensure_model(&model_dir).await?;
let input = std::path::Path::new(&model_dir).join("v3_e2e_rnnt_encoder.onnx");
let output = std::path::Path::new(&model_dir).join("v3_e2e_rnnt_encoder_int8.onnx");
if output.exists() && !force {
tracing::info!("INT8 model already exists: {}", output.display());
tracing::info!("Use --force to re-quantize.");
return Ok(());
}
gigastt_core::quantize::quantize_model(&input, &output)?;
tracing::info!("Quantized model saved to {}", output.display());
}
Commands::Transcribe { file, model_dir } => {
model::ensure_model(&model_dir).await?;
let engine = inference::Engine::load_with_pool_size(&model_dir, 1)?;
log_rss();
let mut guard = engine.pool.checkout().await?;
let result = engine.transcribe_file(&file, &mut guard);
drop(guard);
println!("{}", result?.text);
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_loopback_host_recognises_common_forms() {
assert!(is_loopback_host("127.0.0.1"));
assert!(is_loopback_host("localhost"));
assert!(is_loopback_host("::1"));
assert!(is_loopback_host("[::1]"));
assert!(is_loopback_host("127.0.0.2")); assert!(!is_loopback_host("0.0.0.0"));
assert!(!is_loopback_host("192.168.1.10"));
assert!(!is_loopback_host("example.com"));
}
#[test]
fn test_ensure_bind_allowed_loopback_ok() {
ensure_bind_allowed("127.0.0.1", false).expect("loopback must be allowed");
ensure_bind_allowed("localhost", false).expect("localhost must be allowed");
}
#[test]
fn test_ensure_bind_allowed_non_loopback_requires_flag() {
let previous = std::env::var("GIGASTT_ALLOW_BIND_ANY").ok();
unsafe {
std::env::remove_var("GIGASTT_ALLOW_BIND_ANY");
}
let result = ensure_bind_allowed("0.0.0.0", false);
if let Some(v) = previous {
unsafe {
std::env::set_var("GIGASTT_ALLOW_BIND_ANY", v);
}
}
assert!(
result.is_err(),
"0.0.0.0 without --bind-all must be rejected"
);
}
#[test]
fn test_ensure_bind_allowed_explicit_flag_ok() {
ensure_bind_allowed("0.0.0.0", true).expect("explicit --bind-all must pass");
}
}