1use std::{
12 collections::{BTreeMap, HashSet},
13 future::Future,
14 sync::Arc,
15 time::{Duration, Instant},
16};
17
18use rand::seq::SliceRandom;
19use tokio::{
20 sync::{mpsc, Mutex, Semaphore},
21 task::JoinSet,
22};
23use tracing::{debug, error, info, warn};
24use url::Url;
25
26use crate::{
27 config::Config,
28 discovery::{
29 common_paths::CommonPathDiscovery, headers::HeaderDiscovery, js::JsDiscovery,
30 robots::RobotsDiscovery, sitemap::SitemapDiscovery, swagger::SwaggerDiscovery,
31 },
32 error::CapturedError,
33 http_client::HttpClient,
34 progress_tracker::{ProgressConfig, ProgressTracker},
35 reports::{Finding, Reporter},
36 scanner::{
37 api_security::ApiSecurityScanner, cors::CorsScanner, csp::CspScanner,
38 cve_templates::CveTemplateScanner, graphql::GraphqlScanner, jwt::JwtScanner,
39 mass_assignment::MassAssignmentScanner, oauth_oidc::OAuthOidcScanner,
40 openapi::OpenApiScanner, rate_limit::RateLimitScanner, websocket::WebSocketScanner,
41 Scanner,
42 },
43};
44
45#[derive(Debug, Default)]
49pub struct RunResult {
50 pub findings: Vec<Finding>,
52 pub errors: Vec<CapturedError>,
54 pub elapsed: Duration,
56 pub scanned: usize,
58 pub skipped: usize,
60 pub metrics: RuntimeMetrics,
62}
63
64#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
66pub struct RuntimeMetrics {
67 pub http_requests: u64,
69 pub http_retries: u64,
71 pub scanner_findings: BTreeMap<String, usize>,
73 pub scanner_errors: BTreeMap<String, usize>,
75}
76
77#[derive(Debug, Default, Clone)]
78struct ScannerRunStats {
79 findings: usize,
80 errors: usize,
81}
82
83type ScannerStatsMap = BTreeMap<String, ScannerRunStats>;
84type StreamFindingKey = (String, String);
85type StreamSeenSet = Arc<Mutex<HashSet<StreamFindingKey>>>;
86
87#[derive(Debug, Default, Clone, Copy)]
88struct UrlScanSummary {
89 findings: usize,
90 critical: usize,
91 high: usize,
92 medium: usize,
93}
94
95pub async fn run(
97 urls: Vec<String>,
98 config: Arc<Config>,
99 http_client: Arc<HttpClient>,
100 http_client_b: Option<Arc<HttpClient>>,
101 reporter: Arc<Reporter>,
102 _quiet: bool,
103) -> RunResult {
104 let start = Instant::now();
105
106 let (unique_seeds, skipped_dedup) = dedup(urls);
108 info!(
109 seeds = unique_seeds.len(),
110 discovery_enabled = !config.no_discovery,
111 active_checks = config.active_checks,
112 "Scan lifecycle: seeds prepared"
113 );
114
115 let (discovered, mut discovery_errors, skipped_cap) = if config.no_discovery {
117 (Vec::new(), Vec::new(), 0usize)
118 } else {
119 run_discovery_per_site(&unique_seeds, &config, &http_client).await
120 };
121 if config.no_discovery {
122 info!(
123 seeds = unique_seeds.len(),
124 "Discovery skipped (--no-discovery)"
125 );
126 } else {
127 info!(
128 discovered = discovered.len() + unique_seeds.len(),
129 "Discovery complete"
130 );
131 }
132
133 let mut merged = unique_seeds;
134 merged.extend(discovered);
135 let (work_list, skipped_merged) = dedup(merged);
136
137 info!(
138 discovered = work_list.len(),
139 skipped_dedup, skipped_merged, "Scan lifecycle: discovery merged"
140 );
141
142 let scanned = work_list.len();
143 let skipped = skipped_dedup + skipped_merged + skipped_cap;
144
145 if scanned == 0 {
146 return RunResult {
147 elapsed: start.elapsed(),
148 skipped,
149 ..Default::default()
150 };
151 }
152
153 let semaphore = Arc::new(Semaphore::new(config.concurrency));
155 let scanners = build_scanners(&config, http_client_b.clone());
156 let scanner_names: Vec<&str> = scanners.iter().map(|s| s.name()).collect();
157 info!(
158 scanner_count = scanners.len(),
159 scanners = ?scanner_names,
160 concurrency = config.concurrency,
161 "Scan lifecycle: scanner registry ready"
162 );
163
164 let tracker = Arc::new(ProgressTracker::with_config(ProgressConfig {
166 total: scanned,
167 tty_update_frequency: 1,
168 non_tty_update_frequency: 1,
169 show_elapsed: false,
170 show_eta: false,
171 show_rate: false,
172 prefix: "".to_string(),
173 show_details: false,
174 }));
175
176 let (finding_tx, mut finding_rx) = mpsc::unbounded_channel::<Vec<Finding>>();
178 let (error_tx, mut error_rx) = mpsc::unbounded_channel::<Vec<CapturedError>>();
179 let (scanner_stats_tx, mut scanner_stats_rx) = mpsc::unbounded_channel::<ScannerStatsMap>();
180 let stream_seen: StreamSeenSet = Arc::new(Mutex::new(HashSet::new()));
181
182 let mut join_set: JoinSet<()> = JoinSet::new();
184
185 info!(
186 started_at = %chrono::Local::now().format("%Y-%m-%d %H:%M:%S"),
187 scanned,
188 "Scan started"
189 );
190
191 for url in work_list {
192 let sem = Arc::clone(&semaphore);
193 let client = Arc::clone(&http_client);
194 let scanners = scanners.clone();
195 let ftx = finding_tx.clone();
196 let etx = error_tx.clone();
197 let stx = scanner_stats_tx.clone();
198 let cfg = Arc::clone(&config);
199 let rpt = Arc::clone(&reporter);
200 let stream_seen = Arc::clone(&stream_seen);
201 let progress_handle = tracker.handle();
202
203 join_set.spawn(async move {
204 let _permit = match sem.acquire().await {
205 Ok(p) => p,
206 Err(e) => {
207 error!(url = %url, "Semaphore closed: {e}");
208 return;
209 }
210 };
211
212 let (url_summary, scanner_stats) = scan_url_with_results(
213 url.clone(),
214 &client,
215 &scanners,
216 &cfg,
217 &rpt,
218 (ftx.clone(), etx.clone()),
219 stream_seen,
220 )
221 .await;
222
223 if !scanner_stats.is_empty() {
224 let _ = stx.send(scanner_stats);
225 }
226
227 let mut msg = url.clone();
229 if url_summary.findings > 0 {
230 msg.push_str(&format!(" | 🔍 {} findings", url_summary.findings));
231 if url_summary.critical > 0 {
232 msg.push_str(&format!(" (🔴 {}C", url_summary.critical));
233 }
234 if url_summary.high > 0 {
235 msg.push_str(&format!(" 🟠 {}H", url_summary.high));
236 }
237 if url_summary.medium > 0 {
238 msg.push_str(&format!(" 🟡 {}M", url_summary.medium));
239 }
240 if url_summary.critical > 0 || url_summary.high > 0 || url_summary.medium > 0 {
241 msg.push(')');
242 }
243 } else {
244 msg.push_str(" | ✅ Clean");
245 }
246
247 progress_handle.increment(Some(&msg)).await;
249 });
250 }
251
252 drop(finding_tx);
254 drop(error_tx);
255 drop(scanner_stats_tx);
256
257 let mut findings: Vec<Finding> = Vec::new();
259 let mut errors: Vec<CapturedError> = Vec::new();
260 let mut scanner_stats: ScannerStatsMap = BTreeMap::new();
261 errors.append(&mut discovery_errors);
262
263 loop {
264 tokio::select! {
265 Some(result) = join_set.join_next() => {
266 match result {
267 Ok(()) => {}
268 Err(e) => error!("Worker task panicked: {e}"),
269 }
270 }
271 Some(batch) = finding_rx.recv() => {
272 findings.extend(batch);
273 }
274 Some(batch) = error_rx.recv() => {
275 errors.extend(batch);
276 }
277 Some(batch) = scanner_stats_rx.recv() => {
278 merge_scanner_stats(&mut scanner_stats, batch);
279 }
280 else => break,
281 }
282 }
283
284 tracker.finish().await;
286
287 findings = crate::reports::dedup_findings(findings);
288 sort_findings(&mut findings);
289 dedup_errors(&mut errors);
290
291 let elapsed = start.elapsed();
292 let primary_http_metrics = http_client.runtime_metrics();
293 let secondary_http_metrics = http_client_b
294 .as_ref()
295 .map(|client| client.runtime_metrics())
296 .unwrap_or_default();
297
298 let mut scanner_findings = BTreeMap::new();
299 let mut scanner_errors = BTreeMap::new();
300 for (name, stats) in scanner_stats {
301 scanner_findings.insert(name.clone(), stats.findings);
302 scanner_errors.insert(name, stats.errors);
303 }
304
305 info!(
306 findings = findings.len(),
307 errors = errors.len(),
308 scanned,
309 skipped,
310 elapsed_ms = elapsed.as_millis(),
311 "Scan lifecycle: completed"
312 );
313
314 info!(
315 finished_at = %chrono::Local::now().format("%Y-%m-%d %H:%M:%S"),
316 findings = findings.len(),
317 scanned,
318 elapsed_secs = elapsed.as_secs_f64(),
319 "Scan finished"
320 );
321
322 RunResult {
323 findings,
324 errors,
325 elapsed,
326 scanned,
327 skipped,
328 metrics: RuntimeMetrics {
329 http_requests: primary_http_metrics.requests_sent
330 + secondary_http_metrics.requests_sent,
331 http_retries: primary_http_metrics.retries_performed
332 + secondary_http_metrics.retries_performed,
333 scanner_findings,
334 scanner_errors,
335 },
336 }
337}
338
339async fn scan_url_with_results(
342 url: String,
343 client: &HttpClient,
344 scanners: &[RegisteredScanner],
345 config: &Config,
346 reporter: &Reporter,
347 channels: (
348 mpsc::UnboundedSender<Vec<Finding>>,
349 mpsc::UnboundedSender<Vec<CapturedError>>,
350 ),
351 stream_seen: StreamSeenSet,
352) -> (UrlScanSummary, ScannerStatsMap) {
353 debug!(url = %url, scanners = scanners.len(), "Scanning URL");
354
355 let mut scanner_set: JoinSet<(String, Vec<Finding>, Vec<CapturedError>)> = JoinSet::new();
356 let mut summary = UrlScanSummary::default();
357 let mut scanner_stats: ScannerStatsMap = BTreeMap::new();
358
359 for scanner in scanners {
360 let scanner_name = scanner.name().to_string();
361 let s = Arc::clone(&scanner.scanner);
362 let u = url.clone();
363 let client = client.clone();
364 let cfg = config.clone();
365
366 scanner_set.spawn(async move {
367 let (findings, errors) = s.scan(&u, &client, &cfg).await;
368 (scanner_name, findings, errors)
369 });
370 }
371
372 while let Some(result) = scanner_set.join_next().await {
373 match result {
374 Ok((scanner_name, mut f, e)) => {
375 let stats = scanner_stats.entry(scanner_name).or_default();
376 stats.findings += f.len();
377 stats.errors += e.len();
378
379 summary.findings += f.len();
380 summary.critical += f
381 .iter()
382 .filter(|x| matches!(x.severity, crate::reports::Severity::Critical))
383 .count();
384 summary.high += f
385 .iter()
386 .filter(|x| matches!(x.severity, crate::reports::Severity::High))
387 .count();
388 summary.medium += f
389 .iter()
390 .filter(|x| matches!(x.severity, crate::reports::Severity::Medium))
391 .count();
392
393 for finding in &mut f {
394 if finding.url.is_empty() {
395 finding.url = url.clone();
396 }
397 }
398 if reporter.stream_enabled() {
399 let mut seen = stream_seen.lock().await;
400 for finding in &f {
401 if seen.insert((finding.url.clone(), finding.check.clone())) {
402 reporter.flush_finding(finding);
403 }
404 }
405 }
406 if !f.is_empty() {
407 let (ftx, _) = &channels;
408 let _ = ftx.send(f);
409 }
410 if !e.is_empty() {
411 let (_, etx) = &channels;
412 let _ = etx.send(e);
413 }
414 }
415 Err(join_err) => {
416 let ce = CapturedError::internal(format!("Scanner panic on {url}: {join_err}"));
417 let (_, etx) = &channels;
418 let _ = etx.send(vec![ce]);
419 }
420 }
421 }
422
423 (summary, scanner_stats)
424}
425
426fn build_scanners(
429 config: &Config,
430 http_client_b: Option<Arc<HttpClient>>,
431) -> Vec<RegisteredScanner> {
432 let mut scanners: Vec<RegisteredScanner> = Vec::new();
433
434 if config.toggles.cors {
435 scanners.push(RegisteredScanner::new(Arc::new(CorsScanner::new(config))));
436 }
437 if config.toggles.csp {
438 scanners.push(RegisteredScanner::new(Arc::new(CspScanner::new(config))));
439 }
440 if config.toggles.graphql {
441 scanners.push(RegisteredScanner::new(Arc::new(GraphqlScanner::new(
442 config,
443 ))));
444 }
445 if config.toggles.api_security {
446 scanners.push(RegisteredScanner::new(Arc::new(ApiSecurityScanner::new(
447 config,
448 http_client_b.clone(),
449 ))));
450 }
451 if config.toggles.jwt {
452 scanners.push(RegisteredScanner::new(Arc::new(JwtScanner::new(config))));
453 }
454 if config.toggles.openapi {
455 scanners.push(RegisteredScanner::new(Arc::new(OpenApiScanner::new(
456 config,
457 ))));
458 }
459 if config.active_checks {
460 if config.toggles.mass_assignment {
461 scanners.push(RegisteredScanner::new(Arc::new(
462 MassAssignmentScanner::new(config),
463 )));
464 }
465 if config.toggles.oauth_oidc {
466 scanners.push(RegisteredScanner::new(Arc::new(OAuthOidcScanner::new(
467 config,
468 ))));
469 }
470 if config.toggles.rate_limit {
471 scanners.push(RegisteredScanner::new(Arc::new(RateLimitScanner::new(
472 config,
473 ))));
474 }
475 if config.toggles.cve_templates {
476 scanners.push(RegisteredScanner::new(Arc::new(CveTemplateScanner::new(
477 config,
478 ))));
479 }
480 if config.toggles.websocket {
481 scanners.push(RegisteredScanner::new(Arc::new(WebSocketScanner::new(
482 config,
483 ))));
484 }
485 }
486
487 if scanners.is_empty() {
488 warn!("All scanners disabled");
489 } else if scanners.len() > 1 && should_shuffle_scanners() {
490 let mut rng = rand::thread_rng();
492 scanners.shuffle(&mut rng);
493 }
494
495 scanners
496}
497
498#[derive(Clone)]
499struct RegisteredScanner {
500 scanner: Arc<dyn Scanner>,
501}
502
503impl RegisteredScanner {
504 fn new(scanner: Arc<dyn Scanner>) -> Self {
505 Self { scanner }
506 }
507
508 fn name(&self) -> &'static str {
509 self.scanner.name()
510 }
511}
512
513fn dedup(raw: Vec<String>) -> (Vec<String>, usize) {
516 let mut seen = HashSet::with_capacity(raw.len());
517 let mut unique = Vec::with_capacity(raw.len());
518 let mut dropped = 0usize;
519
520 for raw_url in raw {
521 let canonical = match canonicalise(&raw_url) {
522 Some(c) => c,
523 None => {
524 debug!(url = %raw_url, "URL canonicalization failed; using raw value");
525 raw_url.clone()
526 }
527 };
528
529 if seen.insert(canonical.clone()) {
530 unique.push(canonical);
531 } else {
532 dropped += 1;
533 debug!(url = %raw_url, "Duplicate URL dropped");
534 }
535 }
536
537 (unique, dropped)
538}
539
540fn canonicalise(raw: &str) -> Option<String> {
541 let mut u = Url::parse(raw).ok()?;
542
543 let host = u.host_str()?.to_ascii_lowercase();
544 u.set_host(Some(&host)).ok()?;
545
546 let default_port = match u.scheme() {
547 "https" => Some(443),
548 "http" => Some(80),
549 _ => None,
550 };
551 if u.port() == default_port {
552 u.set_port(None).ok()?;
553 }
554
555 u.set_fragment(None);
556
557 let path = u.path().to_owned();
558 if path.len() > 1 && path.ends_with('/') {
559 u.set_path(path.trim_end_matches('/'));
560 }
561
562 if let Some(query) = u.query() {
563 let mut pairs: Vec<(String, String)> = url::form_urlencoded::parse(query.as_bytes())
564 .map(|(k, v)| (k.into_owned(), v.into_owned()))
565 .collect();
566 if pairs.is_empty() {
567 u.set_query(None);
568 } else {
569 pairs.sort_unstable();
570 let mut serializer = url::form_urlencoded::Serializer::new(String::new());
571 for (k, v) in pairs {
572 serializer.append_pair(&k, &v);
573 }
574 let normalised_query = serializer.finish();
575 u.set_query(Some(&normalised_query));
576 }
577 }
578
579 Some(u.to_string())
580}
581
582async fn run_discovery_per_site(
585 seeds: &[String],
586 config: &Config,
587 client: &HttpClient,
588) -> (Vec<String>, Vec<CapturedError>, usize) {
589 const MAX_SITEMAPS: usize = 5;
590 const MAX_SCRIPTS: usize = 10;
591 let discovery_timeout_secs = config.politeness.timeout_secs.saturating_mul(2).max(1);
592 let discovery_timeout = Duration::from_secs(discovery_timeout_secs);
593
594 let mut all_discovered: HashSet<String> = HashSet::new();
595 let mut all_errors: Vec<CapturedError> = Vec::new();
596 let mut all_capped_dropped = 0usize;
597
598 let sites = group_seeds_by_site(seeds);
599
600 for (base, (host, site_seeds)) in sites {
601 let js_seed = match site_seeds.first() {
602 Some(s) => s.as_str(),
603 None => continue,
604 };
605
606 let mut site_discovered: HashSet<String> = HashSet::new();
607 let mut errors: Vec<CapturedError> = Vec::new();
608 let robots = RobotsDiscovery::new(client, &base, &host);
609 let sitemap = SitemapDiscovery::new(client, &base, &host, MAX_SITEMAPS);
610 let swagger = SwaggerDiscovery::new(client, &base, &host);
611 let js = JsDiscovery::new(client, js_seed, &host, MAX_SCRIPTS);
612 let headers = HeaderDiscovery::new(client, &base, &host);
613 let common_paths = CommonPathDiscovery::new(client, &base, config.concurrency, Vec::new());
614
615 let robots_fut =
616 run_discovery_with_timeout(&base, "robots", discovery_timeout, robots.run());
617 let sitemap_fut =
618 run_discovery_with_timeout(&base, "sitemap", discovery_timeout, sitemap.run());
619 let swagger_fut =
620 run_discovery_with_timeout(&base, "swagger", discovery_timeout, swagger.run());
621 let js_fut = run_discovery_with_timeout(&base, "js", discovery_timeout, js.run());
622 let headers_fut =
623 run_discovery_with_timeout(&base, "headers", discovery_timeout, headers.run());
624 let common_paths_fut = run_discovery_with_timeout(
625 &base,
626 "common-paths",
627 discovery_timeout,
628 common_paths.run(),
629 );
630
631 let (
632 (robots_paths, robots_errs),
633 (sitemap_paths, sitemap_errs),
634 (swagger_paths, swagger_errs),
635 (js_paths, js_errs),
636 (header_paths, header_errs),
637 (common_paths, common_errs),
638 ) = tokio::join!(
639 robots_fut,
640 sitemap_fut,
641 swagger_fut,
642 js_fut,
643 headers_fut,
644 common_paths_fut
645 );
646
647 errors.extend(robots_errs);
648 insert_paths(&base, robots_paths, &mut site_discovered);
649
650 errors.extend(sitemap_errs);
651 insert_paths(&base, sitemap_paths, &mut site_discovered);
652
653 errors.extend(swagger_errs);
654 insert_paths(&base, swagger_paths, &mut site_discovered);
655
656 errors.extend(js_errs);
657 insert_paths(&base, js_paths, &mut site_discovered);
658
659 errors.extend(header_errs);
660 insert_paths(&base, header_paths, &mut site_discovered);
661
662 errors.extend(common_errs);
663 insert_paths(&base, common_paths, &mut site_discovered);
664
665 let max_per_site = if config.max_endpoints == 0 {
666 usize::MAX
667 } else {
668 config.max_endpoints
669 };
670
671 let site_urls: Vec<String> = site_discovered.into_iter().collect();
672 let capped_count = site_urls.len().min(max_per_site);
673 let dropped_by_cap = site_urls.len().saturating_sub(capped_count);
674 all_capped_dropped += dropped_by_cap;
675
676 if site_urls.len() > max_per_site {
677 debug!(
678 site = %host,
679 discovered = site_urls.len(),
680 capped = capped_count,
681 dropped_by_cap,
682 "Site endpoints capped"
683 );
684 }
685
686 all_discovered.extend(site_urls.into_iter().take(capped_count));
687 all_errors.extend(errors);
688 }
689
690 let urls = all_discovered.into_iter().collect();
691 (urls, all_errors, all_capped_dropped)
692}
693
694fn group_seeds_by_site(seeds: &[String]) -> BTreeMap<String, (String, Vec<String>)> {
695 let mut sites: BTreeMap<String, (String, Vec<String>)> = BTreeMap::new();
696
697 for seed in seeds {
698 let parsed = match Url::parse(seed) {
699 Ok(u) => u,
700 Err(_) => continue,
701 };
702
703 let host = match parsed.host_str() {
704 Some(h) => h.to_string(),
705 None => continue,
706 };
707
708 let base = {
709 let mut b = format!("{}://{}", parsed.scheme(), host);
710 if let Some(port) = parsed.port() {
711 b.push_str(&format!(":{port}"));
712 }
713 b
714 };
715
716 sites
717 .entry(base)
718 .and_modify(|(_, list)| list.push(seed.clone()))
719 .or_insert_with(|| (host, vec![seed.clone()]));
720 }
721
722 sites
723}
724
725fn insert_paths(base: &str, paths: HashSet<String>, out: &mut HashSet<String>) {
726 let base = base.trim_end_matches('/');
727 for path in paths {
728 let url = format!("{base}{path}");
729 out.insert(url);
730 }
731}
732
733async fn run_discovery_with_timeout<F>(
734 base: &str,
735 step: &'static str,
736 timeout: Duration,
737 fut: F,
738) -> (HashSet<String>, Vec<CapturedError>)
739where
740 F: Future<Output = (HashSet<String>, Vec<CapturedError>)>,
741{
742 match tokio::time::timeout(timeout, fut).await {
743 Ok((paths, errs)) => (paths, errs),
744 Err(_) => (
745 HashSet::new(),
746 vec![CapturedError::from_str(
747 format!("discovery/{step}"),
748 Some(base.to_string()),
749 format!(
750 "Discovery step '{step}' timed out after {}s",
751 timeout.as_secs()
752 ),
753 )],
754 ),
755 }
756}
757
758fn sort_findings(findings: &mut [Finding]) {
759 findings.sort_by(|a, b| {
760 b.severity
761 .rank()
762 .cmp(&a.severity.rank())
763 .then_with(|| a.url.cmp(&b.url))
764 .then_with(|| a.check.cmp(&b.check))
765 });
766}
767
768fn dedup_errors(errors: &mut Vec<CapturedError>) {
769 let mut seen = HashSet::new();
770 errors.retain(|error| {
771 seen.insert((
772 error.context.clone(),
773 error.url.clone(),
774 error.error_type.clone(),
775 error.message.clone(),
776 ))
777 });
778}
779
780fn merge_scanner_stats(target: &mut ScannerStatsMap, batch: ScannerStatsMap) {
781 for (name, stats) in batch {
782 let entry = target.entry(name).or_default();
783 entry.findings += stats.findings;
784 entry.errors += stats.errors;
785 }
786}
787
788fn should_shuffle_scanners() -> bool {
789 if cfg!(test) {
790 return false;
791 }
792 if std::env::var_os("RUST_TEST_THREADS").is_some() {
795 return false;
796 }
797 std::env::var_os("APIHUNTER_DETERMINISTIC_SCANNER_ORDER").is_none()
798}