1use libp2p::PeerId;
10use parking_lot::RwLock;
11use serde::Serialize;
12use std::collections::{HashMap, HashSet};
13use std::time::{Duration, Instant};
14use tracing::{debug, info, warn};
15
16#[derive(Debug, Clone)]
18pub struct ConnectionLimitsConfig {
19 pub max_connections: usize,
21 pub max_inbound: usize,
23 pub max_outbound: usize,
25 pub reserved_slots: usize,
27 pub idle_timeout: Duration,
29 pub min_score_threshold: u8,
31}
32
33impl Default for ConnectionLimitsConfig {
34 fn default() -> Self {
35 Self {
36 max_connections: 256,
37 max_inbound: 128,
38 max_outbound: 128,
39 reserved_slots: 8,
40 idle_timeout: Duration::from_secs(300),
41 min_score_threshold: 30,
42 }
43 }
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum ConnectionDirection {
49 Inbound,
50 Outbound,
51}
52
53#[derive(Debug, Clone)]
55struct ConnectionInfo {
56 peer_id: PeerId,
58 direction: ConnectionDirection,
60 established_at: Instant,
62 last_activity: Instant,
64 score: u8,
66 reserved: bool,
68 messages_sent: u64,
70 messages_received: u64,
72 avg_latency_ms: Option<u64>,
74}
75
76impl ConnectionInfo {
77 fn new(peer_id: PeerId, direction: ConnectionDirection) -> Self {
78 let now = Instant::now();
79 Self {
80 peer_id,
81 direction,
82 established_at: now,
83 last_activity: now,
84 score: 50, reserved: false,
86 messages_sent: 0,
87 messages_received: 0,
88 avg_latency_ms: None,
89 }
90 }
91
92 fn is_idle(&self, timeout: Duration) -> bool {
93 self.last_activity.elapsed() > timeout
94 }
95
96 fn touch(&mut self) {
97 self.last_activity = Instant::now();
98 }
99
100 fn calculate_value(&self) -> u64 {
102 let age_secs = self.established_at.elapsed().as_secs();
103 let activity = self.messages_sent + self.messages_received;
104
105 let base_value = self.score as u64 * 10;
107 let activity_rate = if age_secs > 0 {
108 activity * 60 / age_secs } else {
110 activity * 60
111 };
112 let latency_bonus = match self.avg_latency_ms {
113 Some(lat) if lat < 50 => 20,
114 Some(lat) if lat < 100 => 10,
115 Some(lat) if lat < 200 => 5,
116 _ => 0,
117 };
118
119 base_value + activity_rate + latency_bonus
120 }
121}
122
123pub struct ConnectionManager {
125 config: ConnectionLimitsConfig,
127 connections: RwLock<HashMap<PeerId, ConnectionInfo>>,
129 reserved_peers: RwLock<HashSet<PeerId>>,
131 banned_peers: RwLock<HashSet<PeerId>>,
133}
134
135impl ConnectionManager {
136 pub fn new(config: ConnectionLimitsConfig) -> Self {
138 Self {
139 config,
140 connections: RwLock::new(HashMap::new()),
141 reserved_peers: RwLock::new(HashSet::new()),
142 banned_peers: RwLock::new(HashSet::new()),
143 }
144 }
145
146 pub fn should_accept(&self, peer_id: &PeerId, direction: ConnectionDirection) -> bool {
148 if self.banned_peers.read().contains(peer_id) {
150 debug!("Rejecting banned peer: {}", peer_id);
151 return false;
152 }
153
154 if self.reserved_peers.read().contains(peer_id) {
156 let reserved_count = self
157 .connections
158 .read()
159 .values()
160 .filter(|c| c.reserved)
161 .count();
162 if reserved_count < self.config.reserved_slots {
163 return true;
164 }
165 }
166
167 let connections = self.connections.read();
168
169 if connections.len() >= self.config.max_connections {
171 debug!(
172 "At max connections ({}), rejecting {}",
173 self.config.max_connections, peer_id
174 );
175 return false;
176 }
177
178 let (inbound, outbound) =
180 connections
181 .values()
182 .fold((0, 0), |(i, o), c| match c.direction {
183 ConnectionDirection::Inbound => (i + 1, o),
184 ConnectionDirection::Outbound => (i, o + 1),
185 });
186
187 match direction {
188 ConnectionDirection::Inbound => {
189 if inbound >= self.config.max_inbound {
190 debug!(
191 "At max inbound ({}), rejecting {}",
192 self.config.max_inbound, peer_id
193 );
194 return false;
195 }
196 }
197 ConnectionDirection::Outbound => {
198 if outbound >= self.config.max_outbound {
199 debug!(
200 "At max outbound ({}), rejecting {}",
201 self.config.max_outbound, peer_id
202 );
203 return false;
204 }
205 }
206 }
207
208 true
209 }
210
211 pub fn connection_established(&self, peer_id: PeerId, direction: ConnectionDirection) {
213 let is_reserved = self.reserved_peers.read().contains(&peer_id);
214
215 let mut connections = self.connections.write();
216 let mut info = ConnectionInfo::new(peer_id, direction);
217 info.reserved = is_reserved;
218
219 connections.insert(peer_id, info);
220 info!("Connection established: {} ({:?})", peer_id, direction);
221 }
222
223 pub fn connection_closed(&self, peer_id: &PeerId) {
225 let mut connections = self.connections.write();
226 if connections.remove(peer_id).is_some() {
227 debug!("Connection closed: {}", peer_id);
228 }
229 }
230
231 pub fn record_activity(&self, peer_id: &PeerId, sent: bool) {
233 let mut connections = self.connections.write();
234 if let Some(info) = connections.get_mut(peer_id) {
235 info.touch();
236 if sent {
237 info.messages_sent += 1;
238 } else {
239 info.messages_received += 1;
240 }
241 }
242 }
243
244 pub fn update_score(&self, peer_id: &PeerId, delta: i16) {
246 let mut connections = self.connections.write();
247 if let Some(info) = connections.get_mut(peer_id) {
248 let new_score = (info.score as i16 + delta).clamp(0, 100) as u8;
249 info.score = new_score;
250 }
251 }
252
253 pub fn update_latency(&self, peer_id: &PeerId, latency_ms: u64) {
255 let mut connections = self.connections.write();
256 if let Some(info) = connections.get_mut(peer_id) {
257 info.avg_latency_ms = Some(latency_ms);
258 info.touch();
259 }
260 }
261
262 pub fn add_reserved(&self, peer_id: PeerId) {
264 self.reserved_peers.write().insert(peer_id);
265
266 if let Some(info) = self.connections.write().get_mut(&peer_id) {
268 info.reserved = true;
269 }
270
271 info!("Added reserved peer: {}", peer_id);
272 }
273
274 pub fn remove_reserved(&self, peer_id: &PeerId) {
276 self.reserved_peers.write().remove(peer_id);
277
278 if let Some(info) = self.connections.write().get_mut(peer_id) {
280 info.reserved = false;
281 }
282
283 debug!("Removed reserved peer: {}", peer_id);
284 }
285
286 pub fn ban_peer(&self, peer_id: PeerId) {
288 self.banned_peers.write().insert(peer_id);
289 self.reserved_peers.write().remove(&peer_id);
290 warn!("Banned peer: {}", peer_id);
291 }
292
293 pub fn unban_peer(&self, peer_id: &PeerId) {
295 self.banned_peers.write().remove(peer_id);
296 info!("Unbanned peer: {}", peer_id);
297 }
298
299 pub fn is_banned(&self, peer_id: &PeerId) -> bool {
301 self.banned_peers.read().contains(peer_id)
302 }
303
304 pub fn get_prune_candidates(&self, count: usize) -> Vec<PeerId> {
306 let connections = self.connections.read();
307
308 let mut candidates: Vec<_> = connections
310 .values()
311 .filter(|c| !c.reserved && c.score < self.config.min_score_threshold)
312 .map(|c| (c.peer_id, c.calculate_value()))
313 .collect();
314
315 candidates.sort_by_key(|(_, value)| *value);
317
318 candidates
319 .into_iter()
320 .take(count)
321 .map(|(peer_id, _)| peer_id)
322 .collect()
323 }
324
325 pub fn get_idle_connections(&self) -> Vec<PeerId> {
327 let connections = self.connections.read();
328 let timeout = self.config.idle_timeout;
329
330 connections
331 .values()
332 .filter(|c| !c.reserved && c.is_idle(timeout))
333 .map(|c| c.peer_id)
334 .collect()
335 }
336
337 pub fn prune_to_limit(&self) -> Vec<PeerId> {
341 let connections = self.connections.read();
342 let current = connections.len();
343
344 if current <= self.config.max_connections {
345 return vec![];
346 }
347
348 let to_prune = current - self.config.max_connections;
349 drop(connections);
350
351 let candidates = self.get_prune_candidates(to_prune);
352 info!(
353 "Pruning {} connections to stay within limit",
354 candidates.len()
355 );
356 candidates
357 }
358
359 pub fn connected_peers(&self) -> Vec<PeerId> {
361 self.connections.read().keys().cloned().collect()
362 }
363
364 pub fn connection_count(&self) -> usize {
366 self.connections.read().len()
367 }
368
369 pub fn is_connected(&self, peer_id: &PeerId) -> bool {
371 self.connections.read().contains_key(peer_id)
372 }
373
374 pub fn stats(&self) -> ConnectionManagerStats {
376 let connections = self.connections.read();
377
378 let (inbound, outbound) =
379 connections
380 .values()
381 .fold((0, 0), |(i, o), c| match c.direction {
382 ConnectionDirection::Inbound => (i + 1, o),
383 ConnectionDirection::Outbound => (i, o + 1),
384 });
385
386 let reserved = connections.values().filter(|c| c.reserved).count();
387
388 let avg_score = if connections.is_empty() {
389 0
390 } else {
391 connections.values().map(|c| c.score as u64).sum::<u64>() / connections.len() as u64
392 };
393
394 ConnectionManagerStats {
395 total_connections: connections.len(),
396 max_connections: self.config.max_connections,
397 inbound_connections: inbound,
398 outbound_connections: outbound,
399 reserved_connections: reserved,
400 banned_peers: self.banned_peers.read().len(),
401 average_score: avg_score as u8,
402 }
403 }
404}
405
406impl Default for ConnectionManager {
407 fn default() -> Self {
408 Self::new(ConnectionLimitsConfig::default())
409 }
410}
411
412#[derive(Debug, Clone, Serialize)]
414pub struct ConnectionManagerStats {
415 pub total_connections: usize,
417 pub max_connections: usize,
419 pub inbound_connections: usize,
421 pub outbound_connections: usize,
423 pub reserved_connections: usize,
425 pub banned_peers: usize,
427 pub average_score: u8,
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434
435 fn random_peer() -> PeerId {
436 PeerId::random()
437 }
438
439 #[test]
440 fn test_connection_manager_basic() {
441 let manager = ConnectionManager::default();
442 let peer1 = random_peer();
443 let peer2 = random_peer();
444
445 assert!(manager.should_accept(&peer1, ConnectionDirection::Inbound));
446
447 manager.connection_established(peer1, ConnectionDirection::Inbound);
448 assert!(manager.is_connected(&peer1));
449 assert_eq!(manager.connection_count(), 1);
450
451 manager.connection_established(peer2, ConnectionDirection::Outbound);
452 assert_eq!(manager.connection_count(), 2);
453
454 manager.connection_closed(&peer1);
455 assert!(!manager.is_connected(&peer1));
456 assert_eq!(manager.connection_count(), 1);
457 }
458
459 #[test]
460 fn test_connection_limits() {
461 let config = ConnectionLimitsConfig {
462 max_connections: 3,
463 max_inbound: 2,
464 max_outbound: 2,
465 ..Default::default()
466 };
467 let manager = ConnectionManager::new(config);
468
469 let peer1 = random_peer();
471 let peer2 = random_peer();
472 manager.connection_established(peer1, ConnectionDirection::Inbound);
473 manager.connection_established(peer2, ConnectionDirection::Inbound);
474
475 let peer3 = random_peer();
477 assert!(!manager.should_accept(&peer3, ConnectionDirection::Inbound));
478
479 assert!(manager.should_accept(&peer3, ConnectionDirection::Outbound));
481 manager.connection_established(peer3, ConnectionDirection::Outbound);
482
483 let peer4 = random_peer();
485 assert!(!manager.should_accept(&peer4, ConnectionDirection::Inbound));
486 assert!(!manager.should_accept(&peer4, ConnectionDirection::Outbound));
487 }
488
489 #[test]
490 fn test_reserved_peers() {
491 let config = ConnectionLimitsConfig {
492 max_connections: 2,
493 reserved_slots: 1,
494 ..Default::default()
495 };
496 let manager = ConnectionManager::new(config);
497
498 let reserved_peer = random_peer();
499 manager.add_reserved(reserved_peer);
500
501 let peer1 = random_peer();
502 let peer2 = random_peer();
503 manager.connection_established(peer1, ConnectionDirection::Inbound);
504 manager.connection_established(peer2, ConnectionDirection::Outbound);
505
506 assert!(manager.should_accept(&reserved_peer, ConnectionDirection::Inbound));
508 }
509
510 #[test]
511 fn test_banned_peers() {
512 let manager = ConnectionManager::default();
513 let peer = random_peer();
514
515 assert!(manager.should_accept(&peer, ConnectionDirection::Inbound));
516
517 manager.ban_peer(peer);
518 assert!(manager.is_banned(&peer));
519 assert!(!manager.should_accept(&peer, ConnectionDirection::Inbound));
520
521 manager.unban_peer(&peer);
522 assert!(!manager.is_banned(&peer));
523 assert!(manager.should_accept(&peer, ConnectionDirection::Inbound));
524 }
525
526 #[test]
527 fn test_activity_tracking() {
528 let manager = ConnectionManager::default();
529 let peer = random_peer();
530
531 manager.connection_established(peer, ConnectionDirection::Outbound);
532
533 manager.record_activity(&peer, true); manager.record_activity(&peer, false); manager.record_activity(&peer, true); let stats = manager.stats();
539 assert_eq!(stats.total_connections, 1);
540 }
541
542 #[test]
543 fn test_score_update() {
544 let manager = ConnectionManager::default();
545 let peer = random_peer();
546
547 manager.connection_established(peer, ConnectionDirection::Inbound);
548 manager.update_score(&peer, 20); manager.update_score(&peer, -40); manager.update_score(&peer, -100); }
554
555 #[test]
556 fn test_prune_candidates() {
557 let config = ConnectionLimitsConfig {
558 min_score_threshold: 50,
559 ..Default::default()
560 };
561 let manager = ConnectionManager::new(config);
562
563 let high_score = random_peer();
565 let low_score1 = random_peer();
566 let low_score2 = random_peer();
567 let reserved = random_peer();
568
569 manager.connection_established(high_score, ConnectionDirection::Inbound);
570 manager.connection_established(low_score1, ConnectionDirection::Inbound);
571 manager.connection_established(low_score2, ConnectionDirection::Outbound);
572 manager.add_reserved(reserved);
573 manager.connection_established(reserved, ConnectionDirection::Inbound);
574
575 manager.update_score(&high_score, 30); manager.update_score(&low_score1, -30); manager.update_score(&low_score2, -25); let candidates = manager.get_prune_candidates(2);
582
583 assert!(!candidates.contains(&reserved));
585 assert!(!candidates.contains(&high_score));
586 assert!(candidates.len() <= 2);
587 }
588
589 #[test]
590 fn test_idle_connections() {
591 let config = ConnectionLimitsConfig {
592 idle_timeout: Duration::from_millis(50),
593 ..Default::default()
594 };
595 let manager = ConnectionManager::new(config);
596
597 let peer = random_peer();
598 manager.connection_established(peer, ConnectionDirection::Inbound);
599
600 assert!(manager.get_idle_connections().is_empty());
602
603 std::thread::sleep(Duration::from_millis(100));
605
606 let idle = manager.get_idle_connections();
608 assert_eq!(idle.len(), 1);
609 assert_eq!(idle[0], peer);
610 }
611
612 #[test]
613 fn test_stats() {
614 let manager = ConnectionManager::default();
615
616 let peer1 = random_peer();
617 let peer2 = random_peer();
618 let reserved = random_peer();
619
620 manager.connection_established(peer1, ConnectionDirection::Inbound);
621 manager.connection_established(peer2, ConnectionDirection::Outbound);
622 manager.add_reserved(reserved);
623 manager.connection_established(reserved, ConnectionDirection::Inbound);
624
625 let banned = random_peer();
626 manager.ban_peer(banned);
627
628 let stats = manager.stats();
629 assert_eq!(stats.total_connections, 3);
630 assert_eq!(stats.inbound_connections, 2);
631 assert_eq!(stats.outbound_connections, 1);
632 assert_eq!(stats.reserved_connections, 1);
633 assert_eq!(stats.banned_peers, 1);
634 }
635}