Skip to main content

chia_query/peer/
pool.rs

1use std::net::SocketAddr;
2use std::sync::atomic::{AtomicU32, AtomicUsize, Ordering};
3use std::sync::Arc;
4use std::time::Duration;
5
6use chia::protocol::{Message, NewPeakWallet, ProtocolMessageTypes};
7use chia::traits::Streamable;
8use futures_util::stream::{FuturesUnordered, StreamExt};
9use tokio::sync::{mpsc, RwLock};
10
11use chia_wallet_sdk::client::Peer;
12use tokio_tungstenite::Connector;
13
14use crate::types::ChiaQueryError;
15use crate::NetworkType;
16
17use super::connect;
18
19// ---------------------------------------------------------------------------
20// Pool entry
21// ---------------------------------------------------------------------------
22
23struct PeerEntry {
24    peer: Peer,
25    address: SocketAddr,
26}
27
28// ---------------------------------------------------------------------------
29// PeerPool
30// ---------------------------------------------------------------------------
31
32pub struct PeerPool {
33    entries: RwLock<Vec<PeerEntry>>,
34    next_idx: AtomicUsize,
35    max_peers: usize,
36    tls: Connector,
37    network: NetworkType,
38    connect_timeout: Duration,
39    /// Latest peak height observed from any connected peer's NewPeakWallet
40    /// messages.  Updated in the background by receiver handler tasks.
41    peak_height: Arc<AtomicU32>,
42}
43
44impl PeerPool {
45    /// Spin up the pool by connecting to `max_peers` random full-node peers
46    /// concurrently.  At least one peer must succeed; otherwise we return
47    /// [`ChiaQueryError::PeerDiscoveryFailed`].
48    pub async fn new(
49        network: NetworkType,
50        tls: Connector,
51        max_peers: usize,
52        connect_timeout: Duration,
53    ) -> Result<Self, ChiaQueryError> {
54        let peak_height = Arc::new(AtomicU32::new(0));
55
56        // Connect to peers concurrently.
57        let mut futures = FuturesUnordered::new();
58        for _ in 0..max_peers {
59            let t = tls.clone();
60            futures.push(async move {
61                connect::connect_random_peer(network, &t, connect_timeout).await
62            });
63        }
64
65        let mut initial: Vec<PeerEntry> = Vec::new();
66        let mut receivers = Vec::new();
67        while let Some(result) = futures.next().await {
68            match result {
69                Ok((peer, addr, receiver)) => {
70                    initial.push(PeerEntry {
71                        peer,
72                        address: addr,
73                    });
74                    receivers.push(receiver);
75                }
76                Err(e) => log::debug!("initial peer connect failed: {e}"),
77            }
78        }
79
80        if initial.is_empty() {
81            return Err(ChiaQueryError::PeerDiscoveryFailed);
82        }
83
84        let pool = Self {
85            entries: RwLock::new(initial),
86            next_idx: AtomicUsize::new(0),
87            max_peers,
88            tls,
89            network,
90            connect_timeout,
91            peak_height,
92        };
93
94        // Spawn receiver handlers for initial peers (must happen after pool
95        // construction so peak_height Arc is available).
96        for receiver in receivers {
97            pool.spawn_receiver_handler(receiver);
98        }
99
100        Ok(pool)
101    }
102
103    /// Latest peak height observed across all connected peers.
104    /// Returns 0 if no peak has been received yet.
105    pub fn peak_height(&self) -> u32 {
106        self.peak_height.load(Ordering::Relaxed)
107    }
108
109    /// Round-robin select a peer from the pool.
110    /// Returns `None` when the pool is empty.
111    pub async fn select_peer(&self) -> Option<(Peer, SocketAddr)> {
112        let entries = self.entries.read().await;
113        if entries.is_empty() {
114            return None;
115        }
116        let idx = self.next_idx.fetch_add(1, Ordering::Relaxed) % entries.len();
117        let entry = &entries[idx];
118        Some((entry.peer.clone(), entry.address))
119    }
120
121    /// Remove a peer from the pool and asynchronously connect a replacement.
122    pub async fn eject_peer(&self, addr: SocketAddr) {
123        {
124            let mut entries = self.entries.write().await;
125            entries.retain(|e| e.address != addr);
126        }
127        log::debug!(
128            "peer ejected from pool; will refill on next request (network={:?})",
129            self.network,
130        );
131    }
132
133    /// Whether the pool has at least one usable peer.
134    pub async fn has_peers(&self) -> bool {
135        !self.entries.read().await.is_empty()
136    }
137
138    /// If the pool is under capacity, try to connect one new peer.
139    /// Also spawns a background task to handle its inbound `NewPeakWallet`
140    /// messages.
141    pub async fn try_refill(&self) {
142        let current = self.entries.read().await.len();
143        if current >= self.max_peers {
144            return;
145        }
146        match connect::connect_random_peer(self.network, &self.tls, self.connect_timeout).await {
147            Ok((peer, addr, receiver)) => {
148                self.spawn_receiver_handler(receiver);
149                let mut entries = self.entries.write().await;
150                if entries.len() < self.max_peers {
151                    entries.push(PeerEntry {
152                        peer,
153                        address: addr,
154                    });
155                    log::debug!("replacement peer connected: {addr}");
156                }
157            }
158            Err(e) => log::warn!("replacement peer connect failed: {e}"),
159        }
160    }
161
162    // -----------------------------------------------------------------------
163    // Receiver helpers (handle NewPeakWallet from peers)
164    // -----------------------------------------------------------------------
165
166    /// Spawn a background task that reads inbound messages from a peer's
167    /// receiver channel and updates the shared peak height.  This mirrors
168    /// the pattern used by chia-block-listener.
169    pub fn spawn_receiver_handler(&self, mut receiver: mpsc::Receiver<Message>) {
170        let peak = Arc::clone(&self.peak_height);
171        tokio::spawn(async move {
172            while let Some(msg) = receiver.recv().await {
173                if msg.msg_type == ProtocolMessageTypes::NewPeakWallet {
174                    if let Ok(new_peak) = NewPeakWallet::from_bytes(&msg.data) {
175                        let prev = peak.fetch_max(new_peak.height, Ordering::Relaxed);
176                        if new_peak.height > prev {
177                            log::debug!("new peak from peer: {}", new_peak.height);
178                        }
179                    }
180                }
181            }
182        });
183    }
184}