use clap::Parser;
use harborshield::{Harborshield, VERSION, check_kernel_version, parse_duration, shutdown_signal};
use std::path::PathBuf;
use std::time::Duration;
use tracing::{error, info};
use tracing_subscriber::{EnvFilter, fmt, prelude::*};
#[derive(Parser, Debug)]
#[command(author, version, about = "Automate management of firewall rules for Docker containers", long_about = None)]
struct Args {
#[arg(long)]
clear: bool,
#[arg(short = 'd', long, default_value = ".")]
data_dir: PathBuf,
#[arg(long)]
debug: bool,
#[arg(short = 'l', long, default_value = "stdout")]
log_path: String,
#[arg(short = 't', long, default_value = "10s", value_parser = parse_duration)]
timeout: Duration,
#[arg(long)]
health_server: Option<String>,
#[arg(long = "version-info")]
version_info: bool,
}
#[tokio::main]
async fn main() {
if let Err(e) = dotenvy::dotenv() {
if e.not_found() {
} else {
eprintln!("Error loading .env file: {}", e);
}
}
let args = Args::parse();
if args.version_info {
println!("harborshield {}", VERSION);
println!("harborshield-rust (Rust port)");
return;
}
let env_filter = if args.debug {
EnvFilter::new("debug")
} else {
EnvFilter::new("info")
};
let subscriber = tracing_subscriber::registry().with(env_filter);
if args.log_path == "stdout" || args.log_path == "stderr" {
let subscriber = subscriber.with(fmt::layer());
tracing::subscriber::set_global_default(subscriber)
.expect("Failed to set tracing subscriber");
} else {
let file_appender = tracing_appender::rolling::never("", &args.log_path);
let (non_blocking, _guard) = tracing_appender::non_blocking(file_appender);
let subscriber = subscriber.with(fmt::layer().with_writer(non_blocking));
tracing::subscriber::set_global_default(subscriber)
.expect("Failed to set tracing subscriber");
}
check_kernel_version();
#[cfg(target_os = "linux")]
{
if let Err(e) = harborshield::security::check_capabilities() {
error!("Capability check failed: {}", e);
std::process::exit(1);
}
info!("All required capabilities are present");
}
let data_dir = match args.data_dir.canonicalize() {
Ok(path) => path,
Err(e) => {
error!("Failed to get absolute path for data directory: {}", e);
std::process::exit(1);
}
};
let db_path = data_dir.join("db.sqlite");
let harborshield = match Harborshield::builder()
.db_path(&db_path)
.timeout(args.timeout)
.maybe_health_server_addr(args.health_server.as_deref())
.build()
.await
{
Ok(handlers) => handlers,
Err(e) => {
error!("Failed to initialize rule handlers: {}", e);
std::process::exit(1);
}
};
let log_path = if args.log_path != "stdout" && args.log_path != "stderr" {
Some(PathBuf::from(&args.log_path))
} else {
None
};
#[cfg(target_os = "linux")]
{
if let Err(e) = harborshield::security::apply_restrictions(&db_path, log_path.as_deref()) {
error!("Failed to apply security restrictions: {}", e);
std::process::exit(1);
}
}
if args.clear {
info!("Clearing all harborshield rules");
if let Err(e) = harborshield.clear().await {
error!("Failed to clear rules: {}", e);
std::process::exit(1);
}
return;
}
info!("Starting harborshield v{}", VERSION);
let harborshield = match harborshield.start().await {
Ok(started_handlers) => started_handlers,
Err(e) => {
error!("Failed to start rule handlers: {}", e);
std::process::exit(1);
}
};
shutdown_signal().await;
info!("Shutting down");
harborshield.stop().await;
}