use super::*;
use crate::message::decode::DecodedIcmpMsg;
use crate::message::echo::{parse_echo_reply, EchoId, EchoSeq, IcmpEchoRequest};
use crate::message::{IcmpV4MsgType, IcmpV6MsgType};
use crate::{platform, Icmpv4, Icmpv6};
use hex_literal::hex;
use itertools::Itertools;
use log::debug;
use std::sync::Arc;
use std::time;
#[test]
fn strip_ipv4_header_no_options() {
let message = hex!("4500003c99090000400100007f0000017f0000010000f07a695700004bf68a200b3877b6b8e38b893f57185e7b7f477b3d1f687f22c4b8d78355e97a");
let (contents, range) = strip_ipv4_header(&message).unwrap();
assert_eq!(message.len() - 5 * 4, contents.len());
assert_eq!(20..60, range);
}
#[test]
fn strip_ipv4_header_with_options() {
let message = hex!("4600003c99090000400100007f0000017f000001FFFFFFFF0000f07a695700004bf68a200b3877b6b8e38b893f57185e7b7f477b3d1f687f22c4b8d78355e97a");
let (contents, range) = strip_ipv4_header(&message).unwrap();
assert_eq!(message.len() - 6 * 4, contents.len());
assert_eq!(24..64, range);
}
#[test]
fn strip_ipv4_header_wrong_version_err() {
let message = hex!("5500003c99090000400100007f0000017f0000010000f07a695700004bf68a200b3877b6b8e38b893f57185e7b7f477b3d1f687f22c4b8d78355e97a");
let context_err = strip_ipv4_header(&message).unwrap_err().into_inner();
assert_eq!(
vec!["Invalid version"],
context_err.context().cloned().collect_vec()
);
}
#[tokio::test]
async fn ping_localhost_ipv4_strips_header() -> anyhow::Result<()> {
let s = Arc::new(IcmpSocket::<Icmpv4>::new(SocketConfig::default())?);
let data = rand::random::<[u8; 32]>();
let id = s.platform_echo_id().unwrap_or_else(rand::random);
let seq = EchoSeq::from_be(0x33_44);
let mut echo_request = IcmpEchoRequest::from_fields(id, seq, &data);
s.send_to(&mut echo_request, net::Ipv4Addr::LOCALHOST)
.await?;
let range_start = if platform::ipv4_recv_prefix_ipv4_header() {
20
} else {
0
};
recv_loop_until_echo_reply(
s.clone(),
IcmpV4MsgType::EchoReply as u8,
id,
seq,
&data,
range_start,
)
.await?;
Ok(())
}
#[tokio::test]
async fn ping_localhost_ipv6_returns_complete_msg() -> anyhow::Result<()> {
let s = Arc::new(IcmpSocket::<Icmpv6>::new(SocketConfig::default())?);
let orig_data = &[0xFF, 0x00, 0x00];
let id = s.platform_echo_id().unwrap_or_else(rand::random);
let mut echo_request = IcmpEchoRequest::from_fields(id, EchoSeq::from_be(0x33_44), orig_data);
let s_clone = s.clone();
let echo_clone = echo_request.clone();
let recv_handle = tokio::spawn(async move {
recv_loop_until_echo_reply(
s_clone,
IcmpV6MsgType::EchoReply as u8,
id,
echo_clone.seq(),
orig_data,
0,
)
.await?;
Ok::<_, anyhow::Error>(())
});
s.send_to(&mut echo_request, net::Ipv6Addr::LOCALHOST)
.await?;
recv_handle.await??;
Ok(())
}
#[tokio::test]
async fn local_port_doesnt_crash() -> anyhow::Result<()> {
let s = IcmpSocket::<Icmpv6>::new(SocketConfig::default())?;
if platform::socket_bind_sets_nonzero_local_port() {
assert_ne!(0, s.local_port());
} else {
assert_eq!(0, s.local_port());
}
Ok(())
}
#[tokio::test]
async fn socket_linux_bind_sets_nonzero_port_ipv4() -> anyhow::Result<()> {
if platform::socket_bind_sets_nonzero_local_port() {
raw_socket_bind_sets_nonzero_port::<Icmpv4>(
net::Ipv4Addr::LOCALHOST,
IcmpV4MsgType::EchoReply as u8,
)
.await?;
}
Ok(())
}
#[tokio::test]
async fn socket_linux_bind_sets_nonzero_port_ipv6() -> anyhow::Result<()> {
if platform::socket_bind_sets_nonzero_local_port() {
raw_socket_bind_sets_nonzero_port::<Icmpv6>(
net::Ipv6Addr::LOCALHOST,
IcmpV6MsgType::EchoReply as u8,
)
.await?;
}
Ok(())
}
#[tokio::test]
async fn socket_macos_bind_doesnt_set_local_port_ipv4() -> anyhow::Result<()> {
if !platform::socket_bind_sets_nonzero_local_port() {
check_bind_doesnt_affect_local_addr::<Icmpv4>(net::Ipv4Addr::LOCALHOST).await?
}
Ok(())
}
#[tokio::test]
async fn socket_macos_bind_doesnt_set_local_port_ipv6() -> anyhow::Result<()> {
if !platform::socket_bind_sets_nonzero_local_port() {
check_bind_doesnt_affect_local_addr::<Icmpv6>(net::Ipv6Addr::LOCALHOST).await?
}
Ok(())
}
async fn raw_socket_bind_sets_nonzero_port<V>(
localhost: V::Address,
echo_reply_type: u8,
) -> anyhow::Result<()>
where
V: IcmpVersion,
IcmpEchoRequest: EncodeIcmpMessage<V>,
{
let s = socket2::Socket::new(V::DOMAIN, socket2::Type::DGRAM, Some(V::PROTOCOL))?;
s.set_nonblocking(true)?;
let sockaddr_zero = V::DEFAULT_BIND;
assert_eq!(sockaddr_zero.into(), s.local_addr()?.as_socket().unwrap());
s.bind(&sockaddr_zero.into().into())?;
let local_addr_after_bind = s.local_addr()?.as_socket().unwrap();
assert_eq!(sockaddr_zero.into().ip(), local_addr_after_bind.ip());
assert_ne!(0, local_addr_after_bind.port());
let s = IcmpSocket {
fd: unix::AsyncFd::new(IcmpSocketInner {
socket: s,
marker: marker::PhantomData::<V>,
})?,
local_port: local_addr_after_bind.port(),
};
let id = s.platform_echo_id().unwrap_or_else(rand::random);
let seq = EchoSeq::from_be(0x5555);
let data = &[0x66, 0x66];
let mut echo = IcmpEchoRequest::from_fields(id, seq, data);
s.send_to(&mut echo, localhost).await?;
assert_eq!(
local_addr_after_bind,
s.fd.get_ref().socket.local_addr()?.as_socket().unwrap()
);
let s = Arc::new(s);
recv_loop_until_echo_reply(s.clone(), echo_reply_type, id, seq, data, 0).await?;
assert_eq!(
local_addr_after_bind,
s.fd.get_ref().socket.local_addr()?.as_socket().unwrap()
);
Ok(())
}
async fn check_bind_doesnt_affect_local_addr<V>(localhost: V::Address) -> anyhow::Result<()>
where
V: IcmpVersion,
IcmpEchoRequest: EncodeIcmpMessage<V>,
{
let s = socket2::Socket::new(V::DOMAIN, socket2::Type::DGRAM, Some(V::PROTOCOL))?;
s.set_nonblocking(true)?;
let sockaddr_zero = V::DEFAULT_BIND;
assert_eq!(sockaddr_zero.into(), s.local_addr()?.as_socket().unwrap());
s.bind(&sockaddr_zero.into().into())?;
let local_addr = s.local_addr()?.as_socket().unwrap();
assert_eq!(sockaddr_zero.into(), local_addr);
let s = IcmpSocket {
fd: unix::AsyncFd::new(IcmpSocketInner {
socket: s,
marker: marker::PhantomData::<V>,
})?,
local_port: local_addr.port(),
};
let id = s.platform_echo_id().unwrap_or_else(rand::random);
let seq = EchoSeq::from_be(0x5555);
let data = &[0x66, 0x66];
let mut echo = IcmpEchoRequest::from_fields(id, seq, data);
s.send_to(&mut echo, localhost).await?;
assert_eq!(
sockaddr_zero.into(),
s.fd.get_ref().socket.local_addr()?.as_socket().unwrap()
);
Ok(())
}
async fn recv_loop_until_echo_reply<V: IcmpVersion>(
socket: Arc<IcmpSocket<V>>,
icmp_msg_type: u8,
expected_id: EchoId,
expected_seq: EchoSeq,
expected_data: &[u8],
expected_range_start: usize,
) -> anyhow::Result<()> {
let mut buf = vec![0; 10_000];
loop {
let (msg, range) =
tokio::time::timeout(time::Duration::from_millis(100), socket.recv(&mut buf)).await??;
let decoded = DecodedIcmpMsg::decode(msg)?;
if decoded.msg_type() != icmp_msg_type || decoded.msg_code() != 0 {
debug!("Not an Echo Reply: {:?}", decoded);
continue;
}
let (id, seq, recv_data) = parse_echo_reply(decoded.body()).unwrap();
if id != expected_id || recv_data != expected_data {
debug!(
"Skipping unexpected id {:?} data {}",
id,
hex::encode(recv_data)
);
continue;
}
assert_eq!(expected_seq, seq);
assert_eq!(
[expected_id.as_slice(), seq.as_slice(), expected_data].concat(),
decoded.body()
);
assert_eq!(
expected_range_start..(expected_range_start + 1 + 1 + 2 + 2 + 2 + expected_data.len()),
range
);
break;
}
Ok(())
}