1mod auth;
12mod aws;
13mod credential;
14mod multi_step;
15mod request;
16mod response;
17
18use std::collections::HashMap;
19use std::sync::Arc;
20use std::time::Duration;
21
22use dashmap::DashMap;
23use keyhog_core::{VerificationResult, VerifiedFinding};
24use reqwest::Client;
25use tokio::sync::{Notify, Semaphore};
26use tokio::task::JoinSet;
27
28use crate::cache;
29use crate::{into_finding, DedupedMatch, VerificationEngine, VerifyConfig, VerifyError};
30
31pub(crate) use aws::build_aws_probe;
32pub(crate) use credential::{verify_with_retry, VerificationAttempt};
33pub(crate) use request::{
34 build_request_for_step, execute_request, resolved_client_for_url, RequestBuildResult,
35};
36pub(crate) use response::{
37 body_indicates_error, evaluate_success, extract_metadata, read_response_body,
38};
39
40const DEFAULT_SERVICE_CONCURRENCY: usize = 5;
41
42#[derive(Clone)]
43struct VerifyTaskShared {
44 global_semaphore: Arc<Semaphore>,
45 service_semaphores: Arc<HashMap<Arc<str>, Arc<Semaphore>>>,
46 client: Client,
47 detectors: Arc<HashMap<Arc<str>, keyhog_core::DetectorSpec>>,
48 timeout: Duration,
49 cache: Arc<cache::VerificationCache>,
50 inflight: Arc<DashMap<(Arc<str>, Arc<str>), Arc<Notify>>>,
51 max_inflight_keys: usize,
52 danger_allow_private_ips: bool,
53 danger_allow_http: bool,
54 oob_session: Option<Arc<crate::oob::OobSession>>,
55}
56
57struct InflightGuard {
58 key: (Arc<str>, Arc<str>),
59 inflight: Arc<DashMap<(Arc<str>, Arc<str>), Arc<Notify>>>,
60 notify: Arc<Notify>,
61}
62
63impl Drop for InflightGuard {
64 fn drop(&mut self) {
65 self.inflight.remove(&self.key);
70 self.notify.notify_waiters();
71 }
72}
73
74async fn verify_group_task(shared: VerifyTaskShared, group: DedupedMatch) -> VerifiedFinding {
75 let global = shared.global_semaphore;
76 let service_sem = shared
77 .service_semaphores
78 .get(&*group.service)
79 .cloned()
80 .unwrap_or_else(|| Arc::new(Semaphore::new(DEFAULT_SERVICE_CONCURRENCY)));
81 let client = shared.client;
82 let detector = shared.detectors.get(&*group.detector_id).cloned();
83 let timeout = shared.timeout;
84
85 let cache = shared.cache;
86 let inflight = shared.inflight;
87 let max_inflight_keys = shared.max_inflight_keys;
88
89 let Ok(_global_permit) = global.acquire().await else {
90 return into_finding(
91 group,
92 VerificationResult::Error("semaphore closed".into()),
93 HashMap::new(),
94 );
95 };
96 let Ok(_service_permit) = service_sem.acquire().await else {
97 return into_finding(
98 group,
99 VerificationResult::Error("service semaphore closed".into()),
100 HashMap::new(),
101 );
102 };
103
104 if let Some((cached_result, cached_meta)) = cache.get(&group.credential, &group.detector_id) {
105 return into_finding(group, cached_result, cached_meta);
106 }
107
108 let _inflight_guard = loop {
109 let notify_to_await: Option<Arc<Notify>> = {
110 if inflight.len() >= max_inflight_keys {
115 break None;
116 }
117
118 let key = (group.detector_id.clone(), group.credential.clone());
119 if let Some((cached_result, cached_meta)) =
120 cache.get(&group.credential, &group.detector_id)
121 {
122 return into_finding(group, cached_result, cached_meta);
123 }
124
125 match inflight.entry(key.clone()) {
126 dashmap::mapref::entry::Entry::Occupied(entry) => Some(entry.get().clone()),
127 dashmap::mapref::entry::Entry::Vacant(entry) => {
128 let notify = Arc::new(Notify::new());
129 entry.insert(notify.clone());
130 break Some(InflightGuard {
131 key,
132 inflight: inflight.clone(),
133 notify,
134 });
135 }
136 }
137 };
138
139 if let Some(notify) = notify_to_await {
140 notify.notified().await;
141 } else {
142 break None;
143 }
144 };
145
146 let (verification, metadata) = if let Some(custom_verifier) =
147 keyhog_core::registry::get_verifier_registry().get(&group.detector_id)
148 {
149 custom_verifier.verify(&group).await
150 } else {
151 match &detector {
152 Some(det) => match &det.verify {
153 Some(verify_spec) => {
154 verify_with_retry(
155 &client,
156 verify_spec,
157 &group.credential,
158 &group.companions,
159 timeout,
160 shared.danger_allow_private_ips,
161 shared.danger_allow_http,
162 shared.oob_session.as_ref(),
163 )
164 .await
165 }
166 None => (VerificationResult::Unverifiable, HashMap::new()),
167 },
168 None => (VerificationResult::Unverifiable, HashMap::new()),
169 }
170 };
171
172 cache.put(
173 &group.credential,
174 &group.detector_id,
175 verification.clone(),
176 metadata.clone(),
177 );
178
179 into_finding(group, verification, metadata)
180}
181
182impl VerificationEngine {
183 pub fn new(
185 detectors: &[keyhog_core::DetectorSpec],
186 config: VerifyConfig,
187 ) -> Result<Self, VerifyError> {
188 let client = Client::builder()
189 .timeout(config.timeout)
190 .danger_accept_invalid_certs(false)
192 .redirect(reqwest::redirect::Policy::none())
193 .build()
194 .map_err(VerifyError::ClientBuild)?;
195
196 let detector_map: HashMap<Arc<str>, keyhog_core::DetectorSpec> = detectors
197 .iter()
198 .cloned()
199 .map(|d| (d.id.clone().into(), d))
200 .collect();
201
202 let mut service_semaphores = HashMap::new();
203 for d in detectors {
204 service_semaphores
205 .entry(d.service.clone().into())
206 .or_insert_with(|| {
207 Arc::new(Semaphore::new(config.max_concurrent_per_service.max(1)))
208 });
209 }
210
211 Ok(Self {
212 client,
213 detectors: Arc::new(detector_map),
214 service_semaphores: Arc::new(service_semaphores),
215 global_semaphore: Arc::new(Semaphore::new(config.max_concurrent_global.max(1))),
216 timeout: config.timeout,
217 cache: Arc::new(cache::VerificationCache::default_ttl()),
218 inflight: Arc::new(DashMap::new()),
219 max_inflight_keys: config.max_inflight_keys.max(1),
220 danger_allow_private_ips: config.danger_allow_private_ips,
221 danger_allow_http: config.danger_allow_http,
222 oob_session: None,
223 })
224 }
225
226 pub async fn verify_all(&self, groups: Vec<DedupedMatch>) -> Vec<VerifiedFinding> {
228 let max_active = self.global_semaphore.available_permits().max(1);
229 let total = groups.len();
230 let shared = VerifyTaskShared {
231 global_semaphore: self.global_semaphore.clone(),
232 service_semaphores: self.service_semaphores.clone(),
233 client: self.client.clone(),
234 detectors: self.detectors.clone(),
235 timeout: self.timeout,
236 cache: self.cache.clone(),
237 inflight: self.inflight.clone(),
238 max_inflight_keys: self.max_inflight_keys,
239 danger_allow_private_ips: self.danger_allow_private_ips,
240 danger_allow_http: self.danger_allow_http,
241 oob_session: self.oob_session.clone(),
242 };
243 let mut pending = groups.into_iter();
244 let mut join_set = JoinSet::new();
245
246 while join_set.len() < max_active {
247 let Some(group) = pending.next() else {
248 break;
249 };
250 join_set.spawn(verify_group_task(shared.clone(), group));
251 }
252
253 let mut findings = Vec::with_capacity(total);
254 while let Some(result) = join_set.join_next().await {
255 match result {
256 Ok(finding) => findings.push(finding),
257 Err(e) => tracing::error!("verification task panicked: {}", e),
258 }
259
260 if let Some(group) = pending.next() {
261 join_set.spawn(verify_group_task(shared.clone(), group));
262 }
263 }
264 findings
265 }
266
267 pub async fn enable_oob(
277 &mut self,
278 config: crate::oob::OobConfig,
279 ) -> Result<(), crate::oob::InteractshError> {
280 if let Some(old) = self.oob_session.take() {
281 old.shutdown().await;
282 }
283 let session = crate::oob::OobSession::start(self.client.clone(), config).await?;
284 self.oob_session = Some(session);
285 Ok(())
286 }
287
288 pub async fn shutdown_oob(&mut self) {
291 if let Some(session) = self.oob_session.take() {
292 session.shutdown().await;
293 }
294 }
295}
296
297impl Drop for VerificationEngine {
298 fn drop(&mut self) {
299 if let Some(session) = self.oob_session.take() {
310 session.abort_poller_for_drop();
311 }
312 }
313}