use dashmap::DashSet;
use log::{debug, trace, warn};
use rustc_hash::FxBuildHasher;
use std::{
sync::{mpsc, Arc},
time::Duration,
};
use miette::{Context as _, Error, IntoDiagnostic as _, NamedSource, Result};
use oxc::allocator::Allocator;
use ureq::{Agent, AgentBuilder};
use url::Url;
use crate::{
http::random_ua, walk::Script, ApiKeyExtractor, Config, ScriptMessage, ScriptReceiver,
};
use super::{error::DownloadScriptDiagnostic, ApiKeyError};
#[derive(Debug)]
pub enum ApiKeyMessage {
Keys(Vec<ApiKeyError>),
RecoverableFailure(Error),
DidScanScript,
DidScrapePages(usize),
Stop,
}
impl From<Vec<ApiKeyError>> for ApiKeyMessage {
fn from(keys: Vec<ApiKeyError>) -> Self {
Self::Keys(keys)
}
}
pub type ApiKeySender = mpsc::Sender<ApiKeyMessage>;
pub type ApiKeyReceiver = mpsc::Receiver<ApiKeyMessage>;
#[derive(Debug)]
pub struct ApiKeyCollector {
config: Arc<Config>,
extractor: ApiKeyExtractor,
receiver: ScriptReceiver,
sender: ApiKeySender,
agent: Agent,
ua: Option<&'static str>,
skip_domains: DashSet<&'static str, FxBuildHasher>,
skip_paths: Vec<&'static str>,
extra_headers: Vec<(String, String)>,
}
impl ApiKeyCollector {
pub fn new(config: Arc<Config>, recv: ScriptReceiver, sender: ApiKeySender) -> Self {
let agent = AgentBuilder::new().timeout(Duration::from_secs(10)).build();
let skip_domains: DashSet<&'static str, FxBuildHasher> = Default::default();
skip_domains.insert("ajax.googleapis.com");
skip_domains.insert("apis.google.com");
skip_domains.insert("youtube.com");
skip_domains.insert("www.googletagmanager.com");
skip_domains.insert("assets.calendly.com");
skip_domains.insert("cdn.jsdelivr.net");
skip_domains.insert("unpkg.com");
skip_domains.insert("events.framer.com");
let skip_paths: Vec<&'static str> = vec!["jquery", "react", "lodash", "unpkg"];
let extractor = ApiKeyExtractor::new(Arc::clone(&config));
Self {
config,
extractor,
receiver: recv,
sender,
agent,
ua: None,
skip_domains,
skip_paths,
extra_headers: vec![],
}
}
pub fn with_random_ua(mut self, yes: bool) -> Self {
if yes && self.ua.is_none() {
self.ua = Some(random_ua(&mut rand::thread_rng()));
} else {
self.ua = None;
}
self
}
pub fn with_headers<I>(mut self, headers: I) -> Self
where
I: IntoIterator<Item = (String, String)>,
{
self.extra_headers.extend(headers);
self
}
pub fn collect(self) {
while let Ok(msg) = self.receiver.recv() {
match msg {
ScriptMessage::Done => {
break;
}
ScriptMessage::DidWalkPage => {
self.send(ApiKeyMessage::DidScrapePages(1));
}
ScriptMessage::Scripts(scripts) => {
for script in scripts {
match script {
Script::Url(url) => {
if self.should_skip_url(&url) {
continue;
}
debug!("({url}) checking for api keys...");
let js = self.download_script(&url);
match js {
Ok(js) => {
self.parse_and_send(url, &js);
}
#[allow(unused_variables)]
Err(DownloadScriptDiagnostic::NotJavascript(url, ct)) => {
#[cfg(debug_assertions)]
warn!(
"({url}) Skipping non-JS script with content type {ct}"
);
}
Err(e) => {
let report = Error::from(e)
.context(format!("Could not download script at {url}"));
warn!("{report:?}");
}
}
}
Script::Embedded(js, page_url) => self.parse_and_send(page_url, &js),
}
}
}
}
}
}
fn download_script(&self, url: &Url) -> Result<String, DownloadScriptDiagnostic> {
let request = self.agent.get(url.as_str());
let request = if let Some(ua) = self.ua {
request.set("User-Agent", ua)
} else {
request
};
let request = self
.extra_headers
.iter()
.fold(request, |req, (key, value)| req.set(key, value));
let res = request.call()?;
if !res.content_type().contains("javascript") {
return Err(DownloadScriptDiagnostic::NotJavascript(
url.to_string(),
res.content_type().to_string(),
));
}
let js: String = res
.into_string()
.map_err(|e| DownloadScriptDiagnostic::CannotReadBody(url.to_string(), e))?;
trace!("({url}) Downloaded script");
Ok(js)
}
fn parse_and_send(&self, url: Url, script: &str) {
trace!("({url}) Parsing script");
let alloc = Allocator::default();
let extract_result = self
.extractor
.extract_api_keys(&alloc, script)
.with_context(|| format!("Failed to parse script at '{url}'"));
self.sender
.send(ApiKeyMessage::DidScanScript)
.into_diagnostic()
.unwrap();
let api_keys = match extract_result {
Ok(api_keys) => api_keys,
Err(e) => {
self.send(ApiKeyMessage::RecoverableFailure(e));
return;
}
};
if !api_keys.is_empty() {
let num_keys = api_keys.len();
let url_string = url.to_string();
let source = Arc::new(
NamedSource::new(url_string.clone(), script.to_string())
.with_language("javascript"),
);
let api_keys = api_keys
.into_iter()
.map(|api_key| ApiKeyError::new(api_key, url_string.clone(), &source, &self.config))
.collect::<Vec<_>>();
self.sender
.send(ApiKeyMessage::Keys(api_keys))
.into_diagnostic()
.context(format!(
"Failed to send {} keys over channel: channel is closed",
num_keys
))
.unwrap();
}
}
fn should_skip_url(&self, url: &Url) -> bool {
if let Some(domain) = url.domain() {
if self.skip_domains.contains(domain) {
trace!("({url}) URL has an ignored domain, skipping");
return true;
}
}
for skip_path_pattern in &self.skip_paths {
if url.path().contains(skip_path_pattern) {
trace!(
"({url}) URL has a path matching ignored pattern {skip_path_pattern}, skipping"
);
return true;
}
}
false
}
fn send(&self, msg: ApiKeyMessage) {
self.sender
.send(msg)
.into_diagnostic()
.context("Failed to send message over API key channel")
.unwrap();
}
}