mod actor;
mod config;
mod metrics;
mod store;
mod transport;
mod types;
#[cfg(test)]
mod actor_tests;
use anyhow::Result;
use std::sync::Arc;
use tokio::signal;
use tokio::task::JoinSet;
use crate::config::Config;
use crate::metrics::Metrics;
use crate::transport::{
Transport, grpc::GrpcTransport, http::HttpTransport, redis::RedisTransport,
};
#[tokio::main]
async fn main() -> Result<()> {
let config = Config::from_env_and_args()?;
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::from_default_env()
.add_directive(format!("throttlecrab={}", config.log_level).parse()?),
)
.init();
let metrics = Arc::new(
Metrics::builder()
.max_denied_keys(config.max_denied_keys as usize)
.build(),
);
let limiter =
store::create_rate_limiter(&config.store, config.buffer_size, Arc::clone(&metrics));
let mut transport_tasks = JoinSet::new();
if let Some(http_config) = &config.transports.http {
let limiter_handle = limiter.clone();
let host = http_config.host.clone();
let port = http_config.port;
let metrics_clone = Arc::clone(&metrics);
transport_tasks.spawn(async move {
tracing::info!("Starting HTTP transport on {}:{}", host, port);
let transport = HttpTransport::new(&host, port, metrics_clone);
transport.start(limiter_handle).await
});
}
if let Some(grpc_config) = &config.transports.grpc {
let limiter_handle = limiter.clone();
let host = grpc_config.host.clone();
let port = grpc_config.port;
let metrics_clone = Arc::clone(&metrics);
transport_tasks.spawn(async move {
tracing::info!("Starting gRPC transport on {}:{}", host, port);
let transport = GrpcTransport::new(&host, port, metrics_clone);
transport.start(limiter_handle).await
});
}
if let Some(redis_config) = &config.transports.redis {
let limiter_handle = limiter.clone();
let host = redis_config.host.clone();
let port = redis_config.port;
let metrics_clone = Arc::clone(&metrics);
transport_tasks.spawn(async move {
tracing::info!("Starting Redis transport on {}:{}", host, port);
let transport = RedisTransport::new(&host, port, metrics_clone)?;
transport.start(limiter_handle).await
});
}
let shutdown_signal = async {
let ctrl_c = signal::ctrl_c();
#[cfg(unix)]
let sigterm = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("Failed to install SIGTERM handler")
.recv()
.await;
};
#[cfg(not(unix))]
let sigterm = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {
tracing::info!("Received SIGINT (Ctrl+C), initiating graceful shutdown...");
}
_ = sigterm => {
tracing::info!("Received SIGTERM, initiating graceful shutdown...");
}
}
};
tokio::select! {
_ = shutdown_signal => {
tracing::info!("Shutdown signal received, stopping all transports...");
transport_tasks.abort_all();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
tracing::info!("ThrottleCrab server shutdown complete");
return Ok(());
}
result = transport_tasks.join_next() => {
if let Some(result) = result {
match result {
Ok(Ok(())) => {
tracing::info!("Transport task completed successfully");
}
Ok(Err(e)) => {
tracing::error!("Transport task failed: {}", e);
return Err(e);
}
Err(e) => {
tracing::error!("Transport task panicked: {}", e);
return Err(anyhow::anyhow!("Transport task panicked"));
}
}
}
}
}
tracing::info!(
"ThrottleCrab server started with store type: {:?}",
config.store.store_type
);
tracing::info!(
"Store capacity: {}, Buffer size: {}",
config.store.capacity,
config.buffer_size
);
Ok(())
}