use super::spec_driven::AnnotatedOperation;
use reqwest::{Client, Method};
use std::collections::BTreeMap;
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct SelfTestConfig {
pub target_url: String,
pub skip_tls_verify: bool,
pub timeout: Duration,
pub extra_headers: Vec<(String, String)>,
pub delay_between_requests: Duration,
}
impl Default for SelfTestConfig {
fn default() -> Self {
Self {
target_url: "http://localhost:3000".into(),
skip_tls_verify: false,
timeout: Duration::from_secs(15),
extra_headers: Vec::new(),
delay_between_requests: Duration::from_millis(0),
}
}
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct CaseOutcome {
pub label: String,
pub expected_4xx: bool,
pub actual_status: u16,
pub passed: bool,
}
#[derive(Debug, Clone, serde::Serialize)]
pub struct OperationResult {
pub method: String,
pub path: String,
pub positive: Option<CaseOutcome>,
pub negatives: Vec<CaseOutcome>,
}
#[derive(Debug, Default, Clone, serde::Serialize)]
pub struct SelfTestReport {
pub positive_pass: usize,
pub positive_fail: usize,
pub negative_caught: BTreeMap<String, usize>,
pub negative_missed: BTreeMap<String, usize>,
pub operations: Vec<OperationResult>,
}
impl SelfTestReport {
pub fn all_passed(&self) -> bool {
self.positive_fail == 0 && self.negative_missed.values().sum::<usize>() == 0
}
pub fn render_summary(&self) -> String {
let mut out = String::new();
out.push_str(&format!(
"Positives: {} pass / {} fail\n",
self.positive_pass, self.positive_fail
));
let mut keys: Vec<&String> =
self.negative_caught.keys().chain(self.negative_missed.keys()).collect();
keys.sort();
keys.dedup();
for cat in keys {
let caught = self.negative_caught.get(cat).copied().unwrap_or(0);
let missed = self.negative_missed.get(cat).copied().unwrap_or(0);
let mark = if missed == 0 { "✓" } else { "⚠" };
out.push_str(&format!(
"Negatives [{}]: {} caught / {} missed {}\n",
cat, caught, missed, mark
));
}
out
}
}
pub async fn run_self_test(
operations: &[AnnotatedOperation],
config: &SelfTestConfig,
) -> Result<SelfTestReport, reqwest::Error> {
let mut builder = Client::builder().timeout(config.timeout);
if config.skip_tls_verify {
builder = builder.danger_accept_invalid_certs(true);
}
let client = builder.build()?;
let mut report = SelfTestReport::default();
for op in operations {
let result = test_operation(&client, config, op).await;
if let Some(p) = &result.positive {
if p.passed {
report.positive_pass += 1;
} else {
report.positive_fail += 1;
}
}
for neg in &result.negatives {
let cat = neg.label.split(':').next().unwrap_or("other").to_string();
if neg.passed {
*report.negative_caught.entry(cat).or_insert(0) += 1;
} else {
*report.negative_missed.entry(cat).or_insert(0) += 1;
}
}
report.operations.push(result);
if !config.delay_between_requests.is_zero() {
tokio::time::sleep(config.delay_between_requests).await;
}
}
Ok(report)
}
async fn test_operation(
client: &Client,
config: &SelfTestConfig,
op: &AnnotatedOperation,
) -> OperationResult {
let url = build_url(&config.target_url, &op.path, &op.path_params);
let method = Method::from_bytes(op.method.to_uppercase().as_bytes()).unwrap_or(Method::GET);
let positive = send_case(
client,
config,
method.clone(),
&url,
"positive",
false,
op.sample_body.as_deref(),
op.query_params.clone(),
op.header_params.clone(),
)
.await;
let mut negatives = Vec::new();
if op.request_body_content_type.is_some() && op.sample_body.is_some() {
negatives.push(
send_case(
client,
config,
method.clone(),
&url,
"request-body:empty",
true,
Some("{}"),
op.query_params.clone(),
op.header_params.clone(),
)
.await,
);
negatives.push(
send_case(
client,
config,
method.clone(),
&url,
"request-body:wrong-type",
true,
Some("[]"),
op.query_params.clone(),
op.header_params.clone(),
)
.await,
);
}
if !op.query_params.is_empty() {
let mut q = op.query_params.clone();
q.remove(0);
negatives.push(
send_case(
client,
config,
method.clone(),
&url,
"parameters:missing-query",
true,
op.sample_body.as_deref(),
q,
op.header_params.clone(),
)
.await,
);
}
if !op.header_params.is_empty() {
let mut h = op.header_params.clone();
h.remove(0);
negatives.push(
send_case(
client,
config,
method.clone(),
&url,
"parameters:missing-header",
true,
op.sample_body.as_deref(),
op.query_params.clone(),
h,
)
.await,
);
}
OperationResult {
method: op.method.clone(),
path: op.path.clone(),
positive: Some(positive),
negatives,
}
}
#[allow(clippy::too_many_arguments)]
async fn send_case(
client: &Client,
config: &SelfTestConfig,
method: Method,
url: &str,
label: &str,
expected_4xx: bool,
body: Option<&str>,
query: Vec<(String, String)>,
headers: Vec<(String, String)>,
) -> CaseOutcome {
let mut req = client.request(method, url);
for (k, v) in &query {
req = req.query(&[(k.as_str(), v.as_str())]);
}
for (k, v) in &headers {
req = req.header(k, v);
}
for (k, v) in &config.extra_headers {
req = req.header(k, v);
}
if let Some(b) = body {
req = req
.header(reqwest::header::CONTENT_TYPE, "application/json")
.body(b.to_string());
}
let actual_status = match req.send().await {
Ok(resp) => resp.status().as_u16(),
Err(e) if e.is_timeout() => 0,
Err(_) => 0,
};
let passed = if expected_4xx {
(400..500).contains(&actual_status)
} else {
(200..400).contains(&actual_status)
};
CaseOutcome {
label: label.to_string(),
expected_4xx,
actual_status,
passed,
}
}
fn build_url(target: &str, path_template: &str, path_params: &[(String, String)]) -> String {
let mut url = path_template.to_string();
for (name, value) in path_params {
let placeholder = format!("{{{}}}", name);
if !value.is_empty() {
url = url.replace(&placeholder, value);
}
}
let target = target.trim_end_matches('/');
if url.starts_with('/') {
format!("{}{}", target, url)
} else {
format!("{}/{}", target, url)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn op(
method: &str,
path: &str,
body: Option<&str>,
query: Vec<(&str, &str)>,
headers: Vec<(&str, &str)>,
path_params: Vec<(&str, &str)>,
) -> AnnotatedOperation {
AnnotatedOperation {
method: method.into(),
path: path.into(),
features: Vec::new(),
request_body_content_type: body.map(|_| "application/json".into()),
sample_body: body.map(|s| s.to_string()),
query_params: query.into_iter().map(|(a, b)| (a.into(), b.into())).collect(),
header_params: headers.into_iter().map(|(a, b)| (a.into(), b.into())).collect(),
path_params: path_params.into_iter().map(|(a, b)| (a.into(), b.into())).collect(),
response_schema: None,
security_schemes: Vec::new(),
}
}
#[test]
fn build_url_substitutes_path_params() {
let url = build_url(
"https://api.test/",
"/users/{id}/posts/{pid}",
&[("id".into(), "42".into()), ("pid".into(), "7".into())],
);
assert_eq!(url, "https://api.test/users/42/posts/7");
}
#[test]
fn build_url_keeps_placeholders_when_no_sample() {
let url = build_url("https://api.test", "/users/{id}", &[]);
assert_eq!(url, "https://api.test/users/{id}");
}
#[test]
fn report_summary_calls_out_misses() {
let r = SelfTestReport {
positive_pass: 3,
positive_fail: 0,
negative_caught: BTreeMap::from([("request-body".into(), 2)]),
negative_missed: BTreeMap::from([("request-body".into(), 1)]),
operations: Vec::new(),
};
let summary = r.render_summary();
assert!(summary.contains("Positives: 3 pass / 0 fail"));
assert!(summary.contains("Negatives [request-body]: 2 caught / 1 missed"));
assert!(summary.contains("⚠"));
assert!(!r.all_passed());
}
#[test]
fn report_all_passed_when_no_miss() {
let r = SelfTestReport {
positive_pass: 5,
positive_fail: 0,
negative_caught: BTreeMap::from([("parameters".into(), 3)]),
negative_missed: BTreeMap::new(),
operations: Vec::new(),
};
assert!(r.all_passed());
assert!(r.render_summary().contains("✓"));
}
#[tokio::test]
async fn run_self_test_against_unreachable_target_marks_all_failed() {
let cfg = SelfTestConfig {
target_url: "http://127.0.0.1:1".into(),
timeout: Duration::from_millis(200),
..Default::default()
};
let ops = vec![op(
"POST",
"/users",
Some("{\"name\":\"a\"}"),
vec![],
vec![],
vec![],
)];
let report = run_self_test(&ops, &cfg).await.expect("client builds");
assert_eq!(report.positive_fail, 1);
assert!(report.negative_missed.values().sum::<usize>() >= 1);
assert!(!report.all_passed());
}
#[test]
fn json_serialises_report() {
let r = SelfTestReport {
positive_pass: 1,
positive_fail: 0,
negative_caught: BTreeMap::new(),
negative_missed: BTreeMap::new(),
operations: vec![OperationResult {
method: "GET".into(),
path: "/x".into(),
positive: Some(CaseOutcome {
label: "positive".into(),
expected_4xx: false,
actual_status: 200,
passed: true,
}),
negatives: Vec::new(),
}],
};
let json = serde_json::to_value(&r).expect("serialises");
assert_eq!(json["positive_pass"], serde_json::json!(1));
assert_eq!(json["operations"][0]["positive"]["actual_status"], serde_json::json!(200));
}
}