use super::spec_driven::{AnnotatedOperation, ApiKeyLocation, SecuritySchemeInfo};
use reqwest::{Client, Method};
use std::collections::BTreeMap;
use std::net::IpAddr;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Duration;
const CAPTURE_BODY_CAP_BYTES: usize = 16 * 1024;
const SCHEMA_MUTATION_CAP: usize = 12;
const CONTENT_TYPE_SWAP_VARIANTS: &[(&str, &str)] = &[
("application/xml", "request-body:content-type-mismatch:xml"),
("application/yaml", "request-body:content-type-mismatch:yaml"),
("multipart/form-data", "request-body:content-type-mismatch:multipart"),
(
"application/x-www-form-urlencoded",
"request-body:content-type-mismatch:urlencoded",
),
];
#[derive(Debug, Clone)]
pub struct SelfTestConfig {
pub target_url: String,
pub skip_tls_verify: bool,
pub timeout: Duration,
pub extra_headers: Vec<(String, String)>,
pub delay_between_requests: Duration,
pub base_path: Option<String>,
pub source_ips: Vec<IpAddr>,
pub geo_source_ips: Vec<IpAddr>,
pub geo_source_headers: Vec<String>,
pub capture: Option<Arc<Mutex<Vec<CaseCapture>>>>,
pub validate_response_schemas: bool,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct CaseCapture {
pub label: String,
pub method: String,
pub url: String,
pub request_headers: BTreeMap<String, String>,
pub request_body: Option<String>,
pub request_body_truncated: bool,
pub response_status: u16,
pub response_headers: BTreeMap<String, String>,
pub response_body: Option<String>,
pub response_body_truncated: bool,
pub error: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub response_schema_error: Option<String>,
}
impl Default for SelfTestConfig {
fn default() -> Self {
Self {
target_url: "http://localhost:3000".into(),
skip_tls_verify: false,
timeout: Duration::from_secs(15),
extra_headers: Vec::new(),
delay_between_requests: Duration::from_millis(0),
base_path: None,
source_ips: Vec::new(),
geo_source_ips: Vec::new(),
geo_source_headers: default_geo_source_headers(),
capture: None,
validate_response_schemas: false,
}
}
}
fn truncate_body_for_capture(body: &str) -> (String, bool) {
if body.len() <= CAPTURE_BODY_CAP_BYTES {
return (body.to_string(), false);
}
let mut end = CAPTURE_BODY_CAP_BYTES;
while end > 0 && !body.is_char_boundary(end) {
end -= 1;
}
(body[..end].to_string(), true)
}
pub fn default_geo_source_headers() -> Vec<String> {
vec![
"X-Forwarded-For".to_string(),
"True-Client-IP".to_string(),
"CF-Connecting-IP".to_string(),
]
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct CaseOutcome {
pub label: String,
pub expected_4xx: bool,
pub actual_status: u16,
pub passed: bool,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct OperationResult {
pub method: String,
pub path: String,
pub positive: Option<CaseOutcome>,
pub negatives: Vec<CaseOutcome>,
}
#[derive(Debug, Default, Clone, serde::Serialize)]
pub struct SelfTestReport {
pub positive_pass: usize,
pub positive_fail: usize,
pub negative_caught: BTreeMap<String, usize>,
pub negative_missed: BTreeMap<String, usize>,
pub operations: Vec<OperationResult>,
}
impl SelfTestReport {
pub fn all_passed(&self) -> bool {
self.positive_fail == 0 && self.negative_missed.values().sum::<usize>() == 0
}
pub fn detect_target_misconfiguration(&self) -> Option<u16> {
if self.positive_pass > 0 || self.positive_fail < 10 {
return None;
}
let mut seen: Option<u16> = None;
for op in &self.operations {
let Some(p) = &op.positive else {
continue;
};
if p.passed {
return None;
}
match seen {
None => seen = Some(p.actual_status),
Some(s) if s != p.actual_status => return None,
_ => {}
}
}
seen
}
pub fn render_summary(&self) -> String {
let mut out = String::new();
out.push_str(&format!(
"Positives: {} pass / {} fail\n",
self.positive_pass, self.positive_fail
));
let mut keys: Vec<&String> =
self.negative_caught.keys().chain(self.negative_missed.keys()).collect();
keys.sort();
keys.dedup();
for cat in keys {
let caught = self.negative_caught.get(cat).copied().unwrap_or(0);
let missed = self.negative_missed.get(cat).copied().unwrap_or(0);
let mark = if missed == 0 { "✓" } else { "⚠" };
out.push_str(&format!(
"Negatives [{}]: {} caught / {} missed {}\n",
cat, caught, missed, mark
));
}
out
}
}
pub async fn run_self_test(
operations: &[AnnotatedOperation],
config: &SelfTestConfig,
) -> Result<SelfTestReport, reqwest::Error> {
let clients = build_client_pool(config)?;
let client_cursor = AtomicUsize::new(0);
let geo_cursor = AtomicUsize::new(0);
let mut report = SelfTestReport::default();
for op in operations {
let client_idx = client_cursor.fetch_add(1, Ordering::Relaxed) % clients.len();
let client = &clients[client_idx];
let geo_ip = if config.geo_source_ips.is_empty() {
None
} else {
let idx = geo_cursor.fetch_add(1, Ordering::Relaxed) % config.geo_source_ips.len();
Some(config.geo_source_ips[idx])
};
let result = test_operation(client, config, op, geo_ip).await;
if let Some(p) = &result.positive {
if p.passed {
report.positive_pass += 1;
} else {
report.positive_fail += 1;
}
}
for neg in &result.negatives {
let cat = neg.label.split(':').next().unwrap_or("other").to_string();
if neg.passed {
*report.negative_caught.entry(cat).or_insert(0) += 1;
} else {
*report.negative_missed.entry(cat).or_insert(0) += 1;
}
}
report.operations.push(result);
if !config.delay_between_requests.is_zero() {
tokio::time::sleep(config.delay_between_requests).await;
}
}
Ok(report)
}
fn effective_op_headers(
base: &[(String, String)],
geo_ip: Option<IpAddr>,
geo_headers: &[String],
) -> Vec<(String, String)> {
let mut out = base.to_vec();
let Some(ip) = geo_ip else {
return out;
};
let value = ip.to_string();
for h in geo_headers {
if out.iter().any(|(k, _)| k.eq_ignore_ascii_case(h)) {
continue;
}
out.push((h.clone(), value.clone()));
}
out
}
fn build_client_pool(config: &SelfTestConfig) -> Result<Vec<Client>, reqwest::Error> {
let make = |bind: Option<IpAddr>| -> Result<Client, reqwest::Error> {
let mut builder = Client::builder().timeout(config.timeout);
if config.skip_tls_verify {
builder = builder.danger_accept_invalid_certs(true);
}
if let Some(addr) = bind {
builder = builder.local_address(addr);
}
builder.build()
};
if config.source_ips.is_empty() {
Ok(vec![make(None)?])
} else {
config.source_ips.iter().map(|ip| make(Some(*ip))).collect()
}
}
async fn test_operation(
client: &Client,
config: &SelfTestConfig,
op: &AnnotatedOperation,
geo_ip: Option<IpAddr>,
) -> OperationResult {
let sink_start = config.capture.as_ref().and_then(|s| s.lock().ok().map(|g| g.len()));
let url = build_url_with_base(
&config.target_url,
config.base_path.as_deref(),
&op.path,
&op.path_params,
);
let method = Method::from_bytes(op.method.to_uppercase().as_bytes()).unwrap_or(Method::GET);
let op_headers = effective_op_headers(&op.header_params, geo_ip, &config.geo_source_headers);
let positive = send_case(
client,
config,
method.clone(),
&url,
"positive",
false,
op.sample_body.as_deref(),
op.query_params.clone(),
op_headers.clone(),
)
.await;
let mut negatives = Vec::new();
if op.request_body_content_type.is_some() {
negatives.push(
send_case(
client,
config,
method.clone(),
&url,
"request-body:empty",
true,
Some("{}"),
op.query_params.clone(),
op_headers.clone(),
)
.await,
);
negatives.push(
send_case(
client,
config,
method.clone(),
&url,
"request-body:wrong-type",
true,
Some("[]"),
op.query_params.clone(),
op_headers.clone(),
)
.await,
);
if op
.request_body_content_type
.as_deref()
.map(|ct| ct.contains("json"))
.unwrap_or(false)
{
let payload = op.sample_body.as_deref().unwrap_or("{}");
for (ct, label) in CONTENT_TYPE_SWAP_VARIANTS {
negatives.push(
send_case_with_extra(
client,
config,
method.clone(),
&url,
label,
true,
Some(payload),
op.query_params.clone(),
op_headers
.iter()
.filter(|(k, _)| !k.eq_ignore_ascii_case("content-type"))
.cloned()
.collect(),
vec![("Content-Type".to_string(), (*ct).to_string())],
)
.await,
);
}
}
if let (Some(sample_str), Some(schema)) =
(op.sample_body.as_deref(), op.request_body_schema.as_ref())
{
if let Ok(sample) = serde_json::from_str::<serde_json::Value>(sample_str) {
let mutations = super::schema_mutator::mutate_body(&sample, schema);
for m in mutations.into_iter().take(SCHEMA_MUTATION_CAP) {
let body_str = serde_json::to_string(&m.body).unwrap_or_default();
negatives.push(
send_case(
client,
config,
method.clone(),
&url,
&m.label,
true,
Some(&body_str),
op.query_params.clone(),
op_headers.clone(),
)
.await,
);
}
}
}
}
{
let pad = "p=".to_string() + &"x".repeat(9_000);
let bad_url = if url.contains('?') {
format!("{url}&{pad}")
} else {
format!("{url}?{pad}")
};
negatives.push(
send_case(
client,
config,
method.clone(),
&bad_url,
"parameters:uri-too-long",
true,
op.sample_body.as_deref(),
op.query_params.clone(),
op_headers.clone(),
)
.await,
);
}
if !op.path_params.is_empty() {
let mut url_with_placeholder = op.path.clone();
if let Some((first_name, _)) = op.path_params.first() {
for (name, value) in op.path_params.iter().skip(1) {
if !value.is_empty() {
url_with_placeholder =
url_with_placeholder.replace(&format!("{{{name}}}"), value);
}
}
url_with_placeholder =
url_with_placeholder.replace(&format!("{{{first_name}}}"), "self-test-invalid-id");
let bad_url = build_url_with_base(
&config.target_url,
config.base_path.as_deref(),
&url_with_placeholder,
&[],
);
negatives.push(
send_case(
client,
config,
method.clone(),
&bad_url,
"parameters:bad-path-param",
true,
op.sample_body.as_deref(),
op.query_params.clone(),
op_headers.clone(),
)
.await,
);
}
}
if !op.query_params.is_empty() {
let mut q = op.query_params.clone();
q.remove(0);
negatives.push(
send_case(
client,
config,
method.clone(),
&url,
"parameters:missing-query",
true,
op.sample_body.as_deref(),
q,
op_headers.clone(),
)
.await,
);
}
for probe in build_security_probes(&op.security_schemes) {
let stripped_extra = strip_auth(&config.extra_headers, &op.security_schemes);
let stripped_headers = strip_auth(&op.header_params, &op.security_schemes);
let stripped_query = strip_auth_query(&op.query_params, &op.security_schemes);
let mut req_headers = stripped_headers;
for (k, v) in &probe.headers {
req_headers.push((k.clone(), v.clone()));
}
if let Some(ip) = geo_ip {
let ip_str = ip.to_string();
for h in &config.geo_source_headers {
let already = req_headers.iter().any(|(k, _)| k.eq_ignore_ascii_case(h));
if !already {
req_headers.push((h.clone(), ip_str.clone()));
}
}
}
let mut req_query = stripped_query;
for (k, v) in &probe.query {
req_query.push((k.clone(), v.clone()));
}
negatives.push(
send_case_with_extra(
client,
config,
method.clone(),
&url,
&probe.label,
true,
op.sample_body.as_deref(),
req_query,
req_headers,
stripped_extra,
)
.await,
);
}
if !op.header_params.is_empty() {
let mut h = op_headers.clone();
if !h.is_empty() {
h.remove(0);
}
negatives.push(
send_case(
client,
config,
method.clone(),
&url,
"parameters:missing-header",
true,
op.sample_body.as_deref(),
op.query_params.clone(),
h,
)
.await,
);
}
for probe in build_owasp_probes(op) {
negatives.push(
send_case(
client,
config,
method.clone(),
&url,
&probe.label,
true,
probe.body.as_deref(),
probe.query,
op_headers.clone(),
)
.await,
);
}
if config.validate_response_schemas {
if let (Some(sink), Some(start)) = (config.capture.as_ref(), sink_start) {
if !op.response_schemas.is_empty() {
if let Ok(mut guard) = sink.lock() {
let end = guard.len();
for i in start..end {
let Some(entry) = guard.get_mut(i) else {
continue;
};
let Some(body) = entry.response_body.as_deref() else {
continue;
};
let Some(schema) = op.response_schemas.get(&entry.response_status) else {
continue;
};
entry.response_schema_error = validate_body_against_schema(body, schema);
}
}
}
}
}
OperationResult {
method: op.method.clone(),
path: op.path.clone(),
positive: Some(positive),
negatives,
}
}
fn validate_body_against_schema(body: &str, schema: &serde_json::Value) -> Option<String> {
let parsed: serde_json::Value = serde_json::from_str(body).ok()?;
let validator = jsonschema::validator_for(schema).ok()?;
let mut errors = validator.iter_errors(&parsed);
let first = errors.next()?;
let path = first.instance_path.to_string();
let path = if path.is_empty() { "/" } else { path.as_str() };
Some(format!(
"at {path}: {}",
format!("{:?}", first.kind).split('(').next().unwrap_or("unknown")
))
}
#[derive(Debug, Clone)]
struct OwaspProbe {
label: String,
body: Option<String>,
query: Vec<(String, String)>,
}
fn build_owasp_probes(op: &AnnotatedOperation) -> Vec<OwaspProbe> {
use crate::security_payloads::{SecurityCategory, SecurityPayloads};
let categories = [
SecurityCategory::SqlInjection,
SecurityCategory::Xss,
SecurityCategory::CommandInjection,
SecurityCategory::PathTraversal,
SecurityCategory::Ssti,
SecurityCategory::LdapInjection,
SecurityCategory::Xxe,
];
let injection_target = pick_injection_target(op);
let Some(target) = injection_target else {
return Vec::new();
};
let mut probes = Vec::new();
for cat in categories {
let Some(payload) = SecurityPayloads::get_by_category(cat).into_iter().next() else {
continue;
};
let mut query = op.query_params.clone();
let mut body = op.sample_body.clone();
match &target {
InjectionTarget::Query(idx) => {
if let Some(slot) = query.get_mut(*idx) {
slot.1 = payload.payload.clone();
}
}
InjectionTarget::BodyStringField(field) => {
body = inject_into_body_field(body.as_deref(), field, &payload.payload);
}
}
probes.push(OwaspProbe {
label: format!("owasp:{}", cat),
body,
query,
});
}
probes
}
#[derive(Debug, Clone)]
enum InjectionTarget {
Query(usize),
BodyStringField(String),
}
fn pick_injection_target(op: &AnnotatedOperation) -> Option<InjectionTarget> {
if !op.query_params.is_empty() {
return Some(InjectionTarget::Query(0));
}
let sample = op.sample_body.as_deref()?;
let parsed: serde_json::Value = serde_json::from_str(sample).ok()?;
let obj = parsed.as_object()?;
for (k, v) in obj {
if v.is_string() {
return Some(InjectionTarget::BodyStringField(k.clone()));
}
}
None
}
fn inject_into_body_field(body: Option<&str>, field: &str, payload: &str) -> Option<String> {
let raw = body?;
let mut parsed: serde_json::Value = serde_json::from_str(raw).ok()?;
let obj = parsed.as_object_mut()?;
obj.insert(field.to_string(), serde_json::json!(payload));
serde_json::to_string(&parsed).ok()
}
#[allow(clippy::too_many_arguments)]
#[derive(Debug, Clone)]
struct SecurityProbe {
label: String,
headers: Vec<(String, String)>,
query: Vec<(String, String)>,
}
fn build_security_probes(schemes: &[SecuritySchemeInfo]) -> Vec<SecurityProbe> {
if schemes.is_empty() {
return Vec::new();
}
let mut probes: Vec<SecurityProbe> = Vec::new();
let mut seen_bearer = false;
let mut seen_basic = false;
let mut seen_apikey: std::collections::BTreeSet<(&'static str, String)> = Default::default();
for s in schemes {
match s {
SecuritySchemeInfo::Bearer if !seen_bearer => {
seen_bearer = true;
probes.push(SecurityProbe {
label: "security:bad-bearer".into(),
headers: vec![(
"Authorization".into(),
"Bearer self-test-invalid-token".into(),
)],
query: Vec::new(),
});
}
SecuritySchemeInfo::Basic if !seen_basic => {
seen_basic = true;
probes.push(SecurityProbe {
label: "security:bad-basic".into(),
headers: vec![(
"Authorization".into(),
"Basic c2VsZi10ZXN0OmludmFsaWQ=".into(),
)],
query: Vec::new(),
});
}
SecuritySchemeInfo::ApiKey { location, name } => {
let loc_tag = match location {
ApiKeyLocation::Header => "header",
ApiKeyLocation::Query => "query",
ApiKeyLocation::Cookie => "cookie",
};
if seen_apikey.contains(&(loc_tag, name.clone())) {
continue;
}
seen_apikey.insert((loc_tag, name.clone()));
let label = format!("security:bad-apikey:{}", name);
let bad = "self-test-invalid-key".to_string();
match location {
ApiKeyLocation::Header => probes.push(SecurityProbe {
label,
headers: vec![(name.clone(), bad)],
query: Vec::new(),
}),
ApiKeyLocation::Query => probes.push(SecurityProbe {
label,
headers: Vec::new(),
query: vec![(name.clone(), bad)],
}),
ApiKeyLocation::Cookie => probes.push(SecurityProbe {
label,
headers: vec![("Cookie".into(), format!("{}={}", name, bad))],
query: Vec::new(),
}),
}
}
_ => {}
}
}
probes.push(SecurityProbe {
label: "security:no-auth".into(),
headers: Vec::new(),
query: Vec::new(),
});
probes
}
fn strip_auth(
headers: &[(String, String)],
schemes: &[SecuritySchemeInfo],
) -> Vec<(String, String)> {
let mut apikey_headers: std::collections::BTreeSet<String> = Default::default();
for s in schemes {
if let SecuritySchemeInfo::ApiKey {
location: ApiKeyLocation::Header,
name,
} = s
{
apikey_headers.insert(name.to_lowercase());
}
if let SecuritySchemeInfo::ApiKey {
location: ApiKeyLocation::Cookie,
..
} = s
{
apikey_headers.insert("cookie".into());
}
}
headers
.iter()
.filter(|(k, _)| {
let lk = k.to_lowercase();
lk != "authorization" && !apikey_headers.contains(&lk)
})
.cloned()
.collect()
}
fn strip_auth_query(
query: &[(String, String)],
schemes: &[SecuritySchemeInfo],
) -> Vec<(String, String)> {
let mut apikey_query: std::collections::BTreeSet<String> = Default::default();
for s in schemes {
if let SecuritySchemeInfo::ApiKey {
location: ApiKeyLocation::Query,
name,
} = s
{
apikey_query.insert(name.clone());
}
}
query.iter().filter(|(k, _)| !apikey_query.contains(k)).cloned().collect()
}
#[allow(clippy::too_many_arguments)]
async fn send_case_with_extra(
client: &Client,
config: &SelfTestConfig,
method: Method,
url: &str,
label: &str,
expected_4xx: bool,
body: Option<&str>,
query: Vec<(String, String)>,
headers: Vec<(String, String)>,
extra_headers: Vec<(String, String)>,
) -> CaseOutcome {
let mut req = client.request(method.clone(), url);
let mut capture_headers: BTreeMap<String, String> = BTreeMap::new();
for (k, v) in &query {
req = req.query(&[(k.as_str(), v.as_str())]);
}
if let Some(b) = body {
req = req
.header(reqwest::header::CONTENT_TYPE, "application/json")
.body(b.to_string());
capture_headers.insert("Content-Type".to_string(), "application/json".to_string());
}
for (k, v) in &headers {
req = req.header(k, v);
capture_headers.insert(k.clone(), v.clone());
}
for (k, v) in &extra_headers {
req = req.header(k, v);
capture_headers.insert(k.clone(), v.clone());
}
let (actual_status, response_capture) = match req.send().await {
Ok(resp) => {
let status = resp.status().as_u16();
if let Some(sink) = &config.capture {
let resp_headers: BTreeMap<String, String> = resp
.headers()
.iter()
.map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap_or("").to_string()))
.collect();
let text = resp.text().await.unwrap_or_default();
let (rb, truncated) = truncate_body_for_capture(&text);
(status, Some((Some((rb, truncated)), resp_headers, None, sink.clone())))
} else {
(status, None)
}
}
Err(e) => {
let err_str = e.to_string();
if let Some(sink) = &config.capture {
(0, Some((None, BTreeMap::new(), Some(err_str), sink.clone())))
} else {
(0, None)
}
}
};
let passed = if expected_4xx {
(400..500).contains(&actual_status)
} else {
(200..400).contains(&actual_status)
};
if let Some((resp_body, resp_headers, error, sink)) = response_capture {
let (request_body, request_body_truncated) = match body {
Some(b) => {
let (rb, t) = truncate_body_for_capture(b);
(Some(rb), t)
}
None => (None, false),
};
let (response_body, response_body_truncated) = match resp_body {
Some((rb, t)) => (Some(rb), t),
None => (None, false),
};
let entry = CaseCapture {
label: label.to_string(),
method: method.to_string(),
url: build_query_url(url, &query),
request_headers: capture_headers,
request_body,
request_body_truncated,
response_status: actual_status,
response_headers: resp_headers,
response_body,
response_body_truncated,
error,
response_schema_error: None,
};
if let Ok(mut guard) = sink.lock() {
guard.push(entry);
}
}
CaseOutcome {
label: label.to_string(),
expected_4xx,
actual_status,
passed,
}
}
#[allow(clippy::too_many_arguments)]
async fn send_case(
client: &Client,
config: &SelfTestConfig,
method: Method,
url: &str,
label: &str,
expected_4xx: bool,
body: Option<&str>,
query: Vec<(String, String)>,
headers: Vec<(String, String)>,
) -> CaseOutcome {
send_case_with_extra(
client,
config,
method,
url,
label,
expected_4xx,
body,
query,
headers,
config.extra_headers.clone(),
)
.await
}
fn build_query_url(base: &str, query: &[(String, String)]) -> String {
if query.is_empty() {
return base.to_string();
}
let qs: String = query
.iter()
.map(|(k, v)| format!("{}={}", urlencoding::encode(k), urlencoding::encode(v)))
.collect::<Vec<_>>()
.join("&");
if base.contains('?') {
format!("{base}&{qs}")
} else {
format!("{base}?{qs}")
}
}
#[allow(dead_code)]
fn build_url(target: &str, path_template: &str, path_params: &[(String, String)]) -> String {
build_url_with_base(target, None, path_template, path_params)
}
fn build_url_with_base(
target: &str,
base_path: Option<&str>,
path_template: &str,
path_params: &[(String, String)],
) -> String {
let mut url = path_template.to_string();
for (name, value) in path_params {
let placeholder = format!("{{{}}}", name);
if !value.is_empty() {
url = url.replace(&placeholder, value);
}
}
let target = target.trim_end_matches('/');
let prefix = match base_path {
Some(bp) if !bp.is_empty() => {
let trimmed = bp.trim_end_matches('/');
if trimmed.starts_with('/') {
trimmed.to_string()
} else {
format!("/{}", trimmed)
}
}
_ => String::new(),
};
let path = if url.starts_with('/') {
url
} else {
format!("/{url}")
};
format!("{target}{prefix}{path}")
}
#[cfg(test)]
mod tests {
use super::*;
fn op(
method: &str,
path: &str,
body: Option<&str>,
query: Vec<(&str, &str)>,
headers: Vec<(&str, &str)>,
path_params: Vec<(&str, &str)>,
) -> AnnotatedOperation {
AnnotatedOperation {
method: method.into(),
path: path.into(),
features: Vec::new(),
request_body_content_type: body.map(|_| "application/json".into()),
sample_body: body.map(|s| s.to_string()),
query_params: query.into_iter().map(|(a, b)| (a.into(), b.into())).collect(),
header_params: headers.into_iter().map(|(a, b)| (a.into(), b.into())).collect(),
path_params: path_params.into_iter().map(|(a, b)| (a.into(), b.into())).collect(),
response_schema: None,
response_schemas: std::collections::BTreeMap::new(),
request_body_schema: None,
security_schemes: Vec::new(),
}
}
#[test]
fn build_url_substitutes_path_params() {
let url = build_url(
"https://api.test/",
"/users/{id}/posts/{pid}",
&[("id".into(), "42".into()), ("pid".into(), "7".into())],
);
assert_eq!(url, "https://api.test/users/42/posts/7");
}
#[test]
fn detect_target_misconfiguration_when_all_positives_share_status() {
let mut report = SelfTestReport {
positive_pass: 0,
positive_fail: 50,
..Default::default()
};
for i in 0..50 {
report.operations.push(OperationResult {
method: "GET".into(),
path: format!("/r/{i}"),
positive: Some(CaseOutcome {
label: "positive".into(),
expected_4xx: false,
actual_status: 404,
passed: false,
}),
negatives: Vec::new(),
});
}
assert_eq!(report.detect_target_misconfiguration(), Some(404));
}
#[test]
fn detect_target_misconfiguration_returns_none_when_some_pass() {
let mut report = SelfTestReport {
positive_pass: 5,
positive_fail: 50,
..Default::default()
};
for i in 0..55 {
report.operations.push(OperationResult {
method: "GET".into(),
path: format!("/r/{i}"),
positive: Some(CaseOutcome {
label: "positive".into(),
expected_4xx: false,
actual_status: if i < 5 { 200 } else { 404 },
passed: i < 5,
}),
negatives: Vec::new(),
});
}
assert_eq!(report.detect_target_misconfiguration(), None);
}
#[test]
fn build_url_applies_base_path_when_present() {
let url = build_url_with_base(
"https://api.example.com",
Some("/api"),
"/users/{id}",
&[("id".into(), "42".into())],
);
assert_eq!(url, "https://api.example.com/api/users/42");
}
#[test]
fn build_url_normalises_base_path() {
let no_slash = build_url_with_base("https://t", Some("api"), "/x", &[]);
assert_eq!(no_slash, "https://t/api/x");
let trailing = build_url_with_base("https://t", Some("/api/"), "/x", &[]);
assert_eq!(trailing, "https://t/api/x");
let empty = build_url_with_base("https://t", Some(""), "/x", &[]);
assert_eq!(empty, "https://t/x");
let none = build_url_with_base("https://t", None, "/x", &[]);
assert_eq!(none, "https://t/x");
}
#[test]
fn build_url_keeps_placeholders_when_no_sample() {
let url = build_url("https://api.test", "/users/{id}", &[]);
assert_eq!(url, "https://api.test/users/{id}");
}
#[test]
fn report_summary_calls_out_misses() {
let r = SelfTestReport {
positive_pass: 3,
positive_fail: 0,
negative_caught: BTreeMap::from([("request-body".into(), 2)]),
negative_missed: BTreeMap::from([("request-body".into(), 1)]),
operations: Vec::new(),
};
let summary = r.render_summary();
assert!(summary.contains("Positives: 3 pass / 0 fail"));
assert!(summary.contains("Negatives [request-body]: 2 caught / 1 missed"));
assert!(summary.contains("⚠"));
assert!(!r.all_passed());
}
#[test]
fn report_all_passed_when_no_miss() {
let r = SelfTestReport {
positive_pass: 5,
positive_fail: 0,
negative_caught: BTreeMap::from([("parameters".into(), 3)]),
negative_missed: BTreeMap::new(),
operations: Vec::new(),
};
assert!(r.all_passed());
assert!(r.render_summary().contains("✓"));
}
#[tokio::test]
async fn run_self_test_against_unreachable_target_marks_all_failed() {
let cfg = SelfTestConfig {
target_url: "http://127.0.0.1:1".into(),
timeout: Duration::from_millis(200),
..Default::default()
};
let ops = vec![op(
"POST",
"/users",
Some("{\"name\":\"a\"}"),
vec![],
vec![],
vec![],
)];
let report = run_self_test(&ops, &cfg).await.expect("client builds");
assert_eq!(report.positive_fail, 1);
assert!(report.negative_missed.values().sum::<usize>() >= 1);
assert!(!report.all_passed());
}
#[tokio::test]
async fn schema_driven_negatives_fire_when_schema_present() {
use openapiv3::{ObjectType, ReferenceOr, Schema, SchemaData, SchemaKind, Type};
let cfg = SelfTestConfig {
target_url: "http://127.0.0.1:1".into(),
timeout: Duration::from_millis(200),
..Default::default()
};
let mut obj = ObjectType::default();
obj.properties.insert(
"name".to_string(),
ReferenceOr::Item(Box::new(Schema {
schema_data: SchemaData::default(),
schema_kind: SchemaKind::Type(Type::String(Default::default())),
})),
);
obj.properties.insert(
"age".to_string(),
ReferenceOr::Item(Box::new(Schema {
schema_data: SchemaData::default(),
schema_kind: SchemaKind::Type(Type::Integer(Default::default())),
})),
);
obj.required = vec!["name".into(), "age".into()];
let schema = Schema {
schema_data: SchemaData::default(),
schema_kind: SchemaKind::Type(Type::Object(obj)),
};
let mut o =
op("POST", "/users", Some(r#"{"name":"Ada","age":30}"#), vec![], vec![], vec![]);
o.request_body_schema = Some(schema);
let report = run_self_test(&[o], &cfg).await.expect("client builds");
let labels: std::collections::BTreeSet<String> = report
.operations
.iter()
.flat_map(|op| op.negatives.iter().map(|n| n.label.clone()))
.collect();
assert!(
labels.iter().any(|l| l.starts_with("request-body:type-mismatch:")),
"missing type-mismatch negative; got {labels:?}"
);
assert!(
labels.iter().any(|l| l.starts_with("request-body:required-removed:")),
"missing required-removed negative; got {labels:?}"
);
assert!(
labels.iter().any(|l| l == "parameters:uri-too-long"),
"missing URI-length negative; got {labels:?}"
);
}
#[tokio::test]
async fn no_sample_body_still_produces_request_body_negatives() {
let cfg = SelfTestConfig {
target_url: "http://127.0.0.1:1".into(),
timeout: Duration::from_millis(200),
..Default::default()
};
let ops = vec![op("POST", "/x", None, vec![], vec![], vec![])];
let mut ops_fixed = ops;
ops_fixed[0].request_body_content_type = Some("application/json".into());
let report = run_self_test(&ops_fixed, &cfg).await.expect("client builds");
assert!(
report.negative_missed.values().sum::<usize>() >= 2,
"expected ≥2 request-body negatives, got {:?}",
report.negative_missed
);
}
#[tokio::test]
async fn path_param_only_endpoint_produces_a_probe() {
let cfg = SelfTestConfig {
target_url: "http://127.0.0.1:1".into(),
timeout: Duration::from_millis(200),
..Default::default()
};
let ops = vec![op(
"GET",
"/teams/{team-id}",
None,
vec![],
vec![],
vec![("team-id", "1")],
)];
let report = run_self_test(&ops, &cfg).await.expect("client builds");
let total: usize = report.negative_caught.values().sum::<usize>()
+ report.negative_missed.values().sum::<usize>();
assert!(total >= 1, "expected ≥1 path-param probe, got {:?}", report);
}
#[test]
fn effective_op_headers_appends_geo_ip_to_default_headers() {
let ip: IpAddr = "203.0.113.42".parse().unwrap();
let headers = effective_op_headers(
&[("Accept".into(), "application/json".into())],
Some(ip),
&default_geo_source_headers(),
);
let names: Vec<&str> = headers.iter().map(|(k, _)| k.as_str()).collect();
assert!(names.contains(&"Accept"));
assert!(names.contains(&"X-Forwarded-For"));
assert!(names.contains(&"True-Client-IP"));
assert!(names.contains(&"CF-Connecting-IP"));
let geo_values: Vec<&str> =
headers.iter().filter(|(k, _)| k != "Accept").map(|(_, v)| v.as_str()).collect();
for v in geo_values {
assert_eq!(v, "203.0.113.42");
}
}
#[test]
fn effective_op_headers_respects_spec_declared_header() {
let ip: IpAddr = "203.0.113.99".parse().unwrap();
let headers = effective_op_headers(
&[("x-forwarded-for".into(), "10.0.0.1".into())],
Some(ip),
&["X-Forwarded-For".to_string()],
);
let xff: Vec<&str> = headers
.iter()
.filter(|(k, _)| k.eq_ignore_ascii_case("x-forwarded-for"))
.map(|(_, v)| v.as_str())
.collect();
assert_eq!(xff, vec!["10.0.0.1"]);
}
#[test]
fn effective_op_headers_is_a_noop_without_geo_ip() {
let base = vec![("Accept".into(), "json".into())];
let h1 = effective_op_headers(&base, None, &default_geo_source_headers());
assert_eq!(h1, base);
let ip: IpAddr = "10.0.0.1".parse().unwrap();
let h2 = effective_op_headers(&base, Some(ip), &[]);
assert_eq!(h2, base);
}
#[test]
fn build_client_pool_one_per_source_ip() {
let mut cfg = SelfTestConfig {
target_url: "http://127.0.0.1:1".into(),
timeout: Duration::from_millis(200),
..Default::default()
};
assert_eq!(build_client_pool(&cfg).expect("default builds").len(), 1);
cfg.source_ips = vec!["127.0.0.1".parse().unwrap()];
assert_eq!(build_client_pool(&cfg).expect("bind loopback").len(), 1);
}
#[tokio::test]
async fn run_self_test_with_geo_source_completes() {
let cfg = SelfTestConfig {
target_url: "http://127.0.0.1:1".into(),
timeout: Duration::from_millis(200),
geo_source_ips: vec![
"203.0.113.1".parse().unwrap(),
"203.0.113.2".parse().unwrap(),
],
..Default::default()
};
let ops = vec![
op("GET", "/a", None, vec![], vec![], vec![]),
op("GET", "/b", None, vec![], vec![], vec![]),
op("GET", "/c", None, vec![], vec![], vec![]),
];
let report = run_self_test(&ops, &cfg).await.expect("client builds");
assert_eq!(report.operations.len(), 3);
}
#[tokio::test]
async fn geo_headers_present_on_every_probe_with_capture() {
let sink: Arc<Mutex<Vec<CaseCapture>>> = Arc::new(Mutex::new(Vec::new()));
let cfg = SelfTestConfig {
target_url: "http://127.0.0.1:1".into(),
timeout: Duration::from_millis(50),
geo_source_ips: vec!["203.0.113.5".parse().unwrap()],
capture: Some(sink.clone()),
..Default::default()
};
let ops = vec![op(
"GET",
"/items",
Some("{}"),
vec![("id", "1")],
vec![("X-Trace", "x")],
vec![],
)];
let _ = run_self_test(&ops, &cfg).await.expect("client builds");
let captures = sink.lock().unwrap();
assert!(!captures.is_empty(), "self-test should record probes");
let geo_headers: std::collections::HashSet<&str> =
["X-Forwarded-For", "True-Client-IP", "CF-Connecting-IP"].into_iter().collect();
for c in captures.iter() {
let has_geo = c
.request_headers
.iter()
.any(|(k, v)| geo_headers.contains(k.as_str()) && v == "203.0.113.5");
assert!(
has_geo,
"probe `{}` is missing the geo IP header; got headers: {:?}",
c.label, c.request_headers
);
}
}
#[tokio::test]
async fn content_type_swap_probes_fire_for_json_bodies() {
let sink: Arc<Mutex<Vec<CaseCapture>>> = Arc::new(Mutex::new(Vec::new()));
let cfg = SelfTestConfig {
target_url: "http://127.0.0.1:1".into(),
timeout: Duration::from_millis(50),
capture: Some(sink.clone()),
..Default::default()
};
let ops = vec![
op("POST", "/users", Some("{\"name\":\"a\"}"), vec![], vec![], vec![]),
op("GET", "/ping", None, vec![], vec![], vec![]),
];
let _ = run_self_test(&ops, &cfg).await.expect("client builds");
let captures = sink.lock().unwrap();
let swap_labels: Vec<&str> = captures
.iter()
.filter(|c| c.label.starts_with("request-body:content-type-mismatch:"))
.map(|c| c.label.as_str())
.collect();
assert_eq!(
swap_labels.len(),
4,
"expected 4 content-type-swap probes (one per variant), got: {swap_labels:?}"
);
let expected_labels = [
"request-body:content-type-mismatch:xml",
"request-body:content-type-mismatch:yaml",
"request-body:content-type-mismatch:multipart",
"request-body:content-type-mismatch:urlencoded",
];
for want in expected_labels {
assert!(swap_labels.contains(&want), "missing swap probe `{want}`");
}
for c in captures.iter() {
let Some(suffix) = c.label.strip_prefix("request-body:content-type-mismatch:") else {
continue;
};
let want_ct = match suffix {
"xml" => "application/xml",
"yaml" => "application/yaml",
"multipart" => "multipart/form-data",
"urlencoded" => "application/x-www-form-urlencoded",
_ => continue,
};
let got_ct = c
.request_headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case("content-type"))
.map(|(_, v)| v.as_str())
.unwrap_or("");
assert_eq!(got_ct, want_ct, "swap probe `{}` sent wrong CT", c.label);
}
let body_less_swaps = captures
.iter()
.filter(|c| {
c.label.starts_with("request-body:content-type-mismatch:")
&& c.url.ends_with("/ping")
})
.count();
assert_eq!(
body_less_swaps, 0,
"GET /ping has no request body; should not produce content-type-swap probes"
);
}
#[test]
fn json_serialises_report() {
let r = SelfTestReport {
positive_pass: 1,
positive_fail: 0,
negative_caught: BTreeMap::new(),
negative_missed: BTreeMap::new(),
operations: vec![OperationResult {
method: "GET".into(),
path: "/x".into(),
positive: Some(CaseOutcome {
label: "positive".into(),
expected_4xx: false,
actual_status: 200,
passed: true,
}),
negatives: Vec::new(),
}],
};
let json = serde_json::to_value(&r).expect("serialises");
assert_eq!(json["positive_pass"], serde_json::json!(1));
assert_eq!(json["operations"][0]["positive"]["actual_status"], serde_json::json!(200));
}
}