dns_checker/
lib.rs

1use std::collections::HashMap;
2use std::collections::hash_map::Entry;
3use std::error::Error;
4use std::num::NonZeroUsize;
5use std::path::PathBuf;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8
9use hickory_resolver::TokioAsyncResolver;
10use hickory_resolver::config::{LookupIpStrategy, ResolverConfig, ResolverOpts};
11use hickory_resolver::system_conf;
12use log::{info, warn};
13use serde::Serialize;
14use tokio::sync::Semaphore;
15use tokio::task::JoinSet;
16
17#[cfg(all(target_os = "linux", feature = "gnu-c"))]
18mod gnu_c_backend;
19#[cfg(feature = "python")]
20mod python;
21
22#[derive(Debug, Clone, Copy)]
23pub enum Backend {
24    Hickory,
25    #[cfg(all(target_os = "linux", feature = "gnu-c"))]
26    GnuC,
27}
28
29impl Backend {
30    fn as_str(self) -> &'static str {
31        match self {
32            Backend::Hickory => "hickory",
33            #[cfg(all(target_os = "linux", feature = "gnu-c"))]
34            Backend::GnuC => "gnu-c",
35        }
36    }
37}
38
39#[derive(Debug)]
40pub struct Config {
41    pub input: PathBuf,
42    pub output: PathBuf,
43    pub backend: Backend,
44    pub concurrency: NonZeroUsize,
45}
46
47#[derive(Debug)]
48struct LineEntry {
49    line: String,
50    domain: Option<String>,
51}
52
53#[derive(Debug)]
54enum LineResult {
55    Checked { alive: bool },
56    Invalid,
57    Error,
58}
59
60#[derive(Debug, Serialize)]
61#[serde(rename_all = "lowercase")]
62enum Status {
63    Alive,
64    Dead,
65    Invalid,
66    Error,
67}
68
69#[derive(Debug, Serialize)]
70struct OutputRecord {
71    input: String,
72    domain: Option<String>,
73    status: Status,
74}
75
76fn extract_domain(raw: &str) -> Option<String> {
77    let trimmed = raw.trim();
78    if trimmed.is_empty() || trimmed.starts_with('#') {
79        return None;
80    }
81
82    let without_scheme = trimmed.split("://").nth(1).unwrap_or(trimmed);
83    let without_path = without_scheme.split(['/', '?', '#']).next().unwrap_or("");
84    let without_creds = without_path
85        .rsplit_once('@')
86        .map(|(_, host)| host)
87        .unwrap_or(without_path);
88
89    if let Some(host) = without_creds.strip_prefix('[')
90        && let Some(end) = host.find(']')
91    {
92        let ipv6 = &host[..end];
93        if !ipv6.is_empty() {
94            return Some(ipv6.to_ascii_lowercase());
95        }
96    }
97
98    let host = without_creds.split(':').next().unwrap_or("");
99    if host.is_empty() {
100        return None;
101    }
102    Some(host.to_ascii_lowercase())
103}
104
105fn tune_resolver_opts(opts: &mut ResolverOpts) {
106    opts.cache_size = 1024;
107    opts.timeout = Duration::from_secs(3);
108    opts.attempts = 1;
109    opts.ip_strategy = LookupIpStrategy::Ipv4Only;
110    opts.positive_min_ttl = Some(Duration::from_secs(30));
111    opts.negative_min_ttl = Some(Duration::from_secs(30));
112}
113
114pub async fn run(config: Config) -> Result<(), Box<dyn Error>> {
115    let started = Instant::now();
116    info!("Reading input from {}", config.input.display());
117    let input = tokio::fs::read_to_string(&config.input).await?;
118    info!("Using backend {}", config.backend.as_str());
119
120    let entries: Vec<LineEntry> = input
121        .lines()
122        .filter_map(|line| {
123            let trimmed = line.trim();
124            if trimmed.is_empty() || trimmed.starts_with('#') {
125                return None;
126            }
127            let domain = extract_domain(trimmed);
128            Some(LineEntry {
129                line: trimmed.to_string(),
130                domain,
131            })
132        })
133        .collect();
134
135    let total_entries = entries.len();
136    info!("Parsed {} entries", total_entries);
137
138    let mut results: Vec<Option<LineResult>> = (0..entries.len()).map(|_| None).collect();
139    let mut domain_indices: HashMap<String, Vec<usize>> = HashMap::new();
140    let mut unique_domains = Vec::new();
141
142    for (idx, entry) in entries.iter().enumerate() {
143        if let Some(domain) = entry.domain.as_ref() {
144            match domain_indices.entry(domain.clone()) {
145                Entry::Occupied(mut existing) => existing.get_mut().push(idx),
146                Entry::Vacant(vacant) => {
147                    vacant.insert(vec![idx]);
148                    unique_domains.push(domain.clone());
149                }
150            }
151        } else {
152            results[idx] = Some(LineResult::Invalid);
153        }
154    }
155
156    let unique_total = unique_domains.len();
157    info!(
158        "Unique domains: {} (deduped from {})",
159        unique_total,
160        entries.len()
161    );
162    let log_every = if unique_total < 10 {
163        1
164    } else {
165        unique_total / 10
166    };
167    let mut processed = 0usize;
168    let log_progress = |processed: usize| {
169        if unique_total > 0 && (processed.is_multiple_of(log_every) || processed == unique_total) {
170            let percent = processed * 100 / unique_total;
171            info!("Progress: {}/{} ({}%)", processed, unique_total, percent);
172        }
173    };
174
175    match config.backend {
176        Backend::Hickory => {
177            let concurrency = config.concurrency.get();
178            info!("Using concurrency {}", concurrency);
179            let resolver = match system_conf::read_system_conf() {
180                Ok((config, mut opts)) => {
181                    tune_resolver_opts(&mut opts);
182                    info!(
183                        "Resolver opts: cache_size={}, attempts={}, timeout={:?}, ipv4_only=true",
184                        opts.cache_size, opts.attempts, opts.timeout
185                    );
186                    TokioAsyncResolver::tokio(config, opts)
187                }
188                Err(err) => {
189                    warn!(
190                        "Failed to load system DNS config ({}), falling back to default resolver",
191                        err
192                    );
193                    let mut opts = ResolverOpts::default();
194                    tune_resolver_opts(&mut opts);
195                    info!(
196                        "Resolver opts: cache_size={}, attempts={}, timeout={:?}, ipv4_only=true",
197                        opts.cache_size, opts.attempts, opts.timeout
198                    );
199                    TokioAsyncResolver::tokio(ResolverConfig::default(), opts)
200                }
201            };
202
203            info!("Clearing Hickory resolver cache before lookups");
204            resolver.clear_cache();
205
206            let mut join_set = JoinSet::new();
207            let semaphore = Arc::new(Semaphore::new(concurrency));
208            info!(
209                "Scheduling {} lookups (concurrency={}) for Hickory backend",
210                unique_total, concurrency
211            );
212
213            for domain in unique_domains.iter().cloned() {
214                let resolver = resolver.clone();
215                let permit = semaphore.clone().acquire_owned().await?;
216                join_set.spawn(async move {
217                    let _permit = permit;
218                    let alive = resolver
219                        .lookup_ip(domain.as_str())
220                        .await
221                        .map(|ips| ips.iter().next().is_some())
222                        .unwrap_or(false);
223                    (domain, alive)
224                });
225            }
226
227            while let Some(result) = join_set.join_next().await {
228                match result {
229                    Ok((domain, alive)) => {
230                        if let Some(indices) = domain_indices.remove(&domain) {
231                            for idx in &indices {
232                                results[*idx] = Some(LineResult::Checked { alive });
233                            }
234                            processed += 1;
235                            log_progress(processed);
236                        } else {
237                            warn!("Received result for unknown domain {}", domain);
238                        }
239                    }
240                    Err(err) => {
241                        warn!("DNS check task failed: {}", err);
242                    }
243                }
244            }
245        }
246        #[cfg(all(target_os = "linux", feature = "gnu-c"))]
247        Backend::GnuC => {
248            let concurrency = config.concurrency.get();
249            let ipv4_only = true;
250            info!(
251                "GNU C backend batch_size {}, ipv4_only={}",
252                concurrency, ipv4_only
253            );
254            let mut join_set = JoinSet::new();
255            let domains = Arc::new(unique_domains);
256            let total_batches = (domains.len() + concurrency.saturating_sub(1)) / concurrency;
257            info!(
258                "Scheduling {} batches (batch_size={}) for GNU C backend",
259                total_batches, concurrency
260            );
261            for batch_idx in 0..total_batches {
262                let start = batch_idx * concurrency;
263                let end = (start + concurrency).min(domains.len());
264                info!(
265                    "Starting batch {}/{} ({} domains)",
266                    batch_idx + 1,
267                    total_batches,
268                    end - start
269                );
270                let domains = Arc::clone(&domains);
271                join_set.spawn_blocking(move || {
272                    let resolved =
273                        gnu_c_backend::resolve_domains_gnu_c(&domains[start..end], ipv4_only);
274                    (start, resolved)
275                });
276            }
277
278            while let Some(result) = join_set.join_next().await {
279                match result {
280                    Ok((start, resolved)) => {
281                        info!("Batch completed with {} results", resolved.len());
282                        for (offset, outcome) in resolved.into_iter().enumerate() {
283                            let domain = &domains[start + offset];
284                            match outcome {
285                                Some(alive) => {
286                                    if let Some(indices) = domain_indices.remove(domain) {
287                                        for idx in &indices {
288                                            results[*idx] = Some(LineResult::Checked { alive });
289                                        }
290                                        processed += 1;
291                                        log_progress(processed);
292                                    } else {
293                                        warn!("Received result for unknown domain {}", domain);
294                                    }
295                                }
296                                None => {
297                                    if let Some(indices) = domain_indices.remove(domain) {
298                                        for idx in &indices {
299                                            results[*idx] = Some(LineResult::Error);
300                                        }
301                                        processed += 1;
302                                        log_progress(processed);
303                                    } else {
304                                        warn!("Received result for unknown domain {}", domain);
305                                    }
306                                }
307                            }
308                        }
309                    }
310                    Err(err) => {
311                        warn!("DNS check task failed: {}", err);
312                    }
313                }
314            }
315        }
316    }
317
318    if !domain_indices.is_empty() {
319        for (_, indices) in domain_indices.drain() {
320            for idx in indices {
321                results[idx] = Some(LineResult::Error);
322            }
323            processed += 1;
324            log_progress(processed);
325        }
326    }
327
328    for slot in results.iter_mut() {
329        if slot.is_none() {
330            *slot = Some(LineResult::Error);
331        }
332    }
333
334    let mut alive_count = 0usize;
335    let mut dead_count = 0usize;
336    let mut invalid_count = 0usize;
337    let mut error_count = 0usize;
338    let mut records = Vec::with_capacity(entries.len());
339
340    for (idx, result) in results.into_iter().enumerate() {
341        let entry = &entries[idx];
342        let (status, domain) = match result {
343            Some(LineResult::Checked { alive }) => {
344                if alive {
345                    alive_count += 1;
346                    (Status::Alive, entry.domain.clone())
347                } else {
348                    dead_count += 1;
349                    (Status::Dead, entry.domain.clone())
350                }
351            }
352            Some(LineResult::Invalid) => {
353                invalid_count += 1;
354                (Status::Invalid, None)
355            }
356            Some(LineResult::Error) | None => {
357                error_count += 1;
358                (Status::Error, entry.domain.clone())
359            }
360        };
361
362        records.push(OutputRecord {
363            input: entry.line.clone(),
364            domain,
365            status,
366        });
367    }
368
369    let output = serde_json::to_string_pretty(&records)?;
370    tokio::fs::write(&config.output, output).await?;
371
372    info!(
373        "Wrote {} results to {}",
374        records.len(),
375        config.output.display()
376    );
377    let elapsed = started.elapsed();
378    let elapsed_secs = elapsed.as_secs_f64();
379    let speed = if elapsed_secs > 0.0 {
380        records.len() as f64 / elapsed_secs
381    } else {
382        0.0
383    };
384    info!(
385        "Summary: alive={}, dead={}, invalid={}, error={}",
386        alive_count, dead_count, invalid_count, error_count
387    );
388    info!("Elapsed: {:.2?}", elapsed);
389    info!("Speed: {:.2} entries/sec", speed);
390    Ok(())
391}