1use std::net::SocketAddr;
2use std::sync::atomic::{AtomicU32, AtomicUsize, Ordering};
3use std::sync::Arc;
4use std::time::Duration;
5
6use chia::protocol::{Message, NewPeakWallet, ProtocolMessageTypes};
7use chia::traits::Streamable;
8use futures_util::stream::{FuturesUnordered, StreamExt};
9use tokio::sync::{mpsc, RwLock};
10
11use chia_wallet_sdk::client::Peer;
12use tokio_tungstenite::Connector;
13
14use crate::types::ChiaQueryError;
15use crate::NetworkType;
16
17use super::connect;
18
19struct PeerEntry {
24 peer: Peer,
25 address: SocketAddr,
26}
27
28pub struct PeerPool {
33 entries: RwLock<Vec<PeerEntry>>,
34 next_idx: AtomicUsize,
35 max_peers: usize,
36 tls: Connector,
37 network: NetworkType,
38 connect_timeout: Duration,
39 peak_height: Arc<AtomicU32>,
42}
43
44impl PeerPool {
45 pub async fn new(
49 network: NetworkType,
50 tls: Connector,
51 max_peers: usize,
52 connect_timeout: Duration,
53 ) -> Result<Self, ChiaQueryError> {
54 let peak_height = Arc::new(AtomicU32::new(0));
55
56 let mut futures = FuturesUnordered::new();
58 for _ in 0..max_peers {
59 let t = tls.clone();
60 futures.push(async move {
61 connect::connect_random_peer(network, &t, connect_timeout).await
62 });
63 }
64
65 let mut initial: Vec<PeerEntry> = Vec::new();
66 let mut receivers = Vec::new();
67 while let Some(result) = futures.next().await {
68 match result {
69 Ok((peer, addr, receiver)) => {
70 initial.push(PeerEntry {
71 peer,
72 address: addr,
73 });
74 receivers.push(receiver);
75 }
76 Err(e) => log::debug!("initial peer connect failed: {e}"),
77 }
78 }
79
80 if initial.is_empty() {
81 return Err(ChiaQueryError::PeerDiscoveryFailed);
82 }
83
84 let pool = Self {
85 entries: RwLock::new(initial),
86 next_idx: AtomicUsize::new(0),
87 max_peers,
88 tls,
89 network,
90 connect_timeout,
91 peak_height,
92 };
93
94 for receiver in receivers {
97 pool.spawn_receiver_handler(receiver);
98 }
99
100 Ok(pool)
101 }
102
103 pub fn peak_height(&self) -> u32 {
106 self.peak_height.load(Ordering::Relaxed)
107 }
108
109 pub async fn select_peer(&self) -> Option<(Peer, SocketAddr)> {
112 let entries = self.entries.read().await;
113 if entries.is_empty() {
114 return None;
115 }
116 let idx = self.next_idx.fetch_add(1, Ordering::Relaxed) % entries.len();
117 let entry = &entries[idx];
118 Some((entry.peer.clone(), entry.address))
119 }
120
121 pub async fn eject_peer(&self, addr: SocketAddr) {
123 {
124 let mut entries = self.entries.write().await;
125 entries.retain(|e| e.address != addr);
126 }
127 log::debug!(
128 "peer ejected from pool; will refill on next request (network={:?})",
129 self.network,
130 );
131 }
132
133 pub async fn has_peers(&self) -> bool {
135 !self.entries.read().await.is_empty()
136 }
137
138 pub async fn try_refill(&self) {
142 let current = self.entries.read().await.len();
143 if current >= self.max_peers {
144 return;
145 }
146 match connect::connect_random_peer(self.network, &self.tls, self.connect_timeout).await {
147 Ok((peer, addr, receiver)) => {
148 self.spawn_receiver_handler(receiver);
149 let mut entries = self.entries.write().await;
150 if entries.len() < self.max_peers {
151 entries.push(PeerEntry {
152 peer,
153 address: addr,
154 });
155 log::debug!("replacement peer connected: {addr}");
156 }
157 }
158 Err(e) => log::warn!("replacement peer connect failed: {e}"),
159 }
160 }
161
162 pub fn spawn_receiver_handler(&self, mut receiver: mpsc::Receiver<Message>) {
170 let peak = Arc::clone(&self.peak_height);
171 tokio::spawn(async move {
172 while let Some(msg) = receiver.recv().await {
173 if msg.msg_type == ProtocolMessageTypes::NewPeakWallet {
174 if let Ok(new_peak) = NewPeakWallet::from_bytes(&msg.data) {
175 let prev = peak.fetch_max(new_peak.height, Ordering::Relaxed);
176 if new_peak.height > prev {
177 log::debug!("new peak from peer: {}", new_peak.height);
178 }
179 }
180 }
181 }
182 });
183 }
184}