use std::net::SocketAddr;
use ahash::RandomState;
use serde::{Deserialize, Serialize};
use serde_with::serde_as;
use crate::message::Transaction;
const BLOOM_BITS: usize = 512;
const BLOOM_BYTES: usize = BLOOM_BITS / 8;
const NUM_HASHES: usize = 4;
#[serde_as]
#[derive(Clone, Serialize, Deserialize)]
pub struct VisitedPeers {
#[serde_as(as = "[_; BLOOM_BYTES]")]
bits: [u8; BLOOM_BYTES],
#[serde(skip)]
hash_keys: (u64, u64),
}
impl std::fmt::Debug for VisitedPeers {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let set_bits = self.bits.iter().map(|b| b.count_ones()).sum::<u32>();
f.debug_struct("VisitedPeers")
.field("set_bits", &set_bits)
.field("total_bits", &BLOOM_BITS)
.finish()
}
}
impl VisitedPeers {
pub fn new(tx: &Transaction) -> Self {
let tx_bytes = tx.id_bytes();
Self {
bits: [0u8; BLOOM_BYTES],
hash_keys: Self::derive_hash_keys(&tx_bytes),
}
}
pub fn with_transaction(mut self, tx: &Transaction) -> Self {
let tx_bytes = tx.id_bytes();
self.hash_keys = Self::derive_hash_keys(&tx_bytes);
self
}
fn derive_hash_keys(tx_bytes: &[u8; 16]) -> (u64, u64) {
let key0 = u64::from_le_bytes([
tx_bytes[0],
tx_bytes[1],
tx_bytes[2],
tx_bytes[3],
tx_bytes[4],
tx_bytes[5],
tx_bytes[6],
tx_bytes[7],
]);
let key1 = u64::from_le_bytes([
tx_bytes[8],
tx_bytes[9],
tx_bytes[10],
tx_bytes[11],
tx_bytes[12],
tx_bytes[13],
tx_bytes[14],
tx_bytes[15],
]);
(key0, key1)
}
pub fn mark_visited(&mut self, addr: SocketAddr) {
for idx in self.hash_indices(&addr) {
let byte_idx = idx / 8;
let bit_idx = idx % 8;
self.bits[byte_idx] |= 1 << bit_idx;
}
}
pub fn probably_visited(&self, addr: SocketAddr) -> bool {
for idx in self.hash_indices(&addr) {
let byte_idx = idx / 8;
let bit_idx = idx % 8;
if self.bits[byte_idx] & (1 << bit_idx) == 0 {
return false;
}
}
true
}
fn hash_indices(&self, addr: &SocketAddr) -> [usize; NUM_HASHES] {
let state1 = RandomState::with_seeds(self.hash_keys.0, self.hash_keys.1, 0, 0);
let state2 = RandomState::with_seeds(
self.hash_keys.0.wrapping_add(0x9e3779b97f4a7c15),
self.hash_keys.1.wrapping_add(0x517cc1b727220a95),
0,
0,
);
let h1 = state1.hash_one(addr);
let h2 = state2.hash_one(addr);
[
(h1 as usize) % BLOOM_BITS,
(h1.wrapping_add(h2) as usize) % BLOOM_BITS,
(h1.wrapping_add(h2.wrapping_mul(2)) as usize) % BLOOM_BITS,
(h1.wrapping_add(h2.wrapping_mul(3)) as usize) % BLOOM_BITS,
]
}
}
impl Default for VisitedPeers {
fn default() -> Self {
Self {
bits: [0u8; BLOOM_BYTES],
hash_keys: (0, 0),
}
}
}
impl crate::util::Contains<SocketAddr> for VisitedPeers {
fn has_element(&self, target: SocketAddr) -> bool {
self.probably_visited(target)
}
}
impl crate::util::Contains<SocketAddr> for &VisitedPeers {
fn has_element(&self, target: SocketAddr) -> bool {
self.probably_visited(target)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::message::Transaction;
use crate::operations::get::GetMsg;
fn test_transaction() -> Transaction {
Transaction::new::<GetMsg>()
}
#[test]
fn test_basic_operations() {
let tx = test_transaction();
let mut visited = VisitedPeers::new(&tx);
let addr: SocketAddr = "127.0.0.1:8000".parse().unwrap();
assert!(!visited.probably_visited(addr));
visited.mark_visited(addr);
assert!(visited.probably_visited(addr));
let other_addr: SocketAddr = "127.0.0.1:9000".parse().unwrap();
assert!(!visited.probably_visited(other_addr));
}
#[test]
fn test_multiple_addresses() {
let tx = test_transaction();
let mut visited = VisitedPeers::new(&tx);
let addrs: Vec<SocketAddr> = (8000..8020)
.map(|port| format!("127.0.0.1:{}", port).parse().unwrap())
.collect();
for addr in &addrs {
visited.mark_visited(*addr);
}
for addr in &addrs {
assert!(
visited.probably_visited(*addr),
"Address {} should be visited",
addr
);
}
}
#[test]
fn test_transaction_isolation() {
let tx1 = test_transaction();
let tx2 = test_transaction();
let v1 = VisitedPeers::new(&tx1);
let v2 = VisitedPeers::new(&tx2);
assert_ne!(
v1.hash_keys, v2.hash_keys,
"Different transactions must produce different hash keys"
);
let v1_again = VisitedPeers::new(&tx1);
assert_eq!(
v1.hash_keys, v1_again.hash_keys,
"Same transaction must produce identical hash keys"
);
}
#[test]
fn test_serialization_roundtrip() {
let tx = test_transaction();
let mut visited = VisitedPeers::new(&tx);
let addrs: Vec<SocketAddr> = vec![
"127.0.0.1:8000".parse().unwrap(),
"192.168.1.1:9000".parse().unwrap(),
"[::1]:8080".parse().unwrap(),
];
for addr in &addrs {
visited.mark_visited(*addr);
}
let bytes = bincode::serialize(&visited).expect("serialization failed");
let deserialized: VisitedPeers =
bincode::deserialize(&bytes).expect("deserialization failed");
let deserialized = deserialized.with_transaction(&tx);
for addr in &addrs {
assert!(
deserialized.probably_visited(*addr),
"Address {} should be visited after roundtrip",
addr
);
}
}
#[test]
fn test_size_is_fixed() {
let tx = test_transaction();
let mut visited = VisitedPeers::new(&tx);
let initial_size = std::mem::size_of_val(&visited.bits);
assert_eq!(initial_size, BLOOM_BYTES);
for port in 8000..8100 {
let addr: SocketAddr = format!("127.0.0.1:{}", port).parse().unwrap();
visited.mark_visited(addr);
}
let final_size = std::mem::size_of_val(&visited.bits);
assert_eq!(final_size, BLOOM_BYTES);
}
#[test]
fn test_false_positive_rate() {
let tx = test_transaction();
let mut visited = VisitedPeers::new(&tx);
let inserted_addrs: Vec<SocketAddr> = (8000..8020)
.map(|port| format!("127.0.0.1:{}", port).parse().unwrap())
.collect();
for addr in &inserted_addrs {
visited.mark_visited(*addr);
}
let mut false_positives = 0;
for port in 10000..20000 {
let addr: SocketAddr = format!("10.0.0.1:{}", port).parse().unwrap();
if visited.probably_visited(addr) {
false_positives += 1;
}
}
let fp_rate = false_positives as f64 / 10000.0 * 100.0;
assert!(
false_positives <= 50,
"False positive rate too high: {}/10000 = {:.3}% (expected ~0.04%)",
false_positives,
fp_rate
);
}
#[test]
fn test_no_false_negatives() {
let tx = test_transaction();
let mut visited = VisitedPeers::new(&tx);
let addrs: Vec<SocketAddr> = (8000..8050)
.map(|port| format!("127.0.0.1:{}", port).parse().unwrap())
.collect();
for addr in &addrs {
visited.mark_visited(*addr);
}
for addr in &addrs {
assert!(
visited.probably_visited(*addr),
"False negative detected for {}",
addr
);
}
}
#[test]
fn test_serialized_size() {
let tx = test_transaction();
let visited = VisitedPeers::new(&tx);
let bytes = bincode::serialize(&visited).expect("serialization failed");
assert_eq!(
bytes.len(),
64,
"Serialized size should be exactly 64 bytes, got {}",
bytes.len()
);
let mut visited_full = VisitedPeers::new(&tx);
for port in 8000..8100 {
let addr: SocketAddr = format!("127.0.0.1:{}", port).parse().unwrap();
visited_full.mark_visited(addr);
}
let bytes_full = bincode::serialize(&visited_full).expect("serialization failed");
assert_eq!(
bytes.len(),
bytes_full.len(),
"Serialized size should be fixed regardless of marked peers"
);
}
#[test]
fn test_forwarding_must_mark_both_this_peer_and_sender() {
let tx = test_transaction();
let mut visited = VisitedPeers::new(&tx);
let this_peer: SocketAddr = "10.0.0.1:8000".parse().unwrap();
let sender: SocketAddr = "10.0.0.2:8000".parse().unwrap();
let originator: SocketAddr = "10.0.0.3:8000".parse().unwrap();
visited.mark_visited(this_peer);
visited.mark_visited(sender);
assert!(
visited.probably_visited(this_peer),
"this_peer must be marked to prevent routing back to processing node"
);
assert!(
visited.probably_visited(sender),
"sender must be marked to prevent routing back to request source"
);
assert!(
!visited.probably_visited(originator),
"originator not marked yet - they should mark themselves when initiating"
);
}
#[test]
fn test_request_cycle_prevention() {
let tx = test_transaction();
let mut visited = VisitedPeers::new(&tx);
let originator: SocketAddr = "10.0.0.1:8000".parse().unwrap();
let gateway: SocketAddr = "10.0.0.2:8000".parse().unwrap();
let peer1: SocketAddr = "10.0.0.3:8000".parse().unwrap();
let peer2: SocketAddr = "10.0.0.4:8000".parse().unwrap();
visited.mark_visited(gateway);
visited.mark_visited(originator);
visited.mark_visited(peer1);
visited.mark_visited(gateway);
visited.mark_visited(peer2);
visited.mark_visited(peer1);
assert!(
visited.probably_visited(originator),
"originator must be blocked"
);
assert!(visited.probably_visited(gateway), "gateway must be blocked");
assert!(visited.probably_visited(peer1), "peer1 must be blocked");
assert!(visited.probably_visited(peer2), "peer2 must be blocked");
let peer3: SocketAddr = "10.0.0.5:8000".parse().unwrap();
assert!(
!visited.probably_visited(peer3),
"unvisited peer should still be valid for routing"
);
}
}