use crate::tracing::error::TracerError::InvalidSourceAddr;
use crate::tracing::error::{TraceResult, TracerError};
use crate::tracing::net::platform::PlatformIpv4FieldByteOrder;
use crate::tracing::net::{ipv4, ipv6, Network};
use crate::tracing::probe::ProbeResponse;
use crate::tracing::types::{PacketSize, PayloadPattern, Port, Sequence, TraceId, TypeOfService};
use crate::tracing::util::Required;
use crate::tracing::{
MultipathStrategy, PortDirection, Probe, TracerAddrFamily, TracerChannelConfig, TracerProtocol,
};
use arrayvec::ArrayVec;
use itertools::Itertools;
use nix::sys::select::FdSet;
use nix::sys::time::{TimeVal, TimeValLike};
use socket2::{Domain, Protocol, SockAddr, Socket, Type};
use std::net::{IpAddr, SocketAddr};
use std::os::unix::io::AsRawFd;
use std::time::{Duration, SystemTime};
pub const MAX_PACKET_SIZE: usize = 1024;
const MAX_TCP_PROBES: usize = 256;
const DISCOVERY_PORT: Port = Port(80);
pub struct TracerChannel {
protocol: TracerProtocol,
addr_family: TracerAddrFamily,
src_addr: IpAddr,
ipv4_length_order: PlatformIpv4FieldByteOrder,
dest_addr: IpAddr,
identifier: TraceId,
packet_size: PacketSize,
payload_pattern: PayloadPattern,
tos: TypeOfService,
initial_sequence: Sequence,
multipath_strategy: MultipathStrategy,
port_direction: PortDirection,
read_timeout: Duration,
tcp_connect_timeout: Duration,
icmp_send_socket: Socket,
udp_send_socket: Socket,
recv_socket: Socket,
tcp_probes: ArrayVec<TcpProbe, MAX_TCP_PROBES>,
}
impl TracerChannel {
pub fn connect(config: &TracerChannelConfig) -> TraceResult<Self> {
if usize::from(config.packet_size.0) > MAX_PACKET_SIZE {
return Err(TracerError::InvalidPacketSize(usize::from(
config.packet_size.0,
)));
}
let src_addr = make_src_addr(
config.source_addr,
config.target_addr,
config.port_direction,
config.interface.as_deref(),
config.addr_family,
)?;
let ipv4_length_order = PlatformIpv4FieldByteOrder::for_address(src_addr)?;
let icmp_send_socket = make_icmp_send_socket(config.addr_family)?;
let udp_send_socket = make_udp_send_socket(config.addr_family)?;
let recv_socket = make_recv_socket(config.addr_family)?;
Ok(Self {
protocol: config.protocol,
addr_family: config.addr_family,
src_addr,
ipv4_length_order,
dest_addr: config.target_addr,
identifier: config.identifier,
packet_size: config.packet_size,
payload_pattern: config.payload_pattern,
tos: config.tos,
initial_sequence: config.initial_sequence,
multipath_strategy: config.multipath_strategy,
port_direction: config.port_direction,
read_timeout: config.read_timeout,
tcp_connect_timeout: config.tcp_connect_timeout,
icmp_send_socket,
udp_send_socket,
recv_socket,
tcp_probes: ArrayVec::new(),
})
}
#[must_use]
pub fn src_addr(&self) -> IpAddr {
self.src_addr
}
}
impl Network for TracerChannel {
fn send_probe(&mut self, probe: Probe) -> TraceResult<()> {
match self.protocol {
TracerProtocol::Icmp => self.dispatch_icmp_probe(probe),
TracerProtocol::Udp => self.dispatch_udp_probe(probe),
TracerProtocol::Tcp => self.dispatch_tcp_probe(probe),
}
}
fn recv_probe(&mut self) -> TraceResult<Option<ProbeResponse>> {
match self.protocol {
TracerProtocol::Icmp | TracerProtocol::Udp => self.recv_icmp_probe(),
TracerProtocol::Tcp => Ok(self.recv_tcp_sockets()?.or(self.recv_icmp_probe()?)),
}
}
}
impl TracerChannel {
fn dispatch_icmp_probe(&mut self, probe: Probe) -> TraceResult<()> {
match (self.addr_family, self.src_addr, self.dest_addr) {
(TracerAddrFamily::Ipv4, IpAddr::V4(src_addr), IpAddr::V4(dest_addr)) => {
ipv4::dispatch_icmp_probe(
&mut self.icmp_send_socket,
probe,
src_addr,
dest_addr,
self.identifier,
self.packet_size,
self.payload_pattern,
self.ipv4_length_order,
)
}
(TracerAddrFamily::Ipv6, IpAddr::V6(src_addr), IpAddr::V6(dest_addr)) => {
ipv6::dispatch_icmp_probe(
&mut self.icmp_send_socket,
probe,
src_addr,
dest_addr,
self.identifier,
self.packet_size,
self.payload_pattern,
)
}
_ => unreachable!(),
}
}
fn dispatch_udp_probe(&mut self, probe: Probe) -> TraceResult<()> {
match (self.addr_family, self.src_addr, self.dest_addr) {
(TracerAddrFamily::Ipv4, IpAddr::V4(src_addr), IpAddr::V4(dest_addr)) => {
ipv4::dispatch_udp_probe(
&mut self.udp_send_socket,
probe,
src_addr,
dest_addr,
self.initial_sequence,
self.multipath_strategy,
self.port_direction,
self.packet_size,
self.payload_pattern,
self.ipv4_length_order,
)
}
(TracerAddrFamily::Ipv6, IpAddr::V6(src_addr), IpAddr::V6(dest_addr)) => {
ipv6::dispatch_udp_probe(
&mut self.udp_send_socket,
probe,
src_addr,
dest_addr,
self.port_direction,
self.packet_size,
self.payload_pattern,
)
}
_ => unreachable!(),
}
}
fn dispatch_tcp_probe(&mut self, probe: Probe) -> TraceResult<()> {
let socket = match (self.addr_family, self.src_addr, self.dest_addr) {
(TracerAddrFamily::Ipv4, IpAddr::V4(src_addr), IpAddr::V4(dest_addr)) => {
ipv4::dispatch_tcp_probe(probe, src_addr, dest_addr, self.port_direction, self.tos)
}
(TracerAddrFamily::Ipv6, IpAddr::V6(src_addr), IpAddr::V6(dest_addr)) => {
ipv6::dispatch_tcp_probe(probe, src_addr, dest_addr, self.port_direction)
}
_ => unreachable!(),
}?;
self.tcp_probes
.push(TcpProbe::new(socket, SystemTime::now()));
Ok(())
}
fn recv_icmp_probe(&mut self) -> TraceResult<Option<ProbeResponse>> {
if is_readable(&self.recv_socket, self.read_timeout)? {
match self.addr_family {
TracerAddrFamily::Ipv4 => ipv4::recv_icmp_probe(
&mut self.recv_socket,
self.protocol,
self.multipath_strategy,
self.port_direction,
),
TracerAddrFamily::Ipv6 => {
ipv6::recv_icmp_probe(&mut self.recv_socket, self.protocol, self.port_direction)
}
}
} else {
Ok(None)
}
}
fn recv_tcp_sockets(&mut self) -> TraceResult<Option<ProbeResponse>> {
self.tcp_probes
.retain(|probe| probe.start.elapsed().unwrap_or_default() < self.tcp_connect_timeout);
let found_index = self
.tcp_probes
.iter()
.find_position(|&probe| is_writable(&probe.socket).unwrap_or_default())
.map(|(i, _)| i);
if let Some(i) = found_index {
let probe = self.tcp_probes.remove(i);
match self.addr_family {
TracerAddrFamily::Ipv4 => ipv4::recv_tcp_socket(&probe.socket, self.dest_addr),
TracerAddrFamily::Ipv6 => ipv6::recv_tcp_socket(&probe.socket, self.dest_addr),
}
} else {
Ok(None)
}
}
}
#[derive(Debug)]
struct TcpProbe {
socket: Socket,
start: SystemTime,
}
impl TcpProbe {
pub fn new(socket: Socket, start: SystemTime) -> Self {
Self { socket, start }
}
}
fn is_readable(sock: &Socket, timeout: Duration) -> TraceResult<bool> {
let mut read = FdSet::new();
read.insert(sock.as_raw_fd());
let readable = nix::sys::select::select(
None,
Some(&mut read),
None,
None,
Some(&mut TimeVal::milliseconds(timeout.as_millis() as i64)),
)
.map_err(|err| TracerError::IoError(std::io::Error::from(err)))?;
Ok(readable == 1)
}
fn is_writable(sock: &Socket) -> TraceResult<bool> {
let mut write = FdSet::new();
write.insert(sock.as_raw_fd());
let writable = nix::sys::select::select(
None,
None,
Some(&mut write),
None,
Some(&mut TimeVal::zero()),
)
.map_err(|err| TracerError::IoError(std::io::Error::from(err)))?;
Ok(writable == 1)
}
fn make_src_addr(
source_addr: Option<IpAddr>,
target_addr: IpAddr,
port_direction: PortDirection,
interface: Option<&str>,
addr_family: TracerAddrFamily,
) -> TraceResult<IpAddr> {
match (source_addr, interface.as_ref()) {
(Some(addr), None) => validate_local_addr(addr_family, addr),
(None, Some(interface)) => lookup_interface_addr(addr_family, interface),
(None, None) => discover_local_addr(
addr_family,
target_addr,
port_direction.dest().unwrap_or(DISCOVERY_PORT).0,
),
(Some(_), Some(_)) => unreachable!(),
}
}
fn lookup_interface_addr(addr_family: TracerAddrFamily, name: &str) -> TraceResult<IpAddr> {
match addr_family {
TracerAddrFamily::Ipv4 => ipv4::lookup_interface_addr(name),
TracerAddrFamily::Ipv6 => ipv6::lookup_interface_addr(name),
}
}
fn discover_local_addr(
addr_family: TracerAddrFamily,
target: IpAddr,
port: u16,
) -> TraceResult<IpAddr> {
let socket = udp_socket_for_addr_family(addr_family)?;
socket.connect(&SockAddr::from(SocketAddr::new(target, port)))?;
Ok(socket.local_addr()?.as_socket().req()?.ip())
}
fn validate_local_addr(addr_family: TracerAddrFamily, source_addr: IpAddr) -> TraceResult<IpAddr> {
let socket = udp_socket_for_addr_family(addr_family)?;
let addr = SocketAddr::new(source_addr, 0);
match socket.bind(&SockAddr::from(addr)) {
Ok(_) => Ok(source_addr),
Err(_) => Err(InvalidSourceAddr(addr.ip())),
}
}
fn udp_socket_for_addr_family(addr_family: TracerAddrFamily) -> TraceResult<Socket> {
Ok(match addr_family {
TracerAddrFamily::Ipv4 => Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?,
TracerAddrFamily::Ipv6 => Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP))?,
})
}
fn make_icmp_send_socket(addr_family: TracerAddrFamily) -> TraceResult<Socket> {
match addr_family {
TracerAddrFamily::Ipv4 => ipv4::make_icmp_send_socket(),
TracerAddrFamily::Ipv6 => ipv6::make_icmp_send_socket(),
}
}
fn make_udp_send_socket(addr_family: TracerAddrFamily) -> TraceResult<Socket> {
match addr_family {
TracerAddrFamily::Ipv4 => ipv4::make_udp_send_socket(),
TracerAddrFamily::Ipv6 => ipv6::make_udp_send_socket(),
}
}
fn make_recv_socket(addr_family: TracerAddrFamily) -> TraceResult<Socket> {
match addr_family {
TracerAddrFamily::Ipv4 => ipv4::make_recv_socket(),
TracerAddrFamily::Ipv6 => ipv6::make_recv_socket(),
}
}