use std::net::SocketAddr;
use std::time::Duration;
use anyhow::Context;
use axum::http::{HeaderName, Request, StatusCode};
use clap::Parser;
use tokio::signal;
use tower::limit::ConcurrencyLimitLayer;
use tower_http::limit::RequestBodyLimitLayer;
use tower_http::request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer};
use tower_http::timeout::TimeoutLayer;
use tower_http::trace::{DefaultOnRequest, DefaultOnResponse, TraceLayer};
use tracing::{info, info_span, warn, Level};
use bezant_server::{router, AppState};
#[derive(Debug, Parser)]
#[command(version, about)]
struct Args {
#[arg(long, env = "BEZANT_BIND", default_value = "0.0.0.0:8080")]
bind: SocketAddr,
#[arg(long, env = "IBKR_GATEWAY_URL", default_value = bezant::DEFAULT_BASE_URL)]
gateway_url: String,
#[arg(long, env = "BEZANT_KEEPALIVE_SECS", default_value_t = 60)]
keepalive_secs: u64,
#[arg(long, env = "BEZANT_VERIFY_TLS", default_value_t = false)]
verify_tls: bool,
#[arg(long, env = "BEZANT_DEBUG_TOKEN")]
debug_token: Option<String>,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "bezant_server=info,bezant=info,tower_http=info".into()),
)
.with_target(false)
.init();
let args = Args::parse();
info!(
bind = %args.bind,
gateway = %args.gateway_url,
keepalive = args.keepalive_secs,
"bezant-server starting"
);
let client = bezant::Client::builder(&args.gateway_url)
.accept_invalid_certs(!args.verify_tls)
.follow_redirects(false)
.build()
.context("building bezant client")?;
let keepalive = client.spawn_keepalive(Duration::from_secs(args.keepalive_secs));
let state = match args.debug_token {
Some(token) => {
info!("debug endpoints enabled (token gating active)");
AppState::with_debug_token(client, token)
}
None => AppState::new(client),
};
let trace = TraceLayer::new_for_http()
.make_span_with(|req: &Request<_>| {
let request_id = req
.headers()
.get("x-request-id")
.and_then(|v| v.to_str().ok())
.unwrap_or("-");
info_span!(
"http",
method = %req.method(),
path = %req.uri().path(),
request_id = %request_id,
)
})
.on_request(DefaultOnRequest::new().level(Level::DEBUG))
.on_response(DefaultOnResponse::new().level(Level::DEBUG));
let req_id_header = HeaderName::from_static("x-request-id");
let app = router(state)
.layer(PropagateRequestIdLayer::new(req_id_header.clone()))
.layer(trace)
.layer(TimeoutLayer::with_status_code(
StatusCode::GATEWAY_TIMEOUT,
Duration::from_secs(35),
))
.layer(ConcurrencyLimitLayer::new(256))
.layer(RequestBodyLimitLayer::new(10 * 1024 * 1024))
.layer(SetRequestIdLayer::new(req_id_header, MakeRequestUuid));
let listener = tokio::net::TcpListener::bind(args.bind)
.await
.with_context(|| format!("binding {}", args.bind))?;
info!(addr = %listener.local_addr()?, "bezant-server listening");
axum::serve(listener, app.into_make_service())
.with_graceful_shutdown(shutdown_signal())
.await
.context("server crashed")?;
info!("server drained; stopping keepalive");
if let Err(e) = keepalive.stop().await {
warn!(error = %e, "keepalive stop returned error");
}
info!("bezant-server shutdown complete");
Ok(())
}
async fn shutdown_signal() {
let ctrl_c = async {
if let Err(e) = signal::ctrl_c().await {
warn!(error = %e, "ctrl-c signal handler failed to install");
std::future::pending::<()>().await;
}
};
#[cfg(unix)]
let terminate = async {
match signal::unix::signal(signal::unix::SignalKind::terminate()) {
Ok(mut s) => {
s.recv().await;
}
Err(e) => {
warn!(error = %e, "SIGTERM handler failed to install");
std::future::pending::<()>().await;
}
}
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => info!("shutdown: SIGINT received"),
_ = terminate => info!("shutdown: SIGTERM received"),
}
}