use crate::services::Services;
use crate::traceroute::AsnInfo;
use crate::traceroute::TracerouteError;
use std::collections::{HashMap, HashSet};
use std::net::IpAddr;
use std::sync::Arc;
use tokio::sync::{mpsc, RwLock};
use tokio::task::JoinSet;
#[derive(Debug, Clone)]
pub struct EnrichmentResult {
pub addr: IpAddr,
pub hostname: Option<String>,
pub asn_info: Option<AsnInfo>,
}
pub struct EnrichmentService {
services: Arc<Services>,
seen_addresses: Arc<RwLock<HashSet<IpAddr>>>,
enrichment_tx: mpsc::UnboundedSender<IpAddr>,
enrichment_rx: Arc<RwLock<mpsc::UnboundedReceiver<IpAddr>>>,
}
impl EnrichmentService {
pub async fn new() -> Result<Self, TracerouteError> {
let services = Arc::new(Services::new());
let (enrichment_tx, enrichment_rx) = mpsc::unbounded_channel();
Ok(Self {
services,
seen_addresses: Arc::new(RwLock::new(HashSet::new())),
enrichment_tx,
enrichment_rx: Arc::new(RwLock::new(enrichment_rx)),
})
}
pub async fn enqueue(&self, addr: IpAddr) -> Result<(), TracerouteError> {
let mut seen = self.seen_addresses.write().await;
if seen.insert(addr) {
self.enrichment_tx
.send(addr)
.map_err(|e| TracerouteError::Other(e.to_string()))?;
}
Ok(())
}
pub async fn start_background_enrichment(self: Arc<Self>) -> HashMap<IpAddr, EnrichmentResult> {
let mut results = HashMap::new();
let mut enrichment_futures = JoinSet::new();
let mut rx = self.enrichment_rx.write().await;
loop {
tokio::select! {
Some(addr) = rx.recv() => {
let services = Arc::clone(&self.services);
enrichment_futures.spawn(async move {
let dns_future = services.rdns.lookup(addr);
let asn_future = services.asn.lookup(addr);
let (hostname_result, asn_result) = tokio::join!(dns_future, asn_future);
let hostname = hostname_result.ok();
let asn_info = asn_result.ok();
EnrichmentResult {
addr,
hostname,
asn_info,
}
});
}
Some(Ok(result)) = enrichment_futures.join_next() => {
results.insert(result.addr, result);
}
else => {
if enrichment_futures.is_empty() {
break;
}
}
}
}
results
}
pub async fn enrich_addresses(
&self,
addresses: Vec<IpAddr>,
) -> HashMap<IpAddr, EnrichmentResult> {
let mut enrichment_futures = JoinSet::new();
for addr in addresses {
let services = Arc::clone(&self.services);
enrichment_futures.spawn(async move {
let dns_future = services.rdns.lookup(addr);
let asn_future = services.asn.lookup(addr);
let (hostname_result, asn_result) = tokio::join!(dns_future, asn_future);
let hostname = hostname_result.ok();
let asn_info = asn_result.ok();
EnrichmentResult {
addr,
hostname,
asn_info,
}
});
}
let mut results = HashMap::new();
while let Some(Ok(result)) = enrichment_futures.join_next().await {
results.insert(result.addr, result);
}
results
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr};
#[tokio::test]
async fn test_enrichment_service_creation() {
let service = EnrichmentService::new().await;
assert!(service.is_ok());
}
#[tokio::test]
async fn test_enqueue_deduplication() {
let service = Arc::new(EnrichmentService::new().await.unwrap());
let addr: IpAddr = "8.8.8.8".parse().unwrap();
assert!(service.enqueue(addr).await.is_ok());
assert!(service.enqueue(addr).await.is_ok());
let seen = service.seen_addresses.read().await;
assert_eq!(seen.len(), 1);
assert!(seen.contains(&addr));
}
#[tokio::test]
async fn test_enrich_addresses() {
let service = EnrichmentService::new().await.unwrap();
let addresses = vec![
IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)),
IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)),
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
];
let results = tokio::time::timeout(
std::time::Duration::from_secs(30),
service.enrich_addresses(addresses.clone()),
)
.await
.expect("Enrichment timed out");
assert_eq!(results.len(), addresses.len());
for addr in &addresses {
assert!(results.contains_key(addr));
let result = &results[addr];
assert_eq!(result.addr, *addr);
}
let localhost_result = &results[&IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))];
let _ = localhost_result.hostname.as_ref();
let google_dns = &results[&IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))];
assert!(google_dns.asn_info.is_some());
}
#[tokio::test]
async fn test_enrichment_result_fields() {
let addr = IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8));
let result = EnrichmentResult {
addr,
hostname: Some("dns.google".to_string()),
asn_info: Some(AsnInfo {
asn: 15169,
name: "GOOGLE".to_string(),
prefix: "8.8.8.0/24".to_string(),
country_code: "US".to_string(),
registry: "arin".to_string(),
}),
};
assert_eq!(result.addr, addr);
assert_eq!(result.hostname, Some("dns.google".to_string()));
assert!(result.asn_info.is_some());
let asn = result.asn_info.unwrap();
assert_eq!(asn.asn, 15169);
assert_eq!(asn.name, "GOOGLE");
}
#[tokio::test]
async fn test_background_enrichment() {
let mut service = EnrichmentService::new().await.unwrap();
let addr1: IpAddr = "8.8.8.8".parse().unwrap();
let addr2: IpAddr = "1.1.1.1".parse().unwrap();
service.enqueue(addr1).await.unwrap();
service.enqueue(addr2).await.unwrap();
let tx = std::mem::replace(&mut service.enrichment_tx, mpsc::unbounded_channel().0);
drop(tx);
let service = Arc::new(service);
let results = tokio::time::timeout(
std::time::Duration::from_secs(30),
service.start_background_enrichment(),
)
.await
.expect("Background enrichment timed out");
assert!(results.contains_key(&addr1));
assert!(results.contains_key(&addr2));
}
#[tokio::test]
async fn test_ipv6_enrichment() {
let service = EnrichmentService::new().await.unwrap();
let ipv6_addr: IpAddr = "2001:4860:4860::8888".parse().unwrap();
let results = service.enrich_addresses(vec![ipv6_addr]).await;
assert!(results.contains_key(&ipv6_addr));
let result = &results[&ipv6_addr];
assert!(result.asn_info.is_none());
assert_eq!(result.addr, ipv6_addr);
}
#[tokio::test]
async fn test_private_ip_enrichment() {
let service = EnrichmentService::new().await.unwrap();
let private_addrs = vec![
IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)),
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
IpAddr::V4(Ipv4Addr::new(172, 16, 0, 1)),
];
let results = service.enrich_addresses(private_addrs.clone()).await;
assert_eq!(results.len(), private_addrs.len());
for addr in &private_addrs {
let result = &results[addr];
assert!(result.asn_info.is_some());
let asn_info = result.asn_info.as_ref().unwrap();
assert_eq!(asn_info.asn, 0); assert_eq!(asn_info.name, "Private Network");
}
}
}