use crate::{Error, Result};
use crate::storage::{PoolConfig, RedisClient, RedisMode};
use crate::middleware::MiddlewareChain;
use crate::aggregator::{AggregatorConfig, AggregatorManager};
use crate::server::JanitorConfig;
use std::sync::Arc;
use uuid::Uuid;
#[derive(Debug, Clone)]
pub struct ServerConfig {
pub redis_url: String,
pub redis_mode: RedisMode,
pub pool_config: PoolConfig,
pub queues: Vec<String>,
pub concurrency: usize,
pub heartbeat_interval: u64,
pub worker_timeout: u64,
pub heartbeat_ttl_multiplier: f64,
pub dequeue_timeout: u64,
pub poll_interval: u64,
pub server_name: String,
pub enable_scheduler: bool,
pub aggregator_config: Option<AggregatorConfig>,
pub janitor_config: Option<JanitorConfig>,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
redis_url: "redis://localhost:6379".to_string(),
redis_mode: RedisMode::Standalone,
pool_config: PoolConfig::default(),
queues: vec!["default".to_string()],
concurrency: 10,
heartbeat_interval: 5,
worker_timeout: 30,
heartbeat_ttl_multiplier: 2.0,
dequeue_timeout: 2,
poll_interval: 100,
server_name: format!("rediq-server-{}", Uuid::new_v4()),
enable_scheduler: true,
aggregator_config: None,
janitor_config: None,
}
}
}
#[derive(Debug, Default)]
pub struct ServerBuilder {
config: ServerConfig,
middleware: MiddlewareChain,
}
impl ServerBuilder {
#[must_use]
pub fn new() -> Self {
Self {
config: ServerConfig::default(),
middleware: MiddlewareChain::new(),
}
}
#[must_use]
pub fn middleware<M: crate::middleware::Middleware + 'static>(mut self, middleware: M) -> Self {
self.middleware = self.middleware.add(middleware);
self
}
#[must_use]
pub fn redis_url(mut self, url: impl Into<String>) -> Self {
self.config.redis_url = url.into();
self
}
#[must_use]
pub fn cluster_mode(mut self) -> Self {
self.config.redis_mode = RedisMode::Cluster;
self
}
#[must_use]
pub fn sentinel_mode(mut self) -> Self {
self.config.redis_mode = RedisMode::Sentinel;
self
}
#[must_use]
pub fn queues(mut self, queues: &[&str]) -> Self {
self.config.queues = queues.iter().map(|s| s.to_string()).collect();
self
}
#[must_use]
pub fn concurrency(mut self, concurrency: usize) -> Self {
self.config.concurrency = concurrency;
self
}
#[must_use]
pub fn heartbeat_interval(mut self, seconds: u64) -> Self {
self.config.heartbeat_interval = seconds;
self
}
#[must_use]
pub fn worker_timeout(mut self, seconds: u64) -> Self {
self.config.worker_timeout = seconds;
self
}
#[must_use]
pub fn heartbeat_ttl_multiplier(mut self, multiplier: f64) -> Self {
if multiplier <= 1.0 {
tracing::warn!("heartbeat_ttl_multiplier should be > 1.0 for reliable operation, got {}", multiplier);
}
self.config.heartbeat_ttl_multiplier = multiplier;
self
}
#[must_use]
pub fn dequeue_timeout(mut self, seconds: u64) -> Self {
self.config.dequeue_timeout = seconds;
self
}
#[must_use]
pub fn poll_interval(mut self, milliseconds: u64) -> Self {
self.config.poll_interval = milliseconds;
self
}
#[must_use]
pub fn server_name(mut self, name: impl Into<String>) -> Self {
self.config.server_name = name.into();
self
}
#[must_use]
pub fn disable_scheduler(mut self) -> Self {
self.config.enable_scheduler = false;
self
}
#[must_use]
pub fn pool_size(mut self, size: usize) -> Self {
self.config.pool_config.pool_size = size;
self
}
#[must_use]
pub fn min_idle(mut self, min_idle: usize) -> Self {
self.config.pool_config.min_idle = Some(min_idle);
self
}
#[must_use]
pub fn connection_timeout(mut self, timeout: u64) -> Self {
self.config.pool_config.connection_timeout = Some(timeout);
self
}
#[must_use]
pub fn idle_timeout(mut self, timeout: u64) -> Self {
self.config.pool_config.idle_timeout = Some(timeout);
self
}
#[must_use]
pub fn max_lifetime(mut self, lifetime: u64) -> Self {
self.config.pool_config.max_lifetime = Some(lifetime);
self
}
#[must_use]
pub fn aggregator_config(mut self, config: AggregatorConfig) -> Self {
self.config.aggregator_config = Some(config);
self
}
#[must_use]
pub fn janitor_config(mut self, config: JanitorConfig) -> Self {
self.config.janitor_config = Some(config);
self
}
pub async fn build(self) -> Result<ServerState> {
if self.config.concurrency == 0 {
return Err(Error::Config("concurrency must be greater than 0".into()));
}
if self.config.queues.is_empty() {
return Err(Error::Config("at least one queue must be specified".into()));
}
if self.config.heartbeat_interval == 0 {
return Err(Error::Config("heartbeat_interval must be greater than 0".into()));
}
let redis = match self.config.redis_mode {
RedisMode::Standalone => RedisClient::from_url_with_pool_config(&self.config.redis_url, self.config.pool_config.clone()).await?,
RedisMode::Cluster => RedisClient::from_cluster_url_with_pool_config(&self.config.redis_url, self.config.pool_config.clone()).await?,
RedisMode::Sentinel => RedisClient::from_sentinel_url_with_pool_config(&self.config.redis_url, self.config.pool_config.clone()).await?,
};
redis.ping().await?;
let mode_str = match self.config.redis_mode {
RedisMode::Standalone => "Standalone",
RedisMode::Cluster => "Cluster",
RedisMode::Sentinel => "Sentinel",
};
tracing::info!("Connected to Redis ({}) at {}", mode_str, self.config.redis_url);
let mut aggregator_manager = AggregatorManager::new();
if let Some(ref agg_config) = self.config.aggregator_config {
aggregator_manager.set_default_config(agg_config.clone());
}
Ok(ServerState {
config: Arc::new(self.config),
redis,
middleware: Arc::new(self.middleware),
aggregator: Arc::new(aggregator_manager),
})
}
}
#[derive(Clone)]
pub struct ServerState {
pub config: Arc<ServerConfig>,
pub redis: RedisClient,
pub middleware: Arc<MiddlewareChain>,
pub aggregator: Arc<AggregatorManager>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = ServerConfig::default();
assert_eq!(config.redis_url, "redis://localhost:6379");
assert_eq!(config.queues, vec!["default"]);
assert_eq!(config.concurrency, 10);
assert_eq!(config.heartbeat_interval, 5);
assert_eq!(config.worker_timeout, 30);
assert_eq!(config.dequeue_timeout, 2);
assert_eq!(config.poll_interval, 100);
assert!(config.enable_scheduler);
}
#[test]
fn test_builder() {
let builder = ServerBuilder::new()
.redis_url("redis://localhost:6380")
.queues(&["critical", "low"])
.concurrency(20)
.heartbeat_interval(10)
.dequeue_timeout(5)
.poll_interval(200)
.server_name("test-server");
assert_eq!(builder.config.redis_url, "redis://localhost:6380");
assert_eq!(builder.config.queues, vec!["critical", "low"]);
assert_eq!(builder.config.concurrency, 20);
assert_eq!(builder.config.heartbeat_interval, 10);
assert_eq!(builder.config.dequeue_timeout, 5);
assert_eq!(builder.config.poll_interval, 200);
assert_eq!(builder.config.server_name, "test-server");
}
#[test]
fn test_builder_disable_scheduler() {
let builder = ServerBuilder::new().disable_scheduler();
assert!(!builder.config.enable_scheduler);
}
#[tokio::test]
#[ignore = "Requires Redis server"]
async fn test_build_server() {
let redis_url = std::env::var("REDIS_URL")
.unwrap_or_else(|_| "redis://localhost:6379".to_string());
let state = ServerBuilder::new()
.redis_url(&redis_url)
.queues(&["default"])
.concurrency(5)
.build()
.await
.unwrap();
assert_eq!(state.config.queues.len(), 1);
assert_eq!(state.config.concurrency, 5);
}
}