simple_waf_scanner/
scanner.rs1use crate::{
2 config::Config,
3 evasion,
4 extractor::DataExtractor,
5 fingerprints::{DetectionResponse, WafDetector},
6 http::{build_client, send_request},
7 payloads::PayloadManager,
8 types::{Finding, ScanResults, ScanSummary},
9};
10use std::collections::HashSet;
11use std::sync::Arc;
12use std::time::Instant;
13use tokio::sync::Semaphore;
14use tokio::time::{sleep, Duration};
15
16pub struct Scanner {
18 config: Config,
19 client: reqwest::Client,
20 payload_manager: PayloadManager,
21 waf_detector: WafDetector,
22 data_extractor: DataExtractor,
23}
24
25impl Scanner {
26 pub async fn new(config: Config) -> crate::error::Result<Self> {
28 config.validate()?;
29
30 let client = build_client(&config)?;
31
32 let payload_manager = if let Some(ref payload_file) = config.payload_file {
33 tracing::info!("Loading custom payloads from: {}", payload_file);
34 PayloadManager::from_file(payload_file).await?
35 } else {
36 tracing::info!("Loading default embedded payloads");
37 PayloadManager::with_defaults()?
38 };
39
40 let waf_detector = WafDetector::new()?;
41 let data_extractor = DataExtractor::new();
42
43 Ok(Self {
44 config,
45 client,
46 payload_manager,
47 waf_detector,
48 data_extractor,
49 })
50 }
51
52 #[tracing::instrument(skip(self), fields(target = %self.config.target))]
54 pub async fn scan(&self) -> crate::error::Result<ScanResults> {
55 let start_time = Instant::now();
56
57 tracing::info!("Starting WAF scan on {}", self.config.target);
58
59 let waf_detected = self.detect_waf().await?;
61
62 if let Some(ref waf_name) = waf_detected {
63 tracing::info!("Detected WAF: {}", waf_name);
64 } else {
65 tracing::info!("No WAF detected");
66 }
67
68 let mut results = ScanResults::new(self.config.target.clone(), waf_detected);
70 let findings = self.test_payloads().await?;
71
72 for finding in findings {
73 results.add_finding(finding);
74 }
75
76 results.sort_by_severity();
78
79 let techniques_used: HashSet<_> = results
80 .findings
81 .iter()
82 .filter_map(|f| f.technique_used.as_ref())
83 .collect();
84
85 results.summary = ScanSummary {
86 total_payloads: self.payload_manager.payloads().len(),
87 successful_bypasses: results.findings.len(),
88 techniques_effective: techniques_used.len(),
89 duration_secs: start_time.elapsed().as_secs_f64(),
90 };
91
92 tracing::info!(
93 "Scan complete. Found {} successful bypasses in {:.2}s",
94 results.summary.successful_bypasses,
95 results.summary.duration_secs
96 );
97
98 Ok(results)
99 }
100
101 async fn detect_waf(&self) -> crate::error::Result<Option<String>> {
103 tracing::debug!("Sending baseline request for WAF detection");
104
105 let response = send_request(&self.client, &self.config.target, None)
106 .await
107 .map_err(|e| {
108 tracing::error!("Connection failed: {}", e);
109 e })?;
111
112 tracing::info!(
114 "Target {} is using HTTP version: {}",
115 self.config.target,
116 response.http_version
117 );
118
119 if response.http_version.contains("HTTP/2") {
120 tracing::info!("✓ HTTP/2 protocol detected - production-ready configuration active");
121 } else {
122 tracing::warn!("⚠ HTTP/1.x detected - some HTTP/2 tests may not apply");
123 }
124
125 let detection_response = DetectionResponse::new(
126 response.status_code,
127 response.headers,
128 response.body,
129 response.cookies,
130 );
131
132 Ok(self.waf_detector.detect(&detection_response))
133 }
134
135 async fn test_payloads(&self) -> crate::error::Result<Vec<Finding>> {
137 let payloads = self.payload_manager.payloads();
138 let semaphore = Arc::new(Semaphore::new(self.config.concurrency));
139 let mut tasks = Vec::new();
140
141 tracing::info!("Testing {} payloads", payloads.len());
142
143 for payload in payloads {
144 for payload_test in &payload.payloads {
145 let technique_variants = evasion::apply_all_techniques(
147 &payload_test.value,
148 self.config.enabled_techniques.as_deref(),
149 );
150
151 for (technique_name, transformed_payload) in technique_variants {
152 let sem = semaphore.clone();
153 let client = self.client.clone();
154 let target = self.config.target.clone();
155 let delay_ms = self.config.delay_ms;
156 let payload_id = payload.id.clone();
157 let severity = payload.info.severity;
158 let category = payload.info.category.clone();
159 let description = payload.info.description.clone();
160 let matchers = payload.matchers.clone();
161 let extractor = self.data_extractor.clone();
162
163 let task = tokio::spawn(async move {
164 let _permit = sem.acquire().await.unwrap();
165
166 if delay_ms > 0 {
168 sleep(Duration::from_millis(delay_ms)).await;
169 }
170
171 let response =
173 send_request(&client, &target, Some(("test", &transformed_payload)))
174 .await;
175
176 match response {
177 Ok(resp) => {
178 let matched = check_matchers(&resp, &matchers);
180
181 if matched {
182 tracing::debug!(
183 "Payload {} matched with technique: {} (HTTP version: {})",
184 payload_id,
185 technique_name,
186 resp.http_version
187 );
188
189 let extracted_data = extractor.extract(
191 &resp.body,
192 &resp.headers,
193 &resp.cookies,
194 );
195
196 Some(Finding {
197 payload_id: payload_id.clone(),
198 severity,
199 category: category.clone(),
200 owasp_category: crate::types::OwaspCategory::from_attack_type(&category),
201 payload_value: transformed_payload,
202 technique_used: if technique_name == "Original" {
203 None
204 } else {
205 Some(technique_name)
206 },
207 response_status: resp.status_code,
208 description,
209 http_version: Some(resp.http_version),
210 extracted_data: if extracted_data.has_data() {
211 Some(extracted_data)
212 } else {
213 None
214 },
215 })
216 } else {
217 None
218 }
219 }
220 Err(e) => {
221 tracing::warn!("Request failed for payload {}: {}", payload_id, e);
222 None
223 }
224 }
225 });
226
227 tasks.push(task);
228 }
229 }
230 }
231
232 let results = futures::future::join_all(tasks).await;
234
235 let findings: Vec<Finding> = results
237 .into_iter()
238 .filter_map(|r| r.ok())
239 .flatten()
240 .collect();
241
242 Ok(findings)
243 }
244}
245
246fn check_matchers(
248 response: &crate::http::HttpResponse,
249 matchers: &[crate::payloads::Matcher],
250) -> bool {
251 for matcher in matchers {
252 match matcher.matcher_type.as_str() {
253 "response_body" => {
254 if matcher.condition == "contains" {
255 for pattern in &matcher.patterns {
256 if response.body.contains(pattern) {
257 return true;
258 }
259 }
260 } else if matcher.condition == "not_contains" {
261 let mut all_not_found = true;
262 for pattern in &matcher.patterns {
263 if response.body.contains(pattern) {
264 all_not_found = false;
265 break;
266 }
267 }
268 if all_not_found {
269 return true;
270 }
271 }
272 }
273 "response_time" => {
274 if matcher.condition == "greater_than" {
275 if let Some(threshold) = matcher.patterns.first() {
276 if let Ok(threshold_ms) = threshold.parse::<u64>() {
277 if response.response_time_ms > threshold_ms {
278 return true;
279 }
280 }
281 }
282 }
283 }
284 "response_status" => {
285 if matcher.condition == "equals" {
286 for pattern in &matcher.patterns {
287 if let Ok(expected_status) = pattern.parse::<u16>() {
288 if response.status_code == expected_status {
289 return true;
290 }
291 }
292 }
293 } else if matcher.condition == "not_equals" {
294 let mut all_different = true;
295 for pattern in &matcher.patterns {
296 if let Ok(expected_status) = pattern.parse::<u16>() {
297 if response.status_code == expected_status {
298 all_different = false;
299 break;
300 }
301 }
302 }
303 if all_different {
304 return true;
305 }
306 }
307 }
308 "response_header" => {
309 if matcher.condition == "contains" {
310 for pattern in &matcher.patterns {
311 for (_, header_value) in &response.headers {
312 if header_value.contains(pattern) {
313 return true;
314 }
315 }
316 }
317 } else if matcher.condition == "not_contains" {
318 let mut found = false;
319 for pattern in &matcher.patterns {
320 for (_, header_value) in &response.headers {
321 if header_value.contains(pattern) {
322 found = true;
323 break;
324 }
325 }
326 if found {
327 break;
328 }
329 }
330 if !found {
331 return true;
332 }
333 }
334 }
335 "ai_refusal_bypass" => {
336 if matcher.condition == "not_contains" {
338 let mut all_not_found = true;
339 for pattern in &matcher.patterns {
340 if response.body.to_lowercase().contains(&pattern.to_lowercase()) {
341 all_not_found = false;
342 break;
343 }
344 }
345 if all_not_found {
346 return true;
347 }
348 }
349 }
350 "system_prompt_leak" => {
351 if matcher.condition == "contains" {
353 for pattern in &matcher.patterns {
354 if response.body.contains(pattern) {
355 return true;
356 }
357 }
358 }
359 }
360 "jailbreak_success" => {
361 if matcher.condition == "contains" {
363 for pattern in &matcher.patterns {
364 if response.body.to_lowercase().contains(&pattern.to_lowercase()) {
365 return true;
366 }
367 }
368 }
369 }
370 "response_json" => {
371 if matcher.condition == "valid" {
373 if serde_json::from_str::<serde_json::Value>(&response.body).is_ok() {
375 return true;
376 }
377 } else if matcher.condition == "contains_field" {
378 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&response.body) {
380 for pattern in &matcher.patterns {
381 if json.get(pattern).is_some() {
382 return true;
383 }
384 }
385 }
386 }
387 }
388 _ => {
389 tracing::warn!("Unknown matcher type: {}", matcher.matcher_type);
390 }
391 }
392 }
393
394 false
395}