1use std::collections::HashMap;
2use std::collections::hash_map::Entry;
3use std::error::Error;
4use std::num::NonZeroUsize;
5use std::path::PathBuf;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8
9use hickory_resolver::TokioAsyncResolver;
10use hickory_resolver::config::{LookupIpStrategy, ResolverConfig, ResolverOpts};
11use hickory_resolver::system_conf;
12use log::{info, warn};
13use serde::Serialize;
14use tokio::sync::Semaphore;
15use tokio::task::JoinSet;
16
17#[cfg(all(target_os = "linux", feature = "gnu-c"))]
18mod gnu_c_backend;
19#[cfg(feature = "python")]
20mod python;
21
22#[derive(Debug, Clone, Copy)]
23pub enum Backend {
24 Hickory,
25 #[cfg(all(target_os = "linux", feature = "gnu-c"))]
26 GnuC,
27}
28
29impl Backend {
30 fn as_str(self) -> &'static str {
31 match self {
32 Backend::Hickory => "hickory",
33 #[cfg(all(target_os = "linux", feature = "gnu-c"))]
34 Backend::GnuC => "gnu-c",
35 }
36 }
37}
38
39#[derive(Debug)]
40pub struct Config {
41 pub input: PathBuf,
42 pub output: PathBuf,
43 pub backend: Backend,
44 pub concurrency: NonZeroUsize,
45}
46
47#[derive(Debug)]
48struct LineEntry {
49 line: String,
50 domain: Option<String>,
51}
52
53#[derive(Debug)]
54enum LineResult {
55 Checked { alive: bool },
56 Invalid,
57 Error,
58}
59
60#[derive(Debug, Serialize)]
61#[serde(rename_all = "lowercase")]
62enum Status {
63 Alive,
64 Dead,
65 Invalid,
66 Error,
67}
68
69#[derive(Debug, Serialize)]
70struct OutputRecord {
71 input: String,
72 domain: Option<String>,
73 status: Status,
74}
75
76fn extract_domain(raw: &str) -> Option<String> {
77 let trimmed = raw.trim();
78 if trimmed.is_empty() || trimmed.starts_with('#') {
79 return None;
80 }
81
82 let without_scheme = trimmed.split("://").nth(1).unwrap_or(trimmed);
83 let without_path = without_scheme.split(['/', '?', '#']).next().unwrap_or("");
84 let without_creds = without_path
85 .rsplit_once('@')
86 .map(|(_, host)| host)
87 .unwrap_or(without_path);
88
89 if let Some(host) = without_creds.strip_prefix('[')
90 && let Some(end) = host.find(']')
91 {
92 let ipv6 = &host[..end];
93 if !ipv6.is_empty() {
94 return Some(ipv6.to_ascii_lowercase());
95 }
96 }
97
98 let host = without_creds.split(':').next().unwrap_or("");
99 if host.is_empty() {
100 return None;
101 }
102 Some(host.to_ascii_lowercase())
103}
104
105fn tune_resolver_opts(opts: &mut ResolverOpts) {
106 opts.cache_size = 1024;
107 opts.timeout = Duration::from_secs(3);
108 opts.attempts = 1;
109 opts.ip_strategy = LookupIpStrategy::Ipv4Only;
110 opts.positive_min_ttl = Some(Duration::from_secs(30));
111 opts.negative_min_ttl = Some(Duration::from_secs(30));
112}
113
114pub async fn run(config: Config) -> Result<(), Box<dyn Error>> {
115 let started = Instant::now();
116 info!("Reading input from {}", config.input.display());
117 let input = tokio::fs::read_to_string(&config.input).await?;
118 info!("Using backend {}", config.backend.as_str());
119
120 let entries: Vec<LineEntry> = input
121 .lines()
122 .filter_map(|line| {
123 let trimmed = line.trim();
124 if trimmed.is_empty() || trimmed.starts_with('#') {
125 return None;
126 }
127 let domain = extract_domain(trimmed);
128 Some(LineEntry {
129 line: trimmed.to_string(),
130 domain,
131 })
132 })
133 .collect();
134
135 let total_entries = entries.len();
136 info!("Parsed {} entries", total_entries);
137
138 let mut results: Vec<Option<LineResult>> = (0..entries.len()).map(|_| None).collect();
139 let mut domain_indices: HashMap<String, Vec<usize>> = HashMap::new();
140 let mut unique_domains = Vec::new();
141
142 for (idx, entry) in entries.iter().enumerate() {
143 if let Some(domain) = entry.domain.as_ref() {
144 match domain_indices.entry(domain.clone()) {
145 Entry::Occupied(mut existing) => existing.get_mut().push(idx),
146 Entry::Vacant(vacant) => {
147 vacant.insert(vec![idx]);
148 unique_domains.push(domain.clone());
149 }
150 }
151 } else {
152 results[idx] = Some(LineResult::Invalid);
153 }
154 }
155
156 let unique_total = unique_domains.len();
157 info!(
158 "Unique domains: {} (deduped from {})",
159 unique_total,
160 entries.len()
161 );
162 let log_every = if unique_total < 10 {
163 1
164 } else {
165 unique_total / 10
166 };
167 let mut processed = 0usize;
168 let log_progress = |processed: usize| {
169 if unique_total > 0 && (processed.is_multiple_of(log_every) || processed == unique_total) {
170 let percent = processed * 100 / unique_total;
171 info!("Progress: {}/{} ({}%)", processed, unique_total, percent);
172 }
173 };
174
175 match config.backend {
176 Backend::Hickory => {
177 let concurrency = config.concurrency.get();
178 info!("Using concurrency {}", concurrency);
179 let resolver = match system_conf::read_system_conf() {
180 Ok((config, mut opts)) => {
181 tune_resolver_opts(&mut opts);
182 info!(
183 "Resolver opts: cache_size={}, attempts={}, timeout={:?}, ipv4_only=true",
184 opts.cache_size, opts.attempts, opts.timeout
185 );
186 TokioAsyncResolver::tokio(config, opts)
187 }
188 Err(err) => {
189 warn!(
190 "Failed to load system DNS config ({}), falling back to default resolver",
191 err
192 );
193 let mut opts = ResolverOpts::default();
194 tune_resolver_opts(&mut opts);
195 info!(
196 "Resolver opts: cache_size={}, attempts={}, timeout={:?}, ipv4_only=true",
197 opts.cache_size, opts.attempts, opts.timeout
198 );
199 TokioAsyncResolver::tokio(ResolverConfig::default(), opts)
200 }
201 };
202
203 info!("Clearing Hickory resolver cache before lookups");
204 resolver.clear_cache();
205
206 let mut join_set = JoinSet::new();
207 let semaphore = Arc::new(Semaphore::new(concurrency));
208 info!(
209 "Scheduling {} lookups (concurrency={}) for Hickory backend",
210 unique_total, concurrency
211 );
212
213 for domain in unique_domains.iter().cloned() {
214 let resolver = resolver.clone();
215 let permit = semaphore.clone().acquire_owned().await?;
216 join_set.spawn(async move {
217 let _permit = permit;
218 let alive = resolver
219 .lookup_ip(domain.as_str())
220 .await
221 .map(|ips| ips.iter().next().is_some())
222 .unwrap_or(false);
223 (domain, alive)
224 });
225 }
226
227 while let Some(result) = join_set.join_next().await {
228 match result {
229 Ok((domain, alive)) => {
230 if let Some(indices) = domain_indices.remove(&domain) {
231 for idx in &indices {
232 results[*idx] = Some(LineResult::Checked { alive });
233 }
234 processed += 1;
235 log_progress(processed);
236 } else {
237 warn!("Received result for unknown domain {}", domain);
238 }
239 }
240 Err(err) => {
241 warn!("DNS check task failed: {}", err);
242 }
243 }
244 }
245 }
246 #[cfg(all(target_os = "linux", feature = "gnu-c"))]
247 Backend::GnuC => {
248 let concurrency = config.concurrency.get();
249 let ipv4_only = true;
250 info!(
251 "GNU C backend batch_size {}, ipv4_only={}",
252 concurrency, ipv4_only
253 );
254 let mut join_set = JoinSet::new();
255 let domains = Arc::new(unique_domains);
256 let total_batches = (domains.len() + concurrency.saturating_sub(1)) / concurrency;
257 info!(
258 "Scheduling {} batches (batch_size={}) for GNU C backend",
259 total_batches, concurrency
260 );
261 for batch_idx in 0..total_batches {
262 let start = batch_idx * concurrency;
263 let end = (start + concurrency).min(domains.len());
264 info!(
265 "Starting batch {}/{} ({} domains)",
266 batch_idx + 1,
267 total_batches,
268 end - start
269 );
270 let domains = Arc::clone(&domains);
271 join_set.spawn_blocking(move || {
272 let resolved =
273 gnu_c_backend::resolve_domains_gnu_c(&domains[start..end], ipv4_only);
274 (start, resolved)
275 });
276 }
277
278 while let Some(result) = join_set.join_next().await {
279 match result {
280 Ok((start, resolved)) => {
281 info!("Batch completed with {} results", resolved.len());
282 for (offset, outcome) in resolved.into_iter().enumerate() {
283 let domain = &domains[start + offset];
284 match outcome {
285 Some(alive) => {
286 if let Some(indices) = domain_indices.remove(domain) {
287 for idx in &indices {
288 results[*idx] = Some(LineResult::Checked { alive });
289 }
290 processed += 1;
291 log_progress(processed);
292 } else {
293 warn!("Received result for unknown domain {}", domain);
294 }
295 }
296 None => {
297 if let Some(indices) = domain_indices.remove(domain) {
298 for idx in &indices {
299 results[*idx] = Some(LineResult::Error);
300 }
301 processed += 1;
302 log_progress(processed);
303 } else {
304 warn!("Received result for unknown domain {}", domain);
305 }
306 }
307 }
308 }
309 }
310 Err(err) => {
311 warn!("DNS check task failed: {}", err);
312 }
313 }
314 }
315 }
316 }
317
318 if !domain_indices.is_empty() {
319 for (_, indices) in domain_indices.drain() {
320 for idx in indices {
321 results[idx] = Some(LineResult::Error);
322 }
323 processed += 1;
324 log_progress(processed);
325 }
326 }
327
328 for slot in results.iter_mut() {
329 if slot.is_none() {
330 *slot = Some(LineResult::Error);
331 }
332 }
333
334 let mut alive_count = 0usize;
335 let mut dead_count = 0usize;
336 let mut invalid_count = 0usize;
337 let mut error_count = 0usize;
338 let mut records = Vec::with_capacity(entries.len());
339
340 for (idx, result) in results.into_iter().enumerate() {
341 let entry = &entries[idx];
342 let (status, domain) = match result {
343 Some(LineResult::Checked { alive }) => {
344 if alive {
345 alive_count += 1;
346 (Status::Alive, entry.domain.clone())
347 } else {
348 dead_count += 1;
349 (Status::Dead, entry.domain.clone())
350 }
351 }
352 Some(LineResult::Invalid) => {
353 invalid_count += 1;
354 (Status::Invalid, None)
355 }
356 Some(LineResult::Error) | None => {
357 error_count += 1;
358 (Status::Error, entry.domain.clone())
359 }
360 };
361
362 records.push(OutputRecord {
363 input: entry.line.clone(),
364 domain,
365 status,
366 });
367 }
368
369 let output = serde_json::to_string_pretty(&records)?;
370 tokio::fs::write(&config.output, output).await?;
371
372 info!(
373 "Wrote {} results to {}",
374 records.len(),
375 config.output.display()
376 );
377 let elapsed = started.elapsed();
378 let elapsed_secs = elapsed.as_secs_f64();
379 let speed = if elapsed_secs > 0.0 {
380 records.len() as f64 / elapsed_secs
381 } else {
382 0.0
383 };
384 info!(
385 "Summary: alive={}, dead={}, invalid={}, error={}",
386 alive_count, dead_count, invalid_count, error_count
387 );
388 info!("Elapsed: {:.2?}", elapsed);
389 info!("Speed: {:.2} entries/sec", speed);
390 Ok(())
391}