use anyhow::Context;
use clap::{Parser, Subcommand};
use gigastt::server;
use gigastt::server::{OriginPolicy, RuntimeLimits, ServerConfig};
use gigastt_core::export::{ExportFormat, RenderOpts};
use gigastt_core::model::ModelVariant;
use gigastt_core::{inference, model};
use std::net::IpAddr;
use std::str::FromStr;
use tracing_subscriber::EnvFilter;
#[derive(Parser)]
#[command(
name = "gigastt",
version,
about = "Local STT server powered by GigaAM v3"
)]
struct Cli {
#[arg(long, global = true, default_value = "info")]
log_level: String,
#[command(subcommand)]
command: Commands,
}
#[allow(clippy::large_enum_variant)]
#[derive(Subcommand)]
enum Commands {
Serve {
#[arg(short, long, default_value_t = 9876)]
port: u16,
#[arg(long, default_value = "127.0.0.1")]
host: String,
#[arg(long, default_value_t = model::default_model_dir())]
model_dir: String,
#[arg(
long,
env = "GIGASTT_MODEL_VARIANT",
value_parser = parse_model_variant
)]
model_variant: Option<ModelVariant>,
#[arg(
long,
env = "GIGASTT_PUNCTUATION",
default_value = "auto",
value_parser = parse_punctuation_mode
)]
punctuation: PunctuationMode,
#[arg(
long,
env = "GIGASTT_PUNCT_MODEL_DIR",
default_value_t = model::default_punct_model_dir()
)]
punct_model_dir: String,
#[arg(
long,
env = "GIGASTT_ITN",
default_value = "auto",
value_parser = parse_itn_mode
)]
itn: ItnMode,
#[arg(long, env = "GIGASTT_HOTWORDS_FILE")]
hotwords_file: Option<String>,
#[arg(long, env = "GIGASTT_HOTWORDS_DEFAULT", default_value_t = false)]
hotwords_default: bool,
#[arg(long, env = "GIGASTT_HOTWORDS_BOOST")]
hotwords_boost: Option<f32>,
#[arg(long, env = "GIGASTT_VAD", default_value_t = false)]
vad: bool,
#[arg(long, env = "GIGASTT_VAD_THRESHOLD")]
vad_threshold: Option<f32>,
#[arg(long, env = "GIGASTT_VAD_MIN_SILENCE_MS")]
vad_min_silence_ms: Option<u32>,
#[arg(long, env = "GIGASTT_VAD_MODEL_DIR", default_value_t = model::default_vad_model_dir())]
vad_model_dir: String,
#[arg(long, default_value_t = 2)]
pool_size: usize,
#[arg(long, env = "GIGASTT_POOL_MIN_SIZE", default_value_t = 1)]
pool_min_size: usize,
#[arg(long, env = "GIGASTT_BATCH_POOL_SIZE", default_value_t = 0)]
batch_pool_size: usize,
#[arg(long, env = "GIGASTT_ENCODER_INTRA_THREADS", default_value_t = 1)]
encoder_intra_threads: usize,
#[arg(long, default_value_t = false)]
bind_all: bool,
#[arg(long = "allow-origin", value_name = "URL")]
allow_origin: Vec<String>,
#[arg(long, default_value_t = false)]
cors_allow_any: bool,
#[arg(long, env = "GIGASTT_IDLE_TIMEOUT_SECS")]
idle_timeout_secs: Option<u64>,
#[arg(long, env = "GIGASTT_WS_FRAME_MAX_BYTES")]
ws_frame_max_bytes: Option<usize>,
#[arg(long, env = "GIGASTT_BODY_LIMIT_BYTES")]
body_limit_bytes: Option<usize>,
#[arg(long, env = "GIGASTT_RATE_LIMIT_PER_MINUTE")]
rate_limit_per_minute: Option<u32>,
#[arg(long, env = "GIGASTT_RATE_LIMIT_BURST")]
rate_limit_burst: Option<u32>,
#[arg(long, env = "GIGASTT_METRICS", default_value_t = false)]
metrics: bool,
#[arg(long, env = "GIGASTT_METRICS_LISTEN")]
metrics_listen: Option<std::net::SocketAddr>,
#[arg(long, env = "GIGASTT_MAX_SESSION_SECS")]
max_session_secs: Option<u64>,
#[arg(long, env = "GIGASTT_SHUTDOWN_DRAIN_SECS")]
shutdown_drain_secs: Option<u64>,
#[arg(long, env = "GIGASTT_POOL_CHECKOUT_TIMEOUT_SECS")]
pool_checkout_timeout_secs: Option<u64>,
#[arg(long, env = "GIGASTT_INFERENCE_TIMEOUT_SECS")]
inference_timeout_secs: Option<u64>,
#[arg(long, env = "GIGASTT_SKIP_QUANTIZE", default_value_t = false)]
skip_quantize: bool,
#[arg(long, env = "GIGASTT_TRUST_PROXY", default_value_t = false)]
trust_proxy: bool,
#[arg(long)]
config: Option<String>,
},
Download {
#[arg(long, default_value_t = model::default_model_dir())]
model_dir: String,
#[arg(
long,
env = "GIGASTT_MODEL_VARIANT",
default_value = "rnnt",
value_parser = parse_model_variant
)]
model_variant: ModelVariant,
#[cfg(feature = "diarization")]
#[arg(long, default_value_t = false)]
skip_diarization: bool,
#[arg(long, env = "GIGASTT_SKIP_QUANTIZE", default_value_t = false)]
skip_quantize: bool,
},
Quantize {
#[arg(long, default_value_t = model::default_model_dir())]
model_dir: String,
#[arg(long)]
force: bool,
},
Transcribe {
file: String,
#[arg(long, default_value_t = model::default_model_dir())]
model_dir: String,
#[arg(
long,
env = "GIGASTT_MODEL_VARIANT",
value_parser = parse_model_variant
)]
model_variant: Option<ModelVariant>,
#[arg(
long,
env = "GIGASTT_PUNCTUATION",
default_value = "auto",
value_parser = parse_punctuation_mode
)]
punctuation: PunctuationMode,
#[arg(
long,
env = "GIGASTT_PUNCT_MODEL_DIR",
default_value_t = model::default_punct_model_dir()
)]
punct_model_dir: String,
#[arg(
long,
env = "GIGASTT_ITN",
default_value = "auto",
value_parser = parse_itn_mode
)]
itn: ItnMode,
#[arg(long, env = "GIGASTT_HOTWORDS_FILE")]
hotwords_file: Option<String>,
#[arg(long, env = "GIGASTT_HOTWORDS_DEFAULT", default_value_t = false)]
hotwords_default: bool,
#[arg(long, env = "GIGASTT_HOTWORDS_BOOST")]
hotwords_boost: Option<f32>,
#[arg(long, env = "GIGASTT_VAD", default_value_t = false)]
vad: bool,
#[arg(long, env = "GIGASTT_VAD_THRESHOLD")]
vad_threshold: Option<f32>,
#[arg(long, env = "GIGASTT_VAD_MIN_SILENCE_MS")]
vad_min_silence_ms: Option<u32>,
#[arg(long, env = "GIGASTT_VAD_MODEL_DIR", default_value_t = model::default_vad_model_dir())]
vad_model_dir: String,
#[arg(long, env = "GIGASTT_ENCODER_INTRA_THREADS", default_value_t = 1)]
encoder_intra_threads: usize,
#[arg(short, long, env = "GIGASTT_FORMAT", default_value = "txt")]
format: String,
#[arg(short, long, env = "GIGASTT_OUTPUT")]
output: Option<String>,
#[arg(long, env = "GIGASTT_MAX_CHARS_PER_LINE")]
max_chars_per_line: Option<usize>,
#[arg(long, env = "GIGASTT_MAX_WORDS_PER_LINE")]
max_words_per_line: Option<usize>,
#[arg(long, env = "GIGASTT_WORD_TIMESTAMPS", default_value_t = false)]
word_timestamps: bool,
},
}
#[allow(clippy::too_many_arguments)]
fn build_limits(
config_path: Option<&str>,
idle_timeout_secs: Option<u64>,
ws_frame_max_bytes: Option<usize>,
body_limit_bytes: Option<usize>,
rate_limit_per_minute: Option<u32>,
rate_limit_burst: Option<u32>,
max_session_secs: Option<u64>,
shutdown_drain_secs: Option<u64>,
pool_checkout_timeout_secs: Option<u64>,
inference_timeout_secs: Option<u64>,
) -> anyhow::Result<RuntimeLimits> {
let mut limits = if let Some(path) = config_path {
server::config::load_config_file(std::path::Path::new(path))?
} else {
RuntimeLimits::default()
};
if let Some(v) = idle_timeout_secs {
limits.idle_timeout_secs = v;
}
if let Some(v) = ws_frame_max_bytes {
limits.ws_frame_max_bytes = v;
}
if let Some(v) = body_limit_bytes {
limits.body_limit_bytes = v;
}
if let Some(v) = rate_limit_per_minute {
limits.rate_limit_per_minute = v;
}
if let Some(v) = rate_limit_burst {
limits.rate_limit_burst = v;
}
if limits.rate_limit_per_minute > 0 && limits.rate_limit_burst == 0 {
anyhow::bail!("--rate-limit-burst must be > 0 when --rate-limit-per-minute is enabled");
}
if let Some(v) = max_session_secs {
limits.max_session_secs = v;
}
if let Some(v) = shutdown_drain_secs {
limits.shutdown_drain_secs = v;
}
if let Some(v) = pool_checkout_timeout_secs {
limits.pool_checkout_timeout_secs = v;
}
if let Some(v) = inference_timeout_secs {
limits.inference_timeout_secs = v;
}
Ok(limits)
}
#[allow(clippy::too_many_arguments)]
fn build_server_config(
port: u16,
host: String,
allow_origin: Vec<String>,
cors_allow_any: bool,
limits: RuntimeLimits,
metrics: bool,
metrics_listen: std::net::SocketAddr,
trust_proxy: bool,
config: Option<String>,
) -> ServerConfig {
ServerConfig {
port,
host,
origin_policy: OriginPolicy {
allow_any: cors_allow_any,
allowed_origins: allow_origin,
},
limits,
metrics_enabled: metrics,
metrics_listen,
trust_proxy,
config_path: config.map(std::path::PathBuf::from),
}
}
fn log_rss() {
#[cfg(target_os = "linux")]
{
if let Ok(status) = std::fs::read_to_string("/proc/self/status")
&& let Some(line) = status.lines().find(|l| l.starts_with("VmRSS:"))
{
tracing::info!("{}", line.trim());
}
}
#[cfg(not(target_os = "linux"))]
{
if let Ok(output) = std::process::Command::new("ps")
.args(["-o", "rss=", "-p", &std::process::id().to_string()])
.output()
&& let Ok(rss) = String::from_utf8_lossy(&output.stdout)
.trim()
.parse::<u64>()
{
tracing::info!(rss_mb = rss / 1024, "memory_after_load");
}
}
}
fn ensure_bind_allowed(host: &str, bind_all_flag: bool) -> anyhow::Result<()> {
if is_loopback_host(host) {
return Ok(());
}
let env_opt_in = std::env::var("GIGASTT_ALLOW_BIND_ANY")
.map(|v| matches!(v.trim(), "1" | "true" | "TRUE" | "yes" | "YES"))
.unwrap_or(false);
if bind_all_flag || env_opt_in {
tracing::warn!(
host = %host,
"binding to non-loopback address — anyone on the network can reach this server"
);
return Ok(());
}
anyhow::bail!(
"refusing to bind to '{host}': non-loopback addresses require \
`--bind-all` (or env GIGASTT_ALLOW_BIND_ANY=1) to prevent accidental \
public exposure of local transcription"
)
}
fn is_loopback_host(host: &str) -> bool {
let lowered = host.trim().to_ascii_lowercase();
if lowered == "localhost" || lowered == "::1" {
return true;
}
let stripped = lowered.trim_start_matches('[').trim_end_matches(']');
if let Ok(ip) = stripped.parse::<IpAddr>() {
return ip.is_loopback();
}
false
}
fn parse_model_variant(s: &str) -> Result<ModelVariant, String> {
s.parse()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum PunctuationMode {
On,
Off,
Auto,
}
impl std::str::FromStr for PunctuationMode {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.trim().to_ascii_lowercase().as_str() {
"on" | "true" | "1" | "yes" => Ok(PunctuationMode::On),
"off" | "false" | "0" | "no" => Ok(PunctuationMode::Off),
"auto" => Ok(PunctuationMode::Auto),
other => Err(format!(
"unknown punctuation mode '{other}' (expected 'on', 'off', or 'auto')"
)),
}
}
}
fn parse_punctuation_mode(s: &str) -> Result<PunctuationMode, String> {
s.parse()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ItnMode {
On,
Off,
Auto,
}
impl std::str::FromStr for ItnMode {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.trim().to_ascii_lowercase().as_str() {
"on" | "true" | "1" | "yes" => Ok(ItnMode::On),
"off" | "false" | "0" | "no" => Ok(ItnMode::Off),
"auto" => Ok(ItnMode::Auto),
other => Err(format!(
"unknown ITN mode '{other}' (expected 'on', 'off', or 'auto')"
)),
}
}
}
fn parse_itn_mode(s: &str) -> Result<ItnMode, String> {
s.parse()
}
fn resolve_itn(mode: ItnMode, variant: ModelVariant) -> bool {
match mode {
ItnMode::On => true,
ItnMode::Off => false,
ItnMode::Auto => variant == ModelVariant::Rnnt,
}
}
fn resolve_punctuation(mode: PunctuationMode, variant: ModelVariant) -> bool {
match mode {
PunctuationMode::On => true,
PunctuationMode::Off => false,
PunctuationMode::Auto => variant == ModelVariant::Rnnt,
}
}
fn maybe_load_punctuator(
mode: PunctuationMode,
punct_model_dir: &str,
variant: ModelVariant,
) -> Option<gigastt_core::punctuation::Punctuator> {
if !resolve_punctuation(mode, variant) {
return None;
}
match gigastt_core::punctuation::Punctuator::load(std::path::Path::new(punct_model_dir)) {
Ok(p) => {
tracing::info!("Punctuation restoration enabled (model dir: {punct_model_dir})");
Some(p)
}
Err(e) => {
tracing::warn!(
"Punctuation model unavailable at {punct_model_dir} ({e:#}); \
continuing without punctuation restoration"
);
None
}
}
}
async fn maybe_download_punct_model(
mode: PunctuationMode,
punct_model_dir: &str,
variant: ModelVariant,
) {
if !resolve_punctuation(mode, variant) {
return;
}
if let Err(e) = model::ensure_punct_model(punct_model_dir).await {
tracing::warn!(
"Punctuation model download failed for {punct_model_dir} ({e:#}); \
continuing without punctuation restoration"
);
}
}
fn build_vad_config(
threshold: Option<f32>,
min_silence_ms: Option<u32>,
) -> gigastt_core::vad::VadConfig {
let mut cfg = gigastt_core::vad::VadConfig::default();
if let Some(t) = threshold {
cfg.threshold = t.clamp(0.0, 1.0);
}
if let Some(ms) = min_silence_ms {
cfg.min_silence_ms = ms;
}
cfg
}
fn maybe_load_vad(enabled: bool, vad_model_dir: &str) -> Option<gigastt_core::vad::SileroVad> {
if !enabled {
return None;
}
let path = std::path::Path::new(vad_model_dir).join(gigastt_core::vad::VAD_MODEL_FILE);
match gigastt_core::vad::SileroVad::load(&path) {
Ok(v) => {
tracing::info!("VAD enabled (model dir: {vad_model_dir})");
Some(v)
}
Err(e) => {
tracing::warn!(
"VAD model unavailable at {vad_model_dir} ({e:#}); continuing without VAD"
);
None
}
}
}
async fn maybe_download_vad_model(enabled: bool, vad_model_dir: &str) {
if !enabled {
return;
}
if let Err(e) = model::ensure_vad_model(vad_model_dir).await {
tracing::warn!(
"VAD model download failed for {vad_model_dir} ({e:#}); continuing without VAD"
);
}
}
const DEFAULT_HOTWORDS_BOOST: f32 = 5.0;
fn parse_hotwords_file(path: &str) -> anyhow::Result<Vec<(String, f32)>> {
let content = std::fs::read_to_string(path)
.with_context(|| format!("failed to read hotwords file: {path}"))?;
let mut pairs = Vec::new();
for line in content.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
let (phrase, weight) = match line.split_once('\t') {
Some((p, w)) => (p.trim(), w.trim().parse::<f32>().unwrap_or(1.0)),
None => (line, 1.0),
};
if !phrase.is_empty() {
pairs.push((phrase.to_string(), weight));
}
}
Ok(pairs)
}
fn resolve_hotwords(
hotwords_file: Option<&str>,
hotwords_default: bool,
) -> Option<Vec<(String, f32)>> {
let mut pairs = Vec::new();
if let Some(path) = hotwords_file {
match parse_hotwords_file(path) {
Ok(p) => pairs.extend(p),
Err(e) => tracing::warn!("{e:#}; continuing without file hotwords"),
}
}
if hotwords_default {
pairs.extend(gigastt_core::lexicon::default_hotword_pairs());
}
if pairs.is_empty() { None } else { Some(pairs) }
}
fn ensure_int8_encoder(variant: ModelVariant, model_dir: &str, skip: bool) -> anyhow::Result<()> {
let dir = std::path::Path::new(model_dir);
let int8_path = dir.join(variant.encoder_int8_file());
if int8_path.exists() {
return Ok(());
}
if skip {
tracing::info!(
"Skipping INT8 quantization (--skip-quantize). Engine will load the FP32 encoder."
);
return Ok(());
}
let input = dir.join(variant.encoder_file());
if !input.exists() {
anyhow::bail!(
"Cannot quantize: FP32 encoder not found at {}",
input.display()
);
}
tracing::info!("Quantizing encoder to INT8 (~2 min, one-time)…");
gigastt_core::quantize::quantize_model(&input, &int8_path)?;
tracing::info!("INT8 encoder saved to {}", int8_path.display());
Ok(())
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let cli = Cli::parse();
let directive = format!("gigastt={}", cli.log_level);
tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env().add_directive(directive.parse()?))
.init();
match cli.command {
Commands::Serve {
port,
host,
model_dir,
model_variant,
punctuation,
punct_model_dir,
itn,
hotwords_file,
hotwords_default,
hotwords_boost,
vad,
vad_threshold,
vad_min_silence_ms,
vad_model_dir,
pool_size,
pool_min_size,
batch_pool_size,
encoder_intra_threads,
bind_all,
allow_origin,
cors_allow_any,
idle_timeout_secs,
ws_frame_max_bytes,
body_limit_bytes,
rate_limit_per_minute,
rate_limit_burst,
metrics,
metrics_listen,
max_session_secs,
shutdown_drain_secs,
pool_checkout_timeout_secs,
inference_timeout_secs,
skip_quantize,
trust_proxy,
config,
} => {
ensure_bind_allowed(&host, bind_all)?;
let resolved = model::ensure_model_variant(model_variant, &model_dir).await?;
ensure_int8_encoder(resolved, &model_dir, skip_quantize)?;
maybe_download_punct_model(punctuation, &punct_model_dir, resolved).await;
maybe_download_vad_model(vad, &vad_model_dir).await;
let punctuator = maybe_load_punctuator(punctuation, &punct_model_dir, resolved);
let hotwords = resolve_hotwords(hotwords_file.as_deref(), hotwords_default);
let mut engine = inference::Engine::load_with_pools_threads(
&model_dir,
pool_size,
pool_min_size,
batch_pool_size,
encoder_intra_threads,
)?
.with_punctuator(punctuator)
.with_itn(resolve_itn(itn, resolved))
.with_vad(
maybe_load_vad(vad, &vad_model_dir),
build_vad_config(vad_threshold, vad_min_silence_ms),
);
if let Some(pairs) = hotwords {
engine =
engine.with_hotwords(&pairs, hotwords_boost.unwrap_or(DEFAULT_HOTWORDS_BOOST));
}
log_rss();
let limits = build_limits(
config.as_deref(),
idle_timeout_secs,
ws_frame_max_bytes,
body_limit_bytes,
rate_limit_per_minute,
rate_limit_burst,
max_session_secs,
shutdown_drain_secs,
pool_checkout_timeout_secs,
inference_timeout_secs,
)?;
let metrics_listen =
metrics_listen.unwrap_or_else(server::config::default_metrics_listen);
if metrics {
ensure_bind_allowed(&metrics_listen.ip().to_string(), bind_all)?;
}
let config = build_server_config(
port,
host,
allow_origin,
cors_allow_any,
limits,
metrics,
metrics_listen,
trust_proxy,
config,
);
server::run_with_config(engine, config, None).await?;
}
Commands::Download {
model_dir,
model_variant,
#[cfg(feature = "diarization")]
skip_diarization,
skip_quantize,
} => {
let requested = Some(model_variant);
let resolved = model::ensure_model_variant(requested, &model_dir).await?;
#[cfg(feature = "diarization")]
{
if !skip_diarization {
model::ensure_speaker_model(&model_dir).await?;
}
}
ensure_int8_encoder(resolved, &model_dir, skip_quantize)?;
tracing::info!("Model ready at {model_dir}");
}
Commands::Quantize { model_dir, force } => {
let dir = std::path::Path::new(&model_dir);
let resolved = model::ensure_model_variant(None, &model_dir).await?;
let input = dir.join(resolved.encoder_file());
let output = dir.join(resolved.encoder_int8_file());
if output.exists() && !force {
tracing::info!("INT8 model already exists: {}", output.display());
tracing::info!("Use --force to re-quantize.");
return Ok(());
}
gigastt_core::quantize::quantize_model(&input, &output)?;
tracing::info!("Quantized model saved to {}", output.display());
}
Commands::Transcribe {
file,
model_dir,
model_variant,
punctuation,
punct_model_dir,
itn,
hotwords_file,
hotwords_default,
hotwords_boost,
vad,
vad_threshold,
vad_min_silence_ms,
vad_model_dir,
encoder_intra_threads,
format,
output,
max_chars_per_line,
max_words_per_line,
word_timestamps,
} => {
let resolved = model::ensure_model_variant(model_variant, &model_dir).await?;
maybe_download_punct_model(punctuation, &punct_model_dir, resolved).await;
maybe_download_vad_model(vad, &vad_model_dir).await;
let punctuator = maybe_load_punctuator(punctuation, &punct_model_dir, resolved);
let hotwords = resolve_hotwords(hotwords_file.as_deref(), hotwords_default);
let mut engine = inference::Engine::load_with_pools_threads(
&model_dir,
1,
1,
0,
encoder_intra_threads,
)?
.with_punctuator(punctuator)
.with_itn(resolve_itn(itn, resolved))
.with_vad(
maybe_load_vad(vad, &vad_model_dir),
build_vad_config(vad_threshold, vad_min_silence_ms),
);
if let Some(pairs) = hotwords {
engine =
engine.with_hotwords(&pairs, hotwords_boost.unwrap_or(DEFAULT_HOTWORDS_BOOST));
}
log_rss();
let mut guard = engine.pool.checkout().await?;
let result = engine.transcribe_file(&file, &mut guard);
drop(guard);
let result = result?;
let format = ExportFormat::from_str(&format).map_err(|e| anyhow::anyhow!("{e}"))?;
let opts = RenderOpts {
max_chars_per_line: max_chars_per_line.unwrap_or(80),
max_words_per_line: max_words_per_line.unwrap_or(14),
include_word_timestamps: word_timestamps,
};
let rendered = format.render(&result, &opts);
match output {
Some(path) => {
std::fs::write(&path, rendered)
.with_context(|| format!("failed to write {path}"))?;
tracing::info!("Wrote {} export to {path}", format);
}
None => println!("{rendered}"),
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
static ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
#[test]
fn test_is_loopback_host_recognises_common_forms() {
assert!(is_loopback_host("127.0.0.1"));
assert!(is_loopback_host("localhost"));
assert!(is_loopback_host("::1"));
assert!(is_loopback_host("[::1]"));
assert!(is_loopback_host("127.0.0.2")); assert!(!is_loopback_host("0.0.0.0"));
assert!(!is_loopback_host("192.168.1.10"));
assert!(!is_loopback_host("example.com"));
}
#[test]
fn test_ensure_bind_allowed_loopback_ok() {
ensure_bind_allowed("127.0.0.1", false).expect("loopback must be allowed");
ensure_bind_allowed("localhost", false).expect("localhost must be allowed");
}
#[test]
fn test_ensure_bind_allowed_non_loopback_requires_flag() {
let _guard = ENV_LOCK.lock().unwrap();
let previous = std::env::var("GIGASTT_ALLOW_BIND_ANY").ok();
unsafe {
std::env::remove_var("GIGASTT_ALLOW_BIND_ANY");
}
let result = ensure_bind_allowed("0.0.0.0", false);
if let Some(v) = previous {
unsafe {
std::env::set_var("GIGASTT_ALLOW_BIND_ANY", v);
}
}
assert!(
result.is_err(),
"0.0.0.0 without --bind-all must be rejected"
);
}
#[test]
fn test_ensure_bind_allowed_explicit_flag_ok() {
ensure_bind_allowed("0.0.0.0", true).expect("explicit --bind-all must pass");
}
#[test]
fn test_cli_serve_parsing() {
let cli = Cli::parse_from(["gigastt", "serve", "--port", "1234", "--bind-all"]);
match cli.command {
Commands::Serve {
port,
bind_all,
metrics,
model_variant,
..
} => {
assert_eq!(port, 1234);
assert!(bind_all);
assert!(!metrics);
assert_eq!(model_variant, None);
}
_ => panic!("expected Serve"),
}
}
struct EnvRestore(&'static str, Option<String>);
impl Drop for EnvRestore {
fn drop(&mut self) {
match &self.1 {
Some(v) => unsafe { std::env::set_var(self.0, v) },
None => unsafe { std::env::remove_var(self.0) },
}
}
}
#[test]
fn test_cli_serve_encoder_intra_threads_default() {
let _guard = ENV_LOCK.lock().unwrap();
let _restore = EnvRestore(
"GIGASTT_ENCODER_INTRA_THREADS",
std::env::var("GIGASTT_ENCODER_INTRA_THREADS").ok(),
);
unsafe {
std::env::remove_var("GIGASTT_ENCODER_INTRA_THREADS");
}
let cli = Cli::parse_from(["gigastt", "serve"]);
match cli.command {
Commands::Serve {
encoder_intra_threads,
..
} => assert_eq!(encoder_intra_threads, 1),
_ => panic!("expected Serve"),
}
}
#[test]
fn test_cli_serve_encoder_intra_threads_flag() {
let _guard = ENV_LOCK.lock().unwrap();
let _restore = EnvRestore(
"GIGASTT_ENCODER_INTRA_THREADS",
std::env::var("GIGASTT_ENCODER_INTRA_THREADS").ok(),
);
unsafe {
std::env::remove_var("GIGASTT_ENCODER_INTRA_THREADS");
}
let cli = Cli::parse_from(["gigastt", "serve", "--encoder-intra-threads", "4"]);
match cli.command {
Commands::Serve {
encoder_intra_threads,
..
} => assert_eq!(encoder_intra_threads, 4),
_ => panic!("expected Serve"),
}
}
#[test]
fn test_cli_serve_encoder_intra_threads_env() {
let _guard = ENV_LOCK.lock().unwrap();
let _restore = EnvRestore(
"GIGASTT_ENCODER_INTRA_THREADS",
std::env::var("GIGASTT_ENCODER_INTRA_THREADS").ok(),
);
unsafe {
std::env::set_var("GIGASTT_ENCODER_INTRA_THREADS", "6");
}
let cli = Cli::parse_from(["gigastt", "serve"]);
match cli.command {
Commands::Serve {
encoder_intra_threads,
..
} => assert_eq!(encoder_intra_threads, 6),
_ => panic!("expected Serve"),
}
}
#[test]
fn test_cli_transcribe_encoder_intra_threads_flag() {
let _guard = ENV_LOCK.lock().unwrap();
let _restore = EnvRestore(
"GIGASTT_ENCODER_INTRA_THREADS",
std::env::var("GIGASTT_ENCODER_INTRA_THREADS").ok(),
);
unsafe {
std::env::remove_var("GIGASTT_ENCODER_INTRA_THREADS");
}
let cli = Cli::parse_from([
"gigastt",
"transcribe",
"audio.wav",
"--encoder-intra-threads",
"3",
]);
match cli.command {
Commands::Transcribe {
encoder_intra_threads,
..
} => assert_eq!(encoder_intra_threads, 3),
_ => panic!("expected Transcribe"),
}
}
#[test]
fn test_cli_serve_model_variant_override() {
let cli = Cli::parse_from(["gigastt", "serve", "--model-variant", "e2e_rnnt"]);
match cli.command {
Commands::Serve { model_variant, .. } => {
assert_eq!(model_variant, Some(ModelVariant::E2eRnnt));
}
_ => panic!("expected Serve"),
}
}
#[test]
fn test_cli_serve_model_variant_explicit_rnnt() {
let cli = Cli::parse_from(["gigastt", "serve", "--model-variant", "rnnt"]);
match cli.command {
Commands::Serve { model_variant, .. } => {
assert_eq!(model_variant, Some(ModelVariant::Rnnt));
}
_ => panic!("expected Serve"),
}
}
#[test]
fn test_cli_download_parsing() {
let cli = Cli::parse_from(["gigastt", "download", "--model-dir", "/tmp/models"]);
match cli.command {
Commands::Download {
model_dir,
model_variant,
..
} => {
assert_eq!(model_dir, "/tmp/models");
assert_eq!(model_variant, ModelVariant::Rnnt);
}
_ => panic!("expected Download"),
}
}
#[test]
fn test_cli_download_model_variant_override() {
let cli = Cli::parse_from(["gigastt", "download", "--model-variant", "e2e_rnnt"]);
match cli.command {
Commands::Download { model_variant, .. } => {
assert_eq!(model_variant, ModelVariant::E2eRnnt);
}
_ => panic!("expected Download"),
}
}
#[test]
fn test_cli_quantize_parsing() {
let cli = Cli::parse_from(["gigastt", "quantize", "--force"]);
match cli.command {
Commands::Quantize { force, .. } => {
assert!(force);
}
_ => panic!("expected Quantize"),
}
}
#[test]
fn test_cli_transcribe_parsing() {
let cli = Cli::parse_from(["gigastt", "transcribe", "audio.wav"]);
match cli.command {
Commands::Transcribe {
file,
model_variant,
format,
output,
..
} => {
assert_eq!(file, "audio.wav");
assert_eq!(model_variant, None);
assert_eq!(format, "txt");
assert!(output.is_none());
}
_ => panic!("expected Transcribe"),
}
}
#[test]
fn test_cli_transcribe_format_and_output() {
let cli = Cli::parse_from([
"gigastt",
"transcribe",
"audio.wav",
"--format",
"srt",
"-o",
"out.srt",
]);
match cli.command {
Commands::Transcribe {
file,
format,
output,
..
} => {
assert_eq!(file, "audio.wav");
assert_eq!(format, "srt");
assert_eq!(output, Some("out.srt".to_string()));
}
_ => panic!("expected Transcribe"),
}
}
#[test]
fn test_cli_transcribe_subtitle_options() {
let cli = Cli::parse_from([
"gigastt",
"transcribe",
"audio.wav",
"--format",
"vtt",
"--max-chars-per-line",
"60",
"--max-words-per-line",
"10",
"--word-timestamps",
]);
match cli.command {
Commands::Transcribe {
format,
max_chars_per_line,
max_words_per_line,
word_timestamps,
..
} => {
assert_eq!(format, "vtt");
assert_eq!(max_chars_per_line, Some(60));
assert_eq!(max_words_per_line, Some(10));
assert!(word_timestamps);
}
_ => panic!("expected Transcribe"),
}
}
#[test]
fn test_cli_serve_rejects_unknown_model_variant() {
let res = Cli::try_parse_from(["gigastt", "serve", "--model-variant", "whisper"]);
assert!(res.is_err(), "unknown variant must be rejected by clap");
}
#[test]
fn test_punctuation_mode_from_str() {
use std::str::FromStr;
assert_eq!(
PunctuationMode::from_str("on").unwrap(),
PunctuationMode::On
);
assert_eq!(
PunctuationMode::from_str("OFF").unwrap(),
PunctuationMode::Off
);
assert_eq!(
PunctuationMode::from_str(" auto ").unwrap(),
PunctuationMode::Auto
);
assert!(PunctuationMode::from_str("maybe").is_err());
}
#[test]
fn test_cli_serve_punctuation_defaults_auto() {
let cli = Cli::parse_from(["gigastt", "serve"]);
match cli.command {
Commands::Serve {
punctuation,
punct_model_dir,
..
} => {
assert_eq!(punctuation, PunctuationMode::Auto);
assert!(punct_model_dir.contains("punct"));
}
_ => panic!("expected Serve"),
}
}
#[test]
fn test_cli_serve_punctuation_override() {
let cli = Cli::parse_from([
"gigastt",
"serve",
"--punctuation",
"on",
"--punct-model-dir",
"/tmp/punct",
]);
match cli.command {
Commands::Serve {
punctuation,
punct_model_dir,
..
} => {
assert_eq!(punctuation, PunctuationMode::On);
assert_eq!(punct_model_dir, "/tmp/punct");
}
_ => panic!("expected Serve"),
}
}
#[test]
fn test_cli_transcribe_punctuation_off() {
let cli = Cli::parse_from(["gigastt", "transcribe", "a.wav", "--punctuation", "off"]);
match cli.command {
Commands::Transcribe { punctuation, .. } => {
assert_eq!(punctuation, PunctuationMode::Off);
}
_ => panic!("expected Transcribe"),
}
}
#[test]
fn test_itn_mode_from_str() {
use std::str::FromStr;
assert_eq!(ItnMode::from_str("on").unwrap(), ItnMode::On);
assert_eq!(ItnMode::from_str("OFF").unwrap(), ItnMode::Off);
assert_eq!(ItnMode::from_str(" auto ").unwrap(), ItnMode::Auto);
assert!(ItnMode::from_str("maybe").is_err());
}
#[test]
fn test_resolve_itn_auto_per_variant() {
assert!(resolve_itn(ItnMode::Auto, ModelVariant::Rnnt));
assert!(!resolve_itn(ItnMode::Auto, ModelVariant::E2eRnnt));
assert!(resolve_itn(ItnMode::On, ModelVariant::E2eRnnt));
assert!(!resolve_itn(ItnMode::Off, ModelVariant::Rnnt));
}
#[test]
fn test_cli_serve_itn_defaults_auto() {
let cli = Cli::parse_from(["gigastt", "serve"]);
match cli.command {
Commands::Serve { itn, .. } => assert_eq!(itn, ItnMode::Auto),
_ => panic!("expected Serve"),
}
}
#[test]
fn test_cli_transcribe_itn_override() {
let cli = Cli::parse_from(["gigastt", "transcribe", "a.wav", "--itn", "on"]);
match cli.command {
Commands::Transcribe { itn, .. } => assert_eq!(itn, ItnMode::On),
_ => panic!("expected Transcribe"),
}
}
#[test]
fn test_maybe_load_punctuator_off_skips_load() {
assert!(
maybe_load_punctuator(PunctuationMode::Off, "/nonexistent", ModelVariant::Rnnt)
.is_none()
);
}
#[test]
fn test_maybe_load_punctuator_auto_e2e_skips_load() {
assert!(
maybe_load_punctuator(PunctuationMode::Auto, "/nonexistent", ModelVariant::E2eRnnt)
.is_none()
);
}
#[test]
fn test_maybe_load_punctuator_missing_model_falls_back_to_none() {
let tmp = tempfile::tempdir().unwrap();
let missing = tmp.path().join("absent");
assert!(
maybe_load_punctuator(
PunctuationMode::On,
missing.to_str().unwrap(),
ModelVariant::Rnnt
)
.is_none()
);
}
#[test]
fn test_cli_serve_hotwords_flags() {
let cli = Cli::parse_from([
"gigastt",
"serve",
"--hotwords-file",
"/tmp/hw.txt",
"--hotwords-default",
"--hotwords-boost",
"8.5",
]);
match cli.command {
Commands::Serve {
hotwords_file,
hotwords_default,
hotwords_boost,
..
} => {
assert_eq!(hotwords_file, Some("/tmp/hw.txt".to_string()));
assert!(hotwords_default);
assert_eq!(hotwords_boost, Some(8.5));
}
_ => panic!("expected Serve"),
}
}
#[test]
fn test_cli_serve_hotwords_default_off() {
let cli = Cli::parse_from(["gigastt", "serve"]);
match cli.command {
Commands::Serve {
hotwords_file,
hotwords_default,
hotwords_boost,
..
} => {
assert_eq!(hotwords_file, None);
assert!(!hotwords_default);
assert_eq!(hotwords_boost, None);
}
_ => panic!("expected Serve"),
}
}
#[test]
fn test_cli_transcribe_hotwords_flags() {
let cli = Cli::parse_from([
"gigastt",
"transcribe",
"a.wav",
"--hotwords-file",
"hw.txt",
]);
match cli.command {
Commands::Transcribe {
hotwords_file,
hotwords_default,
..
} => {
assert_eq!(hotwords_file, Some("hw.txt".to_string()));
assert!(!hotwords_default);
}
_ => panic!("expected Transcribe"),
}
}
#[test]
fn test_parse_hotwords_file_lines_and_weights() {
let tmp = tempfile::NamedTempFile::new().unwrap();
std::fs::write(
tmp.path(),
b"# comment\n\nsynergy\nyoutube\t2.5\n spaced \nbadweight\tnope\n",
)
.unwrap();
let pairs = parse_hotwords_file(tmp.path().to_str().unwrap()).unwrap();
assert_eq!(
pairs,
vec![
("synergy".to_string(), 1.0),
("youtube".to_string(), 2.5),
("spaced".to_string(), 1.0),
("badweight".to_string(), 1.0), ]
);
}
#[test]
fn test_resolve_hotwords_none_when_unset() {
assert!(resolve_hotwords(None, false).is_none());
}
#[test]
fn test_resolve_hotwords_default_pack_only() {
let pairs = resolve_hotwords(None, true).expect("default pack present");
assert_eq!(pairs.len(), gigastt_core::lexicon::DEFAULT_HOTWORDS.len());
}
#[test]
fn test_resolve_hotwords_file_plus_default() {
let tmp = tempfile::NamedTempFile::new().unwrap();
std::fs::write(tmp.path(), "мойбренд\n").unwrap();
let pairs = resolve_hotwords(tmp.path().to_str().unwrap().into(), true).unwrap();
assert_eq!(
pairs.len(),
1 + gigastt_core::lexicon::DEFAULT_HOTWORDS.len()
);
assert_eq!(pairs[0].0, "мойбренд");
}
#[test]
fn test_resolve_hotwords_missing_file_is_graceful() {
assert!(resolve_hotwords(Some("/nonexistent/hw.txt"), false).is_none());
}
#[test]
fn test_cli_serve_with_metrics() {
let cli = Cli::parse_from(["gigastt", "serve", "--metrics"]);
match cli.command {
Commands::Serve {
metrics,
metrics_listen,
..
} => {
assert!(metrics);
assert!(metrics_listen.is_none());
}
_ => panic!("expected Serve"),
}
}
#[test]
fn test_cli_serve_metrics_listen_override() {
let cli = Cli::parse_from([
"gigastt",
"serve",
"--metrics",
"--metrics-listen",
"127.0.0.1:9123",
]);
match cli.command {
Commands::Serve { metrics_listen, .. } => {
let addr = metrics_listen.expect("--metrics-listen must parse");
assert_eq!(addr.port(), 9123);
assert!(addr.ip().is_loopback());
}
_ => panic!("expected Serve"),
}
assert_eq!(server::config::default_metrics_listen().port(), 9090);
}
#[test]
fn test_is_loopback_host_ipv6_bracketed() {
assert!(is_loopback_host("[::1]"));
assert!(!is_loopback_host("[2001:db8::1]"));
}
#[test]
fn test_ensure_int8_encoder_already_exists() {
let tmp = tempfile::tempdir().unwrap();
let int8_path = tmp.path().join("v3_rnnt_encoder_int8.onnx");
std::fs::write(&int8_path, b"fake").unwrap();
ensure_int8_encoder(ModelVariant::Rnnt, tmp.path().to_str().unwrap(), false).unwrap();
}
#[test]
fn test_ensure_int8_encoder_skip_flag() {
let tmp = tempfile::tempdir().unwrap();
ensure_int8_encoder(ModelVariant::Rnnt, tmp.path().to_str().unwrap(), true).unwrap();
}
#[test]
fn test_ensure_int8_encoder_missing_input() {
let tmp = tempfile::tempdir().unwrap();
let err = ensure_int8_encoder(ModelVariant::Rnnt, tmp.path().to_str().unwrap(), false)
.unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("Cannot quantize"), "unexpected error: {msg}");
}
#[test]
fn test_ensure_int8_encoder_e2e_targets_e2e_encoder_name() {
let tmp = tempfile::tempdir().unwrap();
std::fs::write(tmp.path().join("v3_rnnt_encoder.onnx"), b"rnnt").unwrap();
let err = ensure_int8_encoder(ModelVariant::E2eRnnt, tmp.path().to_str().unwrap(), false)
.unwrap_err();
assert!(format!("{err}").contains("Cannot quantize"));
}
#[test]
fn test_log_rss_does_not_panic() {
log_rss();
}
#[test]
fn test_ensure_bind_allowed_env_opt_in() {
let _guard = ENV_LOCK.lock().unwrap();
let previous = std::env::var("GIGASTT_ALLOW_BIND_ANY").ok();
unsafe {
std::env::set_var("GIGASTT_ALLOW_BIND_ANY", "1");
}
let result = ensure_bind_allowed("0.0.0.0", false);
if let Some(v) = previous {
unsafe {
std::env::set_var("GIGASTT_ALLOW_BIND_ANY", v);
}
} else {
unsafe {
std::env::remove_var("GIGASTT_ALLOW_BIND_ANY");
}
}
assert!(result.is_ok(), "env opt-in must allow non-loopback bind");
}
#[test]
fn test_build_limits_defaults_when_no_config() {
let limits =
build_limits(None, None, None, None, None, None, None, None, None, None).unwrap();
assert_eq!(limits.idle_timeout_secs, 300);
assert_eq!(limits.ws_frame_max_bytes, 512 * 1024);
}
#[test]
fn test_build_limits_applies_overrides() {
let limits = build_limits(
None,
Some(600),
Some(1024),
Some(10 * 1024 * 1024),
Some(60),
Some(20),
Some(1800),
Some(5),
Some(15),
Some(45),
)
.unwrap();
assert_eq!(limits.idle_timeout_secs, 600);
assert_eq!(limits.ws_frame_max_bytes, 1024);
assert_eq!(limits.body_limit_bytes, 10 * 1024 * 1024);
assert_eq!(limits.rate_limit_per_minute, 60);
assert_eq!(limits.rate_limit_burst, 20);
assert_eq!(limits.max_session_secs, 1800);
assert_eq!(limits.shutdown_drain_secs, 5);
assert_eq!(limits.pool_checkout_timeout_secs, 15);
assert_eq!(limits.inference_timeout_secs, 45);
}
#[test]
fn test_build_limits_with_valid_config_file() {
let tmp = tempfile::NamedTempFile::new().unwrap();
std::fs::write(tmp.path(), b"idle_timeout_secs = 123\n").unwrap();
let limits = build_limits(
Some(tmp.path().to_str().unwrap()),
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
.unwrap();
assert_eq!(limits.idle_timeout_secs, 123);
}
#[test]
fn test_build_limits_with_invalid_config_file() {
let tmp = tempfile::NamedTempFile::new().unwrap();
std::fs::write(tmp.path(), b"not valid toml {{{").unwrap();
let result = build_limits(
Some(tmp.path().to_str().unwrap()),
None,
None,
None,
None,
None,
None,
None,
None,
None,
);
assert!(result.is_err());
}
#[test]
fn test_build_limits_rejects_zero_burst_with_nonzero_rpm() {
let result = build_limits(
None,
None,
None,
None,
Some(30),
Some(0),
None,
None,
None,
None,
);
assert!(result.is_err());
let msg = format!("{}", result.unwrap_err());
assert!(msg.contains("rate-limit-burst"));
}
#[test]
fn test_build_limits_allows_zero_rpm() {
let limits = build_limits(
None,
None,
None,
None,
Some(0),
Some(0),
None,
None,
None,
None,
)
.unwrap();
assert_eq!(limits.rate_limit_per_minute, 0);
assert_eq!(limits.rate_limit_burst, 0);
}
#[test]
fn test_build_server_config() {
let limits = RuntimeLimits::default();
let cfg = build_server_config(
1234,
"127.0.0.1".into(),
vec!["https://app.example.com".into()],
false,
limits.clone(),
true,
"127.0.0.1:9099".parse().unwrap(),
true,
Some("/tmp/config.toml".into()),
);
assert_eq!(cfg.port, 1234);
assert_eq!(cfg.metrics_listen.port(), 9099);
assert_eq!(cfg.host, "127.0.0.1");
assert_eq!(cfg.origin_policy.allowed_origins.len(), 1);
assert!(!cfg.origin_policy.allow_any);
assert!(cfg.metrics_enabled);
assert!(cfg.trust_proxy);
assert_eq!(
cfg.config_path,
Some(std::path::PathBuf::from("/tmp/config.toml"))
);
assert_eq!(cfg.limits.idle_timeout_secs, limits.idle_timeout_secs);
}
}