1use std::collections::HashMap;
28
29use serde::{Deserialize, Serialize};
30
31use crate::router::{RouteMetadata, Router};
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct ContractTestResult {
36 pub path: String,
38
39 pub method: String,
41
42 pub passed: bool,
44
45 pub failure_reason: Option<String>,
47
48 pub errors: Vec<String>,
50}
51
52impl ContractTestResult {
53 pub fn passed(path: impl Into<String>, method: impl Into<String>) -> Self {
55 Self {
56 path: path.into(),
57 method: method.into(),
58 passed: true,
59 failure_reason: None,
60 errors: Vec::new(),
61 }
62 }
63
64 pub fn failed(
66 path: impl Into<String>,
67 method: impl Into<String>,
68 reason: impl Into<String>,
69 ) -> Self {
70 Self {
71 path: path.into(),
72 method: method.into(),
73 passed: false,
74 failure_reason: Some(reason.into()),
75 errors: Vec::new(),
76 }
77 }
78
79 pub fn with_error(mut self, error: impl Into<String>) -> Self {
81 self.errors.push(error.into());
82 self.passed = false;
83 self
84 }
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct ContractTestResults {
90 pub results: Vec<ContractTestResult>,
92
93 pub total: usize,
95
96 pub passed: usize,
98
99 pub failed: usize,
101
102 pub coverage: f64,
104}
105
106impl ContractTestResults {
107 pub fn new(results: Vec<ContractTestResult>) -> Self {
109 let total = results.len();
110 let passed = results.iter().filter(|r| r.passed).count();
111 let failed = total - passed;
112 let coverage = if total > 0 {
113 (passed as f64 / total as f64) * 100.0
114 } else {
115 0.0
116 };
117
118 Self {
119 results,
120 total,
121 passed,
122 failed,
123 coverage,
124 }
125 }
126
127 pub fn all_passed(&self) -> bool {
129 self.failed == 0
130 }
131
132 pub fn failed_tests(&self) -> Vec<&ContractTestResult> {
134 self.results.iter().filter(|r| !r.passed).collect()
135 }
136}
137
138#[derive(Debug, Clone)]
140pub struct ContractTestConfig {
141 pub validate_requests: bool,
143
144 pub validate_responses: bool,
146
147 pub detect_breaking_changes: bool,
149
150 pub generate_mocks: bool,
152
153 pub fail_fast: bool,
155}
156
157impl Default for ContractTestConfig {
158 fn default() -> Self {
159 Self {
160 validate_requests: true,
161 validate_responses: true,
162 detect_breaking_changes: true,
163 generate_mocks: false,
164 fail_fast: false,
165 }
166 }
167}
168
169impl ContractTestConfig {
170 pub fn new() -> Self {
172 Self::default()
173 }
174
175 pub fn validate_requests(mut self, enable: bool) -> Self {
177 self.validate_requests = enable;
178 self
179 }
180
181 pub fn validate_responses(mut self, enable: bool) -> Self {
183 self.validate_responses = enable;
184 self
185 }
186
187 pub fn detect_breaking_changes(mut self, enable: bool) -> Self {
189 self.detect_breaking_changes = enable;
190 self
191 }
192
193 pub fn generate_mocks(mut self, enable: bool) -> Self {
195 self.generate_mocks = enable;
196 self
197 }
198
199 pub fn fail_fast(mut self, enable: bool) -> Self {
201 self.fail_fast = enable;
202 self
203 }
204}
205
206pub struct ContractTester<'a> {
208 #[allow(dead_code)]
209 router: &'a Router,
210 config: ContractTestConfig,
211 routes: Vec<RouteMetadata>,
212}
213
214impl<'a> ContractTester<'a> {
215 pub fn new(router: &'a Router) -> Self {
217 Self {
218 router,
219 config: ContractTestConfig::default(),
220 routes: router.routes().to_vec(),
221 }
222 }
223
224 pub fn with_config(router: &'a Router, config: ContractTestConfig) -> Self {
226 Self {
227 router,
228 config,
229 routes: router.routes().to_vec(),
230 }
231 }
232
233 pub fn test_all_routes(&self) -> ContractTestResults {
235 let mut results = Vec::new();
236
237 for route in &self.routes {
238 let result = self.test_route(route);
239 results.push(result);
240
241 if self.config.fail_fast && !results.last().unwrap().passed {
242 break;
243 }
244 }
245
246 ContractTestResults::new(results)
247 }
248
249 pub fn test_route(&self, route: &RouteMetadata) -> ContractTestResult {
251 if route.path.is_empty() {
253 return ContractTestResult::failed(&route.path, &route.method, "Route path is empty");
254 }
255
256 let mut result = ContractTestResult::passed(&route.path, &route.method);
257
258 if self.config.validate_requests {
260 if let Some(error) = self.validate_request_schema(route) {
261 result = result.with_error(error);
262 }
263 }
264
265 if self.config.validate_responses {
267 if let Some(error) = self.validate_response_schema(route) {
268 result = result.with_error(error);
269 }
270 }
271
272 result
273 }
274
275 fn validate_request_schema(&self, route: &RouteMetadata) -> Option<String> {
277 if self.config.validate_requests && route.request_schema.is_none() {
279 return Some(format!("Route {} lacks request schema", route.path));
280 }
281 None
282 }
283
284 fn validate_response_schema(&self, route: &RouteMetadata) -> Option<String> {
286 if self.config.validate_responses && route.response_schema.is_none() {
288 return Some(format!("Route {} lacks response schema", route.path));
289 }
290 None
291 }
292
293 pub fn generate_test_code(&self, route: &RouteMetadata) -> String {
295 format!(
296 r#"#[tokio::test]
297async fn test_{}_contract() {{
298 let router = Router::new();
299 let tester = ContractTester::new(&router);
300
301 let route = router.routes()
302 .iter()
303 .find(|r| r.path == "{}" && r.method == "{}")
304 .expect("Route not found");
305
306 let result = tester.test_route(route);
307 assert!(result.passed, "Contract test failed: {{:?}}", result.failure_reason);
308}}
309"#,
310 route.path.replace('/', "_").replace(['{', '}'], ""),
311 route.path,
312 route.method
313 )
314 }
315
316 pub fn coverage_stats(&self) -> HashMap<String, f64> {
318 let mut stats = HashMap::new();
319 let results = self.test_all_routes();
320
321 stats.insert("total_routes".to_string(), self.routes.len() as f64);
322 stats.insert("tested_routes".to_string(), results.passed as f64);
323 stats.insert("coverage_percent".to_string(), results.coverage);
324 stats.insert("failed_tests".to_string(), results.failed as f64);
325
326 stats
327 }
328}
329
330pub trait ContractTestable {
332 fn generate_contract_tests(&self) -> ContractTestResults;
334
335 fn test_route_contract(&self, path: &str, method: &str) -> ContractTestResult;
337}
338
339impl ContractTestable for Router {
340 fn generate_contract_tests(&self) -> ContractTestResults {
341 let tester = ContractTester::new(self);
342 tester.test_all_routes()
343 }
344
345 fn test_route_contract(&self, path: &str, method: &str) -> ContractTestResult {
346 let tester = ContractTester::new(self);
347
348 if let Some(route) = self
349 .routes()
350 .iter()
351 .find(|r| r.path == path && r.method == method)
352 {
353 tester.test_route(route)
354 } else {
355 ContractTestResult::failed(
356 path,
357 method,
358 format!("Route not found: {} {}", path, method),
359 )
360 }
361 }
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367
368 #[test]
369 fn test_contract_test_result_passed() {
370 let result = ContractTestResult::passed("/users", "GET");
371 assert!(result.passed);
372 assert_eq!(result.path, "/users");
373 assert_eq!(result.method, "GET");
374 assert!(result.failure_reason.is_none());
375 }
376
377 #[test]
378 fn test_contract_test_result_failed() {
379 let result = ContractTestResult::failed("/users", "POST", "Invalid schema");
380 assert!(!result.passed);
381 assert_eq!(result.failure_reason, Some("Invalid schema".to_string()));
382 }
383
384 #[test]
385 fn test_contract_test_result_with_error() {
386 let result = ContractTestResult::passed("/users", "GET").with_error("Missing field: name");
387
388 assert!(!result.passed);
389 assert_eq!(result.errors.len(), 1);
390 assert_eq!(result.errors[0], "Missing field: name");
391 }
392
393 #[test]
394 fn test_contract_test_results() {
395 let results = vec![
396 ContractTestResult::passed("/users", "GET"),
397 ContractTestResult::passed("/posts", "GET"),
398 ContractTestResult::failed("/admin", "DELETE", "Unauthorized"),
399 ];
400
401 let test_results = ContractTestResults::new(results);
402
403 assert_eq!(test_results.total, 3);
404 assert_eq!(test_results.passed, 2);
405 assert_eq!(test_results.failed, 1);
406 assert!(!test_results.all_passed());
407 assert_eq!(test_results.coverage, 66.66666666666666);
408 }
409
410 #[test]
411 fn test_contract_test_config() {
412 let config = ContractTestConfig::new()
413 .validate_requests(true)
414 .validate_responses(true)
415 .detect_breaking_changes(false)
416 .fail_fast(true);
417
418 assert!(config.validate_requests);
419 assert!(config.validate_responses);
420 assert!(!config.detect_breaking_changes);
421 assert!(config.fail_fast);
422 }
423
424 #[test]
425 fn test_contract_tester_empty_router() {
426 let router = Router::new();
427 let tester = ContractTester::new(&router);
428 let results = tester.test_all_routes();
429
430 assert_eq!(results.total, 0);
431 assert_eq!(results.passed, 0);
432 assert!(results.all_passed());
433 }
434
435 #[test]
436 fn test_contract_testable_trait() {
437 let router = Router::new();
438 let results = router.generate_contract_tests();
439
440 assert_eq!(results.total, 0);
441 assert!(results.all_passed());
442 }
443
444 #[test]
445 fn test_generate_test_code() {
446 let router = Router::new();
447 let tester = ContractTester::new(&router);
448
449 let route = RouteMetadata {
450 path: "/users".to_string(),
451 method: "GET".to_string(),
452 protocol: "rest".to_string(),
453 description: Some("Get users".to_string()),
454 request_schema: None,
455 response_schema: None,
456 };
457
458 let code = tester.generate_test_code(&route);
459
460 assert!(code.contains("test__users_contract") || code.contains("test_users_contract"));
462 assert!(code.contains("/users"));
463 assert!(code.contains("GET"));
464 }
465
466 #[test]
467 fn test_coverage_stats() {
468 let router = Router::new();
469 let tester = ContractTester::new(&router);
470 let stats = tester.coverage_stats();
471
472 assert_eq!(stats.get("total_routes"), Some(&0.0));
473 assert_eq!(stats.get("tested_routes"), Some(&0.0));
474 }
475}