use std::collections::HashSet;
use std::error::Error;
use std::io::Cursor;
use std::net::Ipv4Addr;
use std::thread;
use std::time::{Duration, Instant};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use clap::Clap;
use etherparse::{
InternetSlice, PacketBuilder, PacketBuilderStep, SlicedPacket, TransportSlice, UdpHeader,
};
use xdpsock::{
socket::{BindFlags, SocketConfig, SocketConfigBuilder, XdpFlags},
umem::{UmemConfig, UmemConfigBuilder},
xsk::{Xsk2, MAX_PACKET_SIZE},
};
#[derive(Clap, Debug, Clone)]
enum Mode {
Tx,
Rx,
}
#[derive(Debug, Clone, Clap)]
#[clap(version = "1.0", author = "Collins Huff")]
struct Opts {
#[clap(short, long)]
dev: String,
#[clap(long)]
src_mac: String,
#[clap(long)]
dest_mac: String,
#[clap(long)]
src_ip: String,
#[clap(long)]
src_port: u16,
#[clap(long)]
dest_ip: String,
#[clap(long)]
dest_port: u16,
#[clap(short, long, parse(from_occurrences))]
verbose: i32,
#[clap(subcommand)]
mode: Mode,
#[clap(short, long)]
n_pkts: u64,
#[clap(long)]
n_threads: Option<u64>,
}
fn main() {
env_logger::init();
let opts: Opts = Opts::parse();
let umem_config = UmemConfigBuilder::new()
.frame_count(8192)
.comp_queue_size(4096)
.fill_queue_size(4096)
.build()
.unwrap();
let socket_config = SocketConfigBuilder::new()
.tx_queue_size(4096)
.rx_queue_size(4096)
.bind_flags(BindFlags::XDP_COPY)
.xdp_flags(XdpFlags::XDP_FLAGS_SKB_MODE)
.build()
.unwrap();
let n_tx_frames = umem_config.frame_count() / 2;
let dev_ifname = opts.dev.clone();
let mut xsk = Xsk2::new(
&dev_ifname,
0,
umem_config,
socket_config,
n_tx_frames as usize,
);
match opts.mode {
Mode::Tx => spawn_tx(xsk, opts),
Mode::Rx => spawn_rx(xsk, opts),
}
}
fn spawn_tx(mut xsk: Xsk2, opts: Opts) {
let n_send_threads = match opts.n_threads {
Some(n) => n,
None => 1,
};
eprintln!("sending {} pkts", opts.n_pkts);
let src_mac = parse_mac(&opts.src_mac).expect("failed to parse src mac addr");
let dest_mac = parse_mac(&opts.dest_mac).expect("failed to parse dest mac addr");
let filter = Filter::new(&opts.src_ip, opts.src_port, &opts.dest_ip, opts.dest_port).unwrap();
let tx_send = xsk.tx_sender().unwrap();
let mut send_handles = vec![];
let pkts_per_thread = opts.n_pkts / n_send_threads;
for i in 0..n_send_threads {
let n_start = i * pkts_per_thread;
let n_end = n_start + pkts_per_thread;
eprintln!("thread {} sending nums {} to {}", i, n_start, n_end);
let filter = filter.clone();
let tx_send = tx_send.clone();
let send_handle = thread::spawn(move || {
for n in n_start..n_end {
let pkt_builder = PacketBuilder::ethernet2(src_mac, dest_mac)
.ipv4(filter.src_ip, filter.dest_ip, 20)
.udp(filter.src_port, filter.dest_port);
let pkt_with_payload = generate_pkt(pkt_builder, n);
let mut packet: [u8; MAX_PACKET_SIZE] = [0; MAX_PACKET_SIZE];
let l = std::cmp::min(MAX_PACKET_SIZE, pkt_with_payload.len());
let packet_slice = &mut packet[..l];
packet_slice.copy_from_slice(&pkt_with_payload[..l]);
tx_send
.send((packet, pkt_with_payload.len()))
.expect("failed to put packet on tx queue");
}
});
send_handles.push(send_handle);
}
drop(tx_send);
for handle in send_handles.into_iter() {
handle.join().expect("failed to join tx handle");
}
let tx_stats = xsk.shutdown_tx().expect("failed to shutdown tx");
let rx_stats = xsk.shutdown_rx().expect("failed to shut down rx");
eprintln!("tx_stats = {:?}", tx_stats);
eprintln!("tx duration = {:?}", tx_stats.duration());
eprintln!("tx pps = {:?}", tx_stats.pps());
eprintln!("rx_stats = {:?}", rx_stats);
}
fn generate_pkt(pkt_builder: PacketBuilderStep<UdpHeader>, n: u64) -> Vec<u8> {
let mut payload = vec![];
payload.write_u64::<LittleEndian>(n).unwrap();
let mut result = Vec::<u8>::with_capacity(pkt_builder.size(payload.len()));
pkt_builder
.write(&mut result, &payload)
.expect("failed to build packet");
result
}
#[derive(Debug, Clone)]
struct Filter {
src_ip: [u8; 4],
src_port: u16,
dest_ip: [u8; 4],
dest_port: u16,
}
impl Filter {
fn new(
src_ip: &str,
src_port: u16,
dest_ip: &str,
dest_port: u16,
) -> Result<Self, Box<dyn Error>> {
let src_ipv4: Ipv4Addr = src_ip.parse()?;
let dest_ipv4: Ipv4Addr = dest_ip.parse()?;
Ok(Self {
src_ip: src_ipv4.octets(),
src_port,
dest_ip: dest_ipv4.octets(),
dest_port,
})
}
}
fn spawn_rx(mut xsk: Xsk2, opts: Opts) {
let rx_recv = xsk.rx_receiver().unwrap();
let filter = Filter::new(&opts.src_ip, opts.src_port, &opts.dest_ip, opts.dest_port).unwrap();
let recv_handle = thread::spawn(move || {
let mut recvd_nums: HashSet<u64> = HashSet::new();
for (pkt, len) in rx_recv.iter() {
match SlicedPacket::from_ethernet(&pkt[..len]) {
Ok(pkt) => {
if filter_pkt(&pkt, &filter) {
let mut rdr = Cursor::new(&pkt.payload[0..8]);
let n = rdr.read_u64::<LittleEndian>().unwrap();
recvd_nums.insert(n);
}
}
Err(e) => log::warn!("failed to parse packet {:?}", e),
}
}
recvd_nums
});
thread::sleep(Duration::from_secs(30));
let rx_stats = xsk.shutdown_rx().expect("failed to shut down rx");
eprintln!("rx_stats = {:?}", rx_stats);
eprintln!("rx duration = {:?}", rx_stats.duration());
eprintln!("rx pps = {:?}", rx_stats.pps());
let tx_stats = xsk.shutdown_tx().expect("failed to shut down tx");
eprintln!("tx_stats = {:?}", tx_stats);
let recvd_nums = recv_handle.join().expect("failed to join recv handle");
let expected_recvd_nums: Vec<u64> = (0..opts.n_pkts).into_iter().collect();
let mut n_missing = 0;
for n in expected_recvd_nums.iter() {
if !recvd_nums.contains(n) {
n_missing += 1;
}
}
eprintln!("missing {} packets", n_missing);
}
fn filter_pkt(parsed_pkt: &SlicedPacket, filter: &Filter) -> bool {
let mut ip_match = false;
let mut transport_match = false;
if let Some(ref ip) = parsed_pkt.ip {
if let InternetSlice::Ipv4(ipv4) = ip {
ip_match = (ipv4.source() == filter.src_ip) && (ipv4.destination() == filter.dest_ip);
}
}
if let Some(ref transport) = parsed_pkt.transport {
if let TransportSlice::Udp(udp) = transport {
transport_match = (udp.source_port() == filter.src_port)
&& (udp.destination_port() == filter.dest_port);
}
}
ip_match && transport_match
}
fn parse_mac(mac: &str) -> Result<[u8; 6], Box<dyn Error>> {
let mut mac_bytes: [u8; 6] = [0; 6];
let parts: Vec<&str> = mac.split(':').into_iter().collect();
if parts.len() != 6 {
Err("wrong len".into())
} else {
for (i, part) in parts.iter().enumerate() {
let mac_byte = u8::from_str_radix(part, 16)?;
mac_bytes[i] = mac_byte;
}
Ok(mac_bytes)
}
}