use crate::adaptive::trust::{NodeStatisticsUpdate, TrustEngine};
use crate::dht_network_manager::{DhtNetworkConfig, DhtNetworkManager};
use crate::{MultiAddr, PeerId};
use crate::error::P2pResult as Result;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
const DEFAULT_BLOCK_THRESHOLD: f64 = 0.15;
const MAX_CONSUMER_WEIGHT: f64 = 5.0;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct AdaptiveDhtConfig {
pub block_threshold: f64,
}
impl Default for AdaptiveDhtConfig {
fn default() -> Self {
Self {
block_threshold: DEFAULT_BLOCK_THRESHOLD,
}
}
}
impl AdaptiveDhtConfig {
pub fn validate(&self) -> crate::error::P2pResult<()> {
if !(0.0..0.5).contains(&self.block_threshold) || self.block_threshold.is_nan() {
return Err(crate::error::P2PError::Validation(
format!(
"block_threshold must be in [0.0, 0.5), got {}",
self.block_threshold
)
.into(),
));
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum TrustEvent {
SuccessfulResponse,
SuccessfulConnection,
ConnectionFailed,
ConnectionTimeout,
ApplicationSuccess(f64),
ApplicationFailure(f64),
}
impl TrustEvent {
fn to_stats_update(self) -> NodeStatisticsUpdate {
match self {
TrustEvent::SuccessfulResponse
| TrustEvent::SuccessfulConnection
| TrustEvent::ApplicationSuccess(_) => NodeStatisticsUpdate::CorrectResponse,
TrustEvent::ConnectionFailed
| TrustEvent::ConnectionTimeout
| TrustEvent::ApplicationFailure(_) => NodeStatisticsUpdate::FailedResponse,
}
}
}
pub struct AdaptiveDHT {
dht_manager: Arc<DhtNetworkManager>,
trust_engine: Arc<TrustEngine>,
config: AdaptiveDhtConfig,
}
impl AdaptiveDHT {
pub async fn new(
transport: Arc<crate::transport_handle::TransportHandle>,
mut dht_config: DhtNetworkConfig,
adaptive_config: AdaptiveDhtConfig,
) -> Result<Self> {
adaptive_config.validate()?;
dht_config.block_threshold = adaptive_config.block_threshold;
let trust_engine = Arc::new(TrustEngine::new());
let dht_manager = Arc::new(
DhtNetworkManager::new(transport, Some(trust_engine.clone()), dht_config).await?,
);
Ok(Self {
dht_manager,
trust_engine,
config: adaptive_config,
})
}
pub async fn report_trust_event(&self, peer_id: &PeerId, event: TrustEvent) {
match event {
TrustEvent::ApplicationSuccess(weight) | TrustEvent::ApplicationFailure(weight) => {
if weight > 0.0 {
let clamped_weight = weight.min(MAX_CONSUMER_WEIGHT);
self.trust_engine.update_node_stats_weighted(
peer_id,
event.to_stats_update(),
clamped_weight,
);
}
}
_ => {
self.trust_engine
.update_node_stats(peer_id, event.to_stats_update());
}
}
if self.config.block_threshold > 0.0
&& self.trust_engine.score(peer_id) < self.config.block_threshold
{
self.dht_manager.evict_blocked_peer(peer_id).await;
}
}
pub fn peer_trust(&self, peer_id: &PeerId) -> f64 {
self.trust_engine.score(peer_id)
}
pub fn trust_engine(&self) -> &Arc<TrustEngine> {
&self.trust_engine
}
pub fn config(&self) -> &AdaptiveDhtConfig {
&self.config
}
pub fn dht_manager(&self) -> &Arc<DhtNetworkManager> {
&self.dht_manager
}
pub async fn start(&self) -> Result<()> {
Arc::clone(&self.dht_manager).start().await
}
pub async fn stop(&self) -> Result<()> {
self.dht_manager.stop().await
}
pub async fn trigger_self_lookup(&self) -> Result<()> {
self.dht_manager.trigger_self_lookup().await
}
pub(crate) async fn peer_addresses_for_dial(&self, peer_id: &PeerId) -> Vec<MultiAddr> {
self.dht_manager.peer_addresses_for_dial(peer_id).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::adaptive::trust::DEFAULT_NEUTRAL_TRUST;
#[test]
fn test_trust_event_mapping() {
assert!(matches!(
TrustEvent::SuccessfulResponse.to_stats_update(),
NodeStatisticsUpdate::CorrectResponse
));
assert!(matches!(
TrustEvent::SuccessfulConnection.to_stats_update(),
NodeStatisticsUpdate::CorrectResponse
));
assert!(matches!(
TrustEvent::ApplicationSuccess(1.0).to_stats_update(),
NodeStatisticsUpdate::CorrectResponse
));
assert!(matches!(
TrustEvent::ConnectionFailed.to_stats_update(),
NodeStatisticsUpdate::FailedResponse
));
assert!(matches!(
TrustEvent::ConnectionTimeout.to_stats_update(),
NodeStatisticsUpdate::FailedResponse
));
assert!(matches!(
TrustEvent::ApplicationFailure(1.0).to_stats_update(),
NodeStatisticsUpdate::FailedResponse
));
}
#[test]
fn test_adaptive_dht_config_defaults() {
let config = AdaptiveDhtConfig::default();
assert!((config.block_threshold - DEFAULT_BLOCK_THRESHOLD).abs() < f64::EPSILON);
}
#[test]
fn test_block_threshold_validation_rejects_invalid() {
for &bad in &[
-0.1,
0.5,
1.0,
1.1,
f64::NAN,
f64::INFINITY,
f64::NEG_INFINITY,
] {
let config = AdaptiveDhtConfig {
block_threshold: bad,
};
assert!(
config.validate().is_err(),
"block_threshold {bad} should fail validation"
);
}
}
#[test]
fn test_block_threshold_validation_accepts_valid() {
for &good in &[0.0, 0.15, 0.49] {
let config = AdaptiveDhtConfig {
block_threshold: good,
};
assert!(
config.validate().is_ok(),
"block_threshold {good} should pass validation"
);
}
}
#[tokio::test]
async fn test_trust_events_affect_scores() {
let engine = Arc::new(TrustEngine::new());
let peer = PeerId::random();
assert!((engine.score(&peer) - DEFAULT_NEUTRAL_TRUST).abs() < f64::EPSILON);
for _ in 0..10 {
engine.update_node_stats(&peer, TrustEvent::SuccessfulResponse.to_stats_update());
}
assert!(engine.score(&peer) > DEFAULT_NEUTRAL_TRUST);
}
#[tokio::test]
async fn test_failures_reduce_trust_below_block_threshold() {
let engine = Arc::new(TrustEngine::new());
let bad_peer = PeerId::random();
for _ in 0..20 {
engine.update_node_stats(&bad_peer, TrustEvent::ConnectionFailed.to_stats_update());
}
let trust = engine.score(&bad_peer);
assert!(
trust < DEFAULT_BLOCK_THRESHOLD,
"Bad peer trust {trust} should be below block threshold {DEFAULT_BLOCK_THRESHOLD}"
);
}
#[tokio::test]
async fn test_trust_scores_bounded() {
let engine = Arc::new(TrustEngine::new());
let peer = PeerId::random();
for _ in 0..100 {
engine.update_node_stats(&peer, NodeStatisticsUpdate::CorrectResponse);
}
let score = engine.score(&peer);
assert!(score >= 0.0, "Score must be >= 0.0, got {score}");
assert!(score <= 1.0, "Score must be <= 1.0, got {score}");
}
#[test]
fn test_all_trust_events_produce_valid_updates() {
let events = [
TrustEvent::SuccessfulResponse,
TrustEvent::SuccessfulConnection,
TrustEvent::ConnectionFailed,
TrustEvent::ConnectionTimeout,
TrustEvent::ApplicationSuccess(1.0),
TrustEvent::ApplicationFailure(3.0),
];
for event in events {
let _update = event.to_stats_update();
}
}
#[tokio::test]
async fn test_peer_lifecycle_block_and_recovery() {
let engine = TrustEngine::new();
let peer = PeerId::random();
assert!(
engine.score(&peer) >= DEFAULT_BLOCK_THRESHOLD,
"New peer should not be blocked"
);
for _ in 0..20 {
engine.update_node_stats(&peer, NodeStatisticsUpdate::CorrectResponse);
}
let good_score = engine.score(&peer);
assert!(
good_score > DEFAULT_NEUTRAL_TRUST,
"Trusted peer: {good_score}"
);
for _ in 0..200 {
engine.update_node_stats(&peer, NodeStatisticsUpdate::FailedResponse);
}
let bad_score = engine.score(&peer);
assert!(
bad_score < DEFAULT_BLOCK_THRESHOLD,
"After many failures, peer should be blocked: {bad_score}"
);
let three_days = std::time::Duration::from_secs(3 * 24 * 3600);
engine.simulate_elapsed(&peer, three_days).await;
let recovered_score = engine.score(&peer);
assert!(
recovered_score >= DEFAULT_BLOCK_THRESHOLD,
"After 3 days idle, peer should be unblocked: {recovered_score}"
);
}
#[tokio::test]
async fn test_block_threshold_is_binary() {
let engine = TrustEngine::new();
let threshold = DEFAULT_BLOCK_THRESHOLD;
let peer_above = PeerId::random();
let peer_below = PeerId::random();
for _ in 0..5 {
engine.update_node_stats(&peer_above, NodeStatisticsUpdate::CorrectResponse);
}
assert!(
engine.score(&peer_above) >= threshold,
"Peer with successes should be above threshold"
);
for _ in 0..50 {
engine.update_node_stats(&peer_below, NodeStatisticsUpdate::FailedResponse);
}
assert!(
engine.score(&peer_below) < threshold,
"Peer with only failures should be below threshold"
);
let unknown = PeerId::random();
assert!(
engine.score(&unknown) >= threshold,
"Unknown peer at neutral should not be blocked"
);
}
#[tokio::test]
async fn test_single_failure_does_not_block() {
let engine = TrustEngine::new();
let peer = PeerId::random();
engine.update_node_stats(&peer, NodeStatisticsUpdate::FailedResponse);
assert!(
engine.score(&peer) >= DEFAULT_BLOCK_THRESHOLD,
"One failure from neutral should not block: {}",
engine.score(&peer)
);
}
#[tokio::test]
async fn test_trusted_peer_resilient_to_occasional_failures() {
let engine = TrustEngine::new();
let peer = PeerId::random();
for _ in 0..50 {
engine.update_node_stats(&peer, NodeStatisticsUpdate::CorrectResponse);
}
let trusted_score = engine.score(&peer);
for _ in 0..3 {
engine.update_node_stats(&peer, NodeStatisticsUpdate::FailedResponse);
}
assert!(
engine.score(&peer) >= DEFAULT_BLOCK_THRESHOLD,
"3 failures after 50 successes should not block: {}",
engine.score(&peer)
);
assert!(
engine.score(&peer) < trusted_score,
"Score should have decreased"
);
}
#[tokio::test]
async fn test_removed_peer_starts_fresh() {
let engine = TrustEngine::new();
let peer = PeerId::random();
for _ in 0..100 {
engine.update_node_stats(&peer, NodeStatisticsUpdate::FailedResponse);
}
assert!(engine.score(&peer) < DEFAULT_BLOCK_THRESHOLD);
engine.remove_node(&peer);
assert!(
(engine.score(&peer) - DEFAULT_NEUTRAL_TRUST).abs() < f64::EPSILON,
"Removed peer should return to neutral"
);
}
#[tokio::test]
async fn test_consumer_reward_improves_trust() {
let engine = Arc::new(TrustEngine::new());
let peer = PeerId::random();
let before = engine.score(&peer);
engine.update_node_stats(&peer, TrustEvent::ApplicationSuccess(1.0).to_stats_update());
let after = engine.score(&peer);
assert!(
after > before,
"consumer reward should improve trust: {before} -> {after}"
);
}
#[tokio::test]
async fn test_higher_weight_larger_impact() {
let engine = Arc::new(TrustEngine::new());
let peer_a = PeerId::random();
let peer_b = PeerId::random();
engine.update_node_stats_weighted(&peer_a, NodeStatisticsUpdate::FailedResponse, 1.0);
engine.update_node_stats_weighted(&peer_b, NodeStatisticsUpdate::FailedResponse, 5.0);
assert!(
engine.score(&peer_b) < engine.score(&peer_a),
"weight-5 failure should have larger impact than weight-1"
);
}
#[tokio::test]
async fn test_zero_negative_weights_noop() {
let engine = Arc::new(TrustEngine::new());
let peer = PeerId::random();
let neutral = engine.score(&peer);
engine.update_node_stats_weighted(&peer, NodeStatisticsUpdate::FailedResponse, 0.0);
let after_zero = engine.score(&peer);
assert!(
(after_zero - neutral).abs() < 1e-10,
"zero-weight should not change score: {neutral} -> {after_zero}"
);
}
#[tokio::test]
async fn test_trust_engine_does_not_clamp_weights() {
let engine = Arc::new(TrustEngine::new());
let peer_clamped = PeerId::random();
let peer_unclamped = PeerId::random();
engine.update_node_stats_weighted(
&peer_clamped,
NodeStatisticsUpdate::FailedResponse,
MAX_CONSUMER_WEIGHT,
);
let score_at_max = engine.score(&peer_clamped);
engine.update_node_stats_weighted(
&peer_unclamped,
NodeStatisticsUpdate::FailedResponse,
100.0,
);
let score_at_100 = engine.score(&peer_unclamped);
assert!(
score_at_100 < score_at_max,
"TrustEngine should not clamp: weight-100 ({score_at_100}) should have more impact than weight-{MAX_CONSUMER_WEIGHT} ({score_at_max})"
);
}
#[tokio::test]
async fn test_consumer_penalty_triggers_blocking() {
let engine = Arc::new(TrustEngine::new());
let peer = PeerId::random();
for _ in 0..5 {
engine.update_node_stats(&peer, NodeStatisticsUpdate::FailedResponse);
}
let score_before = engine.score(&peer);
assert!(
score_before > DEFAULT_BLOCK_THRESHOLD,
"should be above block threshold: {score_before}"
);
for _ in 0..10 {
engine.update_node_stats_weighted(
&peer,
NodeStatisticsUpdate::FailedResponse,
MAX_CONSUMER_WEIGHT,
);
}
let score_after = engine.score(&peer);
assert!(
score_after < DEFAULT_BLOCK_THRESHOLD,
"after heavy consumer failures, score {score_after} should be below block threshold {DEFAULT_BLOCK_THRESHOLD}"
);
}
#[test]
fn test_consumer_event_direction_mapping() {
let success_events = [
TrustEvent::ApplicationSuccess(0.5),
TrustEvent::ApplicationSuccess(1.0),
TrustEvent::ApplicationSuccess(5.0),
];
for event in success_events {
assert!(
matches!(
event.to_stats_update(),
NodeStatisticsUpdate::CorrectResponse
),
"{event:?} should map to CorrectResponse"
);
}
let failure_events = [
TrustEvent::ApplicationFailure(0.5),
TrustEvent::ApplicationFailure(1.0),
TrustEvent::ApplicationFailure(5.0),
];
for event in failure_events {
assert!(
matches!(
event.to_stats_update(),
NodeStatisticsUpdate::FailedResponse
),
"{event:?} should map to FailedResponse"
);
}
}
}