use self::state::TracerState;
use crate::tracing::error::{TraceResult, TracerError};
use crate::tracing::net::Network;
use crate::tracing::probe::ProbeResponse;
use crate::tracing::types::{MaxInflight, MaxRounds, Sequence, TimeToLive, TraceId};
use crate::tracing::util::Required;
use crate::tracing::TracerProtocol;
use crate::tracing::{Probe, TracerConfig};
use std::net::IpAddr;
use std::time::{Duration, SystemTime};
#[derive(Debug, Clone)]
pub struct TracerRound<'a> {
pub probes: &'a [Probe],
pub largest_ttl: TimeToLive,
pub reason: CompletionReason,
}
impl<'a> TracerRound<'a> {
#[must_use]
pub fn new(probes: &'a [Probe], largest_ttl: TimeToLive, reason: CompletionReason) -> Self {
Self {
probes,
largest_ttl,
reason,
}
}
}
#[derive(Debug, Copy, Clone)]
pub enum CompletionReason {
TargetFound,
RoundTimeLimitExceeded,
}
#[derive(Debug, Clone)]
pub struct Tracer<F> {
target_addr: IpAddr,
protocol: TracerProtocol,
trace_identifier: TraceId,
max_rounds: Option<MaxRounds>,
first_ttl: TimeToLive,
max_ttl: TimeToLive,
grace_duration: Duration,
max_inflight: MaxInflight,
initial_sequence: Sequence,
min_round_duration: Duration,
max_round_duration: Duration,
publish: F,
}
impl<F: Fn(&TracerRound<'_>)> Tracer<F> {
pub fn new(config: &TracerConfig, publish: F) -> Self {
Self {
target_addr: config.target_addr,
protocol: config.protocol,
trace_identifier: config.trace_identifier,
max_rounds: config.max_rounds,
first_ttl: config.first_ttl,
max_ttl: config.max_ttl,
grace_duration: config.grace_duration,
max_inflight: config.max_inflight,
initial_sequence: config.initial_sequence,
min_round_duration: config.min_round_duration,
max_round_duration: config.max_round_duration,
publish,
}
}
pub fn trace<N: Network>(self, mut network: N) -> TraceResult<()> {
let mut state = TracerState::new(self.first_ttl, self.initial_sequence);
while !state.finished(self.max_rounds) {
self.send_request(&mut network, &mut state)?;
self.recv_response(&mut network, &mut state)?;
self.update_round(&mut state);
}
Ok(())
}
fn send_request<N: Network>(&self, network: &mut N, st: &mut TracerState) -> TraceResult<()> {
let can_send_ttl = if let Some(target_ttl) = st.target_ttl() {
st.ttl() <= target_ttl
} else {
st.ttl() - st.max_received_ttl().unwrap_or_default() < TimeToLive(self.max_inflight.0)
};
if !st.target_found() && st.ttl() <= self.max_ttl && can_send_ttl {
match self.protocol {
TracerProtocol::Icmp => {
network.send_probe(st.next_probe())?;
}
TracerProtocol::Udp => network.send_probe(st.next_probe())?,
TracerProtocol::Tcp => {
let mut probe = if st.round_has_capacity() {
st.next_probe()
} else {
return Err(TracerError::InsufficientCapacity);
};
while let Err(err) = network.send_probe(probe) {
match err {
TracerError::AddressNotAvailable(_) => {
if st.round_has_capacity() {
probe = st.reissue_probe();
} else {
return Err(TracerError::InsufficientCapacity);
}
}
other => return Err(other),
}
}
}
};
}
Ok(())
}
fn recv_response<N: Network>(&self, network: &mut N, st: &mut TracerState) -> TraceResult<()> {
let next = network.recv_probe()?;
match next {
Some(ProbeResponse::TimeExceeded(data)) => {
let sequence = Sequence(data.sequence);
let received = data.recv;
let host = data.addr;
let is_target = host == self.target_addr;
let trace_id = TraceId(data.identifier);
if self.check_trace_id(trace_id) && st.in_round(sequence) {
st.complete_probe_time_exceeded(sequence, host, received, is_target);
}
}
Some(ProbeResponse::DestinationUnreachable(data)) => {
let sequence = Sequence(data.sequence);
let received = data.recv;
let host = data.addr;
let trace_id = TraceId(data.identifier);
if self.check_trace_id(trace_id) && st.in_round(sequence) {
st.complete_probe_unreachable(sequence, host, received);
}
}
Some(ProbeResponse::EchoReply(data)) => {
let sequence = Sequence(data.sequence);
let received = data.recv;
let host = data.addr;
let trace_id = TraceId(data.identifier);
if self.check_trace_id(trace_id) && st.in_round(sequence) {
st.complete_probe_echo_reply(sequence, host, received);
}
}
Some(ProbeResponse::TcpReply(data) | ProbeResponse::TcpRefused(data)) => {
let ttl = TimeToLive(data.ttl);
let received = data.recv;
let host = data.addr;
let probe = st.probe_for_ttl(ttl).req()?;
let sequence = probe.sequence;
if st.in_round(sequence) {
st.complete_probe_other(sequence, host, received);
}
}
None => {}
}
Ok(())
}
fn update_round(&self, st: &mut TracerState) {
let now = SystemTime::now();
let round_duration = now.duration_since(st.round_start()).unwrap_or_default();
let round_min = round_duration > self.min_round_duration;
let grace_exceeded = exceeds(st.received_time(), now, self.grace_duration);
let round_max = round_duration > self.max_round_duration;
let target_found = st.target_found();
if round_min && grace_exceeded && target_found || round_max {
self.publish_trace(st);
st.advance_round(self.first_ttl);
}
}
fn publish_trace(&self, state: &TracerState) {
let max_received_ttl = if let Some(target_ttl) = state.target_ttl() {
target_ttl
} else {
state
.max_received_ttl()
.map_or(TimeToLive(0), |max_received_ttl| {
let max_sent_ttl = state.ttl() - TimeToLive(1);
max_sent_ttl.min(max_received_ttl + TimeToLive(1))
})
};
let probes = state.probes();
let largest_ttl = max_received_ttl;
let reason = if state.target_found() {
CompletionReason::TargetFound
} else {
CompletionReason::RoundTimeLimitExceeded
};
(self.publish)(&TracerRound::new(probes, largest_ttl, reason));
}
fn check_trace_id(&self, trace_id: TraceId) -> bool {
self.trace_identifier == trace_id || trace_id == TraceId(0)
}
}
mod state {
use crate::tracing::types::{MaxRounds, Round, Sequence, TimeToLive};
use crate::tracing::{IcmpPacketType, Probe, ProbeStatus};
use std::net::IpAddr;
use std::time::SystemTime;
const BUFFER_SIZE: u16 = 1024;
const MAX_SEQUENCE: Sequence = Sequence(u16::MAX - BUFFER_SIZE);
#[derive(Debug)]
pub struct TracerState {
buffer: [Probe; BUFFER_SIZE as usize],
initial_sequence: Sequence,
sequence: Sequence,
round_sequence: Sequence,
ttl: TimeToLive,
round: Round,
round_start: SystemTime,
target_found: bool,
max_received_ttl: Option<TimeToLive>,
target_ttl: Option<TimeToLive>,
received_time: Option<SystemTime>,
}
impl TracerState {
pub fn new(first_ttl: TimeToLive, initial_sequence: Sequence) -> Self {
Self {
buffer: [Probe::default(); BUFFER_SIZE as usize],
initial_sequence,
sequence: initial_sequence,
round_sequence: initial_sequence,
ttl: first_ttl,
round: Round(0),
round_start: SystemTime::now(),
target_found: false,
max_received_ttl: None,
target_ttl: None,
received_time: None,
}
}
pub fn probes(&self) -> &[Probe] {
let round_size = self.sequence - self.round_sequence;
&self.buffer[..round_size.0 as usize]
}
pub fn probe_at(&self, sequence: Sequence) -> Probe {
self.buffer[usize::from(sequence - self.round_sequence)]
}
pub fn probe_for_ttl(&self, ttl: TimeToLive) -> Option<&Probe> {
self.probes().iter().find(|p| p.ttl == ttl)
}
pub const fn ttl(&self) -> TimeToLive {
self.ttl
}
pub const fn round_start(&self) -> SystemTime {
self.round_start
}
pub const fn target_found(&self) -> bool {
self.target_found
}
pub const fn max_received_ttl(&self) -> Option<TimeToLive> {
self.max_received_ttl
}
pub const fn target_ttl(&self) -> Option<TimeToLive> {
self.target_ttl
}
pub const fn received_time(&self) -> Option<SystemTime> {
self.received_time
}
pub fn in_round(&self, sequence: Sequence) -> bool {
sequence >= self.round_sequence && sequence.0 - self.round_sequence.0 < BUFFER_SIZE
}
pub fn round_has_capacity(&self) -> bool {
let round_size = self.sequence - self.round_sequence;
round_size.0 < BUFFER_SIZE
}
pub fn finished(&self, max_rounds: Option<MaxRounds>) -> bool {
match max_rounds {
None => false,
Some(max_rounds) => self.round.0 > max_rounds.0,
}
}
pub fn next_probe(&mut self) -> Probe {
let probe = Probe::new(self.sequence, self.ttl, self.round, SystemTime::now());
self.buffer[usize::from(self.sequence - self.round_sequence)] = probe;
debug_assert!(self.ttl < TimeToLive(u8::MAX));
self.ttl += TimeToLive(1);
debug_assert!(self.sequence < Sequence(u16::MAX));
self.sequence += Sequence(1);
probe
}
pub fn reissue_probe(&mut self) -> Probe {
self.buffer[usize::from(self.sequence - self.round_sequence) - 1] = Probe::default();
let probe = Probe::new(
self.sequence,
self.ttl - TimeToLive(1),
self.round,
SystemTime::now(),
);
self.buffer[usize::from(self.sequence - self.round_sequence)] = probe;
debug_assert!(self.sequence < Sequence(u16::MAX));
self.sequence += Sequence(1);
probe
}
pub fn complete_probe_time_exceeded(
&mut self,
sequence: Sequence,
host: IpAddr,
received: SystemTime,
is_target: bool,
) {
self.complete_probe(
sequence,
IcmpPacketType::TimeExceeded,
host,
received,
is_target,
);
}
pub fn complete_probe_unreachable(
&mut self,
sequence: Sequence,
host: IpAddr,
received: SystemTime,
) {
self.complete_probe(sequence, IcmpPacketType::Unreachable, host, received, true);
}
pub fn complete_probe_echo_reply(
&mut self,
sequence: Sequence,
host: IpAddr,
received: SystemTime,
) {
self.complete_probe(sequence, IcmpPacketType::EchoReply, host, received, true);
}
pub fn complete_probe_other(
&mut self,
sequence: Sequence,
host: IpAddr,
received: SystemTime,
) {
self.complete_probe(
sequence,
IcmpPacketType::NotApplicable,
host,
received,
true,
);
}
fn complete_probe(
&mut self,
sequence: Sequence,
icmp_packet_type: IcmpPacketType,
host: IpAddr,
received: SystemTime,
is_target: bool,
) {
let probe = self
.probe_at(sequence)
.with_status(ProbeStatus::Complete)
.with_icmp_packet_type(icmp_packet_type)
.with_host(host)
.with_received(received);
self.buffer[usize::from(sequence - self.round_sequence)] = probe;
if is_target {
self.target_ttl = match self.target_ttl {
None => Some(probe.ttl),
Some(ttl) if probe.ttl < ttl => Some(probe.ttl),
Some(ttl) => Some(ttl),
};
}
self.max_received_ttl = match self.max_received_ttl {
None => Some(probe.ttl),
Some(max_received_ttl) => Some(max_received_ttl.max(probe.ttl)),
};
self.received_time = Some(received);
self.target_found |= is_target;
}
pub fn advance_round(&mut self, first_ttl: TimeToLive) {
if self.sequence >= MAX_SEQUENCE {
self.sequence = self.initial_sequence;
}
self.target_found = false;
self.round_sequence = self.sequence;
self.received_time = None;
self.round_start = SystemTime::now();
self.max_received_ttl = None;
self.round += Round(1);
self.ttl = first_ttl;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tracing::probe::IcmpPacketType;
use crate::tracing::ProbeStatus;
use rand::Rng;
use std::net::{IpAddr, Ipv4Addr};
#[allow(
clippy::cognitive_complexity,
clippy::too_many_lines,
clippy::bool_assert_comparison
)]
#[test]
fn test_state() {
let mut state = TracerState::new(TimeToLive(1), Sequence(33000));
assert_eq!(state.round, Round(0));
assert_eq!(state.sequence, Sequence(33000));
assert_eq!(state.round_sequence, Sequence(33000));
assert_eq!(state.ttl, TimeToLive(1));
assert_eq!(state.max_received_ttl, None);
assert_eq!(state.received_time, None);
assert_eq!(state.target_ttl, None);
assert_eq!(state.target_found, false);
let prob_init = state.probe_at(Sequence(33000));
assert_eq!(prob_init.sequence, Sequence(0));
assert_eq!(prob_init.ttl, TimeToLive(0));
assert_eq!(prob_init.round, Round(0));
assert_eq!(prob_init.received, None);
assert_eq!(prob_init.host, None);
assert_eq!(prob_init.sent.is_some(), false);
assert_eq!(prob_init.status, ProbeStatus::NotSent);
assert_eq!(prob_init.icmp_packet_type, None);
let probe_1 = state.next_probe();
assert_eq!(probe_1.sequence, Sequence(33000));
assert_eq!(probe_1.ttl, TimeToLive(1));
assert_eq!(probe_1.round, Round(0));
assert_eq!(probe_1.received, None);
assert_eq!(probe_1.host, None);
assert_eq!(probe_1.sent.is_some(), true);
assert_eq!(probe_1.status, ProbeStatus::Awaited);
assert_eq!(probe_1.icmp_packet_type, None);
let received_1 = SystemTime::now();
let host = IpAddr::V4(Ipv4Addr::LOCALHOST);
state.complete_probe_time_exceeded(Sequence(33000), host, received_1, false);
let probe_1_fetch = state.probe_at(Sequence(33000));
assert_eq!(probe_1_fetch.sequence, Sequence(33000));
assert_eq!(probe_1_fetch.ttl, TimeToLive(1));
assert_eq!(probe_1_fetch.round, Round(0));
assert_eq!(probe_1_fetch.received, Some(received_1));
assert_eq!(probe_1_fetch.host, Some(host));
assert_eq!(probe_1_fetch.sent.is_some(), true);
assert_eq!(probe_1_fetch.status, ProbeStatus::Complete);
assert_eq!(
probe_1_fetch.icmp_packet_type,
Some(IcmpPacketType::TimeExceeded)
);
assert_eq!(state.round, Round(0));
assert_eq!(state.sequence, Sequence(33001));
assert_eq!(state.round_sequence, Sequence(33000));
assert_eq!(state.ttl, TimeToLive(2));
assert_eq!(state.max_received_ttl, Some(TimeToLive(1)));
assert_eq!(state.received_time, Some(received_1));
assert_eq!(state.target_ttl, None);
assert_eq!(state.target_found, false);
{
let mut probe_iter = state.probes().iter();
let probe_next1 = *probe_iter.next().unwrap();
assert_eq!(probe_1_fetch, probe_next1);
assert_eq!(None, probe_iter.next());
}
state.advance_round(TimeToLive(1));
assert_eq!(state.round, Round(1));
assert_eq!(state.sequence, Sequence(33001));
assert_eq!(state.round_sequence, Sequence(33001));
assert_eq!(state.ttl, TimeToLive(1));
assert_eq!(state.max_received_ttl, None);
assert_eq!(state.received_time, None);
assert_eq!(state.target_ttl, None);
assert_eq!(state.target_found, false);
let probe_2 = state.next_probe();
assert_eq!(probe_2.sequence, Sequence(33001));
assert_eq!(probe_2.ttl, TimeToLive(1));
assert_eq!(probe_2.round, Round(1));
assert_eq!(probe_2.received, None);
assert_eq!(probe_2.host, None);
assert_eq!(probe_2.sent.is_some(), true);
assert_eq!(probe_2.status, ProbeStatus::Awaited);
assert_eq!(probe_2.icmp_packet_type, None);
let probe_3 = state.next_probe();
assert_eq!(probe_3.sequence, Sequence(33002));
assert_eq!(probe_3.ttl, TimeToLive(2));
assert_eq!(probe_3.round, Round(1));
assert_eq!(probe_3.received, None);
assert_eq!(probe_3.host, None);
assert_eq!(probe_3.sent.is_some(), true);
assert_eq!(probe_3.status, ProbeStatus::Awaited);
assert_eq!(probe_3.icmp_packet_type, None);
let received_2 = SystemTime::now();
let host = IpAddr::V4(Ipv4Addr::LOCALHOST);
state.complete_probe_time_exceeded(Sequence(33001), host, received_2, false);
let probe_2_recv = state.probe_at(Sequence(33001));
assert_eq!(state.round, Round(1));
assert_eq!(state.sequence, Sequence(33003));
assert_eq!(state.round_sequence, Sequence(33001));
assert_eq!(state.ttl, TimeToLive(3));
assert_eq!(state.max_received_ttl, Some(TimeToLive(1)));
assert_eq!(state.received_time, Some(received_2));
assert_eq!(state.target_ttl, None);
assert_eq!(state.target_found, false);
{
let mut probe_iter = state.probes().iter();
let probe_next1 = *probe_iter.next().unwrap();
assert_eq!(probe_2_recv, probe_next1);
let probe_next2 = *probe_iter.next().unwrap();
assert_eq!(probe_3, probe_next2);
}
let received_3 = SystemTime::now();
let host = IpAddr::V4(Ipv4Addr::LOCALHOST);
state.complete_probe_echo_reply(Sequence(33002), host, received_3);
let probe_3_recv = state.probe_at(Sequence(33002));
assert_eq!(state.round, Round(1));
assert_eq!(state.sequence, Sequence(33003));
assert_eq!(state.round_sequence, Sequence(33001));
assert_eq!(state.ttl, TimeToLive(3));
assert_eq!(state.max_received_ttl, Some(TimeToLive(2)));
assert_eq!(state.received_time, Some(received_3));
assert_eq!(state.target_ttl, Some(TimeToLive(2)));
assert_eq!(state.target_found, true);
{
let mut probe_iter = state.probes().iter();
let probe_next1 = *probe_iter.next().unwrap();
assert_eq!(probe_2_recv, probe_next1);
let probe_next2 = *probe_iter.next().unwrap();
assert_eq!(probe_3_recv, probe_next2);
}
}
#[test]
fn test_sequence_wrap1() {
let initial_sequence = Sequence(65278);
let mut state = TracerState::new(TimeToLive(1), initial_sequence);
assert_eq!(state.round, Round(0));
assert_eq!(state.sequence, initial_sequence);
assert_eq!(state.round_sequence, initial_sequence);
assert_eq!(state.next_probe().sequence, Sequence(65278));
assert_eq!(state.sequence, Sequence(65279));
{
let mut iter = state.probes().iter();
assert_eq!(iter.next().unwrap().sequence, Sequence(65278));
iter.take(BUFFER_SIZE as usize - 1)
.for_each(|p| assert_eq!(p.sequence, Sequence(0)));
}
state.advance_round(TimeToLive(1));
assert_eq!(state.round, Round(1));
assert_eq!(state.sequence, initial_sequence);
assert_eq!(state.round_sequence, initial_sequence);
assert_eq!(state.next_probe().sequence, Sequence(65278));
assert_eq!(state.sequence, Sequence(65279));
{
let mut iter = state.probes().iter();
assert_eq!(iter.next().unwrap().sequence, Sequence(65278));
iter.take(BUFFER_SIZE as usize - 1)
.for_each(|p| assert_eq!(p.sequence, Sequence(0)));
}
}
#[test]
fn test_sequence_wrap2() {
let total_rounds = 2000;
let max_probe_per_round = 254;
let mut state = TracerState::new(TimeToLive(1), Sequence(33000));
for _ in 0..total_rounds {
for _ in 0..max_probe_per_round {
let _probe = state.next_probe();
}
state.advance_round(TimeToLive(1));
}
assert_eq!(state.round, Round(2000));
assert_eq!(state.round_sequence, Sequence(33000));
assert_eq!(state.sequence, Sequence(33000));
}
#[test]
fn test_sequence_wrap3() {
let total_rounds = 2000;
let max_probe_per_round = 20;
let mut state = TracerState::new(TimeToLive(1), Sequence(33000));
let mut rng = rand::thread_rng();
for _ in 0..total_rounds {
for _ in 0..rng.gen_range(0..max_probe_per_round) {
state.next_probe();
}
state.advance_round(TimeToLive(1));
}
}
#[test]
fn test_sequence_wrap_with_skip() {
let total_rounds = 2000;
let max_probe_per_round = 254;
let mut state = TracerState::new(TimeToLive(1), Sequence(33000));
for _ in 0..total_rounds {
for _ in 0..max_probe_per_round {
let _ = state.next_probe();
let _ = state.reissue_probe();
}
state.advance_round(TimeToLive(1));
}
assert_eq!(state.round, Round(2000));
assert_eq!(state.round_sequence, Sequence(56876));
assert_eq!(state.sequence, Sequence(56876));
}
#[test]
fn test_in_round() {
let state = TracerState::new(TimeToLive(1), Sequence(33000));
assert!(state.in_round(Sequence(33000)));
assert!(state.in_round(Sequence(34023)));
assert!(!state.in_round(Sequence(34024)));
}
}
}
fn exceeds(start: Option<SystemTime>, end: SystemTime, dur: Duration) -> bool {
start.map_or(false, |start| {
end.duration_since(start).unwrap_or_default() > dur
})
}