async_traceroute/
traceroute.rs

1use std::cell::RefCell;
2use std::cmp::min;
3use std::collections::HashMap;
4use std::future::Future;
5use std::net::IpAddr;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::time::Duration;
9
10use async_stream::stream;
11use futures_core::stream::Stream;
12use tokio::select;
13use tokio::task::JoinSet;
14
15use crate::traceroute::probe::{ProbeError, ProbeResult};
16use crate::traceroute::probe::generator::ProbeTaskGenerator;
17use crate::traceroute::probe::sniffer::{IcmpProbeResponseSniffer, Sniffer};
18use crate::traceroute::utils::dns;
19
20pub mod utils;
21pub mod terminal;
22pub mod probe;
23mod async_socket;
24pub mod builder;
25
26pub struct Traceroute {
27    source_address: IpAddr,
28    destination_address: IpAddr,
29    max_ttl: u8,
30    nqueries: u16,
31    sim_queries: u16,
32    max_wait_probe: Duration,
33    is_active_dns_lookup: bool,
34    current_ttl: Box<RefCell<u8>>,
35    current_query: Box<RefCell<u16>>,
36    probe_task_generator: Box<RefCell<Box<dyn ProbeTaskGenerator>>>,
37    icmp_probe_response_sniffer: Arc<IcmpProbeResponseSniffer>,
38}
39
40impl Traceroute {
41    pub fn new(
42        source_address: IpAddr,
43        destination_address: IpAddr,
44        max_ttl: u8,
45        nqueries: u16,
46        sim_queries: u16,
47        max_wait_probe: Duration,
48        is_active_dns_lookup: bool,
49        probe_task_generator: Box<dyn ProbeTaskGenerator>,
50        icmp_probe_response_sniffer: IcmpProbeResponseSniffer,
51    ) -> Self {
52        Self {
53            source_address,
54            destination_address,
55            max_ttl,
56            nqueries,
57            sim_queries: min(sim_queries, (max_ttl as u16) * nqueries),
58            max_wait_probe,
59            is_active_dns_lookup,
60            current_ttl: Box::new(RefCell::new(1)),
61            current_query: Box::new(RefCell::new(1)),
62            probe_task_generator: Box::new(RefCell::new(probe_task_generator)),
63            icmp_probe_response_sniffer: Arc::new(icmp_probe_response_sniffer),
64        }
65    }
66
67    pub fn trace(self) -> impl Stream<Item=Result<ProbeResult, ProbeError>> {
68        let mut probe_tasks = JoinSet::new();
69
70        for _ in 0..self.sim_queries {
71            let probe_task = self.generate_probe_task(&self.icmp_probe_response_sniffer);
72            self.increment_ttl_query_counter();
73            probe_tasks.spawn(probe_task);
74        }
75
76        let icmp_probe_response_sniffer = Arc::clone(&self.icmp_probe_response_sniffer);
77        tokio::spawn(async move {
78            icmp_probe_response_sniffer.sniff().await
79        });
80
81        let mut stop_send_probes = false;
82        let mut query_count_by_ttl = HashMap::<u8, u16>::new();
83        let mut ttl_target_address = u8::MAX;
84        let mut target_address_found = false;
85
86        stream! {
87            loop {
88                if *self.current_ttl.borrow() > self.max_ttl {
89                    stop_send_probes = true;
90                }
91
92                select! {
93                    Some(Ok(probe_result)) = probe_tasks.join_next() => {
94                        let (probe_result, ttl) = match probe_result {
95                            Ok(mut probe_result) => {
96                                if self.is_active_dns_lookup {
97                                    Self::reverse_dns_lookup(&mut probe_result).await;
98                                }
99
100                                if !target_address_found && probe_result.from_address() == self.destination_address {
101                                    ttl_target_address = probe_result.ttl();
102                                    target_address_found = true;
103                                }
104
105                                let ttl = probe_result.ttl();
106                                (Ok(probe_result), ttl)
107                            },
108                            Err(probe_error) => {
109                                let ttl = probe_error.get_ttl();
110                                (Err(probe_error), ttl)
111                            },
112                        };
113
114                        let query_count = query_count_by_ttl
115                            .entry(ttl)
116                            .or_insert(0);
117                        *query_count += 1;
118
119                        if ttl <= ttl_target_address {
120                            yield probe_result;
121                        }
122
123                        if ttl == ttl_target_address && *query_count == self.nqueries {
124                            stop_send_probes = true;
125                        }
126
127                        if !stop_send_probes {
128                            let probe_task =
129                                self.generate_probe_task(&self.icmp_probe_response_sniffer);
130
131                            self.increment_ttl_query_counter();
132                            probe_tasks.spawn(probe_task);
133                        }
134                    },
135                    else => break
136                }
137            }
138        }
139    }
140
141    fn generate_probe_task(
142        &self,
143        icmp_probe_response_sniffer: &IcmpProbeResponseSniffer,
144    ) -> Pin<Box<impl Future<Output=Result<ProbeResult, ProbeError>>>> {
145        let mut probe_task_generator = self.probe_task_generator.borrow_mut();
146        match probe_task_generator.generate_probe_task(
147            self.source_address,
148            self.destination_address,
149            &icmp_probe_response_sniffer,
150        ) {
151            Ok((_, mut probe_task)) => {
152                let current_ttl = *self.current_ttl.borrow();
153                let timeout = self.max_wait_probe;
154                let probe_task_future = Box::pin(async move {
155                    probe_task.send_probe(current_ttl, timeout).await
156                });
157                probe_task_future
158            }
159            Err(_) => todo!()
160        }
161    }
162
163    fn increment_ttl_query_counter(&self) {
164        let mut current_query = self.current_query.borrow_mut();
165        *current_query += 1;
166        if *current_query > self.nqueries {
167            *current_query = 1;
168            let mut current_ttl = self.current_ttl.borrow_mut();
169            *current_ttl += 1;
170        }
171    }
172
173    async fn reverse_dns_lookup(probe_result: &mut ProbeResult) {
174        let ip_addr = &IpAddr::V4(probe_result.from_address());
175        if let Some(hostname) = dns::reverse_dns_lookup_first_hostname(ip_addr).await {
176            probe_result.set_hostname(&hostname);
177        }
178    }
179
180    pub fn get_nqueries(&self) -> u16 {
181        self.nqueries
182    }
183
184    pub fn get_max_ttl(&self) -> u8 {
185        self.max_ttl
186    }
187}