xdpsock 0.1.0

Rust bindings for Linux AF_XDP sockets
Documentation
use std::collections::HashSet;
use std::error::Error;
use std::io::Cursor;
use std::net::Ipv4Addr;
use std::thread;
use std::time::{Duration, Instant};

use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use clap::Clap;
use etherparse::{
    InternetSlice, PacketBuilder, PacketBuilderStep, SlicedPacket, TransportSlice, UdpHeader,
};

use xdpsock::{
    socket::{BindFlags, SocketConfig, SocketConfigBuilder, XdpFlags},
    umem::{UmemConfig, UmemConfigBuilder},
    xsk::{Xsk2, MAX_PACKET_SIZE},
};

#[derive(Clap, Debug, Clone)]
enum Mode {
    Tx,
    Rx,
}

/// Send or Receive UDP packets on the specified interface
#[derive(Debug, Clone, Clap)]
#[clap(version = "1.0", author = "Collins Huff")]
struct Opts {
    /// interface name
    #[clap(short, long)]
    dev: String,

    /// source MAC address
    #[clap(long)]
    src_mac: String,

    /// destination MAC address
    #[clap(long)]
    dest_mac: String,

    /// source IP address
    #[clap(long)]
    src_ip: String,

    /// source port
    #[clap(long)]
    src_port: u16,

    /// destination IP address
    #[clap(long)]
    dest_ip: String,

    /// destination port
    #[clap(long)]
    dest_port: u16,

    /// A level of verbosity, and can be used multiple times
    #[clap(short, long, parse(from_occurrences))]
    verbose: i32,

    /// Transmit or Receive mode
    #[clap(subcommand)]
    mode: Mode,

    /// Number of packets to transmit or wait to receive
    #[clap(short, long)]
    n_pkts: u64,

    #[clap(long)]
    n_threads: Option<u64>,
}

fn main() {
    env_logger::init();
    let opts: Opts = Opts::parse();

    let umem_config = UmemConfigBuilder::new()
        .frame_count(8192)
        .comp_queue_size(4096)
        .fill_queue_size(4096)
        .build()
        .unwrap();

    let socket_config = SocketConfigBuilder::new()
        .tx_queue_size(4096)
        .rx_queue_size(4096)
        .bind_flags(BindFlags::XDP_COPY)
        .xdp_flags(XdpFlags::XDP_FLAGS_SKB_MODE)
        .build()
        .unwrap();

    let n_tx_frames = umem_config.frame_count() / 2;

    let dev_ifname = opts.dev.clone();
    let mut xsk = Xsk2::new(
        &dev_ifname,
        0,
        umem_config,
        socket_config,
        n_tx_frames as usize,
    );

    match opts.mode {
        Mode::Tx => spawn_tx(xsk, opts),
        Mode::Rx => spawn_rx(xsk, opts),
    }
}

fn spawn_tx(mut xsk: Xsk2, opts: Opts) {
    let n_send_threads = match opts.n_threads {
        Some(n) => n,
        None => 1,
    };
    eprintln!("sending {} pkts", opts.n_pkts);

    let src_mac = parse_mac(&opts.src_mac).expect("failed to parse src mac addr");
    let dest_mac = parse_mac(&opts.dest_mac).expect("failed to parse dest mac addr");
    let filter = Filter::new(&opts.src_ip, opts.src_port, &opts.dest_ip, opts.dest_port).unwrap();

    let tx_send = xsk.tx_sender().unwrap();
    let mut send_handles = vec![];
    let pkts_per_thread = opts.n_pkts / n_send_threads;
    for i in 0..n_send_threads {
        let n_start = i * pkts_per_thread;
        let n_end = n_start + pkts_per_thread;
        eprintln!("thread {} sending nums {} to {}", i, n_start, n_end);
        let filter = filter.clone();
        let tx_send = tx_send.clone();

        let send_handle = thread::spawn(move || {
            for n in n_start..n_end {
                let pkt_builder = PacketBuilder::ethernet2(src_mac, dest_mac)
                    .ipv4(filter.src_ip, filter.dest_ip, 20)
                    .udp(filter.src_port, filter.dest_port);

                let pkt_with_payload = generate_pkt(pkt_builder, n);

                let mut packet: [u8; MAX_PACKET_SIZE] = [0; MAX_PACKET_SIZE];
                let l = std::cmp::min(MAX_PACKET_SIZE, pkt_with_payload.len());
                let packet_slice = &mut packet[..l];
                packet_slice.copy_from_slice(&pkt_with_payload[..l]);

                tx_send
                    .send((packet, pkt_with_payload.len()))
                    .expect("failed to put packet on tx queue");
            }
        });

        send_handles.push(send_handle);
    }
    drop(tx_send);

    for handle in send_handles.into_iter() {
        handle.join().expect("failed to join tx handle");
    }
    let tx_stats = xsk.shutdown_tx().expect("failed to shutdown tx");
    let rx_stats = xsk.shutdown_rx().expect("failed to shut down rx");
    eprintln!("tx_stats = {:?}", tx_stats);
    eprintln!("tx duration = {:?}", tx_stats.duration());
    eprintln!("tx pps = {:?}", tx_stats.pps());

    eprintln!("rx_stats = {:?}", rx_stats);
}

fn generate_pkt(pkt_builder: PacketBuilderStep<UdpHeader>, n: u64) -> Vec<u8> {
    let mut payload = vec![];
    payload.write_u64::<LittleEndian>(n).unwrap();
    //get some memory to store the result
    let mut result = Vec::<u8>::with_capacity(pkt_builder.size(payload.len()));

    //serialize
    pkt_builder
        .write(&mut result, &payload)
        .expect("failed to build packet");
    result
}

#[derive(Debug, Clone)]
struct Filter {
    src_ip: [u8; 4],
    src_port: u16,
    dest_ip: [u8; 4],
    dest_port: u16,
}

impl Filter {
    fn new(
        src_ip: &str,
        src_port: u16,
        dest_ip: &str,
        dest_port: u16,
    ) -> Result<Self, Box<dyn Error>> {
        let src_ipv4: Ipv4Addr = src_ip.parse()?;
        let dest_ipv4: Ipv4Addr = dest_ip.parse()?;

        Ok(Self {
            src_ip: src_ipv4.octets(),
            src_port,
            dest_ip: dest_ipv4.octets(),
            dest_port,
        })
    }
}

fn spawn_rx(mut xsk: Xsk2, opts: Opts) {
    let rx_recv = xsk.rx_receiver().unwrap();

    let filter = Filter::new(&opts.src_ip, opts.src_port, &opts.dest_ip, opts.dest_port).unwrap();

    let recv_handle = thread::spawn(move || {
        let mut recvd_nums: HashSet<u64> = HashSet::new();
        for (pkt, len) in rx_recv.iter() {
            match SlicedPacket::from_ethernet(&pkt[..len]) {
                Ok(pkt) => {
                    if filter_pkt(&pkt, &filter) {
                        let mut rdr = Cursor::new(&pkt.payload[0..8]);
                        let n = rdr.read_u64::<LittleEndian>().unwrap();
                        recvd_nums.insert(n);
                    }
                }
                Err(e) => log::warn!("failed to parse packet {:?}", e),
            }
        }
        recvd_nums
    });

    thread::sleep(Duration::from_secs(30));
    let rx_stats = xsk.shutdown_rx().expect("failed to shut down rx");
    eprintln!("rx_stats = {:?}", rx_stats);
    eprintln!("rx duration = {:?}", rx_stats.duration());
    eprintln!("rx pps = {:?}", rx_stats.pps());

    let tx_stats = xsk.shutdown_tx().expect("failed to shut down tx");
    eprintln!("tx_stats = {:?}", tx_stats);

    let recvd_nums = recv_handle.join().expect("failed to join recv handle");

    let expected_recvd_nums: Vec<u64> = (0..opts.n_pkts).into_iter().collect();

    let mut n_missing = 0;
    for n in expected_recvd_nums.iter() {
        if !recvd_nums.contains(n) {
            //eprintln!("missing {}", n);
            n_missing += 1;
        }
    }
    eprintln!("missing {} packets", n_missing);
}

fn filter_pkt(parsed_pkt: &SlicedPacket, filter: &Filter) -> bool {
    let mut ip_match = false;
    let mut transport_match = false;
    if let Some(ref ip) = parsed_pkt.ip {
        if let InternetSlice::Ipv4(ipv4) = ip {
            ip_match = (ipv4.source() == filter.src_ip) && (ipv4.destination() == filter.dest_ip);
        }
    }

    if let Some(ref transport) = parsed_pkt.transport {
        if let TransportSlice::Udp(udp) = transport {
            transport_match = (udp.source_port() == filter.src_port)
                && (udp.destination_port() == filter.dest_port);
        }
    }

    ip_match && transport_match
}

fn parse_mac(mac: &str) -> Result<[u8; 6], Box<dyn Error>> {
    let mut mac_bytes: [u8; 6] = [0; 6];
    let parts: Vec<&str> = mac.split(':').into_iter().collect();
    if parts.len() != 6 {
        Err("wrong len".into())
    } else {
        for (i, part) in parts.iter().enumerate() {
            let mac_byte = u8::from_str_radix(part, 16)?;
            mac_bytes[i] = mac_byte;
        }
        Ok(mac_bytes)
    }
}