use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use axum::Router;
use clap::{Parser, Subcommand};
use tower_http::trace::TraceLayer;
use tracing::info;
use aranet_service::config::default_config_path;
use aranet_service::middleware::{self, RateLimitState};
use aranet_service::{AppState, Collector, Config, api, ws};
use aranet_store::Store;
mod service;
#[derive(Parser, Debug)]
#[command(name = "aranet-service")]
#[command(version, about, long_about = None)]
struct Args {
#[command(subcommand)]
command: Option<Command>,
#[arg(short, long, global = true)]
config: Option<PathBuf>,
#[arg(short, long, global = true)]
bind: Option<String>,
#[arg(short, long, global = true)]
database: Option<PathBuf>,
#[arg(long, global = true)]
no_collector: bool,
}
#[derive(Subcommand, Debug)]
enum Command {
Run,
Service {
#[command(subcommand)]
action: ServiceAction,
},
}
#[derive(Subcommand, Debug)]
enum ServiceAction {
Install {
#[arg(long)]
user: bool,
},
Uninstall {
#[arg(long)]
user: bool,
},
Start {
#[arg(long)]
user: bool,
},
Stop {
#[arg(long)]
user: bool,
},
Status {
#[arg(long)]
user: bool,
},
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let args = Args::parse();
match args.command {
Some(Command::Service { action }) => handle_service_action(action),
Some(Command::Run) | None => run_server(args).await,
}
}
fn handle_service_action(action: ServiceAction) -> anyhow::Result<()> {
use service::{Level, ServiceStatus};
let (action_name, result) = match action {
ServiceAction::Install { user } => {
let level = if user { Level::User } else { Level::System };
("install", service::install(level))
}
ServiceAction::Uninstall { user } => {
let level = if user { Level::User } else { Level::System };
("uninstall", service::uninstall(level))
}
ServiceAction::Start { user } => {
let level = if user { Level::User } else { Level::System };
("start", service::start(level))
}
ServiceAction::Stop { user } => {
let level = if user { Level::User } else { Level::System };
("stop", service::stop(level))
}
ServiceAction::Status { user } => {
let level = if user { Level::User } else { Level::System };
match service::status(level) {
Ok(ServiceStatus::Running) => {
println!("aranet-service is running");
return Ok(());
}
Ok(ServiceStatus::Stopped) => {
println!("aranet-service is stopped");
return Ok(());
}
Err(e) => {
eprintln!("Failed to get status: {}", e);
return Err(e.into());
}
}
}
};
match result {
Ok(()) => {
println!("Successfully {}ed aranet-service", action_name);
Ok(())
}
Err(e) => {
eprintln!("Failed to {} service: {}", action_name, e);
Err(e.into())
}
}
}
async fn run_server(args: Args) -> anyhow::Result<()> {
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::from_default_env()
.add_directive("aranet_service=info".parse()?)
.add_directive("tower_http=debug".parse()?),
)
.init();
let config_path = args.config.clone().unwrap_or_else(default_config_path);
let mut config = if config_path.exists() {
Config::load(&config_path)?
} else {
Config::default()
};
if let Some(bind) = args.bind {
config.server.bind = bind;
}
if let Some(db_path) = args.database {
config.storage.path = db_path;
}
info!("Opening database at {:?}", config.storage.path);
let store = Store::open(&config.storage.path)?;
let state = AppState::with_config_path(store, config.clone(), config_path);
let security_config = Arc::new(config.security.clone());
let rate_limit_state = Arc::new(RateLimitState::new());
{
let rate_limit_state = Arc::clone(&rate_limit_state);
let window_secs = config.security.rate_limit_window_secs;
let max_entries = config.security.rate_limit_max_entries;
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(300));
loop {
interval.tick().await;
rate_limit_state.cleanup(window_secs, max_entries).await;
}
});
}
let collector = if !args.no_collector {
let mut collector = Collector::new(Arc::clone(&state));
collector.start().await;
Some(collector)
} else {
info!("Background collector disabled");
None
};
#[cfg(feature = "mqtt")]
{
use aranet_service::mqtt::MqttPublisher;
let mqtt_publisher = MqttPublisher::new(Arc::clone(&state));
mqtt_publisher.start().await;
}
#[cfg(feature = "prometheus")]
{
use aranet_service::prometheus::PrometheusPusher;
let prometheus_pusher = PrometheusPusher::new(Arc::clone(&state));
prometheus_pusher.start().await;
}
let app = Router::new()
.merge(api::router())
.merge(ws::router())
.layer(axum::middleware::from_fn_with_state(
security_config.clone(),
middleware::api_key_auth,
))
.layer(axum::middleware::from_fn_with_state(
(security_config, rate_limit_state),
middleware::rate_limit,
))
.layer(TraceLayer::new_for_http())
.layer(middleware::cors_layer(&config.security))
.with_state(Arc::clone(&state));
let addr: SocketAddr = config.server.bind.parse()?;
info!("Starting server on {}", addr);
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.with_graceful_shutdown(shutdown_signal(collector, state))
.await?;
Ok(())
}
async fn shutdown_signal(mut collector: Option<Collector>, state: Arc<AppState>) {
let ctrl_c = async {
if let Err(e) = tokio::signal::ctrl_c().await {
tracing::error!("Failed to install Ctrl+C handler: {}", e);
std::future::pending::<()>().await;
}
};
#[cfg(unix)]
let terminate = async {
match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
Ok(mut signal) => {
signal.recv().await;
}
Err(e) => {
tracing::error!("Failed to install SIGTERM handler: {}", e);
std::future::pending::<()>().await;
}
}
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
info!("Shutdown signal received, stopping services...");
if let Some(ref mut collector) = collector {
collector.stop().await;
}
state.collector.signal_stop();
info!("Graceful shutdown complete");
}