1use std::collections::HashMap;
32
33use serde::{Deserialize, Serialize};
34
35use crate::router::{RouteMetadata, Router};
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct ContractTestResult {
40 pub path: String,
42
43 pub method: String,
45
46 pub passed: bool,
48
49 pub failure_reason: Option<String>,
51
52 pub errors: Vec<String>,
54}
55
56impl ContractTestResult {
57 pub fn passed(path: impl Into<String>, method: impl Into<String>) -> Self {
59 Self {
60 path: path.into(),
61 method: method.into(),
62 passed: true,
63 failure_reason: None,
64 errors: Vec::new(),
65 }
66 }
67
68 pub fn failed(
70 path: impl Into<String>,
71 method: impl Into<String>,
72 reason: impl Into<String>,
73 ) -> Self {
74 Self {
75 path: path.into(),
76 method: method.into(),
77 passed: false,
78 failure_reason: Some(reason.into()),
79 errors: Vec::new(),
80 }
81 }
82
83 pub fn with_error(mut self, error: impl Into<String>) -> Self {
85 self.errors.push(error.into());
86 self.passed = false;
87 self
88 }
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct ContractTestResults {
94 pub results: Vec<ContractTestResult>,
96
97 pub total: usize,
99
100 pub passed: usize,
102
103 pub failed: usize,
105
106 pub coverage: f64,
108}
109
110impl ContractTestResults {
111 pub fn new(results: Vec<ContractTestResult>) -> Self {
113 let total = results.len();
114 let passed = results.iter().filter(|r| r.passed).count();
115 let failed = total - passed;
116 let coverage = if total > 0 {
117 (passed as f64 / total as f64) * 100.0
118 } else {
119 0.0
120 };
121
122 Self {
123 results,
124 total,
125 passed,
126 failed,
127 coverage,
128 }
129 }
130
131 pub fn all_passed(&self) -> bool {
133 self.failed == 0
134 }
135
136 pub fn failed_tests(&self) -> Vec<&ContractTestResult> {
138 self.results.iter().filter(|r| !r.passed).collect()
139 }
140}
141
142#[derive(Debug, Clone)]
144pub struct ContractTestConfig {
145 pub validate_requests: bool,
147
148 pub validate_responses: bool,
150
151 pub detect_breaking_changes: bool,
153
154 pub generate_mocks: bool,
156
157 pub fail_fast: bool,
159}
160
161impl Default for ContractTestConfig {
162 fn default() -> Self {
163 Self {
164 validate_requests: true,
165 validate_responses: true,
166 detect_breaking_changes: true,
167 generate_mocks: false,
168 fail_fast: false,
169 }
170 }
171}
172
173impl ContractTestConfig {
174 pub fn new() -> Self {
176 Self::default()
177 }
178
179 pub fn validate_requests(mut self, enable: bool) -> Self {
181 self.validate_requests = enable;
182 self
183 }
184
185 pub fn validate_responses(mut self, enable: bool) -> Self {
187 self.validate_responses = enable;
188 self
189 }
190
191 pub fn detect_breaking_changes(mut self, enable: bool) -> Self {
193 self.detect_breaking_changes = enable;
194 self
195 }
196
197 pub fn generate_mocks(mut self, enable: bool) -> Self {
199 self.generate_mocks = enable;
200 self
201 }
202
203 pub fn fail_fast(mut self, enable: bool) -> Self {
205 self.fail_fast = enable;
206 self
207 }
208}
209
210pub struct ContractTester<'a> {
212 #[allow(dead_code)]
213 router: &'a Router,
214 config: ContractTestConfig,
215 routes: Vec<RouteMetadata>,
216}
217
218impl<'a> ContractTester<'a> {
219 pub fn new(router: &'a Router) -> Self {
221 Self {
222 router,
223 config: ContractTestConfig::default(),
224 routes: router.routes().to_vec(),
225 }
226 }
227
228 pub fn with_config(router: &'a Router, config: ContractTestConfig) -> Self {
230 Self {
231 router,
232 config,
233 routes: router.routes().to_vec(),
234 }
235 }
236
237 pub fn test_all_routes(&self) -> ContractTestResults {
239 let mut results = Vec::new();
240
241 for route in &self.routes {
242 let result = self.test_route(route);
243 results.push(result);
244
245 if self.config.fail_fast && !results.last().unwrap().passed {
246 break;
247 }
248 }
249
250 ContractTestResults::new(results)
251 }
252
253 pub fn test_route(&self, route: &RouteMetadata) -> ContractTestResult {
255 if route.path.is_empty() {
257 return ContractTestResult::failed(&route.path, &route.method, "Route path is empty");
258 }
259
260 let mut result = ContractTestResult::passed(&route.path, &route.method);
261
262 if self.config.validate_requests {
264 if let Some(error) = self.validate_request_schema(route) {
265 result = result.with_error(error);
266 }
267 }
268
269 if self.config.validate_responses {
271 if let Some(error) = self.validate_response_schema(route) {
272 result = result.with_error(error);
273 }
274 }
275
276 result
277 }
278
279 fn validate_request_schema(&self, route: &RouteMetadata) -> Option<String> {
281 if self.config.validate_requests && route.request_schema.is_none() {
283 return Some(format!("Route {} lacks request schema", route.path));
284 }
285 None
286 }
287
288 fn validate_response_schema(&self, route: &RouteMetadata) -> Option<String> {
290 if self.config.validate_responses && route.response_schema.is_none() {
292 return Some(format!("Route {} lacks response schema", route.path));
293 }
294 None
295 }
296
297 pub fn generate_test_code(&self, route: &RouteMetadata) -> String {
299 format!(
300 r#"#[tokio::test]
301async fn test_{}_contract() {{
302 let router = Router::new();
303 let tester = ContractTester::new(&router);
304
305 let route = router.routes()
306 .iter()
307 .find(|r| r.path == "{}" && r.method == "{}")
308 .expect("Route not found");
309
310 let result = tester.test_route(route);
311 assert!(result.passed, "Contract test failed: {{:?}}", result.failure_reason);
312}}
313"#,
314 route.path.replace('/', "_").replace(['{', '}'], ""),
315 route.path,
316 route.method
317 )
318 }
319
320 pub fn coverage_stats(&self) -> HashMap<String, f64> {
322 let mut stats = HashMap::new();
323 let results = self.test_all_routes();
324
325 stats.insert("total_routes".to_string(), self.routes.len() as f64);
326 stats.insert("tested_routes".to_string(), results.passed as f64);
327 stats.insert("coverage_percent".to_string(), results.coverage);
328 stats.insert("failed_tests".to_string(), results.failed as f64);
329
330 stats
331 }
332}
333
334pub trait ContractTestable {
336 fn generate_contract_tests(&self) -> ContractTestResults;
338
339 fn test_route_contract(&self, path: &str, method: &str) -> ContractTestResult;
341}
342
343impl ContractTestable for Router {
344 fn generate_contract_tests(&self) -> ContractTestResults {
345 let tester = ContractTester::new(self);
346 tester.test_all_routes()
347 }
348
349 fn test_route_contract(&self, path: &str, method: &str) -> ContractTestResult {
350 let tester = ContractTester::new(self);
351
352 if let Some(route) = self
353 .routes()
354 .iter()
355 .find(|r| r.path == path && r.method == method)
356 {
357 tester.test_route(route)
358 } else {
359 ContractTestResult::failed(
360 path,
361 method,
362 format!("Route not found: {} {}", path, method),
363 )
364 }
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371
372 #[test]
373 fn test_contract_test_result_passed() {
374 let result = ContractTestResult::passed("/users", "GET");
375 assert!(result.passed);
376 assert_eq!(result.path, "/users");
377 assert_eq!(result.method, "GET");
378 assert!(result.failure_reason.is_none());
379 }
380
381 #[test]
382 fn test_contract_test_result_failed() {
383 let result = ContractTestResult::failed("/users", "POST", "Invalid schema");
384 assert!(!result.passed);
385 assert_eq!(result.failure_reason, Some("Invalid schema".to_string()));
386 }
387
388 #[test]
389 fn test_contract_test_result_with_error() {
390 let result = ContractTestResult::passed("/users", "GET").with_error("Missing field: name");
391
392 assert!(!result.passed);
393 assert_eq!(result.errors.len(), 1);
394 assert_eq!(result.errors[0], "Missing field: name");
395 }
396
397 #[test]
398 fn test_contract_test_results() {
399 let results = vec![
400 ContractTestResult::passed("/users", "GET"),
401 ContractTestResult::passed("/posts", "GET"),
402 ContractTestResult::failed("/admin", "DELETE", "Unauthorized"),
403 ];
404
405 let test_results = ContractTestResults::new(results);
406
407 assert_eq!(test_results.total, 3);
408 assert_eq!(test_results.passed, 2);
409 assert_eq!(test_results.failed, 1);
410 assert!(!test_results.all_passed());
411 assert_eq!(test_results.coverage, 66.66666666666666);
412 }
413
414 #[test]
415 fn test_contract_test_config() {
416 let config = ContractTestConfig::new()
417 .validate_requests(true)
418 .validate_responses(true)
419 .detect_breaking_changes(false)
420 .fail_fast(true);
421
422 assert!(config.validate_requests);
423 assert!(config.validate_responses);
424 assert!(!config.detect_breaking_changes);
425 assert!(config.fail_fast);
426 }
427
428 #[test]
429 fn test_contract_tester_empty_router() {
430 let router = Router::new();
431 let tester = ContractTester::new(&router);
432 let results = tester.test_all_routes();
433
434 assert_eq!(results.total, 0);
435 assert_eq!(results.passed, 0);
436 assert!(results.all_passed());
437 }
438
439 #[test]
440 fn test_contract_testable_trait() {
441 let router = Router::new();
442 let results = router.generate_contract_tests();
443
444 assert_eq!(results.total, 0);
445 assert!(results.all_passed());
446 }
447
448 #[test]
449 fn test_generate_test_code() {
450 let router = Router::new();
451 let tester = ContractTester::new(&router);
452
453 let route = RouteMetadata {
454 path: "/users".to_string(),
455 method: "GET".to_string(),
456 protocol: "rest".to_string(),
457 description: Some("Get users".to_string()),
458 request_schema: None,
459 response_schema: None,
460 };
461
462 let code = tester.generate_test_code(&route);
463
464 assert!(code.contains("test__users_contract") || code.contains("test_users_contract"));
466 assert!(code.contains("/users"));
467 assert!(code.contains("GET"));
468 }
469
470 #[test]
471 fn test_coverage_stats() {
472 let router = Router::new();
473 let tester = ContractTester::new(&router);
474 let stats = tester.coverage_stats();
475
476 assert_eq!(stats.get("total_routes"), Some(&0.0));
477 assert_eq!(stats.get("tested_routes"), Some(&0.0));
478 }
479}