use crate::socket::{ProbeProtocol, SocketMode};
use serde::{Deserialize, Serialize};
use std::net::IpAddr;
use std::time::Duration;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TimingConfig {
pub receiver_poll_interval: Duration,
pub main_loop_poll_interval: Duration,
pub enrichment_wait_time: Duration,
pub socket_read_timeout: Duration,
pub udp_retry_delay: Duration,
}
impl Default for TimingConfig {
fn default() -> Self {
use crate::config::timing::*;
Self {
receiver_poll_interval: Duration::from_millis(DEFAULT_RECEIVER_POLL_INTERVAL_MS),
main_loop_poll_interval: Duration::from_millis(DEFAULT_MAIN_LOOP_POLL_INTERVAL_MS),
enrichment_wait_time: Duration::from_millis(DEFAULT_ENRICHMENT_WAIT_TIME_MS),
socket_read_timeout: Duration::from_millis(DEFAULT_SOCKET_READ_TIMEOUT_MS),
udp_retry_delay: Duration::from_millis(DEFAULT_UDP_RETRY_DELAY_MS),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TracerouteConfig {
pub target: String,
pub target_ip: Option<IpAddr>,
pub start_ttl: u8,
pub max_hops: u8,
pub probe_timeout: Duration,
pub send_interval: Duration,
pub overall_timeout: Duration,
pub queries_per_hop: u8,
pub protocol: Option<ProbeProtocol>,
pub socket_mode: Option<SocketMode>,
pub port: u16,
pub enable_asn_lookup: bool,
pub enable_rdns: bool,
pub verbose: u8,
pub public_ip: Option<IpAddr>,
pub timing: TimingConfig,
}
impl Default for TracerouteConfig {
fn default() -> Self {
Self {
target: String::new(),
target_ip: None,
start_ttl: 1,
max_hops: 30,
probe_timeout: Duration::from_millis(1000),
send_interval: Duration::from_millis(0),
overall_timeout: Duration::from_millis(3000),
queries_per_hop: 1,
protocol: None,
socket_mode: None,
port: 443,
enable_asn_lookup: true,
enable_rdns: true,
verbose: 0,
public_ip: None,
timing: TimingConfig::default(),
}
}
}
impl TracerouteConfig {
pub fn builder() -> TracerouteConfigBuilder {
TracerouteConfigBuilder::new()
}
pub fn validate(&self) -> Result<(), String> {
if self.target.is_empty() && self.target_ip.is_none() {
return Err("Target must be specified".to_string());
}
if self.start_ttl < 1 {
return Err("start_ttl must be at least 1".to_string());
}
if self.max_hops < self.start_ttl {
return Err("max_hops must be greater than or equal to start_ttl".to_string());
}
if self.probe_timeout.as_millis() == 0 {
return Err("probe_timeout must be greater than 0".to_string());
}
if self.queries_per_hop < 1 {
return Err("queries_per_hop must be at least 1".to_string());
}
Ok(())
}
}
pub struct TracerouteConfigBuilder {
config: TracerouteConfig,
}
impl TracerouteConfigBuilder {
pub fn new() -> Self {
Self {
config: TracerouteConfig::default(),
}
}
pub fn target(mut self, target: impl Into<String>) -> Self {
self.config.target = target.into();
self
}
pub fn target_ip(mut self, ip: IpAddr) -> Self {
self.config.target_ip = Some(ip);
self
}
pub fn start_ttl(mut self, ttl: u8) -> Self {
self.config.start_ttl = ttl;
self
}
pub fn max_hops(mut self, hops: u8) -> Self {
self.config.max_hops = hops;
self
}
pub fn probe_timeout(mut self, timeout: Duration) -> Self {
self.config.probe_timeout = timeout;
self
}
pub fn send_interval(mut self, interval: Duration) -> Self {
self.config.send_interval = interval;
self
}
pub fn overall_timeout(mut self, timeout: Duration) -> Self {
self.config.overall_timeout = timeout;
self
}
pub fn queries_per_hop(mut self, queries: u8) -> Self {
self.config.queries_per_hop = queries;
self
}
pub fn protocol(mut self, protocol: ProbeProtocol) -> Self {
self.config.protocol = Some(protocol);
self
}
pub fn socket_mode(mut self, mode: SocketMode) -> Self {
self.config.socket_mode = Some(mode);
self
}
pub fn port(mut self, port: u16) -> Self {
self.config.port = port;
self
}
pub fn enable_asn_lookup(mut self, enable: bool) -> Self {
self.config.enable_asn_lookup = enable;
self
}
pub fn enable_rdns(mut self, enable: bool) -> Self {
self.config.enable_rdns = enable;
self
}
pub fn verbose(mut self, verbose: u8) -> Self {
self.config.verbose = verbose;
self
}
pub fn public_ip(mut self, ip: IpAddr) -> Self {
self.config.public_ip = Some(ip);
self
}
pub fn timing(mut self, timing: TimingConfig) -> Self {
self.config.timing = timing;
self
}
pub fn build(self) -> Result<TracerouteConfig, String> {
self.config.validate()?;
Ok(self.config)
}
pub fn queries(self, queries: u8) -> Self {
self.queries_per_hop(queries)
}
pub fn parallel_probes(mut self, parallel: u8) -> Self {
let interval_ms = match parallel {
0..=1 => 50, 2..=10 => 20, 11..=30 => 10, 31..=50 => 5, _ => 2, };
self.config.send_interval = Duration::from_millis(interval_ms);
self
}
}
impl Default for TracerouteConfigBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::Ipv4Addr;
#[test]
fn test_default_config() {
let config = TracerouteConfig::default();
assert_eq!(config.start_ttl, 1);
assert_eq!(config.max_hops, 30);
assert_eq!(config.probe_timeout.as_millis(), 1000);
assert_eq!(config.queries_per_hop, 1);
assert!(config.enable_asn_lookup);
assert!(config.enable_rdns);
}
#[test]
fn test_config_builder() {
let config = TracerouteConfig::builder()
.target("google.com")
.max_hops(20)
.probe_timeout(Duration::from_millis(500))
.queries_per_hop(3)
.build()
.unwrap();
assert_eq!(config.target, "google.com");
assert_eq!(config.max_hops, 20);
assert_eq!(config.probe_timeout.as_millis(), 500);
assert_eq!(config.queries_per_hop, 3);
}
#[test]
fn test_config_validation() {
let result = TracerouteConfig::builder().build();
assert!(result.is_err());
let result = TracerouteConfig::builder()
.target("example.com")
.start_ttl(0)
.build();
assert!(result.is_err());
let result = TracerouteConfig::builder()
.target("example.com")
.start_ttl(10)
.max_hops(5)
.build();
assert!(result.is_err());
let result = TracerouteConfig::builder()
.target("example.com")
.probe_timeout(Duration::from_millis(0))
.build();
assert!(result.is_err());
let result = TracerouteConfig::builder()
.target("example.com")
.queries_per_hop(0)
.build();
assert!(result.is_err());
}
#[test]
fn test_config_with_ip() {
let ip = IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8));
let config = TracerouteConfig::builder().target_ip(ip).build().unwrap();
assert_eq!(config.target_ip, Some(ip));
}
}