packet_builder/
ipv4.rs

1use pnet::packet::ipv4::{MutableIpv4Packet};
2
3pub const IPV4_HEADER_LEN: usize = 20;
4pub const DEFAULT_SOURCE: std::net::Ipv4Addr = std::net::Ipv4Addr::new(127,0,0,1);
5pub const DEFAULT_DESTINATION: std::net::Ipv4Addr = std::net::Ipv4Addr::new(127,0,0,1);
6
7pub fn init_ipv4_pkt(pkt: &mut MutableIpv4Packet, len: u16) -> () {
8  pkt.set_version(4);
9  pkt.set_header_length(5);
10  pkt.set_total_length(len);
11  pkt.set_ttl(128);
12  // TODO make the ID a random value
13  pkt.set_identification(256);
14  pkt.set_fragment_offset(0);
15  pkt.set_flags(pnet::packet::ipv4::Ipv4Flags::DontFragment);
16}
17
18#[macro_export]
19macro_rules! extract_address {
20  (set_source, $value:expr) => {{
21    $value
22  }};
23  (set_destination, $value:expr) => {{
24    $value
25  }};
26  ($func:ident, $value:expr) => {{
27    println!("Unexpected case matched in extract_address: {} {}", stringify!($func), stringify!($value));
28    ipv4::DEFAULT_SOURCE
29  }};
30}
31
32
33#[macro_export]
34macro_rules! ipv4 {
35   ({$($func:ident => $value:expr), *}, $l4_pkt:expr, $protocol:expr, $buf:expr) => {{
36
37      let total_len = ipv4::IPV4_HEADER_LEN + $l4_pkt.packet().len();
38      let mut source = ipv4::DEFAULT_SOURCE;
39      let mut dest = ipv4::DEFAULT_DESTINATION;
40      // Get the source/destination IP addresses so we can set the L4 checksum before
41      // creating the MutableIpv4Packet which is another mutable reference to the packet buffer.
42      // Once the MutableIpv4Packet is created we can't use $l4_pkt or we will get borrow errors.
43      $(
44        // If we only used this match without calling the extract_address macro, the compiler can't
45        // determine which func/value combos apply to which branch of the match and it assume they
46        // can all match which will cause type errors.  The extract_address macro avoids this
47        // problem.
48        match stringify!($func) {
49          "set_source" => source = extract_address!($func, $value),
50          "set_destination" => dest = extract_address!($func, $value),
51          _ => (),
52        }
53      )*
54
55      $l4_pkt.checksum_ipv4(&source, &dest);
56      let buf_len = $buf.len();
57      let mut pkt = pnet::packet::ipv4::MutableIpv4Packet::new(&mut $buf[buf_len - total_len..]).unwrap();
58      pkt.set_next_level_protocol($protocol);
59      ipv4::init_ipv4_pkt(&mut pkt, total_len as u16);
60      $(
61        pkt.$func($value);
62      )*
63      pkt.set_checksum(pnet::packet::ipv4::checksum(&pkt.to_immutable()));
64
65      (pkt, pnet::packet::ethernet::EtherTypes::Ipv4)
66   }};
67}
68
69
70#[macro_export]
71macro_rules! ipv4addr {
72  ($addr_str:expr) => {{
73    $addr_str.parse().unwrap()
74  }};
75}
76
77#[cfg(test)]
78mod tests {
79   use pnet::packet::Packet;
80   use pnet::packet::ethernet::EtherTypes::Ipv4;
81   use L4Checksum;
82   use ::payload;
83   use payload::PayloadData;
84   use ipv4;
85
86   #[test]
87   fn macro_ipv4_basic() {
88      let mut buf = [0; 25];
89      let (pkt, proto) = ipv4!({set_source => ipv4addr!("127.0.0.1"), set_destination => ipv4addr!("192.168.1.1"), set_version => 4},
90        payload!({"hello".to_string().into_bytes()}, buf).0, pnet::packet::ip::IpNextHeaderProtocols::Udp, buf);
91      assert_eq!(proto, Ipv4);
92
93      let buf_expected = vec![0; 25];
94      let mut pkt_expected = pnet::packet::ipv4::MutableIpv4Packet::owned(buf_expected).unwrap();
95      pkt_expected.set_destination(ipv4addr!("192.168.1.1")); 
96      pkt_expected.set_source(ipv4addr!("127.0.0.1")); 
97      pkt_expected.set_version(4);
98      pkt_expected.set_header_length(5);
99      pkt_expected.set_total_length(25);
100      pkt_expected.set_payload(&"hello".to_string().into_bytes());
101      pkt_expected.set_ttl(128);
102      pkt_expected.set_identification(256);
103      pkt_expected.set_flags(pnet::packet::ipv4::Ipv4Flags::DontFragment);
104      pkt_expected.set_next_level_protocol(pnet::packet::ip::IpNextHeaderProtocols::Udp);
105      pkt_expected.set_checksum(pnet::packet::ipv4::checksum(&pkt_expected.to_immutable()));
106      assert_eq!(pkt_expected.packet(), pkt.packet());
107   }
108}
109