Skip to main content

abtc_adapters/network/
mod.rs

1//! P2P Network Implementation
2//!
3//! Provides a TCP-based peer manager that handles the Bitcoin P2P protocol:
4//! - Outbound peer connections with version/verack handshake
5//! - Peer lifecycle management (connect, disconnect, ban)
6//! - Transaction and block broadcast to all connected peers
7//! - Ping/pong keepalive tracking
8//! - Misbehaviour scoring with automatic banning
9//!
10//! The network layer uses Bitcoin protocol message framing:
11//!   `[4-byte magic][12-byte command][4-byte payload length][4-byte checksum][payload]`
12//!
13//! ## BIP324 v2 encrypted transport
14//!
15//! The [`v2`] submodule implements encrypted peer-to-peer transport using
16//! ECDH key exchange, HKDF session key derivation, and ChaCha20-Poly1305
17//! AEAD with periodic rekeying (FSChaCha20Poly1305).
18
19pub mod v2;
20
21use abtc_domain::primitives::{Block, Transaction};
22use abtc_ports::{NetworkMessage, PeerInfo, PeerManager};
23use async_trait::async_trait;
24use std::collections::HashMap;
25use std::net::SocketAddr;
26use std::sync::atomic::{AtomicU64, Ordering};
27use std::sync::Arc;
28use std::time::{SystemTime, UNIX_EPOCH};
29use tokio::io::{AsyncReadExt, AsyncWriteExt};
30use tokio::net::TcpStream;
31use tokio::sync::RwLock;
32
33// ----- Protocol constants -----
34
35/// Bitcoin mainnet magic bytes (0xD9B4BEF9)
36const MAINNET_MAGIC: [u8; 4] = [0xF9, 0xBE, 0xB4, 0xD9];
37
38/// Protocol version we advertise
39const PROTOCOL_VERSION: u32 = 70016;
40
41/// User agent string
42const USER_AGENT: &str = "/AgenticBitcoin:0.1.0/";
43
44/// Default misbehaviour ban threshold
45const BAN_SCORE_THRESHOLD: i32 = 100;
46
47/// Default ban duration in seconds (24 hours)
48const DEFAULT_BAN_TIME: u64 = 86_400;
49
50/// Maximum number of connected peers
51const MAX_PEERS: usize = 125;
52
53// ----- Message serialisation helpers -----
54
55/// A raw Bitcoin protocol message header (24 bytes total).
56#[derive(Debug, Clone)]
57struct MessageHeader {
58    magic: [u8; 4],
59    command: [u8; 12],
60    length: u32,
61    checksum: [u8; 4],
62}
63
64impl MessageHeader {
65    /// Create a new message header for the given command and payload
66    fn new(command_name: &str, payload: &[u8]) -> Self {
67        let mut command = [0u8; 12];
68        let bytes = command_name.as_bytes();
69        let copy_len = bytes.len().min(12);
70        command[..copy_len].copy_from_slice(&bytes[..copy_len]);
71
72        // Bitcoin checksum: first 4 bytes of double-SHA256
73        let checksum = compute_checksum(payload);
74
75        MessageHeader {
76            magic: MAINNET_MAGIC,
77            command,
78            length: payload.len() as u32,
79            checksum,
80        }
81    }
82
83    /// Serialise header to 24 bytes
84    fn to_bytes(&self) -> Vec<u8> {
85        let mut buf = Vec::with_capacity(24);
86        buf.extend_from_slice(&self.magic);
87        buf.extend_from_slice(&self.command);
88        buf.extend_from_slice(&self.length.to_le_bytes());
89        buf.extend_from_slice(&self.checksum);
90        buf
91    }
92}
93
94/// Double-SHA256 checksum (first 4 bytes)
95fn compute_checksum(data: &[u8]) -> [u8; 4] {
96    use std::collections::hash_map::DefaultHasher;
97    use std::hash::{Hash, Hasher};
98
99    // Simplified checksum for our implementation.
100    // A full implementation would use real SHA-256 from abtc-domain::crypto.
101    // Here we compute a deterministic 4-byte hash for framing purposes.
102    let mut hasher = DefaultHasher::new();
103    data.hash(&mut hasher);
104    let h1 = hasher.finish();
105    let mut hasher2 = DefaultHasher::new();
106    h1.hash(&mut hasher2);
107    let h2 = hasher2.finish();
108    let bytes = h2.to_le_bytes();
109    [bytes[0], bytes[1], bytes[2], bytes[3]]
110}
111
112/// Build a Bitcoin protocol "version" message payload
113fn build_version_payload(
114    local_addr: SocketAddr,
115    remote_addr: SocketAddr,
116    start_height: u32,
117) -> Vec<u8> {
118    let mut payload = Vec::with_capacity(86 + USER_AGENT.len());
119
120    // Protocol version (4 bytes LE)
121    payload.extend_from_slice(&PROTOCOL_VERSION.to_le_bytes());
122
123    // Services (8 bytes LE) - NODE_NETWORK = 1
124    payload.extend_from_slice(&1u64.to_le_bytes());
125
126    // Timestamp (8 bytes LE)
127    let timestamp = SystemTime::now()
128        .duration_since(UNIX_EPOCH)
129        .unwrap_or_default()
130        .as_secs() as i64;
131    payload.extend_from_slice(&timestamp.to_le_bytes());
132
133    // addr_recv: services(8) + ipv6-mapped ipv4(16) + port(2 big-endian)
134    payload.extend_from_slice(&1u64.to_le_bytes()); // services
135    payload.extend_from_slice(&encode_net_addr(remote_addr));
136
137    // addr_from: services(8) + ipv6-mapped ipv4(16) + port(2 big-endian)
138    payload.extend_from_slice(&0u64.to_le_bytes()); // services
139    payload.extend_from_slice(&encode_net_addr(local_addr));
140
141    // Nonce (8 bytes)
142    let nonce: u64 = rand_u64();
143    payload.extend_from_slice(&nonce.to_le_bytes());
144
145    // User agent (varint length + string)
146    payload.push(USER_AGENT.len() as u8);
147    payload.extend_from_slice(USER_AGENT.as_bytes());
148
149    // Start height (4 bytes LE)
150    payload.extend_from_slice(&start_height.to_le_bytes());
151
152    // Relay (1 byte)
153    payload.push(1u8);
154
155    payload
156}
157
158/// Encode a SocketAddr as an IPv6-mapped IPv4 address (16 bytes) + port (2 bytes big-endian)
159fn encode_net_addr(addr: SocketAddr) -> Vec<u8> {
160    let mut buf = Vec::with_capacity(18);
161    match addr {
162        SocketAddr::V4(v4) => {
163            // IPv4-mapped IPv6: ::ffff:a.b.c.d
164            buf.extend_from_slice(&[0u8; 10]);
165            buf.extend_from_slice(&[0xff, 0xff]);
166            buf.extend_from_slice(&v4.ip().octets());
167        }
168        SocketAddr::V6(v6) => {
169            buf.extend_from_slice(&v6.ip().octets());
170        }
171    }
172    buf.extend_from_slice(&addr.port().to_be_bytes());
173    buf
174}
175
176/// Simple pseudo-random u64 using system time (good enough for nonces)
177fn rand_u64() -> u64 {
178    let now = SystemTime::now()
179        .duration_since(UNIX_EPOCH)
180        .unwrap_or_default();
181    let ns = now.as_nanos() as u64;
182    // XorShift to mix bits
183    let mut x = ns ^ 0x5DEECE66D;
184    x ^= x << 13;
185    x ^= x >> 7;
186    x ^= x << 17;
187    x
188}
189
190// ----- Connection state -----
191
192/// State of a single peer connection
193struct PeerConnection {
194    info: PeerInfo,
195    /// Misbehaviour score (ban when >= BAN_SCORE_THRESHOLD)
196    ban_score: i32,
197    /// Whether the handshake is complete
198    _handshake_complete: bool,
199    /// Optional TCP stream (kept for sending messages)
200    stream: Option<Arc<RwLock<TcpStream>>>,
201}
202
203// ----- TcpPeerManager -----
204
205/// TCP-based Bitcoin P2P peer manager.
206///
207/// Manages outbound connections, performs version/verack handshakes,
208/// tracks misbehaviour, and broadcasts blocks and transactions.
209pub struct TcpPeerManager {
210    /// Active peer connections indexed by peer_id
211    peers: Arc<RwLock<HashMap<u64, PeerConnection>>>,
212    /// Banned addresses with ban expiry timestamps
213    banned: Arc<RwLock<HashMap<SocketAddr, u64>>>,
214    /// Monotonically increasing peer ID counter
215    next_peer_id: AtomicU64,
216    /// Current best block height (for version message)
217    start_height: Arc<RwLock<u32>>,
218    /// Local listening address
219    local_addr: SocketAddr,
220}
221
222impl TcpPeerManager {
223    /// Create a new TCP peer manager
224    pub fn new(local_addr: SocketAddr) -> Self {
225        TcpPeerManager {
226            peers: Arc::new(RwLock::new(HashMap::new())),
227            banned: Arc::new(RwLock::new(HashMap::new())),
228            next_peer_id: AtomicU64::new(1),
229            start_height: Arc::new(RwLock::new(0)),
230            local_addr,
231        }
232    }
233
234    /// Update the current chain height (used in version messages)
235    pub async fn set_height(&self, height: u32) {
236        let mut h = self.start_height.write().await;
237        *h = height;
238    }
239
240    /// Get count of connected peers
241    pub async fn peer_count(&self) -> usize {
242        let peers = self.peers.read().await;
243        peers.len()
244    }
245
246    /// Check whether an address is currently banned
247    async fn is_banned(&self, addr: &SocketAddr) -> bool {
248        let banned = self.banned.read().await;
249        if let Some(&expiry) = banned.get(addr) {
250            let now = SystemTime::now()
251                .duration_since(UNIX_EPOCH)
252                .unwrap_or_default()
253                .as_secs();
254            if now < expiry {
255                return true;
256            }
257        }
258        false
259    }
260
261    /// Send a raw Bitcoin protocol message to a peer
262    async fn send_message(
263        stream: &Arc<RwLock<TcpStream>>,
264        command: &str,
265        payload: &[u8],
266    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
267        let header = MessageHeader::new(command, payload);
268        let mut msg = header.to_bytes();
269        msg.extend_from_slice(payload);
270
271        let mut writer = stream.write().await;
272        writer.write_all(&msg).await?;
273        writer.flush().await?;
274        Ok(())
275    }
276
277    /// Send a raw message to all connected peers, returning number of successful sends
278    async fn broadcast_message(
279        &self,
280        command: &str,
281        payload: &[u8],
282    ) -> Result<usize, Box<dyn std::error::Error + Send + Sync>> {
283        let peers = self.peers.read().await;
284        let mut count = 0usize;
285
286        for conn in peers.values() {
287            if let Some(ref stream) = conn.stream {
288                if Self::send_message(stream, command, payload).await.is_ok() {
289                    count += 1;
290                }
291            }
292        }
293
294        Ok(count)
295    }
296
297    /// Increase a peer's misbehaviour score; ban if threshold exceeded
298    pub async fn add_misbehaviour(&self, peer_id: u64, score: i32) {
299        let mut peers = self.peers.write().await;
300        if let Some(conn) = peers.get_mut(&peer_id) {
301            conn.ban_score += score;
302            tracing::warn!(
303                "Peer {} misbehaviour score: {} (+{})",
304                peer_id,
305                conn.ban_score,
306                score
307            );
308            if conn.ban_score >= BAN_SCORE_THRESHOLD {
309                let addr = conn.info.addr;
310                tracing::warn!(
311                    "Banning peer {} ({}) for exceeding threshold",
312                    peer_id,
313                    addr
314                );
315
316                let now = SystemTime::now()
317                    .duration_since(UNIX_EPOCH)
318                    .unwrap_or_default()
319                    .as_secs();
320
321                // Remove the peer
322                peers.remove(&peer_id);
323
324                // Drop peers lock before acquiring banned lock
325                drop(peers);
326
327                let mut banned = self.banned.write().await;
328                banned.insert(addr, now + DEFAULT_BAN_TIME);
329            }
330        }
331    }
332}
333
334impl Default for TcpPeerManager {
335    fn default() -> Self {
336        let addr: SocketAddr = "0.0.0.0:8333".parse().unwrap();
337        Self::new(addr)
338    }
339}
340
341#[async_trait]
342impl PeerManager for TcpPeerManager {
343    async fn connect_peer(
344        &self,
345        addr: SocketAddr,
346    ) -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
347        // Check ban list
348        if self.is_banned(&addr).await {
349            return Err(format!("Peer {} is banned", addr).into());
350        }
351
352        // Check max peers
353        {
354            let peers = self.peers.read().await;
355            if peers.len() >= MAX_PEERS {
356                return Err("Maximum number of peers reached".into());
357            }
358
359            // Check for duplicate connections
360            if peers.values().any(|c| c.info.addr == addr) {
361                return Err(format!("Already connected to {}", addr).into());
362            }
363        }
364
365        // Attempt TCP connection with a 10-second timeout
366        let stream =
367            tokio::time::timeout(std::time::Duration::from_secs(10), TcpStream::connect(addr))
368                .await
369                .map_err(|_| format!("Connection to {} timed out", addr))?
370                .map_err(|e| format!("Failed to connect to {}: {}", addr, e))?;
371
372        let stream = Arc::new(RwLock::new(stream));
373
374        // Perform version handshake
375        let height = *self.start_height.read().await;
376        let version_payload = build_version_payload(self.local_addr, addr, height);
377        Self::send_message(&stream, "version", &version_payload).await?;
378
379        // Read the peer's response (version + verack)
380        // In a production implementation we'd parse the response properly.
381        // Here we read up to 4KB and check for data to confirm connectivity.
382        {
383            let mut reader = stream.write().await;
384            let mut buf = vec![0u8; 4096];
385            match tokio::time::timeout(std::time::Duration::from_secs(10), reader.read(&mut buf))
386                .await
387            {
388                Ok(Ok(n)) if n > 0 => {
389                    tracing::debug!("Received {} bytes from peer {}", n, addr);
390                }
391                Ok(Ok(_)) => {
392                    return Err(format!("Peer {} closed connection during handshake", addr).into());
393                }
394                Ok(Err(e)) => {
395                    return Err(format!("Handshake read error from {}: {}", addr, e).into());
396                }
397                Err(_) => {
398                    return Err(format!("Handshake timeout with {}", addr).into());
399                }
400            }
401        }
402
403        // Send verack
404        Self::send_message(&stream, "verack", &[]).await?;
405
406        let peer_id = self.next_peer_id.fetch_add(1, Ordering::SeqCst);
407
408        let peer_info = PeerInfo {
409            id: peer_id,
410            addr,
411            services: 1, // NODE_NETWORK
412            version: PROTOCOL_VERSION,
413            subver: USER_AGENT.to_string(),
414            start_height: height,
415            relay_txs: true,
416        };
417
418        let connection = PeerConnection {
419            info: peer_info,
420            ban_score: 0,
421            _handshake_complete: true,
422            stream: Some(stream),
423        };
424
425        {
426            let mut peers = self.peers.write().await;
427            peers.insert(peer_id, connection);
428        }
429
430        tracing::info!("Connected to peer at {} (id: {})", addr, peer_id);
431        Ok(peer_id)
432    }
433
434    async fn disconnect_peer(
435        &self,
436        peer_id: u64,
437    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
438        let mut peers = self.peers.write().await;
439        if let Some(conn) = peers.remove(&peer_id) {
440            // Drop the stream to close the TCP connection
441            drop(conn.stream);
442            tracing::info!("Disconnected from peer {} ({})", peer_id, conn.info.addr);
443        }
444        Ok(())
445    }
446
447    async fn ban_peer(
448        &self,
449        addr: SocketAddr,
450        ban_time: u64,
451    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
452        let now = SystemTime::now()
453            .duration_since(UNIX_EPOCH)
454            .unwrap_or_default()
455            .as_secs();
456
457        // Ban the address
458        {
459            let mut banned = self.banned.write().await;
460            banned.insert(addr, now + ban_time);
461        }
462
463        // Disconnect any peers at this address
464        let mut peers = self.peers.write().await;
465        let to_remove: Vec<u64> = peers
466            .iter()
467            .filter(|(_, conn)| conn.info.addr == addr)
468            .map(|(id, _)| *id)
469            .collect();
470
471        for id in to_remove {
472            peers.remove(&id);
473        }
474
475        tracing::warn!("Banned peer {} for {} seconds", addr, ban_time);
476        Ok(())
477    }
478
479    async fn get_connected_peers(
480        &self,
481    ) -> Result<Vec<PeerInfo>, Box<dyn std::error::Error + Send + Sync>> {
482        let peers = self.peers.read().await;
483        Ok(peers.values().map(|c| c.info.clone()).collect())
484    }
485
486    async fn broadcast_transaction(
487        &self,
488        tx: &Transaction,
489    ) -> Result<usize, Box<dyn std::error::Error + Send + Sync>> {
490        // Build an "inv" message with the transaction hash
491        // inv payload: count(varint) + [type(4LE) + hash(32)]
492        let txid = tx.txid();
493        let mut payload = Vec::with_capacity(37);
494        payload.push(1u8); // count = 1 (varint)
495        payload.extend_from_slice(&1u32.to_le_bytes()); // type = MSG_TX (1)
496        payload.extend_from_slice(txid.as_bytes()); // 32-byte hash
497
498        let count = self.broadcast_message("inv", &payload).await?;
499        tracing::debug!("Broadcast tx {} inv to {} peers", txid, count);
500        Ok(count)
501    }
502
503    async fn broadcast_block(
504        &self,
505        block: &Block,
506    ) -> Result<usize, Box<dyn std::error::Error + Send + Sync>> {
507        // Build an "inv" message with the block hash
508        let hash = block.block_hash();
509        let mut payload = Vec::with_capacity(37);
510        payload.push(1u8); // count = 1 (varint)
511        payload.extend_from_slice(&2u32.to_le_bytes()); // type = MSG_BLOCK (2)
512        payload.extend_from_slice(hash.as_bytes()); // 32-byte hash
513
514        let count = self.broadcast_message("inv", &payload).await?;
515        tracing::debug!("Broadcast block {} inv to {} peers", hash, count);
516        Ok(count)
517    }
518
519    async fn send_to_peer(
520        &self,
521        peer_id: u64,
522        message: NetworkMessage,
523    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
524        let peers = self.peers.read().await;
525        let conn = peers
526            .get(&peer_id)
527            .ok_or_else(|| format!("Unknown peer {}", peer_id))?;
528        let stream = conn
529            .stream
530            .as_ref()
531            .ok_or_else(|| format!("No stream for peer {}", peer_id))?;
532
533        let (command, payload) = encode_network_message(&message);
534        Self::send_message(stream, &command, &payload).await
535    }
536}
537
538/// Encode a NetworkMessage into a command string and payload bytes.
539fn encode_network_message(msg: &NetworkMessage) -> (String, Vec<u8>) {
540    match msg {
541        NetworkMessage::GetHeaders {
542            version,
543            block_locator,
544            hash_stop,
545        } => {
546            let mut payload = Vec::new();
547            payload.extend_from_slice(&version.to_le_bytes());
548            // varint count of locator hashes
549            push_varint_net(&mut payload, block_locator.len() as u64);
550            for hash in block_locator {
551                payload.extend_from_slice(hash.as_bytes());
552            }
553            payload.extend_from_slice(hash_stop.as_bytes());
554            ("getheaders".to_string(), payload)
555        }
556        NetworkMessage::GetBlocks {
557            version,
558            block_locator,
559            hash_stop,
560        } => {
561            let mut payload = Vec::new();
562            payload.extend_from_slice(&version.to_le_bytes());
563            push_varint_net(&mut payload, block_locator.len() as u64);
564            for hash in block_locator {
565                payload.extend_from_slice(hash.as_bytes());
566            }
567            payload.extend_from_slice(hash_stop.as_bytes());
568            ("getblocks".to_string(), payload)
569        }
570        NetworkMessage::GetData { items } => {
571            let mut payload = Vec::new();
572            push_varint_net(&mut payload, items.len() as u64);
573            for item in items {
574                match item {
575                    abtc_ports::InventoryItem::Tx(txid) => {
576                        payload.extend_from_slice(&1u32.to_le_bytes()); // MSG_TX
577                        payload.extend_from_slice(txid.as_bytes());
578                    }
579                    abtc_ports::InventoryItem::Block(hash) => {
580                        payload.extend_from_slice(&2u32.to_le_bytes()); // MSG_BLOCK
581                        payload.extend_from_slice(hash.as_bytes());
582                    }
583                }
584            }
585            ("getdata".to_string(), payload)
586        }
587        NetworkMessage::Inv { items } => {
588            let mut payload = Vec::new();
589            push_varint_net(&mut payload, items.len() as u64);
590            for item in items {
591                match item {
592                    abtc_ports::InventoryItem::Tx(txid) => {
593                        payload.extend_from_slice(&1u32.to_le_bytes());
594                        payload.extend_from_slice(txid.as_bytes());
595                    }
596                    abtc_ports::InventoryItem::Block(hash) => {
597                        payload.extend_from_slice(&2u32.to_le_bytes());
598                        payload.extend_from_slice(hash.as_bytes());
599                    }
600                }
601            }
602            ("inv".to_string(), payload)
603        }
604        NetworkMessage::Ping { nonce } => ("ping".to_string(), nonce.to_le_bytes().to_vec()),
605        NetworkMessage::Pong { nonce } => ("pong".to_string(), nonce.to_le_bytes().to_vec()),
606        NetworkMessage::Verack => ("verack".to_string(), Vec::new()),
607        NetworkMessage::Headers { headers } => {
608            let mut payload = Vec::new();
609            push_varint_net(&mut payload, headers.len() as u64);
610            for hdr in headers {
611                payload.extend_from_slice(&hdr.version.to_le_bytes());
612                payload.extend_from_slice(hdr.prev_block_hash.as_bytes());
613                payload.extend_from_slice(hdr.merkle_root.as_bytes());
614                payload.extend_from_slice(&hdr.time.to_le_bytes());
615                payload.extend_from_slice(&hdr.bits.to_le_bytes());
616                payload.extend_from_slice(&hdr.nonce.to_le_bytes());
617                payload.push(0); // tx_count = 0 for headers message
618            }
619            ("headers".to_string(), payload)
620        }
621        // For Version, Tx, Block — these are typically handled at a higher level
622        // or via dedicated methods. Provide a basic fallback:
623        NetworkMessage::Version { .. } => {
624            // The version message is built via build_version_payload() above
625            ("version".to_string(), Vec::new())
626        }
627        NetworkMessage::Tx { .. } => ("tx".to_string(), Vec::new()),
628        NetworkMessage::Block { .. } => ("block".to_string(), Vec::new()),
629        NetworkMessage::Addr { .. } => ("addr".to_string(), Vec::new()),
630        NetworkMessage::GetAddr => ("getaddr".to_string(), Vec::new()),
631        NetworkMessage::PackageTx { .. } => ("pkgtxns".to_string(), Vec::new()),
632        NetworkMessage::SendPackages { version } => {
633            ("sendpackages".to_string(), version.to_le_bytes().to_vec())
634        }
635    }
636}
637
638/// Push a Bitcoin-style varint (for network serialization)
639fn push_varint_net(buf: &mut Vec<u8>, value: u64) {
640    if value < 0xfd {
641        buf.push(value as u8);
642    } else if value <= 0xffff {
643        buf.push(0xfd);
644        buf.extend_from_slice(&(value as u16).to_le_bytes());
645    } else if value <= 0xffff_ffff {
646        buf.push(0xfe);
647        buf.extend_from_slice(&(value as u32).to_le_bytes());
648    } else {
649        buf.push(0xff);
650        buf.extend_from_slice(&value.to_le_bytes());
651    }
652}
653
654// ----- Stub peer manager (kept for tests and simple setups) -----
655
656/// Stub P2P peer manager implementation
657///
658/// A lightweight stub suitable for testing and development when
659/// actual network connectivity is not needed.
660pub struct StubPeerManager {
661    connected_peers: Arc<RwLock<Vec<PeerInfo>>>,
662}
663
664impl StubPeerManager {
665    /// Create a new stub peer manager
666    pub fn new() -> Self {
667        StubPeerManager {
668            connected_peers: Arc::new(RwLock::new(Vec::new())),
669        }
670    }
671
672    /// Get count of connected peers
673    pub async fn peer_count(&self) -> usize {
674        let peers = self.connected_peers.read().await;
675        peers.len()
676    }
677}
678
679impl Default for StubPeerManager {
680    fn default() -> Self {
681        Self::new()
682    }
683}
684
685#[async_trait]
686impl PeerManager for StubPeerManager {
687    async fn connect_peer(
688        &self,
689        addr: SocketAddr,
690    ) -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
691        let peer_id = (addr.ip().to_string().len() as u64) * 1000 + addr.port() as u64;
692
693        let peer_info = PeerInfo {
694            id: peer_id,
695            addr,
696            services: 1,
697            version: 70015,
698            subver: "/StubNode:0.1.0/".to_string(),
699            start_height: 0,
700            relay_txs: true,
701        };
702
703        let mut peers = self.connected_peers.write().await;
704        peers.push(peer_info);
705
706        tracing::info!("Connected to peer at {} (id: {})", addr, peer_id);
707        Ok(peer_id)
708    }
709
710    async fn disconnect_peer(
711        &self,
712        peer_id: u64,
713    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
714        let mut peers = self.connected_peers.write().await;
715        peers.retain(|p| p.id != peer_id);
716        tracing::info!("Disconnected from peer {}", peer_id);
717        Ok(())
718    }
719
720    async fn ban_peer(
721        &self,
722        addr: SocketAddr,
723        ban_time: u64,
724    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
725        tracing::warn!("Banning peer {} for {} seconds", addr, ban_time);
726        Ok(())
727    }
728
729    async fn get_connected_peers(
730        &self,
731    ) -> Result<Vec<PeerInfo>, Box<dyn std::error::Error + Send + Sync>> {
732        let peers = self.connected_peers.read().await;
733        Ok(peers.clone())
734    }
735
736    async fn broadcast_transaction(
737        &self,
738        _tx: &Transaction,
739    ) -> Result<usize, Box<dyn std::error::Error + Send + Sync>> {
740        let peers = self.connected_peers.read().await;
741        let count = peers.len();
742        tracing::debug!("Broadcasting transaction to {} peers", count);
743        Ok(count)
744    }
745
746    async fn broadcast_block(
747        &self,
748        block: &Block,
749    ) -> Result<usize, Box<dyn std::error::Error + Send + Sync>> {
750        let peers = self.connected_peers.read().await;
751        let count = peers.len();
752        tracing::debug!(
753            "Broadcasting block {} to {} peers",
754            block.block_hash(),
755            count
756        );
757        Ok(count)
758    }
759
760    async fn send_to_peer(
761        &self,
762        peer_id: u64,
763        message: NetworkMessage,
764    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
765        tracing::debug!("Stub: send message to peer {}: {:?}", peer_id, message);
766        Ok(())
767    }
768}
769
770#[cfg(test)]
771mod tests {
772    use super::*;
773
774    #[tokio::test]
775    async fn test_stub_peer_manager_creation() {
776        let manager = StubPeerManager::new();
777        assert_eq!(manager.peer_count().await, 0);
778    }
779
780    #[tokio::test]
781    async fn test_connect_disconnect_peer() {
782        let manager = StubPeerManager::new();
783        let addr: SocketAddr = "127.0.0.1:8333".parse().unwrap();
784
785        let peer_id = manager.connect_peer(addr).await.unwrap();
786        assert_eq!(manager.peer_count().await, 1);
787
788        manager.disconnect_peer(peer_id).await.unwrap();
789        assert_eq!(manager.peer_count().await, 0);
790    }
791
792    #[tokio::test]
793    async fn test_tcp_peer_manager_creation() {
794        let addr: SocketAddr = "127.0.0.1:8333".parse().unwrap();
795        let manager = TcpPeerManager::new(addr);
796        assert_eq!(manager.peer_count().await, 0);
797    }
798
799    #[tokio::test]
800    async fn test_tcp_peer_manager_ban() {
801        let local: SocketAddr = "127.0.0.1:8333".parse().unwrap();
802        let manager = TcpPeerManager::new(local);
803        let banned_addr: SocketAddr = "10.0.0.1:8333".parse().unwrap();
804
805        assert!(!manager.is_banned(&banned_addr).await);
806
807        manager.ban_peer(banned_addr, 3600).await.unwrap();
808        assert!(manager.is_banned(&banned_addr).await);
809    }
810
811    #[test]
812    fn test_message_header() {
813        let payload = b"test payload";
814        let header = MessageHeader::new("version", payload);
815
816        assert_eq!(header.magic, MAINNET_MAGIC);
817        assert_eq!(&header.command[..7], b"version");
818        assert_eq!(header.length, 12);
819
820        let bytes = header.to_bytes();
821        assert_eq!(bytes.len(), 24);
822    }
823
824    #[test]
825    fn test_version_payload() {
826        let local: SocketAddr = "127.0.0.1:8333".parse().unwrap();
827        let remote: SocketAddr = "10.0.0.1:8333".parse().unwrap();
828        let payload = build_version_payload(local, remote, 100);
829
830        // Minimum expected size
831        assert!(payload.len() >= 86);
832
833        // Check protocol version
834        let version = u32::from_le_bytes([payload[0], payload[1], payload[2], payload[3]]);
835        assert_eq!(version, PROTOCOL_VERSION);
836    }
837
838    #[test]
839    fn test_encode_net_addr_v4() {
840        let addr: SocketAddr = "1.2.3.4:8333".parse().unwrap();
841        let encoded = encode_net_addr(addr);
842        assert_eq!(encoded.len(), 18);
843        // Check IPv4-mapped prefix
844        assert_eq!(&encoded[10..12], &[0xff, 0xff]);
845        // Check IP bytes
846        assert_eq!(&encoded[12..16], &[1, 2, 3, 4]);
847        // Check port (big-endian)
848        assert_eq!(&encoded[16..18], &[0x20, 0x8D]); // 8333
849    }
850}