use std::collections::HashSet;
use std::fs;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, ToSocketAddrs};
use std::path::Path;
use std::str::FromStr;
use ipnet::{Ipv4Net, Ipv6Net};
use rand::Rng;
use thiserror::Error;
const MAX_HOSTS_PER_TARGET: usize = 65_536;
#[derive(Debug, Clone)]
pub struct ExpandOpts {
pub ipv6: bool,
pub no_dns: bool,
pub resolve_all: bool,
pub dns_servers: Vec<IpAddr>,
}
#[derive(Debug, Error)]
pub enum TargetError {
#[error("invalid target: {0}")]
Invalid(String),
#[error("DNS resolution failed for {0}: {1}")]
Dns(String, String),
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
}
pub fn read_input_list(path: &Path) -> Result<Vec<String>, TargetError> {
let data = fs::read_to_string(path)?;
let mut out = Vec::new();
for line in data.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
out.push(line.to_string());
}
if out.is_empty() {
return Err(TargetError::Invalid("empty -iL file".into()));
}
Ok(out)
}
pub fn random_addresses(count: u64, ipv6: bool) -> Vec<IpAddr> {
let mut rng = rand::thread_rng();
let mut v = Vec::with_capacity(count.min(65_536) as usize);
for _ in 0..count {
if ipv6 {
let mut b = [0u8; 16];
rng.fill(&mut b);
b[0] = (b[0] & 0x0f) | 0x20;
v.push(IpAddr::V6(Ipv6Addr::from(b)));
} else {
v.push(IpAddr::V4(Ipv4Addr::from(rng.gen::<u32>())));
}
}
v
}
pub async fn expand_target(token: &str, opts: &ExpandOpts) -> Result<Vec<IpAddr>, TargetError> {
let token = token.trim();
if token.is_empty() {
return Err(TargetError::Invalid(token.to_string()));
}
if let Ok(net) = Ipv4Net::from_str(token) {
let hosts: Vec<Ipv4Addr> = net.hosts().collect();
if hosts.len() > MAX_HOSTS_PER_TARGET {
return Err(TargetError::Invalid(format!(
"CIDR {token} expands to {} hosts (> {MAX_HOSTS_PER_TARGET})",
hosts.len()
)));
}
return Ok(hosts.into_iter().map(IpAddr::V4).collect());
}
if opts.ipv6 {
if let Ok(net) = Ipv6Net::from_str(token) {
let hosts: Vec<Ipv6Addr> = net.hosts().take(MAX_HOSTS_PER_TARGET + 1).collect();
if hosts.len() > MAX_HOSTS_PER_TARGET {
return Err(TargetError::Invalid(format!(
"IPv6 CIDR {token} expands to too many hosts (> {MAX_HOSTS_PER_TARGET})"
)));
}
return Ok(hosts.into_iter().map(IpAddr::V6).collect());
}
}
if token.contains('/') && !opts.ipv6 {
return Err(TargetError::Invalid(
"IPv6 CIDR requires -6 (or use IPv4 CIDR)".into(),
));
}
if let Ok(ip) = Ipv4Addr::from_str(token) {
return Ok(vec![IpAddr::V4(ip)]);
}
if opts.ipv6 {
if let Ok(ip) = Ipv6Addr::from_str(token) {
return Ok(vec![IpAddr::V6(ip)]);
}
}
if token.chars().filter(|c| *c == '.').count() == 3 && token.chars().any(|c| c == '-') {
return expand_ipv4_ranges(token).map(|v| v.into_iter().map(IpAddr::V4).collect());
}
if opts.no_dns {
return Err(TargetError::Invalid(
"numeric IP or CIDR required when -n is set".into(),
));
}
resolve_host(token, opts).await
}
fn expand_ipv4_ranges(spec: &str) -> Result<Vec<Ipv4Addr>, TargetError> {
let parts: Vec<&str> = spec.split('.').collect();
if parts.len() != 4 {
return Err(TargetError::Invalid(spec.to_string()));
}
let mut octets: [Vec<u8>; 4] = [vec![], vec![], vec![], vec![]];
for (i, p) in parts.iter().enumerate() {
octets[i] = expand_octet(p)?;
}
let mut out = Vec::new();
for a in &octets[0] {
for b in &octets[1] {
for c in &octets[2] {
for d in &octets[3] {
out.push(Ipv4Addr::new(*a, *b, *c, *d));
if out.len() > MAX_HOSTS_PER_TARGET {
return Err(TargetError::Invalid(format!(
"range {spec} expands to > {MAX_HOSTS_PER_TARGET} hosts"
)));
}
}
}
}
}
Ok(out)
}
fn expand_octet(part: &str) -> Result<Vec<u8>, TargetError> {
if let Ok(n) = part.parse::<u8>() {
return Ok(vec![n]);
}
if let Some((a, b)) = part.split_once('-') {
let start: u8 = a
.parse()
.map_err(|_| TargetError::Invalid(part.to_string()))?;
let end: u8 = b
.parse()
.map_err(|_| TargetError::Invalid(part.to_string()))?;
if start > end {
return Err(TargetError::Invalid(part.to_string()));
}
return Ok((start..=end).collect());
}
Err(TargetError::Invalid(part.to_string()))
}
async fn resolve_host(host: &str, opts: &ExpandOpts) -> Result<Vec<IpAddr>, TargetError> {
let mut out: Vec<IpAddr> = if opts.dns_servers.is_empty() {
let addrs = tokio::net::lookup_host((host, 0))
.await
.map_err(|e| TargetError::Dns(host.to_string(), e.to_string()))?;
addrs.map(|a| a.ip()).collect()
} else {
use hickory_resolver::config::{NameServerConfigGroup, ResolverConfig};
use hickory_resolver::name_server::TokioConnectionProvider;
use hickory_resolver::Resolver;
let ns_group = NameServerConfigGroup::from_ips_clear(
&opts.dns_servers,
53,
true, );
let cfg = ResolverConfig::from_parts(None, vec![], ns_group);
let resolver =
Resolver::builder_with_config(cfg, TokioConnectionProvider::default()).build();
let lookup = resolver
.lookup_ip(host)
.await
.map_err(|e| TargetError::Dns(host.to_string(), e.to_string()))?;
lookup.iter().collect()
};
if opts.ipv6 {
out.retain(|ip| ip.is_ipv6());
} else {
out.retain(|ip| ip.is_ipv4());
}
out.sort_unstable();
out.dedup();
if out.is_empty() {
return Err(TargetError::Dns(
host.to_string(),
"no matching addresses for this address family".into(),
));
}
if !opts.resolve_all && out.len() > 1 {
out.truncate(1);
}
Ok(out)
}
pub fn resolve_host_blocking(host: &str, opts: &ExpandOpts) -> Result<Vec<IpAddr>, TargetError> {
let mut out: Vec<IpAddr> = (host, 0)
.to_socket_addrs()
.map_err(|e| TargetError::Dns(host.to_string(), e.to_string()))?
.map(|a| a.ip())
.collect();
if opts.ipv6 {
out.retain(|ip| ip.is_ipv6());
} else {
out.retain(|ip| ip.is_ipv4());
}
out.sort_unstable();
out.dedup();
if out.is_empty() {
return Err(TargetError::Dns(
host.to_string(),
"no matching addresses for this address family".into(),
));
}
if !opts.resolve_all && out.len() > 1 {
out.truncate(1);
}
Ok(out)
}
pub fn apply_exclude(
hosts: Vec<IpAddr>,
exclude: Option<&str>,
exclude_file: Option<&std::path::Path>,
opts: &ExpandOpts,
) -> Result<Vec<IpAddr>, TargetError> {
let mut banned: HashSet<IpAddr> = HashSet::new();
if let Some(s) = exclude {
for t in s.split(',') {
let t = t.trim();
if t.is_empty() {
continue;
}
for ip in expand_target_blocking(t, opts)? {
banned.insert(ip);
}
}
}
if let Some(path) = exclude_file {
let data = fs::read_to_string(path)
.map_err(|e| TargetError::Invalid(format!("excludefile: {e}")))?;
for line in data.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
for ip in expand_target_blocking(line, opts)? {
banned.insert(ip);
}
}
}
Ok(hosts.into_iter().filter(|h| !banned.contains(h)).collect())
}
fn expand_target_blocking(token: &str, opts: &ExpandOpts) -> Result<Vec<IpAddr>, TargetError> {
let token = token.trim();
if token.is_empty() {
return Err(TargetError::Invalid(token.to_string()));
}
if let Ok(net) = Ipv4Net::from_str(token) {
let hosts: Vec<Ipv4Addr> = net.hosts().collect();
if hosts.len() > MAX_HOSTS_PER_TARGET {
return Err(TargetError::Invalid(format!(
"CIDR {token} expands to {} hosts (> {MAX_HOSTS_PER_TARGET})",
hosts.len()
)));
}
return Ok(hosts.into_iter().map(IpAddr::V4).collect());
}
if opts.ipv6 {
if let Ok(net) = Ipv6Net::from_str(token) {
let hosts: Vec<Ipv6Addr> = net.hosts().take(MAX_HOSTS_PER_TARGET + 1).collect();
if hosts.len() > MAX_HOSTS_PER_TARGET {
return Err(TargetError::Invalid(format!(
"IPv6 CIDR {token} expands to too many hosts (> {MAX_HOSTS_PER_TARGET})"
)));
}
return Ok(hosts.into_iter().map(IpAddr::V6).collect());
}
}
if let Ok(ip) = Ipv4Addr::from_str(token) {
return Ok(vec![IpAddr::V4(ip)]);
}
if opts.ipv6 {
if let Ok(ip) = Ipv6Addr::from_str(token) {
return Ok(vec![IpAddr::V6(ip)]);
}
}
if token.chars().filter(|c| *c == '.').count() == 3 && token.chars().any(|c| c == '-') {
return expand_ipv4_ranges(token).map(|v| v.into_iter().map(IpAddr::V4).collect());
}
resolve_host_blocking(token, opts)
}
#[cfg(test)]
mod tests {
use std::io::Write;
use std::net::{IpAddr, Ipv4Addr};
use tempfile::NamedTempFile;
use super::*;
fn opts_no_dns_v4() -> ExpandOpts {
ExpandOpts {
ipv6: false,
no_dns: true,
resolve_all: false,
dns_servers: vec![],
}
}
#[test]
fn cidr_expands_v4() {
let rt = tokio::runtime::Runtime::new().unwrap();
let ips = rt
.block_on(expand_target("10.0.0.0/31", &opts_no_dns_v4()))
.unwrap();
assert_eq!(ips.len(), 2);
}
#[test]
fn read_input_list_skips_comments_and_blanks() {
let mut f = NamedTempFile::new().unwrap();
writeln!(f, "# skip").unwrap();
writeln!(f).unwrap();
writeln!(f, " 10.0.0.1 ").unwrap();
writeln!(f, "10.0.0.2").unwrap();
f.flush().unwrap();
let lines = read_input_list(f.path()).unwrap();
assert_eq!(lines, vec!["10.0.0.1", "10.0.0.2"]);
}
#[test]
fn read_input_list_empty_file_errors() {
let f = NamedTempFile::new().unwrap();
let err = read_input_list(f.path()).unwrap_err();
assert!(err.to_string().contains("empty") || err.to_string().contains("-iL"));
}
#[tokio::test]
async fn expand_ipv4_octet_range_last_octet() {
let ips = expand_target("192.0.2.1-2", &opts_no_dns_v4())
.await
.unwrap();
assert_eq!(ips.len(), 2);
assert!(ips.contains(&IpAddr::V4(Ipv4Addr::new(192, 0, 2, 1))));
assert!(ips.contains(&IpAddr::V4(Ipv4Addr::new(192, 0, 2, 2))));
}
#[tokio::test]
async fn expand_empty_token_errors() {
let e = expand_target(" ", &opts_no_dns_v4()).await.unwrap_err();
assert!(matches!(e, TargetError::Invalid(_)), "{e:?}");
}
#[tokio::test]
async fn ipv6_cidr_without_dash6_flag_errors() {
let e = expand_target("2001:db8::/126", &opts_no_dns_v4())
.await
.unwrap_err();
let s = e.to_string();
assert!(s.contains("-6") || s.contains("IPv6"), "unexpected: {s}");
}
#[test]
fn apply_exclude_single_ip() {
let hosts = vec![
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)),
];
let out = apply_exclude(hosts, Some("10.0.0.1"), None, &opts_no_dns_v4()).unwrap();
assert_eq!(out, vec![IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2))]);
}
#[test]
fn random_addresses_count_and_family() {
let v4 = random_addresses(7, false);
assert_eq!(v4.len(), 7);
assert!(v4.iter().all(|a| a.is_ipv4()));
let v6 = random_addresses(4, true);
assert_eq!(v6.len(), 4);
assert!(v6.iter().all(|a| a.is_ipv6()));
}
#[tokio::test]
async fn expand_single_ipv4_literal() {
let ips = expand_target("192.0.2.55", &opts_no_dns_v4())
.await
.unwrap();
assert_eq!(ips, vec![IpAddr::V4(Ipv4Addr::new(192, 0, 2, 55))]);
}
#[tokio::test]
async fn expand_ipv6_with_dash6_flag() {
let opts = ExpandOpts {
ipv6: true,
no_dns: true,
resolve_all: false,
dns_servers: vec![],
};
let ips = expand_target("2001:db8::1", &opts).await.unwrap();
assert_eq!(ips.len(), 1);
assert!(ips[0].is_ipv6());
}
#[test]
fn apply_exclude_cidr_removes_block() {
let hosts = vec![
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)),
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 3)),
];
let out = apply_exclude(hosts, Some("10.0.0.0/31"), None, &opts_no_dns_v4()).unwrap();
assert_eq!(
out,
vec![
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)),
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 3)),
]
);
}
#[test]
fn apply_exclude_file_reads_lines() {
let mut f = NamedTempFile::new().unwrap();
writeln!(f, "10.0.0.2").unwrap();
writeln!(f, "# comment").unwrap();
f.flush().unwrap();
let hosts = vec![
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)),
];
let out = apply_exclude(hosts, None, Some(f.path()), &opts_no_dns_v4()).unwrap();
assert_eq!(out, vec![IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))]);
}
#[test]
fn random_addresses_zero_returns_empty() {
assert!(random_addresses(0, false).is_empty());
assert!(random_addresses(0, true).is_empty());
}
#[tokio::test]
async fn cidr_slash_32_is_one_host() {
let ips = expand_target("203.0.113.5/32", &opts_no_dns_v4())
.await
.unwrap();
assert_eq!(ips.len(), 1);
}
#[test]
fn read_input_list_strips_windows_crlf() {
let mut f = NamedTempFile::new().unwrap();
write!(f, "10.0.0.9\r\n").unwrap();
f.flush().unwrap();
let lines = read_input_list(f.path()).unwrap();
assert_eq!(lines, vec!["10.0.0.9"]);
}
#[test]
fn apply_exclude_empty_exclude_leaves_hosts() {
let hosts = vec![IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1))];
let out = apply_exclude(hosts.clone(), None, None, &opts_no_dns_v4()).unwrap();
assert_eq!(out, hosts);
}
#[tokio::test]
async fn expand_ipv4_range_three_hosts() {
let ips = expand_target("192.0.2.1-3", &opts_no_dns_v4())
.await
.unwrap();
assert_eq!(ips.len(), 3);
}
#[test]
fn apply_exclude_comma_separated_two_ips() {
let hosts = vec![
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)),
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 3)),
];
let out = apply_exclude(hosts, Some("10.0.0.1,10.0.0.3"), None, &opts_no_dns_v4()).unwrap();
assert_eq!(out, vec![IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2))]);
}
#[tokio::test]
async fn expand_invalid_ipv4_literal_errors() {
let e = expand_target("999.999.999.999", &opts_no_dns_v4())
.await
.unwrap_err();
assert!(matches!(e, TargetError::Invalid(_)));
}
#[tokio::test]
async fn expand_ipv6_cidr_with_dash6() {
let opts = ExpandOpts {
ipv6: true,
no_dns: true,
resolve_all: false,
dns_servers: vec![],
};
let ips = expand_target("2001:db8::/127", &opts).await.unwrap();
assert_eq!(ips.len(), 2);
}
#[test]
fn apply_exclude_trims_whitespace_in_csv() {
let hosts = vec![IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))];
let out = apply_exclude(hosts, Some(" 10.0.0.1 "), None, &opts_no_dns_v4()).unwrap();
assert!(out.is_empty());
}
#[test]
fn random_addresses_distinct_samples_usually() {
let v4 = random_addresses(20, false);
let uniq: std::collections::HashSet<_> = v4.iter().collect();
assert!(uniq.len() > 1, "20 random v4 draws should not all collide");
}
#[tokio::test]
async fn expand_ipv4_last_octet_range_four_hosts() {
let ips = expand_target("10.0.0.1-4", &opts_no_dns_v4())
.await
.unwrap();
assert_eq!(ips.len(), 4);
}
#[tokio::test]
async fn expand_ipv4_reversed_octet_range_errors() {
let e = expand_target("10.0.5-1.0.1", &opts_no_dns_v4())
.await
.unwrap_err();
assert!(matches!(e, TargetError::Invalid(_)));
}
#[test]
fn random_addresses_ipv6_returns_requested_count() {
let v6 = random_addresses(12, true);
assert_eq!(v6.len(), 12);
assert!(v6.iter().all(|a| a.is_ipv6()));
}
#[test]
fn read_input_list_trims_leading_trailing_whitespace() {
let mut f = NamedTempFile::new().unwrap();
writeln!(f, " 203.0.113.5 ").unwrap();
f.flush().unwrap();
assert_eq!(read_input_list(f.path()).unwrap(), vec!["203.0.113.5"]);
}
#[test]
fn apply_exclude_single_ip_removes_one_host() {
let hosts = vec![
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2)),
];
let out = apply_exclude(hosts, Some("10.0.0.1"), None, &opts_no_dns_v4()).unwrap();
assert_eq!(out, vec![IpAddr::V4(Ipv4Addr::new(10, 0, 0, 2))]);
}
#[tokio::test]
async fn expand_ipv4_literal_loopback() {
let ips = expand_target("127.0.0.1", &opts_no_dns_v4()).await.unwrap();
assert_eq!(ips, vec![IpAddr::V4(Ipv4Addr::LOCALHOST)]);
}
#[tokio::test]
async fn expand_ipv4_slash_30_yields_two_usable_hosts() {
let ips = expand_target("192.0.2.0/30", &opts_no_dns_v4())
.await
.unwrap();
assert_eq!(ips.len(), 2);
}
#[test]
fn apply_exclude_file_ignores_hash_comments() {
let mut f = NamedTempFile::new().unwrap();
writeln!(f, "# excluded").unwrap();
writeln!(f, "10.0.0.99").unwrap();
f.flush().unwrap();
let hosts = vec![IpAddr::V4(Ipv4Addr::new(10, 0, 0, 99))];
let out = apply_exclude(hosts, None, Some(f.path()), &opts_no_dns_v4()).unwrap();
assert!(out.is_empty());
}
#[tokio::test]
async fn expand_ipv4_third_octet_range() {
let ips = expand_target("10.0.1-2.5", &opts_no_dns_v4())
.await
.unwrap();
assert_eq!(ips.len(), 2);
assert!(ips.contains(&IpAddr::V4(Ipv4Addr::new(10, 0, 1, 5))));
assert!(ips.contains(&IpAddr::V4(Ipv4Addr::new(10, 0, 2, 5))));
}
#[test]
fn random_addresses_one_ipv4_is_valid() {
let v = random_addresses(1, false);
assert_eq!(v.len(), 1);
assert!(v[0].is_ipv4());
}
#[tokio::test]
async fn expand_ipv4_first_octet_range() {
let ips = expand_target("10-11.0.0.1", &opts_no_dns_v4())
.await
.unwrap();
assert_eq!(ips.len(), 2);
}
#[test]
fn expand_octet_single_number() {
assert_eq!(super::expand_octet("42").unwrap(), vec![42]);
}
#[test]
fn expand_octet_range_inclusive() {
assert_eq!(super::expand_octet("1-3").unwrap(), vec![1, 2, 3]);
}
#[test]
fn expand_octet_invalid_token_errors() {
assert!(super::expand_octet("x").is_err());
}
#[test]
fn random_addresses_ipv6_when_requested() {
let v = random_addresses(3, true);
assert_eq!(v.len(), 3);
assert!(v.iter().all(|ip| ip.is_ipv6()));
}
#[test]
fn expand_octet_range_single_value() {
assert_eq!(super::expand_octet("5-5").unwrap(), vec![5]);
}
#[test]
fn expand_octet_reversed_range_errors() {
assert!(super::expand_octet("5-3").is_err());
}
#[test]
fn expand_octet_out_of_range_errors() {
assert!(super::expand_octet("256").is_err());
}
#[test]
fn read_input_list_skips_blank_lines() {
let mut f = NamedTempFile::new().unwrap();
writeln!(f).unwrap();
writeln!(f, "10.0.0.1").unwrap();
writeln!(f, " ").unwrap();
f.flush().unwrap();
assert_eq!(read_input_list(f.path()).unwrap(), vec!["10.0.0.1"]);
}
#[test]
fn read_input_list_skips_hash_comments() {
let mut f = NamedTempFile::new().unwrap();
writeln!(f, "# skip").unwrap();
writeln!(f, "10.0.0.2").unwrap();
f.flush().unwrap();
assert_eq!(read_input_list(f.path()).unwrap(), vec!["10.0.0.2"]);
}
#[test]
fn random_addresses_large_count_honors_request() {
let v = random_addresses(100_000, false);
assert_eq!(v.len(), 100_000);
}
#[test]
fn apply_exclude_ipv4_cidr() {
let hosts = vec![IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))];
let out = apply_exclude(hosts, Some("10.0.0.0/30"), None, &opts_no_dns_v4()).unwrap();
assert!(out.is_empty());
}
#[tokio::test]
async fn expand_ipv4_second_octet_range() {
let ips = expand_target("10.1-2.0.1", &opts_no_dns_v4())
.await
.unwrap();
assert_eq!(ips.len(), 2);
}
#[tokio::test]
async fn expand_ipv4_slash_31_two_hosts() {
let ips = expand_target("192.0.2.0/31", &opts_no_dns_v4())
.await
.unwrap();
assert_eq!(ips.len(), 2);
}
#[test]
fn expand_octet_zero() {
assert_eq!(super::expand_octet("0").unwrap(), vec![0]);
}
#[test]
fn expand_octet_max_255() {
assert_eq!(super::expand_octet("255").unwrap(), vec![255]);
}
}