1use bytes::Bytes;
5use parking_lot::{Mutex, RwLock};
6use reqwest::Client;
7use serde_json::Value;
8use std::{
9 collections::HashMap,
10 sync::{
11 atomic::{AtomicU64, Ordering},
12 Arc,
13 },
14 time::{Duration, Instant},
15};
16use tracing::{info, warn};
17
18use crate::{
19 config::Config,
20 config_reloader::{self, ReloadResponse},
21 cooldown,
22 endpoint::{EndpointMetrics, LoadBalancerError, RpcEndpoint},
23 forwarder, health,
24 metrics::{
25 COOLDOWNS_TRIGGERED, COOLDOWN_SECONDS_GAUGE, ENDPOINT_RATE_LIMIT_DEFERRED,
26 HEALTHCHECK_FAILED, HEALTHY_ENDPOINTS, RPC_REQUESTS_FAILED, RPC_REQUESTS_SUCCEEDED,
27 TOTAL_ENDPOINTS,
28 },
29 shutdown::ShutdownManager,
30 strategy,
31};
32
33pub(crate) fn extract_domain_name(url: &str) -> String {
35 url.split("://")
36 .nth(1)
37 .unwrap_or(url) .split('/')
39 .next()
40 .unwrap_or("") .split(':')
42 .next()
43 .unwrap_or("") .split('.')
45 .rev()
46 .take(2) .collect::<Vec<_>>()
48 .into_iter()
49 .rev()
50 .collect::<Vec<_>>()
51 .join("_") .replace('-', "_") .to_lowercase()
54}
55
56#[derive(Debug)]
57pub struct LoadBalancer {
58 pub endpoints: Arc<RwLock<Vec<RpcEndpoint>>>,
59 pub rate_limiters: Arc<RwLock<HashMap<String, Arc<Mutex<ratelimit_meter::DirectRateLimiter>>>>>,
60 pub client: Arc<RwLock<Client>>,
61 pub bind_addr: String,
62 pub health_check_interval_secs: Arc<RwLock<u64>>,
63 pub health_check_timeout_secs: Arc<RwLock<u64>>,
64 pub base_cooldown_secs: Arc<RwLock<u64>>,
65 pub max_cooldown_secs: Arc<RwLock<u64>>,
66 pub max_batch_size: Arc<RwLock<usize>>,
67 pub latency_smoothing_factor: Arc<RwLock<f64>>,
68 pub concurrency_limiter: Arc<tokio::sync::Semaphore>,
69 pub config_path: Arc<RwLock<String>>,
70 pub connect_timeout_ms: Arc<RwLock<u64>>,
72 pub timeout_secs: Arc<RwLock<u64>>,
73 pub pool_idle_timeout_secs: Arc<RwLock<u64>>,
74 pub pool_max_idle_per_host: Arc<RwLock<usize>>,
75}
76
77impl LoadBalancer {
78 pub fn new(config: Option<Config>, config_path: String) -> Self {
79 let config = config
81 .unwrap_or_default()
82 .finalize()
83 .expect("Configuration failed to finalize at startup");
84
85 let server_cfg = config.server.unwrap();
87 let balancer_cfg = config.balancer.unwrap();
88
89 let bind_addr = server_cfg.bind_addr.unwrap();
90 let health_check_interval = balancer_cfg.health_check_interval_secs.unwrap();
91 let health_check_timeout = balancer_cfg.health_check_timeout_secs.unwrap();
92 let base_cooldown = balancer_cfg.base_cooldown_secs.unwrap();
93 let max_cooldown = balancer_cfg.max_cooldown_secs.unwrap();
94 let max_batch_size = balancer_cfg.max_batch_size.unwrap();
95 let latency_smoothing_factor = balancer_cfg.latency_smoothing_factor.unwrap();
96 let max_concurrency = balancer_cfg.max_concurrency.unwrap();
97 let connect_timeout_ms = balancer_cfg.connect_timeout_ms.unwrap();
98 let timeout_secs = balancer_cfg.timeout_secs.unwrap();
99 let pool_idle_timeout_secs = balancer_cfg.pool_idle_timeout_secs.unwrap();
100 let pool_max_idle_per_host = balancer_cfg.pool_max_idle_per_host.unwrap();
101
102 let endpoints_list = balancer_cfg.endpoints.unwrap();
103
104 let endpoints: Vec<RpcEndpoint> = endpoints_list
105 .into_iter()
106 .enumerate()
107 .map(|(index, e)| {
108 let name = e.name.unwrap_or_else(|| {
110 let domain_name = extract_domain_name(&e.url);
111 format!("{:03}_{}", index + 1, domain_name)
112 });
113
114 RpcEndpoint {
115 name, url: e.url.clone(),
117 healthy: true,
118 last_check: Instant::now(),
119 cooldown_until: None,
120 cooldown_attempts: 0,
121 rate_limit_per_sec: e.rate_limit_per_sec,
122 burst_size: e.burst_size,
123 weight: e.weight.unwrap(),
124 metrics: Arc::new(EndpointMetrics {
125 ema_latency_ms: AtomicU64::new(timeout_secs * 1000),
126 ..Default::default()
127 }),
128 }
129 })
130 .collect();
131
132 TOTAL_ENDPOINTS.set(endpoints.len() as i64);
133 HEALTHY_ENDPOINTS.set(endpoints.len() as i64);
134
135 for ep in &endpoints {
136 COOLDOWN_SECONDS_GAUGE.with_label_values(&[&ep.name]).set(0);
137 COOLDOWNS_TRIGGERED.with_label_values(&[&ep.name]).inc_by(0);
138 ENDPOINT_RATE_LIMIT_DEFERRED.with_label_values(&[&ep.name]).inc_by(0);
139 RPC_REQUESTS_SUCCEEDED.with_label_values(&[&ep.name]).inc_by(0);
140 RPC_REQUESTS_FAILED.with_label_values(&[&ep.name]).inc_by(0);
141 HEALTHCHECK_FAILED.with_label_values(&[&ep.name]).inc_by(0);
142 }
143
144 let rate_limiters = Arc::new(RwLock::new(HashMap::new()));
145 for ep in &endpoints {
146 rate_limiters.write().insert(
147 ep.url.clone(),
148 config_reloader::create_rate_limiter(ep.burst_size, ep.rate_limit_per_sec),
149 );
150 }
151
152 let client = Client::builder()
153 .tcp_nodelay(true)
154 .connect_timeout(Duration::from_millis(connect_timeout_ms))
155 .timeout(Duration::from_secs(timeout_secs))
156 .pool_idle_timeout(Some(Duration::from_secs(pool_idle_timeout_secs)))
157 .pool_max_idle_per_host(pool_max_idle_per_host)
158 .http1_title_case_headers()
159 .build()
160 .expect("Failed to create HTTP client");
161
162 info!(
163 bind_addr = %bind_addr,
164 health_check_interval = health_check_interval,
165 health_check_timeout = health_check_timeout,
166 base_cooldown = base_cooldown,
167 max_cooldown = max_cooldown,
168 max_batch_size = max_batch_size,
169 max_concurrency = max_concurrency,
170 latency_smoothing_factor = latency_smoothing_factor,
171 connect_timeout_ms = connect_timeout_ms,
172 timeout_secs = timeout_secs,
173 pool_idle_timeout_secs = pool_idle_timeout_secs,
174 pool_max_idle_per_host = pool_max_idle_per_host,
175 endpoints_count = endpoints.len(),
176 "LoadBalancer initialized"
177 );
178
179 Self {
180 endpoints: Arc::new(RwLock::new(endpoints)),
181 rate_limiters,
182 client: Arc::new(RwLock::new(client)),
183 bind_addr,
184 health_check_interval_secs: Arc::new(RwLock::new(health_check_interval)),
185 health_check_timeout_secs: Arc::new(RwLock::new(health_check_timeout)),
186 base_cooldown_secs: Arc::new(RwLock::new(base_cooldown)),
187 max_cooldown_secs: Arc::new(RwLock::new(max_cooldown)),
188 max_batch_size: Arc::new(RwLock::new(max_batch_size)),
189 latency_smoothing_factor: Arc::new(RwLock::new(latency_smoothing_factor)),
190 concurrency_limiter: Arc::new(tokio::sync::Semaphore::new(max_concurrency)),
191 config_path: Arc::new(RwLock::new(config_path)),
192 connect_timeout_ms: Arc::new(RwLock::new(connect_timeout_ms)),
193 timeout_secs: Arc::new(RwLock::new(timeout_secs)),
194 pool_idle_timeout_secs: Arc::new(RwLock::new(pool_idle_timeout_secs)),
195 pool_max_idle_per_host: Arc::new(RwLock::new(pool_max_idle_per_host)),
196 }
197 }
198
199 pub fn reload_config(&self) -> Result<ReloadResponse, LoadBalancerError> {
200 config_reloader::reload(self)
201 }
202
203 pub fn get_next_endpoint(&self, batch_size: usize) -> Option<RpcEndpoint> {
205 let endpoints = self.endpoints.read();
206 let limiters = self.rate_limiters.read();
207
208 if let Some(endpoint) = strategy::select_best_endpoint(&endpoints, &limiters, batch_size) {
209 endpoint.metrics.last_selected.store(
211 std::time::SystemTime::now()
212 .duration_since(std::time::UNIX_EPOCH)
213 .unwrap_or_default()
214 .as_secs(),
215 std::sync::atomic::Ordering::Relaxed,
216 );
217 Some(endpoint.clone())
218 } else {
219 None
220 }
221 }
222
223 pub async fn forward_raw_request(
225 &self,
226 request_body: Bytes,
227 endpoint: &RpcEndpoint,
228 method: &str,
229 ) -> Result<Bytes, LoadBalancerError> {
230 forwarder::forward_request(self, request_body, endpoint, method).await
231 }
232
233 pub async fn mark_rate_limited(&self, url: &str) {
234 let mut endpoints = self.endpoints.write();
235 if let Some(ep) = endpoints.iter_mut().find(|e| e.url == url) {
236 cooldown::trigger_cooldown(
237 ep,
238 *self.base_cooldown_secs.read(),
239 *self.max_cooldown_secs.read(),
240 );
241 warn!(
242 endpoint = %ep.name,
243 cooldown_secs = ep.cooldown_remaining_secs(),
244 "Endpoint put into cooldown due to rate limiting"
245 );
246 }
247 }
248
249 pub async fn mark_unhealthy(&self, url: &str) {
250 let mut endpoints = self.endpoints.write();
251 if let Some(ep) = endpoints.iter_mut().find(|e| e.url == url) {
252 if ep.healthy {
253 ep.healthy = false;
254 cooldown::trigger_cooldown(
255 ep,
256 *self.base_cooldown_secs.read(),
257 *self.max_cooldown_secs.read(),
258 );
259 warn!(
260 endpoint = %ep.name,
261 cooldown_secs = ep.cooldown_remaining_secs(),
262 "Marked endpoint as unhealthy and put into cooldown"
263 );
264 }
265 }
266 }
267
268 pub async fn update_healthy_count(&self) {
269 let endpoints = self.endpoints.read();
270 let healthy_count = endpoints.iter().filter(|e| e.is_available()).count() as i64;
271 HEALTHY_ENDPOINTS.set(healthy_count);
272 }
273
274 pub fn run_background_tasks(self: &Arc<Self>, shutdown_manager: &mut ShutdownManager) {
275 let health_checker = self.clone();
276 shutdown_manager
277 .spawn_task(health::health_check_loop(health_checker, shutdown_manager.subscribe()));
278
279 let cooldown_updater = self.clone();
280 shutdown_manager.spawn_task(cooldown::cooldown_gauge_updater(
281 cooldown_updater,
282 shutdown_manager.subscribe(),
283 ));
284 }
285
286 fn get_endpoint_priority_stats(&self) -> Value {
287 let endpoints = self.endpoints.read();
288 let mut stats: Vec<_> = endpoints
289 .iter()
290 .map(|ep| {
291 serde_json::json!({
292 "name": ep.name,
293 "weight": ep.weight,
294 "healthy": ep.healthy,
295 "available": ep.is_available(),
296 "in_cooldown": ep.is_in_cooldown(),
297 "cooldown_remaining_secs": ep.cooldown_remaining_secs(),
298 "total_cost_ms": ep.metrics.get_total_cost(),
299 "last_selected_secs_ago": std::time::SystemTime::now()
300 .duration_since(std::time::UNIX_EPOCH)
301 .unwrap_or_default()
302 .as_secs()
303 .saturating_sub(ep.metrics.last_selected.load(Ordering::Relaxed)),
304 })
305 })
306 .collect();
307
308 stats.sort_by(|a, b| {
309 let weight_a = a["weight"].as_u64().unwrap_or(0);
310 let weight_b = b["weight"].as_u64().unwrap_or(0);
311 let cost_a = a["total_cost_ms"].as_u64().unwrap_or(u64::MAX);
312 let cost_b = b["total_cost_ms"].as_u64().unwrap_or(u64::MAX);
313 let time_a = a["last_selected_secs_ago"].as_u64().unwrap_or(0);
314 let time_b = b["last_selected_secs_ago"].as_u64().unwrap_or(0);
315
316 weight_b
317 .cmp(&weight_a)
318 .then_with(|| cost_a.cmp(&cost_b))
319 .then_with(|| time_b.cmp(&time_a)) });
321
322 let ranked_stats: Vec<_> = stats
323 .into_iter()
324 .enumerate()
325 .map(|(i, mut stat)| {
326 stat["priority_rank"] = serde_json::json!(i + 1);
327 stat
328 })
329 .collect();
330
331 serde_json::json!({
332 "strategy": "priority_based_with_lru_tie_breaking",
333 "endpoints": ranked_stats
334 })
335 }
336
337 pub fn get_status(&self) -> Value {
338 let endpoints = self.endpoints.read();
339 let available_endpoints: Vec<&RpcEndpoint> =
340 endpoints.iter().filter(|ep| ep.is_available()).collect();
341
342 let healthy_count = available_endpoints.len();
343 let config_path = self.config_path.read().clone();
344
345 serde_json::json!({
346 "bind_addr": self.bind_addr,
347 "config_path": config_path,
348 "total_endpoints": endpoints.len(),
349 "healthy_endpoints": healthy_count,
350 "balancer_settings": {
351 "health_check_interval_secs": *self.health_check_interval_secs.read(),
352 "health_check_timeout_secs": *self.health_check_timeout_secs.read(),
353 "base_cooldown_secs": *self.base_cooldown_secs.read(),
354 "max_cooldown_secs": *self.max_cooldown_secs.read(),
355 "max_batch_size": *self.max_batch_size.read(),
356 "latency_smoothing_factor": *self.latency_smoothing_factor.read(),
357 },
358 "client_settings": {
359 "connect_timeout_ms": *self.connect_timeout_ms.read(),
360 "timeout_secs": *self.timeout_secs.read(),
361 "pool_idle_timeout_secs": *self.pool_idle_timeout_secs.read(),
362 "pool_max_idle_per_host": *self.pool_max_idle_per_host.read(),
363 },
364 "priority_info": self.get_endpoint_priority_stats(),
365 "endpoints_details": endpoints.iter().map(|ep| {
366 serde_json::json!({
367 "name": ep.name, "healthy": ep.healthy,
370 "available": ep.is_available(),
371 "in_cooldown": ep.is_in_cooldown(),
372 "cooldown_remaining_secs": ep.cooldown_remaining_secs(),
373 "cooldown_attempts": ep.cooldown_attempts,
374 "consecutive_failures": ep.metrics.consecutive_failures.load(Ordering::Relaxed),
375 "ema_latency_ms": ep.metrics.ema_latency_ms.load(Ordering::Relaxed),
376 "ema_error_penalty_ms": ep.metrics.ema_error_penalty_ms.load(Ordering::Relaxed),
377 "total_cost_ms": ep.metrics.get_total_cost(),
378 "last_selected_timestamp": ep.metrics.last_selected.load(Ordering::Relaxed),
379 "last_check_secs_ago": ep.last_check.elapsed().as_secs(),
380 "rate_limit_per_sec": ep.rate_limit_per_sec,
381 "burst_size": ep.burst_size,
382 "weight": ep.weight
383 })
384 }).collect::<Vec<_>>()
385 })
386 }
387}