1mod auth;
12mod aws;
13mod credential;
14mod multi_step;
15mod request;
16mod response;
17
18use std::collections::HashMap;
19use std::sync::Arc;
20use std::time::Duration;
21
22use keyhog_core::{VerificationResult, VerifiedFinding};
23use parking_lot::Mutex;
24use reqwest::Client;
25use tokio::sync::{Notify, Semaphore};
26use tokio::task::JoinSet;
27
28use crate::cache;
29use crate::{DedupedMatch, VerificationEngine, VerifyConfig, VerifyError, into_finding};
30
31pub(crate) use aws::build_aws_probe;
32pub(crate) use credential::{VerificationAttempt, verify_with_retry};
33pub(crate) use request::{
34 RequestBuildResult, build_request_for_step, execute_request, resolved_client_for_url,
35};
36pub(crate) use response::{
37 body_indicates_error, evaluate_success, extract_metadata, read_response_body,
38};
39
40const DEFAULT_SERVICE_CONCURRENCY: usize = 5;
41
42#[derive(Clone)]
43struct VerifyTaskShared {
44 global_semaphore: Arc<Semaphore>,
45 service_semaphores: Arc<HashMap<Arc<str>, Arc<Semaphore>>>,
46 client: Client,
47 detectors: Arc<HashMap<Arc<str>, keyhog_core::DetectorSpec>>,
48 timeout: Duration,
49 cache: Arc<cache::VerificationCache>,
50 inflight: Arc<Mutex<HashMap<(Arc<str>, Arc<str>), Arc<Notify>>>>,
51 max_inflight_keys: usize,
52 danger_allow_private_ips: bool,
53}
54
55struct InflightGuard {
56 key: (Arc<str>, Arc<str>),
57 inflight: Arc<Mutex<HashMap<(Arc<str>, Arc<str>), Arc<Notify>>>>,
58 notify: Arc<Notify>,
59}
60
61impl Drop for InflightGuard {
62 fn drop(&mut self) {
63 let mut lock = self.inflight.lock();
64 lock.remove(&self.key);
65 self.notify.notify_waiters();
66 }
67}
68
69async fn verify_group_task(shared: VerifyTaskShared, group: DedupedMatch) -> VerifiedFinding {
70 let global = shared.global_semaphore;
71 let service_sem = shared
72 .service_semaphores
73 .get(&*group.service)
74 .cloned()
75 .unwrap_or_else(|| Arc::new(Semaphore::new(DEFAULT_SERVICE_CONCURRENCY)));
76 let client = shared.client;
77 let detector = shared.detectors.get(&*group.detector_id).cloned();
78 let timeout = shared.timeout;
79
80 let cache = shared.cache;
81 let inflight = shared.inflight;
82 let max_inflight_keys = shared.max_inflight_keys;
83
84 let Ok(_global_permit) = global.acquire().await else {
85 return into_finding(
86 group,
87 VerificationResult::Error("semaphore closed".into()),
88 HashMap::new(),
89 );
90 };
91 let Ok(_service_permit) = service_sem.acquire().await else {
92 return into_finding(
93 group,
94 VerificationResult::Error("service semaphore closed".into()),
95 HashMap::new(),
96 );
97 };
98
99 if let Some((cached_result, cached_meta)) = cache.get(&group.credential, &group.detector_id) {
100 return into_finding(group, cached_result, cached_meta);
101 }
102
103 let _inflight_guard = loop {
104 let notify_to_await: Option<Arc<Notify>> = {
105 let mut lock = inflight.lock();
106 if lock.len() >= max_inflight_keys {
107 break None;
108 }
109
110 let key = (group.detector_id.clone(), group.credential.clone());
111 if let Some((cached_result, cached_meta)) =
112 cache.get(&group.credential, &group.detector_id)
113 {
114 return into_finding(group, cached_result, cached_meta);
115 }
116
117 match lock.entry(key.clone()) {
118 std::collections::hash_map::Entry::Occupied(entry) => Some(entry.get().clone()),
119 std::collections::hash_map::Entry::Vacant(entry) => {
120 let notify = Arc::new(Notify::new());
121 entry.insert(notify.clone());
122 break Some(InflightGuard {
123 key,
124 inflight: inflight.clone(),
125 notify,
126 });
127 }
128 }
129 };
130
131 if let Some(notify) = notify_to_await {
132 notify.notified().await;
133 } else {
134 break None;
135 }
136 };
137
138 crate::rate_limit::get_rate_limiter()
140 .wait(&group.service)
141 .await;
142
143 let (verification, metadata) = if let Some(custom_verifier) =
144 keyhog_core::registry::get_verifier_registry().get(&group.detector_id)
145 {
146 custom_verifier.verify(&group).await
147 } else {
148 match &detector {
149 Some(det) => match &det.verify {
150 Some(verify_spec) => {
151 verify_with_retry(
152 &client,
153 verify_spec,
154 &group.credential,
155 &group.companions,
156 timeout,
157 shared.danger_allow_private_ips,
158 )
159 .await
160 }
161 None => (VerificationResult::Unverifiable, HashMap::new()),
162 },
163 None => (VerificationResult::Unverifiable, HashMap::new()),
164 }
165 };
166
167 cache.put(
168 &group.credential,
169 &group.detector_id,
170 verification.clone(),
171 metadata.clone(),
172 );
173
174 into_finding(group, verification, metadata)
175}
176
177impl VerificationEngine {
178 pub fn new(
180 detectors: &[keyhog_core::DetectorSpec],
181 config: VerifyConfig,
182 ) -> Result<Self, VerifyError> {
183 let client = Client::builder()
184 .timeout(config.timeout)
185 .danger_accept_invalid_certs(false)
187 .redirect(reqwest::redirect::Policy::none())
188 .build()
189 .map_err(VerifyError::ClientBuild)?;
190
191 let detector_map: HashMap<Arc<str>, keyhog_core::DetectorSpec> = detectors
192 .iter()
193 .cloned()
194 .map(|d| (d.id.clone().into(), d))
195 .collect();
196
197 let mut service_semaphores = HashMap::new();
198 for d in detectors {
199 service_semaphores
200 .entry(d.service.clone().into())
201 .or_insert_with(|| {
202 Arc::new(Semaphore::new(config.max_concurrent_per_service.max(1)))
203 });
204 }
205
206 Ok(Self {
207 client,
208 detectors: Arc::new(detector_map),
209 service_semaphores: Arc::new(service_semaphores),
210 global_semaphore: Arc::new(Semaphore::new(config.max_concurrent_global.max(1))),
211 timeout: config.timeout,
212 cache: Arc::new(cache::VerificationCache::new(Duration::from_secs(3600))), inflight: Arc::new(Mutex::new(HashMap::new())),
214 max_inflight_keys: config.max_inflight_keys.max(1),
215 danger_allow_private_ips: config.danger_allow_private_ips,
216 })
217 }
218
219 pub async fn verify_all(&self, groups: Vec<DedupedMatch>) -> Vec<VerifiedFinding> {
221 let max_active = self.global_semaphore.available_permits().max(1);
222 let total = groups.len();
223 let shared = VerifyTaskShared {
224 global_semaphore: self.global_semaphore.clone(),
225 service_semaphores: self.service_semaphores.clone(),
226 client: self.client.clone(),
227 detectors: self.detectors.clone(),
228 timeout: self.timeout,
229 cache: self.cache.clone(),
230 inflight: self.inflight.clone(),
231 max_inflight_keys: self.max_inflight_keys,
232 danger_allow_private_ips: self.danger_allow_private_ips,
233 };
234 let mut pending = groups.into_iter();
235 let mut join_set = JoinSet::new();
236
237 while join_set.len() < max_active {
238 let Some(group) = pending.next() else {
239 break;
240 };
241 join_set.spawn(verify_group_task(shared.clone(), group));
242 }
243
244 let mut findings = Vec::with_capacity(total);
245 while let Some(result) = join_set.join_next().await {
246 match result {
247 Ok(finding) => findings.push(finding),
248 Err(e) => tracing::error!("verification task panicked: {}", e),
249 }
250
251 if let Some(group) = pending.next() {
252 join_set.spawn(verify_group_task(shared.clone(), group));
253 }
254 }
255 findings
256 }
257}