use std::{io::Write, mem};
use byteorder::{BigEndian, LittleEndian, WriteBytesExt};
use bytes::BytesMut;
use tokio_util::codec::Decoder;
use zebra_chain::{
parameters::Network,
serialization::{sha256d, TrustedPreallocate, ZcashDeserialize},
};
use crate::{
constants::MAX_ADDRS_IN_MESSAGE,
protocol::external::{addr::AddrV2, Codec},
};
#[allow(clippy::identity_op)]
const ADDR_V2_MIN_SIZE: usize = 4 + 1 + 1 + 1 + 0 + 2;
fn build_attack_message(count: usize) -> Vec<u8> {
let mut body = Vec::new();
if count < 253 {
body.write_u8(count as u8).unwrap();
} else if count <= 0xFFFF {
body.write_u8(0xFD).unwrap();
body.write_u16::<LittleEndian>(count as u16).unwrap();
} else {
body.write_u8(0xFE).unwrap();
body.write_u32::<LittleEndian>(count as u32).unwrap();
}
for _ in 0..count {
body.write_u32::<LittleEndian>(0x495FAB29).unwrap(); body.write_u8(0).unwrap(); body.write_u8(0xFF).unwrap(); body.write_u8(0).unwrap(); body.write_u16::<BigEndian>(0).unwrap(); }
let mut msg = Vec::with_capacity(24 + body.len());
msg.write_all(&Network::Mainnet.magic().0).unwrap();
msg.write_all(b"addrv2\0\0\0\0\0\0").unwrap();
msg.write_u32::<LittleEndian>(body.len() as u32).unwrap();
msg.write_all(&sha256d::Checksum::from(body.as_slice()).0)
.unwrap();
msg.write_all(&body).unwrap();
msg
}
#[test]
fn poc_remote_addrv2_resource_exhaustion() {
let _init_guard = zebra_test::init();
let attack_count = (2_097_152 - 5) / ADDR_V2_MIN_SIZE; let raw = build_attack_message(attack_count);
let heap_bytes = attack_count * mem::size_of::<AddrV2>();
let mut codec = Codec::builder().finish();
let mut src = BytesMut::from(raw.as_slice());
let result = codec.decode(&mut src);
assert!(
AddrV2::max_allocation() <= MAX_ADDRS_IN_MESSAGE as u64,
"max_allocation() is {} — a remote peer can force {:.1} MiB heap allocation",
AddrV2::max_allocation(),
heap_bytes as f64 / (1024.0 * 1024.0),
);
assert!(
result.is_err(),
"message with {attack_count} entries was accepted"
);
let oversized_body = build_attack_message(MAX_ADDRS_IN_MESSAGE + 1);
let body_only = &oversized_body[24..];
assert!(
Vec::<AddrV2>::zcash_deserialize(body_only).is_err(),
"Vec<AddrV2>::zcash_deserialize accepted {} entries (protocol cap: {})",
MAX_ADDRS_IN_MESSAGE + 1,
MAX_ADDRS_IN_MESSAGE,
);
}