use std::path::PathBuf;
use std::process::ExitCode;
use clap::Parser;
use epics_bridge_rs::ca_gateway::{GatewayConfig, GatewayServer, RestartPolicy, supervise};
#[derive(Parser, Debug)]
#[command(
name = "ca-gateway-rs",
about = "Pure Rust port of the EPICS CA gateway",
version
)]
struct Args {
#[arg(long)]
pvlist: Option<PathBuf>,
#[arg(long)]
access: Option<PathBuf>,
#[arg(long)]
preload: Option<PathBuf>,
#[arg(long)]
putlog: Option<PathBuf>,
#[arg(long)]
command: Option<PathBuf>,
#[arg(long, default_value_t = 0)]
port: u16,
#[arg(long)]
read_only: bool,
#[arg(long)]
no_stats: bool,
#[arg(long, default_value = "gateway:")]
stats_prefix: String,
#[arg(long, default_value_t = 1)]
heartbeat_interval: u64,
#[arg(long, default_value_t = 10)]
cleanup_interval: u64,
#[arg(long, default_value_t = 10)]
stats_interval: u64,
#[arg(long)]
supervised: bool,
#[arg(long, default_value_t = 10)]
max_restarts: u32,
#[arg(long, default_value_t = 600)]
restart_window: u64,
#[arg(long, default_value_t = 10)]
restart_delay: u64,
#[cfg(feature = "ca-gateway-tls")]
#[arg(long)]
tls_cert: Option<PathBuf>,
#[cfg(feature = "ca-gateway-tls")]
#[arg(long)]
tls_key: Option<PathBuf>,
#[cfg(feature = "ca-gateway-tls")]
#[arg(long)]
tls_client_ca: Option<PathBuf>,
#[cfg(feature = "ca-gateway-tls")]
#[arg(long)]
upstream_tls_roots: Option<PathBuf>,
#[cfg(feature = "ca-gateway-tls")]
#[arg(long)]
upstream_tls_client_cert: Option<PathBuf>,
#[cfg(feature = "ca-gateway-tls")]
#[arg(long)]
upstream_tls_client_key: Option<PathBuf>,
#[cfg(feature = "ca-gateway-tls")]
#[arg(long)]
upstream_tls_server_name: Option<String>,
}
#[cfg(feature = "ca-gateway-tls")]
fn build_tls(
args: &Args,
) -> Result<Option<std::sync::Arc<epics_ca_rs::tls::ServerConfig>>, String> {
use epics_ca_rs::tls::{TlsConfig, load_certs, load_private_key, load_root_store};
let (cert_path, key_path) = match (&args.tls_cert, &args.tls_key) {
(Some(c), Some(k)) => (c, k),
(None, None) => return Ok(None),
_ => {
return Err("--tls-cert and --tls-key must both be set or both unset".into());
}
};
let chain = load_certs(cert_path.to_str().unwrap_or_default())
.map_err(|e| format!("loading cert chain: {e}"))?;
let key = load_private_key(key_path.to_str().unwrap_or_default())
.map_err(|e| format!("loading key: {e}"))?;
let cfg = if let Some(ca_path) = &args.tls_client_ca {
let roots = load_root_store(ca_path.to_str().unwrap_or_default())
.map_err(|e| format!("loading client CA: {e}"))?;
TlsConfig::server_mtls_from_pem(chain, key, roots)
} else {
TlsConfig::server_from_pem(chain, key)
}
.map_err(|e| format!("TLS server build: {e}"))?;
match cfg {
TlsConfig::Server(arc) => Ok(Some(arc)),
TlsConfig::Client(_) => Err("expected server TlsConfig".into()),
}
}
#[cfg(feature = "ca-gateway-tls")]
fn build_upstream_tls(args: &Args) -> Result<Option<epics_ca_rs::tls::TlsConfig>, String> {
use epics_ca_rs::tls::{TlsConfig, load_certs, load_private_key, load_root_store};
let roots_path = match &args.upstream_tls_roots {
Some(p) => p,
None => return Ok(None),
};
let roots = load_root_store(roots_path.to_str().unwrap_or_default())
.map_err(|e| format!("loading upstream TLS roots: {e}"))?;
let cfg = match (
&args.upstream_tls_client_cert,
&args.upstream_tls_client_key,
) {
(None, None) => TlsConfig::client_from_roots(roots),
(Some(cert), Some(key)) => {
let chain = load_certs(cert.to_str().unwrap_or_default())
.map_err(|e| format!("loading upstream client cert: {e}"))?;
let priv_key = load_private_key(key.to_str().unwrap_or_default())
.map_err(|e| format!("loading upstream client key: {e}"))?;
TlsConfig::client_mtls(roots, chain, priv_key)
.map_err(|e| format!("upstream mTLS build: {e}"))?
}
_ => {
return Err(
"--upstream-tls-client-cert and --upstream-tls-client-key must both be \
set or both unset"
.into(),
);
}
};
Ok(Some(cfg))
}
async fn run_once(config: GatewayConfig) -> Result<(), String> {
tracing::info!("ca-gateway-rs: starting");
let server = GatewayServer::build(config)
.await
.map_err(|e| format!("build failed: {e}"))?;
server
.run()
.await
.map_err(|e| format!("runtime error: {e}"))
}
#[tokio::main(flavor = "multi_thread")]
async fn main() -> ExitCode {
let _ = tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
)
.try_init();
let args = Args::parse();
let config = GatewayConfig {
pvlist_path: args.pvlist.clone(),
pvlist_content: None,
access_path: args.access.clone(),
putlog_path: args.putlog.clone(),
command_path: args.command.clone(),
preload_path: args.preload.clone(),
server_port: args.port,
timeouts: Default::default(),
stats_prefix: if args.no_stats {
String::new()
} else {
args.stats_prefix.clone()
},
cleanup_interval: std::time::Duration::from_secs(args.cleanup_interval),
stats_interval: std::time::Duration::from_secs(args.stats_interval),
heartbeat_interval: if args.heartbeat_interval == 0 {
None
} else {
Some(std::time::Duration::from_secs(args.heartbeat_interval))
},
read_only: args.read_only,
#[cfg(feature = "ca-gateway-tls")]
tls: build_tls(&args).unwrap_or_else(|e| {
tracing::error!(error = %e, "ca-gateway-rs: TLS init failed");
std::process::exit(2);
}),
#[cfg(feature = "ca-gateway-tls")]
upstream_tls: build_upstream_tls(&args).unwrap_or_else(|e| {
tracing::error!(error = %e, "ca-gateway-rs: upstream TLS init failed");
std::process::exit(2);
}),
#[cfg(feature = "ca-gateway-tls")]
upstream_tls_server_name: args.upstream_tls_server_name.clone(),
};
if args.supervised {
let policy = RestartPolicy {
max_restarts: args.max_restarts,
window: std::time::Duration::from_secs(args.restart_window),
delay: std::time::Duration::from_secs(args.restart_delay),
};
tracing::info!(
max_restarts = args.max_restarts,
window_secs = args.restart_window,
"ca-gateway-rs: running under supervisor"
);
let result = supervise(policy, || {
let cfg = config.clone();
async move { run_once(cfg).await }
})
.await;
match result {
Ok(()) => ExitCode::SUCCESS,
Err(e) => {
tracing::error!(error = %e, "ca-gateway-rs: supervisor exit");
ExitCode::FAILURE
}
}
} else {
match run_once(config).await {
Ok(()) => ExitCode::SUCCESS,
Err(e) => {
tracing::error!(error = %e, "ca-gateway-rs: error");
ExitCode::FAILURE
}
}
}
}