1use super::spec_driven::AnnotatedOperation;
22use reqwest::{Client, Method};
23use std::collections::BTreeMap;
24use std::time::Duration;
25
26#[derive(Debug, Clone)]
28pub struct SelfTestConfig {
29 pub target_url: String,
30 pub skip_tls_verify: bool,
31 pub timeout: Duration,
32 pub extra_headers: Vec<(String, String)>,
34 pub delay_between_requests: Duration,
36}
37
38impl Default for SelfTestConfig {
39 fn default() -> Self {
40 Self {
41 target_url: "http://localhost:3000".into(),
42 skip_tls_verify: false,
43 timeout: Duration::from_secs(15),
44 extra_headers: Vec::new(),
45 delay_between_requests: Duration::from_millis(0),
46 }
47 }
48}
49
50#[derive(Debug, Clone, serde::Serialize)]
52pub struct CaseOutcome {
53 pub label: String,
54 pub expected_4xx: bool,
55 pub actual_status: u16,
56 pub passed: bool,
59}
60
61#[derive(Debug, Clone, serde::Serialize)]
63pub struct OperationResult {
64 pub method: String,
65 pub path: String,
66 pub positive: Option<CaseOutcome>,
67 pub negatives: Vec<CaseOutcome>,
68}
69
70#[derive(Debug, Default, Clone, serde::Serialize)]
72pub struct SelfTestReport {
73 pub positive_pass: usize,
74 pub positive_fail: usize,
75 pub negative_caught: BTreeMap<String, usize>,
78 pub negative_missed: BTreeMap<String, usize>,
81 pub operations: Vec<OperationResult>,
82}
83
84impl SelfTestReport {
85 pub fn all_passed(&self) -> bool {
88 self.positive_fail == 0 && self.negative_missed.values().sum::<usize>() == 0
89 }
90
91 pub fn render_summary(&self) -> String {
95 let mut out = String::new();
96 out.push_str(&format!(
97 "Positives: {} pass / {} fail\n",
98 self.positive_pass, self.positive_fail
99 ));
100 let mut keys: Vec<&String> =
101 self.negative_caught.keys().chain(self.negative_missed.keys()).collect();
102 keys.sort();
103 keys.dedup();
104 for cat in keys {
105 let caught = self.negative_caught.get(cat).copied().unwrap_or(0);
106 let missed = self.negative_missed.get(cat).copied().unwrap_or(0);
107 let mark = if missed == 0 { "✓" } else { "⚠" };
108 out.push_str(&format!(
109 "Negatives [{}]: {} caught / {} missed {}\n",
110 cat, caught, missed, mark
111 ));
112 }
113 out
114 }
115}
116
117pub async fn run_self_test(
122 operations: &[AnnotatedOperation],
123 config: &SelfTestConfig,
124) -> Result<SelfTestReport, reqwest::Error> {
125 let mut builder = Client::builder().timeout(config.timeout);
126 if config.skip_tls_verify {
127 builder = builder.danger_accept_invalid_certs(true);
128 }
129 let client = builder.build()?;
130
131 let mut report = SelfTestReport::default();
132 for op in operations {
133 let result = test_operation(&client, config, op).await;
134 if let Some(p) = &result.positive {
135 if p.passed {
136 report.positive_pass += 1;
137 } else {
138 report.positive_fail += 1;
139 }
140 }
141 for neg in &result.negatives {
142 let cat = neg.label.split(':').next().unwrap_or("other").to_string();
143 if neg.passed {
144 *report.negative_caught.entry(cat).or_insert(0) += 1;
145 } else {
146 *report.negative_missed.entry(cat).or_insert(0) += 1;
147 }
148 }
149 report.operations.push(result);
150 if !config.delay_between_requests.is_zero() {
151 tokio::time::sleep(config.delay_between_requests).await;
152 }
153 }
154 Ok(report)
155}
156
157async fn test_operation(
158 client: &Client,
159 config: &SelfTestConfig,
160 op: &AnnotatedOperation,
161) -> OperationResult {
162 let url = build_url(&config.target_url, &op.path, &op.path_params);
163 let method = Method::from_bytes(op.method.to_uppercase().as_bytes()).unwrap_or(Method::GET);
164
165 let positive = send_case(
167 client,
168 config,
169 method.clone(),
170 &url,
171 "positive",
172 false,
173 op.sample_body.as_deref(),
174 op.query_params.clone(),
175 op.header_params.clone(),
176 )
177 .await;
178
179 let mut negatives = Vec::new();
181
182 if op.request_body_content_type.is_some() && op.sample_body.is_some() {
184 negatives.push(
185 send_case(
186 client,
187 config,
188 method.clone(),
189 &url,
190 "request-body:empty",
191 true,
192 Some("{}"),
193 op.query_params.clone(),
194 op.header_params.clone(),
195 )
196 .await,
197 );
198
199 negatives.push(
203 send_case(
204 client,
205 config,
206 method.clone(),
207 &url,
208 "request-body:wrong-type",
209 true,
210 Some("[]"),
211 op.query_params.clone(),
212 op.header_params.clone(),
213 )
214 .await,
215 );
216 }
217
218 if !op.query_params.is_empty() {
220 let mut q = op.query_params.clone();
221 q.remove(0);
222 negatives.push(
223 send_case(
224 client,
225 config,
226 method.clone(),
227 &url,
228 "parameters:missing-query",
229 true,
230 op.sample_body.as_deref(),
231 q,
232 op.header_params.clone(),
233 )
234 .await,
235 );
236 }
237
238 if !op.header_params.is_empty() {
240 let mut h = op.header_params.clone();
241 h.remove(0);
242 negatives.push(
243 send_case(
244 client,
245 config,
246 method.clone(),
247 &url,
248 "parameters:missing-header",
249 true,
250 op.sample_body.as_deref(),
251 op.query_params.clone(),
252 h,
253 )
254 .await,
255 );
256 }
257
258 OperationResult {
259 method: op.method.clone(),
260 path: op.path.clone(),
261 positive: Some(positive),
262 negatives,
263 }
264}
265
266#[allow(clippy::too_many_arguments)]
267async fn send_case(
268 client: &Client,
269 config: &SelfTestConfig,
270 method: Method,
271 url: &str,
272 label: &str,
273 expected_4xx: bool,
274 body: Option<&str>,
275 query: Vec<(String, String)>,
276 headers: Vec<(String, String)>,
277) -> CaseOutcome {
278 let mut req = client.request(method, url);
279 for (k, v) in &query {
280 req = req.query(&[(k.as_str(), v.as_str())]);
281 }
282 for (k, v) in &headers {
283 req = req.header(k, v);
284 }
285 for (k, v) in &config.extra_headers {
286 req = req.header(k, v);
287 }
288 if let Some(b) = body {
289 req = req
290 .header(reqwest::header::CONTENT_TYPE, "application/json")
291 .body(b.to_string());
292 }
293
294 let actual_status = match req.send().await {
295 Ok(resp) => resp.status().as_u16(),
296 Err(e) if e.is_timeout() => 0,
297 Err(_) => 0,
298 };
299
300 let passed = if expected_4xx {
301 (400..500).contains(&actual_status)
302 } else {
303 (200..400).contains(&actual_status)
304 };
305
306 CaseOutcome {
307 label: label.to_string(),
308 expected_4xx,
309 actual_status,
310 passed,
311 }
312}
313
314fn build_url(target: &str, path_template: &str, path_params: &[(String, String)]) -> String {
320 let mut url = path_template.to_string();
321 for (name, value) in path_params {
322 let placeholder = format!("{{{}}}", name);
323 if !value.is_empty() {
324 url = url.replace(&placeholder, value);
325 }
326 }
327 let target = target.trim_end_matches('/');
328 if url.starts_with('/') {
329 format!("{}{}", target, url)
330 } else {
331 format!("{}/{}", target, url)
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338
339 fn op(
340 method: &str,
341 path: &str,
342 body: Option<&str>,
343 query: Vec<(&str, &str)>,
344 headers: Vec<(&str, &str)>,
345 path_params: Vec<(&str, &str)>,
346 ) -> AnnotatedOperation {
347 AnnotatedOperation {
348 method: method.into(),
349 path: path.into(),
350 features: Vec::new(),
351 request_body_content_type: body.map(|_| "application/json".into()),
352 sample_body: body.map(|s| s.to_string()),
353 query_params: query.into_iter().map(|(a, b)| (a.into(), b.into())).collect(),
354 header_params: headers.into_iter().map(|(a, b)| (a.into(), b.into())).collect(),
355 path_params: path_params.into_iter().map(|(a, b)| (a.into(), b.into())).collect(),
356 response_schema: None,
357 security_schemes: Vec::new(),
358 }
359 }
360
361 #[test]
362 fn build_url_substitutes_path_params() {
363 let url = build_url(
364 "https://api.test/",
365 "/users/{id}/posts/{pid}",
366 &[("id".into(), "42".into()), ("pid".into(), "7".into())],
367 );
368 assert_eq!(url, "https://api.test/users/42/posts/7");
369 }
370
371 #[test]
372 fn build_url_keeps_placeholders_when_no_sample() {
373 let url = build_url("https://api.test", "/users/{id}", &[]);
374 assert_eq!(url, "https://api.test/users/{id}");
375 }
376
377 #[test]
378 fn report_summary_calls_out_misses() {
379 let r = SelfTestReport {
380 positive_pass: 3,
381 positive_fail: 0,
382 negative_caught: BTreeMap::from([("request-body".into(), 2)]),
383 negative_missed: BTreeMap::from([("request-body".into(), 1)]),
384 operations: Vec::new(),
385 };
386 let summary = r.render_summary();
387 assert!(summary.contains("Positives: 3 pass / 0 fail"));
388 assert!(summary.contains("Negatives [request-body]: 2 caught / 1 missed"));
389 assert!(summary.contains("⚠"));
390 assert!(!r.all_passed());
391 }
392
393 #[test]
394 fn report_all_passed_when_no_miss() {
395 let r = SelfTestReport {
396 positive_pass: 5,
397 positive_fail: 0,
398 negative_caught: BTreeMap::from([("parameters".into(), 3)]),
399 negative_missed: BTreeMap::new(),
400 operations: Vec::new(),
401 };
402 assert!(r.all_passed());
403 assert!(r.render_summary().contains("✓"));
404 }
405
406 #[tokio::test]
407 async fn run_self_test_against_unreachable_target_marks_all_failed() {
408 let cfg = SelfTestConfig {
411 target_url: "http://127.0.0.1:1".into(),
412 timeout: Duration::from_millis(200),
413 ..Default::default()
414 };
415 let ops = vec![op(
416 "POST",
417 "/users",
418 Some("{\"name\":\"a\"}"),
419 vec![],
420 vec![],
421 vec![],
422 )];
423 let report = run_self_test(&ops, &cfg).await.expect("client builds");
424 assert_eq!(report.positive_fail, 1);
428 assert!(report.negative_missed.values().sum::<usize>() >= 1);
429 assert!(!report.all_passed());
430 }
431
432 #[test]
433 fn json_serialises_report() {
434 let r = SelfTestReport {
435 positive_pass: 1,
436 positive_fail: 0,
437 negative_caught: BTreeMap::new(),
438 negative_missed: BTreeMap::new(),
439 operations: vec![OperationResult {
440 method: "GET".into(),
441 path: "/x".into(),
442 positive: Some(CaseOutcome {
443 label: "positive".into(),
444 expected_4xx: false,
445 actual_status: 200,
446 passed: true,
447 }),
448 negatives: Vec::new(),
449 }],
450 };
451 let json = serde_json::to_value(&r).expect("serialises");
452 assert_eq!(json["positive_pass"], serde_json::json!(1));
453 assert_eq!(json["operations"][0]["positive"]["actual_status"], serde_json::json!(200));
454 }
455}