1use crate::common::{TestConfig, TestResult, TestRunner};
2use crate::error::{Error, Result};
3use async_trait::async_trait;
4use chrono::Utc;
5use jsonschema::JSONSchema;
6use log::{info, warn};
7use reqwest::{Client, Method, Response};
8use serde::{Deserialize, Serialize};
9use std::time::{Duration, Instant};
10use tokio_retry::strategy::{ExponentialBackoff, jitter};
11use tokio_retry::Retry;
12
13#[derive(Debug, Serialize, Deserialize, Default)]
14pub struct RetryConfig {
15 #[serde(default = "default_max_retries")]
16 pub max_retries: u32,
17 #[serde(default = "default_initial_delay")]
18 pub initial_delay_ms: u64,
19 #[serde(default = "default_max_delay")]
20 pub max_delay_ms: u64,
21 #[serde(default = "default_retry_status_codes")]
22 pub retry_status_codes: Vec<u16>,
23 #[serde(default)]
24 pub retry_on_timeout: bool,
25 #[serde(default)]
26 pub retry_on_connection_error: bool,
27}
28
29fn default_max_retries() -> u32 {
30 3
31}
32
33fn default_initial_delay() -> u64 {
34 100
35}
36
37fn default_max_delay() -> u64 {
38 5000
39}
40
41fn default_retry_status_codes() -> Vec<u16> {
42 vec![408, 429, 500, 502, 503, 504]
43}
44
45#[derive(Debug, Serialize, Deserialize)]
46pub struct ApiTestConfig {
47 #[serde(flatten)]
48 pub base: TestConfig,
49 pub url: String,
50 pub method: String,
51 pub headers: Option<serde_json::Value>,
52 pub body: Option<serde_json::Value>,
53 pub expected_status: Option<u16>,
54 pub expected_body: Option<serde_json::Value>,
55 #[serde(default = "default_timeout")]
56 pub timeout: u64,
57 #[serde(default)]
58 pub max_response_time: Option<u64>,
59 pub expected_headers: Option<serde_json::Value>,
60 pub json_schema: Option<serde_json::Value>,
61 #[serde(default)]
62 pub retry: RetryConfig,
63}
64
65fn default_timeout() -> u64 {
66 30 }
68
69pub struct ApiTestRunner {
70 client: Client,
71}
72
73impl ApiTestRunner {
74 pub fn new() -> Self {
75 Self {
76 client: Client::builder()
77 .timeout(Duration::from_secs(30))
78 .build()
79 .unwrap_or_else(|_| Client::new()),
80 }
81 }
82
83 async fn execute_request_with_retry(&self, config: &ApiTestConfig) -> Result<Response> {
84 let retry_strategy = ExponentialBackoff::from_millis(config.retry.initial_delay_ms)
85 .max_delay(Duration::from_millis(config.retry.max_delay_ms))
86 .map(jitter) .take(config.retry.max_retries as usize);
88
89 let mut attempt = 1;
90 let result = Retry::spawn(retry_strategy, || {
91 let attempt_num = attempt;
92 attempt += 1;
93 async move {
94 info!("Attempt {} of {}", attempt_num, config.retry.max_retries + 1);
95 match self.execute_request(config).await {
96 Ok(response) => {
97 let status = response.status();
98 let status_code = status.as_u16();
99
100 if config.retry.retry_status_codes.contains(&status_code) {
102 warn!(
103 "Request failed with status {} on attempt {}, retrying...",
104 status_code, attempt_num
105 );
106 Err(Error::TestError(format!(
107 "Retryable status code: {} (attempt {}/{})",
108 status_code, attempt_num, config.retry.max_retries + 1
109 )))
110 } else {
111 Ok(response)
112 }
113 }
114 Err(e) => {
115 let should_retry = match &e {
117 Error::RequestError(req_err) => {
118 if req_err.is_timeout() && config.retry.retry_on_timeout {
119 warn!("Request timed out on attempt {}, retrying...", attempt_num);
120 true
121 } else if req_err.is_connect() && config.retry.retry_on_connection_error {
122 warn!("Connection error on attempt {}, retrying...", attempt_num);
123 true
124 } else {
125 false
126 }
127 }
128 _ => false,
129 };
130
131 if should_retry {
132 Err(e)
133 } else {
134 Err(Error::TestError(format!(
135 "Non-retryable error on attempt {}: {}",
136 attempt_num, e
137 )))
138 }
139 }
140 }
141 }
142 }).await;
143
144 match result {
145 Ok(response) => {
146 info!("Request succeeded after {} attempts", attempt - 1);
147 Ok(response)
148 }
149 Err(e) => {
150 warn!("All {} retry attempts failed: {}", config.retry.max_retries + 1, e);
151 Err(e)
152 }
153 }
154 }
155
156 async fn execute_request(&self, config: &ApiTestConfig) -> Result<Response> {
157 let method = Method::from_bytes(config.method.as_bytes())
158 .map_err(|e| Error::ValidationError(format!("Invalid HTTP method: {}", e)))?;
159
160 let mut request = self.client.request(method, &config.url)
161 .timeout(Duration::from_secs(config.timeout));
162
163 if let Some(headers) = &config.headers {
164 if let Some(headers_obj) = headers.as_object() {
165 for (key, value) in headers_obj {
166 if let Some(value_str) = value.as_str() {
167 request = request.header(key, value_str);
168 }
169 }
170 }
171 }
172
173 if let Some(body) = &config.body {
174 request = request.json(body);
175 }
176
177 info!("Sending {} request to {}", config.method, config.url);
178 let response = request.send().await?;
179 Ok(response)
180 }
181
182 async fn validate_response(&self, response: Response, config: &ApiTestConfig, duration: f64) -> Result<serde_json::Value> {
183 if let Some(max_time) = config.max_response_time {
185 if duration > max_time as f64 {
186 return Err(Error::TestError(format!(
187 "Response time exceeded maximum allowed time. Expected: {}s, Got: {:.2}s",
188 max_time, duration
189 )));
190 }
191 }
192
193 if let Some(expected_status) = config.expected_status {
195 if response.status().as_u16() != expected_status {
196 return Err(Error::TestError(format!(
197 "Expected status {} but got {}",
198 expected_status,
199 response.status().as_u16()
200 )));
201 }
202 }
203
204 if let Some(expected_headers) = &config.expected_headers {
206 if let Some(expected_obj) = expected_headers.as_object() {
207 for (key, expected_value) in expected_obj {
208 if let Some(actual_value) = response.headers().get(key) {
209 let actual_str = actual_value.to_str().unwrap_or("");
210 let expected_str = expected_value.as_str().unwrap_or("");
211 if actual_str != expected_str {
212 return Err(Error::TestError(format!(
213 "Response header mismatch for '{}'. Expected: {}, Got: {}",
214 key, expected_str, actual_str
215 )));
216 }
217 } else {
218 return Err(Error::TestError(format!(
219 "Expected header '{}' not found in response",
220 key
221 )));
222 }
223 }
224 }
225 }
226
227 let body_text = response.text().await?;
228 let actual_body: serde_json::Value = match serde_json::from_str(&body_text) {
229 Ok(json) => json,
230 Err(_) => {
231 return Err(Error::TestError(format!(
232 "Response is not valid JSON: {}",
233 body_text
234 )));
235 }
236 };
237
238 if let Some(schema) = &config.json_schema {
240 let compiled_schema = JSONSchema::compile(schema)
241 .map_err(|e| Error::ValidationError(format!("Invalid JSON Schema: {}", e)))?;
242
243 let validation_result = compiled_schema.validate(&actual_body);
244 if let Err(errors) = validation_result {
245 let error_messages: Vec<String> = errors
246 .map(|e| format!("{}", e))
247 .collect();
248 return Err(Error::TestError(format!(
249 "JSON Schema validation failed:\n{}",
250 error_messages.join("\n")
251 )));
252 }
253 }
254
255 if let Some(expected_body) = &config.expected_body {
257 if let Some(expected_obj) = expected_body.as_object() {
258 for (key, expected_value) in expected_obj {
259 if let Some(actual_value) = actual_body.get(key) {
260 if actual_value != expected_value {
261 return Err(Error::TestError(format!(
262 "Field '{}' mismatch. Expected: {:?}, Got: {:?}",
263 key, expected_value, actual_value
264 )));
265 }
266 } else {
267 return Err(Error::TestError(format!(
268 "Expected field '{}' not found in response body",
269 key
270 )));
271 }
272 }
273 }
274 }
275
276 Ok(actual_body)
277 }
278}
279
280#[async_trait]
281impl TestRunner for ApiTestRunner {
282 async fn run(&self, config: &(impl serde::Serialize + Send + Sync)) -> Result<TestResult> {
283 let config = serde_json::from_value::<ApiTestConfig>(serde_json::to_value(config)?)?;
284 let start = Instant::now();
285
286 match self.execute_request_with_retry(&config).await {
287 Ok(response) => {
288 let duration = start.elapsed().as_secs_f64();
289 let status = response.status();
290 let headers = response.headers().clone();
291
292 match self.validate_response(response, &config, duration).await {
293 Ok(body) => {
294 Ok(TestResult {
295 name: config.base.name,
296 status: "passed".to_string(),
297 duration,
298 details: Some(serde_json::json!({
299 "status_code": status.as_u16(),
300 "response_time": duration,
301 "headers": headers
302 .iter()
303 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
304 .collect::<std::collections::HashMap<_, _>>(),
305 "body": body
306 })),
307 timestamp: Utc::now().to_rfc3339(),
308 })
309 }
310 Err(e) => {
311 Ok(TestResult {
312 name: config.base.name,
313 status: "failed".to_string(),
314 duration,
315 details: Some(serde_json::json!({
316 "error": e.to_string(),
317 "status_code": status.as_u16(),
318 "response_time": duration
319 })),
320 timestamp: Utc::now().to_rfc3339(),
321 })
322 }
323 }
324 }
325 Err(e) => {
326 let duration = start.elapsed().as_secs_f64();
327 Ok(TestResult {
328 name: config.base.name,
329 status: "failed".to_string(),
330 duration,
331 details: Some(serde_json::json!({
332 "error": e.to_string(),
333 "response_time": duration
334 })),
335 timestamp: Utc::now().to_rfc3339(),
336 })
337 }
338 }
339 }
340}