use crate::dns::resolver;
use crate::services::Services;
use crate::socket::factory::create_probe_socket_with_options;
use crate::traceroute::engine::TracerouteEngine;
use crate::traceroute::{TracerouteConfig, TracerouteError, TracerouteResult};
use std::net::IpAddr;
#[derive(Debug)]
pub struct Traceroute {
config: TracerouteConfig,
target_ip: IpAddr,
services: Option<Services>,
}
impl Traceroute {
pub async fn new_with_services(
mut config: TracerouteConfig,
services: Services,
) -> Result<Self, TracerouteError> {
let target_ip = Self::resolve_target(&mut config).await?;
Ok(Self {
config,
target_ip,
services: Some(services),
})
}
async fn resolve_target(config: &mut TracerouteConfig) -> Result<IpAddr, TracerouteError> {
if let Ok(ip) = config.target.parse::<IpAddr>() {
if ip.is_ipv6() {
return Err(TracerouteError::Ipv6NotSupported);
}
config.target_ip = Some(ip);
return Ok(ip);
}
if config.target == "localhost" {
let ip = IpAddr::V4(std::net::Ipv4Addr::LOCALHOST);
config.target_ip = Some(ip);
return Ok(ip);
}
if let Some(ip) = config.target_ip {
return Ok(ip);
}
let addrs = resolver::resolve_a(&config.target)
.await
.map_err(|e| TracerouteError::ResolutionError(e.to_string()))?;
let ip =
IpAddr::V4(*addrs.first().ok_or_else(|| {
TracerouteError::ResolutionError("No addresses found".to_string())
})?);
config.target_ip = Some(ip);
Ok(ip)
}
pub async fn new(mut config: TracerouteConfig) -> Result<Self, TracerouteError> {
let target_ip = Self::resolve_target(&mut config).await?;
Ok(Self {
config,
target_ip,
services: None,
})
}
pub async fn run(self) -> Result<TracerouteResult, TracerouteError> {
let mut timing_config = self.config.timing.clone();
timing_config.socket_read_timeout = self.config.probe_timeout;
if self.config.verbose > 0 {
std::env::set_var("FTR_VERBOSE", self.config.verbose.to_string());
}
let socket = create_probe_socket_with_options(
self.target_ip,
timing_config,
self.config.protocol,
self.config.socket_mode,
)
.await?;
let engine = if let Some(services) = self.services {
TracerouteEngine::new_with_services(
socket,
self.config.clone(),
self.target_ip,
std::sync::Arc::new(services),
)
.await
.map_err(|e| TracerouteError::SocketError(e.to_string()))?
} else {
TracerouteEngine::new(socket, self.config.clone(), self.target_ip)
.await
.map_err(|e| TracerouteError::SocketError(e.to_string()))?
};
let result = engine
.run()
.await
.map_err(|e| TracerouteError::SocketError(e.to_string()))?;
Ok(result)
}
}
pub async fn trace_async(target: &str) -> Result<TracerouteResult, TracerouteError> {
let config = TracerouteConfig::builder()
.target(target)
.build()
.map_err(TracerouteError::ConfigError)?;
trace_with_config_async(config).await
}
pub async fn trace_with_config_async(
config: TracerouteConfig,
) -> Result<TracerouteResult, TracerouteError> {
let traceroute = Traceroute::new(config).await?;
traceroute.run().await
}
pub(crate) async fn trace_with_services(
config: TracerouteConfig,
services: &Services,
) -> Result<TracerouteResult, TracerouteError> {
let traceroute = Traceroute::new_with_services(config, services.clone()).await?;
traceroute.run().await
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[tokio::test]
async fn test_async_traceroute_creation() {
let config = TracerouteConfig::builder()
.target("127.0.0.1")
.build()
.unwrap();
let result = Traceroute::new(config).await;
assert!(result.is_ok());
let traceroute = result.unwrap();
assert_eq!(
traceroute.target_ip,
IpAddr::V4("127.0.0.1".parse().unwrap())
);
}
#[tokio::test]
async fn test_async_traceroute_ipv6_error() {
let config = TracerouteConfig::builder()
.target("::1")
.target_ip(IpAddr::V6("::1".parse().unwrap()))
.build()
.unwrap();
let result = Traceroute::new(config).await;
assert!(result.is_err());
match result.unwrap_err() {
TracerouteError::Ipv6NotSupported => {}
_ => panic!("Expected IPv6 not supported error"),
}
}
#[tokio::test]
async fn test_async_traceroute_with_ip() {
let config = TracerouteConfig::builder()
.target("8.8.8.8")
.target_ip(IpAddr::V4("8.8.8.8".parse().unwrap()))
.build()
.unwrap();
let result = Traceroute::new(config).await;
assert!(result.is_ok());
let traceroute = result.unwrap();
assert_eq!(traceroute.target_ip, IpAddr::V4("8.8.8.8".parse().unwrap()));
}
#[tokio::test]
async fn test_trace_async_localhost() {
let result = trace_async("127.0.0.1").await;
match result {
Ok(trace_result) => {
assert_eq!(trace_result.target, "127.0.0.1");
}
Err(TracerouteError::InsufficientPermissions { .. }) => {
}
Err(TracerouteError::SocketError(_)) => {
}
Err(e) => {
panic!("Unexpected error: {:?}", e);
}
}
}
#[tokio::test]
async fn test_trace_with_config_async() {
let config = TracerouteConfig::builder()
.target("127.0.0.1")
.max_hops(3)
.probe_timeout(Duration::from_millis(100))
.build()
.unwrap();
let result = trace_with_config_async(config).await;
match result {
Ok(trace_result) => {
assert_eq!(trace_result.target, "127.0.0.1");
assert!(trace_result.hops.len() <= 3);
}
Err(TracerouteError::InsufficientPermissions { .. }) => {
}
Err(TracerouteError::SocketError(_)) => {
}
Err(e) => {
panic!("Unexpected error: {:?}", e);
}
}
}
#[tokio::test]
async fn test_async_traceroute_hostname_resolution() {
let config = TracerouteConfig::builder()
.target("localhost")
.build()
.unwrap();
let result = Traceroute::new(config).await;
assert!(result.is_ok());
let traceroute = result.unwrap();
assert_eq!(
traceroute.target_ip,
IpAddr::V4("127.0.0.1".parse().unwrap())
);
}
#[tokio::test]
async fn test_async_traceroute_invalid_hostname() {
let config = TracerouteConfig::builder()
.target("this.hostname.definitely.does.not.exist.invalid")
.build()
.unwrap();
let result = Traceroute::new(config).await;
assert!(result.is_err());
match result.unwrap_err() {
TracerouteError::ResolutionError(_) => {}
_ => panic!("Expected resolution error"),
}
}
#[tokio::test]
async fn test_verbose_environment_setting() {
let config = TracerouteConfig::builder()
.target("127.0.0.1")
.verbose(2)
.max_hops(1) .probe_timeout(Duration::from_millis(100))
.build()
.unwrap();
let traceroute = Traceroute::new(config).await.unwrap();
let original = std::env::var("FTR_VERBOSE").ok();
let _ = tokio::time::timeout(Duration::from_secs(5), traceroute.run()).await;
assert_eq!(std::env::var("FTR_VERBOSE").ok(), Some("2".to_string()));
match original {
Some(val) => std::env::set_var("FTR_VERBOSE", val),
None => std::env::remove_var("FTR_VERBOSE"),
}
}
}