#![deny(unsafe_code)]
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use axum::Router;
use tower_http::trace::TraceLayer;
pub mod api;
pub mod collector;
pub mod config;
pub mod dashboard;
pub mod middleware;
pub mod state;
pub mod ws;
pub use collector::Collector;
pub use config::{
Config, ConfigError, DeviceConfig, InfluxDbConfig, MqttConfig, NotificationConfig,
PrometheusConfig, SecurityConfig, ServerConfig, StorageConfig, WebhookConfig, WebhookEndpoint,
};
pub use state::{AppState, ReadingEvent};
#[cfg(feature = "mqtt")]
pub mod mqtt;
#[cfg(feature = "prometheus")]
pub mod prometheus;
pub mod influxdb;
pub mod mdns;
pub mod webhook;
#[derive(Debug, Clone, Default)]
pub struct RunOptions {
pub config: Option<PathBuf>,
pub bind: Option<String>,
pub database: Option<PathBuf>,
pub no_collector: bool,
}
pub fn init_tracing() -> anyhow::Result<()> {
let filter = tracing_subscriber::EnvFilter::from_default_env()
.add_directive("aranet_service=info".parse()?)
.add_directive("tower_http=debug".parse()?);
let _ = tracing_subscriber::fmt().with_env_filter(filter).try_init();
Ok(())
}
pub fn app(
state: Arc<AppState>,
security_config: Arc<SecurityConfig>,
rate_limit_state: Arc<middleware::RateLimitState>,
) -> Router {
Router::new()
.merge(api::router())
.merge(ws::router())
.merge(dashboard::router())
.layer(axum::middleware::from_fn_with_state(
Arc::clone(&security_config),
middleware::api_key_auth,
))
.layer(axum::middleware::from_fn_with_state(
(security_config, rate_limit_state),
middleware::rate_limit,
))
.layer(TraceLayer::new_for_http())
.with_state(state)
}
pub async fn run(options: RunOptions) -> anyhow::Result<()> {
let config_path = options
.config
.clone()
.unwrap_or_else(config::default_config_path);
let mut config = if config_path.exists() {
Config::load(&config_path)?
} else {
Config::default()
};
if let Some(bind) = options.bind {
config.server.bind = bind;
}
if let Some(db_path) = options.database {
config.storage.path = db_path;
}
config.validate()?;
tracing::info!("Opening database at {:?}", config.storage.path);
let store = aranet_store::Store::open(&config.storage.path)?;
let state = AppState::with_config_path(store, config.clone(), config_path);
let security_config = Arc::new(config.security.clone());
let rate_limit_state = Arc::new(middleware::RateLimitState::new());
{
let rate_limit_state = Arc::clone(&rate_limit_state);
let window_secs = config.security.rate_limit_window_secs;
let max_entries = config.security.rate_limit_max_entries;
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(300));
loop {
interval.tick().await;
rate_limit_state.cleanup(window_secs, max_entries).await;
}
});
}
let collector = if !options.no_collector {
let collector = Collector::new(Arc::clone(&state));
collector.start().await;
Some(collector)
} else {
tracing::info!("Background collector disabled");
None
};
#[cfg(feature = "mqtt")]
{
use crate::mqtt::MqttPublisher;
let mqtt_publisher = MqttPublisher::new(Arc::clone(&state));
mqtt_publisher.start().await;
}
#[cfg(feature = "prometheus")]
{
use crate::prometheus::PrometheusPusher;
let prometheus_pusher = PrometheusPusher::new(Arc::clone(&state));
prometheus_pusher.start().await;
}
{
use crate::webhook::WebhookDispatcher;
let webhook_dispatcher = WebhookDispatcher::new(Arc::clone(&state));
webhook_dispatcher.start().await;
}
{
use crate::influxdb::InfluxDbWriter;
let influxdb_writer = InfluxDbWriter::new(Arc::clone(&state));
influxdb_writer.start().await;
}
let _mdns_handle = {
use crate::mdns::MdnsAdvertiser;
let advertiser = MdnsAdvertiser::new(Arc::clone(&state));
advertiser.start().await
};
let app = app(
Arc::clone(&state),
Arc::clone(&security_config),
Arc::clone(&rate_limit_state),
)
.layer(middleware::cors_layer(&config.security));
let addr: SocketAddr = config.server.bind.parse()?;
tracing::info!("Starting server on {}", addr);
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.with_graceful_shutdown(shutdown_signal(collector, state))
.await?;
Ok(())
}
async fn shutdown_signal(mut collector: Option<Collector>, state: Arc<AppState>) {
let ctrl_c = async {
if let Err(e) = tokio::signal::ctrl_c().await {
tracing::error!("Failed to install Ctrl+C handler: {}", e);
std::future::pending::<()>().await;
}
};
#[cfg(unix)]
let terminate = async {
match tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) {
Ok(mut signal) => {
signal.recv().await;
}
Err(e) => {
tracing::error!("Failed to install SIGTERM handler: {}", e);
std::future::pending::<()>().await;
}
}
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
tracing::info!("Shutdown signal received, stopping services...");
if let Some(ref mut collector) = collector {
collector.stop().await;
}
state.signal_shutdown();
state.collector.signal_stop();
tracing::info!("Graceful shutdown complete");
}