azoth_balancer/
balancer.rs

1//! The core logic for the load balancer, including endpoint management,
2//! selection, and state tracking.
3
4use bytes::Bytes;
5use parking_lot::{Mutex, RwLock};
6use reqwest::Client;
7use serde_json::Value;
8use std::{
9    collections::HashMap,
10    sync::{
11        atomic::{AtomicU64, Ordering},
12        Arc,
13    },
14    time::{Duration, Instant},
15};
16use tracing::{info, warn};
17
18use crate::{
19    config::Config,
20    config_reloader::{self, ReloadResponse},
21    cooldown,
22    endpoint::{EndpointMetrics, LoadBalancerError, RpcEndpoint},
23    forwarder, health,
24    metrics::{
25        COOLDOWNS_TRIGGERED, COOLDOWN_SECONDS_GAUGE, ENDPOINT_RATE_LIMIT_DEFERRED,
26        HEALTHCHECK_FAILED, HEALTHY_ENDPOINTS, RPC_REQUESTS_FAILED, RPC_REQUESTS_SUCCEEDED,
27        TOTAL_ENDPOINTS,
28    },
29    shutdown::ShutdownManager,
30    strategy,
31};
32
33// Helper function to extract domain name from URL
34pub(crate) fn extract_domain_name(url: &str) -> String {
35    url.split("://")
36        .nth(1)
37        .unwrap_or(url) // Remove protocol
38        .split('/')
39        .next()
40        .unwrap_or("") // Get domain part only
41        .split(':')
42        .next()
43        .unwrap_or("") // Remove port
44        .split('.')
45        .rev()
46        .take(2) // Take last 2 parts (domain.com)
47        .collect::<Vec<_>>()
48        .into_iter()
49        .rev()
50        .collect::<Vec<_>>()
51        .join("_") // Join with underscore
52        .replace('-', "_") // Replace hyphens
53        .to_lowercase()
54}
55
56#[derive(Debug)]
57pub struct LoadBalancer {
58    pub endpoints: Arc<RwLock<Vec<RpcEndpoint>>>,
59    pub rate_limiters: Arc<RwLock<HashMap<String, Arc<Mutex<ratelimit_meter::DirectRateLimiter>>>>>,
60    pub client: Arc<RwLock<Client>>,
61    pub bind_addr: String,
62    pub health_check_interval_secs: Arc<RwLock<u64>>,
63    pub health_check_timeout_secs: Arc<RwLock<u64>>,
64    pub base_cooldown_secs: Arc<RwLock<u64>>,
65    pub max_cooldown_secs: Arc<RwLock<u64>>,
66    pub max_batch_size: Arc<RwLock<usize>>,
67    pub latency_smoothing_factor: Arc<RwLock<f64>>,
68    pub concurrency_limiter: Arc<tokio::sync::Semaphore>,
69    pub config_path: Arc<RwLock<String>>,
70    // Store current client settings to detect changes for hot-reloading
71    pub connect_timeout_ms: Arc<RwLock<u64>>,
72    pub timeout_secs: Arc<RwLock<u64>>,
73    pub pool_idle_timeout_secs: Arc<RwLock<u64>>,
74    pub pool_max_idle_per_host: Arc<RwLock<usize>>,
75}
76
77impl LoadBalancer {
78    pub fn new(config: Option<Config>, config_path: String) -> Self {
79        // Finalize the config first: applies defaults, validates, and sanitizes.
80        let config = config
81            .unwrap_or_default()
82            .finalize()
83            .expect("Configuration failed to finalize at startup");
84
85        // Now, all values are guaranteed to be present and valid.
86        let server_cfg = config.server.unwrap();
87        let balancer_cfg = config.balancer.unwrap();
88
89        let bind_addr = server_cfg.bind_addr.unwrap();
90        let health_check_interval = balancer_cfg.health_check_interval_secs.unwrap();
91        let health_check_timeout = balancer_cfg.health_check_timeout_secs.unwrap();
92        let base_cooldown = balancer_cfg.base_cooldown_secs.unwrap();
93        let max_cooldown = balancer_cfg.max_cooldown_secs.unwrap();
94        let max_batch_size = balancer_cfg.max_batch_size.unwrap();
95        let latency_smoothing_factor = balancer_cfg.latency_smoothing_factor.unwrap();
96        let max_concurrency = balancer_cfg.max_concurrency.unwrap();
97        let connect_timeout_ms = balancer_cfg.connect_timeout_ms.unwrap();
98        let timeout_secs = balancer_cfg.timeout_secs.unwrap();
99        let pool_idle_timeout_secs = balancer_cfg.pool_idle_timeout_secs.unwrap();
100        let pool_max_idle_per_host = balancer_cfg.pool_max_idle_per_host.unwrap();
101
102        let endpoints_list = balancer_cfg.endpoints.unwrap();
103
104        let endpoints: Vec<RpcEndpoint> = endpoints_list
105            .into_iter()
106            .enumerate()
107            .map(|(index, e)| {
108                // Generate name if not provided in config
109                let name = e.name.unwrap_or_else(|| {
110                    let domain_name = extract_domain_name(&e.url);
111                    format!("{:03}_{}", index + 1, domain_name)
112                });
113
114                RpcEndpoint {
115                    name, // Use the generated/configured name
116                    url: e.url.clone(),
117                    healthy: true,
118                    last_check: Instant::now(),
119                    cooldown_until: None,
120                    cooldown_attempts: 0,
121                    rate_limit_per_sec: e.rate_limit_per_sec,
122                    burst_size: e.burst_size,
123                    weight: e.weight.unwrap(),
124                    metrics: Arc::new(EndpointMetrics {
125                        ema_latency_ms: AtomicU64::new(timeout_secs * 1000),
126                        ..Default::default()
127                    }),
128                }
129            })
130            .collect();
131
132        TOTAL_ENDPOINTS.set(endpoints.len() as i64);
133        HEALTHY_ENDPOINTS.set(endpoints.len() as i64);
134
135        for ep in &endpoints {
136            COOLDOWN_SECONDS_GAUGE.with_label_values(&[&ep.name]).set(0);
137            COOLDOWNS_TRIGGERED.with_label_values(&[&ep.name]).inc_by(0);
138            ENDPOINT_RATE_LIMIT_DEFERRED.with_label_values(&[&ep.name]).inc_by(0);
139            RPC_REQUESTS_SUCCEEDED.with_label_values(&[&ep.name]).inc_by(0);
140            RPC_REQUESTS_FAILED.with_label_values(&[&ep.name]).inc_by(0);
141            HEALTHCHECK_FAILED.with_label_values(&[&ep.name]).inc_by(0);
142        }
143
144        let rate_limiters = Arc::new(RwLock::new(HashMap::new()));
145        for ep in &endpoints {
146            rate_limiters.write().insert(
147                ep.url.clone(),
148                config_reloader::create_rate_limiter(ep.burst_size, ep.rate_limit_per_sec),
149            );
150        }
151
152        let client = Client::builder()
153            .tcp_nodelay(true)
154            .connect_timeout(Duration::from_millis(connect_timeout_ms))
155            .timeout(Duration::from_secs(timeout_secs))
156            .pool_idle_timeout(Some(Duration::from_secs(pool_idle_timeout_secs)))
157            .pool_max_idle_per_host(pool_max_idle_per_host)
158            .http1_title_case_headers()
159            .build()
160            .expect("Failed to create HTTP client");
161
162        info!(
163            bind_addr = %bind_addr,
164            health_check_interval = health_check_interval,
165            health_check_timeout = health_check_timeout,
166            base_cooldown = base_cooldown,
167            max_cooldown = max_cooldown,
168            max_batch_size = max_batch_size,
169            max_concurrency = max_concurrency,
170            latency_smoothing_factor = latency_smoothing_factor,
171            connect_timeout_ms = connect_timeout_ms,
172            timeout_secs = timeout_secs,
173            pool_idle_timeout_secs = pool_idle_timeout_secs,
174            pool_max_idle_per_host = pool_max_idle_per_host,
175            endpoints_count = endpoints.len(),
176            "LoadBalancer initialized"
177        );
178
179        Self {
180            endpoints: Arc::new(RwLock::new(endpoints)),
181            rate_limiters,
182            client: Arc::new(RwLock::new(client)),
183            bind_addr,
184            health_check_interval_secs: Arc::new(RwLock::new(health_check_interval)),
185            health_check_timeout_secs: Arc::new(RwLock::new(health_check_timeout)),
186            base_cooldown_secs: Arc::new(RwLock::new(base_cooldown)),
187            max_cooldown_secs: Arc::new(RwLock::new(max_cooldown)),
188            max_batch_size: Arc::new(RwLock::new(max_batch_size)),
189            latency_smoothing_factor: Arc::new(RwLock::new(latency_smoothing_factor)),
190            concurrency_limiter: Arc::new(tokio::sync::Semaphore::new(max_concurrency)),
191            config_path: Arc::new(RwLock::new(config_path)),
192            connect_timeout_ms: Arc::new(RwLock::new(connect_timeout_ms)),
193            timeout_secs: Arc::new(RwLock::new(timeout_secs)),
194            pool_idle_timeout_secs: Arc::new(RwLock::new(pool_idle_timeout_secs)),
195            pool_max_idle_per_host: Arc::new(RwLock::new(pool_max_idle_per_host)),
196        }
197    }
198
199    pub fn reload_config(&self) -> Result<ReloadResponse, LoadBalancerError> {
200        config_reloader::reload(self)
201    }
202
203    /// Selects the best available endpoint, updates its last_selected timestamp, and returns it.
204    pub fn get_next_endpoint(&self, batch_size: usize) -> Option<RpcEndpoint> {
205        let endpoints = self.endpoints.read();
206        let limiters = self.rate_limiters.read();
207
208        if let Some(endpoint) = strategy::select_best_endpoint(&endpoints, &limiters, batch_size) {
209            // Set the timestamp to now before returning the endpoint.
210            endpoint.metrics.last_selected.store(
211                std::time::SystemTime::now()
212                    .duration_since(std::time::UNIX_EPOCH)
213                    .unwrap_or_default()
214                    .as_secs(),
215                std::sync::atomic::Ordering::Relaxed,
216            );
217            Some(endpoint.clone())
218        } else {
219            None
220        }
221    }
222
223    /// A wrapper method that calls the main request forwarding logic.
224    pub async fn forward_raw_request(
225        &self,
226        request_body: Bytes,
227        endpoint: &RpcEndpoint,
228        method: &str,
229    ) -> Result<Bytes, LoadBalancerError> {
230        forwarder::forward_request(self, request_body, endpoint, method).await
231    }
232
233    pub async fn mark_rate_limited(&self, url: &str) {
234        let mut endpoints = self.endpoints.write();
235        if let Some(ep) = endpoints.iter_mut().find(|e| e.url == url) {
236            cooldown::trigger_cooldown(
237                ep,
238                *self.base_cooldown_secs.read(),
239                *self.max_cooldown_secs.read(),
240            );
241            warn!(
242                endpoint = %ep.name,
243                cooldown_secs = ep.cooldown_remaining_secs(),
244                "Endpoint put into cooldown due to rate limiting"
245            );
246        }
247    }
248
249    pub async fn mark_unhealthy(&self, url: &str) {
250        let mut endpoints = self.endpoints.write();
251        if let Some(ep) = endpoints.iter_mut().find(|e| e.url == url) {
252            if ep.healthy {
253                ep.healthy = false;
254                cooldown::trigger_cooldown(
255                    ep,
256                    *self.base_cooldown_secs.read(),
257                    *self.max_cooldown_secs.read(),
258                );
259                warn!(
260                    endpoint = %ep.name,
261                    cooldown_secs = ep.cooldown_remaining_secs(),
262                    "Marked endpoint as unhealthy and put into cooldown"
263                );
264            }
265        }
266    }
267
268    pub async fn update_healthy_count(&self) {
269        let endpoints = self.endpoints.read();
270        let healthy_count = endpoints.iter().filter(|e| e.is_available()).count() as i64;
271        HEALTHY_ENDPOINTS.set(healthy_count);
272    }
273
274    pub fn run_background_tasks(self: &Arc<Self>, shutdown_manager: &mut ShutdownManager) {
275        let health_checker = self.clone();
276        shutdown_manager
277            .spawn_task(health::health_check_loop(health_checker, shutdown_manager.subscribe()));
278
279        let cooldown_updater = self.clone();
280        shutdown_manager.spawn_task(cooldown::cooldown_gauge_updater(
281            cooldown_updater,
282            shutdown_manager.subscribe(),
283        ));
284    }
285
286    fn get_endpoint_priority_stats(&self) -> Value {
287        let endpoints = self.endpoints.read();
288        let mut stats: Vec<_> = endpoints
289            .iter()
290            .map(|ep| {
291                serde_json::json!({
292                    "name": ep.name,
293                    "weight": ep.weight,
294                    "healthy": ep.healthy,
295                    "available": ep.is_available(),
296                    "in_cooldown": ep.is_in_cooldown(),
297                    "cooldown_remaining_secs": ep.cooldown_remaining_secs(),
298                    "total_cost_ms": ep.metrics.get_total_cost(),
299                    "last_selected_secs_ago": std::time::SystemTime::now()
300                        .duration_since(std::time::UNIX_EPOCH)
301                        .unwrap_or_default()
302                        .as_secs()
303                        .saturating_sub(ep.metrics.last_selected.load(Ordering::Relaxed)),
304                })
305            })
306            .collect();
307
308        stats.sort_by(|a, b| {
309            let weight_a = a["weight"].as_u64().unwrap_or(0);
310            let weight_b = b["weight"].as_u64().unwrap_or(0);
311            let cost_a = a["total_cost_ms"].as_u64().unwrap_or(u64::MAX);
312            let cost_b = b["total_cost_ms"].as_u64().unwrap_or(u64::MAX);
313            let time_a = a["last_selected_secs_ago"].as_u64().unwrap_or(0);
314            let time_b = b["last_selected_secs_ago"].as_u64().unwrap_or(0);
315
316            weight_b
317                .cmp(&weight_a)
318                .then_with(|| cost_a.cmp(&cost_b))
319                .then_with(|| time_b.cmp(&time_a)) // Higher secs ago (older) is better
320        });
321
322        let ranked_stats: Vec<_> = stats
323            .into_iter()
324            .enumerate()
325            .map(|(i, mut stat)| {
326                stat["priority_rank"] = serde_json::json!(i + 1);
327                stat
328            })
329            .collect();
330
331        serde_json::json!({
332            "strategy": "priority_based_with_lru_tie_breaking",
333            "endpoints": ranked_stats
334        })
335    }
336
337    pub fn get_status(&self) -> Value {
338        let endpoints = self.endpoints.read();
339        let available_endpoints: Vec<&RpcEndpoint> =
340            endpoints.iter().filter(|ep| ep.is_available()).collect();
341
342        let healthy_count = available_endpoints.len();
343        let config_path = self.config_path.read().clone();
344
345        serde_json::json!({
346            "bind_addr": self.bind_addr,
347            "config_path": config_path,
348            "total_endpoints": endpoints.len(),
349            "healthy_endpoints": healthy_count,
350            "balancer_settings": {
351                "health_check_interval_secs": *self.health_check_interval_secs.read(),
352                "health_check_timeout_secs": *self.health_check_timeout_secs.read(),
353                "base_cooldown_secs": *self.base_cooldown_secs.read(),
354                "max_cooldown_secs": *self.max_cooldown_secs.read(),
355                "max_batch_size": *self.max_batch_size.read(),
356                "latency_smoothing_factor": *self.latency_smoothing_factor.read(),
357            },
358            "client_settings": {
359                "connect_timeout_ms": *self.connect_timeout_ms.read(),
360                "timeout_secs": *self.timeout_secs.read(),
361                "pool_idle_timeout_secs": *self.pool_idle_timeout_secs.read(),
362                "pool_max_idle_per_host": *self.pool_max_idle_per_host.read(),
363            },
364            "priority_info": self.get_endpoint_priority_stats(),
365            "endpoints_details": endpoints.iter().map(|ep| {
366                serde_json::json!({
367                    "name": ep.name,  // Show the name
368                    // REMOVED: "url": ep.url,  - Don't expose URL!
369                    "healthy": ep.healthy,
370                    "available": ep.is_available(),
371                    "in_cooldown": ep.is_in_cooldown(),
372                    "cooldown_remaining_secs": ep.cooldown_remaining_secs(),
373                    "cooldown_attempts": ep.cooldown_attempts,
374                    "consecutive_failures": ep.metrics.consecutive_failures.load(Ordering::Relaxed),
375                    "ema_latency_ms": ep.metrics.ema_latency_ms.load(Ordering::Relaxed),
376                    "ema_error_penalty_ms": ep.metrics.ema_error_penalty_ms.load(Ordering::Relaxed),
377                    "total_cost_ms": ep.metrics.get_total_cost(),
378                    "last_selected_timestamp": ep.metrics.last_selected.load(Ordering::Relaxed),
379                    "last_check_secs_ago": ep.last_check.elapsed().as_secs(),
380                    "rate_limit_per_sec": ep.rate_limit_per_sec,
381                    "burst_size": ep.burst_size,
382                    "weight": ep.weight
383                })
384            }).collect::<Vec<_>>()
385        })
386    }
387}