use crate::core::{RpcEndpoint, Result, SolanaRecoverError};
use crate::rpc::{RpcClientWrapper, ConnectionPoolTrait};
use crate::utils::enhanced_metrics::ConnectionPoolMetrics;
use async_trait::async_trait;
#[async_trait]
pub trait RpcClientTrait: Send + Sync {
async fn get_minimum_balance_for_rent_exemption(&self, account_size: usize) -> Result<u64>;
}
use solana_client::rpc_client::RpcClient;
use std::sync::Arc;
use tokio::sync::RwLock;
use dashmap::DashMap;
use std::time::{Duration, Instant};
use serde::{Serialize, Deserialize};
use std::collections::HashMap;
#[cfg(test)]
include!("enhanced_pool_tests.rs");
pub struct EnhancedConnectionPool {
endpoints: Arc<RwLock<Vec<WeightedEndpoint>>>,
connection_pools: Arc<DashMap<String, Arc<BasicConnectionPool>>>,
health_checker: Arc<HealthChecker>,
#[allow(dead_code)]
load_balancer: Arc<LoadBalancer>,
circuit_breakers: Arc<DashMap<String, Arc<CircuitBreaker>>>,
metrics: Arc<RwLock<EnhancedPoolMetrics>>,
#[allow(dead_code)]
config: EnhancedPoolConfig,
}
#[derive(Debug, Clone, Serialize)]
pub struct WeightedEndpoint {
pub endpoint: RpcEndpoint,
pub weight: f64, pub priority: u8,
pub region: String,
pub response_time_ms: f64,
pub success_rate: f64,
pub last_health_check_ms: Option<u64>, pub consecutive_failures: u32,
}
#[derive(Debug, Clone)]
pub struct EnhancedPoolConfig {
pub endpoints: Vec<RpcEndpoint>,
pub max_connections_per_endpoint: usize,
pub health_check_interval: Duration,
pub circuit_breaker_threshold: u32,
pub circuit_breaker_timeout: Duration,
pub enable_load_balancing: bool,
pub request_timeout: Duration,
pub enable_connection_multiplexing: bool,
pub enable_compression: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LoadBalanceStrategy {
RoundRobin,
WeightedRoundRobin,
LeastConnections,
ResponseTime,
}
#[derive(Debug, Default, Clone, Serialize)]
pub struct EnhancedPoolMetrics {
pub total_requests: u64,
pub successful_requests: u64,
pub failed_requests: u64,
pub active_connections: u64,
pub avg_response_time_ms: f64,
pub endpoint_metrics: HashMap<String, EndpointMetrics>,
pub circuit_breaker_activations: u64,
pub last_health_check: Option<chrono::DateTime<chrono::Utc>>,
}
#[derive(Debug, Clone, Serialize)]
pub struct EndpointMetrics {
pub requests: u64,
pub successes: u64,
pub failures: u64,
pub avg_response_time_ms: f64,
pub last_success_ms: Option<u64>, pub last_failure_ms: Option<u64>, }
impl Default for EnhancedPoolConfig {
fn default() -> Self {
Self {
endpoints: vec![],
max_connections_per_endpoint: 50,
health_check_interval: Duration::from_secs(30),
circuit_breaker_threshold: 5,
circuit_breaker_timeout: Duration::from_secs(60),
enable_load_balancing: true,
request_timeout: Duration::from_secs(30),
enable_connection_multiplexing: true,
enable_compression: true,
}
}
}
impl EnhancedConnectionPool {
pub fn with_config(config: EnhancedPoolConfig) -> Self {
let config = config.clone();
let config_clone = config.clone();
let weighted_endpoints: Vec<WeightedEndpoint> = config_clone.endpoints
.into_iter()
.enumerate()
.map(|(i, endpoint)| WeightedEndpoint {
endpoint: endpoint.clone(),
weight: 1.0 / (i as f64 + 1.0), priority: endpoint.priority,
region: Self::extract_region(&endpoint.url),
response_time_ms: 100.0,
success_rate: 1.0,
last_health_check_ms: None,
consecutive_failures: 0,
})
.collect();
let health_check_interval = config_clone.health_check_interval;
let config_for_struct = config.clone();
let pool = Self {
endpoints: Arc::new(RwLock::new(weighted_endpoints)),
connection_pools: Arc::new(DashMap::new()),
health_checker: Arc::new(HealthChecker::new(health_check_interval)),
load_balancer: Arc::new(LoadBalancer::new(LoadBalanceStrategy::WeightedRoundRobin)),
circuit_breakers: Arc::new(DashMap::new()),
metrics: Arc::new(RwLock::new(EnhancedPoolMetrics::default())),
config: config_for_struct,
};
pool.initialize_components(config.clone());
pool
}
fn extract_region(url: &str) -> String {
if url.contains("us-east") {
"us-east".to_string()
} else if url.contains("us-west") {
"us-west".to_string()
} else if url.contains("eu") {
"eu-west".to_string()
} else {
"global".to_string()
}
}
fn initialize_components(&self, config: EnhancedPoolConfig) {
let endpoints = self.endpoints.blocking_read();
for endpoint in endpoints.iter() {
let pool = Arc::new(BasicConnectionPool::new(
endpoint.endpoint.clone(),
config.max_connections_per_endpoint,
));
self.connection_pools.insert(endpoint.endpoint.url.clone(), pool);
let circuit_breaker = Arc::new(CircuitBreaker::new(
endpoint.endpoint.url.clone(),
config.circuit_breaker_threshold,
config.circuit_breaker_timeout,
));
self.circuit_breakers.insert(endpoint.endpoint.url.clone(), circuit_breaker);
}
}
pub async fn start_health_checks(self: Arc<Self>) {
self.health_checker.start(self.clone()).await;
}
async fn select_endpoint(&self) -> Result<String> {
let endpoints = self.endpoints.read().await;
self.load_balancer.select_endpoint(&*endpoints).await
}
async fn get_client_for_endpoint(&self, endpoint_url: &str) -> Result<Arc<RpcClientWrapper>> {
let circuit_breaker = self.circuit_breakers.get(endpoint_url)
.ok_or_else(|| SolanaRecoverError::ConfigError("No circuit breaker for endpoint".to_string()))?;
if !circuit_breaker.allow_request().await {
return Err(SolanaRecoverError::NetworkError("Circuit breaker is open".to_string()));
}
let pool = self.connection_pools.get(endpoint_url)
.ok_or_else(|| SolanaRecoverError::ConfigError("No connection pool for endpoint".to_string()))?;
let client = pool.get_client().await?;
{
let mut metrics = self.metrics.write().await;
metrics.total_requests += 1;
let endpoint_metrics = metrics.endpoint_metrics
.entry(endpoint_url.to_string())
.or_insert_with(|| EndpointMetrics {
requests: 0,
successes: 0,
failures: 0,
avg_response_time_ms: 0.0,
last_success_ms: None,
last_failure_ms: None,
});
endpoint_metrics.requests += 1;
}
Ok(client)
}
pub async fn update_endpoint_metrics(&self, endpoint_url: &str, success: bool, response_time_ms: f64) {
let mut endpoints = self.endpoints.write().await;
if let Some(endpoint) = endpoints.iter_mut().find(|e| e.endpoint.url == endpoint_url) {
if success {
endpoint.success_rate = (endpoint.success_rate * 0.9) + (1.0 * 0.1); endpoint.response_time_ms = (endpoint.response_time_ms * 0.9) + (response_time_ms * 0.1);
endpoint.consecutive_failures = 0;
endpoint.weight = 1.0 / (1.0 + endpoint.response_time_ms / 1000.0) * endpoint.success_rate;
endpoint.last_health_check_ms = Some(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64
);
} else {
endpoint.consecutive_failures += 1;
endpoint.success_rate = endpoint.success_rate * 0.9;
endpoint.weight *= 0.8;
endpoint.last_health_check_ms = Some(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64
);
}
}
{
let mut metrics = self.metrics.write().await;
if success {
metrics.successful_requests += 1;
if let Some(endpoint_metrics) = metrics.endpoint_metrics.get_mut(endpoint_url) {
endpoint_metrics.successes += 1;
endpoint_metrics.last_success_ms = Some(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64
);
endpoint_metrics.avg_response_time_ms =
(endpoint_metrics.avg_response_time_ms * (endpoint_metrics.successes - 1) as f64 + response_time_ms)
/ endpoint_metrics.successes as f64;
}
} else {
metrics.failed_requests += 1;
if let Some(endpoint_metrics) = metrics.endpoint_metrics.get_mut(endpoint_url) {
endpoint_metrics.failures += 1;
endpoint_metrics.last_failure_ms = Some(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as u64
);
}
}
}
}
pub async fn get_metrics(&self) -> EnhancedPoolMetrics {
let metrics = self.metrics.read().await;
EnhancedPoolMetrics {
total_requests: metrics.total_requests,
successful_requests: metrics.successful_requests,
failed_requests: metrics.failed_requests,
active_connections: metrics.active_connections,
avg_response_time_ms: metrics.avg_response_time_ms,
endpoint_metrics: metrics.endpoint_metrics.clone(),
circuit_breaker_activations: metrics.circuit_breaker_activations,
last_health_check: metrics.last_health_check,
}
}
}
#[async_trait]
impl ConnectionPoolTrait for EnhancedConnectionPool {
async fn get_client(&self) -> Result<Arc<RpcClientWrapper>> {
let endpoint_url = self.select_endpoint().await?;
let client = self.get_client_for_endpoint(&endpoint_url).await?;
Ok(client)
}
}
pub struct MetricsAwareClient {
client: Arc<RpcClientWrapper>,
#[allow(dead_code)]
endpoint_url: String,
metrics: Arc<RwLock<ConnectionPoolMetrics>>,
}
impl MetricsAwareClient {
pub fn new(client: Arc<RpcClientWrapper>, endpoint_url: String, metrics: Arc<RwLock<ConnectionPoolMetrics>>) -> Self {
Self {
client,
endpoint_url,
metrics,
}
}
}
#[async_trait]
impl RpcClientTrait for RpcClientWrapper {
async fn get_minimum_balance_for_rent_exemption(&self, account_size: usize) -> Result<u64> {
self.get_minimum_balance_for_rent_exemption(account_size).await
}
}
#[async_trait]
impl RpcClientTrait for MetricsAwareClient {
async fn get_minimum_balance_for_rent_exemption(&self, account_size: usize) -> Result<u64> {
let _start_time = std::time::Instant::now();
let result = self.client.get_minimum_balance_for_rent_exemption(account_size).await;
{
let mut metrics = self.metrics.write().await;
if result.is_ok() {
metrics.active_connections += 1;
} else {
metrics.connection_errors += 1;
}
}
result
}
}
pub struct HealthChecker {
check_interval: Duration,
}
impl HealthChecker {
pub fn new(check_interval: Duration) -> Self {
Self { check_interval }
}
pub async fn start(&self, pool: Arc<EnhancedConnectionPool>) {
let pool_clone = pool.clone();
let interval = self.check_interval;
tokio::spawn(async move {
let mut ticker = tokio::time::interval(interval);
loop {
ticker.tick().await;
Self::perform_health_check(&pool_clone).await;
}
});
}
async fn perform_health_check(pool: &EnhancedConnectionPool) {
let endpoints = pool.endpoints.read().await;
let endpoint_urls: Vec<String> = endpoints.iter().map(|e| e.endpoint.url.clone()).collect();
drop(endpoints);
for endpoint_url in endpoint_urls {
let is_healthy = Self::check_endpoint_health(&endpoint_url).await;
{
let mut endpoints_guard = pool.endpoints.write().await;
if let Some(ep) = endpoints_guard.iter_mut().find(|e| e.endpoint.url == endpoint_url) {
if is_healthy {
ep.consecutive_failures = 0;
ep.endpoint.healthy = true;
} else {
ep.consecutive_failures += 1;
if ep.consecutive_failures >= 3 {
ep.endpoint.healthy = false;
}
}
}
}
if let Some(circuit_breaker) = pool.circuit_breakers.get(&endpoint_url) {
if is_healthy {
circuit_breaker.record_success().await;
} else {
circuit_breaker.record_failure().await;
}
}
}
{
let mut metrics = pool.metrics.write().await;
metrics.last_health_check = Some(chrono::Utc::now());
}
}
async fn check_endpoint_health(url: &str) -> bool {
let client = RpcClient::new_with_timeout(
url.to_string(),
Duration::from_millis(5000),
);
tokio::time::timeout(Duration::from_secs(5), async {
tokio::task::spawn_blocking(move || {
client.get_latest_blockhash().is_ok()
}).await.unwrap_or(false)
}).await.unwrap_or(false)
}
}
pub struct LoadBalancer {
#[allow(dead_code)]
strategy: LoadBalanceStrategy,
#[allow(dead_code)]
round_robin_counter: tokio::sync::Mutex<usize>,
}
impl LoadBalancer {
pub fn new(strategy: LoadBalanceStrategy) -> Self {
Self {
strategy,
round_robin_counter: tokio::sync::Mutex::new(0),
}
}
#[allow(dead_code)]
async fn select_endpoint(&self, endpoints: &[WeightedEndpoint]) -> Result<String> {
let healthy_endpoints: Vec<&WeightedEndpoint> = endpoints
.iter()
.filter(|e| e.endpoint.healthy)
.collect();
if healthy_endpoints.is_empty() {
return Err(SolanaRecoverError::ConfigError("No healthy endpoints available".to_string()));
}
match self.strategy {
LoadBalanceStrategy::RoundRobin => {
let mut counter = self.round_robin_counter.lock().await;
let index = *counter % healthy_endpoints.len();
*counter += 1;
Ok(healthy_endpoints[index].endpoint.url.clone())
}
LoadBalanceStrategy::WeightedRoundRobin => {
let total_weight: f64 = healthy_endpoints.iter().map(|e| e.weight).sum();
let mut random_weight = rand::random::<f64>() * total_weight;
for endpoint in &healthy_endpoints {
random_weight -= endpoint.weight;
if random_weight <= 0.0 {
return Ok(endpoint.endpoint.url.clone());
}
}
Ok(healthy_endpoints[0].endpoint.url.clone())
}
LoadBalanceStrategy::LeastConnections => {
let endpoint = healthy_endpoints
.iter()
.min_by(|a, b| a.response_time_ms.partial_cmp(&b.response_time_ms).unwrap())
.unwrap();
Ok(endpoint.endpoint.url.clone())
}
LoadBalanceStrategy::ResponseTime => {
let endpoint = healthy_endpoints
.iter()
.min_by(|a, b| a.response_time_ms.partial_cmp(&b.response_time_ms).unwrap())
.unwrap();
Ok(endpoint.endpoint.url.clone())
}
}
}
}
pub struct CircuitBreaker {
#[allow(dead_code)]
endpoint_url: String,
#[allow(dead_code)]
failure_threshold: u32,
timeout: Duration,
state: tokio::sync::Mutex<CircuitBreakerState>,
last_state_change: tokio::sync::Mutex<Instant>,
}
#[derive(Debug, Clone)]
enum CircuitBreakerState {
Closed,
Open,
HalfOpen,
}
impl CircuitBreaker {
pub fn new(endpoint_url: String, failure_threshold: u32, timeout: Duration) -> Self {
Self {
endpoint_url,
failure_threshold,
timeout,
state: tokio::sync::Mutex::new(CircuitBreakerState::Closed),
last_state_change: tokio::sync::Mutex::new(Instant::now()),
}
}
pub async fn allow_request(&self) -> bool {
let state = self.state.lock().await;
let last_change = self.last_state_change.lock().await;
match *state {
CircuitBreakerState::Closed => true,
CircuitBreakerState::Open => {
if last_change.elapsed() > self.timeout {
drop(state);
drop(last_change);
self.transition_to_half_open().await;
true
} else {
false
}
}
CircuitBreakerState::HalfOpen => true,
}
}
pub async fn record_success(&self) {
let state = self.state.lock().await;
match *state {
CircuitBreakerState::HalfOpen => {
drop(state);
self.transition_to_closed().await;
}
_ => {}
}
}
pub async fn record_failure(&self) {
let state = self.state.lock().await;
match *state {
CircuitBreakerState::Closed => {
drop(state);
self.transition_to_open().await;
}
CircuitBreakerState::HalfOpen => {
drop(state);
self.transition_to_open().await;
}
_ => {}
}
}
async fn transition_to_closed(&self) {
*self.state.lock().await = CircuitBreakerState::Closed;
*self.last_state_change.lock().await = Instant::now();
}
async fn transition_to_open(&self) {
*self.state.lock().await = CircuitBreakerState::Open;
*self.last_state_change.lock().await = Instant::now();
}
async fn transition_to_half_open(&self) {
*self.state.lock().await = CircuitBreakerState::HalfOpen;
*self.last_state_change.lock().await = Instant::now();
}
}
pub struct BasicConnectionPool {
endpoint: RpcEndpoint,
clients: tokio::sync::Mutex<Vec<Arc<RpcClientWrapper>>>,
max_connections: usize,
}
impl BasicConnectionPool {
pub fn new(endpoint: RpcEndpoint, max_connections: usize) -> Self {
Self {
endpoint,
clients: tokio::sync::Mutex::new(Vec::with_capacity(max_connections)),
max_connections,
}
}
pub async fn get_client(&self) -> Result<Arc<RpcClientWrapper>> {
let mut clients = self.clients.lock().await;
if let Some(client) = clients.pop() {
Ok(client)
} else {
drop(clients);
self.create_client().await
}
}
async fn create_client(&self) -> Result<Arc<RpcClientWrapper>> {
let client = Arc::new(RpcClientWrapper::from_url(
&self.endpoint.url,
self.endpoint.timeout_ms,
)?);
Ok(client)
}
pub async fn return_client(&self, _client: Arc<RpcClientWrapper>) {
let clients = self.clients.lock().await;
if clients.len() < self.max_connections {
}
}
}