use std::sync::Arc;
use log::*;
use super::{
config::ConnectivityConfig,
connection_pool::{ConnectionPool, ConnectionStatus},
error::ConnectivityError,
};
use crate::{
connection_manager::ConnectionManagerRequester,
peer_manager::{NodeId, Peer, PeerManager},
};
const LOG_TARGET: &str = "comms::connectivity::proactive_dialer";
const MAX_CONCURRENT_DIALS: usize = 30;
pub struct ProactiveDialer {
config: ConnectivityConfig,
connection_manager: ConnectionManagerRequester,
peer_manager: Arc<PeerManager>,
node_identity: Arc<crate::NodeIdentity>,
}
impl ProactiveDialer {
pub fn new(
config: ConnectivityConfig,
connection_manager: ConnectionManagerRequester,
peer_manager: Arc<PeerManager>,
node_identity: Arc<crate::NodeIdentity>,
) -> Self {
Self {
config,
connection_manager,
peer_manager,
node_identity,
}
}
pub async fn execute_proactive_dialing(
&mut self,
pool: &ConnectionPool,
connection_stats: &std::collections::HashMap<NodeId, super::connection_stats::PeerConnectionStats>,
excluded_peers: &[NodeId],
task_id: u64,
) -> Result<usize, ConnectivityError> {
let _start_time = std::time::Instant::now();
if !self.config.proactive_dialing_enabled {
return Ok(0);
}
let current_connections = pool.count_connected_nodes();
let target = self.config.target_connection_count;
if current_connections >= target {
debug!(
target: LOG_TARGET,
"({task_id}) Current connections ({current_connections}) meet or exceed target ({target}), no proactive dialing needed",
);
return Ok(0);
}
let needed = target.saturating_sub(current_connections);
debug!(
target: LOG_TARGET,
"({task_id}) Proactive dialing: need {needed} more connections ({current_connections}/{target})",
);
let success_rate = self.calculate_recent_success_rate(connection_stats);
let dial_count = self.calculate_dial_count(needed, success_rate);
debug!(
target: LOG_TARGET,
"({task_id}) Success rate: {success_rate:.2}, will dial {dial_count} peers for {needed} needed connections",
);
let candidates = self
.select_dial_candidates(pool, connection_stats, excluded_peers, dial_count, task_id)
.await?;
if candidates.is_empty() {
warn!(
target: LOG_TARGET,
"({task_id}) No peer candidates available for proactive dialing"
);
return Ok(0);
}
let dialed_count = self.dial_peers_concurrently(candidates, task_id).await;
info!(
target: LOG_TARGET,
"({task_id}) Proactive dialing initiated for {dialed_count} peers ({needed} needed connections)"
);
Ok(dialed_count)
}
fn calculate_recent_success_rate(
&self,
connection_stats: &std::collections::HashMap<NodeId, super::connection_stats::PeerConnectionStats>,
) -> f32 {
if connection_stats.is_empty() {
return 0.25; }
let window = self.config.success_rate_tracking_window;
let total_stats: Vec<f32> = connection_stats
.values()
.map(|stats| stats.success_rate(window))
.collect();
if total_stats.is_empty() {
return 0.25; }
let average = total_stats.iter().sum::<f32>() / total_stats.len() as f32;
average.clamp(0.1, 1.0) }
fn calculate_dial_count(&self, needed: usize, success_rate: f32) -> usize {
let base_count = needed as f32 * self.config.dialing_multiplier;
let adjusted_count = base_count / success_rate.max(0.1);
#[allow(clippy::cast_possible_truncation)]
let final_count = adjusted_count.ceil() as usize;
final_count.max(needed).min(MAX_CONCURRENT_DIALS)
}
async fn select_dial_candidates(
&self,
pool: &ConnectionPool,
connection_stats: &std::collections::HashMap<NodeId, super::connection_stats::PeerConnectionStats>,
excluded_peers: &[NodeId],
count: usize,
task_id: u64,
) -> Result<Vec<Peer>, ConnectivityError> {
let mut managed: Vec<NodeId> = pool
.all()
.iter()
.filter(|state| {
!matches!(
state.status(),
ConnectionStatus::Failed | ConnectionStatus::Disconnected(_)
)
})
.map(|state| state.node_id().clone())
.collect();
let mut managed_and_excluded = managed.clone();
managed_and_excluded.append(&mut excluded_peers.to_vec());
let mut candidates = self
.peer_manager
.get_available_dial_candidates(&managed_and_excluded, Some(count * 3), true, true) .await?;
if candidates.len() < count * 3 {
let mut to_be_excluded = candidates.iter().map(|p| p.node_id.clone()).collect::<Vec<_>>();
to_be_excluded.append(&mut managed);
let mut random = self
.peer_manager
.get_available_dial_candidates(&to_be_excluded, Some(count * 3 - candidates.len()), false, true)
.await?;
candidates.append(&mut random);
}
let mut final_candidates = Vec::new();
for peer in candidates {
if let Some(stats) = connection_stats.get(&peer.node_id) {
if !stats.should_allow_connection(self.config.circuit_breaker_retry_interval) {
trace!(
target: LOG_TARGET,
"({}) Skipping peer {} due to circuit breaker",
task_id,
peer.node_id.short_str()
);
continue;
}
}
final_candidates.push(peer);
}
final_candidates.sort_by(|a, b| {
let health_a = connection_stats
.get(&a.node_id)
.map(|s| s.health_score(self.config.success_rate_tracking_window))
.unwrap_or(0.5);
let health_b = connection_stats
.get(&b.node_id)
.map(|s| s.health_score(self.config.success_rate_tracking_window))
.unwrap_or(0.5);
match health_b.partial_cmp(&health_a) {
Some(std::cmp::Ordering::Equal) => {
let dist_a = a.node_id.distance(self.node_identity.node_id());
let dist_b = b.node_id.distance(self.node_identity.node_id());
dist_a.cmp(&dist_b)
},
Some(order) => order,
None => std::cmp::Ordering::Equal,
}
});
final_candidates.truncate(count);
debug!(
target: LOG_TARGET,
"({}) Selected {} healthy peer candidates for dialing",
task_id,
final_candidates.len()
);
Ok(final_candidates)
}
async fn dial_peers_concurrently(&mut self, peers: Vec<Peer>, task_id: u64) -> usize {
if peers.is_empty() {
return 0;
}
let mut successful_dials = 0;
for peer in peers {
debug!(
target: LOG_TARGET,
"({}) Initiating proactive dial to peer {}",
task_id,
peer.node_id.short_str()
);
match self.connection_manager.send_dial_peer(peer.node_id.clone(), None).await {
Ok(_) => {
successful_dials += 1;
},
Err(err) => {
warn!(
target: LOG_TARGET,
"({}) Failed to send dial request for peer {}: {:?}",
task_id,
peer.node_id.short_str(),
err
);
},
}
}
successful_dials
}
}
#[cfg(test)]
mod tests {
use crate::connectivity::proactive_dialer::MAX_CONCURRENT_DIALS;
#[test]
fn test_calculate_dial_count() {
fn calculate_dial_count(needed: usize, success_rate: f32, multiplier: f32) -> usize {
let base_count = needed as f32 * multiplier;
let adjusted_count = base_count / success_rate.max(0.1);
#[allow(clippy::cast_possible_truncation)]
let final_count = adjusted_count.ceil() as usize;
final_count.max(needed).min(MAX_CONCURRENT_DIALS)
}
assert_eq!(calculate_dial_count(4, 1.0, 2.0), 8);
assert_eq!(calculate_dial_count(4, 0.5, 2.0), 16);
let result = calculate_dial_count(4, 0.1, 2.0);
assert!(result >= 4); assert!(result <= MAX_CONCURRENT_DIALS);
assert_eq!(calculate_dial_count(25, 0.8, 1.5), MAX_CONCURRENT_DIALS); assert_eq!(calculate_dial_count(25, 0.1, 2.0), MAX_CONCURRENT_DIALS); assert_eq!(calculate_dial_count(15, 0.5, 3.0), MAX_CONCURRENT_DIALS); }
#[test]
fn test_calculate_recent_success_rate() {
let _empty_stats: std::collections::HashMap<String, f32> = std::collections::HashMap::new();
let default_rate = 0.25f32;
assert_eq!(default_rate, 0.25);
let test_rate = 1.5f32;
let clamped = test_rate.clamp(0.1, 1.0);
assert_eq!(clamped, 1.0);
let low_rate = 0.05f32;
let clamped_low = low_rate.clamp(0.1, 1.0);
assert_eq!(clamped_low, 0.1);
}
}